use super::{apply_temperature, sample_greedy, sample_token, GenerationConfig};
use crate::error::Result;
use crate::tensor::Tensor;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default)]
pub struct StopSequenceDetector {
token_sequences: Vec<Vec<usize>>,
string_patterns: Vec<String>,
token_buffer: Vec<usize>,
max_seq_len: usize,
}
impl StopSequenceDetector {
pub fn new() -> Self {
Self {
token_sequences: Vec::new(),
string_patterns: Vec::new(),
token_buffer: Vec::new(),
max_seq_len: 0,
}
}
#[must_use]
pub fn with_token_sequence(mut self, sequence: Vec<usize>) -> Self {
if !sequence.is_empty() {
self.max_seq_len = self.max_seq_len.max(sequence.len());
self.token_sequences.push(sequence);
}
self
}
#[must_use]
pub fn with_string_pattern(mut self, pattern: impl Into<String>) -> Self {
let pattern = pattern.into();
if !pattern.is_empty() {
self.string_patterns.push(pattern);
}
self
}
#[must_use]
pub fn with_stop_strings(mut self, stops: Vec<String>) -> Self {
for stop in stops {
if !stop.is_empty() {
self.string_patterns.push(stop);
}
}
self
}
pub fn check_token(&mut self, token_id: usize) -> bool {
self.token_buffer.push(token_id);
if self.token_buffer.len() > self.max_seq_len && self.max_seq_len > 0 {
self.token_buffer.remove(0);
}
for seq in &self.token_sequences {
if self.token_buffer.ends_with(seq) {
return true;
}
}
false
}
pub fn check_text(&self, text: &str) -> Option<usize> {
for pattern in &self.string_patterns {
if let Some(pos) = text.find(pattern) {
return Some(pos);
}
}
None
}
pub fn reset(&mut self) {
self.token_buffer.clear();
}
pub fn has_conditions(&self) -> bool {
!self.token_sequences.is_empty() || !self.string_patterns.is_empty()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RepetitionPenaltyConfig {
pub penalty: f32,
pub window_size: usize,
}
impl Default for RepetitionPenaltyConfig {
fn default() -> Self {
Self {
penalty: 1.0, window_size: 64,
}
}
}
impl RepetitionPenaltyConfig {
pub fn new(penalty: f32) -> Self {
Self {
penalty,
window_size: 64,
}
}
#[must_use]
pub fn with_window(mut self, window_size: usize) -> Self {
self.window_size = window_size;
self
}
pub fn is_enabled(&self) -> bool {
(self.penalty - 1.0).abs() > 1e-6
}
}
pub fn apply_repetition_penalty(
logits: &Tensor<f32>,
context_tokens: &[usize],
config: &RepetitionPenaltyConfig,
) -> Tensor<f32> {
if !config.is_enabled() || context_tokens.is_empty() {
return logits.clone();
}
let data = logits.data();
let mut penalized = data.to_vec();
let vocab_size = data.len();
let window_start = if config.window_size > 0 && context_tokens.len() > config.window_size {
context_tokens.len() - config.window_size
} else {
0
};
let relevant_tokens = &context_tokens[window_start..];
for &token_id in relevant_tokens {
if token_id < vocab_size {
let logit = penalized[token_id];
penalized[token_id] = if logit > 0.0 {
logit / config.penalty
} else {
logit * config.penalty
};
}
}
Tensor::from_vec(logits.shape().to_vec(), penalized)
.expect("Shape should match original logits")
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PresenceFrequencyPenalty {
pub presence_penalty: f32,
pub frequency_penalty: f32,
}
impl Default for PresenceFrequencyPenalty {
fn default() -> Self {
Self {
presence_penalty: 0.0,
frequency_penalty: 0.0,
}
}
}
impl PresenceFrequencyPenalty {
pub fn new(presence: f32, frequency: f32) -> Self {
Self {
presence_penalty: presence,
frequency_penalty: frequency,
}
}
pub fn is_enabled(&self) -> bool {
self.presence_penalty.abs() > 1e-6 || self.frequency_penalty.abs() > 1e-6
}
}
pub fn apply_presence_frequency_penalty(
logits: &Tensor<f32>,
context_tokens: &[usize],
config: &PresenceFrequencyPenalty,
) -> Tensor<f32> {
if !config.is_enabled() || context_tokens.is_empty() {
return logits.clone();
}
let data = logits.data();
let mut penalized = data.to_vec();
let vocab_size = data.len();
let mut token_counts: HashMap<usize, usize> = HashMap::new();
for &token_id in context_tokens {
if token_id < vocab_size {
*token_counts.entry(token_id).or_insert(0) += 1;
}
}
for (token_id, count) in token_counts {
let presence = if count > 0 { 1.0 } else { 0.0 };
penalized[token_id] -= config.presence_penalty * presence;
penalized[token_id] -= config.frequency_penalty * (count as f32);
}
Tensor::from_vec(logits.shape().to_vec(), penalized)
.expect("Shape should match original logits")
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LogitBias {
biases: HashMap<usize, f32>,
}
impl LogitBias {
pub fn new() -> Self {
Self {
biases: HashMap::new(),
}
}
#[must_use]
pub fn with_bias(mut self, token_id: usize, bias: f32) -> Self {
self.biases.insert(token_id, bias);
self
}
#[must_use]
pub fn with_biases(mut self, biases: HashMap<usize, f32>) -> Self {
self.biases.extend(biases);
self
}
pub fn is_empty(&self) -> bool {
self.biases.is_empty()
}
pub fn get(&self, token_id: usize) -> f32 {
self.biases.get(&token_id).copied().unwrap_or(0.0)
}
}
pub fn apply_logit_bias(logits: &Tensor<f32>, bias: &LogitBias) -> Tensor<f32> {
if bias.is_empty() {
return logits.clone();
}
let data = logits.data();
let mut biased = data.to_vec();
let vocab_size = data.len();
for (&token_id, &bias_value) in &bias.biases {
if token_id < vocab_size {
biased[token_id] += bias_value;
}
}
Tensor::from_vec(logits.shape().to_vec(), biased).expect("Shape should match original logits")
}
#[derive(Debug, Clone)]
pub struct PromptCacheEntry {
pub tokens: Vec<usize>,
pub kv_hash: u64,
pub hit_count: usize,
pub last_access: std::time::Instant,
}
#[derive(Debug)]
pub struct PromptCache {
entries: std::collections::HashMap<u64, PromptCacheEntry>,
max_entries: usize,
}
impl Default for PromptCache {
fn default() -> Self {
Self::new(100)
}
}
include!("prompt_cache.rs");
include!("dynamic_temperature.rs");
include!("sampler_topk.rs");
include!("sampler_logit_chain.rs");