use std::time::Instant;
use crate::engine::InferenceEngine;
use crate::error::{RuntimeError, RuntimeResult};
use crate::sampling::SamplingParams;
#[derive(Debug, Clone)]
pub struct BatchResult {
pub prompt_tokens: usize,
pub generated_tokens: Vec<u32>,
pub finish_reason: FinishReason,
pub elapsed_seconds: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FinishReason {
MaxTokens,
Eos,
Error,
Timeout,
}
impl std::fmt::Display for FinishReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MaxTokens => write!(f, "max_tokens"),
Self::Eos => write!(f, "eos"),
Self::Error => write!(f, "error"),
Self::Timeout => write!(f, "timeout"),
}
}
}
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub max_batch_size: usize,
pub max_tokens_per_request: usize,
pub timeout_per_request_ms: Option<u64>,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_batch_size: 8,
max_tokens_per_request: 512,
timeout_per_request_ms: Some(30_000),
}
}
}
pub fn batch_generate(
engine: &mut InferenceEngine<'_>,
prompts: &[Vec<u32>],
max_tokens: usize,
) -> Vec<RuntimeResult<BatchResult>> {
prompts
.iter()
.map(|prompt| {
engine.reset();
let start = Instant::now();
match engine.generate(prompt, max_tokens) {
Ok(tokens) => {
let finish_reason = if tokens.len() >= max_tokens {
FinishReason::MaxTokens
} else {
FinishReason::Eos
};
Ok(BatchResult {
prompt_tokens: prompt.len(),
generated_tokens: tokens,
finish_reason,
elapsed_seconds: start.elapsed().as_secs_f64(),
})
}
Err(e) => Err(e),
}
})
.collect()
}
pub fn batch_generate_with_timeout(
engine: &mut InferenceEngine<'_>,
prompts: &[Vec<u32>],
config: &BatchConfig,
) -> Vec<RuntimeResult<BatchResult>> {
let effective_prompts = if prompts.len() > config.max_batch_size {
&prompts[..config.max_batch_size]
} else {
prompts
};
effective_prompts
.iter()
.map(|prompt| {
engine.reset();
let start = Instant::now();
let timeout = config
.timeout_per_request_ms
.map(std::time::Duration::from_millis);
match engine.generate(prompt, config.max_tokens_per_request) {
Ok(tokens) => {
let elapsed = start.elapsed();
let timed_out = timeout.is_some_and(|t| elapsed > t);
let finish_reason = if timed_out {
FinishReason::Timeout
} else if tokens.len() >= config.max_tokens_per_request {
FinishReason::MaxTokens
} else {
FinishReason::Eos
};
Ok(BatchResult {
prompt_tokens: prompt.len(),
generated_tokens: tokens,
finish_reason,
elapsed_seconds: elapsed.as_secs_f64(),
})
}
Err(e) => Err(e),
}
})
.collect()
}
#[derive(Debug, Clone)]
pub struct BatchRequest {
pub prompt_tokens: Vec<u32>,
pub max_tokens: usize,
pub params: SamplingParams,
}
pub struct RequestQueue {
pending: Vec<BatchRequest>,
max_size: usize,
}
impl RequestQueue {
pub fn new(max_size: usize) -> Self {
Self {
pending: Vec::with_capacity(max_size.min(1024)),
max_size: max_size.max(1),
}
}
pub fn push(&mut self, request: BatchRequest) -> Result<(), RuntimeError> {
if self.pending.len() >= self.max_size {
return Err(RuntimeError::Server(format!(
"request queue full (capacity: {})",
self.max_size
)));
}
self.pending.push(request);
Ok(())
}
pub fn drain_batch(&mut self, batch_size: usize) -> Vec<BatchRequest> {
let n = batch_size.min(self.pending.len());
self.pending.drain(..n).collect()
}
pub fn len(&self) -> usize {
self.pending.len()
}
pub fn is_empty(&self) -> bool {
self.pending.is_empty()
}
pub fn is_full(&self) -> bool {
self.pending.len() >= self.max_size
}
pub fn capacity(&self) -> usize {
self.max_size
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sampling::SamplingParams;
use oxibonsai_core::config::Qwen3Config;
fn make_engine() -> InferenceEngine<'static> {
let config = Qwen3Config::bonsai_8b();
InferenceEngine::new(config, SamplingParams::default(), 42)
}
#[test]
fn batch_generate_empty_prompts() {
let mut engine = make_engine();
let results = batch_generate(&mut engine, &[], 10);
assert!(results.is_empty());
}
#[test]
fn batch_generate_single_empty_prompt() {
let mut engine = make_engine();
let prompts = vec![vec![]];
let results = batch_generate(&mut engine, &prompts, 10);
assert_eq!(results.len(), 1);
let result = results.into_iter().next().expect("should have one result");
assert!(result.is_ok());
let br = result.expect("should be ok");
assert_eq!(br.prompt_tokens, 0);
assert!(br.generated_tokens.is_empty());
assert_eq!(br.finish_reason, FinishReason::Eos);
}
#[test]
fn batch_generate_multiple_prompts() {
let mut engine = make_engine();
let prompts = vec![vec![], vec![], vec![]];
let results = batch_generate(&mut engine, &prompts, 5);
assert_eq!(results.len(), 3);
for result in &results {
assert!(result.is_ok());
}
}
#[test]
fn batch_generate_with_timeout_respects_batch_size() {
let mut engine = make_engine();
let config = BatchConfig {
max_batch_size: 2,
max_tokens_per_request: 10,
timeout_per_request_ms: Some(5_000),
};
let prompts = vec![vec![]; 5];
let results = batch_generate_with_timeout(&mut engine, &prompts, &config);
assert_eq!(results.len(), 2);
}
#[test]
fn batch_config_default_values() {
let config = BatchConfig::default();
assert_eq!(config.max_batch_size, 8);
assert_eq!(config.max_tokens_per_request, 512);
assert_eq!(config.timeout_per_request_ms, Some(30_000));
}
#[test]
fn finish_reason_display() {
assert_eq!(format!("{}", FinishReason::MaxTokens), "max_tokens");
assert_eq!(format!("{}", FinishReason::Eos), "eos");
assert_eq!(format!("{}", FinishReason::Error), "error");
assert_eq!(format!("{}", FinishReason::Timeout), "timeout");
}
#[test]
fn queue_new_empty() {
let queue = RequestQueue::new(10);
assert!(queue.is_empty());
assert!(!queue.is_full());
assert_eq!(queue.len(), 0);
assert_eq!(queue.capacity(), 10);
}
#[test]
fn queue_min_capacity_is_one() {
let queue = RequestQueue::new(0);
assert_eq!(queue.capacity(), 1);
}
#[test]
fn queue_push_and_drain() {
let mut queue = RequestQueue::new(10);
for i in 0..5 {
let req = BatchRequest {
prompt_tokens: vec![i as u32],
max_tokens: 10,
params: SamplingParams::default(),
};
queue.push(req).expect("should succeed");
}
assert_eq!(queue.len(), 5);
assert!(!queue.is_full());
let batch = queue.drain_batch(3);
assert_eq!(batch.len(), 3);
assert_eq!(queue.len(), 2);
assert_eq!(batch[0].prompt_tokens, vec![0]);
assert_eq!(batch[1].prompt_tokens, vec![1]);
assert_eq!(batch[2].prompt_tokens, vec![2]);
}
#[test]
fn queue_drain_more_than_available() {
let mut queue = RequestQueue::new(10);
let req = BatchRequest {
prompt_tokens: vec![42],
max_tokens: 10,
params: SamplingParams::default(),
};
queue.push(req).expect("should succeed");
let batch = queue.drain_batch(100);
assert_eq!(batch.len(), 1);
assert!(queue.is_empty());
}
#[test]
fn queue_full_rejects_push() {
let mut queue = RequestQueue::new(2);
let req1 = BatchRequest {
prompt_tokens: vec![1],
max_tokens: 10,
params: SamplingParams::default(),
};
let req2 = BatchRequest {
prompt_tokens: vec![2],
max_tokens: 10,
params: SamplingParams::default(),
};
let req3 = BatchRequest {
prompt_tokens: vec![3],
max_tokens: 10,
params: SamplingParams::default(),
};
queue.push(req1).expect("should succeed");
queue.push(req2).expect("should succeed");
assert!(queue.is_full());
let result = queue.push(req3);
assert!(result.is_err());
}
#[test]
fn queue_drain_empty() {
let mut queue = RequestQueue::new(5);
let batch = queue.drain_batch(3);
assert!(batch.is_empty());
}
}