Skip to main content

oxibonsai_runtime/
batch_engine.rs

1//! Batched inference: process multiple prompts efficiently.
2//!
3//! Groups prompts into batches for prefill, then generates independently.
4//! Provides a [`RequestQueue`] for continuous batching scenarios where
5//! requests arrive over time and are drained in configurable batch sizes.
6
7use std::time::Instant;
8
9use crate::engine::InferenceEngine;
10use crate::error::{RuntimeError, RuntimeResult};
11use crate::sampling::SamplingParams;
12
13// ─── Result types ──────────────────────────────────────────────────────
14
15/// Result of a single batch element.
16#[derive(Debug, Clone)]
17pub struct BatchResult {
18    /// Number of prompt tokens processed.
19    pub prompt_tokens: usize,
20    /// Generated token IDs (not including the prompt).
21    pub generated_tokens: Vec<u32>,
22    /// Why generation stopped.
23    pub finish_reason: FinishReason,
24    /// Wall-clock time for this request in seconds.
25    pub elapsed_seconds: f64,
26}
27
28/// Reason why token generation stopped.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum FinishReason {
31    /// Reached the maximum token limit.
32    MaxTokens,
33    /// Generated the end-of-sequence token.
34    Eos,
35    /// An error occurred during generation.
36    Error,
37    /// Generation was stopped due to timeout.
38    Timeout,
39}
40
41impl std::fmt::Display for FinishReason {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        match self {
44            Self::MaxTokens => write!(f, "max_tokens"),
45            Self::Eos => write!(f, "eos"),
46            Self::Error => write!(f, "error"),
47            Self::Timeout => write!(f, "timeout"),
48        }
49    }
50}
51
52// ─── Batch configuration ───────────────────────────────────────────────
53
54/// Batch inference configuration.
55#[derive(Debug, Clone)]
56pub struct BatchConfig {
57    /// Maximum number of prompts per batch.
58    pub max_batch_size: usize,
59    /// Maximum tokens to generate per request.
60    pub max_tokens_per_request: usize,
61    /// Optional timeout per request in milliseconds.
62    pub timeout_per_request_ms: Option<u64>,
63}
64
65impl Default for BatchConfig {
66    fn default() -> Self {
67        Self {
68            max_batch_size: 8,
69            max_tokens_per_request: 512,
70            timeout_per_request_ms: Some(30_000),
71        }
72    }
73}
74
75// ─── Batch generation ──────────────────────────────────────────────────
76
77/// Process a batch of prompts sequentially (sharing the engine).
78///
79/// Each prompt is processed independently: the engine state is reset
80/// between prompts. Returns one result per prompt.
81pub fn batch_generate(
82    engine: &mut InferenceEngine<'_>,
83    prompts: &[Vec<u32>],
84    max_tokens: usize,
85) -> Vec<RuntimeResult<BatchResult>> {
86    prompts
87        .iter()
88        .map(|prompt| {
89            engine.reset();
90            let start = Instant::now();
91
92            match engine.generate(prompt, max_tokens) {
93                Ok(tokens) => {
94                    let finish_reason = if tokens.len() >= max_tokens {
95                        FinishReason::MaxTokens
96                    } else {
97                        FinishReason::Eos
98                    };
99                    Ok(BatchResult {
100                        prompt_tokens: prompt.len(),
101                        generated_tokens: tokens,
102                        finish_reason,
103                        elapsed_seconds: start.elapsed().as_secs_f64(),
104                    })
105                }
106                Err(e) => Err(e),
107            }
108        })
109        .collect()
110}
111
112/// Process a batch with timeout per request.
113///
114/// Uses [`BatchConfig`] to control generation limits and timeouts.
115/// If a request exceeds its timeout, it is terminated and marked
116/// with [`FinishReason::Timeout`].
117pub fn batch_generate_with_timeout(
118    engine: &mut InferenceEngine<'_>,
119    prompts: &[Vec<u32>],
120    config: &BatchConfig,
121) -> Vec<RuntimeResult<BatchResult>> {
122    let effective_prompts = if prompts.len() > config.max_batch_size {
123        &prompts[..config.max_batch_size]
124    } else {
125        prompts
126    };
127
128    effective_prompts
129        .iter()
130        .map(|prompt| {
131            engine.reset();
132            let start = Instant::now();
133            let timeout = config
134                .timeout_per_request_ms
135                .map(std::time::Duration::from_millis);
136
137            match engine.generate(prompt, config.max_tokens_per_request) {
138                Ok(tokens) => {
139                    let elapsed = start.elapsed();
140                    let timed_out = timeout.is_some_and(|t| elapsed > t);
141
142                    let finish_reason = if timed_out {
143                        FinishReason::Timeout
144                    } else if tokens.len() >= config.max_tokens_per_request {
145                        FinishReason::MaxTokens
146                    } else {
147                        FinishReason::Eos
148                    };
149
150                    Ok(BatchResult {
151                        prompt_tokens: prompt.len(),
152                        generated_tokens: tokens,
153                        finish_reason,
154                        elapsed_seconds: elapsed.as_secs_f64(),
155                    })
156                }
157                Err(e) => Err(e),
158            }
159        })
160        .collect()
161}
162
163// ─── Request queue ─────────────────────────────────────────────────────
164
165/// A single queued inference request.
166#[derive(Debug, Clone)]
167pub struct BatchRequest {
168    /// Tokenized prompt.
169    pub prompt_tokens: Vec<u32>,
170    /// Maximum tokens to generate.
171    pub max_tokens: usize,
172    /// Sampling parameters for this request.
173    pub params: SamplingParams,
174}
175
176/// Request queue for continuous batching.
177///
178/// Accumulates incoming requests and drains them in configurable
179/// batch sizes for efficient processing.
180pub struct RequestQueue {
181    pending: Vec<BatchRequest>,
182    max_size: usize,
183}
184
185impl RequestQueue {
186    /// Create a new request queue with the given maximum capacity.
187    pub fn new(max_size: usize) -> Self {
188        Self {
189            pending: Vec::with_capacity(max_size.min(1024)),
190            max_size: max_size.max(1),
191        }
192    }
193
194    /// Push a new request onto the queue.
195    ///
196    /// Returns an error if the queue is full.
197    pub fn push(&mut self, request: BatchRequest) -> Result<(), RuntimeError> {
198        if self.pending.len() >= self.max_size {
199            return Err(RuntimeError::Server(format!(
200                "request queue full (capacity: {})",
201                self.max_size
202            )));
203        }
204        self.pending.push(request);
205        Ok(())
206    }
207
208    /// Drain up to `batch_size` requests from the front of the queue.
209    ///
210    /// Returns the drained requests in FIFO order.
211    pub fn drain_batch(&mut self, batch_size: usize) -> Vec<BatchRequest> {
212        let n = batch_size.min(self.pending.len());
213        self.pending.drain(..n).collect()
214    }
215
216    /// Number of pending requests.
217    pub fn len(&self) -> usize {
218        self.pending.len()
219    }
220
221    /// Whether the queue is empty.
222    pub fn is_empty(&self) -> bool {
223        self.pending.is_empty()
224    }
225
226    /// Whether the queue is at capacity.
227    pub fn is_full(&self) -> bool {
228        self.pending.len() >= self.max_size
229    }
230
231    /// Maximum queue capacity.
232    pub fn capacity(&self) -> usize {
233        self.max_size
234    }
235}
236
237// ─── Tests ─────────────────────────────────────────────────────────────
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::sampling::SamplingParams;
243    use oxibonsai_core::config::Qwen3Config;
244
245    fn make_engine() -> InferenceEngine<'static> {
246        let config = Qwen3Config::bonsai_8b();
247        InferenceEngine::new(config, SamplingParams::default(), 42)
248    }
249
250    #[test]
251    fn batch_generate_empty_prompts() {
252        let mut engine = make_engine();
253        let results = batch_generate(&mut engine, &[], 10);
254        assert!(results.is_empty());
255    }
256
257    #[test]
258    fn batch_generate_single_empty_prompt() {
259        let mut engine = make_engine();
260        let prompts = vec![vec![]];
261        let results = batch_generate(&mut engine, &prompts, 10);
262        assert_eq!(results.len(), 1);
263        let result = results.into_iter().next().expect("should have one result");
264        assert!(result.is_ok());
265        let br = result.expect("should be ok");
266        assert_eq!(br.prompt_tokens, 0);
267        assert!(br.generated_tokens.is_empty());
268        assert_eq!(br.finish_reason, FinishReason::Eos);
269    }
270
271    #[test]
272    fn batch_generate_multiple_prompts() {
273        let mut engine = make_engine();
274        let prompts = vec![vec![], vec![], vec![]];
275        let results = batch_generate(&mut engine, &prompts, 5);
276        assert_eq!(results.len(), 3);
277        for result in &results {
278            assert!(result.is_ok());
279        }
280    }
281
282    #[test]
283    fn batch_generate_with_timeout_respects_batch_size() {
284        let mut engine = make_engine();
285        let config = BatchConfig {
286            max_batch_size: 2,
287            max_tokens_per_request: 10,
288            timeout_per_request_ms: Some(5_000),
289        };
290        // Provide 5 prompts but limit to 2
291        let prompts = vec![vec![]; 5];
292        let results = batch_generate_with_timeout(&mut engine, &prompts, &config);
293        assert_eq!(results.len(), 2);
294    }
295
296    #[test]
297    fn batch_config_default_values() {
298        let config = BatchConfig::default();
299        assert_eq!(config.max_batch_size, 8);
300        assert_eq!(config.max_tokens_per_request, 512);
301        assert_eq!(config.timeout_per_request_ms, Some(30_000));
302    }
303
304    #[test]
305    fn finish_reason_display() {
306        assert_eq!(format!("{}", FinishReason::MaxTokens), "max_tokens");
307        assert_eq!(format!("{}", FinishReason::Eos), "eos");
308        assert_eq!(format!("{}", FinishReason::Error), "error");
309        assert_eq!(format!("{}", FinishReason::Timeout), "timeout");
310    }
311
312    // ── RequestQueue tests ─────────────────────────────────────────────
313
314    #[test]
315    fn queue_new_empty() {
316        let queue = RequestQueue::new(10);
317        assert!(queue.is_empty());
318        assert!(!queue.is_full());
319        assert_eq!(queue.len(), 0);
320        assert_eq!(queue.capacity(), 10);
321    }
322
323    #[test]
324    fn queue_min_capacity_is_one() {
325        let queue = RequestQueue::new(0);
326        assert_eq!(queue.capacity(), 1);
327    }
328
329    #[test]
330    fn queue_push_and_drain() {
331        let mut queue = RequestQueue::new(10);
332        for i in 0..5 {
333            let req = BatchRequest {
334                prompt_tokens: vec![i as u32],
335                max_tokens: 10,
336                params: SamplingParams::default(),
337            };
338            queue.push(req).expect("should succeed");
339        }
340        assert_eq!(queue.len(), 5);
341        assert!(!queue.is_full());
342
343        let batch = queue.drain_batch(3);
344        assert_eq!(batch.len(), 3);
345        assert_eq!(queue.len(), 2);
346
347        // Check FIFO order
348        assert_eq!(batch[0].prompt_tokens, vec![0]);
349        assert_eq!(batch[1].prompt_tokens, vec![1]);
350        assert_eq!(batch[2].prompt_tokens, vec![2]);
351    }
352
353    #[test]
354    fn queue_drain_more_than_available() {
355        let mut queue = RequestQueue::new(10);
356        let req = BatchRequest {
357            prompt_tokens: vec![42],
358            max_tokens: 10,
359            params: SamplingParams::default(),
360        };
361        queue.push(req).expect("should succeed");
362
363        let batch = queue.drain_batch(100);
364        assert_eq!(batch.len(), 1);
365        assert!(queue.is_empty());
366    }
367
368    #[test]
369    fn queue_full_rejects_push() {
370        let mut queue = RequestQueue::new(2);
371        let req1 = BatchRequest {
372            prompt_tokens: vec![1],
373            max_tokens: 10,
374            params: SamplingParams::default(),
375        };
376        let req2 = BatchRequest {
377            prompt_tokens: vec![2],
378            max_tokens: 10,
379            params: SamplingParams::default(),
380        };
381        let req3 = BatchRequest {
382            prompt_tokens: vec![3],
383            max_tokens: 10,
384            params: SamplingParams::default(),
385        };
386
387        queue.push(req1).expect("should succeed");
388        queue.push(req2).expect("should succeed");
389        assert!(queue.is_full());
390
391        let result = queue.push(req3);
392        assert!(result.is_err());
393    }
394
395    #[test]
396    fn queue_drain_empty() {
397        let mut queue = RequestQueue::new(5);
398        let batch = queue.drain_batch(3);
399        assert!(batch.is_empty());
400    }
401}