impl PromptCache {
pub fn new(max_entries: usize) -> Self {
Self {
entries: std::collections::HashMap::new(),
max_entries,
}
}
fn hash_tokens(tokens: &[usize]) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
tokens.hash(&mut hasher);
hasher.finish()
}
pub fn find_prefix(&mut self, tokens: &[usize]) -> Option<(usize, u64)> {
for len in (1..=tokens.len()).rev() {
let prefix = &tokens[..len];
let hash = Self::hash_tokens(prefix);
if let Some(entry) = self.entries.get_mut(&hash) {
entry.hit_count += 1;
entry.last_access = std::time::Instant::now();
return Some((len, entry.kv_hash));
}
}
None
}
pub fn add(&mut self, tokens: Vec<usize>, kv_hash: u64) {
if self.entries.len() >= self.max_entries {
self.evict_lru();
}
let hash = Self::hash_tokens(&tokens);
self.entries.insert(
hash,
PromptCacheEntry {
tokens,
kv_hash,
hit_count: 0,
last_access: std::time::Instant::now(),
},
);
}
fn evict_lru(&mut self) {
if let Some((&key, _)) = self.entries.iter().min_by_key(|(_, v)| v.last_access) {
self.entries.remove(&key);
}
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn stats(&self) -> PromptCacheStats {
let total_hits: usize = self.entries.values().map(|e| e.hit_count).sum();
PromptCacheStats {
entries: self.entries.len(),
total_hits,
max_entries: self.max_entries,
}
}
}
#[derive(Debug, Clone)]
pub struct PromptCacheStats {
pub entries: usize,
pub total_hits: usize,
pub max_entries: usize,
}
#[derive(Debug, Clone)]
pub struct BeamHypothesis {
pub tokens: Vec<usize>,
pub score: f32,
pub finished: bool,
}
impl BeamHypothesis {
pub fn new(tokens: Vec<usize>, score: f32) -> Self {
Self {
tokens,
score,
finished: false,
}
}
#[must_use]
pub fn extend(&self, token: usize, log_prob: f32, is_eos: bool) -> Self {
let mut new_tokens = self.tokens.clone();
new_tokens.push(token);
Self {
tokens: new_tokens,
score: self.score + log_prob,
finished: is_eos,
}
}
pub fn normalized_score(&self, length_penalty: f32) -> f32 {
let len = self.tokens.len() as f32;
self.score / len.powf(length_penalty)
}
}
#[derive(Debug, Clone)]
pub struct BeamSearchConfig {
pub num_beams: usize,
pub length_penalty: f32,
pub early_stopping: bool,
pub num_return: usize,
}
impl Default for BeamSearchConfig {
fn default() -> Self {
Self {
num_beams: 4,
length_penalty: 1.0,
early_stopping: true,
num_return: 1,
}
}
}
impl BeamSearchConfig {
pub fn new(num_beams: usize) -> Self {
Self {
num_beams,
..Default::default()
}
}
#[must_use]
pub fn with_length_penalty(mut self, penalty: f32) -> Self {
self.length_penalty = penalty;
self
}
#[must_use]
pub fn with_early_stopping(mut self, early: bool) -> Self {
self.early_stopping = early;
self
}
#[must_use]
pub fn with_num_return(mut self, n: usize) -> Self {
self.num_return = n;
self
}
}
#[derive(Debug, Clone)]
pub struct BeamSearchState {
pub hypotheses: Vec<BeamHypothesis>,
pub finished: Vec<BeamHypothesis>,
pub config: BeamSearchConfig,
}
impl BeamSearchState {
pub fn new(config: BeamSearchConfig, initial_tokens: Vec<usize>) -> Self {
let hypotheses = vec![BeamHypothesis::new(initial_tokens, 0.0)];
Self {
hypotheses,
finished: Vec::new(),
config,
}
}
pub fn step(&mut self, log_probs_per_hyp: &[Vec<f32>], eos_token: Option<usize>) {
let mut candidates: Vec<BeamHypothesis> = Vec::new();
for (hyp_idx, hyp) in self.hypotheses.iter().enumerate() {
if hyp.finished {
candidates.push(hyp.clone());
continue;
}
let log_probs = &log_probs_per_hyp[hyp_idx];
let mut indexed: Vec<(usize, f32)> = log_probs
.iter()
.enumerate()
.map(|(i, &lp)| (i, lp))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for &(token, log_prob) in indexed.iter().take(self.config.num_beams * 2) {
let is_eos = eos_token == Some(token);
let new_hyp = hyp.extend(token, log_prob, is_eos);
if is_eos {
self.finished.push(new_hyp);
} else {
candidates.push(new_hyp);
}
}
}
candidates.sort_by(|a, b| {
let score_a = a.normalized_score(self.config.length_penalty);
let score_b = b.normalized_score(self.config.length_penalty);
score_b
.partial_cmp(&score_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
self.hypotheses = candidates.into_iter().take(self.config.num_beams).collect();
}
pub fn should_stop(&self) -> bool {
if self.config.early_stopping && self.finished.len() >= self.config.num_beams {
return true;
}
self.hypotheses.is_empty() || self.hypotheses.iter().all(|h| h.finished)
}
pub fn best_hypotheses(&self) -> Vec<BeamHypothesis> {
let mut all: Vec<_> = self
.finished
.iter()
.chain(self.hypotheses.iter())
.cloned()
.collect();
all.sort_by(|a, b| {
let score_a = a.normalized_score(self.config.length_penalty);
let score_b = b.normalized_score(self.config.length_penalty);
score_b
.partial_cmp(&score_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
all.into_iter().take(self.config.num_return).collect()
}
}
#[derive(Debug)]
pub struct StreamingGenerator {
pub tokens: Vec<usize>,
pub text: String,
pub finished: bool,
pub total_tokens: usize,
}
impl StreamingGenerator {
pub fn new() -> Self {
Self {
tokens: Vec::new(),
text: String::new(),
finished: false,
total_tokens: 0,
}
}
pub fn add_token(&mut self, token_id: usize, token_text: Option<&str>) {
self.tokens.push(token_id);
if let Some(text) = token_text {
self.text.push_str(text);
}
self.total_tokens += 1;
}
pub fn finish(&mut self) {
self.finished = true;
}
pub fn token_count(&self) -> usize {
self.total_tokens
}
}
impl Default for StreamingGenerator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct AdvancedGenerationConfig {
pub base: GenerationConfig,
pub stop_detector: Option<StopSequenceDetector>,
pub repetition_penalty: Option<RepetitionPenaltyConfig>,
pub presence_frequency: Option<PresenceFrequencyPenalty>,
pub logit_bias: Option<LogitBias>,
}
impl AdvancedGenerationConfig {
pub fn new(base: GenerationConfig) -> Self {
Self {
base,
..Default::default()
}
}
#[must_use]
pub fn with_stop_sequences(mut self, stops: Vec<String>) -> Self {
self.stop_detector = Some(StopSequenceDetector::new().with_stop_strings(stops));
self
}
#[must_use]
pub fn with_repetition_penalty(mut self, penalty: f32) -> Self {
self.repetition_penalty = Some(RepetitionPenaltyConfig::new(penalty));
self
}
#[must_use]
pub fn with_presence_frequency(mut self, presence: f32, frequency: f32) -> Self {
self.presence_frequency = Some(PresenceFrequencyPenalty::new(presence, frequency));
self
}
#[must_use]
pub fn with_logit_bias(mut self, bias: LogitBias) -> Self {
self.logit_bias = Some(bias);
self
}
}
pub fn apply_all_penalties(
logits: &Tensor<f32>,
context_tokens: &[usize],
config: &AdvancedGenerationConfig,
) -> Tensor<f32> {
let mut result = logits.clone();
if let Some(ref rep_config) = config.repetition_penalty {
result = apply_repetition_penalty(&result, context_tokens, rep_config);
}
if let Some(ref pf_config) = config.presence_frequency {
result = apply_presence_frequency_penalty(&result, context_tokens, pf_config);
}
if let Some(ref bias) = config.logit_bias {
result = apply_logit_bias(&result, bias);
}
result
}