#[derive(Debug, Clone, Default)]
pub struct TokenLogitEvolution {
pub token_id: u32,
pub token_str: String,
pub per_layer_logit: Vec<f32>,
pub per_layer_rank: Vec<usize>,
pub final_probability: f32,
pub final_rank: usize,
}
impl TokenLogitEvolution {
pub fn new(token_id: u32, token_str: String) -> Self {
Self { token_id, token_str, ..Default::default() }
}
pub fn record_layer(&mut self, logit: f32, rank: usize) {
self.per_layer_logit.push(logit);
self.per_layer_rank.push(rank);
}
pub fn decisive_layer(&self) -> Option<usize> {
if self.per_layer_rank.len() < 2 {
return None;
}
let mut max_change = 0i64;
let mut decisive = 0;
for i in 1..self.per_layer_rank.len() {
let change = (self.per_layer_rank[i] as i64 - self.per_layer_rank[i - 1] as i64).abs();
if change > max_change {
max_change = change;
decisive = i;
}
}
Some(decisive)
}
}
#[derive(Debug, Clone, Default)]
pub struct LogitEvolutionTrace {
pub position: usize,
pub tracked_tokens: Vec<TokenLogitEvolution>,
pub decisive_layer: usize,
pub temperature: f32,
pub top_p: f32,
}
impl LogitEvolutionTrace {
pub fn new(position: usize, temperature: f32, top_p: f32) -> Self {
Self { position, temperature, top_p, ..Default::default() }
}
pub fn track_token(&mut self, token_id: u32, token_str: String) -> &mut TokenLogitEvolution {
self.tracked_tokens.push(TokenLogitEvolution::new(token_id, token_str));
self.tracked_tokens.last_mut().expect("invariant: just pushed")
}
pub fn compute_rank(logits: &[f32], token_id: u32) -> usize {
let target_logit = logits.get(token_id as usize).copied().unwrap_or(f32::MIN);
logits.iter().filter(|&&l| l > target_logit).count()
}
pub fn finalize(&mut self, selected_token_id: u32) {
for token in &self.tracked_tokens {
if token.token_id == selected_token_id {
if let Some(layer) = token.decisive_layer() {
self.decisive_layer = layer;
}
break;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_logit_evolution() {
let mut evo = TokenLogitEvolution::new(42, "test".to_string());
evo.record_layer(1.0, 100);
evo.record_layer(2.0, 50);
evo.record_layer(3.0, 10);
assert_eq!(evo.per_layer_logit.len(), 3);
assert_eq!(evo.per_layer_rank.len(), 3);
assert_eq!(evo.decisive_layer(), Some(1)); }
#[test]
fn test_logit_evolution_trace_compute_rank() {
let logits = vec![1.0, 5.0, 3.0, 2.0]; assert_eq!(LogitEvolutionTrace::compute_rank(&logits, 0), 3);
assert_eq!(LogitEvolutionTrace::compute_rank(&logits, 1), 0);
}
}