use super::{sample_from_distribution, sample_greedy};
use crate::error::{RealizarError, Result};
use crate::layers::softmax;
use crate::tensor::Tensor;
use serde::{Deserialize, Serialize};
pub fn sample_min_p(logits: &Tensor<f32>, min_p: f32, rng_value: f32) -> Result<usize> {
let data = logits.data();
if data.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Logits cannot be empty".to_string(),
});
}
if !(0.0..=1.0).contains(&min_p) {
return Err(RealizarError::InvalidShape {
reason: "min_p must be in [0, 1]".to_string(),
});
}
let probs_tensor = softmax(logits)?;
let probs = probs_tensor.data();
let max_prob = probs.iter().copied().fold(0.0_f32, f32::max);
let threshold = min_p * max_prob;
let mut candidates: Vec<(usize, f32)> = probs
.iter()
.copied()
.enumerate()
.filter(|(_, p)| *p >= threshold)
.collect();
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if candidates.is_empty() {
return sample_greedy(logits);
}
let sum: f32 = candidates.iter().map(|(_, p)| p).sum();
let normalized: Vec<f32> = candidates.iter().map(|(_, p)| p / sum).collect();
let indices: Vec<usize> = candidates.iter().map(|(idx, _)| *idx).collect();
Ok(sample_from_distribution(&normalized, &indices, rng_value))
}
#[derive(Debug, Clone)]
pub struct MirostatState {
pub tau: f32,
pub eta: f32,
pub mu: f32,
}
impl Default for MirostatState {
fn default() -> Self {
Self {
tau: 5.0, eta: 0.1, mu: 10.0, }
}
}
impl MirostatState {
pub fn new(tau: f32) -> Self {
Self {
tau,
eta: 0.1,
mu: 2.0 * tau,
}
}
#[must_use]
pub fn with_eta(mut self, eta: f32) -> Self {
self.eta = eta;
self
}
pub fn update(&mut self, observed_surprise: f32) {
self.mu -= self.eta * (observed_surprise - self.tau);
}
}
pub fn sample_mirostat(
logits: &Tensor<f32>,
state: &mut MirostatState,
rng_value: f32,
) -> Result<usize> {
let data = logits.data();
if data.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Logits cannot be empty".to_string(),
});
}
let probs_tensor = softmax(logits)?;
let probs = probs_tensor.data();
let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_candidate = indexed[0];
let mut candidates = Vec::new();
for (idx, prob) in indexed {
let surprise = -prob.ln();
if surprise > state.mu {
break;
}
candidates.push((idx, prob));
}
if candidates.is_empty() {
candidates.push(top_candidate);
}
let sum: f32 = candidates.iter().map(|(_, p)| p).sum();
let normalized: Vec<f32> = candidates.iter().map(|(_, p)| p / sum).collect();
let indices: Vec<usize> = candidates.iter().map(|(idx, _)| *idx).collect();
let selected = sample_from_distribution(&normalized, &indices, rng_value);
let selected_idx = indices.iter().position(|&i| i == selected).unwrap_or(0);
let selected_prob = candidates[selected_idx].1;
let observed_surprise = -selected_prob.ln();
state.update(observed_surprise);
Ok(selected)
}
pub fn sample_tfs(logits: &Tensor<f32>, z: f32, rng_value: f32) -> Result<usize> {
let data = logits.data();
if data.is_empty() {
return Err(crate::error::RealizarError::InvalidShape {
reason: "Logits cannot be empty".to_string(),
});
}
let max_logit = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_logits: Vec<f32> = data.iter().map(|&x| (x - max_logit).exp()).collect();
let sum: f32 = exp_logits.iter().sum();
let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum).collect();
let mut indexed: Vec<(usize, f32)> = probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if indexed.len() < 3 {
return Ok(indexed[0].0);
}
let first_derivatives: Vec<f32> = indexed
.windows(2)
.map(|w| (w[0].1 - w[1].1).abs())
.collect();
let second_derivatives: Vec<f32> = first_derivatives
.windows(2)
.map(|w| (w[0] - w[1]).abs())
.collect();
let sum_second: f32 = second_derivatives.iter().sum();
let normalized: Vec<f32> = if sum_second > 1e-9 {
second_derivatives.iter().map(|&x| x / sum_second).collect()
} else {
vec![1.0 / second_derivatives.len() as f32; second_derivatives.len()]
};
let mut cumsum = 0.0;
let mut cutoff_idx = indexed.len();
for (i, &val) in normalized.iter().enumerate() {
cumsum += val;
if cumsum > z {
cutoff_idx = i + 2; break;
}
}
let kept: Vec<(usize, f32)> = indexed.into_iter().take(cutoff_idx.max(1)).collect();
let sum_kept: f32 = kept.iter().map(|(_, p)| p).sum();
let normalized_kept: Vec<f32> = kept.iter().map(|(_, p)| p / sum_kept).collect();
let indices: Vec<usize> = kept.iter().map(|(idx, _)| *idx).collect();
Ok(sample_from_distribution(
&normalized_kept,
&indices,
rng_value,
))
}
pub fn sample_typical(logits: &Tensor<f32>, p: f32, rng_value: f32) -> Result<usize> {
let data = logits.data();
if data.is_empty() {
return Err(crate::error::RealizarError::InvalidShape {
reason: "Logits cannot be empty".to_string(),
});
}
let max_logit = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_logits: Vec<f32> = data.iter().map(|&x| (x - max_logit).exp()).collect();
let sum: f32 = exp_logits.iter().sum();
let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum).collect();
let entropy: f32 = -probs
.iter()
.filter(|&&p| p > 1e-10)
.map(|&p| p * p.ln())
.sum::<f32>();
let mut indexed: Vec<(usize, f32, f32)> = probs
.iter()
.enumerate()
.filter(|(_, &prob)| prob > 1e-10)
.map(|(i, &prob)| {
let info = -prob.ln();
let deviation = (info - entropy).abs();
(i, prob, deviation)
})
.collect();
indexed.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0;
let mut kept: Vec<(usize, f32)> = Vec::new();
for (idx, prob, _) in indexed {
kept.push((idx, prob));
cumsum += prob;
if cumsum >= p {
break;
}
}
if kept.is_empty() {
let max_idx = probs
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i);
return Ok(max_idx);
}
let sum_kept: f32 = kept.iter().map(|(_, p)| p).sum();
let normalized: Vec<f32> = kept.iter().map(|(_, p)| p / sum_kept).collect();
let indices: Vec<usize> = kept.iter().map(|(idx, _)| *idx).collect();
Ok(sample_from_distribution(&normalized, &indices, rng_value))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DryConfig {
pub multiplier: f32,
pub base: f32,
pub allowed_length: usize,
pub penalty_last_n: usize,
}
impl Default for DryConfig {
fn default() -> Self {
Self {
multiplier: 0.8,
base: 1.75,
allowed_length: 2,
penalty_last_n: 256,
}
}
}
impl DryConfig {
pub fn new(multiplier: f32) -> Self {
Self {
multiplier,
..Default::default()
}
}
#[must_use]
pub fn with_base(mut self, base: f32) -> Self {
self.base = base;
self
}
#[must_use]
pub fn with_allowed_length(mut self, len: usize) -> Self {
self.allowed_length = len;
self
}
#[must_use]
pub fn with_penalty_last_n(mut self, n: usize) -> Self {
self.penalty_last_n = n;
self
}
pub fn is_enabled(&self) -> bool {
self.multiplier > 0.0
}
}
include!("dry_penalty.rs");
include!("classifier_free_guidance.rs");
include!("sampling_tests.rs");