1use std::time::Instant;
8
9use crate::engine::InferenceEngine;
10use crate::error::{RuntimeError, RuntimeResult};
11use crate::sampling::SamplingParams;
12
13#[derive(Debug, Clone)]
17pub struct BatchResult {
18 pub prompt_tokens: usize,
20 pub generated_tokens: Vec<u32>,
22 pub finish_reason: FinishReason,
24 pub elapsed_seconds: f64,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum FinishReason {
31 MaxTokens,
33 Eos,
35 Error,
37 Timeout,
39}
40
41impl std::fmt::Display for FinishReason {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 Self::MaxTokens => write!(f, "max_tokens"),
45 Self::Eos => write!(f, "eos"),
46 Self::Error => write!(f, "error"),
47 Self::Timeout => write!(f, "timeout"),
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
56pub struct BatchConfig {
57 pub max_batch_size: usize,
59 pub max_tokens_per_request: usize,
61 pub timeout_per_request_ms: Option<u64>,
63}
64
65impl Default for BatchConfig {
66 fn default() -> Self {
67 Self {
68 max_batch_size: 8,
69 max_tokens_per_request: 512,
70 timeout_per_request_ms: Some(30_000),
71 }
72 }
73}
74
75pub fn batch_generate(
82 engine: &mut InferenceEngine<'_>,
83 prompts: &[Vec<u32>],
84 max_tokens: usize,
85) -> Vec<RuntimeResult<BatchResult>> {
86 prompts
87 .iter()
88 .map(|prompt| {
89 engine.reset();
90 let start = Instant::now();
91
92 match engine.generate(prompt, max_tokens) {
93 Ok(tokens) => {
94 let finish_reason = if tokens.len() >= max_tokens {
95 FinishReason::MaxTokens
96 } else {
97 FinishReason::Eos
98 };
99 Ok(BatchResult {
100 prompt_tokens: prompt.len(),
101 generated_tokens: tokens,
102 finish_reason,
103 elapsed_seconds: start.elapsed().as_secs_f64(),
104 })
105 }
106 Err(e) => Err(e),
107 }
108 })
109 .collect()
110}
111
112pub fn batch_generate_with_timeout(
118 engine: &mut InferenceEngine<'_>,
119 prompts: &[Vec<u32>],
120 config: &BatchConfig,
121) -> Vec<RuntimeResult<BatchResult>> {
122 let effective_prompts = if prompts.len() > config.max_batch_size {
123 &prompts[..config.max_batch_size]
124 } else {
125 prompts
126 };
127
128 effective_prompts
129 .iter()
130 .map(|prompt| {
131 engine.reset();
132 let start = Instant::now();
133 let timeout = config
134 .timeout_per_request_ms
135 .map(std::time::Duration::from_millis);
136
137 match engine.generate(prompt, config.max_tokens_per_request) {
138 Ok(tokens) => {
139 let elapsed = start.elapsed();
140 let timed_out = timeout.is_some_and(|t| elapsed > t);
141
142 let finish_reason = if timed_out {
143 FinishReason::Timeout
144 } else if tokens.len() >= config.max_tokens_per_request {
145 FinishReason::MaxTokens
146 } else {
147 FinishReason::Eos
148 };
149
150 Ok(BatchResult {
151 prompt_tokens: prompt.len(),
152 generated_tokens: tokens,
153 finish_reason,
154 elapsed_seconds: elapsed.as_secs_f64(),
155 })
156 }
157 Err(e) => Err(e),
158 }
159 })
160 .collect()
161}
162
163#[derive(Debug, Clone)]
167pub struct BatchRequest {
168 pub prompt_tokens: Vec<u32>,
170 pub max_tokens: usize,
172 pub params: SamplingParams,
174}
175
176pub struct RequestQueue {
181 pending: Vec<BatchRequest>,
182 max_size: usize,
183}
184
185impl RequestQueue {
186 pub fn new(max_size: usize) -> Self {
188 Self {
189 pending: Vec::with_capacity(max_size.min(1024)),
190 max_size: max_size.max(1),
191 }
192 }
193
194 pub fn push(&mut self, request: BatchRequest) -> Result<(), RuntimeError> {
198 if self.pending.len() >= self.max_size {
199 return Err(RuntimeError::Server(format!(
200 "request queue full (capacity: {})",
201 self.max_size
202 )));
203 }
204 self.pending.push(request);
205 Ok(())
206 }
207
208 pub fn drain_batch(&mut self, batch_size: usize) -> Vec<BatchRequest> {
212 let n = batch_size.min(self.pending.len());
213 self.pending.drain(..n).collect()
214 }
215
216 pub fn len(&self) -> usize {
218 self.pending.len()
219 }
220
221 pub fn is_empty(&self) -> bool {
223 self.pending.is_empty()
224 }
225
226 pub fn is_full(&self) -> bool {
228 self.pending.len() >= self.max_size
229 }
230
231 pub fn capacity(&self) -> usize {
233 self.max_size
234 }
235}
236
237#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::sampling::SamplingParams;
243 use oxibonsai_core::config::Qwen3Config;
244
245 fn make_engine() -> InferenceEngine<'static> {
246 let config = Qwen3Config::bonsai_8b();
247 InferenceEngine::new(config, SamplingParams::default(), 42)
248 }
249
250 #[test]
251 fn batch_generate_empty_prompts() {
252 let mut engine = make_engine();
253 let results = batch_generate(&mut engine, &[], 10);
254 assert!(results.is_empty());
255 }
256
257 #[test]
258 fn batch_generate_single_empty_prompt() {
259 let mut engine = make_engine();
260 let prompts = vec![vec![]];
261 let results = batch_generate(&mut engine, &prompts, 10);
262 assert_eq!(results.len(), 1);
263 let result = results.into_iter().next().expect("should have one result");
264 assert!(result.is_ok());
265 let br = result.expect("should be ok");
266 assert_eq!(br.prompt_tokens, 0);
267 assert!(br.generated_tokens.is_empty());
268 assert_eq!(br.finish_reason, FinishReason::Eos);
269 }
270
271 #[test]
272 fn batch_generate_multiple_prompts() {
273 let mut engine = make_engine();
274 let prompts = vec![vec![], vec![], vec![]];
275 let results = batch_generate(&mut engine, &prompts, 5);
276 assert_eq!(results.len(), 3);
277 for result in &results {
278 assert!(result.is_ok());
279 }
280 }
281
282 #[test]
283 fn batch_generate_with_timeout_respects_batch_size() {
284 let mut engine = make_engine();
285 let config = BatchConfig {
286 max_batch_size: 2,
287 max_tokens_per_request: 10,
288 timeout_per_request_ms: Some(5_000),
289 };
290 let prompts = vec![vec![]; 5];
292 let results = batch_generate_with_timeout(&mut engine, &prompts, &config);
293 assert_eq!(results.len(), 2);
294 }
295
296 #[test]
297 fn batch_config_default_values() {
298 let config = BatchConfig::default();
299 assert_eq!(config.max_batch_size, 8);
300 assert_eq!(config.max_tokens_per_request, 512);
301 assert_eq!(config.timeout_per_request_ms, Some(30_000));
302 }
303
304 #[test]
305 fn finish_reason_display() {
306 assert_eq!(format!("{}", FinishReason::MaxTokens), "max_tokens");
307 assert_eq!(format!("{}", FinishReason::Eos), "eos");
308 assert_eq!(format!("{}", FinishReason::Error), "error");
309 assert_eq!(format!("{}", FinishReason::Timeout), "timeout");
310 }
311
312 #[test]
315 fn queue_new_empty() {
316 let queue = RequestQueue::new(10);
317 assert!(queue.is_empty());
318 assert!(!queue.is_full());
319 assert_eq!(queue.len(), 0);
320 assert_eq!(queue.capacity(), 10);
321 }
322
323 #[test]
324 fn queue_min_capacity_is_one() {
325 let queue = RequestQueue::new(0);
326 assert_eq!(queue.capacity(), 1);
327 }
328
329 #[test]
330 fn queue_push_and_drain() {
331 let mut queue = RequestQueue::new(10);
332 for i in 0..5 {
333 let req = BatchRequest {
334 prompt_tokens: vec![i as u32],
335 max_tokens: 10,
336 params: SamplingParams::default(),
337 };
338 queue.push(req).expect("should succeed");
339 }
340 assert_eq!(queue.len(), 5);
341 assert!(!queue.is_full());
342
343 let batch = queue.drain_batch(3);
344 assert_eq!(batch.len(), 3);
345 assert_eq!(queue.len(), 2);
346
347 assert_eq!(batch[0].prompt_tokens, vec![0]);
349 assert_eq!(batch[1].prompt_tokens, vec![1]);
350 assert_eq!(batch[2].prompt_tokens, vec![2]);
351 }
352
353 #[test]
354 fn queue_drain_more_than_available() {
355 let mut queue = RequestQueue::new(10);
356 let req = BatchRequest {
357 prompt_tokens: vec![42],
358 max_tokens: 10,
359 params: SamplingParams::default(),
360 };
361 queue.push(req).expect("should succeed");
362
363 let batch = queue.drain_batch(100);
364 assert_eq!(batch.len(), 1);
365 assert!(queue.is_empty());
366 }
367
368 #[test]
369 fn queue_full_rejects_push() {
370 let mut queue = RequestQueue::new(2);
371 let req1 = BatchRequest {
372 prompt_tokens: vec![1],
373 max_tokens: 10,
374 params: SamplingParams::default(),
375 };
376 let req2 = BatchRequest {
377 prompt_tokens: vec![2],
378 max_tokens: 10,
379 params: SamplingParams::default(),
380 };
381 let req3 = BatchRequest {
382 prompt_tokens: vec![3],
383 max_tokens: 10,
384 params: SamplingParams::default(),
385 };
386
387 queue.push(req1).expect("should succeed");
388 queue.push(req2).expect("should succeed");
389 assert!(queue.is_full());
390
391 let result = queue.push(req3);
392 assert!(result.is_err());
393 }
394
395 #[test]
396 fn queue_drain_empty() {
397 let mut queue = RequestQueue::new(5);
398 let batch = queue.drain_batch(3);
399 assert!(batch.is_empty());
400 }
401}