1use ferrum_types::{Result, SamplingParams, TokenId};
8use rand::RngCore;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[derive(Debug)]
14pub struct SamplingContext<'a> {
15 pub step: usize,
17 pub sampling_params: &'a SamplingParams,
19 pub logits: &'a mut [f32],
21 pub previous_tokens: &'a [TokenId],
23 pub token_frequencies: &'a HashMap<TokenId, usize>,
25 pub vocab_size: usize,
27 pub metadata: HashMap<String, f32>,
29}
30
31impl<'a> SamplingContext<'a> {
32 pub fn new(
34 step: usize,
35 sampling_params: &'a SamplingParams,
36 logits: &'a mut [f32],
37 previous_tokens: &'a [TokenId],
38 token_frequencies: &'a HashMap<TokenId, usize>,
39 vocab_size: usize,
40 ) -> Self {
41 Self {
42 step,
43 sampling_params,
44 logits,
45 previous_tokens,
46 token_frequencies,
47 vocab_size,
48 metadata: HashMap::new(),
49 }
50 }
51
52 pub fn get_logit(&self, token_id: TokenId) -> Option<f32> {
54 if usize::from(token_id) < self.logits.len() {
55 Some(self.logits[usize::from(token_id)])
56 } else {
57 None
58 }
59 }
60
61 pub fn set_logit(&mut self, token_id: TokenId, value: f32) -> bool {
63 if usize::from(token_id) < self.logits.len() {
64 self.logits[usize::from(token_id)] = value;
65 true
66 } else {
67 false
68 }
69 }
70
71 pub fn mask_tokens(&mut self, token_ids: &[TokenId]) {
73 for &token_id in token_ids {
74 if usize::from(token_id) < self.logits.len() {
75 self.logits[usize::from(token_id)] = f32::NEG_INFINITY;
76 }
77 }
78 }
79}
80
81pub trait LogitsProcessor: Send + Sync {
83 fn process(&self, ctx: &mut SamplingContext) -> Result<()>;
85
86 fn name(&self) -> &str;
88
89 fn priority(&self) -> ProcessorPriority {
91 ProcessorPriority::Normal
92 }
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
97pub enum ProcessorPriority {
98 High = 3,
100 Normal = 2,
102 Low = 1,
104}
105
106pub trait Sampler: Send + Sync {
108 fn sample(&self, logits: &[f32], rng: &mut dyn RngCore) -> Result<TokenId>;
110
111 fn sample_with_context(&self, ctx: &SamplingContext, rng: &mut dyn RngCore) -> Result<TokenId> {
113 self.sample(ctx.logits, rng)
114 }
115
116 fn name(&self) -> &str;
118
119 fn is_deterministic(&self) -> bool;
121}
122
123pub trait MultiSampler: Sampler {
125 fn sample_multiple(
127 &self,
128 logits: &[f32],
129 num_samples: usize,
130 rng: &mut dyn RngCore,
131 ) -> Result<Vec<TokenId>>;
132
133 fn sample_with_probabilities(
135 &self,
136 logits: &[f32],
137 rng: &mut dyn RngCore,
138 ) -> Result<(TokenId, Vec<f32>)>;
139}
140
141pub struct LogitsProcessorChain {
143 processors: Vec<Box<dyn LogitsProcessor>>,
144}
145
146impl LogitsProcessorChain {
147 pub fn new() -> Self {
149 Self {
150 processors: Vec::new(),
151 }
152 }
153
154 pub fn add_processor(mut self, processor: Box<dyn LogitsProcessor>) -> Self {
156 self.processors.push(processor);
157 self.processors
159 .sort_by(|a, b| b.priority().cmp(&a.priority()));
160 self
161 }
162
163 pub fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
165 for processor in &self.processors {
166 processor.process(ctx)?;
167 }
168 Ok(())
169 }
170
171 pub fn processor_names(&self) -> Vec<&str> {
173 self.processors.iter().map(|p| p.name()).collect()
174 }
175}
176
177impl Default for LogitsProcessorChain {
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183pub struct TemperatureProcessor {
187 pub temperature: f32,
188}
189
190impl TemperatureProcessor {
191 pub fn new(temperature: f32) -> Self {
192 Self { temperature }
193 }
194}
195
196impl LogitsProcessor for TemperatureProcessor {
197 fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
198 if self.temperature > 0.0 && self.temperature != 1.0 {
199 for logit in ctx.logits.iter_mut() {
200 *logit /= self.temperature;
201 }
202 }
203 Ok(())
204 }
205
206 fn name(&self) -> &str {
207 "temperature"
208 }
209
210 fn priority(&self) -> ProcessorPriority {
211 ProcessorPriority::Low }
213}
214
215pub struct TopKProcessor {
217 pub k: usize,
218}
219
220impl TopKProcessor {
221 pub fn new(k: usize) -> Self {
222 Self { k }
223 }
224}
225
226impl LogitsProcessor for TopKProcessor {
227 fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
228 if self.k > 0 && self.k < ctx.logits.len() {
229 let mut indices: Vec<usize> = (0..ctx.logits.len()).collect();
231 indices.sort_by(|&a, &b| {
232 ctx.logits[b]
233 .partial_cmp(&ctx.logits[a])
234 .unwrap_or(std::cmp::Ordering::Equal)
235 });
236
237 let threshold = ctx.logits[indices[self.k - 1]];
238
239 for logit in ctx.logits.iter_mut() {
241 if *logit < threshold {
242 *logit = f32::NEG_INFINITY;
243 }
244 }
245 }
246 Ok(())
247 }
248
249 fn name(&self) -> &str {
250 "top_k"
251 }
252}
253
254pub struct TopPProcessor {
256 pub p: f32,
257}
258
259impl TopPProcessor {
260 pub fn new(p: f32) -> Self {
261 Self { p }
262 }
263}
264
265impl LogitsProcessor for TopPProcessor {
266 fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
267 if self.p < 1.0 && self.p > 0.0 {
268 let max_logit = ctx.logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
270 let mut probs: Vec<f32> = ctx
271 .logits
272 .iter()
273 .map(|&logit| (logit - max_logit).exp())
274 .collect();
275
276 let sum: f32 = probs.iter().sum();
277 for prob in probs.iter_mut() {
278 *prob /= sum;
279 }
280
281 let mut indices: Vec<usize> = (0..probs.len()).collect();
283 indices.sort_by(|&a, &b| {
284 probs[b]
285 .partial_cmp(&probs[a])
286 .unwrap_or(std::cmp::Ordering::Equal)
287 });
288
289 let mut cum_prob = 0.0;
291 let mut cutoff_idx = probs.len();
292
293 for (i, &idx) in indices.iter().enumerate() {
294 cum_prob += probs[idx];
295 if cum_prob > self.p {
296 cutoff_idx = i + 1;
297 break;
298 }
299 }
300
301 for (i, &idx) in indices.iter().enumerate() {
303 if i >= cutoff_idx {
304 ctx.logits[idx] = f32::NEG_INFINITY;
305 }
306 }
307 }
308 Ok(())
309 }
310
311 fn name(&self) -> &str {
312 "top_p"
313 }
314}
315
316pub struct RepetitionPenaltyProcessor {
318 pub penalty: f32,
319}
320
321impl RepetitionPenaltyProcessor {
322 pub fn new(penalty: f32) -> Self {
323 Self { penalty }
324 }
325}
326
327impl LogitsProcessor for RepetitionPenaltyProcessor {
328 fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
329 if self.penalty != 1.0 {
330 for &token_id in ctx.previous_tokens {
331 if let Some(freq) = ctx.token_frequencies.get(&token_id) {
332 if usize::from(token_id) < ctx.logits.len() {
333 let idx = usize::from(token_id);
334 let current_logit = ctx.logits[idx];
335 let penalty_factor = self.penalty.powi(*freq as i32);
336
337 if current_logit > 0.0 {
338 ctx.logits[idx] = current_logit / penalty_factor;
339 } else {
340 ctx.logits[idx] = current_logit * penalty_factor;
341 }
342 }
343 }
344 }
345 }
346 Ok(())
347 }
348
349 fn name(&self) -> &str {
350 "repetition_penalty"
351 }
352
353 fn priority(&self) -> ProcessorPriority {
354 ProcessorPriority::High }
356}
357
358pub struct GreedySampler;
362
363impl Sampler for GreedySampler {
364 fn sample(&self, logits: &[f32], _rng: &mut dyn RngCore) -> Result<TokenId> {
365 let max_idx = logits
366 .iter()
367 .enumerate()
368 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
369 .map(|(idx, _)| idx)
370 .ok_or_else(|| ferrum_types::FerrumError::backend("Empty logits for sampling"))?;
371
372 Ok(TokenId::new(max_idx as u32))
373 }
374
375 fn name(&self) -> &str {
376 "greedy"
377 }
378
379 fn is_deterministic(&self) -> bool {
380 true
381 }
382}
383
384pub struct MultinomialSampler;
386
387impl Sampler for MultinomialSampler {
388 fn sample(&self, logits: &[f32], rng: &mut dyn RngCore) -> Result<TokenId> {
389 let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
391
392 let mut probs: Vec<f32> = logits
393 .iter()
394 .map(|&logit| {
395 if logit.is_finite() && logit > f32::NEG_INFINITY {
396 (logit - max_logit).exp()
397 } else {
398 0.0
399 }
400 })
401 .collect();
402
403 let sum: f32 = probs.iter().sum();
404 if sum <= 0.0 {
405 return Err(ferrum_types::FerrumError::backend(
406 "No valid tokens for sampling",
407 ));
408 }
409
410 for prob in probs.iter_mut() {
411 *prob /= sum;
412 }
413
414 let threshold = rng.next_u32() as f32 / u32::MAX as f32;
416 let mut cumulative = 0.0;
417
418 for (idx, prob) in probs.iter().enumerate() {
419 cumulative += prob;
420 if cumulative >= threshold {
421 return Ok(TokenId::new(idx as u32));
422 }
423 }
424
425 Ok(TokenId::new((probs.len() - 1) as u32))
427 }
428
429 fn name(&self) -> &str {
430 "multinomial"
431 }
432
433 fn is_deterministic(&self) -> bool {
434 false
435 }
436}
437
438pub struct SamplingConfigBuilder {
440 processors: Vec<Box<dyn LogitsProcessor>>,
441 sampler: Option<Box<dyn Sampler>>,
442}
443
444impl SamplingConfigBuilder {
445 pub fn new() -> Self {
447 Self {
448 processors: Vec::new(),
449 sampler: None,
450 }
451 }
452
453 pub fn with_temperature(mut self, temperature: f32) -> Self {
455 if temperature > 0.0 && temperature != 1.0 {
456 self.processors
457 .push(Box::new(TemperatureProcessor::new(temperature)));
458 }
459 self
460 }
461
462 pub fn with_top_k(mut self, k: usize) -> Self {
464 if k > 0 {
465 self.processors.push(Box::new(TopKProcessor::new(k)));
466 }
467 self
468 }
469
470 pub fn with_top_p(mut self, p: f32) -> Self {
472 if p > 0.0 && p < 1.0 {
473 self.processors.push(Box::new(TopPProcessor::new(p)));
474 }
475 self
476 }
477
478 pub fn with_repetition_penalty(mut self, penalty: f32) -> Self {
480 if penalty != 1.0 {
481 self.processors
482 .push(Box::new(RepetitionPenaltyProcessor::new(penalty)));
483 }
484 self
485 }
486
487 pub fn with_sampler(mut self, sampler: Box<dyn Sampler>) -> Self {
489 self.sampler = Some(sampler);
490 self
491 }
492
493 pub fn build(self) -> SamplingConfig {
495 let mut chain = LogitsProcessorChain::new();
496 for processor in self.processors {
497 chain = chain.add_processor(processor);
498 }
499
500 let sampler = self.sampler.unwrap_or_else(|| Box::new(MultinomialSampler));
501
502 SamplingConfig {
503 processor_chain: chain,
504 sampler,
505 }
506 }
507}
508
509impl Default for SamplingConfigBuilder {
510 fn default() -> Self {
511 Self::new()
512 }
513}
514
515pub struct SamplingConfig {
517 pub processor_chain: LogitsProcessorChain,
518 pub sampler: Box<dyn Sampler>,
519}
520
521impl SamplingConfig {
522 pub fn from_params(params: &SamplingParams) -> Self {
524 let mut builder = SamplingConfigBuilder::new()
525 .with_temperature(params.temperature)
526 .with_repetition_penalty(params.repetition_penalty);
527
528 if let Some(top_k) = params.top_k {
529 builder = builder.with_top_k(top_k);
530 }
531
532 if params.top_p < 1.0 {
533 builder = builder.with_top_p(params.top_p);
534 }
535
536 let sampler: Box<dyn Sampler> = if params.temperature == 0.0 {
538 Box::new(GreedySampler)
539 } else {
540 Box::new(MultinomialSampler)
541 };
542
543 builder.with_sampler(sampler).build()
544 }
545
546 pub fn sample(&self, mut ctx: SamplingContext, rng: &mut dyn RngCore) -> Result<TokenId> {
548 self.processor_chain.process(&mut ctx)?;
550
551 self.sampler.sample_with_context(&ctx, rng)
553 }
554}
555
556#[derive(Debug, Clone, Serialize, Deserialize)]
558pub struct SamplingStats {
559 pub total_samples: u64,
561 pub avg_sample_time_us: f64,
563 pub token_distribution: HashMap<TokenId, u64>,
565 pub effective_temperature: f32,
567 pub processor_times: HashMap<String, f64>,
569}