use crate::error::{ModelError, Result};
use crate::local::backends::LocalBackend;
use crate::local::sampling::do_sample;
use candle_core::{DType, Device, Tensor};
use tracing::debug;
pub struct BatchRequest {
pub prompt: String,
pub input_ids: Vec<u32>,
pub max_tokens: usize,
pub temperature: f32,
pub eos_token: Option<u32>,
}
pub struct BatchResult {
pub text: String,
pub tokens: Vec<u32>,
pub token_count: usize,
}
pub fn generate_batch<F>(
backend: &mut LocalBackend,
requests: Vec<BatchRequest>,
device: &Device,
sample_fn: F,
) -> Result<Vec<BatchResult>>
where
F: Fn(&[f32], f32, f32, Option<usize>) -> Result<u32> + Copy,
{
if requests.is_empty() {
return Ok(Vec::new());
}
let mut results = Vec::with_capacity(requests.len());
for request in requests {
let generated_tokens = match backend {
LocalBackend::Llama { model, config } => generate_llama_single(
model,
config,
&request.input_ids,
request.max_tokens,
request.temperature,
0.9,
None,
request.eos_token,
device,
sample_fn,
)?,
LocalBackend::Mistral { model, .. } => generate_mistral_single(
model,
&request.input_ids,
request.max_tokens,
request.temperature,
0.9,
None,
request.eos_token,
device,
sample_fn,
)?,
LocalBackend::Phi3 { model, .. } => generate_phi3_single(
model,
&request.input_ids,
request.max_tokens,
request.temperature,
0.9,
None,
request.eos_token,
device,
sample_fn,
)?,
LocalBackend::Mamba { model, config } => generate_mamba_single(
model,
config,
&request.input_ids,
request.max_tokens,
request.temperature,
0.9,
None,
request.eos_token,
device,
sample_fn,
)?,
LocalBackend::GraniteMoeHybrid { model, config } => generate_granite_moe_single(
model,
config,
&request.input_ids,
request.max_tokens,
request.temperature,
0.9,
None,
request.eos_token,
device,
sample_fn,
)?,
LocalBackend::Bert { .. } => {
return Err(ModelError::LocalModelError(
"Encoder-only models (BERT) cannot generate text. Use embeddings instead."
.to_string(),
));
}
#[cfg(feature = "gguf")]
LocalBackend::Gguf { .. } => {
return Err(ModelError::LocalModelError(
"GGUF batch generation not yet supported.".to_string(),
));
}
};
results.push(BatchResult {
text: String::new(), tokens: generated_tokens.clone(),
token_count: generated_tokens.len(),
});
}
Ok(results)
}
pub fn generate_batch_stream<F, C>(
backend: &mut LocalBackend,
requests: Vec<BatchRequest>,
device: &Device,
sample_fn: C,
mut callback: F,
) -> Result<Vec<BatchResult>>
where
F: FnMut(usize, String) -> Result<()>,
C: Fn(&[f32], f32, f32, Option<usize>) -> Result<u32> + Copy,
{
if requests.is_empty() {
return Ok(Vec::new());
}
let mut results = Vec::with_capacity(requests.len());
for (idx, request) in requests.into_iter().enumerate() {
let generated_tokens = match backend {
LocalBackend::Llama { model, config } => generate_llama_single(
model,
config,
&request.input_ids,
request.max_tokens,
request.temperature,
0.9,
None,
request.eos_token,
device,
sample_fn,
)?,
LocalBackend::Mistral { model, .. } => generate_mistral_single(
model,
&request.input_ids,
request.max_tokens,
request.temperature,
0.9,
None,
request.eos_token,
device,
sample_fn,
)?,
LocalBackend::Phi3 { model, .. } => generate_phi3_single(
model,
&request.input_ids,
request.max_tokens,
request.temperature,
0.9,
None,
request.eos_token,
device,
sample_fn,
)?,
LocalBackend::Mamba { model, config } => generate_mamba_single(
model,
config,
&request.input_ids,
request.max_tokens,
request.temperature,
0.9,
None,
request.eos_token,
device,
sample_fn,
)?,
LocalBackend::GraniteMoeHybrid { model, config } => generate_granite_moe_single(
model,
config,
&request.input_ids,
request.max_tokens,
request.temperature,
0.9,
None,
request.eos_token,
device,
sample_fn,
)?,
LocalBackend::Bert { .. } => {
return Err(ModelError::LocalModelError(
"Encoder-only models (BERT) cannot generate text. Use embeddings instead."
.to_string(),
));
}
#[cfg(feature = "gguf")]
LocalBackend::Gguf { .. } => {
return Err(ModelError::LocalModelError(
"GGUF batch generation not yet supported.".to_string(),
));
}
};
for token in &generated_tokens {
callback(idx, format!("<token_{}>", token))?;
}
results.push(BatchResult {
text: String::new(),
tokens: generated_tokens.clone(),
token_count: generated_tokens.len(),
});
}
Ok(results)
}
fn generate_llama_single<F>(
model: &candle_transformers::models::llama::Llama,
config: &candle_transformers::models::llama::Config,
input_ids: &[u32],
max_tokens: usize,
temperature: f32,
top_p: f32,
top_k: Option<usize>,
eos_token: Option<u32>,
device: &Device,
sample_fn: F,
) -> Result<Vec<u32>>
where
F: Fn(&[f32], f32, f32, Option<usize>) -> Result<u32>,
{
use candle_transformers::models::llama::Cache;
let mut cache = Cache::new(true, DType::F32, config, device)
.map_err(|e| ModelError::LocalModelError(format!("Failed to create cache: {}", e)))?;
let prompt_tensor = Tensor::new(input_ids, device)?.unsqueeze(0)?;
let logits = model.forward(&prompt_tensor, 0, &mut cache)?;
let logits_vec = logits.to_dtype(DType::F32)?.to_vec2::<f32>()?;
let last_logits = &logits_vec[0];
let next = sample_fn(last_logits, temperature, top_p, top_k)?;
let mut generated = vec![next];
for idx in 1..max_tokens {
if let Some(eos) = eos_token {
if next == eos {
break;
}
}
let tensor = Tensor::new(&[next], device)?.unsqueeze(0)?;
let logits = model.forward(&tensor, input_ids.len() + idx - 1, &mut cache)?;
let logits_vec = logits.to_dtype(DType::F32)?.to_vec2::<f32>()?;
let token_logits = &logits_vec[0];
let next = sample_fn(token_logits, temperature, top_p, top_k)?;
generated.push(next);
}
Ok(generated)
}
fn generate_mistral_single<F>(
model: &mut candle_transformers::models::mistral::Model,
input_ids: &[u32],
max_tokens: usize,
temperature: f32,
top_p: f32,
top_k: Option<usize>,
eos_token: Option<u32>,
device: &Device,
sample_fn: F,
) -> Result<Vec<u32>>
where
F: Fn(&[f32], f32, f32, Option<usize>) -> Result<u32>,
{
let prompt_tensor = Tensor::new(input_ids, device)?.unsqueeze(0)?;
let logits = model.forward(&prompt_tensor, 0)?;
let logits_vec = logits.to_dtype(DType::F32)?.to_vec3::<f32>()?;
let seq_len = logits_vec[0].len();
let last_logits = &logits_vec[0][seq_len - 1];
let next = sample_fn(last_logits, temperature, top_p, top_k)?;
let mut generated = vec![next];
for idx in 1..max_tokens {
if let Some(eos) = eos_token {
if next == eos {
break;
}
}
let tensor = Tensor::new(&[next], device)?.unsqueeze(0)?;
let logits = model.forward(&tensor, input_ids.len() + idx)?;
let logits_vec = logits.to_dtype(DType::F32)?.to_vec3::<f32>()?;
let token_logits = &logits_vec[0][0];
let next = sample_fn(token_logits, temperature, top_p, top_k)?;
generated.push(next);
}
Ok(generated)
}
fn generate_phi3_single<F>(
model: &mut candle_transformers::models::phi3::Model,
input_ids: &[u32],
max_tokens: usize,
temperature: f32,
top_p: f32,
top_k: Option<usize>,
eos_token: Option<u32>,
device: &Device,
sample_fn: F,
) -> Result<Vec<u32>>
where
F: Fn(&[f32], f32, f32, Option<usize>) -> Result<u32>,
{
model.clear_kv_cache();
let prompt_tensor = Tensor::new(input_ids, device)?.unsqueeze(0)?;
let logits = model.forward(&prompt_tensor, 0)?;
let logits_vec = logits.to_dtype(DType::F32)?.to_vec3::<f32>()?;
let seq_len = logits_vec[0].len();
let last_logits = &logits_vec[0][seq_len - 1];
let mut next = sample_fn(last_logits, temperature, top_p, top_k)?;
let mut generated = vec![next];
for idx in 1..max_tokens {
if let Some(eos) = eos_token {
if next == eos {
break;
}
}
let tensor = Tensor::new(&[next], device)?.unsqueeze(0)?;
let logits = model.forward(&tensor, input_ids.len() + idx - 1)?;
let logits_vec = logits.to_dtype(DType::F32)?.to_vec3::<f32>()?;
let token_logits = &logits_vec[0][0];
next = sample_fn(token_logits, temperature, top_p, top_k)?;
generated.push(next);
}
Ok(generated)
}
fn generate_mamba_single<F>(
model: &candle_transformers::models::mamba::Model,
config: &candle_transformers::models::mamba::Config,
input_ids: &[u32],
max_tokens: usize,
temperature: f32,
top_p: f32,
top_k: Option<usize>,
eos_token: Option<u32>,
device: &Device,
sample_fn: F,
) -> Result<Vec<u32>>
where
F: Fn(&[f32], f32, f32, Option<usize>) -> Result<u32>,
{
use candle_transformers::models::mamba::State as MambaState;
let mut state = MambaState::new(1, config, DType::F32, device)
.map_err(|e| ModelError::LocalModelError(format!("Failed to create state: {}", e)))?;
let mut last_logits: Option<Vec<f32>> = None;
for &tok in input_ids.iter() {
let token_tensor = Tensor::new(&[tok], device)?;
let logits = model.forward(&token_tensor, &mut state)?;
let logits_vec = logits.to_dtype(DType::F32)?.to_vec2::<f32>()?;
let row = logits_vec
.get(0)
.ok_or_else(|| ModelError::LocalModelError("Mamba logits were empty".to_string()))?;
last_logits = Some(row.clone());
}
let last_logits =
last_logits.ok_or_else(|| ModelError::LocalModelError("Empty prompt".to_string()))?;
let token_logits = last_logits.as_slice();
let mut next = sample_fn(token_logits, temperature, top_p, top_k)?;
let mut generated = vec![next];
for _idx in 1..max_tokens {
if let Some(eos) = eos_token {
if next == eos {
break;
}
}
let token_tensor = Tensor::new(&[next], device)?;
let logits = model.forward(&token_tensor, &mut state)?;
let logits_vec = logits.to_dtype(DType::F32)?.to_vec2::<f32>()?;
let token_logits = logits_vec
.get(0)
.ok_or_else(|| ModelError::LocalModelError("Mamba logits were empty".to_string()))?;
next = sample_fn(token_logits, temperature, top_p, top_k)?;
generated.push(next);
}
Ok(generated)
}
fn generate_granite_moe_single<F>(
model: &candle_transformers::models::granitemoehybrid::GraniteMoeHybrid,
config: &candle_transformers::models::granitemoehybrid::GraniteMoeHybridInternalConfig,
input_ids: &[u32],
max_tokens: usize,
temperature: f32,
top_p: f32,
top_k: Option<usize>,
eos_token: Option<u32>,
device: &Device,
sample_fn: F,
) -> Result<Vec<u32>>
where
F: Fn(&[f32], f32, f32, Option<usize>) -> Result<u32>,
{
use candle_transformers::models::granitemoehybrid::GraniteMoeHybridCache;
let mut cache = GraniteMoeHybridCache::new(true, DType::F32, config, device)
.map_err(|e| ModelError::LocalModelError(format!("Failed to create cache: {}", e)))?;
let prompt_tensor = Tensor::new(input_ids, device)?.unsqueeze(0)?;
let logits = model.forward(&prompt_tensor, 0, &mut cache)?;
let logits_vec = logits.to_dtype(DType::F32)?.to_vec2::<f32>()?;
let token_logits = &logits_vec[0];
let mut next = sample_fn(token_logits, temperature, top_p, top_k)?;
let mut generated = vec![next];
for idx in 1..max_tokens {
if let Some(eos) = eos_token {
if next == eos {
break;
}
}
let tensor = Tensor::new(&[next], device)?.unsqueeze(0)?;
let logits = model.forward(&tensor, input_ids.len() + idx - 1, &mut cache)?;
let logits_vec = logits.to_dtype(DType::F32)?.to_vec2::<f32>()?;
let token_logits = &logits_vec[0];
next = sample_fn(token_logits, temperature, top_p, top_k)?;
generated.push(next);
}
Ok(generated)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_batch() {
let results: Vec<BatchResult> = Vec::new();
assert!(results.is_empty());
}
#[test]
fn test_batch_request_creation() {
let request = BatchRequest {
prompt: "Hello".to_string(),
input_ids: vec![1, 2, 3],
max_tokens: 100,
temperature: 0.7,
eos_token: Some(0),
};
assert_eq!(request.input_ids.len(), 3);
assert_eq!(request.max_tokens, 100);
}
#[test]
fn test_batch_result_creation() {
let result = BatchResult {
text: "Hello world".to_string(),
tokens: vec![1, 2, 3, 4, 5],
token_count: 5,
};
assert_eq!(result.tokens.len(), 5);
assert_eq!(result.token_count, 5);
}
}