ds_r1_rs/inference/
sampling.rs

1//! # Sampling Strategies
2//!
3//! Different sampling methods for text generation.
4
5use crate::utils::error::{ModelError, Result};
6use rand::Rng;
7
8/// Sampling configuration
9#[derive(Debug, Clone)]
10pub struct SamplingConfig {
11    pub temperature: f32,
12    pub top_k: Option<usize>,
13    pub top_p: Option<f32>,
14    pub repetition_penalty: f32,
15}
16
17impl Default for SamplingConfig {
18    fn default() -> Self {
19        Self {
20            temperature: 1.0,
21            top_k: None,
22            top_p: None,
23            repetition_penalty: 1.0,
24        }
25    }
26}
27
28/// Text sampler for generation
29pub struct Sampler {
30    config: SamplingConfig,
31}
32
33impl Sampler {
34    /// Create a new sampler
35    pub fn new(config: SamplingConfig) -> Self {
36        Self { config }
37    }
38
39    /// Sample next token using greedy decoding
40    pub fn sample_greedy(&self, logits: &[f32]) -> Result<u32> {
41        if logits.is_empty() {
42            return Err(ModelError::Forward("Empty logits for sampling".to_string()));
43        }
44
45        // Find the token with highest probability
46        let max_idx = logits
47            .iter()
48            .enumerate()
49            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
50            .map(|(idx, _)| idx)
51            .ok_or_else(|| ModelError::Forward("Failed to find max logit".to_string()))?;
52
53        Ok(max_idx as u32)
54    }
55
56    /// Sample next token using temperature sampling
57    pub fn sample_temperature(&self, logits: &[f32]) -> Result<u32> {
58        if logits.is_empty() {
59            return Err(ModelError::Forward("Empty logits for sampling".to_string()));
60        }
61
62        let mut rng = rand::rng();
63
64        // Apply temperature scaling
65        let scaled_logits: Vec<f32> = if self.config.temperature > 0.0 {
66            logits
67                .iter()
68                .map(|&x| x / self.config.temperature)
69                .collect()
70        } else {
71            // Temperature 0 means greedy
72            return self.sample_greedy(logits);
73        };
74
75        // Convert to probabilities using softmax
76        let probs = self.softmax(&scaled_logits)?;
77
78        // Sample from the distribution
79        let sample: f32 = rng.random::<f32>();
80        let mut cumulative = 0.0;
81
82        for (idx, &prob) in probs.iter().enumerate() {
83            cumulative += prob;
84            if sample <= cumulative {
85                return Ok(idx as u32);
86            }
87        }
88
89        // Fallback to last token if rounding errors occur
90        Ok((probs.len() - 1) as u32)
91    }
92
93    /// Sample next token using top-k sampling
94    pub fn sample_top_k(&self, logits: &[f32]) -> Result<u32> {
95        if logits.is_empty() {
96            return Err(ModelError::Forward("Empty logits for sampling".to_string()));
97        }
98
99        let k = self.config.top_k.unwrap_or(logits.len());
100        if k == 0 {
101            return Err(ModelError::Forward(
102                "Top-k must be greater than 0".to_string(),
103            ));
104        }
105
106        // Get top-k indices and their logits
107        let mut indexed_logits: Vec<(usize, f32)> = logits
108            .iter()
109            .enumerate()
110            .map(|(idx, &logit)| (idx, logit))
111            .collect();
112
113        // Sort by logit value (descending)
114        indexed_logits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
115
116        // Take top-k
117        indexed_logits.truncate(k);
118
119        // Create filtered logits array
120        let mut filtered_logits = vec![f32::NEG_INFINITY; logits.len()];
121        for (idx, logit) in indexed_logits {
122            filtered_logits[idx] = logit;
123        }
124
125        // Sample from filtered distribution
126        self.sample_temperature(&filtered_logits)
127    }
128
129    /// Apply softmax to convert logits to probabilities
130    fn softmax(&self, logits: &[f32]) -> Result<Vec<f32>> {
131        if logits.is_empty() {
132            return Ok(vec![]);
133        }
134
135        // Find max for numerical stability
136        let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
137
138        // Compute exp(x - max) for each logit
139        let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
140
141        // Compute sum of exponentials
142        let sum_exp: f32 = exp_logits.iter().sum();
143
144        if sum_exp <= 0.0 {
145            return Err(ModelError::Forward(
146                "Invalid softmax computation".to_string(),
147            ));
148        }
149
150        // Normalize to get probabilities
151        let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum_exp).collect();
152
153        Ok(probs)
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_sampling_config_default() {
163        let config = SamplingConfig::default();
164        assert_eq!(config.temperature, 1.0);
165        assert_eq!(config.repetition_penalty, 1.0);
166        assert!(config.top_k.is_none());
167        assert!(config.top_p.is_none());
168    }
169
170    #[test]
171    fn test_sampler_creation() {
172        let config = SamplingConfig::default();
173        let _sampler = Sampler::new(config);
174    }
175
176    #[test]
177    fn test_greedy_sampling() {
178        let config = SamplingConfig::default();
179        let sampler = Sampler::new(config);
180
181        let logits = vec![0.1, 0.8, 0.3, 0.2];
182        let token = sampler.sample_greedy(&logits).unwrap();
183        assert_eq!(token, 1); // Index of highest logit (0.8)
184    }
185
186    #[test]
187    fn test_temperature_sampling() {
188        let mut config = SamplingConfig::default();
189        config.temperature = 0.0; // Should behave like greedy
190        let sampler = Sampler::new(config);
191
192        let logits = vec![0.1, 0.8, 0.3, 0.2];
193        let token = sampler.sample_temperature(&logits).unwrap();
194        assert_eq!(token, 1); // Should be greedy when temperature is 0
195    }
196
197    #[test]
198    fn test_top_k_sampling() {
199        let mut config = SamplingConfig::default();
200        config.top_k = Some(2);
201        config.temperature = 0.0; // Make it deterministic
202        let sampler = Sampler::new(config);
203
204        let logits = vec![0.1, 0.8, 0.3, 0.2];
205        let token = sampler.sample_top_k(&logits).unwrap();
206        assert_eq!(token, 1); // Should pick from top-2: [0.8, 0.3], greedy picks 0.8
207    }
208
209    #[test]
210    fn test_softmax() {
211        let config = SamplingConfig::default();
212        let sampler = Sampler::new(config);
213
214        let logits = vec![1.0, 2.0, 3.0];
215        let probs = sampler.softmax(&logits).unwrap();
216
217        // Check probabilities sum to 1
218        let sum: f32 = probs.iter().sum();
219        assert!((sum - 1.0).abs() < 1e-6);
220
221        // Check probabilities are positive
222        for prob in probs {
223            assert!(prob > 0.0);
224        }
225    }
226
227    #[test]
228    fn test_empty_logits() {
229        let config = SamplingConfig::default();
230        let sampler = Sampler::new(config);
231
232        let logits = vec![];
233        assert!(sampler.sample_greedy(&logits).is_err());
234        assert!(sampler.sample_temperature(&logits).is_err());
235        assert!(sampler.sample_top_k(&logits).is_err());
236    }
237}