kalosm_language_model/model/
ext.rs1use futures_channel::mpsc::UnboundedReceiver;
2use futures_channel::oneshot::Receiver;
3use futures_util::Future;
4use futures_util::FutureExt;
5use futures_util::Stream;
6use futures_util::StreamExt;
7use std::any::Any;
8use std::error::Error;
9use std::future::IntoFuture;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::sync::Mutex;
13use std::sync::OnceLock;
14use std::sync::RwLock;
15use std::task::Poll;
16
17use crate::GenerationParameters;
18use crate::ModelConstraints;
19use crate::NoConstraints;
20
21use super::BoxedStructuredTextCompletionModel;
22use super::BoxedTextCompletionModel;
23use super::CreateDefaultCompletionConstraintsForType;
24use super::CreateTextCompletionSession;
25use super::StructuredTextCompletionModel;
26use super::TextCompletionModel;
27use super::TextCompletionSession;
28
29#[doc = include_str!("../../docs/completion.md")]
30pub trait TextCompletionModelExt: CreateTextCompletionSession {
31 fn complete(&self, text: impl ToString) -> TextCompletionBuilder<Self>
33 where
34 Self: Clone,
35 {
36 TextCompletionBuilder {
38 text: text.to_string(),
39 model: Some(self.clone()),
40 constraints: None,
41 sampler: Some(GenerationParameters::default()),
42 task: OnceLock::new(),
43 queued_tokens: None,
44 result: None,
45 }
46 }
47
48 fn boxed_completion_model(self) -> BoxedTextCompletionModel
51 where
52 Self: TextCompletionModel<
53 Error: Send + Sync + std::error::Error + 'static,
54 Session: TextCompletionSession<Error: std::error::Error + Send + Sync + 'static>
55 + Clone
56 + Send
57 + Sync
58 + 'static,
59 > + Sized
60 + Send
61 + Sync
62 + 'static,
63 {
64 BoxedTextCompletionModel::new(self)
65 }
66
67 fn boxed_typed_completion_model<T>(self) -> BoxedStructuredTextCompletionModel<T>
70 where
71 Self: StructuredTextCompletionModel<
72 Self::DefaultConstraints,
73 Error: Send + Sync + Error + 'static,
74 Session: TextCompletionSession<Error: Error + Send + Sync + 'static>
75 + Clone
76 + Send
77 + Sync
78 + 'static,
79 > + CreateDefaultCompletionConstraintsForType<T>
80 + Sized
81 + Send
82 + Sync
83 + 'static,
84 T: 'static,
85 {
86 BoxedStructuredTextCompletionModel::new(self)
87 }
88}
89
90impl<M: CreateTextCompletionSession> TextCompletionModelExt for M {}
91
92pub struct TextCompletionBuilder<
130 M: CreateTextCompletionSession,
131 Constraints = NoConstraints,
132 Sampler = GenerationParameters,
133> {
134 text: String,
135 model: Option<M>,
136 constraints: Option<Constraints>,
137 sampler: Option<Sampler>,
138 task: OnceLock<RwLock<Pin<Box<dyn Future<Output = ()> + Send>>>>,
139 #[allow(clippy::type_complexity)]
140 result: Option<Receiver<Result<Box<dyn Any + Send>, M::Error>>>,
141 queued_tokens: Option<UnboundedReceiver<String>>,
142}
143
144impl<M: CreateTextCompletionSession, Constraints, Sampler>
145 TextCompletionBuilder<M, Constraints, Sampler>
146{
147 pub fn with_constraints<NewConstraints: ModelConstraints>(
174 self,
175 constraints: NewConstraints,
176 ) -> TextCompletionBuilder<M, NewConstraints, Sampler> {
177 TextCompletionBuilder {
178 text: self.text,
179 model: self.model,
180 constraints: Some(constraints),
181 sampler: self.sampler,
182 queued_tokens: None,
183 result: None,
184 task: OnceLock::new(),
185 }
186 }
187
188 pub fn typed<T>(
214 self,
215 ) -> TextCompletionBuilder<
216 M,
217 <M as CreateDefaultCompletionConstraintsForType<T>>::DefaultConstraints,
218 Sampler,
219 >
220 where
221 M: CreateDefaultCompletionConstraintsForType<T>,
222 {
223 self.with_constraints(M::create_default_constraints())
224 }
225
226 pub fn with_sampler<NewSampler>(
245 self,
246 sampler: NewSampler,
247 ) -> TextCompletionBuilder<M, Constraints, NewSampler> {
248 TextCompletionBuilder {
249 text: self.text,
250 model: self.model,
251 constraints: self.constraints,
252 sampler: Some(sampler),
253 queued_tokens: None,
254 result: None,
255 task: OnceLock::new(),
256 }
257 }
258}
259
260impl<M, Sampler> TextCompletionBuilder<M, NoConstraints, Sampler>
261where
262 Sampler: Send + Unpin + 'static,
263 M: TextCompletionModel<Sampler> + Send + Sync + Unpin + 'static,
264 M::Session: Send + Sync + Unpin + 'static,
265{
266 fn ensure_unstructured_task_started(&mut self) {
267 if self.task.get().is_none() {
268 let text = std::mem::take(&mut self.text);
269 let model = self
270 .model
271 .take()
272 .expect("TextCompletionBuilder cannot be turned into a future twice");
273 let sampler = self
274 .sampler
275 .take()
276 .expect("TextCompletionBuilder cannot be turned into a future twice");
277 let (mut tx, rx) = futures_channel::mpsc::unbounded();
278 let (result_tx, result_rx) = futures_channel::oneshot::channel();
279 self.queued_tokens = Some(rx);
280 self.result = Some(result_rx);
281 let all_text = Arc::new(Mutex::new(String::new()));
282 let on_token = {
283 let all_text = all_text.clone();
284 move |tok: String| {
285 all_text.lock().unwrap().push_str(&tok);
286 _ = tx.start_send(tok);
287 Ok(())
288 }
289 };
290 let future = async move {
291 let mut session = model.new_session()?;
292 model
293 .stream_text_with_callback(&mut session, &text, sampler, on_token)
294 .await?;
295 let mut all_text = all_text.lock().unwrap();
296 let all_text = std::mem::take(&mut *all_text);
297 Ok(Box::new(all_text) as Box<dyn Any + Send>)
298 };
299 let wrapped = async move {
300 let result: Result<Box<dyn Any + Send>, M::Error> = future.await;
301 _ = result_tx.send(result);
302 };
303 let task = Box::pin(wrapped);
304 self.task
305 .set(RwLock::new(task))
306 .unwrap_or_else(|_| panic!("Task already set"));
307 }
308 }
309}
310
311impl<M, Sampler> Stream for TextCompletionBuilder<M, NoConstraints, Sampler>
312where
313 Sampler: Send + Unpin + 'static,
314 M: TextCompletionModel<Sampler> + Send + Sync + Unpin + 'static,
315 M::Session: Send + Sync + Unpin + 'static,
316{
317 type Item = String;
318
319 fn poll_next(
320 self: Pin<&mut Self>,
321 cx: &mut std::task::Context<'_>,
322 ) -> std::task::Poll<Option<Self::Item>> {
323 let myself = Pin::get_mut(self);
324 myself.ensure_unstructured_task_started();
325 {
326 if let Some(token) = &mut myself.queued_tokens {
327 if let Poll::Ready(Some(token)) = token.poll_next_unpin(cx) {
328 return Poll::Ready(Some(token));
329 }
330 }
331 }
332 let mut task = myself.task.get().unwrap().write().unwrap();
333 task.poll_unpin(cx).map(|_| None)
334 }
335}
336
337impl<M, Sampler> IntoFuture for TextCompletionBuilder<M, NoConstraints, Sampler>
338where
339 Sampler: Send + Unpin + 'static,
340 M: TextCompletionModel<Sampler> + Send + Sync + Unpin + 'static,
341 M::Session: Clone + Send + Sync + Unpin + 'static,
342{
343 type Output = Result<String, M::Error>;
344 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
345
346 fn into_future(mut self) -> Self::IntoFuture {
347 self.ensure_unstructured_task_started();
348
349 Box::pin(async move {
350 if self.result.is_none() {
351 self.task.into_inner().unwrap().into_inner().unwrap().await;
352 }
353 let result = self.result.take().unwrap().await.unwrap();
354 result.map(|boxed| *boxed.downcast::<String>().unwrap())
355 })
356 }
357}
358
359impl<M, Constraints, Sampler> TextCompletionBuilder<M, Constraints, Sampler>
360where
361 Constraints: ModelConstraints + Send + Sync + Unpin + 'static,
362 Sampler: Send + Unpin + 'static,
363 M: StructuredTextCompletionModel<Constraints, Sampler> + Send + Sync + Unpin + 'static,
364 M::Session: Clone + Send + Sync + Unpin + 'static,
365 Constraints::Output: Send + 'static,
366{
367 fn ensure_structured_task_started(&mut self) {
368 if self.task.get().is_none() {
369 let text = std::mem::take(&mut self.text);
370 let model = self
371 .model
372 .take()
373 .expect("TextCompletionBuilder cannot be turned into a future twice");
374 let sampler = self
375 .sampler
376 .take()
377 .expect("TextCompletionBuilder cannot be turned into a future twice");
378 let constraints = self
379 .constraints
380 .take()
381 .expect("TextCompletionBuilder cannot be turned into a future twice");
382 let (mut tx, rx) = futures_channel::mpsc::unbounded();
383 let (result_tx, result_rx) = futures_channel::oneshot::channel();
384 self.queued_tokens = Some(rx);
385 self.result = Some(result_rx);
386 let on_token = move |tok: String| {
387 _ = tx.start_send(tok);
388 Ok(())
389 };
390 let future = async move {
391 let mut session = model.new_session()?;
392 model
393 .stream_text_with_callback_and_parser(
394 &mut session,
395 &text,
396 sampler,
397 constraints,
398 on_token,
399 )
400 .await
401 .map(|value| Box::new(value) as Box<dyn Any + Send>)
402 };
403 let wrapped = async move {
404 let result: Result<Box<dyn Any + Send>, M::Error> = future.await;
405 _ = result_tx.send(result);
406 };
407 let task = Box::pin(wrapped);
408 self.task
409 .set(RwLock::new(task))
410 .unwrap_or_else(|_| panic!("Task already set"));
411 }
412 }
413}
414
415impl<M, Constraints, Sampler> Stream for TextCompletionBuilder<M, Constraints, Sampler>
416where
417 Constraints: ModelConstraints + Send + Sync + Unpin + 'static,
418 Sampler: Send + Unpin + 'static,
419 M: StructuredTextCompletionModel<Constraints, Sampler> + Send + Sync + Unpin + 'static,
420 M::Session: Clone + Send + Sync + Unpin + 'static,
421 Constraints::Output: Send + 'static,
422{
423 type Item = String;
424
425 fn poll_next(
426 self: Pin<&mut Self>,
427 cx: &mut std::task::Context<'_>,
428 ) -> std::task::Poll<Option<Self::Item>> {
429 let myself = Pin::get_mut(self);
430 myself.ensure_structured_task_started();
431 {
432 if let Some(token) = &mut myself.queued_tokens {
433 if let Poll::Ready(Some(token)) = token.poll_next_unpin(cx) {
434 return Poll::Ready(Some(token));
435 }
436 }
437 }
438 let mut task = myself.task.get().unwrap().write().unwrap();
439 task.poll_unpin(cx).map(|_| None)
440 }
441}
442
443impl<M, Constraints, Sampler> IntoFuture for TextCompletionBuilder<M, Constraints, Sampler>
444where
445 Constraints: ModelConstraints + Send + Sync + Unpin + 'static,
446 Sampler: Send + Unpin + 'static,
447 M: StructuredTextCompletionModel<Constraints, Sampler> + Send + Sync + Unpin + 'static,
448 M::Session: Clone + Send + Sync + Unpin + 'static,
449 Constraints::Output: Send + 'static,
450{
451 type Output = Result<Constraints::Output, M::Error>;
452 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
453
454 fn into_future(mut self) -> Self::IntoFuture {
455 self.ensure_structured_task_started();
456
457 Box::pin(async move {
458 if self.result.is_none() {
459 self.task.into_inner().unwrap().into_inner().unwrap().await;
460 }
461 let result = self.result.take().unwrap().await.unwrap();
462 result.map(|boxed| *boxed.downcast::<Constraints::Output>().unwrap())
463 })
464 }
465}