use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug)]
pub struct AttentionVisualizer {
attention_weights: HashMap<String, AttentionWeights>,
token_vocab: Option<Vec<String>>,
config: AttentionVisualizerConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionVisualizerConfig {
pub normalize: bool,
pub min_weight: f64,
pub max_tokens: usize,
pub color_scheme: ColorScheme,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ColorScheme {
BlueRed,
Grayscale,
Viridis,
Plasma,
}
impl Default for AttentionVisualizerConfig {
fn default() -> Self {
Self {
normalize: true,
min_weight: 0.01,
max_tokens: 512,
color_scheme: ColorScheme::BlueRed,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionWeights {
pub layer_name: String,
pub num_heads: usize,
pub weights: Vec<Vec<Vec<f64>>>,
pub source_tokens: Vec<String>,
pub target_tokens: Vec<String>,
pub attention_type: AttentionType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AttentionType {
SelfAttention,
CrossAttention,
EncoderDecoderAttention,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionAnalysis {
pub layer_name: String,
pub entropy_per_head: Vec<f64>,
pub sparsity_per_head: Vec<f64>,
pub most_attended_tokens: Vec<(usize, f64)>,
pub flow_patterns: Vec<AttentionFlow>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionFlow {
pub from: usize,
pub to: usize,
pub weight: f64,
pub head: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionHeatmap {
pub layer_name: String,
pub head: usize,
pub weights: Vec<Vec<f64>>,
pub row_labels: Vec<String>,
pub col_labels: Vec<String>,
}
impl AttentionVisualizer {
pub fn new() -> Self {
Self {
attention_weights: HashMap::new(),
token_vocab: None,
config: AttentionVisualizerConfig::default(),
}
}
pub fn with_config(config: AttentionVisualizerConfig) -> Self {
Self {
attention_weights: HashMap::new(),
token_vocab: None,
config,
}
}
pub fn set_token_vocab(&mut self, tokens: Vec<String>) {
self.token_vocab = Some(tokens);
}
pub fn register(
&mut self,
layer_name: &str,
weights: Vec<Vec<Vec<f64>>>,
source_tokens: Vec<String>,
target_tokens: Vec<String>,
attention_type: AttentionType,
) -> Result<()> {
let num_heads = weights.len();
let attention_weights = AttentionWeights {
layer_name: layer_name.to_string(),
num_heads,
weights,
source_tokens,
target_tokens,
attention_type,
};
self.attention_weights.insert(layer_name.to_string(), attention_weights);
Ok(())
}
pub fn get_attention(&self, layer_name: &str) -> Option<&AttentionWeights> {
self.attention_weights.get(layer_name)
}
pub fn create_heatmap(&self, layer_name: &str, head: usize) -> Result<AttentionHeatmap> {
let attention = self
.attention_weights
.get(layer_name)
.ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
if head >= attention.num_heads {
anyhow::bail!(
"Head {} out of range (max: {})",
head,
attention.num_heads - 1
);
}
let weights = &attention.weights[head];
Ok(AttentionHeatmap {
layer_name: layer_name.to_string(),
head,
weights: weights.clone(),
row_labels: attention.source_tokens.clone(),
col_labels: attention.target_tokens.clone(),
})
}
pub fn analyze(&self, layer_name: &str) -> Result<AttentionAnalysis> {
let attention = self
.attention_weights
.get(layer_name)
.ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
let entropy_per_head = attention
.weights
.iter()
.map(|head_weights| compute_entropy(head_weights))
.collect();
let sparsity_per_head = attention
.weights
.iter()
.map(|head_weights| compute_sparsity(head_weights, self.config.min_weight))
.collect();
let most_attended_tokens = find_most_attended_tokens(&attention.weights);
let flow_patterns = extract_attention_flows(&attention.weights, self.config.min_weight);
Ok(AttentionAnalysis {
layer_name: layer_name.to_string(),
entropy_per_head,
sparsity_per_head,
most_attended_tokens,
flow_patterns,
})
}
pub fn plot_heatmap_ascii(&self, layer_name: &str, head: usize) -> Result<String> {
let heatmap = self.create_heatmap(layer_name, head)?;
let mut output = String::new();
output.push_str(&format!(
"Attention Heatmap: {} (Head {})\n",
layer_name, head
));
output.push_str(&"=".repeat(60));
output.push('\n');
let max_display = 20;
let display_rows = heatmap.row_labels.len().min(max_display);
let display_cols = heatmap.col_labels.len().min(max_display);
output.push_str(" ");
for col in 0..display_cols {
output.push_str(&format!(
"{:4}",
heatmap.col_labels[col].chars().next().unwrap_or('?')
));
}
output.push('\n');
for row in 0..display_rows {
let label = &heatmap.row_labels[row];
output.push_str(&format!(
"{:6} ",
label.chars().take(6).collect::<String>()
));
for col in 0..display_cols {
let weight = heatmap.weights[row][col];
let symbol = weight_to_symbol(weight);
output.push_str(&format!("{:4}", symbol));
}
output.push('\n');
}
if display_rows < heatmap.row_labels.len() || display_cols < heatmap.col_labels.len() {
output.push_str(&format!(
"\n(Showing {}/{} rows, {}/{} cols)\n",
display_rows,
heatmap.row_labels.len(),
display_cols,
heatmap.col_labels.len()
));
}
Ok(output)
}
pub fn export_to_json(&self, layer_name: &str, output_path: &Path) -> Result<()> {
let attention = self
.attention_weights
.get(layer_name)
.ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
let json = serde_json::to_string_pretty(attention)?;
std::fs::write(output_path, json)?;
Ok(())
}
pub fn export_to_bertviz(&self, layer_name: &str, output_path: &Path) -> Result<()> {
let attention = self
.attention_weights
.get(layer_name)
.ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
let mut html =
String::from("<html><head><title>Attention Visualization</title></head><body>");
html.push_str(&format!("<h1>{}</h1>", layer_name));
for head in 0..attention.num_heads {
html.push_str(&format!("<h2>Head {}</h2>", head));
html.push_str("<table border='1'><tr><th></th>");
for token in &attention.target_tokens {
html.push_str(&format!("<th>{}</th>", html_escape(token)));
}
html.push_str("</tr>");
for (row_idx, source_token) in attention.source_tokens.iter().enumerate() {
html.push_str(&format!("<tr><th>{}</th>", html_escape(source_token)));
for col_idx in 0..attention.target_tokens.len() {
let weight = attention.weights[head][row_idx][col_idx];
let color = weight_to_color(weight);
html.push_str(&format!(
"<td style='background-color: {}'>{:.3}</td>",
color, weight
));
}
html.push_str("</tr>");
}
html.push_str("</table>");
}
html.push_str("</body></html>");
std::fs::write(output_path, html)?;
Ok(())
}
pub fn summary(&self) -> String {
let mut output = String::new();
output.push_str("Attention Summary\n");
output.push_str(&"=".repeat(80));
output.push('\n');
for (layer_name, attention) in &self.attention_weights {
output.push_str(&format!("\nLayer: {}\n", layer_name));
output.push_str(&format!(" Num Heads: {}\n", attention.num_heads));
output.push_str(&format!(
" Seq Length: {}\n",
attention.source_tokens.len()
));
output.push_str(&format!(
" Attention Type: {:?}\n",
attention.attention_type
));
if let Ok(analysis) = self.analyze(layer_name) {
output.push_str(&format!(
" Avg Entropy: {:.4}\n",
analysis.entropy_per_head.iter().sum::<f64>()
/ analysis.entropy_per_head.len() as f64
));
output.push_str(&format!(
" Avg Sparsity: {:.4}\n",
analysis.sparsity_per_head.iter().sum::<f64>()
/ analysis.sparsity_per_head.len() as f64
));
}
}
output
}
pub fn clear(&mut self) {
self.attention_weights.clear();
}
pub fn num_layers(&self) -> usize {
self.attention_weights.len()
}
}
impl Default for AttentionVisualizer {
fn default() -> Self {
Self::new()
}
}
fn compute_entropy(weights: &[Vec<f64>]) -> f64 {
let mut total_entropy = 0.0;
let mut count = 0;
for row in weights {
let sum: f64 = row.iter().sum();
if sum > 0.0 {
let entropy: f64 = row
.iter()
.filter(|&&w| w > 0.0)
.map(|&w| {
let p = w / sum;
-p * p.log2()
})
.sum();
total_entropy += entropy;
count += 1;
}
}
if count > 0 {
total_entropy / count as f64
} else {
0.0
}
}
fn compute_sparsity(weights: &[Vec<f64>], threshold: f64) -> f64 {
let total_weights: usize = weights.iter().map(|row| row.len()).sum();
let sparse_weights: usize =
weights.iter().map(|row| row.iter().filter(|&&w| w < threshold).count()).sum();
if total_weights > 0 {
sparse_weights as f64 / total_weights as f64
} else {
0.0
}
}
fn find_most_attended_tokens(weights: &[Vec<Vec<f64>>]) -> Vec<(usize, f64)> {
let seq_len = if !weights.is_empty() && !weights[0].is_empty() {
weights[0][0].len()
} else {
return Vec::new();
};
let mut token_attention = vec![0.0; seq_len];
for head_weights in weights {
for row in head_weights {
for (i, &weight) in row.iter().enumerate() {
token_attention[i] += weight;
}
}
}
let mut indexed: Vec<_> = token_attention.iter().enumerate().map(|(i, &w)| (i, w)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.into_iter().take(10).collect()
}
fn extract_attention_flows(weights: &[Vec<Vec<f64>>], threshold: f64) -> Vec<AttentionFlow> {
let mut flows = Vec::new();
for (head, head_weights) in weights.iter().enumerate() {
for (from, row) in head_weights.iter().enumerate() {
for (to, &weight) in row.iter().enumerate() {
if weight >= threshold {
flows.push(AttentionFlow {
from,
to,
weight,
head,
});
}
}
}
}
flows.sort_by(|a, b| b.weight.partial_cmp(&a.weight).unwrap_or(std::cmp::Ordering::Equal));
flows.into_iter().take(100).collect()
}
fn weight_to_symbol(weight: f64) -> &'static str {
if weight > 0.8 {
"â–ˆ"
} else if weight > 0.6 {
"â–“"
} else if weight > 0.4 {
"â–’"
} else if weight > 0.2 {
"â–‘"
} else {
" "
}
}
fn weight_to_color(weight: f64) -> String {
let intensity = (weight * 255.0) as u8;
format!("rgb(255, {}, {})", 255 - intensity, 255 - intensity)
}
fn html_escape(s: &str) -> String {
s.replace('&', "&")
.replace('<', "<")
.replace('>', ">")
.replace('"', """)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attention_visualizer_creation() {
let visualizer = AttentionVisualizer::new();
assert_eq!(visualizer.num_layers(), 0);
}
#[test]
fn test_register_attention() {
let mut visualizer = AttentionVisualizer::new();
let weights = vec![vec![
vec![0.5, 0.3, 0.2],
vec![0.1, 0.6, 0.3],
vec![0.2, 0.3, 0.5],
]];
let tokens = vec!["A".to_string(), "B".to_string(), "C".to_string()];
visualizer
.register(
"layer.0",
weights,
tokens.clone(),
tokens,
AttentionType::SelfAttention,
)
.expect("operation failed in test");
assert_eq!(visualizer.num_layers(), 1);
}
#[test]
fn test_create_heatmap() {
let mut visualizer = AttentionVisualizer::new();
let weights = vec![vec![
vec![0.5, 0.3, 0.2],
vec![0.1, 0.6, 0.3],
vec![0.2, 0.3, 0.5],
]];
let tokens = vec!["A".to_string(), "B".to_string(), "C".to_string()];
visualizer
.register(
"layer.0",
weights,
tokens.clone(),
tokens,
AttentionType::SelfAttention,
)
.expect("operation failed in test");
let heatmap = visualizer.create_heatmap("layer.0", 0).expect("operation failed in test");
assert_eq!(heatmap.layer_name, "layer.0");
assert_eq!(heatmap.head, 0);
assert_eq!(heatmap.weights.len(), 3);
}
#[test]
fn test_analyze_attention() {
let mut visualizer = AttentionVisualizer::new();
let weights = vec![vec![
vec![0.7, 0.2, 0.1],
vec![0.1, 0.8, 0.1],
vec![0.1, 0.1, 0.8],
]];
let tokens = vec!["A".to_string(), "B".to_string(), "C".to_string()];
visualizer
.register(
"layer.0",
weights,
tokens.clone(),
tokens,
AttentionType::SelfAttention,
)
.expect("operation failed in test");
let analysis = visualizer.analyze("layer.0").expect("operation failed in test");
assert_eq!(analysis.entropy_per_head.len(), 1);
assert_eq!(analysis.sparsity_per_head.len(), 1);
assert!(!analysis.most_attended_tokens.is_empty());
}
#[test]
fn test_export_to_json() {
use std::env;
let temp_dir = env::temp_dir();
let output_path = temp_dir.join("attention.json");
let mut visualizer = AttentionVisualizer::new();
let weights = vec![vec![vec![1.0]]];
let tokens = vec!["A".to_string()];
visualizer
.register(
"layer.0",
weights,
tokens.clone(),
tokens,
AttentionType::SelfAttention,
)
.expect("operation failed in test");
visualizer
.export_to_json("layer.0", &output_path)
.expect("operation failed in test");
assert!(output_path.exists());
let _ = std::fs::remove_file(output_path);
}
#[test]
fn test_compute_entropy() {
let weights = vec![vec![0.5, 0.3, 0.2], vec![1.0, 0.0, 0.0]];
let entropy = compute_entropy(&weights);
assert!(entropy > 0.0);
}
#[test]
fn test_compute_sparsity() {
let weights = vec![vec![0.9, 0.05, 0.05], vec![0.01, 0.01, 0.98]];
let sparsity = compute_sparsity(&weights, 0.1);
assert!(sparsity > 0.0);
assert!(sparsity <= 1.0);
}
}