Skip to main content

ferrum_sampler/
lib.rs

1//! # Ferrum Sampler
2//!
3//! MVP sampler implementation for Ferrum inference stack.
4//!
5//! This crate provides a thin wrapper around the sampling interfaces defined in
6//! `ferrum-interfaces`, offering convenient factory functions and utilities for
7//! building sampling pipelines from `SamplingParams`.
8//!
9//! ## Design
10//!
11//! - **Re-export Interface Types**: All core types from `ferrum-interfaces::sampler`
12//! - **Factory Pattern**: Simple factory for creating samplers and configs
13//! - **Zero Overhead**: Direct delegation to interface implementations
14//!
15//! ## Usage
16//!
17//! ```no_run
18//! use ferrum_sampler::{build_sampling_config, sampler_from_params};
19//! use ferrum_types::SamplingParams;
20//!
21//! let params = SamplingParams::default();
22//! let config = build_sampling_config(&params);
23//! let sampler = sampler_from_params(&params);
24//! ```
25
26pub mod json_mode;
27
28// Re-export all sampler types from ferrum-interfaces
29pub use ferrum_interfaces::sampler::{
30    GreedySampler, LogitsProcessor, LogitsProcessorChain, MultiSampler, MultinomialSampler,
31    ProcessorPriority, RepetitionPenaltyProcessor, Sampler, SamplingConfig, SamplingConfigBuilder,
32    SamplingContext, SamplingStats, TemperatureProcessor, TopKProcessor, TopPProcessor,
33};
34
35// Re-export types from ferrum-types
36pub use ferrum_types::{Result, SamplingParams, TokenId};
37
38use rand::RngCore;
39use std::collections::HashMap;
40
41/// Default sampler factory for creating samplers and configurations.
42#[derive(Debug, Clone, Default)]
43pub struct DefaultSamplerFactory;
44
45impl DefaultSamplerFactory {
46    /// Create new factory instance
47    pub fn new() -> Self {
48        Self
49    }
50
51    /// Build sampling configuration from parameters
52    pub fn build_config(&self, params: &SamplingParams) -> SamplingConfig {
53        SamplingConfig::from_params(params)
54    }
55
56    /// Create sampler instance based on temperature
57    /// - temperature == 0.0 → GreedySampler (deterministic)
58    /// - temperature > 0.0 → MultinomialSampler (stochastic)
59    pub fn create_sampler(&self, params: &SamplingParams) -> Box<dyn Sampler + Send + Sync> {
60        if params.temperature == 0.0 {
61            Box::new(GreedySampler)
62        } else {
63            Box::new(MultinomialSampler)
64        }
65    }
66
67    /// Create sampling pipeline with config and sampler
68    pub fn build_pipeline(&self, params: &SamplingParams) -> SamplingPipeline {
69        let config = self.build_config(params);
70        SamplingPipeline { config }
71    }
72}
73
74/// Sampling pipeline that combines config and execution logic.
75///
76/// This struct holds a `SamplingConfig` and provides a convenient interface
77/// for sampling tokens with context.
78pub struct SamplingPipeline {
79    config: SamplingConfig,
80}
81
82impl SamplingPipeline {
83    /// Create new pipeline from parameters
84    pub fn new(params: &SamplingParams) -> Self {
85        let config = SamplingConfig::from_params(params);
86        Self { config }
87    }
88
89    /// Get reference to sampling config
90    pub fn config(&self) -> &SamplingConfig {
91        &self.config
92    }
93
94    /// Sample next token with full context
95    ///
96    /// # Arguments
97    /// * `step` - Current generation step (0-based)
98    /// * `logits` - Mutable logits array to process
99    /// * `previous_tokens` - Previously generated tokens
100    /// * `token_frequencies` - Token frequency map for penalties
101    /// * `sampling_params` - Sampling parameters for this step
102    /// * `rng` - Random number generator
103    pub fn sample_next(
104        &self,
105        step: usize,
106        logits: &mut [f32],
107        previous_tokens: &[TokenId],
108        token_frequencies: &HashMap<TokenId, usize>,
109        sampling_params: &SamplingParams,
110        rng: &mut dyn RngCore,
111    ) -> Result<TokenId> {
112        let vocab_size = logits.len();
113        let ctx = SamplingContext::new(
114            step,
115            sampling_params,
116            logits,
117            previous_tokens,
118            token_frequencies,
119            vocab_size,
120        );
121        self.config.sample(ctx, rng)
122    }
123
124    /// Simple sampling without context (uses default params)
125    pub fn sample_simple(&self, logits: &mut [f32], rng: &mut dyn RngCore) -> Result<TokenId> {
126        let params = SamplingParams::default();
127        let empty_tokens = Vec::new();
128        let empty_freqs = HashMap::new();
129        self.sample_next(0, logits, &empty_tokens, &empty_freqs, &params, rng)
130    }
131}
132
133// ============================================================================
134// Convenience Functions
135// ============================================================================
136
137/// Build sampling configuration from parameters.
138///
139/// This is the primary entry point for creating a `SamplingConfig`.
140pub fn build_sampling_config(params: &SamplingParams) -> SamplingConfig {
141    SamplingConfig::from_params(params)
142}
143
144/// Create sampler instance from parameters.
145///
146/// Returns a boxed `Sampler` trait object based on the temperature setting.
147pub fn sampler_from_params(params: &SamplingParams) -> Box<dyn Sampler + Send + Sync> {
148    DefaultSamplerFactory::new().create_sampler(params)
149}
150
151/// Build complete sampling pipeline from parameters.
152pub fn pipeline_from_params(params: &SamplingParams) -> SamplingPipeline {
153    DefaultSamplerFactory::new().build_pipeline(params)
154}
155
156/// Create a greedy sampler (always picks highest logit).
157pub fn greedy_sampler() -> Box<dyn Sampler + Send + Sync> {
158    Box::new(GreedySampler)
159}
160
161/// Create a multinomial sampler (probabilistic sampling).
162pub fn multinomial_sampler() -> Box<dyn Sampler + Send + Sync> {
163    Box::new(MultinomialSampler)
164}
165
166// ============================================================================
167// Tests
168// ============================================================================
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use rand::rngs::StdRng;
174    use rand::SeedableRng;
175
176    #[test]
177    fn test_factory_creates_greedy_for_zero_temp() {
178        let factory = DefaultSamplerFactory::new();
179        let params = SamplingParams {
180            temperature: 0.0,
181            ..Default::default()
182        };
183        let sampler = factory.create_sampler(&params);
184        assert!(sampler.is_deterministic());
185    }
186
187    #[test]
188    fn test_factory_creates_multinomial_for_nonzero_temp() {
189        let factory = DefaultSamplerFactory::new();
190        let params = SamplingParams {
191            temperature: 1.0,
192            ..Default::default()
193        };
194        let sampler = factory.create_sampler(&params);
195        assert!(!sampler.is_deterministic());
196    }
197
198    #[test]
199    fn test_build_sampling_config() {
200        let params = SamplingParams {
201            temperature: 0.8,
202            top_k: Some(50),
203            top_p: 0.95,
204            repetition_penalty: 1.1,
205            ..Default::default()
206        };
207        let config = build_sampling_config(&params);
208        // Config should be created successfully
209        // Should have: temperature, top_k, top_p, repetition_penalty processors
210        assert_eq!(config.processor_chain.processor_names().len(), 4);
211    }
212
213    #[test]
214    fn test_pipeline_sample_simple() {
215        let params = SamplingParams::greedy();
216        let pipeline = pipeline_from_params(&params);
217        let mut rng = StdRng::seed_from_u64(42);
218
219        let mut logits = vec![1.0, 5.0, 2.0, 0.5];
220        let token = pipeline.sample_simple(&mut logits, &mut rng).unwrap();
221
222        // Should select index 1 (highest logit)
223        assert_eq!(token.get(), 1);
224    }
225
226    #[test]
227    fn test_greedy_sampler_deterministic() {
228        let sampler = greedy_sampler();
229        assert!(sampler.is_deterministic());
230        assert_eq!(sampler.name(), "greedy");
231    }
232
233    #[test]
234    fn test_multinomial_sampler_stochastic() {
235        let sampler = multinomial_sampler();
236        assert!(!sampler.is_deterministic());
237        assert_eq!(sampler.name(), "multinomial");
238    }
239
240    #[test]
241    fn test_pipeline_with_context() {
242        let params = SamplingParams {
243            temperature: 1.0,
244            repetition_penalty: 1.2,
245            ..Default::default()
246        };
247        let pipeline = SamplingPipeline::new(&params);
248        let mut rng = StdRng::seed_from_u64(42);
249
250        let mut logits = vec![1.0, 2.0, 3.0, 2.0];
251        let previous_tokens = vec![TokenId::new(2)]; // Token 2 was generated before
252        let mut freqs = HashMap::new();
253        freqs.insert(TokenId::new(2), 1);
254
255        let token = pipeline
256            .sample_next(0, &mut logits, &previous_tokens, &freqs, &params, &mut rng)
257            .unwrap();
258
259        // Token should be valid
260        assert!(token.get() < 4);
261    }
262}