1use crate::error::{InferenceError, InferenceResult};
11use scirs2_core::ndarray::{Array1, Array2};
12
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct SamplingConfig {
16 pub strategy: SamplingStrategy,
18 pub temperature: f32,
20 pub top_k: Option<usize>,
22 pub top_p: Option<f32>,
24 pub beam_width: usize,
26 pub seed: Option<u64>,
28}
29
30impl Default for SamplingConfig {
31 fn default() -> Self {
32 Self {
33 strategy: SamplingStrategy::Greedy,
34 temperature: 1.0,
35 top_k: None,
36 top_p: None,
37 beam_width: 1,
38 seed: None,
39 }
40 }
41}
42
43impl SamplingConfig {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn strategy(mut self, strategy: SamplingStrategy) -> Self {
51 self.strategy = strategy;
52 self
53 }
54
55 pub fn temperature(mut self, temp: f32) -> Self {
57 self.temperature = temp;
58 self
59 }
60
61 pub fn top_k(mut self, k: usize) -> Self {
63 self.strategy = SamplingStrategy::TopK;
64 self.top_k = Some(k);
65 self
66 }
67
68 pub fn top_p(mut self, p: f32) -> Self {
70 self.strategy = SamplingStrategy::TopP;
71 self.top_p = Some(p);
72 self
73 }
74
75 pub fn beam_search(mut self, width: usize) -> Self {
77 self.strategy = SamplingStrategy::BeamSearch;
78 self.beam_width = width;
79 self
80 }
81
82 pub fn seed(mut self, seed: u64) -> Self {
84 self.seed = Some(seed);
85 self
86 }
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
91pub enum SamplingStrategy {
92 Greedy,
94 Temperature,
96 TopK,
98 TopP,
100 BeamSearch,
102 Custom,
104}
105
106pub type CustomSamplingFn = Arc<dyn Fn(&Array1<f32>, f32) -> InferenceResult<f32> + Send + Sync>;
109
110pub struct Sampler {
112 config: SamplingConfig,
113 custom_fn: Option<CustomSamplingFn>,
115}
116
117impl Sampler {
118 pub fn new(config: SamplingConfig) -> Self {
120 Self {
121 config,
122 custom_fn: None,
123 }
124 }
125
126 pub fn with_custom_fn(mut config: SamplingConfig, custom_fn: CustomSamplingFn) -> Self {
128 config.strategy = SamplingStrategy::Custom;
129 Self {
130 config,
131 custom_fn: Some(custom_fn),
132 }
133 }
134
135 pub fn set_custom_fn(&mut self, custom_fn: CustomSamplingFn) {
137 self.custom_fn = Some(custom_fn);
138 self.config.strategy = SamplingStrategy::Custom;
139 }
140
141 pub fn sample(&mut self, logits: &Array1<f32>) -> InferenceResult<f32> {
149 if logits.is_empty() {
150 return Err(InferenceError::DimensionMismatch {
151 expected: 1,
152 got: 0,
153 });
154 }
155
156 match self.config.strategy {
157 SamplingStrategy::Greedy => Ok(self.greedy_sample(logits)),
158 SamplingStrategy::Temperature => self.temperature_sample(logits),
159 SamplingStrategy::TopK => self.top_k_sample(logits),
160 SamplingStrategy::TopP => self.top_p_sample(logits),
161 SamplingStrategy::BeamSearch => {
162 Ok(self.greedy_sample(logits))
164 }
165 SamplingStrategy::Custom => {
166 if let Some(ref custom_fn) = self.custom_fn {
167 custom_fn(logits, self.config.temperature)
168 } else {
169 Ok(self.greedy_sample(logits))
171 }
172 }
173 }
174 }
175
176 pub fn sample_batch(&mut self, logits: &Array2<f32>) -> InferenceResult<Array1<f32>> {
178 let batch_size = logits.nrows();
179 let mut results = Vec::with_capacity(batch_size);
180
181 for i in 0..batch_size {
182 let logit_row = logits.row(i).to_owned();
183 results.push(self.sample(&logit_row)?);
184 }
185
186 Ok(Array1::from_vec(results))
187 }
188
189 fn greedy_sample(&self, logits: &Array1<f32>) -> f32 {
191 logits
192 .iter()
193 .enumerate()
194 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
195 .map(|(idx, _)| idx as f32)
196 .unwrap_or(0.0)
197 }
198
199 fn temperature_sample(&mut self, logits: &Array1<f32>) -> InferenceResult<f32> {
201 let scaled = if (self.config.temperature - 1.0).abs() > 1e-6 {
202 logits.mapv(|x| x / self.config.temperature)
203 } else {
204 logits.clone()
205 };
206
207 let probs = softmax(&scaled);
208 self.sample_categorical(&probs)
209 }
210
211 fn top_k_sample(&mut self, logits: &Array1<f32>) -> InferenceResult<f32> {
213 let k = self.config.top_k.unwrap_or(10);
214
215 let mut indexed: Vec<_> = logits.iter().enumerate().collect();
217 indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
218 let top_k_indices: Vec<usize> = indexed.iter().take(k).map(|(idx, _)| *idx).collect();
219
220 let mut filtered = Array1::from_elem(logits.len(), f32::NEG_INFINITY);
222 for &idx in &top_k_indices {
223 filtered[idx] = logits[idx];
224 }
225
226 let probs = softmax(&filtered);
227 self.sample_categorical(&probs)
228 }
229
230 fn top_p_sample(&mut self, logits: &Array1<f32>) -> InferenceResult<f32> {
232 let p = self.config.top_p.unwrap_or(0.9);
233
234 let probs = softmax(logits);
236 let mut indexed: Vec<_> = probs.iter().enumerate().collect();
237 indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
238
239 let mut cumsum = 0.0;
241 let mut nucleus_size = 0;
242 for (_, &prob) in &indexed {
243 cumsum += prob;
244 nucleus_size += 1;
245 if cumsum >= p {
246 break;
247 }
248 }
249
250 let nucleus_indices: Vec<usize> = indexed
252 .iter()
253 .take(nucleus_size)
254 .map(|(idx, _)| *idx)
255 .collect();
256 let mut filtered = Array1::from_elem(logits.len(), f32::NEG_INFINITY);
257 for &idx in &nucleus_indices {
258 filtered[idx] = logits[idx];
259 }
260
261 let filtered_probs = softmax(&filtered);
262 self.sample_categorical(&filtered_probs)
263 }
264
265 fn sample_categorical(&mut self, probs: &Array1<f32>) -> InferenceResult<f32> {
267 use scirs2_core::random::{rng, Rng};
269
270 let mut rng_gen = rng();
271 let uniform: f32 = rng_gen.random();
272 let mut cumsum = 0.0;
273 for (idx, &prob) in probs.iter().enumerate() {
274 cumsum += prob;
275 if uniform < cumsum {
276 return Ok(idx as f32);
277 }
278 }
279 Ok((probs.len() - 1) as f32)
281 }
282
283 pub fn config(&self) -> &SamplingConfig {
285 &self.config
286 }
287}
288
289fn softmax(logits: &Array1<f32>) -> Array1<f32> {
291 let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
293 let exp_logits = logits.mapv(|x| (x - max_logit).exp());
294 let sum_exp: f32 = exp_logits.sum();
295
296 if sum_exp > 0.0 {
297 exp_logits / sum_exp
298 } else {
299 Array1::from_elem(logits.len(), 1.0 / logits.len() as f32)
301 }
302}
303
304#[derive(Debug, Clone)]
306pub struct Beam {
307 pub sequence: Vec<f32>,
309 pub log_prob: f32,
311 pub states: Vec<kizzasi_core::HiddenState>,
313}
314
315impl Beam {
316 pub fn new() -> Self {
318 Self {
319 sequence: Vec::new(),
320 log_prob: 0.0,
321 states: Vec::new(),
322 }
323 }
324
325 pub fn extend(&mut self, value: f32, log_prob: f32) {
327 self.sequence.push(value);
328 self.log_prob += log_prob;
329 }
330
331 pub fn avg_log_prob(&self) -> f32 {
333 if self.sequence.is_empty() {
334 0.0
335 } else {
336 self.log_prob / self.sequence.len() as f32
337 }
338 }
339}
340
341impl Default for Beam {
342 fn default() -> Self {
343 Self::new()
344 }
345}
346
347pub struct BeamSearch {
349 beam_width: usize,
351 beams: Vec<Beam>,
353}
354
355impl BeamSearch {
356 pub fn new(beam_width: usize) -> Self {
358 let beams = vec![Beam::new()];
359 Self { beam_width, beams }
360 }
361
362 pub fn expand(&mut self, logits: &Array2<f32>) -> InferenceResult<()> {
364 if logits.nrows() != self.beams.len() {
365 return Err(InferenceError::DimensionMismatch {
366 expected: self.beams.len(),
367 got: logits.nrows(),
368 });
369 }
370
371 let mut candidates = Vec::new();
372
373 for (beam_idx, beam) in self.beams.iter().enumerate() {
374 let beam_logits = logits.row(beam_idx).to_owned();
375 let probs = softmax(&beam_logits);
376
377 let mut indexed: Vec<_> = probs.iter().enumerate().collect();
379 indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
380
381 for (idx, &prob) in indexed.iter().take(self.beam_width) {
382 let mut new_beam = beam.clone();
383 new_beam.extend(*idx as f32, prob.ln());
384 candidates.push(new_beam);
385 }
386 }
387
388 candidates.sort_by(|a, b| {
390 b.avg_log_prob()
391 .partial_cmp(&a.avg_log_prob())
392 .unwrap_or(std::cmp::Ordering::Equal)
393 });
394 self.beams = candidates.into_iter().take(self.beam_width).collect();
395
396 Ok(())
397 }
398
399 pub fn best(&self) -> Option<&Beam> {
401 self.beams.first()
402 }
403
404 pub fn beams(&self) -> &[Beam] {
406 &self.beams
407 }
408}
409
410pub type ConstraintFn = Arc<dyn Fn(&[f32]) -> bool + Send + Sync>;
413
414pub struct ConstrainedBeamSearch {
416 beam_search: BeamSearch,
418 constraints: Vec<ConstraintFn>,
420 soft_constraints: bool,
422 constraint_penalty: f32,
424}
425
426impl ConstrainedBeamSearch {
427 pub fn new(beam_width: usize) -> Self {
429 Self {
430 beam_search: BeamSearch::new(beam_width),
431 constraints: Vec::new(),
432 soft_constraints: false,
433 constraint_penalty: 1.0,
434 }
435 }
436
437 pub fn add_constraint(mut self, constraint: ConstraintFn) -> Self {
439 self.constraints.push(constraint);
440 self
441 }
442
443 pub fn with_soft_constraints(mut self, penalty: f32) -> Self {
445 self.soft_constraints = true;
446 self.constraint_penalty = penalty;
447 self
448 }
449
450 fn satisfies_constraints(&self, sequence: &[f32]) -> bool {
452 self.constraints.iter().all(|c| c(sequence))
453 }
454
455 pub fn expand(&mut self, logits: &Array2<f32>) -> InferenceResult<()> {
457 self.beam_search.expand(logits)?;
459
460 if self.soft_constraints {
462 let violations: Vec<bool> = self
465 .beam_search
466 .beams
467 .iter()
468 .map(|beam| !self.satisfies_constraints(&beam.sequence))
469 .collect();
470
471 let penalty = self.constraint_penalty;
473 for (beam, &violates) in self.beam_search.beams.iter_mut().zip(violations.iter()) {
474 if violates {
475 beam.log_prob -= penalty;
476 }
477 }
478
479 self.beam_search.beams.sort_by(|a, b| {
481 b.log_prob
482 .partial_cmp(&a.log_prob)
483 .unwrap_or(std::cmp::Ordering::Equal)
484 });
485 } else {
486 let valid_beams: Vec<Beam> = self
488 .beam_search
489 .beams
490 .iter()
491 .filter(|beam| self.satisfies_constraints(&beam.sequence))
492 .cloned()
493 .collect();
494
495 if !valid_beams.is_empty() {
496 self.beam_search.beams = valid_beams;
497 }
498 }
501
502 Ok(())
503 }
504
505 pub fn best(&self) -> Option<&Beam> {
507 self.beam_search.best()
508 }
509
510 pub fn beams(&self) -> &[Beam] {
512 self.beam_search.beams()
513 }
514
515 pub fn num_constraints(&self) -> usize {
517 self.constraints.len()
518 }
519}
520
521use std::sync::Arc;
522
523pub struct RejectionSampler {
529 base_sampler: Sampler,
531 constraints: Vec<ConstraintFn>,
533 max_attempts: usize,
535 fallback_strategy: FallbackStrategy,
537}
538
539#[derive(Debug, Clone, Copy, PartialEq, Eq)]
541pub enum FallbackStrategy {
542 BestCandidate,
544 Greedy,
546 Error,
548}
549
550impl RejectionSampler {
551 pub fn new(config: SamplingConfig) -> Self {
553 Self {
554 base_sampler: Sampler::new(config),
555 constraints: Vec::new(),
556 max_attempts: 100,
557 fallback_strategy: FallbackStrategy::BestCandidate,
558 }
559 }
560
561 pub fn add_constraint(mut self, constraint: ConstraintFn) -> Self {
563 self.constraints.push(constraint);
564 self
565 }
566
567 pub fn max_attempts(mut self, attempts: usize) -> Self {
569 self.max_attempts = attempts;
570 self
571 }
572
573 pub fn fallback_strategy(mut self, strategy: FallbackStrategy) -> Self {
575 self.fallback_strategy = strategy;
576 self
577 }
578
579 pub fn sample_with_rejection(
588 &mut self,
589 logits: &Array1<f32>,
590 context: &[f32],
591 ) -> InferenceResult<f32> {
592 if self.constraints.is_empty() {
593 return self.base_sampler.sample(logits);
595 }
596
597 let mut best_candidate = None;
598 let mut min_violations = usize::MAX;
599
600 for attempt in 0..self.max_attempts {
601 let candidate = self.base_sampler.sample(logits)?;
602
603 let mut test_sequence = context.to_vec();
605 test_sequence.push(candidate);
606
607 let violations = self.count_violations(&test_sequence);
609
610 if violations == 0 {
611 return Ok(candidate);
613 }
614
615 if violations < min_violations {
617 min_violations = violations;
618 best_candidate = Some(candidate);
619 }
620
621 if attempt > self.max_attempts / 2 && violations < self.constraints.len() / 2 {
623 break;
624 }
625 }
626
627 match self.fallback_strategy {
629 FallbackStrategy::BestCandidate => best_candidate.ok_or_else(|| {
630 InferenceError::ForwardError(
631 "Rejection sampling failed: no candidates generated".to_string(),
632 )
633 }),
634 FallbackStrategy::Greedy => {
635 let greedy_config = SamplingConfig::new().strategy(SamplingStrategy::Greedy);
636 let mut greedy_sampler = Sampler::new(greedy_config);
637 greedy_sampler.sample(logits)
638 }
639 FallbackStrategy::Error => Err(InferenceError::ForwardError(format!(
640 "Rejection sampling failed after {} attempts",
641 self.max_attempts
642 ))),
643 }
644 }
645
646 fn count_violations(&self, sequence: &[f32]) -> usize {
648 self.constraints
649 .iter()
650 .filter(|constraint| !constraint(sequence))
651 .count()
652 }
653
654 pub fn base_sampler(&self) -> &Sampler {
656 &self.base_sampler
657 }
658
659 pub fn base_sampler_mut(&mut self) -> &mut Sampler {
661 &mut self.base_sampler
662 }
663
664 pub fn num_constraints(&self) -> usize {
666 self.constraints.len()
667 }
668}
669
670pub struct AdaptiveRejectionSampler {
672 rejection_sampler: RejectionSampler,
674 rejection_counts: Vec<usize>,
676 total_samples: usize,
678}
679
680impl AdaptiveRejectionSampler {
681 pub fn new(config: SamplingConfig, vocab_size: usize) -> Self {
683 Self {
684 rejection_sampler: RejectionSampler::new(config),
685 rejection_counts: vec![0; vocab_size],
686 total_samples: 0,
687 }
688 }
689
690 pub fn add_constraint(mut self, constraint: ConstraintFn) -> Self {
692 self.rejection_sampler = self.rejection_sampler.add_constraint(constraint);
693 self
694 }
695
696 pub fn sample_adaptive(
698 &mut self,
699 logits: &Array1<f32>,
700 context: &[f32],
701 ) -> InferenceResult<f32> {
702 self.total_samples += 1;
703
704 let mut adjusted_logits = logits.clone();
706 if self.total_samples > 10 {
707 let max_rejections = *self.rejection_counts.iter().max().unwrap_or(&1) as f32;
708 for (i, &count) in self.rejection_counts.iter().enumerate() {
709 if i < adjusted_logits.len() && count > 0 {
710 let penalty = (count as f32 / max_rejections) * 2.0;
712 adjusted_logits[i] -= penalty;
713 }
714 }
715 }
716
717 let result = self
719 .rejection_sampler
720 .sample_with_rejection(&adjusted_logits, context);
721
722 if let Ok(value) = result {
724 Ok(value)
725 } else {
726 let greedy_config = SamplingConfig::new().strategy(SamplingStrategy::Greedy);
728 let mut greedy_sampler = Sampler::new(greedy_config);
729 if let Ok(fallback) = greedy_sampler.sample(&adjusted_logits) {
730 let idx = fallback as usize;
731 if idx < self.rejection_counts.len() {
732 self.rejection_counts[idx] += 1;
733 }
734 }
735 Err(InferenceError::ForwardError(
736 "Adaptive rejection sampling failed".to_string(),
737 ))
738 }
739 }
740
741 pub fn rejection_rate(&self) -> f32 {
743 if self.total_samples == 0 {
744 return 0.0;
745 }
746 let total_rejections: usize = self.rejection_counts.iter().sum();
747 total_rejections as f32 / self.total_samples as f32
748 }
749
750 pub fn reset_stats(&mut self) {
752 self.rejection_counts.fill(0);
753 self.total_samples = 0;
754 }
755}
756
757#[cfg(test)]
758mod tests {
759 use super::*;
760
761 #[test]
762 fn test_greedy_sampling() {
763 let config = SamplingConfig::new().strategy(SamplingStrategy::Greedy);
764 let mut sampler = Sampler::new(config);
765
766 let logits = Array1::from_vec(vec![0.1, 0.5, 0.3, 0.8, 0.2]);
767 let result = sampler.sample(&logits).unwrap();
768 assert_eq!(result, 3.0); }
770
771 #[test]
772 fn test_temperature_sampling() {
773 let config = SamplingConfig::new()
774 .strategy(SamplingStrategy::Temperature)
775 .temperature(0.5)
776 .seed(42);
777 let mut sampler = Sampler::new(config);
778
779 let logits = Array1::from_vec(vec![0.1, 0.5, 0.3, 0.8, 0.2]);
780 let result = sampler.sample(&logits);
781 assert!(result.is_ok());
782 }
783
784 #[test]
785 fn test_top_k_sampling() {
786 let config = SamplingConfig::new().top_k(3).seed(42);
787 let mut sampler = Sampler::new(config);
788
789 let logits = Array1::from_vec(vec![0.1, 0.5, 0.3, 0.8, 0.2]);
790 let result = sampler.sample(&logits);
791 assert!(result.is_ok());
792 }
793
794 #[test]
795 fn test_top_p_sampling() {
796 let config = SamplingConfig::new().top_p(0.9).seed(42);
797 let mut sampler = Sampler::new(config);
798
799 let logits = Array1::from_vec(vec![0.1, 0.5, 0.3, 0.8, 0.2]);
800 let result = sampler.sample(&logits);
801 assert!(result.is_ok());
802 }
803
804 #[test]
805 fn test_softmax() {
806 let logits = Array1::from_vec(vec![1.0, 2.0, 3.0]);
807 let probs = softmax(&logits);
808
809 let sum: f32 = probs.sum();
811 assert!((sum - 1.0).abs() < 1e-6);
812
813 let max_idx = probs
815 .iter()
816 .enumerate()
817 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
818 .map(|(idx, _)| idx)
819 .unwrap();
820 assert_eq!(max_idx, 2);
821 }
822
823 #[test]
824 fn test_beam_search() {
825 let mut bs = BeamSearch::new(2);
826
827 let logits1 = Array2::from_shape_vec((1, 3), vec![0.5, 0.3, 0.2]).unwrap();
829 bs.expand(&logits1).unwrap();
830 assert_eq!(bs.beams().len(), 2);
831
832 let logits2 = Array2::from_shape_vec((2, 3), vec![0.4, 0.3, 0.3, 0.5, 0.3, 0.2]).unwrap();
834 bs.expand(&logits2).unwrap();
835 assert_eq!(bs.beams().len(), 2);
836
837 let best = bs.best().unwrap();
838 assert_eq!(best.sequence.len(), 2);
839 }
840
841 #[test]
842 fn test_beam_avg_log_prob() {
843 let mut beam = Beam::new();
844 beam.extend(1.0, -0.5);
845 beam.extend(2.0, -0.3);
846
847 let avg = beam.avg_log_prob();
848 assert!((avg - (-0.4)).abs() < 1e-6);
849 }
850
851 #[test]
852 fn test_sample_batch() {
853 let config = SamplingConfig::new().strategy(SamplingStrategy::Greedy);
854 let mut sampler = Sampler::new(config);
855
856 let logits = Array2::from_shape_vec(
857 (3, 4),
858 vec![
859 0.1, 0.5, 0.3, 0.2, 0.8, 0.2, 0.1, 0.3, 0.2, 0.3, 0.9, 0.1, ],
863 )
864 .unwrap();
865
866 let results = sampler.sample_batch(&logits).unwrap();
867 assert_eq!(results[0], 1.0);
868 assert_eq!(results[1], 0.0);
869 assert_eq!(results[2], 2.0);
870 }
871}