kalosm_language_model/model/
ext.rsuse futures_channel::mpsc::UnboundedReceiver;
use futures_channel::oneshot::Receiver;
use futures_util::Future;
use futures_util::FutureExt;
use futures_util::Stream;
use futures_util::StreamExt;
use std::any::Any;
use std::error::Error;
use std::future::IntoFuture;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::OnceLock;
use std::sync::RwLock;
use std::task::Poll;
use crate::GenerationParameters;
use crate::ModelConstraints;
use crate::NoConstraints;
use super::BoxedStructuredTextCompletionModel;
use super::BoxedTextCompletionModel;
use super::CreateDefaultCompletionConstraintsForType;
use super::CreateTextCompletionSession;
use super::StructuredTextCompletionModel;
use super::TextCompletionModel;
use super::TextCompletionSession;
#[doc = include_str!("../../docs/completion.md")]
pub trait TextCompletionModelExt: CreateTextCompletionSession {
fn complete(&self, text: impl ToString) -> TextCompletionBuilder<Self>
where
Self: Clone,
{
TextCompletionBuilder {
text: text.to_string(),
model: Some(self.clone()),
constraints: None,
sampler: Some(GenerationParameters::default()),
task: OnceLock::new(),
queued_tokens: None,
result: None,
}
}
fn boxed_completion_model(self) -> BoxedTextCompletionModel
where
Self: TextCompletionModel<
Error: Send + Sync + std::error::Error + 'static,
Session: TextCompletionSession<Error: std::error::Error + Send + Sync + 'static>
+ Clone
+ Send
+ Sync
+ 'static,
> + Sized
+ Send
+ Sync
+ 'static,
{
BoxedTextCompletionModel::new(self)
}
fn boxed_typed_completion_model<T>(self) -> BoxedStructuredTextCompletionModel<T>
where
Self: StructuredTextCompletionModel<
Self::DefaultConstraints,
Error: Send + Sync + Error + 'static,
Session: TextCompletionSession<Error: Error + Send + Sync + 'static>
+ Clone
+ Send
+ Sync
+ 'static,
> + CreateDefaultCompletionConstraintsForType<T>
+ Sized
+ Send
+ Sync
+ 'static,
T: 'static,
{
BoxedStructuredTextCompletionModel::new(self)
}
}
impl<M: CreateTextCompletionSession> TextCompletionModelExt for M {}
pub struct TextCompletionBuilder<
M: CreateTextCompletionSession,
Constraints = NoConstraints,
Sampler = GenerationParameters,
> {
text: String,
model: Option<M>,
constraints: Option<Constraints>,
sampler: Option<Sampler>,
task: OnceLock<RwLock<Pin<Box<dyn Future<Output = ()> + Send>>>>,
#[allow(clippy::type_complexity)]
result: Option<Receiver<Result<Box<dyn Any + Send>, M::Error>>>,
queued_tokens: Option<UnboundedReceiver<String>>,
}
impl<M: CreateTextCompletionSession, Constraints, Sampler>
TextCompletionBuilder<M, Constraints, Sampler>
{
pub fn with_constraints<NewConstraints: ModelConstraints>(
self,
constraints: NewConstraints,
) -> TextCompletionBuilder<M, NewConstraints, Sampler> {
TextCompletionBuilder {
text: self.text,
model: self.model,
constraints: Some(constraints),
sampler: self.sampler,
queued_tokens: None,
result: None,
task: OnceLock::new(),
}
}
pub fn typed<T>(
self,
) -> TextCompletionBuilder<
M,
<M as CreateDefaultCompletionConstraintsForType<T>>::DefaultConstraints,
Sampler,
>
where
M: CreateDefaultCompletionConstraintsForType<T>,
{
self.with_constraints(M::create_default_constraints())
}
pub fn with_sampler<NewSampler>(
self,
sampler: NewSampler,
) -> TextCompletionBuilder<M, Constraints, NewSampler> {
TextCompletionBuilder {
text: self.text,
model: self.model,
constraints: self.constraints,
sampler: Some(sampler),
queued_tokens: None,
result: None,
task: OnceLock::new(),
}
}
}
impl<M, Sampler> TextCompletionBuilder<M, NoConstraints, Sampler>
where
Sampler: Send + Unpin + 'static,
M: TextCompletionModel<Sampler> + Send + Sync + Unpin + 'static,
M::Session: Send + Sync + Unpin + 'static,
{
fn ensure_unstructured_task_started(&mut self) {
if self.task.get().is_none() {
let text = std::mem::take(&mut self.text);
let model = self
.model
.take()
.expect("TextCompletionBuilder cannot be turned into a future twice");
let sampler = self
.sampler
.take()
.expect("TextCompletionBuilder cannot be turned into a future twice");
let (mut tx, rx) = futures_channel::mpsc::unbounded();
let (result_tx, result_rx) = futures_channel::oneshot::channel();
self.queued_tokens = Some(rx);
self.result = Some(result_rx);
let all_text = Arc::new(Mutex::new(String::new()));
let on_token = {
let all_text = all_text.clone();
move |tok: String| {
all_text.lock().unwrap().push_str(&tok);
_ = tx.start_send(tok);
Ok(())
}
};
let future = async move {
let mut session = model.new_session()?;
model
.stream_text_with_callback(&mut session, &text, sampler, on_token)
.await?;
let mut all_text = all_text.lock().unwrap();
let all_text = std::mem::take(&mut *all_text);
Ok(Box::new(all_text) as Box<dyn Any + Send>)
};
let wrapped = async move {
let result: Result<Box<dyn Any + Send>, M::Error> = future.await;
_ = result_tx.send(result);
};
let task = Box::pin(wrapped);
self.task
.set(RwLock::new(task))
.unwrap_or_else(|_| panic!("Task already set"));
}
}
}
impl<M, Sampler> Stream for TextCompletionBuilder<M, NoConstraints, Sampler>
where
Sampler: Send + Unpin + 'static,
M: TextCompletionModel<Sampler> + Send + Sync + Unpin + 'static,
M::Session: Send + Sync + Unpin + 'static,
{
type Item = String;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let myself = Pin::get_mut(self);
myself.ensure_unstructured_task_started();
{
if let Some(token) = &mut myself.queued_tokens {
if let Poll::Ready(Some(token)) = token.poll_next_unpin(cx) {
return Poll::Ready(Some(token));
}
}
}
let mut task = myself.task.get().unwrap().write().unwrap();
task.poll_unpin(cx).map(|_| None)
}
}
impl<M, Sampler> IntoFuture for TextCompletionBuilder<M, NoConstraints, Sampler>
where
Sampler: Send + Unpin + 'static,
M: TextCompletionModel<Sampler> + Send + Sync + Unpin + 'static,
M::Session: Clone + Send + Sync + Unpin + 'static,
{
type Output = Result<String, M::Error>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
fn into_future(mut self) -> Self::IntoFuture {
self.ensure_unstructured_task_started();
Box::pin(async move {
if self.result.is_none() {
self.task.into_inner().unwrap().into_inner().unwrap().await;
}
let result = self.result.take().unwrap().await.unwrap();
result.map(|boxed| *boxed.downcast::<String>().unwrap())
})
}
}
impl<M, Constraints, Sampler> TextCompletionBuilder<M, Constraints, Sampler>
where
Constraints: ModelConstraints + Send + Sync + Unpin + 'static,
Sampler: Send + Unpin + 'static,
M: StructuredTextCompletionModel<Constraints, Sampler> + Send + Sync + Unpin + 'static,
M::Session: Clone + Send + Sync + Unpin + 'static,
Constraints::Output: Send + 'static,
{
fn ensure_structured_task_started(&mut self) {
if self.task.get().is_none() {
let text = std::mem::take(&mut self.text);
let model = self
.model
.take()
.expect("TextCompletionBuilder cannot be turned into a future twice");
let sampler = self
.sampler
.take()
.expect("TextCompletionBuilder cannot be turned into a future twice");
let constraints = self
.constraints
.take()
.expect("TextCompletionBuilder cannot be turned into a future twice");
let (mut tx, rx) = futures_channel::mpsc::unbounded();
let (result_tx, result_rx) = futures_channel::oneshot::channel();
self.queued_tokens = Some(rx);
self.result = Some(result_rx);
let on_token = move |tok: String| {
_ = tx.start_send(tok);
Ok(())
};
let future = async move {
let mut session = model.new_session()?;
model
.stream_text_with_callback_and_parser(
&mut session,
&text,
sampler,
constraints,
on_token,
)
.await
.map(|value| Box::new(value) as Box<dyn Any + Send>)
};
let wrapped = async move {
let result: Result<Box<dyn Any + Send>, M::Error> = future.await;
_ = result_tx.send(result);
};
let task = Box::pin(wrapped);
self.task
.set(RwLock::new(task))
.unwrap_or_else(|_| panic!("Task already set"));
}
}
}
impl<M, Constraints, Sampler> Stream for TextCompletionBuilder<M, Constraints, Sampler>
where
Constraints: ModelConstraints + Send + Sync + Unpin + 'static,
Sampler: Send + Unpin + 'static,
M: StructuredTextCompletionModel<Constraints, Sampler> + Send + Sync + Unpin + 'static,
M::Session: Clone + Send + Sync + Unpin + 'static,
Constraints::Output: Send + 'static,
{
type Item = String;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let myself = Pin::get_mut(self);
myself.ensure_structured_task_started();
{
if let Some(token) = &mut myself.queued_tokens {
if let Poll::Ready(Some(token)) = token.poll_next_unpin(cx) {
return Poll::Ready(Some(token));
}
}
}
let mut task = myself.task.get().unwrap().write().unwrap();
task.poll_unpin(cx).map(|_| None)
}
}
impl<M, Constraints, Sampler> IntoFuture for TextCompletionBuilder<M, Constraints, Sampler>
where
Constraints: ModelConstraints + Send + Sync + Unpin + 'static,
Sampler: Send + Unpin + 'static,
M: StructuredTextCompletionModel<Constraints, Sampler> + Send + Sync + Unpin + 'static,
M::Session: Clone + Send + Sync + Unpin + 'static,
Constraints::Output: Send + 'static,
{
type Output = Result<Constraints::Output, M::Error>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
fn into_future(mut self) -> Self::IntoFuture {
self.ensure_structured_task_started();
Box::pin(async move {
if self.result.is_none() {
self.task.into_inner().unwrap().into_inner().unwrap().await;
}
let result = self.result.take().unwrap().await.unwrap();
result.map(|boxed| *boxed.downcast::<Constraints::Output>().unwrap())
})
}
}