Skip to main content

cbtop/continuous_batcher/
schedule.rs

1//! Scheduling types for continuous batching.
2
3use std::fmt;
4use std::time::Instant;
5
6use crate::paged_kv::SeqId;
7
8use super::request::Token;
9
10/// Scheduling policy for request prioritization.
11#[derive(Debug, Clone, PartialEq)]
12pub enum SchedulingPolicy {
13    /// First-come, first-served
14    FCFS,
15    /// Shortest job first (by estimated tokens)
16    SJF,
17    /// Priority-based (API tiers)
18    Priority { preempt_enabled: bool },
19    /// Fair share (equal GPU time per user)
20    FairShare,
21}
22
23impl Default for SchedulingPolicy {
24    fn default() -> Self {
25        SchedulingPolicy::FCFS
26    }
27}
28
29impl fmt::Display for SchedulingPolicy {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self {
32            SchedulingPolicy::FCFS => write!(f, "FCFS"),
33            SchedulingPolicy::SJF => write!(f, "SJF"),
34            SchedulingPolicy::Priority { preempt_enabled } => {
35                write!(f, "Priority(preempt={})", preempt_enabled)
36            }
37            SchedulingPolicy::FairShare => write!(f, "FairShare"),
38        }
39    }
40}
41
42/// Batch schedule result.
43#[derive(Debug, Clone)]
44pub struct BatchSchedule {
45    /// Sequence IDs in this batch
46    pub sequence_ids: Vec<SeqId>,
47    /// Number of sequences in batch
48    pub batch_size: usize,
49    /// Total tokens to process
50    pub total_tokens: usize,
51    /// Prefill sequences (first token)
52    pub prefill_count: usize,
53    /// Decode sequences (continuation)
54    pub decode_count: usize,
55}
56
57impl BatchSchedule {
58    /// Create empty schedule.
59    pub fn empty() -> Self {
60        Self {
61            sequence_ids: Vec::new(),
62            batch_size: 0,
63            total_tokens: 0,
64            prefill_count: 0,
65            decode_count: 0,
66        }
67    }
68
69    /// Check if schedule is empty.
70    pub fn is_empty(&self) -> bool {
71        self.batch_size == 0
72    }
73}
74
75/// Token output from a decode step.
76#[derive(Debug, Clone)]
77pub struct TokenOutput {
78    /// Sequence ID
79    pub seq_id: SeqId,
80    /// Generated token
81    pub token: Token,
82    /// Is EOS token?
83    pub is_eos: bool,
84}
85
86/// Batcher statistics.
87#[derive(Debug, Clone, Default)]
88pub struct BatcherStats {
89    /// Total tokens processed
90    pub total_tokens: u64,
91    /// Total requests completed
92    pub total_requests: u64,
93    /// Total preemptions
94    pub total_preemptions: u64,
95    /// Total swaps (CPU<->GPU)
96    pub total_swaps: u64,
97    /// Processing start time
98    pub start_time: Option<Instant>,
99}
100
101impl BatcherStats {
102    /// Calculate throughput (tokens/sec).
103    pub fn throughput(&self) -> f64 {
104        if let Some(start) = self.start_time {
105            let elapsed = start.elapsed().as_secs_f64();
106            if elapsed > 0.0 {
107                return self.total_tokens as f64 / elapsed;
108            }
109        }
110        0.0
111    }
112}