1use sapient_core::error::{Result, SapientError};
10
11#[derive(Debug, Clone)]
14pub enum SamplingStrategy {
15 Greedy,
17 Temperature(f32),
19 TopK { k: usize, temperature: f32 },
21 TopP { p: f32, temperature: f32 },
23 Combined {
25 top_k: usize,
26 top_p: f32,
27 temperature: f32,
28 repetition_penalty: f32,
29 },
30}
31
32impl Default for SamplingStrategy {
33 fn default() -> Self {
34 Self::Greedy
35 }
36}
37
38pub struct Sampler {
41 pub strategy: SamplingStrategy,
42 rng_seed: u64,
43 counter: u64,
44}
45
46impl Sampler {
47 pub fn new(strategy: SamplingStrategy) -> Self {
48 let seed = std::time::SystemTime::now()
49 .duration_since(std::time::UNIX_EPOCH)
50 .map(|d| d.as_nanos() as u64)
51 .unwrap_or(42);
52 Self {
53 strategy,
54 rng_seed: seed,
55 counter: 0,
56 }
57 }
58
59 pub fn with_seed(strategy: SamplingStrategy, seed: u64) -> Self {
60 Self {
61 strategy,
62 rng_seed: seed,
63 counter: 0,
64 }
65 }
66
67 pub fn sample(&mut self, logits: &[f32], prev_tokens: &[u32]) -> Result<u32> {
70 match &self.strategy {
71 SamplingStrategy::Greedy => Ok(argmax(logits)),
72
73 SamplingStrategy::Temperature(t) => {
74 let t = *t;
75 if t <= 0.0 {
76 return Ok(argmax(logits));
77 }
78 let scaled = scale_logits(logits, t);
79 let probs = softmax(&scaled);
80 Ok(self.random_sample(&probs))
81 }
82
83 SamplingStrategy::TopK { k, temperature } => {
84 let (k, t) = (*k, *temperature);
85 if t <= 0.0 {
86 return Ok(argmax(logits));
87 }
88 let scaled = scale_logits(logits, t);
89 let filtered = top_k_filter(&scaled, k);
90 let probs = softmax(&filtered);
91 Ok(self.random_sample(&probs))
92 }
93
94 SamplingStrategy::TopP { p, temperature } => {
95 let (p, t) = (*p, *temperature);
96 if t <= 0.0 {
97 return Ok(argmax(logits));
98 }
99 let scaled = scale_logits(logits, t);
100 let filtered = top_p_filter(&scaled, p);
101 let probs = softmax(&filtered);
102 Ok(self.random_sample(&probs))
103 }
104
105 SamplingStrategy::Combined {
106 top_k,
107 top_p,
108 temperature,
109 repetition_penalty,
110 } => {
111 let (k, p, t, rp) = (*top_k, *top_p, *temperature, *repetition_penalty);
112 let mut penalized = apply_repetition_penalty(logits, prev_tokens, rp);
113 if t <= 0.0 {
114 return Ok(argmax(&penalized));
115 }
116 penalized = scale_logits(&penalized, t);
117 penalized = top_k_filter(&penalized, k);
118 penalized = top_p_filter(&penalized, p);
119 let probs = softmax(&penalized);
120 Ok(self.random_sample(&probs))
121 }
122 }
123 }
124
125 fn random_u64(&mut self) -> u64 {
127 self.counter += 1;
128 let mut x = self
129 .rng_seed
130 .wrapping_add(self.counter.wrapping_mul(6364136223846793005));
131 x ^= x >> 30;
132 x = x.wrapping_mul(0xbf58476d1ce4e5b9);
133 x ^= x >> 27;
134 x = x.wrapping_mul(0x94d049bb133111eb);
135 x ^= x >> 31;
136 x
137 }
138
139 fn random_f32(&mut self) -> f32 {
140 (self.random_u64() >> 11) as f32 / (1u64 << 53) as f32
141 }
142
143 fn random_sample(&mut self, probs: &[f32]) -> u32 {
144 let r = self.random_f32();
145 let mut cum = 0.0f32;
146 for (i, &p) in probs.iter().enumerate() {
147 cum += p;
148 if r < cum {
149 return i as u32;
150 }
151 }
152 (probs.len() - 1) as u32
153 }
154}
155
156pub fn argmax(logits: &[f32]) -> u32 {
159 logits
160 .iter()
161 .enumerate()
162 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
163 .map(|(i, _)| i as u32)
164 .unwrap_or(0)
165}
166
167fn softmax(logits: &[f32]) -> Vec<f32> {
168 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
169 let mut out: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
170 let sum: f32 = out.iter().sum();
171 out.iter_mut().for_each(|x| *x /= sum);
172 out
173}
174
175fn scale_logits(logits: &[f32], temperature: f32) -> Vec<f32> {
176 if temperature <= 0.0 || temperature == 1.0 {
177 return logits.to_vec();
178 }
179 logits.iter().map(|&x| x / temperature).collect()
180}
181
182fn top_k_filter(logits: &[f32], k: usize) -> Vec<f32> {
183 if k == 0 || k >= logits.len() {
184 return logits.to_vec();
185 }
186 let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
187 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
188 let threshold = indexed[k - 1].1;
189 logits
190 .iter()
191 .map(|&x| if x >= threshold { x } else { f32::NEG_INFINITY })
192 .collect()
193}
194
195fn top_p_filter(logits: &[f32], p: f32) -> Vec<f32> {
196 let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
197 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
198
199 let probs = softmax(logits);
200 let mut sorted_probs: Vec<(usize, f32)> = probs.iter().cloned().enumerate().collect();
201 sorted_probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
202
203 let mut cum = 0.0f32;
204 let mut cutoff_idx = sorted_probs.len();
205 for (i, (_, prob)) in sorted_probs.iter().enumerate() {
206 cum += prob;
207 if cum >= p {
208 cutoff_idx = i + 1;
209 break;
210 }
211 }
212
213 let keep: std::collections::HashSet<usize> =
214 sorted_probs[..cutoff_idx].iter().map(|(i, _)| *i).collect();
215 logits
216 .iter()
217 .enumerate()
218 .map(|(i, &x)| {
219 if keep.contains(&i) {
220 x
221 } else {
222 f32::NEG_INFINITY
223 }
224 })
225 .collect()
226}
227
228fn apply_repetition_penalty(logits: &[f32], prev_tokens: &[u32], penalty: f32) -> Vec<f32> {
229 if (penalty - 1.0).abs() < 1e-6 {
230 return logits.to_vec();
231 }
232 let mut out = logits.to_vec();
233 for &tok in prev_tokens {
234 let idx = tok as usize;
235 if idx < out.len() {
236 if out[idx] >= 0.0 {
237 out[idx] /= penalty;
238 } else {
239 out[idx] *= penalty;
240 }
241 }
242 }
243 out
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn greedy_picks_argmax() {
252 let logits = vec![0.1, 0.9, 0.3, 0.5];
253 let mut s = Sampler::with_seed(SamplingStrategy::Greedy, 42);
254 assert_eq!(s.sample(&logits, &[]).unwrap(), 1);
255 }
256
257 #[test]
258 fn top_k_removes_low_prob() {
259 let logits = vec![10.0, 1.0, 1.0, 1.0];
260 let filtered = top_k_filter(&logits, 1);
261 assert_eq!(filtered[0], 10.0);
262 assert!(filtered[1].is_infinite() && filtered[1] < 0.0);
263 }
264
265 #[test]
266 fn repetition_penalty_reduces_score() {
267 let logits = vec![1.0, 2.0, 3.0];
268 let penalized = apply_repetition_penalty(&logits, &[2], 1.3);
269 assert!(penalized[2] < logits[2]);
270 }
271}