#[cfg(feature = "async")]
use futures::future::BoxFuture;
#[cfg(feature = "async")]
use std::sync::Arc;
#[cfg(feature = "async")]
use tokio::task;
use crate::{Context, ContextParams, Model, MullamaError, SamplerChain, SamplerParams, TokenId};
#[cfg(feature = "async")]
#[derive(Clone)]
pub struct AsyncModel {
inner: Arc<Model>,
}
#[cfg(feature = "async")]
impl AsyncModel {
pub async fn load(path: impl AsRef<str> + Send + 'static) -> Result<Self, MullamaError> {
let path = path.as_ref().to_string();
let model = task::spawn_blocking(move || Model::load(&path))
.await
.map_err(|e| MullamaError::ModelLoadError(format!("Async task failed: {}", e)))?;
match model {
Ok(model) => Ok(AsyncModel {
inner: Arc::new(model),
}),
Err(e) => Err(e),
}
}
pub async fn load_with_params(
path: impl AsRef<str> + Send + 'static,
params: crate::ModelParams,
) -> Result<Self, MullamaError> {
let path = path.as_ref().to_string();
let model = task::spawn_blocking(move || Model::load_with_params(&path, params))
.await
.map_err(|e| MullamaError::ModelLoadError(format!("Async task failed: {}", e)))?;
match model {
Ok(model) => Ok(AsyncModel {
inner: Arc::new(model),
}),
Err(e) => Err(e),
}
}
pub async fn create_context_async(
&self,
params: ContextParams,
) -> Result<AsyncContext, MullamaError> {
let model = self.inner.clone();
let context = task::spawn_blocking(move || Context::new(model, params))
.await
.map_err(|e| MullamaError::ContextError(format!("Async task failed: {}", e)))?;
match context {
Ok(context) => Ok(AsyncContext {
inner: context,
model: self.inner.clone(),
}),
Err(e) => Err(e),
}
}
pub async fn generate_async(
&self,
prompt: &str,
max_tokens: usize,
) -> Result<String, MullamaError> {
let model = self.inner.clone();
let prompt = prompt.to_string();
task::spawn_blocking(move || {
let ctx_params = ContextParams {
n_ctx: 2048,
..ContextParams::default()
};
let mut context = Context::new(model.clone(), ctx_params)?;
let sampler_params = SamplerParams {
temperature: 0.7,
..SamplerParams::default()
};
let mut sampler = sampler_params.build_chain(model.clone())?;
let tokens = model.tokenize(&prompt, true, false)?;
if tokens.is_empty() {
return Err(MullamaError::InvalidInput(
"Prompt produced no tokens".to_string(),
));
}
context.decode(&tokens)?;
let mut result = String::new();
let eos = model.token_eos();
for _ in 0..max_tokens {
let next_token = sampler.sample(&mut context, -1);
if next_token == eos {
break;
}
let text = model.token_to_str(next_token, 0, false)?;
result.push_str(&text);
sampler.accept(next_token);
context.decode(std::slice::from_ref(&next_token))?;
}
Ok(result)
})
.await
.map_err(|e| MullamaError::GenerationError(format!("Async task failed: {}", e)))?
}
pub fn model(&self) -> &Arc<Model> {
&self.inner
}
pub async fn info_async(&self) -> ModelInfo {
let model = self.inner.clone();
task::spawn_blocking(move || ModelInfo {
vocab_size: model.vocab_size(),
n_ctx_train: model.n_ctx_train(),
n_embd: model.n_embd(),
n_layer: model.n_layer(),
})
.await
.unwrap_or_default()
}
}
#[cfg(feature = "async")]
pub struct AsyncContext {
inner: Context,
model: Arc<Model>,
}
#[cfg(feature = "async")]
impl AsyncContext {
pub fn new(inner: Context, model: Arc<Model>) -> Self {
Self { inner, model }
}
pub fn inner_mut(&mut self) -> &mut Context {
&mut self.inner
}
pub fn inner(&self) -> &Context {
&self.inner
}
pub async fn generate_with_sampler_async(
mut self,
tokens: &[TokenId],
max_tokens: usize,
mut sampler: SamplerChain,
) -> Result<String, MullamaError> {
let model = self.model.clone();
let tokens = tokens.to_vec();
task::spawn_blocking(move || {
let mut result = String::new();
if tokens.is_empty() {
return Err(MullamaError::InvalidInput(
"Cannot generate from empty prompt".to_string(),
));
}
self.inner.decode(&tokens)?;
let eos = model.token_eos();
for _ in 0..max_tokens {
let next_token = sampler.sample(&mut self.inner, -1);
if next_token == eos {
break;
}
let text = model.token_to_str(next_token, 0, false)?;
result.push_str(&text);
sampler.accept(next_token);
self.inner.decode(std::slice::from_ref(&next_token))?;
}
Ok(result)
})
.await
.map_err(|e| MullamaError::GenerationError(format!("Async task failed: {}", e)))?
}
pub fn into_inner(self) -> Context {
self.inner
}
}
#[cfg(feature = "async")]
#[derive(Debug, Clone, Default)]
pub struct ModelInfo {
pub vocab_size: i32,
pub n_ctx_train: i32,
pub n_embd: i32,
pub n_layer: i32,
}
#[cfg(feature = "async")]
pub type ProgressCallback = Box<dyn Fn(f32) -> BoxFuture<'static, ()> + Send + Sync>;
#[cfg(feature = "async")]
#[derive(Clone, Default)]
pub struct AsyncConfig {
pub progress_callback: Option<Arc<ProgressCallback>>,
pub cancellation_token: Option<tokio_util::sync::CancellationToken>,
}
#[cfg(feature = "async")]
impl std::fmt::Debug for AsyncConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncConfig")
.field(
"progress_callback",
&self.progress_callback.as_ref().map(|_| "<callback>"),
)
.field("cancellation_token", &self.cancellation_token)
.finish()
}
}
#[cfg(feature = "async")]
#[cfg(not(feature = "async"))]
compile_error!("Async support requires the 'async' feature to be enabled");