cbtop/continuous_batcher/
mod.rs1mod request;
17mod schedule;
18mod speculative;
19
20pub use request::{InferenceRequest, Priority, SequenceGroup, Token};
21pub use schedule::{BatchSchedule, BatcherStats, SchedulingPolicy, TokenOutput};
22pub use speculative::{ExponentialMovingAverage, SpeculativeDecoder, SpeculativeOutput};
23
24use std::collections::VecDeque;
25use std::fmt;
26use std::time::Instant;
27
28use crate::paged_kv::SeqId;
29
30#[derive(Debug)]
35pub struct ContinuousBatcher {
36 max_batch_size: usize,
38 max_seq_len: usize,
40 running: Vec<SequenceGroup>,
42 waiting: VecDeque<SequenceGroup>,
44 swapped: Vec<SequenceGroup>,
46 policy: SchedulingPolicy,
48 stats: BatcherStats,
50 memory_threshold: f64,
52}
53
54impl ContinuousBatcher {
55 pub fn new(max_batch_size: usize, max_seq_len: usize) -> Self {
57 Self {
58 max_batch_size,
59 max_seq_len,
60 running: Vec::new(),
61 waiting: VecDeque::new(),
62 swapped: Vec::new(),
63 policy: SchedulingPolicy::default(),
64 stats: BatcherStats {
65 start_time: Some(Instant::now()),
66 ..Default::default()
67 },
68 memory_threshold: 0.9,
69 }
70 }
71
72 pub fn with_policy(mut self, policy: SchedulingPolicy) -> Self {
74 self.policy = policy;
75 self
76 }
77
78 pub fn with_memory_threshold(mut self, threshold: f64) -> Self {
80 self.memory_threshold = threshold.clamp(0.0, 1.0);
81 self
82 }
83
84 pub fn policy(&self) -> &SchedulingPolicy {
86 &self.policy
87 }
88
89 pub fn max_batch_size(&self) -> usize {
91 self.max_batch_size
92 }
93
94 pub fn max_seq_len(&self) -> usize {
96 self.max_seq_len
97 }
98
99 pub fn running_count(&self) -> usize {
101 self.running.len()
102 }
103
104 pub fn waiting_count(&self) -> usize {
106 self.waiting.len()
107 }
108
109 pub fn swapped_count(&self) -> usize {
111 self.swapped.len()
112 }
113
114 pub fn stats(&self) -> &BatcherStats {
116 &self.stats
117 }
118
119 pub fn throughput(&self) -> f64 {
121 self.stats.throughput()
122 }
123
124 pub fn add_request(&mut self, request: InferenceRequest) {
126 let seq_group = SequenceGroup::new(request);
127 self.insert_waiting(seq_group);
128 }
129
130 fn insert_waiting(&mut self, seq_group: SequenceGroup) {
132 match &self.policy {
133 SchedulingPolicy::FCFS => {
134 self.waiting.push_back(seq_group);
136 }
137 SchedulingPolicy::SJF => {
138 let insert_idx = self
140 .waiting
141 .iter()
142 .position(|s| s.request.estimated_tokens > seq_group.request.estimated_tokens)
143 .unwrap_or(self.waiting.len());
144 self.waiting.insert(insert_idx, seq_group);
145 }
146 SchedulingPolicy::Priority { .. } => {
147 let insert_idx = self
149 .waiting
150 .iter()
151 .position(|s| s.request.priority < seq_group.request.priority)
152 .unwrap_or(self.waiting.len());
153 self.waiting.insert(insert_idx, seq_group);
154 }
155 SchedulingPolicy::FairShare => {
156 self.waiting.push_back(seq_group);
158 }
159 }
160 }
161
162 pub fn schedule(&mut self) -> BatchSchedule {
164 self.running.retain(|s| !s.is_finished);
166
167 let _available = self.max_batch_size.saturating_sub(self.running.len());
169
170 let mut prefill_count = 0;
172 while !self.waiting.is_empty() && self.running.len() < self.max_batch_size {
173 if let Some(seq_group) = self.waiting.pop_front() {
174 if seq_group.total_tokens() <= self.max_seq_len {
176 prefill_count += 1;
177 self.running.push(seq_group);
178 } else {
179 self.waiting.push_front(seq_group);
181 break;
182 }
183 }
184 }
185
186 while !self.swapped.is_empty() && self.running.len() < self.max_batch_size {
188 if let Some(seq_group) = self.swapped.pop() {
189 self.running.push(seq_group);
190 self.stats.total_swaps += 1;
191 }
192 }
193
194 let sequence_ids: Vec<SeqId> = self.running.iter().map(|s| s.request.id).collect();
196 let total_tokens: usize = self.running.iter().map(|s| s.total_tokens()).sum();
197 let decode_count = self.running.len() - prefill_count;
198
199 BatchSchedule {
200 batch_size: sequence_ids.len(),
201 sequence_ids,
202 total_tokens,
203 prefill_count,
204 decode_count,
205 }
206 }
207
208 pub fn process_outputs(&mut self, outputs: Vec<TokenOutput>) {
210 for output in outputs {
211 if let Some(seq_group) = self
213 .running
214 .iter_mut()
215 .find(|s| s.request.id == output.seq_id)
216 {
217 seq_group.add_token(output.token);
218 self.stats.total_tokens += 1;
219
220 if output.is_eos {
222 seq_group.finish();
223 self.stats.total_requests += 1;
224 }
225 }
226 }
227 }
228
229 pub fn preempt(&mut self, num_to_preempt: usize) -> Vec<SeqId> {
231 let mut preempted = Vec::new();
232
233 for _ in 0..num_to_preempt {
235 if self.running.is_empty() {
236 break;
237 }
238
239 let victim_idx = self
241 .running
242 .iter()
243 .enumerate()
244 .max_by_key(|(_, s)| s.total_tokens())
245 .map(|(i, _)| i);
246
247 if let Some(idx) = victim_idx {
248 let victim = self.running.remove(idx);
249 preempted.push(victim.request.id);
250 self.swapped.push(victim);
251 self.stats.total_preemptions += 1;
252 }
253 }
254
255 preempted
256 }
257
258 pub fn needs_preemption(&self, current_utilization: f64) -> bool {
260 current_utilization >= self.memory_threshold
261 && !self.running.is_empty()
262 && matches!(
263 self.policy,
264 SchedulingPolicy::Priority {
265 preempt_enabled: true
266 }
267 )
268 }
269
270 pub fn get_sequence(&self, seq_id: SeqId) -> Option<&SequenceGroup> {
272 self.running
273 .iter()
274 .chain(self.waiting.iter())
275 .chain(self.swapped.iter())
276 .find(|s| s.request.id == seq_id)
277 }
278
279 pub fn all_sequence_ids(&self) -> Vec<SeqId> {
281 self.running
282 .iter()
283 .chain(self.waiting.iter())
284 .chain(self.swapped.iter())
285 .map(|s| s.request.id)
286 .collect()
287 }
288}
289
290impl fmt::Display for ContinuousBatcher {
291 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292 writeln!(f, "ContinuousBatcher")?;
293 writeln!(
294 f,
295 " Policy: {} | Max Batch: {} | Max Seq: {}",
296 self.policy, self.max_batch_size, self.max_seq_len
297 )?;
298 writeln!(
299 f,
300 " Running: {} | Waiting: {} | Swapped: {}",
301 self.running.len(),
302 self.waiting.len(),
303 self.swapped.len()
304 )?;
305 writeln!(f, " Throughput: {:.1} tok/s", self.throughput())?;
306 writeln!(
307 f,
308 " Stats: tokens={}, requests={}, preemptions={}, swaps={}",
309 self.stats.total_tokens,
310 self.stats.total_requests,
311 self.stats.total_preemptions,
312 self.stats.total_swaps
313 )?;
314 Ok(())
315 }
316}
317
318#[cfg(test)]
319mod tests;