kizzasi_inference/
sampling.rs

1//! Sampling strategies for inference
2//!
3//! This module provides various sampling strategies for autoregressive prediction:
4//! - **Greedy**: Always select the highest probability value
5//! - **Temperature**: Scale logits to control randomness
6//! - **Top-k**: Sample from the k most likely values
7//! - **Top-p (nucleus)**: Sample from the smallest set with cumulative probability >= p
8//! - **Beam search**: Maintain multiple hypotheses for multi-step prediction
9
10use crate::error::{InferenceError, InferenceResult};
11use scirs2_core::ndarray::{Array1, Array2};
12
13/// Configuration for sampling strategies
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct SamplingConfig {
16    /// Sampling strategy to use
17    pub strategy: SamplingStrategy,
18    /// Temperature for scaling (1.0 = no scaling, <1.0 = sharper, >1.0 = smoother)
19    pub temperature: f32,
20    /// For top-k sampling: number of top candidates to consider
21    pub top_k: Option<usize>,
22    /// For top-p sampling: cumulative probability threshold
23    pub top_p: Option<f32>,
24    /// For beam search: beam width
25    pub beam_width: usize,
26    /// Random seed for reproducibility
27    pub seed: Option<u64>,
28}
29
30impl Default for SamplingConfig {
31    fn default() -> Self {
32        Self {
33            strategy: SamplingStrategy::Greedy,
34            temperature: 1.0,
35            top_k: None,
36            top_p: None,
37            beam_width: 1,
38            seed: None,
39        }
40    }
41}
42
43impl SamplingConfig {
44    /// Create a new sampling configuration
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Set the sampling strategy
50    pub fn strategy(mut self, strategy: SamplingStrategy) -> Self {
51        self.strategy = strategy;
52        self
53    }
54
55    /// Set temperature (1.0 = no scaling)
56    pub fn temperature(mut self, temp: f32) -> Self {
57        self.temperature = temp;
58        self
59    }
60
61    /// Enable top-k sampling
62    pub fn top_k(mut self, k: usize) -> Self {
63        self.strategy = SamplingStrategy::TopK;
64        self.top_k = Some(k);
65        self
66    }
67
68    /// Enable top-p (nucleus) sampling
69    pub fn top_p(mut self, p: f32) -> Self {
70        self.strategy = SamplingStrategy::TopP;
71        self.top_p = Some(p);
72        self
73    }
74
75    /// Enable beam search
76    pub fn beam_search(mut self, width: usize) -> Self {
77        self.strategy = SamplingStrategy::BeamSearch;
78        self.beam_width = width;
79        self
80    }
81
82    /// Set random seed
83    pub fn seed(mut self, seed: u64) -> Self {
84        self.seed = Some(seed);
85        self
86    }
87}
88
89/// Available sampling strategies
90#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
91pub enum SamplingStrategy {
92    /// Always select the highest value (deterministic)
93    Greedy,
94    /// Sample from scaled distribution
95    Temperature,
96    /// Sample from top k candidates
97    TopK,
98    /// Sample from nucleus (cumulative probability threshold)
99    TopP,
100    /// Maintain multiple beams for multi-step prediction
101    BeamSearch,
102    /// Custom sampling function
103    Custom,
104}
105
106/// Custom sampling function type
107/// Takes logits and temperature, returns sampled index
108pub type CustomSamplingFn = Arc<dyn Fn(&Array1<f32>, f32) -> InferenceResult<f32> + Send + Sync>;
109
110/// Sampler for generating predictions from model outputs
111pub struct Sampler {
112    config: SamplingConfig,
113    /// Custom sampling function (if strategy is Custom)
114    custom_fn: Option<CustomSamplingFn>,
115}
116
117impl Sampler {
118    /// Create a new sampler with given configuration
119    pub fn new(config: SamplingConfig) -> Self {
120        Self {
121            config,
122            custom_fn: None,
123        }
124    }
125
126    /// Create a sampler with a custom sampling function
127    pub fn with_custom_fn(mut config: SamplingConfig, custom_fn: CustomSamplingFn) -> Self {
128        config.strategy = SamplingStrategy::Custom;
129        Self {
130            config,
131            custom_fn: Some(custom_fn),
132        }
133    }
134
135    /// Set custom sampling function
136    pub fn set_custom_fn(&mut self, custom_fn: CustomSamplingFn) {
137        self.custom_fn = Some(custom_fn);
138        self.config.strategy = SamplingStrategy::Custom;
139    }
140
141    /// Sample a single value from logits
142    ///
143    /// # Arguments
144    /// * `logits` - Raw model outputs (unnormalized)
145    ///
146    /// # Returns
147    /// The sampled value
148    pub fn sample(&mut self, logits: &Array1<f32>) -> InferenceResult<f32> {
149        if logits.is_empty() {
150            return Err(InferenceError::DimensionMismatch {
151                expected: 1,
152                got: 0,
153            });
154        }
155
156        match self.config.strategy {
157            SamplingStrategy::Greedy => Ok(self.greedy_sample(logits)),
158            SamplingStrategy::Temperature => self.temperature_sample(logits),
159            SamplingStrategy::TopK => self.top_k_sample(logits),
160            SamplingStrategy::TopP => self.top_p_sample(logits),
161            SamplingStrategy::BeamSearch => {
162                // Beam search requires multi-step context, use greedy for single-step
163                Ok(self.greedy_sample(logits))
164            }
165            SamplingStrategy::Custom => {
166                if let Some(ref custom_fn) = self.custom_fn {
167                    custom_fn(logits, self.config.temperature)
168                } else {
169                    // Fallback to greedy if no custom function is set
170                    Ok(self.greedy_sample(logits))
171                }
172            }
173        }
174    }
175
176    /// Sample multiple values from a batch of logits
177    pub fn sample_batch(&mut self, logits: &Array2<f32>) -> InferenceResult<Array1<f32>> {
178        let batch_size = logits.nrows();
179        let mut results = Vec::with_capacity(batch_size);
180
181        for i in 0..batch_size {
182            let logit_row = logits.row(i).to_owned();
183            results.push(self.sample(&logit_row)?);
184        }
185
186        Ok(Array1::from_vec(results))
187    }
188
189    /// Greedy sampling: select the maximum value
190    fn greedy_sample(&self, logits: &Array1<f32>) -> f32 {
191        logits
192            .iter()
193            .enumerate()
194            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
195            .map(|(idx, _)| idx as f32)
196            .unwrap_or(0.0)
197    }
198
199    /// Temperature sampling with optional scaling
200    fn temperature_sample(&mut self, logits: &Array1<f32>) -> InferenceResult<f32> {
201        let scaled = if (self.config.temperature - 1.0).abs() > 1e-6 {
202            logits.mapv(|x| x / self.config.temperature)
203        } else {
204            logits.clone()
205        };
206
207        let probs = softmax(&scaled);
208        self.sample_categorical(&probs)
209    }
210
211    /// Top-k sampling: sample from k most likely candidates
212    fn top_k_sample(&mut self, logits: &Array1<f32>) -> InferenceResult<f32> {
213        let k = self.config.top_k.unwrap_or(10);
214
215        // Get top-k indices
216        let mut indexed: Vec<_> = logits.iter().enumerate().collect();
217        indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
218        let top_k_indices: Vec<usize> = indexed.iter().take(k).map(|(idx, _)| *idx).collect();
219
220        // Create filtered logits
221        let mut filtered = Array1::from_elem(logits.len(), f32::NEG_INFINITY);
222        for &idx in &top_k_indices {
223            filtered[idx] = logits[idx];
224        }
225
226        let probs = softmax(&filtered);
227        self.sample_categorical(&probs)
228    }
229
230    /// Top-p (nucleus) sampling: sample from cumulative probability threshold
231    fn top_p_sample(&mut self, logits: &Array1<f32>) -> InferenceResult<f32> {
232        let p = self.config.top_p.unwrap_or(0.9);
233
234        // Sort by probability (descending)
235        let probs = softmax(logits);
236        let mut indexed: Vec<_> = probs.iter().enumerate().collect();
237        indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
238
239        // Find nucleus (smallest set with cumulative prob >= p)
240        let mut cumsum = 0.0;
241        let mut nucleus_size = 0;
242        for (_, &prob) in &indexed {
243            cumsum += prob;
244            nucleus_size += 1;
245            if cumsum >= p {
246                break;
247            }
248        }
249
250        // Create filtered logits
251        let nucleus_indices: Vec<usize> = indexed
252            .iter()
253            .take(nucleus_size)
254            .map(|(idx, _)| *idx)
255            .collect();
256        let mut filtered = Array1::from_elem(logits.len(), f32::NEG_INFINITY);
257        for &idx in &nucleus_indices {
258            filtered[idx] = logits[idx];
259        }
260
261        let filtered_probs = softmax(&filtered);
262        self.sample_categorical(&filtered_probs)
263    }
264
265    /// Sample from a categorical distribution
266    fn sample_categorical(&mut self, probs: &Array1<f32>) -> InferenceResult<f32> {
267        // Use simple random sampling based on system RNG
268        use scirs2_core::random::{rng, Rng};
269
270        let mut rng_gen = rng();
271        let uniform: f32 = rng_gen.random();
272        let mut cumsum = 0.0;
273        for (idx, &prob) in probs.iter().enumerate() {
274            cumsum += prob;
275            if uniform < cumsum {
276                return Ok(idx as f32);
277            }
278        }
279        // Fallback to last index
280        Ok((probs.len() - 1) as f32)
281    }
282
283    /// Get the current configuration
284    pub fn config(&self) -> &SamplingConfig {
285        &self.config
286    }
287}
288
289/// Apply softmax to convert logits to probabilities
290fn softmax(logits: &Array1<f32>) -> Array1<f32> {
291    // Subtract max for numerical stability
292    let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
293    let exp_logits = logits.mapv(|x| (x - max_logit).exp());
294    let sum_exp: f32 = exp_logits.sum();
295
296    if sum_exp > 0.0 {
297        exp_logits / sum_exp
298    } else {
299        // All zeros case - uniform distribution
300        Array1::from_elem(logits.len(), 1.0 / logits.len() as f32)
301    }
302}
303
304/// Beam search state for multi-step prediction
305#[derive(Debug, Clone)]
306pub struct Beam {
307    /// Sequence of values
308    pub sequence: Vec<f32>,
309    /// Cumulative log probability
310    pub log_prob: f32,
311    /// Current hidden states
312    pub states: Vec<kizzasi_core::HiddenState>,
313}
314
315impl Beam {
316    /// Create a new beam
317    pub fn new() -> Self {
318        Self {
319            sequence: Vec::new(),
320            log_prob: 0.0,
321            states: Vec::new(),
322        }
323    }
324
325    /// Add a value to the beam
326    pub fn extend(&mut self, value: f32, log_prob: f32) {
327        self.sequence.push(value);
328        self.log_prob += log_prob;
329    }
330
331    /// Get the average log probability (normalized by length)
332    pub fn avg_log_prob(&self) -> f32 {
333        if self.sequence.is_empty() {
334            0.0
335        } else {
336            self.log_prob / self.sequence.len() as f32
337        }
338    }
339}
340
341impl Default for Beam {
342    fn default() -> Self {
343        Self::new()
344    }
345}
346
347/// Beam search manager
348pub struct BeamSearch {
349    /// Number of beams to maintain
350    beam_width: usize,
351    /// Current beams
352    beams: Vec<Beam>,
353}
354
355impl BeamSearch {
356    /// Create a new beam search with given width
357    pub fn new(beam_width: usize) -> Self {
358        let beams = vec![Beam::new()];
359        Self { beam_width, beams }
360    }
361
362    /// Expand beams with new candidates
363    pub fn expand(&mut self, logits: &Array2<f32>) -> InferenceResult<()> {
364        if logits.nrows() != self.beams.len() {
365            return Err(InferenceError::DimensionMismatch {
366                expected: self.beams.len(),
367                got: logits.nrows(),
368            });
369        }
370
371        let mut candidates = Vec::new();
372
373        for (beam_idx, beam) in self.beams.iter().enumerate() {
374            let beam_logits = logits.row(beam_idx).to_owned();
375            let probs = softmax(&beam_logits);
376
377            // Get top-k candidates for this beam
378            let mut indexed: Vec<_> = probs.iter().enumerate().collect();
379            indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
380
381            for (idx, &prob) in indexed.iter().take(self.beam_width) {
382                let mut new_beam = beam.clone();
383                new_beam.extend(*idx as f32, prob.ln());
384                candidates.push(new_beam);
385            }
386        }
387
388        // Select top beam_width candidates
389        candidates.sort_by(|a, b| {
390            b.avg_log_prob()
391                .partial_cmp(&a.avg_log_prob())
392                .unwrap_or(std::cmp::Ordering::Equal)
393        });
394        self.beams = candidates.into_iter().take(self.beam_width).collect();
395
396        Ok(())
397    }
398
399    /// Get the best beam
400    pub fn best(&self) -> Option<&Beam> {
401        self.beams.first()
402    }
403
404    /// Get all beams
405    pub fn beams(&self) -> &[Beam] {
406        &self.beams
407    }
408}
409
410/// Constraint function type for constrained beam search
411/// Returns true if the sequence satisfies the constraint
412pub type ConstraintFn = Arc<dyn Fn(&[f32]) -> bool + Send + Sync>;
413
414/// Constrained beam search that only keeps beams satisfying constraints
415pub struct ConstrainedBeamSearch {
416    /// Underlying beam search
417    beam_search: BeamSearch,
418    /// Constraint functions to check
419    constraints: Vec<ConstraintFn>,
420    /// Whether to use soft constraints (prefer but don't require)
421    soft_constraints: bool,
422    /// Penalty for violating soft constraints
423    constraint_penalty: f32,
424}
425
426impl ConstrainedBeamSearch {
427    /// Create a new constrained beam search
428    pub fn new(beam_width: usize) -> Self {
429        Self {
430            beam_search: BeamSearch::new(beam_width),
431            constraints: Vec::new(),
432            soft_constraints: false,
433            constraint_penalty: 1.0,
434        }
435    }
436
437    /// Add a hard constraint (must be satisfied)
438    pub fn add_constraint(mut self, constraint: ConstraintFn) -> Self {
439        self.constraints.push(constraint);
440        self
441    }
442
443    /// Enable soft constraints with penalty
444    pub fn with_soft_constraints(mut self, penalty: f32) -> Self {
445        self.soft_constraints = true;
446        self.constraint_penalty = penalty;
447        self
448    }
449
450    /// Check if a sequence satisfies all constraints
451    fn satisfies_constraints(&self, sequence: &[f32]) -> bool {
452        self.constraints.iter().all(|c| c(sequence))
453    }
454
455    /// Expand beams with constraint checking
456    pub fn expand(&mut self, logits: &Array2<f32>) -> InferenceResult<()> {
457        // First, perform standard beam expansion
458        self.beam_search.expand(logits)?;
459
460        // Then filter or penalize beams based on constraints
461        if self.soft_constraints {
462            // Soft constraints: penalize violating beams
463            // First collect which beams violate constraints
464            let violations: Vec<bool> = self
465                .beam_search
466                .beams
467                .iter()
468                .map(|beam| !self.satisfies_constraints(&beam.sequence))
469                .collect();
470
471            // Then apply penalties
472            let penalty = self.constraint_penalty;
473            for (beam, &violates) in self.beam_search.beams.iter_mut().zip(violations.iter()) {
474                if violates {
475                    beam.log_prob -= penalty;
476                }
477            }
478
479            // Re-sort by modified scores
480            self.beam_search.beams.sort_by(|a, b| {
481                b.log_prob
482                    .partial_cmp(&a.log_prob)
483                    .unwrap_or(std::cmp::Ordering::Equal)
484            });
485        } else {
486            // Hard constraints: filter out violating beams
487            let valid_beams: Vec<Beam> = self
488                .beam_search
489                .beams
490                .iter()
491                .filter(|beam| self.satisfies_constraints(&beam.sequence))
492                .cloned()
493                .collect();
494
495            if !valid_beams.is_empty() {
496                self.beam_search.beams = valid_beams;
497            }
498            // If all beams violate constraints, keep original beams
499            // (fallback behavior - could also raise error)
500        }
501
502        Ok(())
503    }
504
505    /// Get the best beam
506    pub fn best(&self) -> Option<&Beam> {
507        self.beam_search.best()
508    }
509
510    /// Get all beams
511    pub fn beams(&self) -> &[Beam] {
512        self.beam_search.beams()
513    }
514
515    /// Get number of active constraints
516    pub fn num_constraints(&self) -> usize {
517        self.constraints.len()
518    }
519}
520
521use std::sync::Arc;
522
523// ============================================================================
524// Rejection Sampling with Constraints
525// ============================================================================
526
527/// Rejection sampler that rejects samples violating constraints
528pub struct RejectionSampler {
529    /// Base sampler for generating candidates
530    base_sampler: Sampler,
531    /// Constraint functions to check
532    constraints: Vec<ConstraintFn>,
533    /// Maximum number of rejection attempts before giving up
534    max_attempts: usize,
535    /// Fallback strategy when all attempts fail
536    fallback_strategy: FallbackStrategy,
537}
538
539/// Fallback strategy when rejection sampling fails
540#[derive(Debug, Clone, Copy, PartialEq, Eq)]
541pub enum FallbackStrategy {
542    /// Return the best candidate that violates constraints least
543    BestCandidate,
544    /// Return greedy sample
545    Greedy,
546    /// Return an error
547    Error,
548}
549
550impl RejectionSampler {
551    /// Create a new rejection sampler
552    pub fn new(config: SamplingConfig) -> Self {
553        Self {
554            base_sampler: Sampler::new(config),
555            constraints: Vec::new(),
556            max_attempts: 100,
557            fallback_strategy: FallbackStrategy::BestCandidate,
558        }
559    }
560
561    /// Add a constraint function
562    pub fn add_constraint(mut self, constraint: ConstraintFn) -> Self {
563        self.constraints.push(constraint);
564        self
565    }
566
567    /// Set maximum number of rejection attempts
568    pub fn max_attempts(mut self, attempts: usize) -> Self {
569        self.max_attempts = attempts;
570        self
571    }
572
573    /// Set fallback strategy
574    pub fn fallback_strategy(mut self, strategy: FallbackStrategy) -> Self {
575        self.fallback_strategy = strategy;
576        self
577    }
578
579    /// Sample with constraint checking and rejection
580    ///
581    /// # Arguments
582    /// * `logits` - Model output logits
583    /// * `context` - Current sequence context for constraint checking
584    ///
585    /// # Returns
586    /// Sampled value that satisfies constraints, or fallback value
587    pub fn sample_with_rejection(
588        &mut self,
589        logits: &Array1<f32>,
590        context: &[f32],
591    ) -> InferenceResult<f32> {
592        if self.constraints.is_empty() {
593            // No constraints, just sample normally
594            return self.base_sampler.sample(logits);
595        }
596
597        let mut best_candidate = None;
598        let mut min_violations = usize::MAX;
599
600        for attempt in 0..self.max_attempts {
601            let candidate = self.base_sampler.sample(logits)?;
602
603            // Build test sequence
604            let mut test_sequence = context.to_vec();
605            test_sequence.push(candidate);
606
607            // Check constraints
608            let violations = self.count_violations(&test_sequence);
609
610            if violations == 0 {
611                // Found a valid sample!
612                return Ok(candidate);
613            }
614
615            // Track best candidate
616            if violations < min_violations {
617                min_violations = violations;
618                best_candidate = Some(candidate);
619            }
620
621            // Early exit if we're making progress
622            if attempt > self.max_attempts / 2 && violations < self.constraints.len() / 2 {
623                break;
624            }
625        }
626
627        // All attempts failed, use fallback
628        match self.fallback_strategy {
629            FallbackStrategy::BestCandidate => best_candidate.ok_or_else(|| {
630                InferenceError::ForwardError(
631                    "Rejection sampling failed: no candidates generated".to_string(),
632                )
633            }),
634            FallbackStrategy::Greedy => {
635                let greedy_config = SamplingConfig::new().strategy(SamplingStrategy::Greedy);
636                let mut greedy_sampler = Sampler::new(greedy_config);
637                greedy_sampler.sample(logits)
638            }
639            FallbackStrategy::Error => Err(InferenceError::ForwardError(format!(
640                "Rejection sampling failed after {} attempts",
641                self.max_attempts
642            ))),
643        }
644    }
645
646    /// Count how many constraints are violated
647    fn count_violations(&self, sequence: &[f32]) -> usize {
648        self.constraints
649            .iter()
650            .filter(|constraint| !constraint(sequence))
651            .count()
652    }
653
654    /// Get the base sampler
655    pub fn base_sampler(&self) -> &Sampler {
656        &self.base_sampler
657    }
658
659    /// Get mutable base sampler
660    pub fn base_sampler_mut(&mut self) -> &mut Sampler {
661        &mut self.base_sampler
662    }
663
664    /// Get number of constraints
665    pub fn num_constraints(&self) -> usize {
666        self.constraints.len()
667    }
668}
669
670/// Adaptive rejection sampler that learns from rejections
671pub struct AdaptiveRejectionSampler {
672    /// Base rejection sampler
673    rejection_sampler: RejectionSampler,
674    /// Rejection history for learning
675    rejection_counts: Vec<usize>,
676    /// Total samples attempted
677    total_samples: usize,
678}
679
680impl AdaptiveRejectionSampler {
681    /// Create a new adaptive rejection sampler
682    pub fn new(config: SamplingConfig, vocab_size: usize) -> Self {
683        Self {
684            rejection_sampler: RejectionSampler::new(config),
685            rejection_counts: vec![0; vocab_size],
686            total_samples: 0,
687        }
688    }
689
690    /// Add a constraint
691    pub fn add_constraint(mut self, constraint: ConstraintFn) -> Self {
692        self.rejection_sampler = self.rejection_sampler.add_constraint(constraint);
693        self
694    }
695
696    /// Sample with adaptive biasing away from frequently rejected values
697    pub fn sample_adaptive(
698        &mut self,
699        logits: &Array1<f32>,
700        context: &[f32],
701    ) -> InferenceResult<f32> {
702        self.total_samples += 1;
703
704        // Bias logits away from frequently rejected values
705        let mut adjusted_logits = logits.clone();
706        if self.total_samples > 10 {
707            let max_rejections = *self.rejection_counts.iter().max().unwrap_or(&1) as f32;
708            for (i, &count) in self.rejection_counts.iter().enumerate() {
709                if i < adjusted_logits.len() && count > 0 {
710                    // Penalize frequently rejected values
711                    let penalty = (count as f32 / max_rejections) * 2.0;
712                    adjusted_logits[i] -= penalty;
713                }
714            }
715        }
716
717        // Try to sample with rejection
718        let result = self
719            .rejection_sampler
720            .sample_with_rejection(&adjusted_logits, context);
721
722        // Record statistics even on success (for learning)
723        if let Ok(value) = result {
724            Ok(value)
725        } else {
726            // On failure, try greedy as fallback and record
727            let greedy_config = SamplingConfig::new().strategy(SamplingStrategy::Greedy);
728            let mut greedy_sampler = Sampler::new(greedy_config);
729            if let Ok(fallback) = greedy_sampler.sample(&adjusted_logits) {
730                let idx = fallback as usize;
731                if idx < self.rejection_counts.len() {
732                    self.rejection_counts[idx] += 1;
733                }
734            }
735            Err(InferenceError::ForwardError(
736                "Adaptive rejection sampling failed".to_string(),
737            ))
738        }
739    }
740
741    /// Get rejection statistics
742    pub fn rejection_rate(&self) -> f32 {
743        if self.total_samples == 0 {
744            return 0.0;
745        }
746        let total_rejections: usize = self.rejection_counts.iter().sum();
747        total_rejections as f32 / self.total_samples as f32
748    }
749
750    /// Reset statistics
751    pub fn reset_stats(&mut self) {
752        self.rejection_counts.fill(0);
753        self.total_samples = 0;
754    }
755}
756
757#[cfg(test)]
758mod tests {
759    use super::*;
760
761    #[test]
762    fn test_greedy_sampling() {
763        let config = SamplingConfig::new().strategy(SamplingStrategy::Greedy);
764        let mut sampler = Sampler::new(config);
765
766        let logits = Array1::from_vec(vec![0.1, 0.5, 0.3, 0.8, 0.2]);
767        let result = sampler.sample(&logits).unwrap();
768        assert_eq!(result, 3.0); // Index of max value
769    }
770
771    #[test]
772    fn test_temperature_sampling() {
773        let config = SamplingConfig::new()
774            .strategy(SamplingStrategy::Temperature)
775            .temperature(0.5)
776            .seed(42);
777        let mut sampler = Sampler::new(config);
778
779        let logits = Array1::from_vec(vec![0.1, 0.5, 0.3, 0.8, 0.2]);
780        let result = sampler.sample(&logits);
781        assert!(result.is_ok());
782    }
783
784    #[test]
785    fn test_top_k_sampling() {
786        let config = SamplingConfig::new().top_k(3).seed(42);
787        let mut sampler = Sampler::new(config);
788
789        let logits = Array1::from_vec(vec![0.1, 0.5, 0.3, 0.8, 0.2]);
790        let result = sampler.sample(&logits);
791        assert!(result.is_ok());
792    }
793
794    #[test]
795    fn test_top_p_sampling() {
796        let config = SamplingConfig::new().top_p(0.9).seed(42);
797        let mut sampler = Sampler::new(config);
798
799        let logits = Array1::from_vec(vec![0.1, 0.5, 0.3, 0.8, 0.2]);
800        let result = sampler.sample(&logits);
801        assert!(result.is_ok());
802    }
803
804    #[test]
805    fn test_softmax() {
806        let logits = Array1::from_vec(vec![1.0, 2.0, 3.0]);
807        let probs = softmax(&logits);
808
809        // Probabilities should sum to 1
810        let sum: f32 = probs.sum();
811        assert!((sum - 1.0).abs() < 1e-6);
812
813        // Highest logit should have highest probability
814        let max_idx = probs
815            .iter()
816            .enumerate()
817            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
818            .map(|(idx, _)| idx)
819            .unwrap();
820        assert_eq!(max_idx, 2);
821    }
822
823    #[test]
824    fn test_beam_search() {
825        let mut bs = BeamSearch::new(2);
826
827        // First expansion
828        let logits1 = Array2::from_shape_vec((1, 3), vec![0.5, 0.3, 0.2]).unwrap();
829        bs.expand(&logits1).unwrap();
830        assert_eq!(bs.beams().len(), 2);
831
832        // Second expansion
833        let logits2 = Array2::from_shape_vec((2, 3), vec![0.4, 0.3, 0.3, 0.5, 0.3, 0.2]).unwrap();
834        bs.expand(&logits2).unwrap();
835        assert_eq!(bs.beams().len(), 2);
836
837        let best = bs.best().unwrap();
838        assert_eq!(best.sequence.len(), 2);
839    }
840
841    #[test]
842    fn test_beam_avg_log_prob() {
843        let mut beam = Beam::new();
844        beam.extend(1.0, -0.5);
845        beam.extend(2.0, -0.3);
846
847        let avg = beam.avg_log_prob();
848        assert!((avg - (-0.4)).abs() < 1e-6);
849    }
850
851    #[test]
852    fn test_sample_batch() {
853        let config = SamplingConfig::new().strategy(SamplingStrategy::Greedy);
854        let mut sampler = Sampler::new(config);
855
856        let logits = Array2::from_shape_vec(
857            (3, 4),
858            vec![
859                0.1, 0.5, 0.3, 0.2, // Row 0: max at index 1
860                0.8, 0.2, 0.1, 0.3, // Row 1: max at index 0
861                0.2, 0.3, 0.9, 0.1, // Row 2: max at index 2
862            ],
863        )
864        .unwrap();
865
866        let results = sampler.sample_batch(&logits).unwrap();
867        assert_eq!(results[0], 1.0);
868        assert_eq!(results[1], 0.0);
869        assert_eq!(results[2], 2.0);
870    }
871}