Skip to main content

ferrum_testkit/
sampler.rs

1//! Mock sampler: greedy argmax over logits.
2
3use ferrum_interfaces::Sampler;
4use ferrum_types::{Result, TokenId};
5use rand::RngCore;
6
7/// Greedy sampler — always picks the token with highest logit.
8/// Deterministic, no temperature or top-k.
9pub struct MockSampler;
10
11impl Sampler for MockSampler {
12    fn sample(&self, logits: &[f32], _rng: &mut dyn RngCore) -> Result<TokenId> {
13        let (idx, _) = logits
14            .iter()
15            .enumerate()
16            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
17            .ok_or_else(|| ferrum_types::FerrumError::internal("Empty logits"))?;
18        Ok(TokenId::new(idx as u32))
19    }
20
21    fn name(&self) -> &str {
22        "mock-greedy"
23    }
24
25    fn is_deterministic(&self) -> bool {
26        true
27    }
28}