Skip to main content

trueno/brick/tracing/
logit.rs

1// ============================================================================
2// E.11.4: LogitEvolutionTrace (MLT-03)
3// ============================================================================
4
5/// Logit evolution for a single token through layers.
6///
7/// Tracks how a token's logit value and rank change as hidden states
8/// pass through transformer layers.
9#[derive(Debug, Clone, Default)]
10pub struct TokenLogitEvolution {
11    /// Token ID being tracked
12    pub token_id: u32,
13    /// Token string representation (for display)
14    pub token_str: String,
15    /// Logit value after each layer's contribution
16    pub per_layer_logit: Vec<f32>,
17    /// Rank among vocabulary at each layer (0 = highest probability)
18    pub per_layer_rank: Vec<usize>,
19    /// Final probability after softmax
20    pub final_probability: f32,
21    /// Final rank (0 = selected token)
22    pub final_rank: usize,
23}
24
25impl TokenLogitEvolution {
26    /// Create a new token evolution tracker.
27    pub fn new(token_id: u32, token_str: String) -> Self {
28        Self { token_id, token_str, ..Default::default() }
29    }
30
31    /// Record logit value at a layer.
32    pub fn record_layer(&mut self, logit: f32, rank: usize) {
33        self.per_layer_logit.push(logit);
34        self.per_layer_rank.push(rank);
35    }
36
37    /// Get the layer where this token's rank changed most dramatically.
38    pub fn decisive_layer(&self) -> Option<usize> {
39        if self.per_layer_rank.len() < 2 {
40            return None;
41        }
42
43        let mut max_change = 0i64;
44        let mut decisive = 0;
45
46        for i in 1..self.per_layer_rank.len() {
47            let change = (self.per_layer_rank[i] as i64 - self.per_layer_rank[i - 1] as i64).abs();
48            if change > max_change {
49                max_change = change;
50                decisive = i;
51            }
52        }
53
54        Some(decisive)
55    }
56}
57
58/// Full logit trace for one generation step.
59#[derive(Debug, Clone, Default)]
60pub struct LogitEvolutionTrace {
61    /// Position being generated
62    pub position: usize,
63    /// Tokens being tracked (typically top-k candidates + ground truth)
64    pub tracked_tokens: Vec<TokenLogitEvolution>,
65    /// Which layer had the largest impact on the selected token
66    pub decisive_layer: usize,
67    /// Temperature used for sampling
68    pub temperature: f32,
69    /// Top-p (nucleus) value used
70    pub top_p: f32,
71}
72
73impl LogitEvolutionTrace {
74    /// Create a new logit evolution trace.
75    pub fn new(position: usize, temperature: f32, top_p: f32) -> Self {
76        Self { position, temperature, top_p, ..Default::default() }
77    }
78
79    /// Add a token to track.
80    pub fn track_token(&mut self, token_id: u32, token_str: String) -> &mut TokenLogitEvolution {
81        self.tracked_tokens.push(TokenLogitEvolution::new(token_id, token_str));
82        self.tracked_tokens.last_mut().expect("invariant: just pushed")
83    }
84
85    /// Compute rank of a token in a logit distribution.
86    pub fn compute_rank(logits: &[f32], token_id: u32) -> usize {
87        let target_logit = logits.get(token_id as usize).copied().unwrap_or(f32::MIN);
88
89        logits.iter().filter(|&&l| l > target_logit).count()
90    }
91
92    /// Finalize the trace after generation completes.
93    pub fn finalize(&mut self, selected_token_id: u32) {
94        // Find the decisive layer for the selected token
95        for token in &self.tracked_tokens {
96            if token.token_id == selected_token_id {
97                if let Some(layer) = token.decisive_layer() {
98                    self.decisive_layer = layer;
99                }
100                break;
101            }
102        }
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn test_token_logit_evolution() {
112        let mut evo = TokenLogitEvolution::new(42, "test".to_string());
113        evo.record_layer(1.0, 100);
114        evo.record_layer(2.0, 50);
115        evo.record_layer(3.0, 10);
116
117        assert_eq!(evo.per_layer_logit.len(), 3);
118        assert_eq!(evo.per_layer_rank.len(), 3);
119        assert_eq!(evo.decisive_layer(), Some(1)); // 100->50 is biggest jump
120    }
121
122    #[test]
123    fn test_logit_evolution_trace_compute_rank() {
124        let logits = vec![1.0, 5.0, 3.0, 2.0]; // sorted: 5, 3, 2, 1
125                                               // Token 0 has logit 1.0, rank 3 (3 values above it)
126        assert_eq!(LogitEvolutionTrace::compute_rank(&logits, 0), 3);
127        // Token 1 has logit 5.0, rank 0 (nothing above it)
128        assert_eq!(LogitEvolutionTrace::compute_rank(&logits, 1), 0);
129    }
130}