Skip to main content

oxicuda_dnn/
dynamic_batch.rs

1//! Dynamic batching and continuous batching for inference serving.
2//!
3//! This module implements the core scheduling primitives used by modern LLM
4//! inference engines (vLLM, Orca, TensorRT-LLM):
5//!
6//! - **[`ContinuousBatcher`]** — iteration-level scheduler that decides which
7//!   requests to prefill, decode, or preempt at each step.
8//! - **[`TokenBudgetAllocator`]** — manages a per-step token budget shared
9//!   between prefill and decode phases.
10//! - **[`PagedKvManager`]** — block-level paged KV-cache allocator with
11//!   copy-on-write support for beam search and speculative decoding.
12//! - **[`SpeculativeDecoder`]** — draft-model speculative decoding implementing
13//!   the speculative-sampling algorithm of Leviathan et al. (2023) and
14//!   Chen et al. (2023): drafted tokens are sampled from the draft model's
15//!   categorical distribution and verified against the target model with
16//!   modified rejection sampling.
17//! - **[`BatchMetrics`]** — running statistics for throughput, latency, and
18//!   utilization monitoring.
19//!
20//! # Scheduling Policies
21//!
22//! | Policy | Description |
23//! |--------|-------------|
24//! | [`SchedulingPolicy::Fcfs`] | First-come, first-served |
25//! | [`SchedulingPolicy::ShortestJobFirst`] | Shortest remaining generation |
26//! | [`SchedulingPolicy::PriorityBased`] | User-assigned priority levels |
27//! | [`SchedulingPolicy::DeadlineAware`] | EDF (earliest deadline first) |
28//! | [`SchedulingPolicy::Orca`] | Iteration-level (selective batching) |
29
30use std::collections::VecDeque;
31
32use crate::error::{DnnError, DnnResult};
33
34// ---------------------------------------------------------------------------
35// Basic types
36// ---------------------------------------------------------------------------
37
38/// Unique identifier for an inference request.
39pub type RequestId = u64;
40
41/// Priority level for a request.
42#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
43pub enum Priority {
44    /// Lowest priority — best-effort.
45    Low = 0,
46    /// Default priority.
47    Normal = 1,
48    /// Expedited processing.
49    High = 2,
50}
51
52/// An incoming inference request.
53#[derive(Debug, Clone)]
54pub struct InferenceRequest {
55    /// Unique request identifier.
56    pub request_id: RequestId,
57    /// Number of input (prompt) tokens.
58    pub sequence_length: usize,
59    /// Maximum number of tokens to generate.
60    pub max_new_tokens: usize,
61    /// Scheduling priority.
62    pub priority: Priority,
63    /// Monotonic arrival timestamp in nanoseconds.
64    pub arrival_time_ns: u64,
65    /// Optional hard deadline in nanoseconds (absolute).
66    pub deadline_ns: Option<u64>,
67}
68
69/// A slot inside the running batch.
70#[derive(Debug, Clone)]
71pub struct BatchSlot {
72    /// Slot index within the batch.
73    pub slot_id: usize,
74    /// Request occupying this slot.
75    pub request_id: RequestId,
76    /// Tokens processed so far (prompt + generated).
77    pub current_seq_len: usize,
78    /// Maximum sequence length (prompt + max_new_tokens).
79    pub max_seq_len: usize,
80    /// Whether this slot is currently doing prefill.
81    pub is_prefill: bool,
82    /// Whether this slot is actively generating.
83    pub is_active: bool,
84}
85
86/// Scheduling algorithm.
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum SchedulingPolicy {
89    /// First-come, first-served.
90    Fcfs,
91    /// Prefer requests with fewer remaining tokens.
92    ShortestJobFirst,
93    /// Respect user-assigned [`Priority`] levels.
94    PriorityBased,
95    /// Earliest-deadline-first (requires `deadline_ns`).
96    DeadlineAware,
97    /// Orca-style iteration-level selective batching.
98    Orca,
99}
100
101/// How to handle a preempted request.
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub enum PreemptionPolicy {
104    /// Discard KV cache and recompute from scratch when resumed.
105    Recompute,
106    /// Swap (offload) KV cache blocks to host memory.
107    Swap,
108}
109
110/// Configuration for the continuous batcher.
111#[derive(Debug, Clone)]
112pub struct BatchConfig {
113    /// Maximum number of requests in a single batch.
114    pub max_batch_size: usize,
115    /// Maximum total tokens (prompt + generated) across the batch.
116    pub max_total_tokens: usize,
117    /// Maximum sequence length for any single request.
118    pub max_sequence_length: usize,
119    /// Maximum prefill tokens per step.
120    pub prefill_batch_size: usize,
121    /// Maximum decode slots per step.
122    pub decode_batch_size: usize,
123    /// Scheduling algorithm.
124    pub scheduling_policy: SchedulingPolicy,
125}
126
127/// Result of a single scheduling step.
128#[derive(Debug, Clone)]
129pub struct BatchDecision {
130    /// Requests to prefill in this step.
131    pub prefill_requests: Vec<RequestId>,
132    /// Requests to decode in this step.
133    pub decode_requests: Vec<RequestId>,
134    /// Requests preempted to free capacity.
135    pub preempted: Vec<RequestId>,
136    /// Total token count for the step.
137    pub total_tokens: usize,
138}
139
140// ---------------------------------------------------------------------------
141// BatchState — internal bookkeeping
142// ---------------------------------------------------------------------------
143
144/// Internal state of the batcher.
145#[derive(Debug)]
146struct BatchState {
147    /// Currently running slots.
148    active_slots: Vec<BatchSlot>,
149    /// Total tokens across all active slots.
150    total_tokens: usize,
151    /// Requests waiting for prefill.
152    prefill_queue: VecDeque<InferenceRequest>,
153    /// Requests in decode phase (tracked separately for Orca).
154    decode_queue: VecDeque<RequestId>,
155    /// Preempted requests awaiting resumption.
156    preempted_queue: VecDeque<InferenceRequest>,
157}
158
159impl BatchState {
160    fn new() -> Self {
161        Self {
162            active_slots: Vec::new(),
163            total_tokens: 0,
164            prefill_queue: VecDeque::new(),
165            decode_queue: VecDeque::new(),
166            preempted_queue: VecDeque::new(),
167        }
168    }
169}
170
171// ---------------------------------------------------------------------------
172// ContinuousBatcher
173// ---------------------------------------------------------------------------
174
175/// Continuous batcher — the main scheduler for LLM inference serving.
176///
177/// Implements iteration-level scheduling inspired by Orca / vLLM.  At each
178/// [`step`](ContinuousBatcher::step) the batcher decides which waiting
179/// requests to admit for prefill, which running requests continue decoding,
180/// and whether any requests must be preempted.
181#[derive(Debug)]
182pub struct ContinuousBatcher {
183    config: BatchConfig,
184    state: BatchState,
185    next_slot_id: usize,
186    completed_count: u64,
187}
188
189impl ContinuousBatcher {
190    /// Create a new batcher with the given configuration.
191    pub fn new(config: BatchConfig) -> Self {
192        Self {
193            config,
194            state: BatchState::new(),
195            next_slot_id: 0,
196            completed_count: 0,
197        }
198    }
199
200    /// Enqueue a new inference request. Returns its `RequestId`.
201    pub fn add_request(&mut self, request: InferenceRequest) -> DnnResult<RequestId> {
202        if request.sequence_length == 0 {
203            return Err(DnnError::InvalidArgument(
204                "sequence_length must be > 0".into(),
205            ));
206        }
207        if request.sequence_length > self.config.max_sequence_length {
208            return Err(DnnError::InvalidArgument(format!(
209                "sequence_length {} exceeds max_sequence_length {}",
210                request.sequence_length, self.config.max_sequence_length
211            )));
212        }
213        let id = request.request_id;
214        self.state.prefill_queue.push_back(request);
215        Ok(id)
216    }
217
218    /// Execute one scheduling step.
219    ///
220    /// Returns a [`BatchDecision`] describing which requests to prefill,
221    /// decode, and preempt during this iteration.
222    pub fn step(&mut self) -> DnnResult<BatchDecision> {
223        let mut decision = BatchDecision {
224            prefill_requests: Vec::new(),
225            decode_requests: Vec::new(),
226            preempted: Vec::new(),
227            total_tokens: 0,
228        };
229
230        // 1. Collect decode requests from active slots.
231        let decode_ids: Vec<RequestId> = self
232            .state
233            .active_slots
234            .iter()
235            .filter(|s| s.is_active && !s.is_prefill)
236            .map(|s| s.request_id)
237            .collect();
238
239        let decode_count = decode_ids.len().min(self.config.decode_batch_size);
240        let decode_tokens: usize = self
241            .state
242            .active_slots
243            .iter()
244            .filter(|s| s.is_active && !s.is_prefill)
245            .take(decode_count)
246            .map(|s| s.current_seq_len + 1) // +1 for the token being generated
247            .sum();
248
249        decision.decode_requests = decode_ids.into_iter().take(decode_count).collect();
250
251        // 2. Sort the prefill queue according to scheduling policy.
252        self.sort_prefill_queue();
253
254        // 3. Admit prefill requests within budget.
255        let mut prefill_budget = self
256            .config
257            .prefill_batch_size
258            .min(self.config.max_total_tokens.saturating_sub(decode_tokens));
259
260        let mut admitted = Vec::new();
261        while !self.state.prefill_queue.is_empty()
262            && self.state.active_slots.len() + admitted.len() < self.config.max_batch_size
263        {
264            // Peek at the front.
265            let req = match self.state.prefill_queue.front() {
266                Some(r) => r,
267                None => break,
268            };
269            if req.sequence_length > prefill_budget {
270                break;
271            }
272            // Safe: we just confirmed front() is Some.
273            let req = self
274                .state
275                .prefill_queue
276                .pop_front()
277                .ok_or_else(|| DnnError::InvalidArgument("empty queue".into()))?;
278
279            prefill_budget = prefill_budget.saturating_sub(req.sequence_length);
280
281            let slot = BatchSlot {
282                slot_id: self.next_slot_id,
283                request_id: req.request_id,
284                current_seq_len: req.sequence_length,
285                max_seq_len: req.sequence_length + req.max_new_tokens,
286                is_prefill: true,
287                is_active: true,
288            };
289            self.next_slot_id += 1;
290            decision.prefill_requests.push(req.request_id);
291            admitted.push(slot);
292        }
293
294        // 4. Transition admitted prefill slots to decode.
295        for slot in &mut admitted {
296            slot.is_prefill = false;
297        }
298        self.state.active_slots.extend(admitted);
299
300        // Increment decode tokens for existing slots.
301        for slot in &mut self.state.active_slots {
302            if slot.is_active && !slot.is_prefill {
303                slot.current_seq_len = slot.current_seq_len.saturating_add(1);
304            }
305        }
306
307        decision.total_tokens = self
308            .state
309            .active_slots
310            .iter()
311            .filter(|s| s.is_active)
312            .map(|s| s.current_seq_len)
313            .sum();
314
315        self.state.total_tokens = decision.total_tokens;
316
317        Ok(decision)
318    }
319
320    /// Mark a request as completed and free its resources.
321    pub fn complete_request(&mut self, request_id: RequestId) -> DnnResult<()> {
322        let pos = self
323            .state
324            .active_slots
325            .iter()
326            .position(|s| s.request_id == request_id)
327            .ok_or_else(|| {
328                DnnError::InvalidArgument(format!("request {request_id} not in active slots"))
329            })?;
330        let slot = &self.state.active_slots[pos];
331        self.state.total_tokens = self.state.total_tokens.saturating_sub(slot.current_seq_len);
332        self.state.active_slots.remove(pos);
333        self.state.decode_queue.retain(|id| *id != request_id);
334        self.completed_count += 1;
335        Ok(())
336    }
337
338    /// Preempt a running request. The request is moved to the preempted queue
339    /// and may be resumed later.
340    pub fn preempt(&mut self, request_id: RequestId) -> DnnResult<()> {
341        let pos = self
342            .state
343            .active_slots
344            .iter()
345            .position(|s| s.request_id == request_id)
346            .ok_or_else(|| {
347                DnnError::InvalidArgument(format!("request {request_id} not in active slots"))
348            })?;
349        let slot = self.state.active_slots.remove(pos);
350        self.state.total_tokens = self.state.total_tokens.saturating_sub(slot.current_seq_len);
351        self.state.decode_queue.retain(|id| *id != request_id);
352
353        // Re-enqueue as a prefill request so it can be recomputed.
354        let preempted_req = InferenceRequest {
355            request_id,
356            sequence_length: slot.current_seq_len,
357            max_new_tokens: slot.max_seq_len.saturating_sub(slot.current_seq_len),
358            priority: Priority::Normal,
359            arrival_time_ns: 0,
360            deadline_ns: None,
361        };
362        self.state.preempted_queue.push_back(preempted_req);
363        Ok(())
364    }
365
366    /// Number of requests currently executing (prefill + decode).
367    pub fn active_requests(&self) -> usize {
368        self.state
369            .active_slots
370            .iter()
371            .filter(|s| s.is_active)
372            .count()
373    }
374
375    /// Number of requests waiting in all queues (prefill + preempted).
376    pub fn pending_requests(&self) -> usize {
377        self.state.prefill_queue.len() + self.state.preempted_queue.len()
378    }
379
380    /// Total tokens that would be processed in the current active batch.
381    pub fn throughput_tokens_per_step(&self) -> usize {
382        self.state.total_tokens
383    }
384
385    // -- private helpers --
386
387    fn sort_prefill_queue(&mut self) {
388        let queue = &mut self.state.prefill_queue;
389        let policy = self.config.scheduling_policy;
390
391        let mut vec: Vec<InferenceRequest> = queue.drain(..).collect();
392        match policy {
393            SchedulingPolicy::Fcfs => {
394                // Already in arrival order — sort by arrival_time_ns.
395                vec.sort_by_key(|r| r.arrival_time_ns);
396            }
397            SchedulingPolicy::ShortestJobFirst => {
398                vec.sort_by_key(|r| r.max_new_tokens);
399            }
400            SchedulingPolicy::PriorityBased => {
401                // Higher priority first, then FCFS within same priority.
402                vec.sort_by(|a, b| {
403                    b.priority
404                        .cmp(&a.priority)
405                        .then(a.arrival_time_ns.cmp(&b.arrival_time_ns))
406                });
407            }
408            SchedulingPolicy::DeadlineAware => {
409                // Earliest deadline first; no-deadline requests go last.
410                vec.sort_by(|a, b| {
411                    let da = a.deadline_ns.unwrap_or(u64::MAX);
412                    let db = b.deadline_ns.unwrap_or(u64::MAX);
413                    da.cmp(&db).then(a.arrival_time_ns.cmp(&b.arrival_time_ns))
414                });
415            }
416            SchedulingPolicy::Orca => {
417                // Orca: iteration-level — same as FCFS for the prefill queue.
418                vec.sort_by_key(|r| r.arrival_time_ns);
419            }
420        }
421        *queue = VecDeque::from(vec);
422    }
423}
424
425// ---------------------------------------------------------------------------
426// TokenBudgetAllocator
427// ---------------------------------------------------------------------------
428
429/// Manages the per-step token budget shared between prefill and decode.
430#[derive(Debug)]
431pub struct TokenBudgetAllocator {
432    max_total_tokens: usize,
433    allocated: usize,
434}
435
436impl TokenBudgetAllocator {
437    /// Create an allocator with the given capacity.
438    pub fn new(max_total_tokens: usize) -> Self {
439        Self {
440            max_total_tokens,
441            allocated: 0,
442        }
443    }
444
445    /// Try to allocate `seq_len` tokens for a prefill request.
446    /// Returns `Some(slot_index)` on success, `None` if the budget is
447    /// exhausted.
448    pub fn allocate_prefill(&mut self, seq_len: usize) -> Option<usize> {
449        if self.allocated + seq_len > self.max_total_tokens {
450            return None;
451        }
452        let slot = self.allocated;
453        self.allocated += seq_len;
454        Some(slot)
455    }
456
457    /// How many decode slots (each consuming 1 token) can still fit.
458    pub fn allocate_decode(&mut self, count: usize) -> usize {
459        let remaining = self.max_total_tokens.saturating_sub(self.allocated);
460        let actual = count.min(remaining);
461        self.allocated += actual;
462        actual
463    }
464
465    /// Release `tokens` from the budget.
466    pub fn release(&mut self, tokens: usize) {
467        self.allocated = self.allocated.saturating_sub(tokens);
468    }
469
470    /// Fraction of the budget currently in use (0.0..=1.0).
471    pub fn utilization(&self) -> f64 {
472        if self.max_total_tokens == 0 {
473            return 0.0;
474        }
475        self.allocated as f64 / self.max_total_tokens as f64
476    }
477}
478
479// ---------------------------------------------------------------------------
480// PagedKvManager
481// ---------------------------------------------------------------------------
482
483/// Block-level paged KV-cache manager.
484///
485/// Inspired by the paging scheme in vLLM.  Physical blocks are allocated on
486/// demand and freed when a request completes.  Copy-on-write is supported for
487/// speculative / beam-search scenarios.
488#[derive(Debug)]
489pub struct PagedKvManager {
490    num_blocks: usize,
491    block_size: usize,
492    /// `true` ⇒ block is free.
493    free_map: Vec<bool>,
494    /// Reference count per block (for CoW).
495    ref_counts: Vec<usize>,
496}
497
498impl PagedKvManager {
499    /// Create a manager with `num_blocks` blocks, each holding `block_size`
500    /// tokens.
501    pub fn new(num_blocks: usize, block_size: usize) -> Self {
502        Self {
503            num_blocks,
504            block_size,
505            free_map: vec![true; num_blocks],
506            ref_counts: vec![0; num_blocks],
507        }
508    }
509
510    /// Allocate enough blocks to hold `num_tokens` tokens.
511    ///
512    /// Returns the list of allocated block IDs, or an error if there is
513    /// insufficient free space.
514    pub fn allocate(&mut self, num_tokens: usize) -> DnnResult<Vec<usize>> {
515        if self.block_size == 0 {
516            return Err(DnnError::InvalidArgument("block_size is 0".into()));
517        }
518        let blocks_needed = num_tokens.div_ceil(self.block_size);
519        if !self.can_allocate(num_tokens) {
520            return Err(DnnError::InvalidArgument(format!(
521                "not enough free blocks: need {blocks_needed}, have {}",
522                self.free_block_count()
523            )));
524        }
525        let mut ids = Vec::with_capacity(blocks_needed);
526        for (i, free) in self.free_map.iter_mut().enumerate() {
527            if ids.len() >= blocks_needed {
528                break;
529            }
530            if *free {
531                *free = false;
532                self.ref_counts[i] = 1;
533                ids.push(i);
534            }
535        }
536        Ok(ids)
537    }
538
539    /// Free the given blocks. Decrements reference counts and marks blocks as
540    /// free when the count reaches zero.
541    pub fn free(&mut self, block_ids: &[usize]) {
542        for &id in block_ids {
543            if id < self.num_blocks {
544                self.ref_counts[id] = self.ref_counts[id].saturating_sub(1);
545                if self.ref_counts[id] == 0 {
546                    self.free_map[id] = true;
547                }
548            }
549        }
550    }
551
552    /// Copy-on-write: create a new physical copy of `block_id`.
553    ///
554    /// Used when a block is shared (ref_count > 1) and one branch needs to
555    /// diverge (e.g. beam search).
556    pub fn copy_on_write(&mut self, block_id: usize) -> DnnResult<usize> {
557        if block_id >= self.num_blocks {
558            return Err(DnnError::InvalidArgument(format!(
559                "block_id {block_id} out of range (max {})",
560                self.num_blocks
561            )));
562        }
563        // Find a free block.
564        let new_id =
565            self.free_map.iter().position(|&free| free).ok_or_else(|| {
566                DnnError::InvalidArgument("no free blocks for copy-on-write".into())
567            })?;
568        self.free_map[new_id] = false;
569        self.ref_counts[new_id] = 1;
570
571        // Decrement old block ref count.
572        self.ref_counts[block_id] = self.ref_counts[block_id].saturating_sub(1);
573        if self.ref_counts[block_id] == 0 {
574            self.free_map[block_id] = true;
575        }
576
577        Ok(new_id)
578    }
579
580    /// (used, total) block counts.
581    pub fn usage(&self) -> (usize, usize) {
582        let used = self.free_map.iter().filter(|&&free| !free).count();
583        (used, self.num_blocks)
584    }
585
586    /// Whether `num_tokens` tokens can be allocated right now.
587    pub fn can_allocate(&self, num_tokens: usize) -> bool {
588        if self.block_size == 0 {
589            return false;
590        }
591        let needed = num_tokens.div_ceil(self.block_size);
592        self.free_block_count() >= needed
593    }
594
595    fn free_block_count(&self) -> usize {
596        self.free_map.iter().filter(|&&f| f).count()
597    }
598}
599
600// ---------------------------------------------------------------------------
601// LcgRng — workspace-convention pseudo-random number generator
602// ---------------------------------------------------------------------------
603
604/// Minimal full-period 64-bit LCG (Knuth MMIX constants).
605///
606/// Used for the categorical sampling and rejection-sampling steps of
607/// [`SpeculativeDecoder`].  The high bits of the state are used for output —
608/// the low bits of an MMIX LCG have short periods and must be discarded.
609#[derive(Debug, Clone)]
610pub struct LcgRng {
611    state: u64,
612}
613
614impl LcgRng {
615    /// LCG multiplier (Knuth MMIX).
616    const MUL: u64 = 6_364_136_223_846_793_005;
617    /// LCG increment (Knuth MMIX).
618    const ADD: u64 = 1_442_695_040_888_963_407;
619
620    /// Creates a new generator seeded with `seed`.
621    ///
622    /// The seed is run through a SplitMix64-style finalising multiply so that
623    /// nearby seeds produce well-separated streams.
624    #[must_use]
625    pub fn new(seed: u64) -> Self {
626        Self {
627            state: seed
628                .wrapping_mul(0x9E37_79B9_7F4A_7C15)
629                .wrapping_add(Self::ADD),
630        }
631    }
632
633    /// Advances the state and returns the next 64-bit value.
634    #[inline]
635    pub fn next_u64(&mut self) -> u64 {
636        self.state = self.state.wrapping_mul(Self::MUL).wrapping_add(Self::ADD);
637        self.state
638    }
639
640    /// Returns a uniform `f64` in `[0, 1)`.
641    ///
642    /// The top 53 bits of the state are used so every representable
643    /// double-precision fraction in `[0, 1)` is reachable.
644    #[inline]
645    pub fn next_f64(&mut self) -> f64 {
646        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
647    }
648
649    /// Samples a category index from a (not necessarily normalised) weight
650    /// vector via inverse-CDF sampling.
651    ///
652    /// `weights` must be non-negative and have a strictly positive sum;
653    /// `None` is returned when that precondition does not hold (empty slice,
654    /// all-zero weights, or a non-finite total).  The draw is performed
655    /// against the normalised cumulative distribution, so the result is a
656    /// genuine categorical sample from `weights / sum(weights)`.
657    pub fn sample_categorical(&mut self, weights: &[f64]) -> Option<usize> {
658        let total: f64 = weights.iter().sum();
659        if weights.is_empty() || !total.is_finite() || total <= 0.0 {
660            return None;
661        }
662        let threshold = self.next_f64() * total;
663        let mut acc = 0.0;
664        for (idx, &w) in weights.iter().enumerate() {
665            acc += w.max(0.0);
666            if threshold < acc {
667                return Some(idx);
668            }
669        }
670        // Floating-point round-off: fall back to the last positive-weight index.
671        weights.iter().rposition(|&w| w > 0.0)
672    }
673}
674
675// ---------------------------------------------------------------------------
676// SpeculativeDecoder
677// ---------------------------------------------------------------------------
678
679/// Outcome of one [`SpeculativeDecoder::verify_and_accept`] call.
680#[derive(Debug, Clone, PartialEq, Eq)]
681pub struct SpeculativeResult {
682    /// Tokens emitted this round: the accepted prefix of drafted tokens
683    /// followed by exactly one extra token (a correction token on the first
684    /// rejection, or a bonus token when every drafted token was accepted).
685    pub tokens: Vec<u32>,
686    /// Number of drafted tokens that passed the rejection test.
687    pub accepted: usize,
688    /// Number of drafted tokens that were rejected (`0` or `1` per round —
689    /// at most one rejection can occur because verification stops there).
690    pub rejected: usize,
691}
692
693/// Speculative decoding support (draft + verify).
694///
695/// A small "draft" model proposes several tokens ahead, and a larger "target"
696/// model verifies them, accepting a prefix of the proposed tokens.  This
697/// amortises the cost of autoregressive generation.
698///
699/// This is a faithful host-side implementation of the speculative-sampling
700/// algorithm of Leviathan et al., *"Fast Inference from Transformers via
701/// Speculative Decoding"* (2023), and Chen et al., *"Accelerating Large
702/// Language Model Decoding with Speculative Sampling"* (2023):
703///
704/// 1. [`Self::propose_tokens`] draws drafted tokens by **categorical sampling**
705///    from the draft model's per-position probability distributions.
706/// 2. [`Self::verify_and_accept`] performs **modified rejection sampling**:
707///    for the drafted token `t` at position `i`, a uniform `r ∈ [0, 1)` is
708///    drawn and the token is accepted iff `r < min(1, p_target(t) / p_draft(t))`.
709///    The longest passing prefix is kept; on the first rejection a correction
710///    token is sampled from the normalised residual
711///    `normalize(max(0, p_target − p_draft))`.  If every drafted token is
712///    accepted, one bonus token is sampled from `p_target`.
713///
714/// The decoder operates on the probability vectors supplied by the caller
715/// (which would come from running the draft and target model forward passes).
716/// The sampling and acceptance arithmetic is exact — no values are fabricated.
717#[derive(Debug)]
718pub struct SpeculativeDecoder {
719    draft_length: usize,
720    rng: LcgRng,
721    total_proposed: u64,
722    total_accepted: u64,
723    rounds: u64,
724}
725
726impl SpeculativeDecoder {
727    /// Default RNG seed used by [`SpeculativeDecoder::new`].
728    const DEFAULT_SEED: u64 = 0x5350_4543; // "SPEC"
729
730    /// Creates a speculative decoder that proposes `draft_length` tokens at a
731    /// time, using the default RNG seed.
732    #[must_use]
733    pub fn new(draft_length: usize) -> Self {
734        Self::with_seed(draft_length, Self::DEFAULT_SEED)
735    }
736
737    /// Creates a speculative decoder with an explicit RNG seed.
738    ///
739    /// A fixed seed makes the categorical sampling and rejection draws
740    /// reproducible, which is useful for tests and deterministic replay.
741    #[must_use]
742    pub fn with_seed(draft_length: usize, seed: u64) -> Self {
743        Self {
744            draft_length,
745            rng: LcgRng::new(seed),
746            total_proposed: 0,
747            total_accepted: 0,
748            rounds: 0,
749        }
750    }
751
752    /// Number of tokens proposed per speculation round (γ).
753    #[must_use]
754    pub fn draft_length(&self) -> usize {
755        self.draft_length
756    }
757
758    /// Proposes a sequence of drafted tokens from the draft model.
759    ///
760    /// `draft_probs` holds one probability distribution per draft position:
761    /// `draft_probs[i]` is the draft model's categorical distribution over the
762    /// vocabulary for the `i`-th drafted token (its length is the vocabulary
763    /// size).  Each drafted token is obtained by **categorical sampling** from
764    /// the corresponding distribution — this is genuine sampling from the draft
765    /// model, not a deterministic placeholder.
766    ///
767    /// Returns one token id per position, paired with the probability the
768    /// draft model assigned to the sampled token (`p_draft(t_i)`).  That
769    /// probability is exactly what [`Self::verify_and_accept`] needs as the
770    /// denominator of the acceptance ratio.
771    ///
772    /// At most `draft_length` positions are consumed; supplying fewer
773    /// distributions simply produces a shorter draft.
774    ///
775    /// # Errors
776    ///
777    /// Returns [`DnnError::InvalidArgument`] if any consumed distribution is
778    /// empty, has non-finite or all-zero mass, so that no token can be drawn.
779    pub fn propose_tokens(&mut self, draft_probs: &[Vec<f64>]) -> DnnResult<Vec<DraftedToken>> {
780        let count = draft_probs.len().min(self.draft_length);
781        let mut drafted = Vec::with_capacity(count);
782        for (position, dist) in draft_probs.iter().take(count).enumerate() {
783            let token = self.rng.sample_categorical(dist).ok_or_else(|| {
784                DnnError::InvalidArgument(format!(
785                    "draft distribution at position {position} has no positive, finite mass"
786                ))
787            })?;
788            let total: f64 = dist.iter().map(|p| p.max(0.0)).sum();
789            // `total > 0` is guaranteed: sample_categorical returned `Some`.
790            let draft_prob = dist[token].max(0.0) / total;
791            drafted.push(DraftedToken {
792                token_id: token as u32,
793                draft_prob,
794            });
795        }
796        Ok(drafted)
797    }
798
799    /// Verifies drafted tokens against the target model with modified
800    /// rejection sampling, returning the tokens to emit this round.
801    ///
802    /// # Arguments
803    ///
804    /// * `drafted` — tokens proposed by [`Self::propose_tokens`], each carrying
805    ///   the draft probability `p_draft(t_i)`.
806    /// * `target_dists` — one target-model probability distribution per drafted
807    ///   position: `target_dists[i]` is the target model's categorical
808    ///   distribution over the vocabulary at position `i`.  It must be long
809    ///   enough to cover every drafted token.
810    ///
811    /// # Algorithm
812    ///
813    /// For each drafted token `t_i` in order, a uniform `r ∈ [0, 1)` is drawn
814    /// and `t_i` is accepted iff `r < min(1, p_target(t_i) / p_draft(t_i))`.
815    /// On the first rejection at position `i`, a correction token is sampled
816    /// from the normalised residual distribution
817    /// `normalize(max(0, p_target[i] − p_draft[i]))` and verification stops.
818    /// If every drafted token is accepted, one bonus token is sampled from the
819    /// extra target distribution at position `drafted.len()`
820    /// (`target_dists` must therefore contain `drafted.len() + 1` rows in that
821    /// case — see the error conditions).
822    ///
823    /// # Errors
824    ///
825    /// Returns [`DnnError::InvalidArgument`] if `target_dists` does not cover
826    /// every drafted position (plus one extra row for the all-accepted bonus
827    /// token), if a target distribution referenced by a drafted token is too
828    /// short to contain that token, or if a distribution required for sampling
829    /// has no positive, finite mass.
830    pub fn verify_and_accept(
831        &mut self,
832        drafted: &[DraftedToken],
833        target_dists: &[Vec<f64>],
834    ) -> DnnResult<SpeculativeResult> {
835        let gamma = drafted.len();
836        // The all-accepted path needs one extra target distribution to draw
837        // the bonus token from, so `gamma + 1` rows are always required.
838        if target_dists.len() <= gamma {
839            return Err(DnnError::InvalidArgument(format!(
840                "target_dists must have at least {} rows (one per drafted token \
841                 plus a bonus row), got {}",
842                gamma + 1,
843                target_dists.len(),
844            )));
845        }
846
847        let mut tokens = Vec::with_capacity(gamma + 1);
848        for (i, draft) in drafted.iter().enumerate() {
849            let token = draft.token_id as usize;
850            let target_dist = &target_dists[i];
851            let target_total: f64 = target_dist.iter().map(|p| p.max(0.0)).sum();
852            let p_target = target_dist
853                .get(token)
854                .copied()
855                .ok_or_else(|| {
856                    DnnError::InvalidArgument(format!(
857                        "target distribution at position {i} (len {}) does not \
858                         contain drafted token id {token}",
859                        target_dist.len(),
860                    ))
861                })?
862                .max(0.0);
863            // Normalise the target probability of the drafted token so the
864            // acceptance ratio is between two genuine probabilities.
865            let p_target = if target_total > 0.0 {
866                p_target / target_total
867            } else {
868                0.0
869            };
870
871            // Acceptance ratio  min(1, p_target / p_draft).
872            // A draft probability of zero means the draft model could never
873            // have produced this token — treat it as an unconditional reject.
874            let accept_ratio = if draft.draft_prob > 0.0 {
875                (p_target / draft.draft_prob).min(1.0)
876            } else {
877                0.0
878            };
879
880            let r = self.rng.next_f64();
881            if r < accept_ratio {
882                tokens.push(draft.token_id);
883                continue;
884            }
885
886            // ---- First rejection: sample the correction token. ----------
887            let residual = Self::residual_distribution(target_dist, drafted, i);
888            let correction = self.rng.sample_categorical(&residual).ok_or_else(|| {
889                DnnError::InvalidArgument(format!(
890                    "residual distribution at position {i} has no positive mass"
891                ))
892            })?;
893            tokens.push(correction as u32);
894
895            let accepted = i;
896            self.record(gamma, accepted);
897            return Ok(SpeculativeResult {
898                tokens,
899                accepted,
900                rejected: 1,
901            });
902        }
903
904        // ---- Every drafted token accepted: sample one bonus token. ------
905        let bonus_dist = &target_dists[gamma];
906        let bonus = self.rng.sample_categorical(bonus_dist).ok_or_else(|| {
907            DnnError::InvalidArgument(
908                "bonus target distribution has no positive, finite mass".into(),
909            )
910        })?;
911        tokens.push(bonus as u32);
912
913        self.record(gamma, gamma);
914        Ok(SpeculativeResult {
915            tokens,
916            accepted: gamma,
917            rejected: 0,
918        })
919    }
920
921    /// Builds the normalised residual distribution
922    /// `normalize(max(0, p_target − p_draft))` at the rejection position `i`.
923    ///
924    /// The draft model's distribution at position `i` is reconstructed as a
925    /// one-hot vector on the drafted token: the draft sampled exactly that
926    /// token, so all of the draft mass relevant to the residual sits there.
927    /// This matches the speculative-sampling residual `p(x) − q(x)` clamped to
928    /// non-negative values.  When the residual sums to zero (the target placed
929    /// no extra mass anywhere) the raw target distribution is returned so the
930    /// caller still draws a valid token from `p_target`.
931    fn residual_distribution(
932        target_dist: &[f64],
933        drafted: &[DraftedToken],
934        position: usize,
935    ) -> Vec<f64> {
936        let target_total: f64 = target_dist.iter().map(|p| p.max(0.0)).sum();
937        let drafted_token = drafted[position].token_id as usize;
938        let draft_prob = drafted[position].draft_prob;
939
940        let mut residual: Vec<f64> = Vec::with_capacity(target_dist.len());
941        for (idx, &t) in target_dist.iter().enumerate() {
942            let p_target = if target_total > 0.0 {
943                t.max(0.0) / target_total
944            } else {
945                0.0
946            };
947            // q(x): the draft distribution is one-hot on the drafted token.
948            let p_draft = if idx == drafted_token {
949                draft_prob.max(0.0)
950            } else {
951                0.0
952            };
953            residual.push((p_target - p_draft).max(0.0));
954        }
955
956        let residual_sum: f64 = residual.iter().sum();
957        if residual_sum <= 0.0 {
958            // Degenerate residual — draw the correction straight from p_target.
959            return target_dist.iter().map(|p| p.max(0.0)).collect();
960        }
961        residual
962    }
963
964    /// Updates the running accept/propose counters for one finished round.
965    fn record(&mut self, proposed: usize, accepted: usize) {
966        self.total_proposed += proposed as u64;
967        self.total_accepted += accepted as u64;
968        self.rounds += 1;
969    }
970
971    /// Running acceptance rate across all calls to [`Self::verify_and_accept`].
972    ///
973    /// This is `total accepted drafted tokens / total drafted tokens` and
974    /// reflects the real algorithm — it excludes correction and bonus tokens,
975    /// which are emitted regardless of acceptance.
976    #[must_use]
977    pub fn acceptance_rate(&self) -> f64 {
978        if self.total_proposed == 0 {
979            return 0.0;
980        }
981        self.total_accepted as f64 / self.total_proposed as f64
982    }
983
984    /// Total number of drafted tokens proposed across all rounds.
985    #[must_use]
986    pub fn total_proposed(&self) -> u64 {
987        self.total_proposed
988    }
989
990    /// Total number of drafted tokens accepted across all rounds.
991    #[must_use]
992    pub fn total_accepted(&self) -> u64 {
993        self.total_accepted
994    }
995
996    /// Number of completed speculation rounds.
997    #[must_use]
998    pub fn rounds(&self) -> u64 {
999        self.rounds
1000    }
1001
1002    /// Average number of tokens emitted per round, including the correction or
1003    /// bonus token.
1004    ///
1005    /// With a draft length of γ this lies in `[1, γ + 1]`: each round always
1006    /// emits one correction-or-bonus token on top of the accepted prefix.
1007    /// It is the practical speed-up factor of speculative decoding versus
1008    /// plain autoregressive decoding (one target forward pass per round).
1009    #[must_use]
1010    pub fn mean_tokens_per_round(&self) -> f64 {
1011        if self.rounds == 0 {
1012            return 0.0;
1013        }
1014        // accepted drafted tokens + one emitted token (correction/bonus) per round.
1015        (self.total_accepted + self.rounds) as f64 / self.rounds as f64
1016    }
1017}
1018
1019/// A token drafted by the draft model, with the probability the draft model
1020/// assigned to it.
1021#[derive(Debug, Clone, Copy, PartialEq)]
1022pub struct DraftedToken {
1023    /// Sampled vocabulary id.
1024    pub token_id: u32,
1025    /// Draft-model probability `p_draft(token_id)` of the sampled token,
1026    /// normalised so it lies in `[0, 1]`.
1027    pub draft_prob: f64,
1028}
1029
1030// ---------------------------------------------------------------------------
1031// BatchMetrics
1032// ---------------------------------------------------------------------------
1033
1034/// Running statistics for the inference serving loop.
1035#[derive(Debug)]
1036pub struct BatchMetrics {
1037    /// (prefill_tokens, decode_tokens, latency_us) per step.
1038    steps: Vec<(usize, usize, u64)>,
1039    /// Time-to-first-token in microseconds for each request.
1040    ttft_samples: Vec<u64>,
1041}
1042
1043impl BatchMetrics {
1044    /// Create an empty metrics collector.
1045    pub fn new() -> Self {
1046        Self {
1047            steps: Vec::new(),
1048            ttft_samples: Vec::new(),
1049        }
1050    }
1051
1052    /// Record one scheduling step.
1053    pub fn record_step(&mut self, prefill_tokens: usize, decode_tokens: usize, latency_us: u64) {
1054        self.steps.push((prefill_tokens, decode_tokens, latency_us));
1055    }
1056
1057    /// Record a time-to-first-token sample (in microseconds).
1058    pub fn record_ttft(&mut self, ttft_us: u64) {
1059        self.ttft_samples.push(ttft_us);
1060    }
1061
1062    /// Average latency of steps that included prefill tokens (microseconds).
1063    pub fn avg_prefill_latency(&self) -> f64 {
1064        let prefills: Vec<u64> = self
1065            .steps
1066            .iter()
1067            .filter(|(p, _, _)| *p > 0)
1068            .map(|(_, _, l)| *l)
1069            .collect();
1070        if prefills.is_empty() {
1071            return 0.0;
1072        }
1073        prefills.iter().sum::<u64>() as f64 / prefills.len() as f64
1074    }
1075
1076    /// Average latency of steps that included decode tokens (microseconds).
1077    pub fn avg_decode_latency(&self) -> f64 {
1078        let decodes: Vec<u64> = self
1079            .steps
1080            .iter()
1081            .filter(|(_, d, _)| *d > 0)
1082            .map(|(_, _, l)| *l)
1083            .collect();
1084        if decodes.is_empty() {
1085            return 0.0;
1086        }
1087        decodes.iter().sum::<u64>() as f64 / decodes.len() as f64
1088    }
1089
1090    /// Average batch size (total tokens per step).
1091    pub fn avg_batch_size(&self) -> f64 {
1092        if self.steps.is_empty() {
1093            return 0.0;
1094        }
1095        let total: usize = self.steps.iter().map(|(p, d, _)| p + d).sum();
1096        total as f64 / self.steps.len() as f64
1097    }
1098
1099    /// Estimated token throughput (tokens / second).
1100    pub fn token_throughput(&self) -> f64 {
1101        if self.steps.is_empty() {
1102            return 0.0;
1103        }
1104        let total_tokens: usize = self.steps.iter().map(|(p, d, _)| p + d).sum();
1105        let total_us: u64 = self.steps.iter().map(|(_, _, l)| l).sum();
1106        if total_us == 0 {
1107            return 0.0;
1108        }
1109        total_tokens as f64 / (total_us as f64 / 1_000_000.0)
1110    }
1111
1112    /// Median (p50) time-to-first-token in microseconds.
1113    pub fn time_to_first_token_p50(&self) -> f64 {
1114        if self.ttft_samples.is_empty() {
1115            return 0.0;
1116        }
1117        let mut sorted = self.ttft_samples.clone();
1118        sorted.sort_unstable();
1119        let mid = sorted.len() / 2;
1120        if sorted.len() % 2 == 0 && sorted.len() >= 2 {
1121            (sorted[mid - 1] + sorted[mid]) as f64 / 2.0
1122        } else {
1123            sorted[mid] as f64
1124        }
1125    }
1126
1127    /// Human-readable performance report.
1128    pub fn format_report(&self) -> String {
1129        format!(
1130            "BatchMetrics Report\n\
1131             ====================\n\
1132             Steps recorded       : {}\n\
1133             Avg prefill latency  : {:.1} us\n\
1134             Avg decode latency   : {:.1} us\n\
1135             Avg batch size       : {:.1} tokens/step\n\
1136             Token throughput     : {:.0} tokens/s\n\
1137             TTFT p50             : {:.1} us\n\
1138             TTFT samples         : {}",
1139            self.steps.len(),
1140            self.avg_prefill_latency(),
1141            self.avg_decode_latency(),
1142            self.avg_batch_size(),
1143            self.token_throughput(),
1144            self.time_to_first_token_p50(),
1145            self.ttft_samples.len(),
1146        )
1147    }
1148}
1149
1150impl Default for BatchMetrics {
1151    fn default() -> Self {
1152        Self::new()
1153    }
1154}
1155
1156// ---------------------------------------------------------------------------
1157// Tests
1158// ---------------------------------------------------------------------------
1159
1160#[cfg(test)]
1161mod tests {
1162    use super::*;
1163
1164    fn default_config() -> BatchConfig {
1165        BatchConfig {
1166            max_batch_size: 8,
1167            max_total_tokens: 4096,
1168            max_sequence_length: 2048,
1169            prefill_batch_size: 1024,
1170            decode_batch_size: 8,
1171            scheduling_policy: SchedulingPolicy::Fcfs,
1172        }
1173    }
1174
1175    fn make_request(id: RequestId, seq_len: usize, max_new: usize) -> InferenceRequest {
1176        InferenceRequest {
1177            request_id: id,
1178            sequence_length: seq_len,
1179            max_new_tokens: max_new,
1180            priority: Priority::Normal,
1181            arrival_time_ns: id * 1000,
1182            deadline_ns: None,
1183        }
1184    }
1185
1186    // 1. Add single request
1187    #[test]
1188    fn test_add_single_request() {
1189        let mut batcher = ContinuousBatcher::new(default_config());
1190        let req = make_request(1, 128, 64);
1191        let id = batcher.add_request(req).expect("should succeed");
1192        assert_eq!(id, 1);
1193        assert_eq!(batcher.pending_requests(), 1);
1194        assert_eq!(batcher.active_requests(), 0);
1195    }
1196
1197    // 2. Batch step with mixed prefill/decode
1198    #[test]
1199    fn test_batch_step_mixed_prefill_decode() {
1200        let mut batcher = ContinuousBatcher::new(default_config());
1201        // First request: prefill + become decode
1202        batcher.add_request(make_request(1, 64, 32)).expect("add 1");
1203        let d1 = batcher.step().expect("step 1");
1204        assert_eq!(d1.prefill_requests.len(), 1);
1205
1206        // Second request while first is decoding
1207        batcher.add_request(make_request(2, 32, 16)).expect("add 2");
1208        let d2 = batcher.step().expect("step 2");
1209        assert!(!d2.decode_requests.is_empty(), "should have decode slots");
1210        assert!(!d2.prefill_requests.is_empty(), "should have prefill slots");
1211    }
1212
1213    // 3. Token budget allocation/release
1214    #[test]
1215    fn test_token_budget_allocation_release() {
1216        let mut alloc = TokenBudgetAllocator::new(1024);
1217        let slot = alloc.allocate_prefill(512);
1218        assert!(slot.is_some());
1219        assert!((alloc.utilization() - 0.5).abs() < 1e-9);
1220
1221        // Allocate more than remaining.
1222        assert!(alloc.allocate_prefill(600).is_none());
1223
1224        alloc.release(256);
1225        assert!((alloc.utilization() - 0.25).abs() < 1e-9);
1226    }
1227
1228    // 4. Paged KV allocation/free
1229    #[test]
1230    fn test_paged_kv_allocation_free() {
1231        let mut mgr = PagedKvManager::new(16, 64);
1232        let blocks = mgr.allocate(128).expect("allocate 128");
1233        assert_eq!(blocks.len(), 2);
1234        let (used, total) = mgr.usage();
1235        assert_eq!(used, 2);
1236        assert_eq!(total, 16);
1237
1238        mgr.free(&blocks);
1239        let (used, _) = mgr.usage();
1240        assert_eq!(used, 0);
1241    }
1242
1243    // 5. Copy-on-write
1244    #[test]
1245    fn test_copy_on_write() {
1246        let mut mgr = PagedKvManager::new(4, 64);
1247        let blocks = mgr.allocate(64).expect("allocate");
1248        assert_eq!(blocks.len(), 1);
1249        let orig = blocks[0];
1250
1251        // Bump ref count to simulate sharing.
1252        mgr.ref_counts[orig] = 2;
1253
1254        let new_id = mgr.copy_on_write(orig).expect("cow");
1255        assert_ne!(new_id, orig);
1256        // Old block should still be allocated (ref_count decremented to 1).
1257        assert!(!mgr.free_map[orig]);
1258        assert_eq!(mgr.ref_counts[orig], 1);
1259        assert_eq!(mgr.ref_counts[new_id], 1);
1260    }
1261
1262    // 6. Continuous batching with request completion
1263    #[test]
1264    fn test_continuous_batching_completion() {
1265        let mut batcher = ContinuousBatcher::new(default_config());
1266        batcher.add_request(make_request(10, 64, 8)).expect("add");
1267        let _ = batcher.step().expect("step");
1268        assert_eq!(batcher.active_requests(), 1);
1269
1270        batcher.complete_request(10).expect("complete");
1271        assert_eq!(batcher.active_requests(), 0);
1272    }
1273
1274    // 7. Preemption
1275    #[test]
1276    fn test_preemption() {
1277        let mut batcher = ContinuousBatcher::new(default_config());
1278        batcher.add_request(make_request(20, 64, 16)).expect("add");
1279        let _ = batcher.step().expect("step");
1280        assert_eq!(batcher.active_requests(), 1);
1281
1282        batcher.preempt(20).expect("preempt");
1283        assert_eq!(batcher.active_requests(), 0);
1284        // Preempted request is in the preempted queue.
1285        assert_eq!(batcher.pending_requests(), 1);
1286    }
1287
1288    // 8. FCFS scheduling order
1289    #[test]
1290    fn test_fcfs_scheduling_order() {
1291        let mut batcher = ContinuousBatcher::new(default_config());
1292        batcher.add_request(make_request(3, 32, 8)).expect("add 3");
1293        batcher.add_request(make_request(1, 32, 8)).expect("add 1");
1294        batcher.add_request(make_request(2, 32, 8)).expect("add 2");
1295        // arrival_time_ns = id * 1000, so order is 1, 2, 3.
1296        let d = batcher.step().expect("step");
1297        assert_eq!(d.prefill_requests, vec![1, 2, 3]);
1298    }
1299
1300    // 9. Priority-based scheduling
1301    #[test]
1302    fn test_priority_based_scheduling() {
1303        let mut config = default_config();
1304        config.scheduling_policy = SchedulingPolicy::PriorityBased;
1305        let mut batcher = ContinuousBatcher::new(config);
1306
1307        let mut low = make_request(1, 32, 8);
1308        low.priority = Priority::Low;
1309        low.arrival_time_ns = 100;
1310        let mut high = make_request(2, 32, 8);
1311        high.priority = Priority::High;
1312        high.arrival_time_ns = 200;
1313        let mut normal = make_request(3, 32, 8);
1314        normal.priority = Priority::Normal;
1315        normal.arrival_time_ns = 50;
1316
1317        batcher.add_request(low).expect("add low");
1318        batcher.add_request(high).expect("add high");
1319        batcher.add_request(normal).expect("add normal");
1320
1321        let d = batcher.step().expect("step");
1322        // High (2) first, then Normal (3), then Low (1).
1323        assert_eq!(d.prefill_requests, vec![2, 3, 1]);
1324    }
1325
1326    // 10. Deadline-aware scheduling
1327    #[test]
1328    fn test_deadline_aware_scheduling() {
1329        let mut config = default_config();
1330        config.scheduling_policy = SchedulingPolicy::DeadlineAware;
1331        let mut batcher = ContinuousBatcher::new(config);
1332
1333        let mut r1 = make_request(1, 32, 8);
1334        r1.deadline_ns = Some(5000);
1335        let mut r2 = make_request(2, 32, 8);
1336        r2.deadline_ns = Some(1000);
1337        let mut r3 = make_request(3, 32, 8);
1338        r3.deadline_ns = None; // No deadline → goes last.
1339
1340        batcher.add_request(r1).expect("add r1");
1341        batcher.add_request(r2).expect("add r2");
1342        batcher.add_request(r3).expect("add r3");
1343
1344        let d = batcher.step().expect("step");
1345        assert_eq!(d.prefill_requests, vec![2, 1, 3]);
1346    }
1347
1348    // 11. Speculative decoding — propose samples from the draft distribution.
1349    #[test]
1350    fn test_speculative_decoding_propose_samples_draft() {
1351        let mut spec = SpeculativeDecoder::with_seed(3, 12345);
1352        // Three positions, each a 4-token vocabulary. Position 0 always
1353        // produces token 2 (the only one with mass); position 1 produces
1354        // token 0; position 2 produces token 3.
1355        let draft_probs = vec![
1356            vec![0.0, 0.0, 1.0, 0.0],
1357            vec![1.0, 0.0, 0.0, 0.0],
1358            vec![0.0, 0.0, 0.0, 1.0],
1359        ];
1360        let drafted = spec.propose_tokens(&draft_probs).expect("propose");
1361        assert_eq!(drafted.len(), 3);
1362        assert_eq!(drafted[0].token_id, 2);
1363        assert_eq!(drafted[1].token_id, 0);
1364        assert_eq!(drafted[2].token_id, 3);
1365        // The draft probability of a one-hot distribution is exactly 1.0.
1366        for d in &drafted {
1367            assert!((d.draft_prob - 1.0).abs() < 1e-12);
1368        }
1369    }
1370
1371    // 11b. propose_tokens respects the draft_length cap and normalises probs.
1372    #[test]
1373    fn test_speculative_decoding_propose_caps_and_normalises() {
1374        let mut spec = SpeculativeDecoder::with_seed(2, 99);
1375        // Four positions supplied but draft_length == 2 → only 2 consumed.
1376        // Un-normalised weights: token 1 has 3/4 of the mass.
1377        let draft_probs = vec![
1378            vec![1.0, 3.0],
1379            vec![3.0, 1.0],
1380            vec![1.0, 0.0],
1381            vec![0.0, 1.0],
1382        ];
1383        let drafted = spec.propose_tokens(&draft_probs).expect("propose");
1384        assert_eq!(drafted.len(), 2, "draft_length caps the count");
1385        for d in &drafted {
1386            // draft_prob must be a genuine probability in [0, 1].
1387            assert!((0.0..=1.0).contains(&d.draft_prob));
1388            // For these two-element weight vectors the sampled token's
1389            // normalised probability is either 1/4 or 3/4.
1390            let p = d.draft_prob;
1391            assert!(
1392                (p - 0.25).abs() < 1e-12 || (p - 0.75).abs() < 1e-12,
1393                "unexpected normalised prob {p}"
1394            );
1395        }
1396    }
1397
1398    // 11c. propose_tokens rejects a degenerate (all-zero) distribution.
1399    #[test]
1400    fn test_speculative_decoding_propose_rejects_zero_dist() {
1401        let mut spec = SpeculativeDecoder::new(2);
1402        let draft_probs = vec![vec![0.0, 0.0, 0.0]];
1403        assert!(spec.propose_tokens(&draft_probs).is_err());
1404    }
1405
1406    // 11d. Categorical sampling reproduces the target frequencies (statistical).
1407    #[test]
1408    fn test_categorical_sampling_matches_distribution() {
1409        let mut rng = LcgRng::new(0x00C0_FFEE);
1410        // Target distribution over 4 categories.
1411        let weights = [0.1_f64, 0.2, 0.3, 0.4];
1412        let trials = 200_000;
1413        let mut counts = [0u64; 4];
1414        for _ in 0..trials {
1415            let idx = rng.sample_categorical(&weights).expect("sample");
1416            counts[idx] += 1;
1417        }
1418        for (i, &w) in weights.iter().enumerate() {
1419            let freq = counts[i] as f64 / trials as f64;
1420            assert!(
1421                (freq - w).abs() < 0.01,
1422                "category {i}: freq {freq} vs expected {w}"
1423            );
1424        }
1425    }
1426
1427    // 11e. Rejection sampling accepts with probability min(1, p_t / p_d).
1428    #[test]
1429    fn test_rejection_sampling_acceptance_probability() {
1430        // Drafted token id 0 with draft prob 0.8. Target prob 0.4 → the
1431        // acceptance ratio is 0.4 / 0.8 = 0.5: across many independent
1432        // single-token rounds, ~half should be accepted.
1433        let trials = 100_000;
1434        let mut accepted_rounds = 0u64;
1435        for seed in 0..trials {
1436            let mut spec = SpeculativeDecoder::with_seed(1, seed);
1437            let drafted = vec![DraftedToken {
1438                token_id: 0,
1439                draft_prob: 0.8,
1440            }];
1441            // Position 0 target dist: token 0 has prob 0.4, token 1 has 0.6.
1442            // Bonus row (position 1) is required even on rejection paths.
1443            let target = vec![vec![0.4, 0.6], vec![0.5, 0.5]];
1444            let res = spec.verify_and_accept(&drafted, &target).expect("verify");
1445            if res.accepted == 1 {
1446                accepted_rounds += 1;
1447            }
1448        }
1449        let rate = accepted_rounds as f64 / trials as f64;
1450        assert!(
1451            (rate - 0.5).abs() < 0.01,
1452            "acceptance rate {rate} should be ~0.5"
1453        );
1454    }
1455
1456    // 11f. p_target >= p_draft → unconditional acceptance.
1457    #[test]
1458    fn test_rejection_sampling_always_accepts_when_target_ge_draft() {
1459        for seed in 0..2000 {
1460            let mut spec = SpeculativeDecoder::with_seed(1, seed);
1461            let drafted = vec![DraftedToken {
1462                token_id: 0,
1463                draft_prob: 0.3,
1464            }];
1465            // Target prob of token 0 (normalised) is 0.6 >= 0.3 draft prob.
1466            let target = vec![vec![0.6, 0.4], vec![0.5, 0.5]];
1467            let res = spec.verify_and_accept(&drafted, &target).expect("verify");
1468            assert_eq!(res.accepted, 1, "ratio >= 1 must always accept");
1469            assert_eq!(res.rejected, 0);
1470        }
1471    }
1472
1473    // 11g. Zero draft probability → unconditional rejection.
1474    #[test]
1475    fn test_rejection_sampling_rejects_zero_draft_prob() {
1476        let mut spec = SpeculativeDecoder::with_seed(1, 7);
1477        let drafted = vec![DraftedToken {
1478            token_id: 0,
1479            draft_prob: 0.0,
1480        }];
1481        let target = vec![vec![0.9, 0.1], vec![0.5, 0.5]];
1482        let res = spec.verify_and_accept(&drafted, &target).expect("verify");
1483        assert_eq!(res.accepted, 0);
1484        assert_eq!(res.rejected, 1);
1485        // Exactly one correction token is emitted.
1486        assert_eq!(res.tokens.len(), 1);
1487    }
1488
1489    // 11h. Residual-distribution resampling is correct (statistical).
1490    #[test]
1491    fn test_residual_distribution_resampling() {
1492        // Force a guaranteed rejection so the correction token is always
1493        // drawn from the residual. Drafted token id 0 with draft prob 1.0 and
1494        // target prob 0.0 for token 0 → acceptance ratio 0 → always reject.
1495        //
1496        // Target distribution: [0.0, 0.5, 0.5]. Draft is one-hot on token 0.
1497        // Residual = max(0, p_target - p_draft):
1498        //   token 0: max(0, 0.0 - 1.0) = 0.0
1499        //   token 1: max(0, 0.5 - 0.0) = 0.5
1500        //   token 2: max(0, 0.5 - 0.0) = 0.5
1501        // Normalised residual = [0.0, 0.5, 0.5].
1502        let trials = 100_000;
1503        let mut counts = [0u64; 3];
1504        for seed in 0..trials {
1505            let mut spec = SpeculativeDecoder::with_seed(1, seed);
1506            let drafted = vec![DraftedToken {
1507                token_id: 0,
1508                draft_prob: 1.0,
1509            }];
1510            let target = vec![vec![0.0, 0.5, 0.5], vec![1.0, 0.0, 0.0]];
1511            let res = spec.verify_and_accept(&drafted, &target).expect("verify");
1512            assert_eq!(res.accepted, 0, "must reject");
1513            let corr = res.tokens[0] as usize;
1514            counts[corr] += 1;
1515        }
1516        let total = trials as f64;
1517        assert_eq!(counts[0], 0, "token 0 has zero residual mass");
1518        assert!((counts[1] as f64 / total - 0.5).abs() < 0.01);
1519        assert!((counts[2] as f64 / total - 0.5).abs() < 0.01);
1520    }
1521
1522    // 11i. Rejection of a token the target does not place mass on draws the
1523    //      correction from the residual concentrated elsewhere.
1524    #[test]
1525    fn test_residual_distribution_concentrated() {
1526        // Drafted token id 1 with draft prob 1.0; target dist [1.0, 0.0].
1527        // p_target(token 1) == 0 → acceptance ratio 0 → always reject.
1528        // residual = max(0, p_target - p_draft) with draft one-hot on token 1:
1529        //   token 0: max(0, 1.0 - 0.0) = 1.0
1530        //   token 1: max(0, 0.0 - 1.0) = 0.0
1531        // residual = [1.0, 0.0] (non-zero) → correction is always token 0.
1532        for seed in 0..1000 {
1533            let mut spec = SpeculativeDecoder::with_seed(1, seed);
1534            let drafted = vec![DraftedToken {
1535                token_id: 1,
1536                draft_prob: 1.0,
1537            }];
1538            let target = vec![vec![1.0, 0.0], vec![0.5, 0.5]];
1539            let res = spec.verify_and_accept(&drafted, &target).expect("verify");
1540            assert_eq!(res.accepted, 0);
1541            assert_eq!(res.tokens[0], 0, "residual concentrates on token 0");
1542        }
1543    }
1544
1545    // 11i-bis. A residual that sums to zero falls back to the target dist.
1546    #[test]
1547    fn test_residual_distribution_zero_fallback() {
1548        // residual_distribution returns the raw target distribution when the
1549        // clamped residual max(0, p_target - p_draft) sums to zero. This
1550        // happens when the draft one-hot mass fully covers the target mass at
1551        // the drafted token and the target has no mass anywhere else.
1552        // Construct that case directly: target one-hot on the drafted token.
1553        let drafted = [DraftedToken {
1554            token_id: 0,
1555            draft_prob: 1.0,
1556        }];
1557        let target_dist = [1.0_f64, 0.0, 0.0];
1558        let residual = SpeculativeDecoder::residual_distribution(&target_dist, &drafted, 0);
1559        // p_target = [1,0,0], p_draft one-hot on token 0 = [1,0,0]:
1560        // residual = [0,0,0] → sum 0 → fallback returns the target dist.
1561        assert_eq!(residual, vec![1.0, 0.0, 0.0]);
1562    }
1563
1564    // 11j. Draft == target accepts every token (statistical).
1565    #[test]
1566    fn test_speculative_draft_equals_target_accepts_all() {
1567        // When the draft and target distributions are identical, the
1568        // acceptance ratio p_target / p_draft is exactly 1.0 for every
1569        // drafted token, so all gamma tokens must always be accepted.
1570        let gamma = 5;
1571        for seed in 0..3000 {
1572            let mut spec = SpeculativeDecoder::with_seed(gamma, seed);
1573            // Identical draft/target distributions for every position.
1574            let dist = vec![0.15, 0.25, 0.20, 0.40];
1575            let draft_probs = vec![dist.clone(); gamma];
1576            let drafted = spec.propose_tokens(&draft_probs).expect("propose");
1577            assert_eq!(drafted.len(), gamma);
1578
1579            // Target: same distribution at every drafted position + bonus row.
1580            let target_dists = vec![dist.clone(); gamma + 1];
1581            let res = spec
1582                .verify_and_accept(&drafted, &target_dists)
1583                .expect("verify");
1584            assert_eq!(res.accepted, gamma, "draft==target must accept all");
1585            assert_eq!(res.rejected, 0);
1586            // Accepted prefix + 1 bonus token.
1587            assert_eq!(res.tokens.len(), gamma + 1);
1588        }
1589    }
1590
1591    // 11k. Accepted-length distribution is sane and accounting is correct.
1592    #[test]
1593    fn test_speculative_accepted_length_distribution() {
1594        let gamma = 4usize;
1595        let mut spec = SpeculativeDecoder::with_seed(gamma, 0xABCD);
1596        let rounds = 5000u64;
1597        let mut sum_accepted = 0u64;
1598        for _ in 0..rounds {
1599            // Draft distribution: token 0 always sampled (one-hot).
1600            let draft_probs = vec![vec![1.0, 0.0]; gamma];
1601            let drafted = spec.propose_tokens(&draft_probs).expect("propose");
1602            // Target: token 0 (drafted) has acceptance ratio 0.7/1.0 = 0.7.
1603            let target_dists = vec![vec![0.7, 0.3]; gamma + 1];
1604            let res = spec
1605                .verify_and_accept(&drafted, &target_dists)
1606                .expect("verify");
1607            assert!(res.accepted <= gamma, "accepted within [0, gamma]");
1608            assert_eq!(res.rejected, usize::from(res.accepted < gamma));
1609            // Emitted tokens = accepted prefix + exactly one extra token.
1610            assert_eq!(res.tokens.len(), res.accepted + 1);
1611            sum_accepted += res.accepted as u64;
1612        }
1613        // total_proposed = rounds * gamma; acceptance rate should track 0.7.
1614        assert_eq!(spec.total_proposed(), rounds * gamma as u64);
1615        assert_eq!(spec.total_accepted(), sum_accepted);
1616        assert_eq!(spec.rounds(), rounds);
1617        let rate = spec.acceptance_rate();
1618        // Per-position acceptance is p = 0.7, but the first rejection truncates
1619        // the round, so the realised rate is well below p. The expected number
1620        // of accepted drafted tokens per gamma=4 round is
1621        //   sum_{j=1}^{3} j*p^j*(1-p) + 4*p^4 = 1.7731,
1622        // giving an expected rate of 1.7731 / 4 ≈ 0.4433.
1623        assert!(
1624            (rate - 0.4433).abs() < 0.02,
1625            "acceptance rate {rate} should be ~0.4433"
1626        );
1627        // Mean tokens per round must lie in [1, gamma + 1].
1628        let mtpr = spec.mean_tokens_per_round();
1629        assert!(mtpr >= 1.0 && mtpr <= (gamma + 1) as f64, "mtpr {mtpr}");
1630    }
1631
1632    // 11l. verify_and_accept errors when target_dists is too short.
1633    #[test]
1634    fn test_speculative_verify_rejects_short_target() {
1635        let mut spec = SpeculativeDecoder::new(2);
1636        let drafted = vec![
1637            DraftedToken {
1638                token_id: 0,
1639                draft_prob: 0.5,
1640            },
1641            DraftedToken {
1642                token_id: 1,
1643                draft_prob: 0.5,
1644            },
1645        ];
1646        // Only 2 rows supplied; need gamma + 1 == 3.
1647        let target = vec![vec![0.5, 0.5], vec![0.5, 0.5]];
1648        assert!(spec.verify_and_accept(&drafted, &target).is_err());
1649    }
1650
1651    // 11m. verify_and_accept errors when a drafted token is out of vocab.
1652    #[test]
1653    fn test_speculative_verify_rejects_token_out_of_range() {
1654        let mut spec = SpeculativeDecoder::new(1);
1655        let drafted = vec![DraftedToken {
1656            token_id: 9, // out of range for a 2-token target distribution
1657            draft_prob: 0.5,
1658        }];
1659        let target = vec![vec![0.5, 0.5], vec![0.5, 0.5]];
1660        assert!(spec.verify_and_accept(&drafted, &target).is_err());
1661    }
1662
1663    // 11n. Empty draft → only a bonus token is emitted.
1664    #[test]
1665    fn test_speculative_empty_draft_emits_bonus() {
1666        let mut spec = SpeculativeDecoder::with_seed(4, 55);
1667        let drafted: Vec<DraftedToken> = Vec::new();
1668        // gamma == 0, so a single bonus row is required.
1669        let target = vec![vec![0.0, 1.0, 0.0]];
1670        let res = spec.verify_and_accept(&drafted, &target).expect("verify");
1671        assert_eq!(res.accepted, 0);
1672        assert_eq!(res.rejected, 0);
1673        assert_eq!(res.tokens, vec![1], "bonus drawn from one-hot target");
1674    }
1675
1676    // 11o. LcgRng produces uniform f64 in [0, 1) and is deterministic.
1677    #[test]
1678    fn test_lcg_rng_uniform_and_deterministic() {
1679        let mut a = LcgRng::new(2024);
1680        let mut b = LcgRng::new(2024);
1681        let mut sum = 0.0_f64;
1682        let n = 100_000;
1683        for _ in 0..n {
1684            let va = a.next_f64();
1685            let vb = b.next_f64();
1686            assert_eq!(va, vb, "same seed must yield same stream");
1687            assert!((0.0..1.0).contains(&va));
1688            sum += va;
1689        }
1690        // Mean of a uniform [0,1) sample should be close to 0.5.
1691        let mean = sum / n as f64;
1692        assert!((mean - 0.5).abs() < 0.01, "uniform mean {mean}");
1693    }
1694
1695    // 12. Batch metrics tracking
1696    #[test]
1697    fn test_batch_metrics_tracking() {
1698        let mut m = BatchMetrics::new();
1699        m.record_step(128, 0, 500);
1700        m.record_step(0, 8, 100);
1701        m.record_step(64, 4, 300);
1702
1703        assert!((m.avg_prefill_latency() - 400.0).abs() < 1e-9);
1704        assert!((m.avg_decode_latency() - 200.0).abs() < 1e-9);
1705        // (128+0+8+64+4) / 3 = 68.0
1706        assert!((m.avg_batch_size() - 68.0).abs() < 1e-9);
1707        assert!(m.token_throughput() > 0.0);
1708    }
1709
1710    // 13. Max batch size enforcement
1711    #[test]
1712    fn test_max_batch_size_enforcement() {
1713        let mut config = default_config();
1714        config.max_batch_size = 2;
1715        let mut batcher = ContinuousBatcher::new(config);
1716
1717        for i in 0..4 {
1718            batcher.add_request(make_request(i, 32, 8)).expect("add");
1719        }
1720        let d = batcher.step().expect("step");
1721        assert!(d.prefill_requests.len() <= 2);
1722        assert_eq!(batcher.active_requests(), d.prefill_requests.len());
1723    }
1724
1725    // 14. Queue management
1726    #[test]
1727    fn test_queue_management() {
1728        let mut batcher = ContinuousBatcher::new(default_config());
1729        assert_eq!(batcher.pending_requests(), 0);
1730
1731        batcher.add_request(make_request(1, 32, 8)).expect("add");
1732        batcher.add_request(make_request(2, 32, 8)).expect("add");
1733        assert_eq!(batcher.pending_requests(), 2);
1734
1735        let _ = batcher.step().expect("step");
1736        assert_eq!(batcher.pending_requests(), 0);
1737        assert_eq!(batcher.active_requests(), 2);
1738
1739        batcher.complete_request(1).expect("complete");
1740        assert_eq!(batcher.active_requests(), 1);
1741    }
1742
1743    // 15. Utilization calculation
1744    #[test]
1745    fn test_utilization_calculation() {
1746        let mut alloc = TokenBudgetAllocator::new(1000);
1747        assert!((alloc.utilization() - 0.0).abs() < 1e-9);
1748
1749        alloc.allocate_prefill(250);
1750        assert!((alloc.utilization() - 0.25).abs() < 1e-9);
1751
1752        let fitted = alloc.allocate_decode(900);
1753        assert_eq!(fitted, 750);
1754        assert!((alloc.utilization() - 1.0).abs() < 1e-9);
1755
1756        // Edge case: zero-capacity allocator.
1757        let zero = TokenBudgetAllocator::new(0);
1758        assert!((zero.utilization() - 0.0).abs() < 1e-9);
1759    }
1760
1761    // 16. Format report
1762    #[test]
1763    fn test_format_report() {
1764        let mut m = BatchMetrics::new();
1765        m.record_step(100, 10, 200);
1766        m.record_step(0, 8, 100);
1767        m.record_ttft(150);
1768        m.record_ttft(250);
1769        let report = m.format_report();
1770        assert!(report.contains("Steps recorded"));
1771        assert!(report.contains("Token throughput"));
1772        assert!(report.contains("TTFT p50"));
1773    }
1774}