pub fn apply_dry_penalty(
logits: &Tensor<f32>,
context_tokens: &[usize],
config: &DryConfig,
) -> Tensor<f32> {
if !config.is_enabled() || context_tokens.len() < config.allowed_length {
return logits.clone();
}
let data = logits.data();
let mut penalized = data.to_vec();
let window_start = if context_tokens.len() > config.penalty_last_n {
context_tokens.len() - config.penalty_last_n
} else {
0
};
let context = &context_tokens[window_start..];
for (token_id, logit) in penalized.iter_mut().enumerate() {
let match_len = find_ngram_match_length(context, token_id, config.allowed_length);
if match_len >= config.allowed_length {
let penalty =
config.multiplier * config.base.powi((match_len - config.allowed_length) as i32);
*logit -= penalty;
}
}
Tensor::from_vec(logits.shape().to_vec(), penalized)
.expect("Shape should match original logits")
}
#[inline]
fn check_ngram_match(
context: &[usize],
suffix: &[usize],
start: usize,
suffix_len: usize,
next_token: usize,
) -> bool {
let potential_end = start + suffix_len;
potential_end < context.len()
&& context[start..potential_end] == *suffix
&& context[potential_end] == next_token
}
fn find_ngram_match_length(context: &[usize], next_token: usize, min_len: usize) -> usize {
if context.len() < min_len {
return 0;
}
let mut max_match = 0;
for end_pos in min_len..=context.len() {
let search_start = context.len() - end_pos;
let suffix = &context[search_start..];
for start in 0..(context.len() - end_pos) {
if check_ngram_match(context, suffix, start, end_pos, next_token) {
max_match = max_match.max(end_pos + 1);
}
}
}
max_match
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct XtcConfig {
pub probability: f32,
pub threshold: f32,
pub min_keep: usize,
}
impl Default for XtcConfig {
fn default() -> Self {
Self {
probability: 0.0,
threshold: 0.5,
min_keep: 1,
}
}
}
impl XtcConfig {
pub fn new(probability: f32) -> Self {
Self {
probability,
..Default::default()
}
}
#[must_use]
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = threshold;
self
}
#[must_use]
pub fn with_min_keep(mut self, min_keep: usize) -> Self {
self.min_keep = min_keep;
self
}
pub fn is_enabled(&self) -> bool {
self.probability > 0.0
}
}
pub fn apply_xtc(logits: &Tensor<f32>, config: &XtcConfig, rng_value: f32) -> Tensor<f32> {
if !config.is_enabled() || rng_value >= config.probability {
return logits.clone();
}
let data = logits.data();
if data.len() <= config.min_keep {
return logits.clone();
}
let max_logit = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_logits: Vec<f32> = data.iter().map(|&x| (x - max_logit).exp()).collect();
let sum: f32 = exp_logits.iter().sum();
let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum).collect();
let mut excluded_count = 0;
let mut modified = data.to_vec();
let mut indexed: Vec<(usize, f32)> = probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for (idx, prob) in &indexed {
if *prob >= config.threshold && data.len() - excluded_count > config.min_keep {
modified[*idx] = f32::NEG_INFINITY;
excluded_count += 1;
}
}
Tensor::from_vec(logits.shape().to_vec(), modified).expect("Shape should match original logits")
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EtaConfig {
pub eta: f32,
pub min_p: f32,
}
impl Default for EtaConfig {
fn default() -> Self {
Self {
eta: 0.3,
min_p: 0.0001,
}
}
}
impl EtaConfig {
pub fn new(eta: f32) -> Self {
Self {
eta,
..Default::default()
}
}
#[must_use]
pub fn with_min_p(mut self, min_p: f32) -> Self {
self.min_p = min_p;
self
}
pub fn is_enabled(&self) -> bool {
self.eta > 0.0
}
}
pub fn sample_eta(logits: &Tensor<f32>, config: &EtaConfig, rng_value: f32) -> Result<usize> {
let data = logits.data();
if data.is_empty() {
return Err(crate::error::RealizarError::InvalidShape {
reason: "Logits cannot be empty".to_string(),
});
}
let max_logit = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_logits: Vec<f32> = data.iter().map(|&x| (x - max_logit).exp()).collect();
let sum: f32 = exp_logits.iter().sum();
let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum).collect();
let entropy: f32 = -probs
.iter()
.filter(|&&p| p > 1e-10)
.map(|&p| p * p.ln())
.sum::<f32>();
let threshold = (config.eta * (-entropy).exp()).max(config.min_p);
let mut indexed: Vec<(usize, f32)> = probs
.iter()
.enumerate()
.filter(|(_, &p)| p >= threshold)
.map(|(i, &p)| (i, p))
.collect();
if indexed.is_empty() {
let max_idx = probs
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i);
return Ok(max_idx);
}
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let sum_kept: f32 = indexed.iter().map(|(_, p)| p).sum();
let normalized: Vec<f32> = indexed.iter().map(|(_, p)| p / sum_kept).collect();
let indices: Vec<usize> = indexed.iter().map(|(idx, _)| *idx).collect();
Ok(sample_from_distribution(&normalized, &indices, rng_value))
}
#[derive(Debug, Clone, Default)]
pub struct TokenHealingConfig {
pub enabled: bool,
pub max_backup_chars: usize,
}
impl TokenHealingConfig {
pub fn new(enabled: bool) -> Self {
Self {
enabled,
max_backup_chars: 10,
}
}
#[must_use]
pub fn with_max_backup(mut self, chars: usize) -> Self {
self.max_backup_chars = chars;
self
}
}
#[derive(Debug, Clone)]
pub struct TokenHealingResult {
pub adjusted_tokens: Vec<usize>,
pub prefix_constraint: Option<String>,
pub tokens_removed: usize,
}
pub fn analyze_token_healing(
prompt_tokens: &[usize],
last_token_text: Option<&str>,
) -> TokenHealingResult {
let should_heal = last_token_text.is_some_and(|text| {
!text.is_empty()
&& !text.starts_with(' ')
&& text.len() <= 3
&& text.chars().all(char::is_alphanumeric)
});
if should_heal && !prompt_tokens.is_empty() {
TokenHealingResult {
adjusted_tokens: prompt_tokens[..prompt_tokens.len() - 1].to_vec(),
prefix_constraint: last_token_text.map(String::from),
tokens_removed: 1,
}
} else {
TokenHealingResult {
adjusted_tokens: prompt_tokens.to_vec(),
prefix_constraint: None,
tokens_removed: 0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CfgConfig {
pub scale: f32,
pub negative_prompt_tokens: Vec<usize>,
}
impl Default for CfgConfig {
fn default() -> Self {
Self {
scale: 1.0,
negative_prompt_tokens: Vec::new(),
}
}
}
impl CfgConfig {
pub fn new(scale: f32) -> Self {
Self {
scale,
..Default::default()
}
}
#[must_use]
pub fn with_negative_prompt(mut self, tokens: Vec<usize>) -> Self {
self.negative_prompt_tokens = tokens;
self
}
pub fn is_enabled(&self) -> bool {
self.scale > 1.0
}
}