boostr 0.1.0

ML framework built on numr - attention, quantization, model architectures
Documentation
use super::*;
use numr::runtime::cpu::CpuRuntime;

struct MockModel {
    vocab_size: usize,
    name: String,
    fixed_logits: Option<Vec<f32>>,
}

impl MockModel {
    fn new(vocab_size: usize, name: &str) -> Self {
        Self {
            vocab_size,
            name: name.to_string(),
            fixed_logits: None,
        }
    }

    fn with_fixed_logits(mut self, logits: Vec<f32>) -> Self {
        self.fixed_logits = Some(logits);
        self
    }
}

impl SpeculativeModel<CpuRuntime> for MockModel {
    fn forward(&mut self, _input_tokens: &[TokenId], _position: usize) -> Result<Vec<f32>> {
        if let Some(ref logits) = self.fixed_logits {
            Ok(logits.clone())
        } else {
            Ok(vec![0.0; self.vocab_size])
        }
    }

    fn vocab_size(&self) -> usize {
        self.vocab_size
    }

    fn reset(&mut self) -> Result<()> {
        Ok(())
    }

    fn name(&self) -> &str {
        &self.name
    }
}

#[test]
fn test_speculative_config_default() {
    let config = SpeculativeConfig::default();
    assert_eq!(config.num_speculative_tokens, 4);
    assert!((config.draft_temperature - 0.7).abs() < 1e-6);
}

#[test]
fn test_speculative_stats() {
    let stats = SpeculativeStats {
        accepted_tokens: 80,
        rejected_tokens: 20,
        total_tokens: 100,
        target_forward_passes: 25,
        draft_forward_passes: 100,
        ..Default::default()
    };

    assert!((stats.acceptance_rate() - 0.8).abs() < 1e-6);
    assert!((stats.tokens_per_forward() - 4.0).abs() < 1e-6);
}

#[test]
fn test_speculative_executor_creation() {
    let draft = MockModel::new(100, "draft");
    let target = MockModel::new(100, "target");
    let config = SpeculativeConfig::default();

    let executor = SpeculativeExecutor::<CpuRuntime, _, _>::new(draft, target, config);

    assert_eq!(executor.current_depth(), 4);
    assert_eq!(executor.stats().total_tokens, 0);
}

#[test]
fn test_speculative_executor_generate() {
    let mut logits = vec![-10.0; 100];
    logits[5] = 0.0;
    logits[2] = -100.0;

    let draft = MockModel::new(100, "draft").with_fixed_logits(logits.clone());
    let target = MockModel::new(100, "target").with_fixed_logits(logits);

    let config = SpeculativeConfig {
        num_speculative_tokens: 2,
        seed: Some(12345),
        ..Default::default()
    };

    let mut executor = SpeculativeExecutor::<CpuRuntime, _, _>::new(draft, target, config);

    let prompt = vec![1, 2, 3];
    let result = executor.generate(&prompt, 5).unwrap();

    assert!(!result.is_empty());
    assert!(result.len() <= 5);

    let stats = executor.stats();
    assert!(stats.total_tokens > 0);
    assert!(stats.target_forward_passes > 0);
}

#[test]
fn test_verification_result() {
    let result = VerificationResult {
        num_accepted: 3,
        accepted_tokens: vec![10, 20, 30],
        bonus_token: Some(40),
        first_rejection_pos: -1,
    };

    assert_eq!(result.num_accepted, 3);
    assert_eq!(result.accepted_tokens.len(), 3);
    assert_eq!(result.bonus_token, Some(40));
}

#[test]
fn test_apply_temperature_and_softmax() {
    let draft = MockModel::new(4, "draft");
    let target = MockModel::new(4, "target");
    let config = SpeculativeConfig::default();
    let executor = SpeculativeExecutor::<CpuRuntime, _, _>::new(draft, target, config);

    let logits = vec![1.0, 2.0, 3.0, 4.0];
    let log_probs = executor.apply_temperature_and_softmax(&logits, 1.0);

    let sum: f32 = log_probs.iter().map(|&lp| lp.exp()).sum();
    assert!((sum - 1.0).abs() < 1e-5);

    assert!(log_probs[3] > log_probs[2]);
    assert!(log_probs[2] > log_probs[1]);
    assert!(log_probs[1] > log_probs[0]);
}

#[test]
fn test_adaptive_depth() {
    let draft = MockModel::new(100, "draft");
    let target = MockModel::new(100, "target");
    let config = SpeculativeConfig {
        num_speculative_tokens: 4,
        adaptive_depth: true,
        min_acceptance_rate: 0.3,
        ..Default::default()
    };

    let mut executor = SpeculativeExecutor::<CpuRuntime, _, _>::new(draft, target, config);

    executor.stats.accepted_tokens = 10;
    executor.stats.rejected_tokens = 90;
    executor.adjust_depth();
    assert!(executor.current_depth() < 4);

    executor.stats.accepted_tokens = 90;
    executor.stats.rejected_tokens = 10;
    executor.current_depth = 2;
    executor.adjust_depth();
    assert!(executor.current_depth() > 2);
}