use std::sync::Arc;
use crate::error::Error as SdkError;
use crate::model::Stream;
use crate::{ChatCompletionResponse, Model, ModelBuilder, RequestLike, Response, TextModelBuilder};
pub struct BlockingModel {
inner: Model,
rt: Arc<tokio::runtime::Runtime>,
}
impl BlockingModel {
pub fn from_builder(builder: TextModelBuilder) -> crate::error::Result<Self> {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| SdkError::Inference(e.into()))?;
let inner = rt
.block_on(builder.build())
.map_err(|e| SdkError::ModelLoad(e.into()))?;
Ok(Self {
inner,
rt: Arc::new(rt),
})
}
pub fn from_auto_builder(builder: ModelBuilder) -> crate::error::Result<Self> {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.map_err(|e| SdkError::Inference(e.into()))?;
let inner = rt
.block_on(builder.build())
.map_err(|e| SdkError::ModelLoad(e.into()))?;
Ok(Self {
inner,
rt: Arc::new(rt),
})
}
pub fn new(model: Model, rt: Arc<tokio::runtime::Runtime>) -> Self {
Self { inner: model, rt }
}
pub fn send_chat_request<R: RequestLike>(
&self,
request: R,
) -> crate::error::Result<ChatCompletionResponse> {
self.rt.block_on(self.inner.send_chat_request(request))
}
pub fn chat(&self, message: impl ToString) -> crate::error::Result<String> {
self.rt.block_on(self.inner.chat(message))
}
pub fn stream_chat_request<R: RequestLike>(
&self,
request: R,
) -> crate::error::Result<BlockingStream> {
let stream: Stream<'_> = self.rt.block_on(self.inner.stream_chat_request(request))?;
Ok(BlockingStream {
rx: stream.into_receiver(),
rt: self.rt.clone(),
})
}
pub fn generate_structured<T>(
&self,
messages: impl Into<crate::RequestBuilder>,
) -> crate::error::Result<T>
where
T: serde::de::DeserializeOwned + schemars::JsonSchema,
{
self.rt
.block_on(self.inner.generate_structured::<T>(messages))
}
pub fn inner(&self) -> &Model {
&self.inner
}
}
pub struct BlockingStream {
rx: tokio::sync::mpsc::Receiver<Response>,
rt: Arc<tokio::runtime::Runtime>,
}
impl Iterator for BlockingStream {
type Item = Response;
fn next(&mut self) -> Option<Self::Item> {
self.rt.block_on(self.rx.recv())
}
}