1use std::collections::VecDeque;
17
18use crate::engine::InferenceEngine;
19use crate::sampling::SamplingParams;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
28pub enum RequestPriority {
29 Low = 0,
31 #[default]
33 Normal = 1,
34 High = 2,
36 Critical = 3,
38}
39
40#[derive(Debug, Clone, PartialEq)]
44pub enum RequestState {
45 Waiting,
47 Prefilling,
49 Decoding,
51 Completed,
53 Failed(String),
55}
56
57pub struct BatchRequest {
61 pub id: u64,
63 pub prompt_tokens: Vec<u32>,
65 pub params: SamplingParams,
67 pub max_tokens: usize,
69 pub priority: RequestPriority,
71 pub state: RequestState,
73 pub generated_tokens: Vec<u32>,
75 pub created_at: std::time::Instant,
77 pub started_at: Option<std::time::Instant>,
79 pub completed_at: Option<std::time::Instant>,
81}
82
83impl BatchRequest {
84 pub fn new(
86 id: u64,
87 prompt_tokens: Vec<u32>,
88 params: SamplingParams,
89 max_tokens: usize,
90 ) -> Self {
91 Self {
92 id,
93 prompt_tokens,
94 params,
95 max_tokens,
96 priority: RequestPriority::Normal,
97 state: RequestState::Waiting,
98 generated_tokens: Vec::new(),
99 created_at: std::time::Instant::now(),
100 started_at: None,
101 completed_at: None,
102 }
103 }
104
105 pub fn with_priority(mut self, priority: RequestPriority) -> Self {
107 self.priority = priority;
108 self
109 }
110
111 pub fn time_to_first_token(&self) -> Option<std::time::Duration> {
115 self.started_at.map(|s| s.duration_since(self.created_at))
116 }
117
118 pub fn total_latency(&self) -> Option<std::time::Duration> {
122 self.completed_at.map(|c| c.duration_since(self.created_at))
123 }
124
125 pub fn tokens_generated(&self) -> usize {
127 self.generated_tokens.len()
128 }
129
130 pub fn is_finished(&self) -> bool {
133 matches!(
134 self.state,
135 RequestState::Completed | RequestState::Failed(_)
136 )
137 }
138}
139
140#[derive(Debug, thiserror::Error)]
144pub enum SchedulerError {
145 #[error("Queue full: {max_queue_size} requests waiting")]
147 QueueFull {
148 max_queue_size: usize,
150 },
151 #[error("Request {id} not found")]
153 NotFound {
154 id: u64,
156 },
157}
158
159#[derive(Debug, serde::Serialize)]
163pub struct SchedulerStats {
164 pub total_requests: u64,
166 pub total_tokens_generated: u64,
168 pub queue_depth: usize,
170 pub active_count: usize,
172}
173
174pub struct ContinuousBatchScheduler {
181 pub max_concurrent: usize,
183 pub max_queue_size: usize,
185
186 queue: VecDeque<BatchRequest>,
187 active: Vec<BatchRequest>,
188 completed: Vec<BatchRequest>,
189 next_id: u64,
190 total_requests: u64,
191 total_tokens_generated: u64,
192}
193
194impl ContinuousBatchScheduler {
195 pub fn new(max_concurrent: usize, max_queue_size: usize) -> Self {
200 Self {
201 max_concurrent: max_concurrent.max(1),
202 max_queue_size: max_queue_size.max(1),
203 queue: VecDeque::new(),
204 active: Vec::new(),
205 completed: Vec::new(),
206 next_id: 1,
207 total_requests: 0,
208 total_tokens_generated: 0,
209 }
210 }
211
212 pub fn submit(
217 &mut self,
218 prompt_tokens: Vec<u32>,
219 params: SamplingParams,
220 max_tokens: usize,
221 ) -> Result<u64, SchedulerError> {
222 self.submit_with_priority(prompt_tokens, params, max_tokens, RequestPriority::Normal)
223 }
224
225 pub fn submit_with_priority(
227 &mut self,
228 prompt_tokens: Vec<u32>,
229 params: SamplingParams,
230 max_tokens: usize,
231 priority: RequestPriority,
232 ) -> Result<u64, SchedulerError> {
233 if self.queue.len() >= self.max_queue_size {
234 return Err(SchedulerError::QueueFull {
235 max_queue_size: self.max_queue_size,
236 });
237 }
238
239 let id = self.next_id;
240 self.next_id += 1;
241 self.total_requests += 1;
242
243 let request =
244 BatchRequest::new(id, prompt_tokens, params, max_tokens).with_priority(priority);
245
246 let pos = self
248 .queue
249 .iter()
250 .position(|r| r.priority < priority)
251 .unwrap_or(self.queue.len());
252 self.queue.insert(pos, request);
253
254 Ok(id)
255 }
256
257 pub fn step(&mut self, engine: &mut InferenceEngine<'_>) {
264 while self.active.len() < self.max_concurrent {
266 match self.queue.pop_front() {
267 Some(mut req) => {
268 req.state = RequestState::Prefilling;
269 self.active.push(req);
270 }
271 None => break,
272 }
273 }
274
275 if self.active.is_empty() {
276 return;
277 }
278
279 let mut finished_indices: Vec<usize> = Vec::new();
281
282 for (idx, req) in self.active.iter_mut().enumerate() {
283 let context: Vec<u32> = req
285 .prompt_tokens
286 .iter()
287 .chain(req.generated_tokens.iter())
288 .copied()
289 .collect();
290
291 engine.reset();
293 let generated = engine.generate(&context, 1);
294
295 match generated {
296 Ok(new_tokens) => {
297 if req.started_at.is_none() {
298 req.started_at = Some(std::time::Instant::now());
299 req.state = RequestState::Decoding;
300 }
301
302 if let Some(&token) = new_tokens.first() {
303 req.generated_tokens.push(token);
304 }
305
306 let hit_max = req.generated_tokens.len() >= req.max_tokens;
308 let hit_eos = new_tokens.is_empty(); if hit_max || hit_eos {
311 req.state = RequestState::Completed;
312 req.completed_at = Some(std::time::Instant::now());
313 finished_indices.push(idx);
314 }
315 }
316 Err(e) => {
317 req.state = RequestState::Failed(e.to_string());
318 req.completed_at = Some(std::time::Instant::now());
319 finished_indices.push(idx);
320 }
321 }
322 }
323
324 for &idx in finished_indices.iter().rev() {
327 let req = self.active.remove(idx);
328 self.total_tokens_generated += req.generated_tokens.len() as u64;
329 self.completed.push(req);
330 }
331 }
332
333 pub fn run_to_completion(&mut self, engine: &mut InferenceEngine<'_>) {
336 while !self.is_idle() {
337 self.step(engine);
338 }
339 }
340
341 pub fn get_result(&self, id: u64) -> Option<&BatchRequest> {
345 self.completed.iter().find(|r| r.id == id)
346 }
347
348 pub fn queue_depth(&self) -> usize {
350 self.queue.len()
351 }
352
353 pub fn active_count(&self) -> usize {
355 self.active.len()
356 }
357
358 pub fn completed_count(&self) -> usize {
360 self.completed.len()
361 }
362
363 pub fn is_idle(&self) -> bool {
365 self.queue.is_empty() && self.active.is_empty()
366 }
367
368 pub fn throughput_stats(&self) -> SchedulerStats {
370 SchedulerStats {
371 total_requests: self.total_requests,
372 total_tokens_generated: self.total_tokens_generated,
373 queue_depth: self.queue.len(),
374 active_count: self.active.len(),
375 }
376 }
377
378 pub fn drain_completed(&mut self) -> Vec<BatchRequest> {
380 std::mem::take(&mut self.completed)
381 }
382}
383
384#[cfg(test)]
387mod tests {
388 use super::*;
389 use oxibonsai_core::config::Qwen3Config;
390
391 fn make_engine() -> InferenceEngine<'static> {
392 let config = Qwen3Config::bonsai_8b();
393 InferenceEngine::new(config, SamplingParams::default(), 42)
394 }
395
396 fn default_params() -> SamplingParams {
397 SamplingParams {
398 temperature: 0.0, ..Default::default()
400 }
401 }
402
403 #[test]
406 fn test_scheduler_submit_returns_id() {
407 let mut sched = ContinuousBatchScheduler::new(4, 64);
408 let id1 = sched
409 .submit(vec![1, 2, 3], default_params(), 10)
410 .expect("submit should succeed");
411 let id2 = sched
412 .submit(vec![4, 5, 6], default_params(), 10)
413 .expect("submit should succeed");
414 assert_ne!(id1, id2, "IDs must be unique");
415 assert!(id1 > 0 && id2 > 0);
416 }
417
418 #[test]
419 fn test_scheduler_queue_depth() {
420 let mut sched = ContinuousBatchScheduler::new(1, 64);
421 assert_eq!(sched.queue_depth(), 0);
422
423 sched
424 .submit(vec![1], default_params(), 5)
425 .expect("submit should succeed");
426 sched
427 .submit(vec![2], default_params(), 5)
428 .expect("submit should succeed");
429 assert_eq!(sched.queue_depth(), 2);
430 }
431
432 #[test]
433 fn test_scheduler_max_queue_enforced() {
434 let mut sched = ContinuousBatchScheduler::new(8, 2);
435 sched
436 .submit(vec![1], default_params(), 5)
437 .expect("first submit should succeed");
438 sched
439 .submit(vec![2], default_params(), 5)
440 .expect("second submit should succeed");
441
442 let err = sched
443 .submit(vec![3], default_params(), 5)
444 .expect_err("third submit should be rejected");
445
446 assert!(
447 matches!(err, SchedulerError::QueueFull { max_queue_size: 2 }),
448 "unexpected error variant: {err}"
449 );
450 }
451
452 #[test]
455 fn test_request_priority_ordering() {
456 assert!(RequestPriority::Critical > RequestPriority::High);
457 assert!(RequestPriority::High > RequestPriority::Normal);
458 assert!(RequestPriority::Normal > RequestPriority::Low);
459 }
460
461 #[test]
462 fn test_priority_queue_ordering() {
463 let mut sched = ContinuousBatchScheduler::new(1, 64);
464
465 sched
467 .submit_with_priority(vec![1], default_params(), 5, RequestPriority::Low)
468 .expect("submit low");
469 sched
470 .submit_with_priority(vec![2], default_params(), 5, RequestPriority::High)
471 .expect("submit high");
472
473 let front = sched.queue.front().expect("queue should not be empty");
475 assert_eq!(front.priority, RequestPriority::High);
476 }
477
478 #[test]
481 fn test_request_state_transitions() {
482 let req = BatchRequest::new(1, vec![10, 11], default_params(), 5);
483 assert_eq!(req.state, RequestState::Waiting);
484 assert!(!req.is_finished());
485
486 let mut req = req;
487 req.state = RequestState::Prefilling;
488 assert!(!req.is_finished());
489
490 req.state = RequestState::Decoding;
491 assert!(!req.is_finished());
492
493 req.state = RequestState::Completed;
494 assert!(req.is_finished());
495
496 req.state = RequestState::Failed("oops".into());
497 assert!(req.is_finished());
498 }
499
500 #[test]
503 fn test_batch_request_time_to_first_token() {
504 let mut req = BatchRequest::new(42, vec![1, 2, 3], default_params(), 10);
505 assert!(req.time_to_first_token().is_none());
506 assert!(req.total_latency().is_none());
507
508 req.started_at = Some(req.created_at + std::time::Duration::from_millis(10));
510 let ttft = req.time_to_first_token().expect("should have TTFT");
511 assert!(ttft.as_millis() >= 10, "TTFT should be >= 10ms");
512
513 req.completed_at = Some(req.created_at + std::time::Duration::from_millis(50));
514 let lat = req.total_latency().expect("should have latency");
515 assert!(lat.as_millis() >= 50, "latency should be >= 50ms");
516 }
517
518 #[test]
521 fn test_scheduler_drain_completed() {
522 let mut sched = ContinuousBatchScheduler::new(4, 64);
523 let mut engine = make_engine();
524
525 let _id = sched
526 .submit(vec![], default_params(), 2)
527 .expect("submit should succeed");
528
529 sched.run_to_completion(&mut engine);
530
531 let drained = sched.drain_completed();
532 assert!(
533 !drained.is_empty(),
534 "should have at least one completed request"
535 );
536 assert_eq!(
537 sched.completed_count(),
538 0,
539 "completed list should be empty after drain"
540 );
541 }
542
543 #[test]
546 fn test_scheduler_stats() {
547 let mut sched = ContinuousBatchScheduler::new(4, 64);
548 sched
549 .submit(vec![1, 2], default_params(), 5)
550 .expect("submit should succeed");
551 sched
552 .submit(vec![3, 4], default_params(), 5)
553 .expect("submit should succeed");
554
555 let stats = sched.throughput_stats();
556 assert_eq!(stats.total_requests, 2);
557 assert_eq!(stats.queue_depth, 2);
558 assert_eq!(stats.active_count, 0);
559 assert_eq!(stats.total_tokens_generated, 0);
560 }
561
562 #[test]
565 fn test_scheduler_run_to_completion() {
566 let mut sched = ContinuousBatchScheduler::new(4, 64);
567 let mut engine = make_engine();
568
569 let id = sched
571 .submit(vec![], default_params(), 5)
572 .expect("submit should succeed");
573
574 sched.run_to_completion(&mut engine);
575
576 assert!(sched.is_idle(), "scheduler should be idle after completion");
577
578 let result = sched.get_result(id).expect("result should be available");
579 assert!(
580 result.is_finished(),
581 "request should be finished, state={:?}",
582 result.state
583 );
584 }
585
586 #[test]
587 fn test_scheduler_is_idle_initially() {
588 let sched = ContinuousBatchScheduler::new(4, 64);
589 assert!(sched.is_idle());
590 assert_eq!(sched.active_count(), 0);
591 assert_eq!(sched.queue_depth(), 0);
592 assert_eq!(sched.completed_count(), 0);
593 }
594}