trueno/brick/tracing/
logit.rs1#[derive(Debug, Clone, Default)]
10pub struct TokenLogitEvolution {
11 pub token_id: u32,
13 pub token_str: String,
15 pub per_layer_logit: Vec<f32>,
17 pub per_layer_rank: Vec<usize>,
19 pub final_probability: f32,
21 pub final_rank: usize,
23}
24
25impl TokenLogitEvolution {
26 pub fn new(token_id: u32, token_str: String) -> Self {
28 Self { token_id, token_str, ..Default::default() }
29 }
30
31 pub fn record_layer(&mut self, logit: f32, rank: usize) {
33 self.per_layer_logit.push(logit);
34 self.per_layer_rank.push(rank);
35 }
36
37 pub fn decisive_layer(&self) -> Option<usize> {
39 if self.per_layer_rank.len() < 2 {
40 return None;
41 }
42
43 let mut max_change = 0i64;
44 let mut decisive = 0;
45
46 for i in 1..self.per_layer_rank.len() {
47 let change = (self.per_layer_rank[i] as i64 - self.per_layer_rank[i - 1] as i64).abs();
48 if change > max_change {
49 max_change = change;
50 decisive = i;
51 }
52 }
53
54 Some(decisive)
55 }
56}
57
58#[derive(Debug, Clone, Default)]
60pub struct LogitEvolutionTrace {
61 pub position: usize,
63 pub tracked_tokens: Vec<TokenLogitEvolution>,
65 pub decisive_layer: usize,
67 pub temperature: f32,
69 pub top_p: f32,
71}
72
73impl LogitEvolutionTrace {
74 pub fn new(position: usize, temperature: f32, top_p: f32) -> Self {
76 Self { position, temperature, top_p, ..Default::default() }
77 }
78
79 pub fn track_token(&mut self, token_id: u32, token_str: String) -> &mut TokenLogitEvolution {
81 self.tracked_tokens.push(TokenLogitEvolution::new(token_id, token_str));
82 self.tracked_tokens.last_mut().expect("invariant: just pushed")
83 }
84
85 pub fn compute_rank(logits: &[f32], token_id: u32) -> usize {
87 let target_logit = logits.get(token_id as usize).copied().unwrap_or(f32::MIN);
88
89 logits.iter().filter(|&&l| l > target_logit).count()
90 }
91
92 pub fn finalize(&mut self, selected_token_id: u32) {
94 for token in &self.tracked_tokens {
96 if token.token_id == selected_token_id {
97 if let Some(layer) = token.decisive_layer() {
98 self.decisive_layer = layer;
99 }
100 break;
101 }
102 }
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 #[test]
111 fn test_token_logit_evolution() {
112 let mut evo = TokenLogitEvolution::new(42, "test".to_string());
113 evo.record_layer(1.0, 100);
114 evo.record_layer(2.0, 50);
115 evo.record_layer(3.0, 10);
116
117 assert_eq!(evo.per_layer_logit.len(), 3);
118 assert_eq!(evo.per_layer_rank.len(), 3);
119 assert_eq!(evo.decisive_layer(), Some(1)); }
121
122 #[test]
123 fn test_logit_evolution_trace_compute_rank() {
124 let logits = vec![1.0, 5.0, 3.0, 2.0]; assert_eq!(LogitEvolutionTrace::compute_rank(&logits, 0), 3);
127 assert_eq!(LogitEvolutionTrace::compute_rank(&logits, 1), 0);
129 }
130}