ds_r1_rs/inference/
sampling.rs1use crate::utils::error::{ModelError, Result};
6use rand::Rng;
7
8#[derive(Debug, Clone)]
10pub struct SamplingConfig {
11 pub temperature: f32,
12 pub top_k: Option<usize>,
13 pub top_p: Option<f32>,
14 pub repetition_penalty: f32,
15}
16
17impl Default for SamplingConfig {
18 fn default() -> Self {
19 Self {
20 temperature: 1.0,
21 top_k: None,
22 top_p: None,
23 repetition_penalty: 1.0,
24 }
25 }
26}
27
28pub struct Sampler {
30 config: SamplingConfig,
31}
32
33impl Sampler {
34 pub fn new(config: SamplingConfig) -> Self {
36 Self { config }
37 }
38
39 pub fn sample_greedy(&self, logits: &[f32]) -> Result<u32> {
41 if logits.is_empty() {
42 return Err(ModelError::Forward("Empty logits for sampling".to_string()));
43 }
44
45 let max_idx = logits
47 .iter()
48 .enumerate()
49 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
50 .map(|(idx, _)| idx)
51 .ok_or_else(|| ModelError::Forward("Failed to find max logit".to_string()))?;
52
53 Ok(max_idx as u32)
54 }
55
56 pub fn sample_temperature(&self, logits: &[f32]) -> Result<u32> {
58 if logits.is_empty() {
59 return Err(ModelError::Forward("Empty logits for sampling".to_string()));
60 }
61
62 let mut rng = rand::rng();
63
64 let scaled_logits: Vec<f32> = if self.config.temperature > 0.0 {
66 logits
67 .iter()
68 .map(|&x| x / self.config.temperature)
69 .collect()
70 } else {
71 return self.sample_greedy(logits);
73 };
74
75 let probs = self.softmax(&scaled_logits)?;
77
78 let sample: f32 = rng.random::<f32>();
80 let mut cumulative = 0.0;
81
82 for (idx, &prob) in probs.iter().enumerate() {
83 cumulative += prob;
84 if sample <= cumulative {
85 return Ok(idx as u32);
86 }
87 }
88
89 Ok((probs.len() - 1) as u32)
91 }
92
93 pub fn sample_top_k(&self, logits: &[f32]) -> Result<u32> {
95 if logits.is_empty() {
96 return Err(ModelError::Forward("Empty logits for sampling".to_string()));
97 }
98
99 let k = self.config.top_k.unwrap_or(logits.len());
100 if k == 0 {
101 return Err(ModelError::Forward(
102 "Top-k must be greater than 0".to_string(),
103 ));
104 }
105
106 let mut indexed_logits: Vec<(usize, f32)> = logits
108 .iter()
109 .enumerate()
110 .map(|(idx, &logit)| (idx, logit))
111 .collect();
112
113 indexed_logits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
115
116 indexed_logits.truncate(k);
118
119 let mut filtered_logits = vec![f32::NEG_INFINITY; logits.len()];
121 for (idx, logit) in indexed_logits {
122 filtered_logits[idx] = logit;
123 }
124
125 self.sample_temperature(&filtered_logits)
127 }
128
129 fn softmax(&self, logits: &[f32]) -> Result<Vec<f32>> {
131 if logits.is_empty() {
132 return Ok(vec![]);
133 }
134
135 let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
137
138 let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
140
141 let sum_exp: f32 = exp_logits.iter().sum();
143
144 if sum_exp <= 0.0 {
145 return Err(ModelError::Forward(
146 "Invalid softmax computation".to_string(),
147 ));
148 }
149
150 let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum_exp).collect();
152
153 Ok(probs)
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn test_sampling_config_default() {
163 let config = SamplingConfig::default();
164 assert_eq!(config.temperature, 1.0);
165 assert_eq!(config.repetition_penalty, 1.0);
166 assert!(config.top_k.is_none());
167 assert!(config.top_p.is_none());
168 }
169
170 #[test]
171 fn test_sampler_creation() {
172 let config = SamplingConfig::default();
173 let _sampler = Sampler::new(config);
174 }
175
176 #[test]
177 fn test_greedy_sampling() {
178 let config = SamplingConfig::default();
179 let sampler = Sampler::new(config);
180
181 let logits = vec![0.1, 0.8, 0.3, 0.2];
182 let token = sampler.sample_greedy(&logits).unwrap();
183 assert_eq!(token, 1); }
185
186 #[test]
187 fn test_temperature_sampling() {
188 let mut config = SamplingConfig::default();
189 config.temperature = 0.0; let sampler = Sampler::new(config);
191
192 let logits = vec![0.1, 0.8, 0.3, 0.2];
193 let token = sampler.sample_temperature(&logits).unwrap();
194 assert_eq!(token, 1); }
196
197 #[test]
198 fn test_top_k_sampling() {
199 let mut config = SamplingConfig::default();
200 config.top_k = Some(2);
201 config.temperature = 0.0; let sampler = Sampler::new(config);
203
204 let logits = vec![0.1, 0.8, 0.3, 0.2];
205 let token = sampler.sample_top_k(&logits).unwrap();
206 assert_eq!(token, 1); }
208
209 #[test]
210 fn test_softmax() {
211 let config = SamplingConfig::default();
212 let sampler = Sampler::new(config);
213
214 let logits = vec![1.0, 2.0, 3.0];
215 let probs = sampler.softmax(&logits).unwrap();
216
217 let sum: f32 = probs.iter().sum();
219 assert!((sum - 1.0).abs() < 1e-6);
220
221 for prob in probs {
223 assert!(prob > 0.0);
224 }
225 }
226
227 #[test]
228 fn test_empty_logits() {
229 let config = SamplingConfig::default();
230 let sampler = Sampler::new(config);
231
232 let logits = vec![];
233 assert!(sampler.sample_greedy(&logits).is_err());
234 assert!(sampler.sample_temperature(&logits).is_err());
235 assert!(sampler.sample_top_k(&logits).is_err());
236 }
237}