1use std::collections::VecDeque;
31
32use crate::error::{DnnError, DnnResult};
33
34pub type RequestId = u64;
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
43pub enum Priority {
44 Low = 0,
46 Normal = 1,
48 High = 2,
50}
51
52#[derive(Debug, Clone)]
54pub struct InferenceRequest {
55 pub request_id: RequestId,
57 pub sequence_length: usize,
59 pub max_new_tokens: usize,
61 pub priority: Priority,
63 pub arrival_time_ns: u64,
65 pub deadline_ns: Option<u64>,
67}
68
69#[derive(Debug, Clone)]
71pub struct BatchSlot {
72 pub slot_id: usize,
74 pub request_id: RequestId,
76 pub current_seq_len: usize,
78 pub max_seq_len: usize,
80 pub is_prefill: bool,
82 pub is_active: bool,
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum SchedulingPolicy {
89 Fcfs,
91 ShortestJobFirst,
93 PriorityBased,
95 DeadlineAware,
97 Orca,
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub enum PreemptionPolicy {
104 Recompute,
106 Swap,
108}
109
110#[derive(Debug, Clone)]
112pub struct BatchConfig {
113 pub max_batch_size: usize,
115 pub max_total_tokens: usize,
117 pub max_sequence_length: usize,
119 pub prefill_batch_size: usize,
121 pub decode_batch_size: usize,
123 pub scheduling_policy: SchedulingPolicy,
125}
126
127#[derive(Debug, Clone)]
129pub struct BatchDecision {
130 pub prefill_requests: Vec<RequestId>,
132 pub decode_requests: Vec<RequestId>,
134 pub preempted: Vec<RequestId>,
136 pub total_tokens: usize,
138}
139
140#[derive(Debug)]
146struct BatchState {
147 active_slots: Vec<BatchSlot>,
149 total_tokens: usize,
151 prefill_queue: VecDeque<InferenceRequest>,
153 decode_queue: VecDeque<RequestId>,
155 preempted_queue: VecDeque<InferenceRequest>,
157}
158
159impl BatchState {
160 fn new() -> Self {
161 Self {
162 active_slots: Vec::new(),
163 total_tokens: 0,
164 prefill_queue: VecDeque::new(),
165 decode_queue: VecDeque::new(),
166 preempted_queue: VecDeque::new(),
167 }
168 }
169}
170
171#[derive(Debug)]
182pub struct ContinuousBatcher {
183 config: BatchConfig,
184 state: BatchState,
185 next_slot_id: usize,
186 completed_count: u64,
187}
188
189impl ContinuousBatcher {
190 pub fn new(config: BatchConfig) -> Self {
192 Self {
193 config,
194 state: BatchState::new(),
195 next_slot_id: 0,
196 completed_count: 0,
197 }
198 }
199
200 pub fn add_request(&mut self, request: InferenceRequest) -> DnnResult<RequestId> {
202 if request.sequence_length == 0 {
203 return Err(DnnError::InvalidArgument(
204 "sequence_length must be > 0".into(),
205 ));
206 }
207 if request.sequence_length > self.config.max_sequence_length {
208 return Err(DnnError::InvalidArgument(format!(
209 "sequence_length {} exceeds max_sequence_length {}",
210 request.sequence_length, self.config.max_sequence_length
211 )));
212 }
213 let id = request.request_id;
214 self.state.prefill_queue.push_back(request);
215 Ok(id)
216 }
217
218 pub fn step(&mut self) -> DnnResult<BatchDecision> {
223 let mut decision = BatchDecision {
224 prefill_requests: Vec::new(),
225 decode_requests: Vec::new(),
226 preempted: Vec::new(),
227 total_tokens: 0,
228 };
229
230 let decode_ids: Vec<RequestId> = self
232 .state
233 .active_slots
234 .iter()
235 .filter(|s| s.is_active && !s.is_prefill)
236 .map(|s| s.request_id)
237 .collect();
238
239 let decode_count = decode_ids.len().min(self.config.decode_batch_size);
240 let decode_tokens: usize = self
241 .state
242 .active_slots
243 .iter()
244 .filter(|s| s.is_active && !s.is_prefill)
245 .take(decode_count)
246 .map(|s| s.current_seq_len + 1) .sum();
248
249 decision.decode_requests = decode_ids.into_iter().take(decode_count).collect();
250
251 self.sort_prefill_queue();
253
254 let mut prefill_budget = self
256 .config
257 .prefill_batch_size
258 .min(self.config.max_total_tokens.saturating_sub(decode_tokens));
259
260 let mut admitted = Vec::new();
261 while !self.state.prefill_queue.is_empty()
262 && self.state.active_slots.len() + admitted.len() < self.config.max_batch_size
263 {
264 let req = match self.state.prefill_queue.front() {
266 Some(r) => r,
267 None => break,
268 };
269 if req.sequence_length > prefill_budget {
270 break;
271 }
272 let req = self
274 .state
275 .prefill_queue
276 .pop_front()
277 .ok_or_else(|| DnnError::InvalidArgument("empty queue".into()))?;
278
279 prefill_budget = prefill_budget.saturating_sub(req.sequence_length);
280
281 let slot = BatchSlot {
282 slot_id: self.next_slot_id,
283 request_id: req.request_id,
284 current_seq_len: req.sequence_length,
285 max_seq_len: req.sequence_length + req.max_new_tokens,
286 is_prefill: true,
287 is_active: true,
288 };
289 self.next_slot_id += 1;
290 decision.prefill_requests.push(req.request_id);
291 admitted.push(slot);
292 }
293
294 for slot in &mut admitted {
296 slot.is_prefill = false;
297 }
298 self.state.active_slots.extend(admitted);
299
300 for slot in &mut self.state.active_slots {
302 if slot.is_active && !slot.is_prefill {
303 slot.current_seq_len = slot.current_seq_len.saturating_add(1);
304 }
305 }
306
307 decision.total_tokens = self
308 .state
309 .active_slots
310 .iter()
311 .filter(|s| s.is_active)
312 .map(|s| s.current_seq_len)
313 .sum();
314
315 self.state.total_tokens = decision.total_tokens;
316
317 Ok(decision)
318 }
319
320 pub fn complete_request(&mut self, request_id: RequestId) -> DnnResult<()> {
322 let pos = self
323 .state
324 .active_slots
325 .iter()
326 .position(|s| s.request_id == request_id)
327 .ok_or_else(|| {
328 DnnError::InvalidArgument(format!("request {request_id} not in active slots"))
329 })?;
330 let slot = &self.state.active_slots[pos];
331 self.state.total_tokens = self.state.total_tokens.saturating_sub(slot.current_seq_len);
332 self.state.active_slots.remove(pos);
333 self.state.decode_queue.retain(|id| *id != request_id);
334 self.completed_count += 1;
335 Ok(())
336 }
337
338 pub fn preempt(&mut self, request_id: RequestId) -> DnnResult<()> {
341 let pos = self
342 .state
343 .active_slots
344 .iter()
345 .position(|s| s.request_id == request_id)
346 .ok_or_else(|| {
347 DnnError::InvalidArgument(format!("request {request_id} not in active slots"))
348 })?;
349 let slot = self.state.active_slots.remove(pos);
350 self.state.total_tokens = self.state.total_tokens.saturating_sub(slot.current_seq_len);
351 self.state.decode_queue.retain(|id| *id != request_id);
352
353 let preempted_req = InferenceRequest {
355 request_id,
356 sequence_length: slot.current_seq_len,
357 max_new_tokens: slot.max_seq_len.saturating_sub(slot.current_seq_len),
358 priority: Priority::Normal,
359 arrival_time_ns: 0,
360 deadline_ns: None,
361 };
362 self.state.preempted_queue.push_back(preempted_req);
363 Ok(())
364 }
365
366 pub fn active_requests(&self) -> usize {
368 self.state
369 .active_slots
370 .iter()
371 .filter(|s| s.is_active)
372 .count()
373 }
374
375 pub fn pending_requests(&self) -> usize {
377 self.state.prefill_queue.len() + self.state.preempted_queue.len()
378 }
379
380 pub fn throughput_tokens_per_step(&self) -> usize {
382 self.state.total_tokens
383 }
384
385 fn sort_prefill_queue(&mut self) {
388 let queue = &mut self.state.prefill_queue;
389 let policy = self.config.scheduling_policy;
390
391 let mut vec: Vec<InferenceRequest> = queue.drain(..).collect();
392 match policy {
393 SchedulingPolicy::Fcfs => {
394 vec.sort_by_key(|r| r.arrival_time_ns);
396 }
397 SchedulingPolicy::ShortestJobFirst => {
398 vec.sort_by_key(|r| r.max_new_tokens);
399 }
400 SchedulingPolicy::PriorityBased => {
401 vec.sort_by(|a, b| {
403 b.priority
404 .cmp(&a.priority)
405 .then(a.arrival_time_ns.cmp(&b.arrival_time_ns))
406 });
407 }
408 SchedulingPolicy::DeadlineAware => {
409 vec.sort_by(|a, b| {
411 let da = a.deadline_ns.unwrap_or(u64::MAX);
412 let db = b.deadline_ns.unwrap_or(u64::MAX);
413 da.cmp(&db).then(a.arrival_time_ns.cmp(&b.arrival_time_ns))
414 });
415 }
416 SchedulingPolicy::Orca => {
417 vec.sort_by_key(|r| r.arrival_time_ns);
419 }
420 }
421 *queue = VecDeque::from(vec);
422 }
423}
424
425#[derive(Debug)]
431pub struct TokenBudgetAllocator {
432 max_total_tokens: usize,
433 allocated: usize,
434}
435
436impl TokenBudgetAllocator {
437 pub fn new(max_total_tokens: usize) -> Self {
439 Self {
440 max_total_tokens,
441 allocated: 0,
442 }
443 }
444
445 pub fn allocate_prefill(&mut self, seq_len: usize) -> Option<usize> {
449 if self.allocated + seq_len > self.max_total_tokens {
450 return None;
451 }
452 let slot = self.allocated;
453 self.allocated += seq_len;
454 Some(slot)
455 }
456
457 pub fn allocate_decode(&mut self, count: usize) -> usize {
459 let remaining = self.max_total_tokens.saturating_sub(self.allocated);
460 let actual = count.min(remaining);
461 self.allocated += actual;
462 actual
463 }
464
465 pub fn release(&mut self, tokens: usize) {
467 self.allocated = self.allocated.saturating_sub(tokens);
468 }
469
470 pub fn utilization(&self) -> f64 {
472 if self.max_total_tokens == 0 {
473 return 0.0;
474 }
475 self.allocated as f64 / self.max_total_tokens as f64
476 }
477}
478
479#[derive(Debug)]
489pub struct PagedKvManager {
490 num_blocks: usize,
491 block_size: usize,
492 free_map: Vec<bool>,
494 ref_counts: Vec<usize>,
496}
497
498impl PagedKvManager {
499 pub fn new(num_blocks: usize, block_size: usize) -> Self {
502 Self {
503 num_blocks,
504 block_size,
505 free_map: vec![true; num_blocks],
506 ref_counts: vec![0; num_blocks],
507 }
508 }
509
510 pub fn allocate(&mut self, num_tokens: usize) -> DnnResult<Vec<usize>> {
515 if self.block_size == 0 {
516 return Err(DnnError::InvalidArgument("block_size is 0".into()));
517 }
518 let blocks_needed = num_tokens.div_ceil(self.block_size);
519 if !self.can_allocate(num_tokens) {
520 return Err(DnnError::InvalidArgument(format!(
521 "not enough free blocks: need {blocks_needed}, have {}",
522 self.free_block_count()
523 )));
524 }
525 let mut ids = Vec::with_capacity(blocks_needed);
526 for (i, free) in self.free_map.iter_mut().enumerate() {
527 if ids.len() >= blocks_needed {
528 break;
529 }
530 if *free {
531 *free = false;
532 self.ref_counts[i] = 1;
533 ids.push(i);
534 }
535 }
536 Ok(ids)
537 }
538
539 pub fn free(&mut self, block_ids: &[usize]) {
542 for &id in block_ids {
543 if id < self.num_blocks {
544 self.ref_counts[id] = self.ref_counts[id].saturating_sub(1);
545 if self.ref_counts[id] == 0 {
546 self.free_map[id] = true;
547 }
548 }
549 }
550 }
551
552 pub fn copy_on_write(&mut self, block_id: usize) -> DnnResult<usize> {
557 if block_id >= self.num_blocks {
558 return Err(DnnError::InvalidArgument(format!(
559 "block_id {block_id} out of range (max {})",
560 self.num_blocks
561 )));
562 }
563 let new_id =
565 self.free_map.iter().position(|&free| free).ok_or_else(|| {
566 DnnError::InvalidArgument("no free blocks for copy-on-write".into())
567 })?;
568 self.free_map[new_id] = false;
569 self.ref_counts[new_id] = 1;
570
571 self.ref_counts[block_id] = self.ref_counts[block_id].saturating_sub(1);
573 if self.ref_counts[block_id] == 0 {
574 self.free_map[block_id] = true;
575 }
576
577 Ok(new_id)
578 }
579
580 pub fn usage(&self) -> (usize, usize) {
582 let used = self.free_map.iter().filter(|&&free| !free).count();
583 (used, self.num_blocks)
584 }
585
586 pub fn can_allocate(&self, num_tokens: usize) -> bool {
588 if self.block_size == 0 {
589 return false;
590 }
591 let needed = num_tokens.div_ceil(self.block_size);
592 self.free_block_count() >= needed
593 }
594
595 fn free_block_count(&self) -> usize {
596 self.free_map.iter().filter(|&&f| f).count()
597 }
598}
599
600#[derive(Debug, Clone)]
610pub struct LcgRng {
611 state: u64,
612}
613
614impl LcgRng {
615 const MUL: u64 = 6_364_136_223_846_793_005;
617 const ADD: u64 = 1_442_695_040_888_963_407;
619
620 #[must_use]
625 pub fn new(seed: u64) -> Self {
626 Self {
627 state: seed
628 .wrapping_mul(0x9E37_79B9_7F4A_7C15)
629 .wrapping_add(Self::ADD),
630 }
631 }
632
633 #[inline]
635 pub fn next_u64(&mut self) -> u64 {
636 self.state = self.state.wrapping_mul(Self::MUL).wrapping_add(Self::ADD);
637 self.state
638 }
639
640 #[inline]
645 pub fn next_f64(&mut self) -> f64 {
646 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
647 }
648
649 pub fn sample_categorical(&mut self, weights: &[f64]) -> Option<usize> {
658 let total: f64 = weights.iter().sum();
659 if weights.is_empty() || !total.is_finite() || total <= 0.0 {
660 return None;
661 }
662 let threshold = self.next_f64() * total;
663 let mut acc = 0.0;
664 for (idx, &w) in weights.iter().enumerate() {
665 acc += w.max(0.0);
666 if threshold < acc {
667 return Some(idx);
668 }
669 }
670 weights.iter().rposition(|&w| w > 0.0)
672 }
673}
674
675#[derive(Debug, Clone, PartialEq, Eq)]
681pub struct SpeculativeResult {
682 pub tokens: Vec<u32>,
686 pub accepted: usize,
688 pub rejected: usize,
691}
692
693#[derive(Debug)]
718pub struct SpeculativeDecoder {
719 draft_length: usize,
720 rng: LcgRng,
721 total_proposed: u64,
722 total_accepted: u64,
723 rounds: u64,
724}
725
726impl SpeculativeDecoder {
727 const DEFAULT_SEED: u64 = 0x5350_4543; #[must_use]
733 pub fn new(draft_length: usize) -> Self {
734 Self::with_seed(draft_length, Self::DEFAULT_SEED)
735 }
736
737 #[must_use]
742 pub fn with_seed(draft_length: usize, seed: u64) -> Self {
743 Self {
744 draft_length,
745 rng: LcgRng::new(seed),
746 total_proposed: 0,
747 total_accepted: 0,
748 rounds: 0,
749 }
750 }
751
752 #[must_use]
754 pub fn draft_length(&self) -> usize {
755 self.draft_length
756 }
757
758 pub fn propose_tokens(&mut self, draft_probs: &[Vec<f64>]) -> DnnResult<Vec<DraftedToken>> {
780 let count = draft_probs.len().min(self.draft_length);
781 let mut drafted = Vec::with_capacity(count);
782 for (position, dist) in draft_probs.iter().take(count).enumerate() {
783 let token = self.rng.sample_categorical(dist).ok_or_else(|| {
784 DnnError::InvalidArgument(format!(
785 "draft distribution at position {position} has no positive, finite mass"
786 ))
787 })?;
788 let total: f64 = dist.iter().map(|p| p.max(0.0)).sum();
789 let draft_prob = dist[token].max(0.0) / total;
791 drafted.push(DraftedToken {
792 token_id: token as u32,
793 draft_prob,
794 });
795 }
796 Ok(drafted)
797 }
798
799 pub fn verify_and_accept(
831 &mut self,
832 drafted: &[DraftedToken],
833 target_dists: &[Vec<f64>],
834 ) -> DnnResult<SpeculativeResult> {
835 let gamma = drafted.len();
836 if target_dists.len() <= gamma {
839 return Err(DnnError::InvalidArgument(format!(
840 "target_dists must have at least {} rows (one per drafted token \
841 plus a bonus row), got {}",
842 gamma + 1,
843 target_dists.len(),
844 )));
845 }
846
847 let mut tokens = Vec::with_capacity(gamma + 1);
848 for (i, draft) in drafted.iter().enumerate() {
849 let token = draft.token_id as usize;
850 let target_dist = &target_dists[i];
851 let target_total: f64 = target_dist.iter().map(|p| p.max(0.0)).sum();
852 let p_target = target_dist
853 .get(token)
854 .copied()
855 .ok_or_else(|| {
856 DnnError::InvalidArgument(format!(
857 "target distribution at position {i} (len {}) does not \
858 contain drafted token id {token}",
859 target_dist.len(),
860 ))
861 })?
862 .max(0.0);
863 let p_target = if target_total > 0.0 {
866 p_target / target_total
867 } else {
868 0.0
869 };
870
871 let accept_ratio = if draft.draft_prob > 0.0 {
875 (p_target / draft.draft_prob).min(1.0)
876 } else {
877 0.0
878 };
879
880 let r = self.rng.next_f64();
881 if r < accept_ratio {
882 tokens.push(draft.token_id);
883 continue;
884 }
885
886 let residual = Self::residual_distribution(target_dist, drafted, i);
888 let correction = self.rng.sample_categorical(&residual).ok_or_else(|| {
889 DnnError::InvalidArgument(format!(
890 "residual distribution at position {i} has no positive mass"
891 ))
892 })?;
893 tokens.push(correction as u32);
894
895 let accepted = i;
896 self.record(gamma, accepted);
897 return Ok(SpeculativeResult {
898 tokens,
899 accepted,
900 rejected: 1,
901 });
902 }
903
904 let bonus_dist = &target_dists[gamma];
906 let bonus = self.rng.sample_categorical(bonus_dist).ok_or_else(|| {
907 DnnError::InvalidArgument(
908 "bonus target distribution has no positive, finite mass".into(),
909 )
910 })?;
911 tokens.push(bonus as u32);
912
913 self.record(gamma, gamma);
914 Ok(SpeculativeResult {
915 tokens,
916 accepted: gamma,
917 rejected: 0,
918 })
919 }
920
921 fn residual_distribution(
932 target_dist: &[f64],
933 drafted: &[DraftedToken],
934 position: usize,
935 ) -> Vec<f64> {
936 let target_total: f64 = target_dist.iter().map(|p| p.max(0.0)).sum();
937 let drafted_token = drafted[position].token_id as usize;
938 let draft_prob = drafted[position].draft_prob;
939
940 let mut residual: Vec<f64> = Vec::with_capacity(target_dist.len());
941 for (idx, &t) in target_dist.iter().enumerate() {
942 let p_target = if target_total > 0.0 {
943 t.max(0.0) / target_total
944 } else {
945 0.0
946 };
947 let p_draft = if idx == drafted_token {
949 draft_prob.max(0.0)
950 } else {
951 0.0
952 };
953 residual.push((p_target - p_draft).max(0.0));
954 }
955
956 let residual_sum: f64 = residual.iter().sum();
957 if residual_sum <= 0.0 {
958 return target_dist.iter().map(|p| p.max(0.0)).collect();
960 }
961 residual
962 }
963
964 fn record(&mut self, proposed: usize, accepted: usize) {
966 self.total_proposed += proposed as u64;
967 self.total_accepted += accepted as u64;
968 self.rounds += 1;
969 }
970
971 #[must_use]
977 pub fn acceptance_rate(&self) -> f64 {
978 if self.total_proposed == 0 {
979 return 0.0;
980 }
981 self.total_accepted as f64 / self.total_proposed as f64
982 }
983
984 #[must_use]
986 pub fn total_proposed(&self) -> u64 {
987 self.total_proposed
988 }
989
990 #[must_use]
992 pub fn total_accepted(&self) -> u64 {
993 self.total_accepted
994 }
995
996 #[must_use]
998 pub fn rounds(&self) -> u64 {
999 self.rounds
1000 }
1001
1002 #[must_use]
1010 pub fn mean_tokens_per_round(&self) -> f64 {
1011 if self.rounds == 0 {
1012 return 0.0;
1013 }
1014 (self.total_accepted + self.rounds) as f64 / self.rounds as f64
1016 }
1017}
1018
1019#[derive(Debug, Clone, Copy, PartialEq)]
1022pub struct DraftedToken {
1023 pub token_id: u32,
1025 pub draft_prob: f64,
1028}
1029
1030#[derive(Debug)]
1036pub struct BatchMetrics {
1037 steps: Vec<(usize, usize, u64)>,
1039 ttft_samples: Vec<u64>,
1041}
1042
1043impl BatchMetrics {
1044 pub fn new() -> Self {
1046 Self {
1047 steps: Vec::new(),
1048 ttft_samples: Vec::new(),
1049 }
1050 }
1051
1052 pub fn record_step(&mut self, prefill_tokens: usize, decode_tokens: usize, latency_us: u64) {
1054 self.steps.push((prefill_tokens, decode_tokens, latency_us));
1055 }
1056
1057 pub fn record_ttft(&mut self, ttft_us: u64) {
1059 self.ttft_samples.push(ttft_us);
1060 }
1061
1062 pub fn avg_prefill_latency(&self) -> f64 {
1064 let prefills: Vec<u64> = self
1065 .steps
1066 .iter()
1067 .filter(|(p, _, _)| *p > 0)
1068 .map(|(_, _, l)| *l)
1069 .collect();
1070 if prefills.is_empty() {
1071 return 0.0;
1072 }
1073 prefills.iter().sum::<u64>() as f64 / prefills.len() as f64
1074 }
1075
1076 pub fn avg_decode_latency(&self) -> f64 {
1078 let decodes: Vec<u64> = self
1079 .steps
1080 .iter()
1081 .filter(|(_, d, _)| *d > 0)
1082 .map(|(_, _, l)| *l)
1083 .collect();
1084 if decodes.is_empty() {
1085 return 0.0;
1086 }
1087 decodes.iter().sum::<u64>() as f64 / decodes.len() as f64
1088 }
1089
1090 pub fn avg_batch_size(&self) -> f64 {
1092 if self.steps.is_empty() {
1093 return 0.0;
1094 }
1095 let total: usize = self.steps.iter().map(|(p, d, _)| p + d).sum();
1096 total as f64 / self.steps.len() as f64
1097 }
1098
1099 pub fn token_throughput(&self) -> f64 {
1101 if self.steps.is_empty() {
1102 return 0.0;
1103 }
1104 let total_tokens: usize = self.steps.iter().map(|(p, d, _)| p + d).sum();
1105 let total_us: u64 = self.steps.iter().map(|(_, _, l)| l).sum();
1106 if total_us == 0 {
1107 return 0.0;
1108 }
1109 total_tokens as f64 / (total_us as f64 / 1_000_000.0)
1110 }
1111
1112 pub fn time_to_first_token_p50(&self) -> f64 {
1114 if self.ttft_samples.is_empty() {
1115 return 0.0;
1116 }
1117 let mut sorted = self.ttft_samples.clone();
1118 sorted.sort_unstable();
1119 let mid = sorted.len() / 2;
1120 if sorted.len() % 2 == 0 && sorted.len() >= 2 {
1121 (sorted[mid - 1] + sorted[mid]) as f64 / 2.0
1122 } else {
1123 sorted[mid] as f64
1124 }
1125 }
1126
1127 pub fn format_report(&self) -> String {
1129 format!(
1130 "BatchMetrics Report\n\
1131 ====================\n\
1132 Steps recorded : {}\n\
1133 Avg prefill latency : {:.1} us\n\
1134 Avg decode latency : {:.1} us\n\
1135 Avg batch size : {:.1} tokens/step\n\
1136 Token throughput : {:.0} tokens/s\n\
1137 TTFT p50 : {:.1} us\n\
1138 TTFT samples : {}",
1139 self.steps.len(),
1140 self.avg_prefill_latency(),
1141 self.avg_decode_latency(),
1142 self.avg_batch_size(),
1143 self.token_throughput(),
1144 self.time_to_first_token_p50(),
1145 self.ttft_samples.len(),
1146 )
1147 }
1148}
1149
1150impl Default for BatchMetrics {
1151 fn default() -> Self {
1152 Self::new()
1153 }
1154}
1155
1156#[cfg(test)]
1161mod tests {
1162 use super::*;
1163
1164 fn default_config() -> BatchConfig {
1165 BatchConfig {
1166 max_batch_size: 8,
1167 max_total_tokens: 4096,
1168 max_sequence_length: 2048,
1169 prefill_batch_size: 1024,
1170 decode_batch_size: 8,
1171 scheduling_policy: SchedulingPolicy::Fcfs,
1172 }
1173 }
1174
1175 fn make_request(id: RequestId, seq_len: usize, max_new: usize) -> InferenceRequest {
1176 InferenceRequest {
1177 request_id: id,
1178 sequence_length: seq_len,
1179 max_new_tokens: max_new,
1180 priority: Priority::Normal,
1181 arrival_time_ns: id * 1000,
1182 deadline_ns: None,
1183 }
1184 }
1185
1186 #[test]
1188 fn test_add_single_request() {
1189 let mut batcher = ContinuousBatcher::new(default_config());
1190 let req = make_request(1, 128, 64);
1191 let id = batcher.add_request(req).expect("should succeed");
1192 assert_eq!(id, 1);
1193 assert_eq!(batcher.pending_requests(), 1);
1194 assert_eq!(batcher.active_requests(), 0);
1195 }
1196
1197 #[test]
1199 fn test_batch_step_mixed_prefill_decode() {
1200 let mut batcher = ContinuousBatcher::new(default_config());
1201 batcher.add_request(make_request(1, 64, 32)).expect("add 1");
1203 let d1 = batcher.step().expect("step 1");
1204 assert_eq!(d1.prefill_requests.len(), 1);
1205
1206 batcher.add_request(make_request(2, 32, 16)).expect("add 2");
1208 let d2 = batcher.step().expect("step 2");
1209 assert!(!d2.decode_requests.is_empty(), "should have decode slots");
1210 assert!(!d2.prefill_requests.is_empty(), "should have prefill slots");
1211 }
1212
1213 #[test]
1215 fn test_token_budget_allocation_release() {
1216 let mut alloc = TokenBudgetAllocator::new(1024);
1217 let slot = alloc.allocate_prefill(512);
1218 assert!(slot.is_some());
1219 assert!((alloc.utilization() - 0.5).abs() < 1e-9);
1220
1221 assert!(alloc.allocate_prefill(600).is_none());
1223
1224 alloc.release(256);
1225 assert!((alloc.utilization() - 0.25).abs() < 1e-9);
1226 }
1227
1228 #[test]
1230 fn test_paged_kv_allocation_free() {
1231 let mut mgr = PagedKvManager::new(16, 64);
1232 let blocks = mgr.allocate(128).expect("allocate 128");
1233 assert_eq!(blocks.len(), 2);
1234 let (used, total) = mgr.usage();
1235 assert_eq!(used, 2);
1236 assert_eq!(total, 16);
1237
1238 mgr.free(&blocks);
1239 let (used, _) = mgr.usage();
1240 assert_eq!(used, 0);
1241 }
1242
1243 #[test]
1245 fn test_copy_on_write() {
1246 let mut mgr = PagedKvManager::new(4, 64);
1247 let blocks = mgr.allocate(64).expect("allocate");
1248 assert_eq!(blocks.len(), 1);
1249 let orig = blocks[0];
1250
1251 mgr.ref_counts[orig] = 2;
1253
1254 let new_id = mgr.copy_on_write(orig).expect("cow");
1255 assert_ne!(new_id, orig);
1256 assert!(!mgr.free_map[orig]);
1258 assert_eq!(mgr.ref_counts[orig], 1);
1259 assert_eq!(mgr.ref_counts[new_id], 1);
1260 }
1261
1262 #[test]
1264 fn test_continuous_batching_completion() {
1265 let mut batcher = ContinuousBatcher::new(default_config());
1266 batcher.add_request(make_request(10, 64, 8)).expect("add");
1267 let _ = batcher.step().expect("step");
1268 assert_eq!(batcher.active_requests(), 1);
1269
1270 batcher.complete_request(10).expect("complete");
1271 assert_eq!(batcher.active_requests(), 0);
1272 }
1273
1274 #[test]
1276 fn test_preemption() {
1277 let mut batcher = ContinuousBatcher::new(default_config());
1278 batcher.add_request(make_request(20, 64, 16)).expect("add");
1279 let _ = batcher.step().expect("step");
1280 assert_eq!(batcher.active_requests(), 1);
1281
1282 batcher.preempt(20).expect("preempt");
1283 assert_eq!(batcher.active_requests(), 0);
1284 assert_eq!(batcher.pending_requests(), 1);
1286 }
1287
1288 #[test]
1290 fn test_fcfs_scheduling_order() {
1291 let mut batcher = ContinuousBatcher::new(default_config());
1292 batcher.add_request(make_request(3, 32, 8)).expect("add 3");
1293 batcher.add_request(make_request(1, 32, 8)).expect("add 1");
1294 batcher.add_request(make_request(2, 32, 8)).expect("add 2");
1295 let d = batcher.step().expect("step");
1297 assert_eq!(d.prefill_requests, vec![1, 2, 3]);
1298 }
1299
1300 #[test]
1302 fn test_priority_based_scheduling() {
1303 let mut config = default_config();
1304 config.scheduling_policy = SchedulingPolicy::PriorityBased;
1305 let mut batcher = ContinuousBatcher::new(config);
1306
1307 let mut low = make_request(1, 32, 8);
1308 low.priority = Priority::Low;
1309 low.arrival_time_ns = 100;
1310 let mut high = make_request(2, 32, 8);
1311 high.priority = Priority::High;
1312 high.arrival_time_ns = 200;
1313 let mut normal = make_request(3, 32, 8);
1314 normal.priority = Priority::Normal;
1315 normal.arrival_time_ns = 50;
1316
1317 batcher.add_request(low).expect("add low");
1318 batcher.add_request(high).expect("add high");
1319 batcher.add_request(normal).expect("add normal");
1320
1321 let d = batcher.step().expect("step");
1322 assert_eq!(d.prefill_requests, vec![2, 3, 1]);
1324 }
1325
1326 #[test]
1328 fn test_deadline_aware_scheduling() {
1329 let mut config = default_config();
1330 config.scheduling_policy = SchedulingPolicy::DeadlineAware;
1331 let mut batcher = ContinuousBatcher::new(config);
1332
1333 let mut r1 = make_request(1, 32, 8);
1334 r1.deadline_ns = Some(5000);
1335 let mut r2 = make_request(2, 32, 8);
1336 r2.deadline_ns = Some(1000);
1337 let mut r3 = make_request(3, 32, 8);
1338 r3.deadline_ns = None; batcher.add_request(r1).expect("add r1");
1341 batcher.add_request(r2).expect("add r2");
1342 batcher.add_request(r3).expect("add r3");
1343
1344 let d = batcher.step().expect("step");
1345 assert_eq!(d.prefill_requests, vec![2, 1, 3]);
1346 }
1347
1348 #[test]
1350 fn test_speculative_decoding_propose_samples_draft() {
1351 let mut spec = SpeculativeDecoder::with_seed(3, 12345);
1352 let draft_probs = vec![
1356 vec![0.0, 0.0, 1.0, 0.0],
1357 vec![1.0, 0.0, 0.0, 0.0],
1358 vec![0.0, 0.0, 0.0, 1.0],
1359 ];
1360 let drafted = spec.propose_tokens(&draft_probs).expect("propose");
1361 assert_eq!(drafted.len(), 3);
1362 assert_eq!(drafted[0].token_id, 2);
1363 assert_eq!(drafted[1].token_id, 0);
1364 assert_eq!(drafted[2].token_id, 3);
1365 for d in &drafted {
1367 assert!((d.draft_prob - 1.0).abs() < 1e-12);
1368 }
1369 }
1370
1371 #[test]
1373 fn test_speculative_decoding_propose_caps_and_normalises() {
1374 let mut spec = SpeculativeDecoder::with_seed(2, 99);
1375 let draft_probs = vec![
1378 vec![1.0, 3.0],
1379 vec![3.0, 1.0],
1380 vec![1.0, 0.0],
1381 vec![0.0, 1.0],
1382 ];
1383 let drafted = spec.propose_tokens(&draft_probs).expect("propose");
1384 assert_eq!(drafted.len(), 2, "draft_length caps the count");
1385 for d in &drafted {
1386 assert!((0.0..=1.0).contains(&d.draft_prob));
1388 let p = d.draft_prob;
1391 assert!(
1392 (p - 0.25).abs() < 1e-12 || (p - 0.75).abs() < 1e-12,
1393 "unexpected normalised prob {p}"
1394 );
1395 }
1396 }
1397
1398 #[test]
1400 fn test_speculative_decoding_propose_rejects_zero_dist() {
1401 let mut spec = SpeculativeDecoder::new(2);
1402 let draft_probs = vec![vec![0.0, 0.0, 0.0]];
1403 assert!(spec.propose_tokens(&draft_probs).is_err());
1404 }
1405
1406 #[test]
1408 fn test_categorical_sampling_matches_distribution() {
1409 let mut rng = LcgRng::new(0x00C0_FFEE);
1410 let weights = [0.1_f64, 0.2, 0.3, 0.4];
1412 let trials = 200_000;
1413 let mut counts = [0u64; 4];
1414 for _ in 0..trials {
1415 let idx = rng.sample_categorical(&weights).expect("sample");
1416 counts[idx] += 1;
1417 }
1418 for (i, &w) in weights.iter().enumerate() {
1419 let freq = counts[i] as f64 / trials as f64;
1420 assert!(
1421 (freq - w).abs() < 0.01,
1422 "category {i}: freq {freq} vs expected {w}"
1423 );
1424 }
1425 }
1426
1427 #[test]
1429 fn test_rejection_sampling_acceptance_probability() {
1430 let trials = 100_000;
1434 let mut accepted_rounds = 0u64;
1435 for seed in 0..trials {
1436 let mut spec = SpeculativeDecoder::with_seed(1, seed);
1437 let drafted = vec![DraftedToken {
1438 token_id: 0,
1439 draft_prob: 0.8,
1440 }];
1441 let target = vec![vec![0.4, 0.6], vec![0.5, 0.5]];
1444 let res = spec.verify_and_accept(&drafted, &target).expect("verify");
1445 if res.accepted == 1 {
1446 accepted_rounds += 1;
1447 }
1448 }
1449 let rate = accepted_rounds as f64 / trials as f64;
1450 assert!(
1451 (rate - 0.5).abs() < 0.01,
1452 "acceptance rate {rate} should be ~0.5"
1453 );
1454 }
1455
1456 #[test]
1458 fn test_rejection_sampling_always_accepts_when_target_ge_draft() {
1459 for seed in 0..2000 {
1460 let mut spec = SpeculativeDecoder::with_seed(1, seed);
1461 let drafted = vec![DraftedToken {
1462 token_id: 0,
1463 draft_prob: 0.3,
1464 }];
1465 let target = vec![vec![0.6, 0.4], vec![0.5, 0.5]];
1467 let res = spec.verify_and_accept(&drafted, &target).expect("verify");
1468 assert_eq!(res.accepted, 1, "ratio >= 1 must always accept");
1469 assert_eq!(res.rejected, 0);
1470 }
1471 }
1472
1473 #[test]
1475 fn test_rejection_sampling_rejects_zero_draft_prob() {
1476 let mut spec = SpeculativeDecoder::with_seed(1, 7);
1477 let drafted = vec![DraftedToken {
1478 token_id: 0,
1479 draft_prob: 0.0,
1480 }];
1481 let target = vec![vec![0.9, 0.1], vec![0.5, 0.5]];
1482 let res = spec.verify_and_accept(&drafted, &target).expect("verify");
1483 assert_eq!(res.accepted, 0);
1484 assert_eq!(res.rejected, 1);
1485 assert_eq!(res.tokens.len(), 1);
1487 }
1488
1489 #[test]
1491 fn test_residual_distribution_resampling() {
1492 let trials = 100_000;
1503 let mut counts = [0u64; 3];
1504 for seed in 0..trials {
1505 let mut spec = SpeculativeDecoder::with_seed(1, seed);
1506 let drafted = vec![DraftedToken {
1507 token_id: 0,
1508 draft_prob: 1.0,
1509 }];
1510 let target = vec![vec![0.0, 0.5, 0.5], vec![1.0, 0.0, 0.0]];
1511 let res = spec.verify_and_accept(&drafted, &target).expect("verify");
1512 assert_eq!(res.accepted, 0, "must reject");
1513 let corr = res.tokens[0] as usize;
1514 counts[corr] += 1;
1515 }
1516 let total = trials as f64;
1517 assert_eq!(counts[0], 0, "token 0 has zero residual mass");
1518 assert!((counts[1] as f64 / total - 0.5).abs() < 0.01);
1519 assert!((counts[2] as f64 / total - 0.5).abs() < 0.01);
1520 }
1521
1522 #[test]
1525 fn test_residual_distribution_concentrated() {
1526 for seed in 0..1000 {
1533 let mut spec = SpeculativeDecoder::with_seed(1, seed);
1534 let drafted = vec![DraftedToken {
1535 token_id: 1,
1536 draft_prob: 1.0,
1537 }];
1538 let target = vec![vec![1.0, 0.0], vec![0.5, 0.5]];
1539 let res = spec.verify_and_accept(&drafted, &target).expect("verify");
1540 assert_eq!(res.accepted, 0);
1541 assert_eq!(res.tokens[0], 0, "residual concentrates on token 0");
1542 }
1543 }
1544
1545 #[test]
1547 fn test_residual_distribution_zero_fallback() {
1548 let drafted = [DraftedToken {
1554 token_id: 0,
1555 draft_prob: 1.0,
1556 }];
1557 let target_dist = [1.0_f64, 0.0, 0.0];
1558 let residual = SpeculativeDecoder::residual_distribution(&target_dist, &drafted, 0);
1559 assert_eq!(residual, vec![1.0, 0.0, 0.0]);
1562 }
1563
1564 #[test]
1566 fn test_speculative_draft_equals_target_accepts_all() {
1567 let gamma = 5;
1571 for seed in 0..3000 {
1572 let mut spec = SpeculativeDecoder::with_seed(gamma, seed);
1573 let dist = vec![0.15, 0.25, 0.20, 0.40];
1575 let draft_probs = vec![dist.clone(); gamma];
1576 let drafted = spec.propose_tokens(&draft_probs).expect("propose");
1577 assert_eq!(drafted.len(), gamma);
1578
1579 let target_dists = vec![dist.clone(); gamma + 1];
1581 let res = spec
1582 .verify_and_accept(&drafted, &target_dists)
1583 .expect("verify");
1584 assert_eq!(res.accepted, gamma, "draft==target must accept all");
1585 assert_eq!(res.rejected, 0);
1586 assert_eq!(res.tokens.len(), gamma + 1);
1588 }
1589 }
1590
1591 #[test]
1593 fn test_speculative_accepted_length_distribution() {
1594 let gamma = 4usize;
1595 let mut spec = SpeculativeDecoder::with_seed(gamma, 0xABCD);
1596 let rounds = 5000u64;
1597 let mut sum_accepted = 0u64;
1598 for _ in 0..rounds {
1599 let draft_probs = vec![vec![1.0, 0.0]; gamma];
1601 let drafted = spec.propose_tokens(&draft_probs).expect("propose");
1602 let target_dists = vec![vec![0.7, 0.3]; gamma + 1];
1604 let res = spec
1605 .verify_and_accept(&drafted, &target_dists)
1606 .expect("verify");
1607 assert!(res.accepted <= gamma, "accepted within [0, gamma]");
1608 assert_eq!(res.rejected, usize::from(res.accepted < gamma));
1609 assert_eq!(res.tokens.len(), res.accepted + 1);
1611 sum_accepted += res.accepted as u64;
1612 }
1613 assert_eq!(spec.total_proposed(), rounds * gamma as u64);
1615 assert_eq!(spec.total_accepted(), sum_accepted);
1616 assert_eq!(spec.rounds(), rounds);
1617 let rate = spec.acceptance_rate();
1618 assert!(
1624 (rate - 0.4433).abs() < 0.02,
1625 "acceptance rate {rate} should be ~0.4433"
1626 );
1627 let mtpr = spec.mean_tokens_per_round();
1629 assert!(mtpr >= 1.0 && mtpr <= (gamma + 1) as f64, "mtpr {mtpr}");
1630 }
1631
1632 #[test]
1634 fn test_speculative_verify_rejects_short_target() {
1635 let mut spec = SpeculativeDecoder::new(2);
1636 let drafted = vec![
1637 DraftedToken {
1638 token_id: 0,
1639 draft_prob: 0.5,
1640 },
1641 DraftedToken {
1642 token_id: 1,
1643 draft_prob: 0.5,
1644 },
1645 ];
1646 let target = vec![vec![0.5, 0.5], vec![0.5, 0.5]];
1648 assert!(spec.verify_and_accept(&drafted, &target).is_err());
1649 }
1650
1651 #[test]
1653 fn test_speculative_verify_rejects_token_out_of_range() {
1654 let mut spec = SpeculativeDecoder::new(1);
1655 let drafted = vec![DraftedToken {
1656 token_id: 9, draft_prob: 0.5,
1658 }];
1659 let target = vec![vec![0.5, 0.5], vec![0.5, 0.5]];
1660 assert!(spec.verify_and_accept(&drafted, &target).is_err());
1661 }
1662
1663 #[test]
1665 fn test_speculative_empty_draft_emits_bonus() {
1666 let mut spec = SpeculativeDecoder::with_seed(4, 55);
1667 let drafted: Vec<DraftedToken> = Vec::new();
1668 let target = vec![vec![0.0, 1.0, 0.0]];
1670 let res = spec.verify_and_accept(&drafted, &target).expect("verify");
1671 assert_eq!(res.accepted, 0);
1672 assert_eq!(res.rejected, 0);
1673 assert_eq!(res.tokens, vec![1], "bonus drawn from one-hot target");
1674 }
1675
1676 #[test]
1678 fn test_lcg_rng_uniform_and_deterministic() {
1679 let mut a = LcgRng::new(2024);
1680 let mut b = LcgRng::new(2024);
1681 let mut sum = 0.0_f64;
1682 let n = 100_000;
1683 for _ in 0..n {
1684 let va = a.next_f64();
1685 let vb = b.next_f64();
1686 assert_eq!(va, vb, "same seed must yield same stream");
1687 assert!((0.0..1.0).contains(&va));
1688 sum += va;
1689 }
1690 let mean = sum / n as f64;
1692 assert!((mean - 0.5).abs() < 0.01, "uniform mean {mean}");
1693 }
1694
1695 #[test]
1697 fn test_batch_metrics_tracking() {
1698 let mut m = BatchMetrics::new();
1699 m.record_step(128, 0, 500);
1700 m.record_step(0, 8, 100);
1701 m.record_step(64, 4, 300);
1702
1703 assert!((m.avg_prefill_latency() - 400.0).abs() < 1e-9);
1704 assert!((m.avg_decode_latency() - 200.0).abs() < 1e-9);
1705 assert!((m.avg_batch_size() - 68.0).abs() < 1e-9);
1707 assert!(m.token_throughput() > 0.0);
1708 }
1709
1710 #[test]
1712 fn test_max_batch_size_enforcement() {
1713 let mut config = default_config();
1714 config.max_batch_size = 2;
1715 let mut batcher = ContinuousBatcher::new(config);
1716
1717 for i in 0..4 {
1718 batcher.add_request(make_request(i, 32, 8)).expect("add");
1719 }
1720 let d = batcher.step().expect("step");
1721 assert!(d.prefill_requests.len() <= 2);
1722 assert_eq!(batcher.active_requests(), d.prefill_requests.len());
1723 }
1724
1725 #[test]
1727 fn test_queue_management() {
1728 let mut batcher = ContinuousBatcher::new(default_config());
1729 assert_eq!(batcher.pending_requests(), 0);
1730
1731 batcher.add_request(make_request(1, 32, 8)).expect("add");
1732 batcher.add_request(make_request(2, 32, 8)).expect("add");
1733 assert_eq!(batcher.pending_requests(), 2);
1734
1735 let _ = batcher.step().expect("step");
1736 assert_eq!(batcher.pending_requests(), 0);
1737 assert_eq!(batcher.active_requests(), 2);
1738
1739 batcher.complete_request(1).expect("complete");
1740 assert_eq!(batcher.active_requests(), 1);
1741 }
1742
1743 #[test]
1745 fn test_utilization_calculation() {
1746 let mut alloc = TokenBudgetAllocator::new(1000);
1747 assert!((alloc.utilization() - 0.0).abs() < 1e-9);
1748
1749 alloc.allocate_prefill(250);
1750 assert!((alloc.utilization() - 0.25).abs() < 1e-9);
1751
1752 let fitted = alloc.allocate_decode(900);
1753 assert_eq!(fitted, 750);
1754 assert!((alloc.utilization() - 1.0).abs() < 1e-9);
1755
1756 let zero = TokenBudgetAllocator::new(0);
1758 assert!((zero.utilization() - 0.0).abs() < 1e-9);
1759 }
1760
1761 #[test]
1763 fn test_format_report() {
1764 let mut m = BatchMetrics::new();
1765 m.record_step(100, 10, 200);
1766 m.record_step(0, 8, 100);
1767 m.record_ttft(150);
1768 m.record_ttft(250);
1769 let report = m.format_report();
1770 assert!(report.contains("Steps recorded"));
1771 assert!(report.contains("Token throughput"));
1772 assert!(report.contains("TTFT p50"));
1773 }
1774}