use crate::{
error::{RealizarError, Result},
layers::softmax,
tensor::Tensor,
};
mod algorithms;
mod sampler;
pub use algorithms::{
analyze_token_healing, apply_cfg, apply_dry_penalty, apply_xtc, sample_eta, sample_min_p,
sample_mirostat, sample_tfs, sample_typical, CfgConfig, DryConfig, EtaConfig, MirostatState,
TokenHealingConfig, TokenHealingResult, XtcConfig,
};
pub use sampler::{
apply_all_penalties, apply_dynamic_temperature, apply_infill_sampling, apply_logit_bias,
apply_presence_frequency_penalty, apply_repetition_penalty, AdvancedGenerationConfig,
BeamHypothesis, BeamSearchConfig, BeamSearchState, DynTempConfig, DynTempSampler,
GenerationPipeline, GenerativeModel, InfillConfig, InfillResult, InfillSampler, LogitBias,
LogitProcessor, LogitProcessorChain, LogitProcessorContext, PresenceFrequencyPenalty,
PromptCache, PromptCacheEntry, PromptCacheStats, RepetitionPenalty, RepetitionPenaltyConfig,
RepetitionPenaltySampler, Sampler, SamplerChain, SamplerContext, StopSequenceDetector,
StreamingGenerator, TemperatureSampler, TemperatureScaler, TokenSuppressor, TopKSampler,
TopPSampler,
};
pub(crate) fn sample_from_distribution(probs: &[f32], indices: &[usize], rng_value: f32) -> usize {
let mut cumsum = 0.0;
for (i, &prob) in probs.iter().enumerate() {
cumsum += prob;
if rng_value < cumsum {
return indices[i];
}
}
indices[indices.len() - 1]
}
pub(crate) fn logits_to_probs(indexed: &[(usize, f32)]) -> Vec<f32> {
let max_logit = indexed[0].1;
let exp_vals: Vec<f32> = indexed.iter().map(|(_, l)| (l - max_logit).exp()).collect();
let sum_exp: f32 = exp_vals.iter().sum();
exp_vals.iter().map(|e| e / sum_exp).collect()
}
pub(crate) fn build_nucleus(indexed: &[(usize, f32)], p: f32) -> Vec<(usize, f32)> {
let mut cumsum = 0.0;
let mut nucleus = Vec::new();
for &(idx, prob) in indexed {
nucleus.push((idx, prob));
cumsum += prob;
if cumsum >= p {
break;
}
}
nucleus
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SamplingStrategy {
Greedy,
TopK {
k: usize,
},
TopP {
p: f32,
},
}
#[derive(Debug, Clone)]
pub struct GenerationConfig {
pub max_tokens: usize,
pub strategy: SamplingStrategy,
pub temperature: f32,
pub eos_token_id: Option<usize>,
pub seed: Option<u64>,
}
impl Default for GenerationConfig {
fn default() -> Self {
Self {
max_tokens: 100,
strategy: SamplingStrategy::Greedy,
temperature: 1.0,
eos_token_id: None,
seed: None,
}
}
}
impl GenerationConfig {
#[must_use]
pub fn greedy() -> Self {
Self {
strategy: SamplingStrategy::Greedy,
..Default::default()
}
}
#[must_use]
pub fn top_k(k: usize) -> Self {
Self {
strategy: SamplingStrategy::TopK { k },
..Default::default()
}
}
#[must_use]
pub fn top_p(p: f32) -> Self {
Self {
strategy: SamplingStrategy::TopP { p },
..Default::default()
}
}
#[must_use]
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
#[must_use]
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = max_tokens;
self
}
#[must_use]
pub fn with_eos_token_id(mut self, eos_token_id: usize) -> Self {
self.eos_token_id = Some(eos_token_id);
self
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
}
pub fn apply_temperature(logits: &Tensor<f32>, temperature: f32) -> Result<Tensor<f32>> {
if temperature <= 0.0 {
return Err(RealizarError::InvalidShape {
reason: "Temperature must be positive".to_string(),
});
}
if (temperature - 1.0).abs() < 1e-6 {
return Ok(logits.clone());
}
let data = logits.data();
let scaled: Vec<f32> = data.iter().map(|&x| x / temperature).collect();
Tensor::from_vec(logits.shape().to_vec(), scaled)
}
pub fn sample_greedy(logits: &Tensor<f32>) -> Result<usize> {
let data = logits.data();
if data.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Logits cannot be empty".to_string(),
});
}
let mut max_idx = 0;
let mut max_val = data[0];
for (i, &val) in data.iter().enumerate().skip(1) {
if val > max_val {
max_val = val;
max_idx = i;
}
}
Ok(max_idx)
}
pub fn sample_top_k(logits: &Tensor<f32>, k: usize, rng_value: f32) -> Result<usize> {
let data = logits.data();
if data.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Logits cannot be empty".to_string(),
});
}
if k == 0 {
return Err(RealizarError::InvalidShape {
reason: "k must be > 0".to_string(),
});
}
let mut indexed: Vec<(usize, f32)> = data.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_k: Vec<(usize, f32)> = indexed.into_iter().take(k.min(data.len())).collect();
let probs = logits_to_probs(&top_k);
let indices: Vec<usize> = top_k.iter().map(|(idx, _)| *idx).collect();
Ok(sample_from_distribution(&probs, &indices, rng_value))
}
pub fn sample_top_p(logits: &Tensor<f32>, p: f32, rng_value: f32) -> Result<usize> {
let data = logits.data();
if data.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Logits cannot be empty".to_string(),
});
}
if p <= 0.0 || p > 1.0 {
return Err(RealizarError::InvalidShape {
reason: "p must be in (0, 1]".to_string(),
});
}
let probs_tensor = softmax(logits)?;
let probs = probs_tensor.data();
let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let nucleus = build_nucleus(&indexed, p);
let nucleus_sum: f32 = nucleus.iter().map(|(_, prob)| prob).sum();
let normalized_probs: Vec<f32> = nucleus.iter().map(|(_, prob)| prob / nucleus_sum).collect();
let indices: Vec<usize> = nucleus.iter().map(|(idx, _)| *idx).collect();
Ok(sample_from_distribution(
&normalized_probs,
&indices,
rng_value,
))
}
pub fn sample_token(
logits: &Tensor<f32>,
config: &GenerationConfig,
rng_value: f32,
) -> Result<usize> {
let scaled_logits = apply_temperature(logits, config.temperature)?;
match config.strategy {
SamplingStrategy::Greedy => sample_greedy(&scaled_logits),
SamplingStrategy::TopK { k } => sample_top_k(&scaled_logits, k, rng_value),
SamplingStrategy::TopP { p } => sample_top_p(&scaled_logits, p, rng_value),
}
}
#[cfg(test)]
#[path = "tests.rs"]
mod generate_tests;
#[cfg(test)]
#[path = "tests_sample_greedy.rs"]
mod generate_tests_part_02;
#[cfg(test)]
mod algorithms_tests;
#[cfg(test)]
#[path = "tests_sampling_contract.rs"]
mod tests_sampling_contract;