Skip to main content

cbtop/continuous_batcher/
mod.rs

1//! ContinuousBatcher Implementation (PMAT-015)
2//!
3//! Implements continuous batching for LLM inference per cbtop spec ยง19.
4//!
5//! # Overview
6//!
7//! Continuous batching processes inference requests dynamically, allowing
8//! new requests to join and completed requests to leave mid-batch.
9//!
10//! # Citations
11//!
12//! - [Yu et al. 2022] "ORCA: Continuous Batching for LLM Inference" OSDI
13//! - [Leviathan et al. 2023] "Fast Inference from Transformers via Speculative Decoding" ICML
14//! - [Chen et al. 2023] "Accelerating LLM Decoding with Speculative Sampling" arXiv
15
16mod request;
17mod schedule;
18mod speculative;
19
20pub use request::{InferenceRequest, Priority, SequenceGroup, Token};
21pub use schedule::{BatchSchedule, BatcherStats, SchedulingPolicy, TokenOutput};
22pub use speculative::{ExponentialMovingAverage, SpeculativeDecoder, SpeculativeOutput};
23
24use std::collections::VecDeque;
25use std::fmt;
26use std::time::Instant;
27
28use crate::paged_kv::SeqId;
29
30/// Continuous batching scheduler for LLM inference.
31///
32/// Processes requests as they arrive without waiting for batch completion.
33/// Based on ORCA continuous batching algorithm.
34#[derive(Debug)]
35pub struct ContinuousBatcher {
36    /// Maximum batch size (GPU memory limited)
37    max_batch_size: usize,
38    /// Maximum sequence length
39    max_seq_len: usize,
40    /// Active sequences in current batch
41    running: Vec<SequenceGroup>,
42    /// Waiting queue (sorted by policy)
43    waiting: VecDeque<SequenceGroup>,
44    /// Swapped sequences (offloaded to CPU)
45    swapped: Vec<SequenceGroup>,
46    /// Scheduling policy
47    policy: SchedulingPolicy,
48    /// Statistics
49    stats: BatcherStats,
50    /// Memory threshold for preemption (0.0-1.0)
51    memory_threshold: f64,
52}
53
54impl ContinuousBatcher {
55    /// Create a new continuous batcher.
56    pub fn new(max_batch_size: usize, max_seq_len: usize) -> Self {
57        Self {
58            max_batch_size,
59            max_seq_len,
60            running: Vec::new(),
61            waiting: VecDeque::new(),
62            swapped: Vec::new(),
63            policy: SchedulingPolicy::default(),
64            stats: BatcherStats {
65                start_time: Some(Instant::now()),
66                ..Default::default()
67            },
68            memory_threshold: 0.9,
69        }
70    }
71
72    /// Set scheduling policy.
73    pub fn with_policy(mut self, policy: SchedulingPolicy) -> Self {
74        self.policy = policy;
75        self
76    }
77
78    /// Set memory threshold for preemption.
79    pub fn with_memory_threshold(mut self, threshold: f64) -> Self {
80        self.memory_threshold = threshold.clamp(0.0, 1.0);
81        self
82    }
83
84    /// Get scheduling policy.
85    pub fn policy(&self) -> &SchedulingPolicy {
86        &self.policy
87    }
88
89    /// Get max batch size.
90    pub fn max_batch_size(&self) -> usize {
91        self.max_batch_size
92    }
93
94    /// Get max sequence length.
95    pub fn max_seq_len(&self) -> usize {
96        self.max_seq_len
97    }
98
99    /// Get number of running sequences.
100    pub fn running_count(&self) -> usize {
101        self.running.len()
102    }
103
104    /// Get number of waiting sequences.
105    pub fn waiting_count(&self) -> usize {
106        self.waiting.len()
107    }
108
109    /// Get number of swapped sequences.
110    pub fn swapped_count(&self) -> usize {
111        self.swapped.len()
112    }
113
114    /// Get statistics.
115    pub fn stats(&self) -> &BatcherStats {
116        &self.stats
117    }
118
119    /// Current throughput (tokens/sec).
120    pub fn throughput(&self) -> f64 {
121        self.stats.throughput()
122    }
123
124    /// Add new inference request.
125    pub fn add_request(&mut self, request: InferenceRequest) {
126        let seq_group = SequenceGroup::new(request);
127        self.insert_waiting(seq_group);
128    }
129
130    /// Insert into waiting queue according to policy.
131    fn insert_waiting(&mut self, seq_group: SequenceGroup) {
132        match &self.policy {
133            SchedulingPolicy::FCFS => {
134                // Add to back of queue
135                self.waiting.push_back(seq_group);
136            }
137            SchedulingPolicy::SJF => {
138                // Insert sorted by estimated tokens (shortest first)
139                let insert_idx = self
140                    .waiting
141                    .iter()
142                    .position(|s| s.request.estimated_tokens > seq_group.request.estimated_tokens)
143                    .unwrap_or(self.waiting.len());
144                self.waiting.insert(insert_idx, seq_group);
145            }
146            SchedulingPolicy::Priority { .. } => {
147                // Insert sorted by priority (highest first)
148                let insert_idx = self
149                    .waiting
150                    .iter()
151                    .position(|s| s.request.priority < seq_group.request.priority)
152                    .unwrap_or(self.waiting.len());
153                self.waiting.insert(insert_idx, seq_group);
154            }
155            SchedulingPolicy::FairShare => {
156                // Simple round-robin for now
157                self.waiting.push_back(seq_group);
158            }
159        }
160    }
161
162    /// Schedule next iteration batch.
163    pub fn schedule(&mut self) -> BatchSchedule {
164        // Remove finished sequences from running
165        self.running.retain(|s| !s.is_finished);
166
167        // Calculate available slots
168        let _available = self.max_batch_size.saturating_sub(self.running.len());
169
170        // Promote from waiting to running
171        let mut prefill_count = 0;
172        while !self.waiting.is_empty() && self.running.len() < self.max_batch_size {
173            if let Some(seq_group) = self.waiting.pop_front() {
174                // Check if sequence fits
175                if seq_group.total_tokens() <= self.max_seq_len {
176                    prefill_count += 1;
177                    self.running.push(seq_group);
178                } else {
179                    // Too long, put back
180                    self.waiting.push_front(seq_group);
181                    break;
182                }
183            }
184        }
185
186        // Swap in from swapped if space available
187        while !self.swapped.is_empty() && self.running.len() < self.max_batch_size {
188            if let Some(seq_group) = self.swapped.pop() {
189                self.running.push(seq_group);
190                self.stats.total_swaps += 1;
191            }
192        }
193
194        // Build schedule
195        let sequence_ids: Vec<SeqId> = self.running.iter().map(|s| s.request.id).collect();
196        let total_tokens: usize = self.running.iter().map(|s| s.total_tokens()).sum();
197        let decode_count = self.running.len() - prefill_count;
198
199        BatchSchedule {
200            batch_size: sequence_ids.len(),
201            sequence_ids,
202            total_tokens,
203            prefill_count,
204            decode_count,
205        }
206    }
207
208    /// Process outputs from a decode step.
209    pub fn process_outputs(&mut self, outputs: Vec<TokenOutput>) {
210        for output in outputs {
211            // Find the sequence and add the token
212            if let Some(seq_group) = self
213                .running
214                .iter_mut()
215                .find(|s| s.request.id == output.seq_id)
216            {
217                seq_group.add_token(output.token);
218                self.stats.total_tokens += 1;
219
220                // Check for EOS
221                if output.is_eos {
222                    seq_group.finish();
223                    self.stats.total_requests += 1;
224                }
225            }
226        }
227    }
228
229    /// Preempt sequences under memory pressure.
230    pub fn preempt(&mut self, num_to_preempt: usize) -> Vec<SeqId> {
231        let mut preempted = Vec::new();
232
233        // Preempt from running (lowest priority or longest)
234        for _ in 0..num_to_preempt {
235            if self.running.is_empty() {
236                break;
237            }
238
239            // Find victim (longest sequence for simplicity)
240            let victim_idx = self
241                .running
242                .iter()
243                .enumerate()
244                .max_by_key(|(_, s)| s.total_tokens())
245                .map(|(i, _)| i);
246
247            if let Some(idx) = victim_idx {
248                let victim = self.running.remove(idx);
249                preempted.push(victim.request.id);
250                self.swapped.push(victim);
251                self.stats.total_preemptions += 1;
252            }
253        }
254
255        preempted
256    }
257
258    /// Check if preemption is needed (simulated memory pressure).
259    pub fn needs_preemption(&self, current_utilization: f64) -> bool {
260        current_utilization >= self.memory_threshold
261            && !self.running.is_empty()
262            && matches!(
263                self.policy,
264                SchedulingPolicy::Priority {
265                    preempt_enabled: true
266                }
267            )
268    }
269
270    /// Get sequence by ID.
271    pub fn get_sequence(&self, seq_id: SeqId) -> Option<&SequenceGroup> {
272        self.running
273            .iter()
274            .chain(self.waiting.iter())
275            .chain(self.swapped.iter())
276            .find(|s| s.request.id == seq_id)
277    }
278
279    /// Get all sequence IDs.
280    pub fn all_sequence_ids(&self) -> Vec<SeqId> {
281        self.running
282            .iter()
283            .chain(self.waiting.iter())
284            .chain(self.swapped.iter())
285            .map(|s| s.request.id)
286            .collect()
287    }
288}
289
290impl fmt::Display for ContinuousBatcher {
291    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292        writeln!(f, "ContinuousBatcher")?;
293        writeln!(
294            f,
295            "  Policy: {} | Max Batch: {} | Max Seq: {}",
296            self.policy, self.max_batch_size, self.max_seq_len
297        )?;
298        writeln!(
299            f,
300            "  Running: {} | Waiting: {} | Swapped: {}",
301            self.running.len(),
302            self.waiting.len(),
303            self.swapped.len()
304        )?;
305        writeln!(f, "  Throughput: {:.1} tok/s", self.throughput())?;
306        writeln!(
307            f,
308            "  Stats: tokens={}, requests={}, preemptions={}, swaps={}",
309            self.stats.total_tokens,
310            self.stats.total_requests,
311            self.stats.total_preemptions,
312            self.stats.total_swaps
313        )?;
314        Ok(())
315    }
316}
317
318#[cfg(test)]
319mod tests;