cortex_rust 0.6.0

High-performance LLM inference with 4-bit quantization and Test-Time Training (TTT)
Documentation
//! Speculative Decoding - Accelerated Token Generation
//!
//! Uses a small draft model to speculatively generate multiple tokens,
//! then verifies them with the target model in a single forward pass.
//! Achieves 2-3x speedup when draft model has high acceptance rate.
//!
//! # Algorithm
//! 1. Draft model generates K tokens speculatively
//! 2. Target model verifies all K tokens in one forward pass
//! 3. Accept tokens where draft probability matches target
//! 4. Resample first rejected token from target distribution
//!
//! # References
//! - "Fast Inference from Transformers via Speculative Decoding" (2022)
//! - "Accelerating Large Language Model Decoding with Speculative Sampling" (2023)

use candle_core::{Result, Tensor};
use rand::prelude::*;

/// Configuration for speculative decoding
#[derive(Debug, Clone)]
pub struct SpeculativeConfig {
    /// Number of tokens to speculate per iteration
    pub num_speculative_tokens: usize,
    /// Temperature for sampling (0 = greedy)
    pub temperature: f32,
    /// Top-p sampling threshold
    pub top_p: f32,
    /// Whether to use rejection sampling (vs greedy verification)
    pub use_rejection_sampling: bool,
}

impl Default for SpeculativeConfig {
    fn default() -> Self {
        Self {
            num_speculative_tokens: 4,
            temperature: 1.0,
            top_p: 0.9,
            use_rejection_sampling: true,
        }
    }
}

/// Result of speculative decoding iteration
#[derive(Debug)]
pub struct SpeculativeResult {
    /// Accepted tokens (may be 0 to num_speculative_tokens + 1)
    pub accepted_tokens: Vec<u32>,
    /// Number of draft tokens that were accepted
    pub num_accepted: usize,
    /// Total tokens generated in this iteration
    pub total_generated: usize,
    /// Acceptance rate for this iteration
    pub acceptance_rate: f32,
}

/// Speculative decoder state
pub struct SpeculativeDecoder {
    config: SpeculativeConfig,
    /// Running statistics
    total_draft_tokens: usize,
    total_accepted_tokens: usize,
}

impl SpeculativeDecoder {
    pub fn new(config: SpeculativeConfig) -> Self {
        Self {
            config,
            total_draft_tokens: 0,
            total_accepted_tokens: 0,
        }
    }

    /// Verify draft tokens against target model logits
    ///
    /// # Arguments
    /// * `draft_tokens` - Tokens generated by draft model
    /// * `draft_probs` - Probabilities from draft model for each token
    /// * `target_logits` - Logits from target model [batch, seq, vocab]
    ///
    /// # Returns
    /// * `SpeculativeResult` - Accepted tokens and statistics
    pub fn verify(
        &mut self,
        draft_tokens: &[u32],
        draft_probs: &[f32],
        target_logits: &Tensor,
    ) -> Result<SpeculativeResult> {
        let num_draft = draft_tokens.len();
        let mut accepted_tokens = Vec::with_capacity(num_draft + 1);
        let mut rng = rand::thread_rng();

        // Get target probabilities
        let target_probs = self.compute_probs(target_logits)?;
        // target_probs shape: [num_draft + 1, vocab_size]

        let _vocab_size = target_probs.dim(1)?;

        for i in 0..num_draft {
            let draft_token = draft_tokens[i] as usize;
            let draft_prob = draft_probs[i];

            // Get target probability for this position
            let target_prob_vec: Vec<f32> = target_probs.get(i)?.to_vec1()?;
            let target_prob = target_prob_vec[draft_token];

            // Rejection sampling criterion
            let accept = if self.config.use_rejection_sampling {
                // Accept with probability min(1, target_prob / draft_prob)
                let acceptance_prob = (target_prob / draft_prob.max(1e-10)).min(1.0);
                rng.gen::<f32>() < acceptance_prob
            } else {
                // Greedy: accept if target agrees
                let target_token = argmax(&target_prob_vec);
                target_token == draft_token
            };

            if accept {
                accepted_tokens.push(draft_tokens[i]);
            } else {
                // Rejection: sample from adjusted distribution
                // P_adjusted = max(0, P_target - P_draft) / Z
                if self.config.use_rejection_sampling {
                    let adjusted_token =
                        self.sample_adjusted(&target_prob_vec, draft_prob, draft_token, &mut rng);
                    accepted_tokens.push(adjusted_token);
                } else {
                    // Greedy: use target's argmax
                    accepted_tokens.push(argmax(&target_prob_vec) as u32);
                }
                break; // Stop after first rejection
            }
        }

        // If all draft tokens accepted, sample one more from target
        if accepted_tokens.len() == num_draft {
            let final_probs: Vec<f32> = target_probs.get(num_draft)?.to_vec1()?;
            let next_token = self.sample_from_probs(&final_probs, &mut rng);
            accepted_tokens.push(next_token);
        }

        // Update statistics
        let num_accepted = accepted_tokens.len().saturating_sub(1).min(num_draft);
        self.total_draft_tokens += num_draft;
        self.total_accepted_tokens += num_accepted;

        let acceptance_rate = if num_draft > 0 {
            num_accepted as f32 / num_draft as f32
        } else {
            0.0
        };

        let total_generated = accepted_tokens.len();
        Ok(SpeculativeResult {
            accepted_tokens,
            num_accepted,
            total_generated,
            acceptance_rate,
        })
    }

    /// Compute probabilities from logits with temperature and top-p
    fn compute_probs(&self, logits: &Tensor) -> Result<Tensor> {
        let logits = if self.config.temperature != 1.0 {
            (logits / self.config.temperature as f64)?
        } else {
            logits.clone()
        };

        // Softmax
        let probs = candle_nn::ops::softmax_last_dim(&logits)?;

        // Top-p filtering would go here (optional)

        Ok(probs)
    }

    /// Sample from adjusted distribution after rejection
    fn sample_adjusted(
        &self,
        target_probs: &[f32],
        draft_prob: f32,
        draft_token: usize,
        rng: &mut impl Rng,
    ) -> u32 {
        // Compute adjusted distribution: max(0, P_target - P_draft)
        let mut adjusted: Vec<f32> = target_probs
            .iter()
            .enumerate()
            .map(|(i, &p)| {
                if i == draft_token {
                    (p - draft_prob).max(0.0)
                } else {
                    p
                }
            })
            .collect();

        // Normalize
        let sum: f32 = adjusted.iter().sum();
        if sum > 1e-10 {
            for p in &mut adjusted {
                *p /= sum;
            }
            self.sample_from_probs(&adjusted, rng)
        } else {
            // Fallback to argmax of target
            argmax(target_probs) as u32
        }
    }

    /// Sample a token from probability distribution
    fn sample_from_probs(&self, probs: &[f32], rng: &mut impl Rng) -> u32 {
        if self.config.temperature == 0.0 {
            return argmax(probs) as u32;
        }

        let r: f32 = rng.gen();
        let mut cumsum = 0.0;
        for (i, &p) in probs.iter().enumerate() {
            cumsum += p;
            if r < cumsum {
                return i as u32;
            }
        }
        (probs.len() - 1) as u32
    }

    /// Get overall acceptance rate
    pub fn overall_acceptance_rate(&self) -> f32 {
        if self.total_draft_tokens > 0 {
            self.total_accepted_tokens as f32 / self.total_draft_tokens as f32
        } else {
            0.0
        }
    }

    /// Reset statistics
    pub fn reset_stats(&mut self) {
        self.total_draft_tokens = 0;
        self.total_accepted_tokens = 0;
    }

    /// Get expected speedup based on acceptance rate
    /// Speedup = (1 + γα) / (1 + γ/β) where:
    /// - γ = num_speculative_tokens
    /// - α = acceptance_rate
    /// - β = target_model_time / draft_model_time (assumed ~10)
    pub fn expected_speedup(&self) -> f32 {
        let gamma = self.config.num_speculative_tokens as f32;
        let alpha = self.overall_acceptance_rate();
        let beta = 10.0; // Typical ratio

        (1.0 + gamma * alpha) / (1.0 + gamma / beta)
    }
}

/// Find argmax of a slice
fn argmax(slice: &[f32]) -> usize {
    slice
        .iter()
        .enumerate()
        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
        .map(|(i, _)| i)
        .unwrap_or(0)
}

#[cfg(test)]
mod tests {
    use super::*;
    use candle_core::Device;

    #[test]
    fn test_speculative_config_default() {
        let config = SpeculativeConfig::default();
        assert_eq!(config.num_speculative_tokens, 4);
        assert_eq!(config.temperature, 1.0);
    }

    #[test]
    fn test_argmax() {
        assert_eq!(argmax(&[0.1, 0.5, 0.3, 0.1]), 1);
        assert_eq!(argmax(&[0.9, 0.05, 0.05]), 0);
    }

    #[test]
    fn test_speculative_decoder_creation() {
        let decoder = SpeculativeDecoder::new(SpeculativeConfig::default());
        assert_eq!(decoder.overall_acceptance_rate(), 0.0);
    }

    #[test]
    fn test_verify_all_accepted() {
        let device = Device::Cpu;
        let mut decoder = SpeculativeDecoder::new(SpeculativeConfig {
            use_rejection_sampling: false, // Greedy for deterministic test
            ..Default::default()
        });

        // Draft tokens
        let draft_tokens = vec![1u32, 2, 3];
        let draft_probs = vec![0.9, 0.8, 0.7];

        // Target logits that agree with draft (token 1, 2, 3 have highest prob)
        // Shape: [4, 10] (3 verification + 1 next token)
        let mut logits_data = vec![0.0f32; 4 * 10];
        logits_data[1] = 10.0; // Position 0: token 1 wins
        logits_data[12] = 10.0; // Position 1: token 2 wins
        logits_data[23] = 10.0; // Position 2: token 3 wins
        logits_data[34] = 10.0; // Position 3: token 4 wins (next token)

        let target_logits = Tensor::from_vec(logits_data, (4, 10), &device).unwrap();

        let result = decoder
            .verify(&draft_tokens, &draft_probs, &target_logits)
            .unwrap();

        assert_eq!(result.num_accepted, 3);
        assert_eq!(result.accepted_tokens.len(), 4); // 3 accepted + 1 bonus
        assert_eq!(result.accepted_tokens[0..3], vec![1, 2, 3]);
    }

    #[test]
    fn test_verify_early_rejection() {
        let device = Device::Cpu;
        let mut decoder = SpeculativeDecoder::new(SpeculativeConfig {
            use_rejection_sampling: false,
            ..Default::default()
        });

        let draft_tokens = vec![1u32, 2, 3];
        let draft_probs = vec![0.9, 0.8, 0.7];

        // Target disagrees at position 1 (prefers token 5 instead of 2)
        let mut logits_data = vec![0.0f32; 4 * 10];
        logits_data[1] = 10.0; // Position 0: token 1 wins (agree)
        logits_data[15] = 10.0; // Position 1: token 5 wins (disagree with draft token 2)
        logits_data[23] = 10.0; // Position 2: would be token 3 but won't reach here
        logits_data[34] = 10.0; // Position 3: won't reach here

        let target_logits = Tensor::from_vec(logits_data, (4, 10), &device).unwrap();

        let result = decoder
            .verify(&draft_tokens, &draft_probs, &target_logits)
            .unwrap();

        assert_eq!(result.num_accepted, 1); // Only first token accepted
        assert_eq!(result.accepted_tokens.len(), 2); // 1 accepted + 1 from target
        assert_eq!(result.accepted_tokens[0], 1); // First draft token
        assert_eq!(result.accepted_tokens[1], 5); // Target's choice
    }

    #[test]
    fn test_expected_speedup() {
        let mut decoder = SpeculativeDecoder::new(SpeculativeConfig {
            num_speculative_tokens: 4,
            ..Default::default()
        });

        // Simulate 80% acceptance rate
        decoder.total_draft_tokens = 100;
        decoder.total_accepted_tokens = 80;

        let speedup = decoder.expected_speedup();
        // With 80% acceptance and 4 speculative tokens, expect ~3x speedup
        // Formula: (1 + 4*0.8) / (1 + 4/10) = 4.2 / 1.4 = 3.0
        assert!(speedup > 2.5 && speedup < 3.5, "speedup = {}", speedup);
    }
}