Skip to main content

oxibonsai_runtime/
continuous_batch.rs

1//! Continuous (iteration-level) batching for OxiBonsai.
2//!
3//! Continuous batching processes multiple inference requests simultaneously by
4//! adding new requests to the active set as slots free up — unlike static
5//! batching where every request in a batch must start and finish together.
6//!
7//! The [`ContinuousBatchScheduler`] maintains three queues:
8//!
9//! 1. **Waiting queue** — requests awaiting a slot.
10//! 2. **Active set** — at most `max_concurrent` requests currently being decoded.
11//! 3. **Completed list** — finished requests available for result retrieval.
12//!
13//! Each call to [`ContinuousBatchScheduler::step`] advances every active request
14//! by exactly one token, promoting waiting requests when slots become available.
15
16use std::collections::VecDeque;
17
18use crate::engine::InferenceEngine;
19use crate::sampling::SamplingParams;
20
21// ─── Priority ──────────────────────────────────────────────────────────────
22
23/// Priority level for request scheduling.
24///
25/// Higher-priority requests are promoted from the waiting queue before
26/// lower-priority ones when slots become available.
27#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
28pub enum RequestPriority {
29    /// Lowest priority — background work.
30    Low = 0,
31    /// Default priority for most requests.
32    #[default]
33    Normal = 1,
34    /// Elevated priority — user-facing interactive requests.
35    High = 2,
36    /// Highest priority — real-time / SLA-bound requests.
37    Critical = 3,
38}
39
40// ─── State ─────────────────────────────────────────────────────────────────
41
42/// Lifecycle state of a [`BatchRequest`].
43#[derive(Debug, Clone, PartialEq)]
44pub enum RequestState {
45    /// Sitting in the waiting queue.
46    Waiting,
47    /// Running the prompt prefill phase.
48    Prefilling,
49    /// Actively generating tokens one by one.
50    Decoding,
51    /// All tokens generated (or EOS hit).
52    Completed,
53    /// Generation failed with the enclosed message.
54    Failed(String),
55}
56
57// ─── BatchRequest ──────────────────────────────────────────────────────────
58
59/// A single inference request managed by the continuous-batch scheduler.
60pub struct BatchRequest {
61    /// Unique request identifier returned by [`ContinuousBatchScheduler::submit`].
62    pub id: u64,
63    /// Tokenised prompt.
64    pub prompt_tokens: Vec<u32>,
65    /// Sampling parameters for this request.
66    pub params: SamplingParams,
67    /// Maximum number of tokens to generate.
68    pub max_tokens: usize,
69    /// Scheduling priority.
70    pub priority: RequestPriority,
71    /// Current lifecycle state.
72    pub state: RequestState,
73    /// Tokens generated so far (not including the prompt).
74    pub generated_tokens: Vec<u32>,
75    /// Wall-clock time at which the request was submitted.
76    pub created_at: std::time::Instant,
77    /// Wall-clock time at which the first token was generated (prefill complete).
78    pub started_at: Option<std::time::Instant>,
79    /// Wall-clock time at which generation finished.
80    pub completed_at: Option<std::time::Instant>,
81}
82
83impl BatchRequest {
84    /// Create a new request with `Normal` priority.
85    pub fn new(
86        id: u64,
87        prompt_tokens: Vec<u32>,
88        params: SamplingParams,
89        max_tokens: usize,
90    ) -> Self {
91        Self {
92            id,
93            prompt_tokens,
94            params,
95            max_tokens,
96            priority: RequestPriority::Normal,
97            state: RequestState::Waiting,
98            generated_tokens: Vec::new(),
99            created_at: std::time::Instant::now(),
100            started_at: None,
101            completed_at: None,
102        }
103    }
104
105    /// Override the priority, returning `self` for builder-style chaining.
106    pub fn with_priority(mut self, priority: RequestPriority) -> Self {
107        self.priority = priority;
108        self
109    }
110
111    /// Elapsed time from submission to first generated token.
112    ///
113    /// Returns `None` if the first token has not yet been produced.
114    pub fn time_to_first_token(&self) -> Option<std::time::Duration> {
115        self.started_at.map(|s| s.duration_since(self.created_at))
116    }
117
118    /// Elapsed time from submission to completion.
119    ///
120    /// Returns `None` if the request has not yet completed.
121    pub fn total_latency(&self) -> Option<std::time::Duration> {
122        self.completed_at.map(|c| c.duration_since(self.created_at))
123    }
124
125    /// Number of tokens generated so far.
126    pub fn tokens_generated(&self) -> usize {
127        self.generated_tokens.len()
128    }
129
130    /// `true` when the request is in [`RequestState::Completed`] or
131    /// [`RequestState::Failed`].
132    pub fn is_finished(&self) -> bool {
133        matches!(
134            self.state,
135            RequestState::Completed | RequestState::Failed(_)
136        )
137    }
138}
139
140// ─── Errors ────────────────────────────────────────────────────────────────
141
142/// Errors returned by the continuous-batch scheduler.
143#[derive(Debug, thiserror::Error)]
144pub enum SchedulerError {
145    /// The waiting queue is at capacity.
146    #[error("Queue full: {max_queue_size} requests waiting")]
147    QueueFull {
148        /// The configured maximum queue size.
149        max_queue_size: usize,
150    },
151    /// No request with the given ID was found.
152    #[error("Request {id} not found")]
153    NotFound {
154        /// The unknown request ID.
155        id: u64,
156    },
157}
158
159// ─── Stats ─────────────────────────────────────────────────────────────────
160
161/// Throughput statistics snapshot.
162#[derive(Debug, serde::Serialize)]
163pub struct SchedulerStats {
164    /// Total requests submitted since the scheduler was created.
165    pub total_requests: u64,
166    /// Total tokens generated across all completed requests.
167    pub total_tokens_generated: u64,
168    /// Current depth of the waiting queue.
169    pub queue_depth: usize,
170    /// Number of actively decoding requests.
171    pub active_count: usize,
172}
173
174// ─── Scheduler ─────────────────────────────────────────────────────────────
175
176/// Continuous-batch scheduler.
177///
178/// Manages request lifecycle from submission through generation to completion,
179/// interleaving multiple requests at the token level.
180pub struct ContinuousBatchScheduler {
181    /// Maximum number of requests decoding simultaneously.
182    pub max_concurrent: usize,
183    /// Maximum number of requests that may wait in the queue.
184    pub max_queue_size: usize,
185
186    queue: VecDeque<BatchRequest>,
187    active: Vec<BatchRequest>,
188    completed: Vec<BatchRequest>,
189    next_id: u64,
190    total_requests: u64,
191    total_tokens_generated: u64,
192}
193
194impl ContinuousBatchScheduler {
195    /// Create a new scheduler.
196    ///
197    /// `max_concurrent` — at most this many requests decode in parallel.
198    /// `max_queue_size` — queue rejects new submissions beyond this count.
199    pub fn new(max_concurrent: usize, max_queue_size: usize) -> Self {
200        Self {
201            max_concurrent: max_concurrent.max(1),
202            max_queue_size: max_queue_size.max(1),
203            queue: VecDeque::new(),
204            active: Vec::new(),
205            completed: Vec::new(),
206            next_id: 1,
207            total_requests: 0,
208            total_tokens_generated: 0,
209        }
210    }
211
212    /// Submit a request with `Normal` priority.
213    ///
214    /// Returns the assigned request ID, or [`SchedulerError::QueueFull`] if the
215    /// waiting queue is already at capacity.
216    pub fn submit(
217        &mut self,
218        prompt_tokens: Vec<u32>,
219        params: SamplingParams,
220        max_tokens: usize,
221    ) -> Result<u64, SchedulerError> {
222        self.submit_with_priority(prompt_tokens, params, max_tokens, RequestPriority::Normal)
223    }
224
225    /// Submit a request with an explicit priority.
226    pub fn submit_with_priority(
227        &mut self,
228        prompt_tokens: Vec<u32>,
229        params: SamplingParams,
230        max_tokens: usize,
231        priority: RequestPriority,
232    ) -> Result<u64, SchedulerError> {
233        if self.queue.len() >= self.max_queue_size {
234            return Err(SchedulerError::QueueFull {
235                max_queue_size: self.max_queue_size,
236            });
237        }
238
239        let id = self.next_id;
240        self.next_id += 1;
241        self.total_requests += 1;
242
243        let request =
244            BatchRequest::new(id, prompt_tokens, params, max_tokens).with_priority(priority);
245
246        // Insert maintaining priority order (higher priority → closer to front)
247        let pos = self
248            .queue
249            .iter()
250            .position(|r| r.priority < priority)
251            .unwrap_or(self.queue.len());
252        self.queue.insert(pos, request);
253
254        Ok(id)
255    }
256
257    /// Advance one iteration of the batch.
258    ///
259    /// 1. Promotes waiting requests into the active set until `max_concurrent`
260    ///    slots are full (or the queue is drained).
261    /// 2. Steps every active request by generating one token.
262    /// 3. Moves finished requests to the completed list.
263    pub fn step(&mut self, engine: &mut InferenceEngine<'_>) {
264        // --- Promote waiting requests into the active set ---
265        while self.active.len() < self.max_concurrent {
266            match self.queue.pop_front() {
267                Some(mut req) => {
268                    req.state = RequestState::Prefilling;
269                    self.active.push(req);
270                }
271                None => break,
272            }
273        }
274
275        if self.active.is_empty() {
276            return;
277        }
278
279        // --- Step each active request by one token ---
280        let mut finished_indices: Vec<usize> = Vec::new();
281
282        for (idx, req) in self.active.iter_mut().enumerate() {
283            // Build the full context: prompt + already-generated tokens
284            let context: Vec<u32> = req
285                .prompt_tokens
286                .iter()
287                .chain(req.generated_tokens.iter())
288                .copied()
289                .collect();
290
291            // Run the engine for a single token
292            engine.reset();
293            let generated = engine.generate(&context, 1);
294
295            match generated {
296                Ok(new_tokens) => {
297                    if req.started_at.is_none() {
298                        req.started_at = Some(std::time::Instant::now());
299                        req.state = RequestState::Decoding;
300                    }
301
302                    if let Some(&token) = new_tokens.first() {
303                        req.generated_tokens.push(token);
304                    }
305
306                    // Check stopping conditions
307                    let hit_max = req.generated_tokens.len() >= req.max_tokens;
308                    let hit_eos = new_tokens.is_empty(); // engine stopped at EOS
309
310                    if hit_max || hit_eos {
311                        req.state = RequestState::Completed;
312                        req.completed_at = Some(std::time::Instant::now());
313                        finished_indices.push(idx);
314                    }
315                }
316                Err(e) => {
317                    req.state = RequestState::Failed(e.to_string());
318                    req.completed_at = Some(std::time::Instant::now());
319                    finished_indices.push(idx);
320                }
321            }
322        }
323
324        // Move finished requests to completed list (iterate in reverse to
325        // preserve index validity during removal)
326        for &idx in finished_indices.iter().rev() {
327            let req = self.active.remove(idx);
328            self.total_tokens_generated += req.generated_tokens.len() as u64;
329            self.completed.push(req);
330        }
331    }
332
333    /// Run all pending and active requests to completion, blocking until the
334    /// scheduler is idle.
335    pub fn run_to_completion(&mut self, engine: &mut InferenceEngine<'_>) {
336        while !self.is_idle() {
337            self.step(engine);
338        }
339    }
340
341    /// Look up a completed (or failed) request by ID.
342    ///
343    /// Returns `None` if the request is still waiting/decoding or does not exist.
344    pub fn get_result(&self, id: u64) -> Option<&BatchRequest> {
345        self.completed.iter().find(|r| r.id == id)
346    }
347
348    /// Number of requests currently waiting in the queue.
349    pub fn queue_depth(&self) -> usize {
350        self.queue.len()
351    }
352
353    /// Number of requests currently being decoded.
354    pub fn active_count(&self) -> usize {
355        self.active.len()
356    }
357
358    /// Number of completed (or failed) requests.
359    pub fn completed_count(&self) -> usize {
360        self.completed.len()
361    }
362
363    /// `true` when there are no waiting or active requests.
364    pub fn is_idle(&self) -> bool {
365        self.queue.is_empty() && self.active.is_empty()
366    }
367
368    /// Snapshot of current throughput statistics.
369    pub fn throughput_stats(&self) -> SchedulerStats {
370        SchedulerStats {
371            total_requests: self.total_requests,
372            total_tokens_generated: self.total_tokens_generated,
373            queue_depth: self.queue.len(),
374            active_count: self.active.len(),
375        }
376    }
377
378    /// Remove and return all completed requests, clearing the completed list.
379    pub fn drain_completed(&mut self) -> Vec<BatchRequest> {
380        std::mem::take(&mut self.completed)
381    }
382}
383
384// ─── Tests ─────────────────────────────────────────────────────────────────
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use oxibonsai_core::config::Qwen3Config;
390
391    fn make_engine() -> InferenceEngine<'static> {
392        let config = Qwen3Config::bonsai_8b();
393        InferenceEngine::new(config, SamplingParams::default(), 42)
394    }
395
396    fn default_params() -> SamplingParams {
397        SamplingParams {
398            temperature: 0.0, // greedy for determinism
399            ..Default::default()
400        }
401    }
402
403    // ── Submit / queue tests ───────────────────────────────────────────────
404
405    #[test]
406    fn test_scheduler_submit_returns_id() {
407        let mut sched = ContinuousBatchScheduler::new(4, 64);
408        let id1 = sched
409            .submit(vec![1, 2, 3], default_params(), 10)
410            .expect("submit should succeed");
411        let id2 = sched
412            .submit(vec![4, 5, 6], default_params(), 10)
413            .expect("submit should succeed");
414        assert_ne!(id1, id2, "IDs must be unique");
415        assert!(id1 > 0 && id2 > 0);
416    }
417
418    #[test]
419    fn test_scheduler_queue_depth() {
420        let mut sched = ContinuousBatchScheduler::new(1, 64);
421        assert_eq!(sched.queue_depth(), 0);
422
423        sched
424            .submit(vec![1], default_params(), 5)
425            .expect("submit should succeed");
426        sched
427            .submit(vec![2], default_params(), 5)
428            .expect("submit should succeed");
429        assert_eq!(sched.queue_depth(), 2);
430    }
431
432    #[test]
433    fn test_scheduler_max_queue_enforced() {
434        let mut sched = ContinuousBatchScheduler::new(8, 2);
435        sched
436            .submit(vec![1], default_params(), 5)
437            .expect("first submit should succeed");
438        sched
439            .submit(vec![2], default_params(), 5)
440            .expect("second submit should succeed");
441
442        let err = sched
443            .submit(vec![3], default_params(), 5)
444            .expect_err("third submit should be rejected");
445
446        assert!(
447            matches!(err, SchedulerError::QueueFull { max_queue_size: 2 }),
448            "unexpected error variant: {err}"
449        );
450    }
451
452    // ── Priority tests ─────────────────────────────────────────────────────
453
454    #[test]
455    fn test_request_priority_ordering() {
456        assert!(RequestPriority::Critical > RequestPriority::High);
457        assert!(RequestPriority::High > RequestPriority::Normal);
458        assert!(RequestPriority::Normal > RequestPriority::Low);
459    }
460
461    #[test]
462    fn test_priority_queue_ordering() {
463        let mut sched = ContinuousBatchScheduler::new(1, 64);
464
465        // Submit low priority first, then high priority
466        sched
467            .submit_with_priority(vec![1], default_params(), 5, RequestPriority::Low)
468            .expect("submit low");
469        sched
470            .submit_with_priority(vec![2], default_params(), 5, RequestPriority::High)
471            .expect("submit high");
472
473        // The high-priority request should be at the front of the queue
474        let front = sched.queue.front().expect("queue should not be empty");
475        assert_eq!(front.priority, RequestPriority::High);
476    }
477
478    // ── State transition tests ─────────────────────────────────────────────
479
480    #[test]
481    fn test_request_state_transitions() {
482        let req = BatchRequest::new(1, vec![10, 11], default_params(), 5);
483        assert_eq!(req.state, RequestState::Waiting);
484        assert!(!req.is_finished());
485
486        let mut req = req;
487        req.state = RequestState::Prefilling;
488        assert!(!req.is_finished());
489
490        req.state = RequestState::Decoding;
491        assert!(!req.is_finished());
492
493        req.state = RequestState::Completed;
494        assert!(req.is_finished());
495
496        req.state = RequestState::Failed("oops".into());
497        assert!(req.is_finished());
498    }
499
500    // ── Latency measurement tests ──────────────────────────────────────────
501
502    #[test]
503    fn test_batch_request_time_to_first_token() {
504        let mut req = BatchRequest::new(42, vec![1, 2, 3], default_params(), 10);
505        assert!(req.time_to_first_token().is_none());
506        assert!(req.total_latency().is_none());
507
508        // Simulate first-token timing
509        req.started_at = Some(req.created_at + std::time::Duration::from_millis(10));
510        let ttft = req.time_to_first_token().expect("should have TTFT");
511        assert!(ttft.as_millis() >= 10, "TTFT should be >= 10ms");
512
513        req.completed_at = Some(req.created_at + std::time::Duration::from_millis(50));
514        let lat = req.total_latency().expect("should have latency");
515        assert!(lat.as_millis() >= 50, "latency should be >= 50ms");
516    }
517
518    // ── drain_completed tests ──────────────────────────────────────────────
519
520    #[test]
521    fn test_scheduler_drain_completed() {
522        let mut sched = ContinuousBatchScheduler::new(4, 64);
523        let mut engine = make_engine();
524
525        let _id = sched
526            .submit(vec![], default_params(), 2)
527            .expect("submit should succeed");
528
529        sched.run_to_completion(&mut engine);
530
531        let drained = sched.drain_completed();
532        assert!(
533            !drained.is_empty(),
534            "should have at least one completed request"
535        );
536        assert_eq!(
537            sched.completed_count(),
538            0,
539            "completed list should be empty after drain"
540        );
541    }
542
543    // ── Stats tests ────────────────────────────────────────────────────────
544
545    #[test]
546    fn test_scheduler_stats() {
547        let mut sched = ContinuousBatchScheduler::new(4, 64);
548        sched
549            .submit(vec![1, 2], default_params(), 5)
550            .expect("submit should succeed");
551        sched
552            .submit(vec![3, 4], default_params(), 5)
553            .expect("submit should succeed");
554
555        let stats = sched.throughput_stats();
556        assert_eq!(stats.total_requests, 2);
557        assert_eq!(stats.queue_depth, 2);
558        assert_eq!(stats.active_count, 0);
559        assert_eq!(stats.total_tokens_generated, 0);
560    }
561
562    // ── run_to_completion tests ────────────────────────────────────────────
563
564    #[test]
565    fn test_scheduler_run_to_completion() {
566        let mut sched = ContinuousBatchScheduler::new(4, 64);
567        let mut engine = make_engine();
568
569        // Empty prompt — engine returns immediately with EOS
570        let id = sched
571            .submit(vec![], default_params(), 5)
572            .expect("submit should succeed");
573
574        sched.run_to_completion(&mut engine);
575
576        assert!(sched.is_idle(), "scheduler should be idle after completion");
577
578        let result = sched.get_result(id).expect("result should be available");
579        assert!(
580            result.is_finished(),
581            "request should be finished, state={:?}",
582            result.state
583        );
584    }
585
586    #[test]
587    fn test_scheduler_is_idle_initially() {
588        let sched = ContinuousBatchScheduler::new(4, 64);
589        assert!(sched.is_idle());
590        assert_eq!(sched.active_count(), 0);
591        assert_eq!(sched.queue_depth(), 0);
592        assert_eq!(sched.completed_count(), 0);
593    }
594}