kalosm_language_model/model/
ext.rs

1use futures_channel::mpsc::UnboundedReceiver;
2use futures_channel::oneshot::Receiver;
3use futures_util::Future;
4use futures_util::FutureExt;
5use futures_util::Stream;
6use futures_util::StreamExt;
7use std::any::Any;
8use std::error::Error;
9use std::future::IntoFuture;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::sync::Mutex;
13use std::sync::OnceLock;
14use std::sync::RwLock;
15use std::task::Poll;
16
17use crate::GenerationParameters;
18use crate::ModelConstraints;
19use crate::NoConstraints;
20
21use super::BoxedStructuredTextCompletionModel;
22use super::BoxedTextCompletionModel;
23use super::CreateDefaultCompletionConstraintsForType;
24use super::CreateTextCompletionSession;
25use super::StructuredTextCompletionModel;
26use super::TextCompletionModel;
27use super::TextCompletionSession;
28
29#[doc = include_str!("../../docs/completion.md")]
30pub trait TextCompletionModelExt: CreateTextCompletionSession {
31    /// Create a new text completion builder for this model. See [`TextCompletionBuilder`] for more details.
32    fn complete(&self, text: impl ToString) -> TextCompletionBuilder<Self>
33    where
34        Self: Clone,
35    {
36        // Then create the builder that will respond to the message if it is awaited
37        TextCompletionBuilder {
38            text: text.to_string(),
39            model: Some(self.clone()),
40            constraints: None,
41            sampler: Some(GenerationParameters::default()),
42            task: OnceLock::new(),
43            queued_tokens: None,
44            result: None,
45        }
46    }
47
48    /// Erase the type of the text completion model. This can be used to make multiple implementations of
49    /// [`TextCompletionModel`] compatible with the same type.
50    fn boxed_completion_model(self) -> BoxedTextCompletionModel
51    where
52        Self: TextCompletionModel<
53                Error: Send + Sync + std::error::Error + 'static,
54                Session: TextCompletionSession<Error: std::error::Error + Send + Sync + 'static>
55                             + Clone
56                             + Send
57                             + Sync
58                             + 'static,
59            > + Sized
60            + Send
61            + Sync
62            + 'static,
63    {
64        BoxedTextCompletionModel::new(self)
65    }
66
67    /// Erase the type of the structured text completion model. This can be used to make multiple implementations of
68    /// [`StructuredTextCompletionModel`] compatible with the same type.
69    fn boxed_typed_completion_model<T>(self) -> BoxedStructuredTextCompletionModel<T>
70    where
71        Self: StructuredTextCompletionModel<
72                Self::DefaultConstraints,
73                Error: Send + Sync + Error + 'static,
74                Session: TextCompletionSession<Error: Error + Send + Sync + 'static>
75                             + Clone
76                             + Send
77                             + Sync
78                             + 'static,
79            > + CreateDefaultCompletionConstraintsForType<T>
80            + Sized
81            + Send
82            + Sync
83            + 'static,
84        T: 'static,
85    {
86        BoxedStructuredTextCompletionModel::new(self)
87    }
88}
89
90impl<M: CreateTextCompletionSession> TextCompletionModelExt for M {}
91
92/// A builder for a text completion response. This is returned by [`TextCompletionModelExt::complete`]
93/// and can be modified with [`TextCompletionBuilder::with_sampler`] and [`TextCompletionBuilder::with_constraints`]
94/// until you start awaiting the response.
95///
96///
97/// Once you are done setting up the response, you can call `.await` for the full text completion, or [`ModelConstraints::Output`]:
98/// ```rust, no_run
99/// use kalosm::language::*;
100///
101/// #[tokio::main]
102/// async fn main() {
103///     let mut llm = Llama::new().await.unwrap();
104///     let prompt = "The following is a 300 word essay about why the capital of France is Paris:";
105///     print!("{prompt}");
106///     let mut completion = llm.complete(prompt).await.unwrap();
107///     println!("{completion}");
108/// }
109/// ```
110///
111/// Or use the response as a [`Stream`]:
112///
113/// ```rust, no_run
114/// use kalosm::language::*;
115/// use std::io::Write;
116///
117/// #[tokio::main]
118/// async fn main() {
119///     let mut llm = Llama::new().await.unwrap();
120///     let prompt = "The following is a 300 word essay about why the capital of France is Paris:";
121///     print!("{prompt}");
122///     let mut completion = llm.complete(prompt);
123///     while let Some(token) = completion.next().await {
124///         print!("{token}");
125///         std::io::stdout().flush().unwrap();
126///     }
127/// }
128/// ```
129pub struct TextCompletionBuilder<
130    M: CreateTextCompletionSession,
131    Constraints = NoConstraints,
132    Sampler = GenerationParameters,
133> {
134    text: String,
135    model: Option<M>,
136    constraints: Option<Constraints>,
137    sampler: Option<Sampler>,
138    task: OnceLock<RwLock<Pin<Box<dyn Future<Output = ()> + Send>>>>,
139    #[allow(clippy::type_complexity)]
140    result: Option<Receiver<Result<Box<dyn Any + Send>, M::Error>>>,
141    queued_tokens: Option<UnboundedReceiver<String>>,
142}
143
144impl<M: CreateTextCompletionSession, Constraints, Sampler>
145    TextCompletionBuilder<M, Constraints, Sampler>
146{
147    /// Constrains the model's response to the given parser. This can be used to make the model start with a certain phrase, or to make the model respond in a certain way.
148    ///
149    /// # Example
150    /// ```rust, no_run
151    /// # use kalosm::language::*;
152    /// # #[tokio::main]
153    /// # async fn main() {
154    /// #[derive(Parse, Clone, Debug)]
155    /// struct Pet {
156    ///     name: String,
157    ///     age: u32,
158    ///     description: String,
159    /// }
160    ///
161    /// // First create a model
162    /// let model = Llama::new().await.unwrap();
163    /// // Then create a parser for your data. Any type that implements the `Parse` trait has the `new_parser` method
164    /// let parser = Pet::new_parser();
165    /// // Create a text completion stream with the constraints
166    /// let description = model.complete("JSON for an adorable dog named ruffles: ")
167    ///     .with_constraints(parser);
168    /// // Finally, await the stream to get the parsed response
169    /// let pet: Pet = description.await.unwrap();
170    /// println!("{pet:?}");
171    /// # }
172    /// ```
173    pub fn with_constraints<NewConstraints: ModelConstraints>(
174        self,
175        constraints: NewConstraints,
176    ) -> TextCompletionBuilder<M, NewConstraints, Sampler> {
177        TextCompletionBuilder {
178            text: self.text,
179            model: self.model,
180            constraints: Some(constraints),
181            sampler: self.sampler,
182            queued_tokens: None,
183            result: None,
184            task: OnceLock::new(),
185        }
186    }
187
188    /// Constrains the model's response to the the default parser for the given type. This can be used to make the model return a specific type.
189    ///
190    /// # Example
191    /// ```rust, no_run
192    /// # use kalosm::language::*;
193    /// # #[tokio::main]
194    /// # async fn main() {
195    /// #[derive(Parse, Clone, Debug)]
196    /// struct Pet {
197    ///     name: String,
198    ///     age: u32,
199    ///     description: String,
200    /// }
201    ///
202    /// // First create a model
203    /// let model = Llama::new().await.unwrap();
204    /// // Create a text completion stream with the typed response
205    /// let description = model
206    ///     .complete("JSON for an adorable dog named ruffles: ")
207    ///     .typed();
208    /// // Finally, await the stream to get the parsed response
209    /// let pet: Pet = description.await.unwrap();
210    /// println!("{pet:?}");
211    /// # }
212    /// ```
213    pub fn typed<T>(
214        self,
215    ) -> TextCompletionBuilder<
216        M,
217        <M as CreateDefaultCompletionConstraintsForType<T>>::DefaultConstraints,
218        Sampler,
219    >
220    where
221        M: CreateDefaultCompletionConstraintsForType<T>,
222    {
223        self.with_constraints(M::create_default_constraints())
224    }
225
226    /// Sets the sampler to use for generating responses. The sampler determines how tokens are chosen from the probability distribution
227    /// the model generates. They can be used to make the model more or less predictable and prevent repetition.
228    ///
229    /// # Example
230    /// ```rust, no_run
231    /// # use kalosm::language::*;
232    /// # #[tokio::main]
233    /// # async fn main() {
234    /// let model = Llama::new().await.unwrap();
235    /// // Create the sampler to use for the text completion
236    /// let sampler = GenerationParameters::default().sampler();
237    /// // Create a completion request with the sampler
238    /// let mut stream = model
239    ///     .complete("Here is a list of 5 primes: ")
240    ///     .with_sampler(sampler);
241    /// stream.to_std_out().await.unwrap();
242    /// # }
243    /// ```
244    pub fn with_sampler<NewSampler>(
245        self,
246        sampler: NewSampler,
247    ) -> TextCompletionBuilder<M, Constraints, NewSampler> {
248        TextCompletionBuilder {
249            text: self.text,
250            model: self.model,
251            constraints: self.constraints,
252            sampler: Some(sampler),
253            queued_tokens: None,
254            result: None,
255            task: OnceLock::new(),
256        }
257    }
258}
259
260impl<M, Sampler> TextCompletionBuilder<M, NoConstraints, Sampler>
261where
262    Sampler: Send + Unpin + 'static,
263    M: TextCompletionModel<Sampler> + Send + Sync + Unpin + 'static,
264    M::Session: Send + Sync + Unpin + 'static,
265{
266    fn ensure_unstructured_task_started(&mut self) {
267        if self.task.get().is_none() {
268            let text = std::mem::take(&mut self.text);
269            let model = self
270                .model
271                .take()
272                .expect("TextCompletionBuilder cannot be turned into a future twice");
273            let sampler = self
274                .sampler
275                .take()
276                .expect("TextCompletionBuilder cannot be turned into a future twice");
277            let (mut tx, rx) = futures_channel::mpsc::unbounded();
278            let (result_tx, result_rx) = futures_channel::oneshot::channel();
279            self.queued_tokens = Some(rx);
280            self.result = Some(result_rx);
281            let all_text = Arc::new(Mutex::new(String::new()));
282            let on_token = {
283                let all_text = all_text.clone();
284                move |tok: String| {
285                    all_text.lock().unwrap().push_str(&tok);
286                    _ = tx.start_send(tok);
287                    Ok(())
288                }
289            };
290            let future = async move {
291                let mut session = model.new_session()?;
292                model
293                    .stream_text_with_callback(&mut session, &text, sampler, on_token)
294                    .await?;
295                let mut all_text = all_text.lock().unwrap();
296                let all_text = std::mem::take(&mut *all_text);
297                Ok(Box::new(all_text) as Box<dyn Any + Send>)
298            };
299            let wrapped = async move {
300                let result: Result<Box<dyn Any + Send>, M::Error> = future.await;
301                _ = result_tx.send(result);
302            };
303            let task = Box::pin(wrapped);
304            self.task
305                .set(RwLock::new(task))
306                .unwrap_or_else(|_| panic!("Task already set"));
307        }
308    }
309}
310
311impl<M, Sampler> Stream for TextCompletionBuilder<M, NoConstraints, Sampler>
312where
313    Sampler: Send + Unpin + 'static,
314    M: TextCompletionModel<Sampler> + Send + Sync + Unpin + 'static,
315    M::Session: Send + Sync + Unpin + 'static,
316{
317    type Item = String;
318
319    fn poll_next(
320        self: Pin<&mut Self>,
321        cx: &mut std::task::Context<'_>,
322    ) -> std::task::Poll<Option<Self::Item>> {
323        let myself = Pin::get_mut(self);
324        myself.ensure_unstructured_task_started();
325        {
326            if let Some(token) = &mut myself.queued_tokens {
327                if let Poll::Ready(Some(token)) = token.poll_next_unpin(cx) {
328                    return Poll::Ready(Some(token));
329                }
330            }
331        }
332        let mut task = myself.task.get().unwrap().write().unwrap();
333        task.poll_unpin(cx).map(|_| None)
334    }
335}
336
337impl<M, Sampler> IntoFuture for TextCompletionBuilder<M, NoConstraints, Sampler>
338where
339    Sampler: Send + Unpin + 'static,
340    M: TextCompletionModel<Sampler> + Send + Sync + Unpin + 'static,
341    M::Session: Clone + Send + Sync + Unpin + 'static,
342{
343    type Output = Result<String, M::Error>;
344    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
345
346    fn into_future(mut self) -> Self::IntoFuture {
347        self.ensure_unstructured_task_started();
348
349        Box::pin(async move {
350            if self.result.is_none() {
351                self.task.into_inner().unwrap().into_inner().unwrap().await;
352            }
353            let result = self.result.take().unwrap().await.unwrap();
354            result.map(|boxed| *boxed.downcast::<String>().unwrap())
355        })
356    }
357}
358
359impl<M, Constraints, Sampler> TextCompletionBuilder<M, Constraints, Sampler>
360where
361    Constraints: ModelConstraints + Send + Sync + Unpin + 'static,
362    Sampler: Send + Unpin + 'static,
363    M: StructuredTextCompletionModel<Constraints, Sampler> + Send + Sync + Unpin + 'static,
364    M::Session: Clone + Send + Sync + Unpin + 'static,
365    Constraints::Output: Send + 'static,
366{
367    fn ensure_structured_task_started(&mut self) {
368        if self.task.get().is_none() {
369            let text = std::mem::take(&mut self.text);
370            let model = self
371                .model
372                .take()
373                .expect("TextCompletionBuilder cannot be turned into a future twice");
374            let sampler = self
375                .sampler
376                .take()
377                .expect("TextCompletionBuilder cannot be turned into a future twice");
378            let constraints = self
379                .constraints
380                .take()
381                .expect("TextCompletionBuilder cannot be turned into a future twice");
382            let (mut tx, rx) = futures_channel::mpsc::unbounded();
383            let (result_tx, result_rx) = futures_channel::oneshot::channel();
384            self.queued_tokens = Some(rx);
385            self.result = Some(result_rx);
386            let on_token = move |tok: String| {
387                _ = tx.start_send(tok);
388                Ok(())
389            };
390            let future = async move {
391                let mut session = model.new_session()?;
392                model
393                    .stream_text_with_callback_and_parser(
394                        &mut session,
395                        &text,
396                        sampler,
397                        constraints,
398                        on_token,
399                    )
400                    .await
401                    .map(|value| Box::new(value) as Box<dyn Any + Send>)
402            };
403            let wrapped = async move {
404                let result: Result<Box<dyn Any + Send>, M::Error> = future.await;
405                _ = result_tx.send(result);
406            };
407            let task = Box::pin(wrapped);
408            self.task
409                .set(RwLock::new(task))
410                .unwrap_or_else(|_| panic!("Task already set"));
411        }
412    }
413}
414
415impl<M, Constraints, Sampler> Stream for TextCompletionBuilder<M, Constraints, Sampler>
416where
417    Constraints: ModelConstraints + Send + Sync + Unpin + 'static,
418    Sampler: Send + Unpin + 'static,
419    M: StructuredTextCompletionModel<Constraints, Sampler> + Send + Sync + Unpin + 'static,
420    M::Session: Clone + Send + Sync + Unpin + 'static,
421    Constraints::Output: Send + 'static,
422{
423    type Item = String;
424
425    fn poll_next(
426        self: Pin<&mut Self>,
427        cx: &mut std::task::Context<'_>,
428    ) -> std::task::Poll<Option<Self::Item>> {
429        let myself = Pin::get_mut(self);
430        myself.ensure_structured_task_started();
431        {
432            if let Some(token) = &mut myself.queued_tokens {
433                if let Poll::Ready(Some(token)) = token.poll_next_unpin(cx) {
434                    return Poll::Ready(Some(token));
435                }
436            }
437        }
438        let mut task = myself.task.get().unwrap().write().unwrap();
439        task.poll_unpin(cx).map(|_| None)
440    }
441}
442
443impl<M, Constraints, Sampler> IntoFuture for TextCompletionBuilder<M, Constraints, Sampler>
444where
445    Constraints: ModelConstraints + Send + Sync + Unpin + 'static,
446    Sampler: Send + Unpin + 'static,
447    M: StructuredTextCompletionModel<Constraints, Sampler> + Send + Sync + Unpin + 'static,
448    M::Session: Clone + Send + Sync + Unpin + 'static,
449    Constraints::Output: Send + 'static,
450{
451    type Output = Result<Constraints::Output, M::Error>;
452    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
453
454    fn into_future(mut self) -> Self::IntoFuture {
455        self.ensure_structured_task_started();
456
457        Box::pin(async move {
458            if self.result.is_none() {
459                self.task.into_inner().unwrap().into_inner().unwrap().await;
460            }
461            let result = self.result.take().unwrap().await.unwrap();
462            result.map(|boxed| *boxed.downcast::<Constraints::Output>().unwrap())
463        })
464    }
465}