#[derive(Debug, Clone)]
pub struct LogitLensResult {
pub layer: usize,
pub predictions: Vec<TokenPrediction>,
}
#[derive(Debug, Clone)]
pub struct TokenPrediction {
pub token_id: u32,
pub token: String,
pub probability: f32,
}
#[derive(Debug)]
pub struct LogitLensAnalysis {
pub input_text: String,
pub layer_results: Vec<LogitLensResult>,
pub n_layers: usize,
}
impl LogitLensAnalysis {
#[must_use]
pub fn new(input_text: String, n_layers: usize) -> Self {
Self {
input_text,
layer_results: Vec::with_capacity(n_layers),
n_layers,
}
}
pub fn push(&mut self, result: LogitLensResult) {
self.layer_results.push(result);
}
#[must_use]
pub fn top_predictions(&self) -> Vec<(&str, f32)> {
self.layer_results
.iter()
.filter_map(|r| r.predictions.first())
.map(|p| (p.token.as_str(), p.probability))
.collect()
}
#[must_use]
pub fn first_appearance(&self, token: &str, k: usize) -> Option<usize> {
for result in &self.layer_results {
let in_top_k = result
.predictions
.iter()
.take(k)
.any(|p| p.token.contains(token));
if in_top_k {
return Some(result.layer);
}
}
None
}
pub fn print_summary(&self) {
println!("=== Logit Lens Analysis ===");
println!("Input: {}", self.input_text);
println!("\nTop prediction at each layer:");
for result in &self.layer_results {
if let Some(top) = result.predictions.first() {
println!(
" Layer {:2}: {:>12} ({})",
result.layer,
format!("\"{}\"", format_token(&top.token)),
format_probability(top.probability),
);
}
}
}
pub fn print_detailed(&self, top_k: usize) {
println!("=== Logit Lens Detailed Analysis ===");
println!("Input: {}", self.input_text);
for result in &self.layer_results {
println!("\nLayer {}:", result.layer);
for (i, pred) in result.predictions.iter().take(top_k).enumerate() {
println!(
" {}. {:>15} ({})",
i + 1,
format!("\"{}\"", format_token(&pred.token)),
format_probability(pred.probability),
);
}
}
}
}
#[must_use]
pub fn decode_predictions_with(
predictions: &[(u32, f32)],
decode_fn: impl Fn(u32) -> String,
) -> Vec<TokenPrediction> {
predictions
.iter()
.map(|&(token_id, prob)| {
let token = decode_fn(token_id);
TokenPrediction {
token_id,
token,
probability: prob,
}
})
.collect()
}
#[must_use]
pub fn format_token(token: &str) -> String {
token
.replace('\n', "\\n")
.replace('\t', "\\t")
.replace('\r', "\\r")
}
#[must_use]
pub fn format_probability(prob: f32) -> String {
let pct = prob * 100.0;
if pct >= 1.0 {
format!("{pct:.1}%")
} else if pct >= 0.01 {
format!("{pct:.3}%")
} else {
format!("{pct:.1e}%")
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn logit_lens_result_basic() {
let result = LogitLensResult {
layer: 0,
predictions: vec![
TokenPrediction {
token_id: 1,
token: "fn".to_string(),
probability: 0.5,
},
TokenPrediction {
token_id: 2,
token: "def".to_string(),
probability: 0.3,
},
],
};
assert_eq!(result.layer, 0);
assert_eq!(result.predictions.len(), 2);
assert_eq!(result.predictions.first().unwrap().token, "fn");
}
#[test]
fn first_appearance_found() {
let mut analysis = LogitLensAnalysis::new("test".to_string(), 3);
analysis.push(LogitLensResult {
layer: 0,
predictions: vec![TokenPrediction {
token_id: 1,
token: "a".to_string(),
probability: 0.5,
}],
});
analysis.push(LogitLensResult {
layer: 1,
predictions: vec![TokenPrediction {
token_id: 2,
token: "#[test]".to_string(),
probability: 0.5,
}],
});
assert_eq!(analysis.first_appearance("#[test]", 1), Some(1));
assert_eq!(analysis.first_appearance("notfound", 1), None);
}
#[test]
fn decode_predictions_with_custom_fn() {
let preds = decode_predictions_with(&[(1, 0.5), (2, 0.3)], |id| format!("tok_{id}"));
assert_eq!(preds.len(), 2);
assert_eq!(preds.first().unwrap().token, "tok_1");
assert_eq!(preds.first().unwrap().token_id, 1);
}
#[test]
fn format_token_escapes() {
assert_eq!(format_token("hello\nworld"), "hello\\nworld");
assert_eq!(format_token("tab\there"), "tab\\there");
assert_eq!(format_token("no_escapes"), "no_escapes");
}
#[test]
fn top_predictions_empty() {
let analysis = LogitLensAnalysis::new("test".to_string(), 0);
assert!(analysis.top_predictions().is_empty());
}
}