burn_lm_inference/
job.rs

1use std::{
2    any::Any,
3    io::Write,
4    marker::PhantomData,
5    sync::{
6        atomic::{AtomicBool, Ordering},
7        mpsc::SyncSender,
8        Arc,
9    },
10};
11
12use crate::{Message, Prompt};
13
14/// Defines a job to be run during inference.
15pub struct InferenceJob {
16    /// The task to be performed by the job.
17    pub task: InferenceTask,
18    /// The emitter for the current job.
19    pub emitter: GeneratedItemEmitter,
20}
21
22/// An emitter is responsible to send [generated items](GeneratedItem) to the [inference job](InferenceJob)
23/// channel.
24pub struct GeneratedItemEmitter {
25    sender: SyncSender<Msg>,
26    done: Arc<AtomicBool>,
27}
28
29/// The potential tasks that can be executed by an [inference job](InferenceJob) using the
30/// [inference server](crate::InferenceServer).
31pub enum InferenceTask {
32    /// A single message to be processed by the server.
33    Message(Message),
34    /// Multiple messages to be processed by the server.
35    ///
36    /// This could be useful to restore a previous session based on text history.
37    Context(Vec<Message>),
38    /// Run with a simple prompt.
39    Prompt(Prompt),
40}
41
42/// Defines all the potential items that can be generated by an [inference job](InferenceJob).
43///
44/// For now there is only text generation, but this could be extented to other kind of output
45/// artifacts.
46pub enum GeneratedItem {
47    /// Generated text includes intermediary tokens and doesn't mark the end of a text generation
48    /// job.
49    Text(String),
50}
51
52impl InferenceJob {
53    /// Start a new inference job and process it on another thread.
54    ///
55    /// When a task is performed for the current job, it should be registered using the
56    /// [completed method](Self::completed).
57    pub fn create<L: InferenceJobListener>(
58        task: InferenceTask,
59        listener: L,
60    ) -> (Self, JobHandle<L>) {
61        let (emitter, handle) = GeneratedItemEmitter::init(listener);
62
63        (Self { task, emitter }, handle)
64    }
65}
66
67impl GeneratedItemEmitter {
68    pub fn init<L: InferenceJobListener>(mut listener: L) -> (Self, JobHandle<L>) {
69        let (sender, receiver) = std::sync::mpsc::sync_channel::<Msg>(1);
70
71        let handle = JobHandle {
72            sender: sender.clone(),
73            _c: PhantomData,
74        };
75        let done = Arc::new(AtomicBool::new(false));
76        let emitter = GeneratedItemEmitter {
77            sender,
78            done: done.clone(),
79        };
80
81        // TODO: We could use a threadpool for inference jobs.
82        std::thread::spawn(move || {
83            for msg in receiver {
84                match msg {
85                    Msg::Text(text) => listener.on_text(text),
86                    Msg::Finished(c) => {
87                        let result = listener.on_finished();
88                        let result: Box<dyn Any + Send> = Box::new(result);
89                        c.send(result).unwrap();
90                        done.store(true, Ordering::Relaxed);
91                        return;
92                    }
93                }
94            }
95        });
96
97        (emitter, handle)
98    }
99
100    /// Register the completion of an [inference generation](InferenceGeneration).
101    pub fn completed(&self, item: GeneratedItem) {
102        if !self.done.load(Ordering::Relaxed) {
103            let msg = match item {
104                GeneratedItem::Text(text) => Msg::Text(text),
105            };
106            self.sender.send(msg).unwrap();
107        }
108    }
109}
110
111/// An inference job listener receive events while the [inference job](InferenceJob) is running.
112pub trait InferenceJobListener: Send + 'static {
113    /// The item that is returned by the listener.
114    type CompletedItem: Send;
115
116    /// Called when new text is generated from an [inference job](InferenceJob).
117    fn on_text(&mut self, text: String);
118
119    /// Called when the job is finished.
120    ///
121    /// The inference job listener can return an item when a job is finished.
122    /// The item is going to be available through the [job handle finished method](JobHandle::finished).
123    fn on_finished(self) -> Self::CompletedItem;
124}
125
126#[derive(Default)]
127/// The text generation listener accumulate the generated text in a string that can be
128/// obtained at the end of the job with the [handle finished method](JobHandle::finished).
129pub struct TextGenerationListener {
130    value: String,
131}
132
133impl InferenceJobListener for TextGenerationListener {
134    type CompletedItem = String;
135
136    fn on_text(&mut self, text: String) {
137        self.value += &text;
138    }
139
140    fn on_finished(self) -> Self::CompletedItem {
141        self.value
142    }
143}
144
145#[derive(Default)]
146/// The stdout listener directly writes the intermediary [generated item](GeneratedItem) to
147/// [std::io::stdout].
148pub struct StdOutListener {}
149
150impl InferenceJobListener for StdOutListener {
151    type CompletedItem = ();
152
153    fn on_text(&mut self, text: String) {
154        let mut io = std::io::stdout();
155
156        write!(io, "{text}").unwrap();
157        io.flush().unwrap();
158    }
159
160    fn on_finished(self) -> Self::CompletedItem {}
161}
162
163/// The handle returned by [InferenceJob::start].
164///
165/// This handle should be used to indicate when a job is finished using the
166/// [join method](JobHandle::join).
167pub struct JobHandle<C: InferenceJobListener> {
168    sender: SyncSender<Msg>,
169    _c: PhantomData<C>,
170}
171
172impl<C: InferenceJobListener> JobHandle<C> {
173    /// Wait for the job to complete and returns the
174    /// [InferenceJobListener::CompletedItem] from the job listener.
175    ///
176    /// # Warning
177    ///
178    /// Make sure to actually launch the inference job before calling this method, otherwise this
179    /// method will never complete.
180    pub fn join(&self) -> C::CompletedItem {
181        let (sender, rec) = std::sync::mpsc::sync_channel(1);
182        self.sender.send(Msg::Finished(sender)).unwrap();
183
184        match rec.recv() {
185            Ok(any) => *any.downcast().unwrap(),
186            Err(err) => panic!("{err}"),
187        }
188    }
189}
190
191enum Msg {
192    Text(String),
193    Finished(SyncSender<Box<dyn Any + Send>>),
194}