Skip to main content

ferrum_interfaces/
sampler.rs

1//! Sampling and logits processing interfaces
2//!
3//! This module provides abstractions for sampling tokens from model outputs,
4//! including various sampling strategies and logits processors. These are
5//! completely separate from model execution to allow for flexible composition.
6
7use ferrum_types::{Result, SamplingParams, TokenId};
8use rand::RngCore;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Sampling context passed to logits processors and samplers
13#[derive(Debug)]
14pub struct SamplingContext<'a> {
15    /// Current generation step (0-based)
16    pub step: usize,
17    /// Request-specific sampling parameters
18    pub sampling_params: &'a SamplingParams,
19    /// Current logits (mutable for processing)
20    pub logits: &'a mut [f32],
21    /// Previous token IDs in sequence  
22    pub previous_tokens: &'a [TokenId],
23    /// Token frequencies for repetition penalty
24    pub token_frequencies: &'a HashMap<TokenId, usize>,
25    /// Vocabulary size
26    pub vocab_size: usize,
27    /// Additional metadata
28    pub metadata: HashMap<String, f32>,
29}
30
31impl<'a> SamplingContext<'a> {
32    /// Create new sampling context
33    pub fn new(
34        step: usize,
35        sampling_params: &'a SamplingParams,
36        logits: &'a mut [f32],
37        previous_tokens: &'a [TokenId],
38        token_frequencies: &'a HashMap<TokenId, usize>,
39        vocab_size: usize,
40    ) -> Self {
41        Self {
42            step,
43            sampling_params,
44            logits,
45            previous_tokens,
46            token_frequencies,
47            vocab_size,
48            metadata: HashMap::new(),
49        }
50    }
51
52    /// Get logit value for specific token
53    pub fn get_logit(&self, token_id: TokenId) -> Option<f32> {
54        if usize::from(token_id) < self.logits.len() {
55            Some(self.logits[usize::from(token_id)])
56        } else {
57            None
58        }
59    }
60
61    /// Set logit value for specific token
62    pub fn set_logit(&mut self, token_id: TokenId, value: f32) -> bool {
63        if usize::from(token_id) < self.logits.len() {
64            self.logits[usize::from(token_id)] = value;
65            true
66        } else {
67            false
68        }
69    }
70
71    /// Mask (set to negative infinity) specific tokens
72    pub fn mask_tokens(&mut self, token_ids: &[TokenId]) {
73        for &token_id in token_ids {
74            if usize::from(token_id) < self.logits.len() {
75                self.logits[usize::from(token_id)] = f32::NEG_INFINITY;
76            }
77        }
78    }
79}
80
81/// Logits processor trait for modifying raw model outputs
82pub trait LogitsProcessor: Send + Sync {
83    /// Process logits in-place
84    fn process(&self, ctx: &mut SamplingContext) -> Result<()>;
85
86    /// Get processor name for debugging/logging
87    fn name(&self) -> &str;
88
89    /// Whether this processor should be applied before others
90    fn priority(&self) -> ProcessorPriority {
91        ProcessorPriority::Normal
92    }
93}
94
95/// Priority levels for logits processors
96#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
97pub enum ProcessorPriority {
98    /// Applied first (e.g., hard constraints, token masking)
99    High = 3,
100    /// Normal processing order
101    Normal = 2,
102    /// Applied later (e.g., temperature scaling)
103    Low = 1,
104}
105
106/// Token sampler trait for selecting next token from processed logits
107pub trait Sampler: Send + Sync {
108    /// Sample next token from logits
109    fn sample(&self, logits: &[f32], rng: &mut dyn RngCore) -> Result<TokenId>;
110
111    /// Sample with additional context (default implementation ignores context)
112    fn sample_with_context(&self, ctx: &SamplingContext, rng: &mut dyn RngCore) -> Result<TokenId> {
113        self.sample(ctx.logits, rng)
114    }
115
116    /// Get sampler name
117    fn name(&self) -> &str;
118
119    /// Whether this sampler is deterministic
120    fn is_deterministic(&self) -> bool;
121}
122
123/// Multi-sample capability for beam search and parallel sampling
124pub trait MultiSampler: Sampler {
125    /// Sample multiple tokens at once
126    fn sample_multiple(
127        &self,
128        logits: &[f32],
129        num_samples: usize,
130        rng: &mut dyn RngCore,
131    ) -> Result<Vec<TokenId>>;
132
133    /// Sample with probabilities for each token
134    fn sample_with_probabilities(
135        &self,
136        logits: &[f32],
137        rng: &mut dyn RngCore,
138    ) -> Result<(TokenId, Vec<f32>)>;
139}
140
141/// Logits processor chain for composing multiple processors
142pub struct LogitsProcessorChain {
143    processors: Vec<Box<dyn LogitsProcessor>>,
144}
145
146impl LogitsProcessorChain {
147    /// Create new processor chain
148    pub fn new() -> Self {
149        Self {
150            processors: Vec::new(),
151        }
152    }
153
154    /// Add processor to chain
155    pub fn add_processor(mut self, processor: Box<dyn LogitsProcessor>) -> Self {
156        self.processors.push(processor);
157        // Sort by priority (high to low)
158        self.processors
159            .sort_by(|a, b| b.priority().cmp(&a.priority()));
160        self
161    }
162
163    /// Process logits through entire chain
164    pub fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
165        for processor in &self.processors {
166            processor.process(ctx)?;
167        }
168        Ok(())
169    }
170
171    /// Get all processor names in order
172    pub fn processor_names(&self) -> Vec<&str> {
173        self.processors.iter().map(|p| p.name()).collect()
174    }
175}
176
177impl Default for LogitsProcessorChain {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183/// Common logits processors
184
185/// Temperature scaling processor
186pub struct TemperatureProcessor {
187    pub temperature: f32,
188}
189
190impl TemperatureProcessor {
191    pub fn new(temperature: f32) -> Self {
192        Self { temperature }
193    }
194}
195
196impl LogitsProcessor for TemperatureProcessor {
197    fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
198        if self.temperature > 0.0 && self.temperature != 1.0 {
199            for logit in ctx.logits.iter_mut() {
200                *logit /= self.temperature;
201            }
202        }
203        Ok(())
204    }
205
206    fn name(&self) -> &str {
207        "temperature"
208    }
209
210    fn priority(&self) -> ProcessorPriority {
211        ProcessorPriority::Low // Apply temperature scaling last
212    }
213}
214
215/// Top-k filtering processor
216pub struct TopKProcessor {
217    pub k: usize,
218}
219
220impl TopKProcessor {
221    pub fn new(k: usize) -> Self {
222        Self { k }
223    }
224}
225
226impl LogitsProcessor for TopKProcessor {
227    fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
228        if self.k > 0 && self.k < ctx.logits.len() {
229            // Find k-th largest logit
230            let mut indices: Vec<usize> = (0..ctx.logits.len()).collect();
231            indices.sort_by(|&a, &b| {
232                ctx.logits[b]
233                    .partial_cmp(&ctx.logits[a])
234                    .unwrap_or(std::cmp::Ordering::Equal)
235            });
236
237            let threshold = ctx.logits[indices[self.k - 1]];
238
239            // Mask tokens below threshold
240            for logit in ctx.logits.iter_mut() {
241                if *logit < threshold {
242                    *logit = f32::NEG_INFINITY;
243                }
244            }
245        }
246        Ok(())
247    }
248
249    fn name(&self) -> &str {
250        "top_k"
251    }
252}
253
254/// Top-p (nucleus) filtering processor
255pub struct TopPProcessor {
256    pub p: f32,
257}
258
259impl TopPProcessor {
260    pub fn new(p: f32) -> Self {
261        Self { p }
262    }
263}
264
265impl LogitsProcessor for TopPProcessor {
266    fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
267        if self.p < 1.0 && self.p > 0.0 {
268            // Convert logits to probabilities
269            let max_logit = ctx.logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
270            let mut probs: Vec<f32> = ctx
271                .logits
272                .iter()
273                .map(|&logit| (logit - max_logit).exp())
274                .collect();
275
276            let sum: f32 = probs.iter().sum();
277            for prob in probs.iter_mut() {
278                *prob /= sum;
279            }
280
281            // Sort by probability
282            let mut indices: Vec<usize> = (0..probs.len()).collect();
283            indices.sort_by(|&a, &b| {
284                probs[b]
285                    .partial_cmp(&probs[a])
286                    .unwrap_or(std::cmp::Ordering::Equal)
287            });
288
289            // Find cumulative probability threshold
290            let mut cum_prob = 0.0;
291            let mut cutoff_idx = probs.len();
292
293            for (i, &idx) in indices.iter().enumerate() {
294                cum_prob += probs[idx];
295                if cum_prob > self.p {
296                    cutoff_idx = i + 1;
297                    break;
298                }
299            }
300
301            // Mask tokens beyond cutoff
302            for (i, &idx) in indices.iter().enumerate() {
303                if i >= cutoff_idx {
304                    ctx.logits[idx] = f32::NEG_INFINITY;
305                }
306            }
307        }
308        Ok(())
309    }
310
311    fn name(&self) -> &str {
312        "top_p"
313    }
314}
315
316/// Repetition penalty processor
317pub struct RepetitionPenaltyProcessor {
318    pub penalty: f32,
319}
320
321impl RepetitionPenaltyProcessor {
322    pub fn new(penalty: f32) -> Self {
323        Self { penalty }
324    }
325}
326
327impl LogitsProcessor for RepetitionPenaltyProcessor {
328    fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
329        if self.penalty != 1.0 {
330            for &token_id in ctx.previous_tokens {
331                if let Some(freq) = ctx.token_frequencies.get(&token_id) {
332                    if usize::from(token_id) < ctx.logits.len() {
333                        let idx = usize::from(token_id);
334                        let current_logit = ctx.logits[idx];
335                        let penalty_factor = self.penalty.powi(*freq as i32);
336
337                        if current_logit > 0.0 {
338                            ctx.logits[idx] = current_logit / penalty_factor;
339                        } else {
340                            ctx.logits[idx] = current_logit * penalty_factor;
341                        }
342                    }
343                }
344            }
345        }
346        Ok(())
347    }
348
349    fn name(&self) -> &str {
350        "repetition_penalty"
351    }
352
353    fn priority(&self) -> ProcessorPriority {
354        ProcessorPriority::High // Apply penalties early
355    }
356}
357
358/// Common samplers
359
360/// Greedy sampler (always picks highest probability token)
361pub struct GreedySampler;
362
363impl Sampler for GreedySampler {
364    fn sample(&self, logits: &[f32], _rng: &mut dyn RngCore) -> Result<TokenId> {
365        let max_idx = logits
366            .iter()
367            .enumerate()
368            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
369            .map(|(idx, _)| idx)
370            .ok_or_else(|| ferrum_types::FerrumError::backend("Empty logits for sampling"))?;
371
372        Ok(TokenId::new(max_idx as u32))
373    }
374
375    fn name(&self) -> &str {
376        "greedy"
377    }
378
379    fn is_deterministic(&self) -> bool {
380        true
381    }
382}
383
384/// Multinomial sampler for probabilistic sampling
385pub struct MultinomialSampler;
386
387impl Sampler for MultinomialSampler {
388    fn sample(&self, logits: &[f32], rng: &mut dyn RngCore) -> Result<TokenId> {
389        // Convert logits to probabilities
390        let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
391
392        let mut probs: Vec<f32> = logits
393            .iter()
394            .map(|&logit| {
395                if logit.is_finite() && logit > f32::NEG_INFINITY {
396                    (logit - max_logit).exp()
397                } else {
398                    0.0
399                }
400            })
401            .collect();
402
403        let sum: f32 = probs.iter().sum();
404        if sum <= 0.0 {
405            return Err(ferrum_types::FerrumError::backend(
406                "No valid tokens for sampling",
407            ));
408        }
409
410        for prob in probs.iter_mut() {
411            *prob /= sum;
412        }
413
414        // Sample from categorical distribution
415        let threshold = rng.next_u32() as f32 / u32::MAX as f32;
416        let mut cumulative = 0.0;
417
418        for (idx, prob) in probs.iter().enumerate() {
419            cumulative += prob;
420            if cumulative >= threshold {
421                return Ok(TokenId::new(idx as u32));
422            }
423        }
424
425        // Fallback to last token (shouldn't happen with proper normalization)
426        Ok(TokenId::new((probs.len() - 1) as u32))
427    }
428
429    fn name(&self) -> &str {
430        "multinomial"
431    }
432
433    fn is_deterministic(&self) -> bool {
434        false
435    }
436}
437
438/// Sampling configuration builder
439pub struct SamplingConfigBuilder {
440    processors: Vec<Box<dyn LogitsProcessor>>,
441    sampler: Option<Box<dyn Sampler>>,
442}
443
444impl SamplingConfigBuilder {
445    /// Create new builder
446    pub fn new() -> Self {
447        Self {
448            processors: Vec::new(),
449            sampler: None,
450        }
451    }
452
453    /// Add temperature scaling
454    pub fn with_temperature(mut self, temperature: f32) -> Self {
455        if temperature > 0.0 && temperature != 1.0 {
456            self.processors
457                .push(Box::new(TemperatureProcessor::new(temperature)));
458        }
459        self
460    }
461
462    /// Add top-k filtering
463    pub fn with_top_k(mut self, k: usize) -> Self {
464        if k > 0 {
465            self.processors.push(Box::new(TopKProcessor::new(k)));
466        }
467        self
468    }
469
470    /// Add top-p filtering
471    pub fn with_top_p(mut self, p: f32) -> Self {
472        if p > 0.0 && p < 1.0 {
473            self.processors.push(Box::new(TopPProcessor::new(p)));
474        }
475        self
476    }
477
478    /// Add repetition penalty
479    pub fn with_repetition_penalty(mut self, penalty: f32) -> Self {
480        if penalty != 1.0 {
481            self.processors
482                .push(Box::new(RepetitionPenaltyProcessor::new(penalty)));
483        }
484        self
485    }
486
487    /// Set sampler (greedy vs multinomial)
488    pub fn with_sampler(mut self, sampler: Box<dyn Sampler>) -> Self {
489        self.sampler = Some(sampler);
490        self
491    }
492
493    /// Build sampling configuration
494    pub fn build(self) -> SamplingConfig {
495        let mut chain = LogitsProcessorChain::new();
496        for processor in self.processors {
497            chain = chain.add_processor(processor);
498        }
499
500        let sampler = self.sampler.unwrap_or_else(|| Box::new(MultinomialSampler));
501
502        SamplingConfig {
503            processor_chain: chain,
504            sampler,
505        }
506    }
507}
508
509impl Default for SamplingConfigBuilder {
510    fn default() -> Self {
511        Self::new()
512    }
513}
514
515/// Complete sampling configuration
516pub struct SamplingConfig {
517    pub processor_chain: LogitsProcessorChain,
518    pub sampler: Box<dyn Sampler>,
519}
520
521impl SamplingConfig {
522    /// Create from sampling parameters
523    pub fn from_params(params: &SamplingParams) -> Self {
524        let mut builder = SamplingConfigBuilder::new()
525            .with_temperature(params.temperature)
526            .with_repetition_penalty(params.repetition_penalty);
527
528        if let Some(top_k) = params.top_k {
529            builder = builder.with_top_k(top_k);
530        }
531
532        if params.top_p < 1.0 {
533            builder = builder.with_top_p(params.top_p);
534        }
535
536        // Choose sampler based on temperature
537        let sampler: Box<dyn Sampler> = if params.temperature == 0.0 {
538            Box::new(GreedySampler)
539        } else {
540            Box::new(MultinomialSampler)
541        };
542
543        builder.with_sampler(sampler).build()
544    }
545
546    /// Process logits and sample token
547    pub fn sample(&self, mut ctx: SamplingContext, rng: &mut dyn RngCore) -> Result<TokenId> {
548        // Apply all logits processors
549        self.processor_chain.process(&mut ctx)?;
550
551        // Sample token
552        self.sampler.sample_with_context(&ctx, rng)
553    }
554}
555
556/// Sampling statistics for monitoring
557#[derive(Debug, Clone, Serialize, Deserialize)]
558pub struct SamplingStats {
559    /// Total sampling operations
560    pub total_samples: u64,
561    /// Average sampling time in microseconds
562    pub avg_sample_time_us: f64,
563    /// Distribution of sampled tokens
564    pub token_distribution: HashMap<TokenId, u64>,
565    /// Effective temperature (entropy-based measure)
566    pub effective_temperature: f32,
567    /// Processor execution times
568    pub processor_times: HashMap<String, f64>,
569}