burn_lm_inference/job.rs
1use std::{
2 any::Any,
3 io::Write,
4 marker::PhantomData,
5 sync::{
6 atomic::{AtomicBool, Ordering},
7 mpsc::SyncSender,
8 Arc,
9 },
10};
11
12use crate::{Message, Prompt};
13
14/// Defines a job to be run during inference.
15pub struct InferenceJob {
16 /// The task to be performed by the job.
17 pub task: InferenceTask,
18 /// The emitter for the current job.
19 pub emitter: GeneratedItemEmitter,
20}
21
22/// An emitter is responsible to send [generated items](GeneratedItem) to the [inference job](InferenceJob)
23/// channel.
24pub struct GeneratedItemEmitter {
25 sender: SyncSender<Msg>,
26 done: Arc<AtomicBool>,
27}
28
29/// The potential tasks that can be executed by an [inference job](InferenceJob) using the
30/// [inference server](crate::InferenceServer).
31pub enum InferenceTask {
32 /// A single message to be processed by the server.
33 Message(Message),
34 /// Multiple messages to be processed by the server.
35 ///
36 /// This could be useful to restore a previous session based on text history.
37 Context(Vec<Message>),
38 /// Run with a simple prompt.
39 Prompt(Prompt),
40}
41
42/// Defines all the potential items that can be generated by an [inference job](InferenceJob).
43///
44/// For now there is only text generation, but this could be extented to other kind of output
45/// artifacts.
46pub enum GeneratedItem {
47 /// Generated text includes intermediary tokens and doesn't mark the end of a text generation
48 /// job.
49 Text(String),
50}
51
52impl InferenceJob {
53 /// Start a new inference job and process it on another thread.
54 ///
55 /// When a task is performed for the current job, it should be registered using the
56 /// [completed method](Self::completed).
57 pub fn create<L: InferenceJobListener>(
58 task: InferenceTask,
59 listener: L,
60 ) -> (Self, JobHandle<L>) {
61 let (emitter, handle) = GeneratedItemEmitter::init(listener);
62
63 (Self { task, emitter }, handle)
64 }
65}
66
67impl GeneratedItemEmitter {
68 pub fn init<L: InferenceJobListener>(mut listener: L) -> (Self, JobHandle<L>) {
69 let (sender, receiver) = std::sync::mpsc::sync_channel::<Msg>(1);
70
71 let handle = JobHandle {
72 sender: sender.clone(),
73 _c: PhantomData,
74 };
75 let done = Arc::new(AtomicBool::new(false));
76 let emitter = GeneratedItemEmitter {
77 sender,
78 done: done.clone(),
79 };
80
81 // TODO: We could use a threadpool for inference jobs.
82 std::thread::spawn(move || {
83 for msg in receiver {
84 match msg {
85 Msg::Text(text) => listener.on_text(text),
86 Msg::Finished(c) => {
87 let result = listener.on_finished();
88 let result: Box<dyn Any + Send> = Box::new(result);
89 c.send(result).unwrap();
90 done.store(true, Ordering::Relaxed);
91 return;
92 }
93 }
94 }
95 });
96
97 (emitter, handle)
98 }
99
100 /// Register the completion of an [inference generation](InferenceGeneration).
101 pub fn completed(&self, item: GeneratedItem) {
102 if !self.done.load(Ordering::Relaxed) {
103 let msg = match item {
104 GeneratedItem::Text(text) => Msg::Text(text),
105 };
106 self.sender.send(msg).unwrap();
107 }
108 }
109}
110
111/// An inference job listener receive events while the [inference job](InferenceJob) is running.
112pub trait InferenceJobListener: Send + 'static {
113 /// The item that is returned by the listener.
114 type CompletedItem: Send;
115
116 /// Called when new text is generated from an [inference job](InferenceJob).
117 fn on_text(&mut self, text: String);
118
119 /// Called when the job is finished.
120 ///
121 /// The inference job listener can return an item when a job is finished.
122 /// The item is going to be available through the [job handle finished method](JobHandle::finished).
123 fn on_finished(self) -> Self::CompletedItem;
124}
125
126#[derive(Default)]
127/// The text generation listener accumulate the generated text in a string that can be
128/// obtained at the end of the job with the [handle finished method](JobHandle::finished).
129pub struct TextGenerationListener {
130 value: String,
131}
132
133impl InferenceJobListener for TextGenerationListener {
134 type CompletedItem = String;
135
136 fn on_text(&mut self, text: String) {
137 self.value += &text;
138 }
139
140 fn on_finished(self) -> Self::CompletedItem {
141 self.value
142 }
143}
144
145#[derive(Default)]
146/// The stdout listener directly writes the intermediary [generated item](GeneratedItem) to
147/// [std::io::stdout].
148pub struct StdOutListener {}
149
150impl InferenceJobListener for StdOutListener {
151 type CompletedItem = ();
152
153 fn on_text(&mut self, text: String) {
154 let mut io = std::io::stdout();
155
156 write!(io, "{text}").unwrap();
157 io.flush().unwrap();
158 }
159
160 fn on_finished(self) -> Self::CompletedItem {}
161}
162
163/// The handle returned by [InferenceJob::start].
164///
165/// This handle should be used to indicate when a job is finished using the
166/// [join method](JobHandle::join).
167pub struct JobHandle<C: InferenceJobListener> {
168 sender: SyncSender<Msg>,
169 _c: PhantomData<C>,
170}
171
172impl<C: InferenceJobListener> JobHandle<C> {
173 /// Wait for the job to complete and returns the
174 /// [InferenceJobListener::CompletedItem] from the job listener.
175 ///
176 /// # Warning
177 ///
178 /// Make sure to actually launch the inference job before calling this method, otherwise this
179 /// method will never complete.
180 pub fn join(&self) -> C::CompletedItem {
181 let (sender, rec) = std::sync::mpsc::sync_channel(1);
182 self.sender.send(Msg::Finished(sender)).unwrap();
183
184 match rec.recv() {
185 Ok(any) => *any.downcast().unwrap(),
186 Err(err) => panic!("{err}"),
187 }
188 }
189}
190
191enum Msg {
192 Text(String),
193 Finished(SyncSender<Box<dyn Any + Send>>),
194}