use anyhow::Result;
use scirs2_core::ndarray::*; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug)]
pub struct AttentionDebugger {
pub config: AttentionDebugConfig,
attention_maps: Vec<AttentionMap>,
head_analysis: HashMap<usize, AttentionHeadAnalysis>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionDebugConfig {
pub enable_attention_visualization: bool,
pub enable_head_analysis: bool,
pub enable_pattern_detection: bool,
pub attention_threshold: f32,
pub max_heads_to_analyze: usize,
}
impl Default for AttentionDebugConfig {
fn default() -> Self {
Self {
enable_attention_visualization: true,
enable_head_analysis: true,
enable_pattern_detection: true,
attention_threshold: 0.01,
max_heads_to_analyze: 16,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionMap {
pub layer_index: usize,
pub head_index: usize,
pub sequence_length: usize,
pub attention_weights: Vec<Vec<f32>>,
pub attention_pattern: AttentionPattern,
pub attention_entropy: f32,
pub sparsity_ratio: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionHeadAnalysis {
pub head_id: usize,
pub layer_id: usize,
pub specialization_type: HeadSpecializationType,
pub attention_distribution: AttentionDistribution,
pub redundancy_score: f32,
pub importance_score: f32,
pub patterns_detected: Vec<AttentionPattern>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum HeadSpecializationType {
LocalSyntax, LongRange, Positional, ContentBased, Copying, Delimiter, Mixed, Redundant, }
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum AttentionPattern {
Diagonal, Block, Sparse, Uniform, Concentrated, Strided, Random, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionDistribution {
pub mean_attention: f32,
pub std_attention: f32,
pub max_attention: f32,
pub min_attention: f32,
pub entropy: f32,
pub effective_context_length: f32,
}
impl AttentionDebugger {
pub fn new(config: AttentionDebugConfig) -> Self {
Self {
config,
attention_maps: Vec::new(),
head_analysis: HashMap::new(),
}
}
pub fn analyze_attention_layer(
&mut self,
layer_index: usize,
attention_weights: &[ArrayD<f32>], ) -> Result<LayerAttentionAnalysis> {
let mut head_analyses = Vec::new();
let mut attention_maps = Vec::new();
for (head_index, weights) in attention_weights.iter().enumerate() {
if head_index >= self.config.max_heads_to_analyze {
break;
}
let attention_map = self.create_attention_map(layer_index, head_index, weights)?;
attention_maps.push(attention_map.clone());
self.attention_maps.push(attention_map);
let head_analysis = self.analyze_attention_head(layer_index, head_index, weights)?;
head_analyses.push(head_analysis.clone());
self.head_analysis.insert(head_index, head_analysis);
}
let layer_diversity_score = self.compute_layer_diversity(&head_analyses);
let redundancy_analysis = self.analyze_head_redundancy(&head_analyses);
Ok(LayerAttentionAnalysis {
layer_index,
num_heads: attention_weights.len(),
head_analyses,
attention_maps,
layer_diversity_score,
redundancy_analysis,
})
}
fn create_attention_map(
&self,
layer_index: usize,
head_index: usize,
weights: &ArrayD<f32>,
) -> Result<AttentionMap> {
let shape = weights.shape();
if shape.len() != 2 {
return Err(anyhow::anyhow!(
"Expected 2D attention weights, got {}D",
shape.len()
));
}
let seq_len = shape[0];
let attention_weights: Vec<Vec<f32>> =
(0..seq_len).map(|i| (0..shape[1]).map(|j| weights[[i, j]]).collect()).collect();
let pattern = self.detect_attention_pattern(&attention_weights);
let entropy = self.compute_attention_entropy(&attention_weights);
let sparsity = self.compute_sparsity_ratio(&attention_weights);
Ok(AttentionMap {
layer_index,
head_index,
sequence_length: seq_len,
attention_weights,
attention_pattern: pattern,
attention_entropy: entropy,
sparsity_ratio: sparsity,
})
}
fn analyze_attention_head(
&self,
layer_index: usize,
head_index: usize,
weights: &ArrayD<f32>,
) -> Result<AttentionHeadAnalysis> {
let specialization = self.classify_head_specialization(weights)?;
let distribution = self.compute_attention_distribution(weights)?;
let patterns = vec![self.detect_attention_pattern_from_weights(weights)?];
Ok(AttentionHeadAnalysis {
head_id: head_index,
layer_id: layer_index,
specialization_type: specialization,
attention_distribution: distribution,
redundancy_score: 0.0, importance_score: self.compute_head_importance(weights)?,
patterns_detected: patterns,
})
}
fn detect_attention_pattern(&self, weights: &[Vec<f32>]) -> AttentionPattern {
let seq_len = weights.len();
if seq_len == 0 {
return AttentionPattern::Random;
}
let diagonal_strength = self.measure_diagonal_strength(weights);
if diagonal_strength > 0.7 {
return AttentionPattern::Diagonal;
}
let sparsity = self.compute_sparsity_ratio(weights);
if sparsity > 0.8 {
return AttentionPattern::Sparse;
}
let uniformity = self.measure_uniformity(weights);
if uniformity > 0.8 {
return AttentionPattern::Uniform;
}
if self.has_block_structure(weights) {
return AttentionPattern::Block;
}
AttentionPattern::Random
}
fn measure_diagonal_strength(&self, weights: &[Vec<f32>]) -> f32 {
let seq_len = weights.len();
if seq_len == 0 {
return 0.0;
}
let mut diagonal_sum = 0.0;
let mut total_sum = 0.0;
let window_size = 3;
for i in 0..seq_len {
for j in 0..weights[i].len() {
let weight = weights[i][j];
total_sum += weight;
if (i as i32 - j as i32).abs() <= window_size {
diagonal_sum += weight;
}
}
}
if total_sum > 0.0 {
diagonal_sum / total_sum
} else {
0.0
}
}
fn measure_uniformity(&self, weights: &[Vec<f32>]) -> f32 {
let seq_len = weights.len();
if seq_len == 0 {
return 0.0;
}
let expected_weight = 1.0 / seq_len as f32;
let mut deviation_sum = 0.0;
let mut count = 0;
for row in weights {
for &weight in row {
deviation_sum += (weight - expected_weight).abs();
count += 1;
}
}
if count > 0 {
1.0 - (deviation_sum / count as f32)
} else {
0.0
}
}
fn has_block_structure(&self, weights: &[Vec<f32>]) -> bool {
let seq_len = weights.len();
if seq_len < 4 {
return false;
}
let block_size = seq_len / 4;
let mut block_concentrations = Vec::new();
for block_start in (0..seq_len).step_by(block_size) {
let block_end = (block_start + block_size).min(seq_len);
let mut block_sum = 0.0;
let mut block_count = 0;
for i in block_start..block_end {
for j in block_start..(block_end.min(weights[i].len())) {
block_sum += weights[i][j];
block_count += 1;
}
}
if block_count > 0 {
block_concentrations.push(block_sum / block_count as f32);
}
}
if block_concentrations.len() < 2 {
return false;
}
let max_concentration = block_concentrations.iter().cloned().fold(0.0f32, f32::max);
let avg_concentration =
block_concentrations.iter().sum::<f32>() / block_concentrations.len() as f32;
max_concentration > avg_concentration * 2.0
}
fn classify_head_specialization(
&self,
weights: &ArrayD<f32>,
) -> Result<HeadSpecializationType> {
let shape = weights.shape();
if shape.len() != 2 {
return Ok(HeadSpecializationType::Mixed);
}
let seq_len = shape[0];
let weights_2d: Vec<Vec<f32>> =
(0..seq_len).map(|i| (0..shape[1]).map(|j| weights[[i, j]]).collect()).collect();
let diagonal_strength = self.measure_diagonal_strength(&weights_2d);
let long_range_strength = self.measure_long_range_attention(&weights_2d);
let positional_bias = self.measure_positional_bias(&weights_2d);
Ok(if diagonal_strength > 0.7 {
HeadSpecializationType::LocalSyntax
} else if long_range_strength > 0.6 {
HeadSpecializationType::LongRange
} else if positional_bias > 0.8 {
HeadSpecializationType::Positional
} else {
HeadSpecializationType::ContentBased
})
}
fn measure_long_range_attention(&self, weights: &[Vec<f32>]) -> f32 {
let seq_len = weights.len();
if seq_len < 4 {
return 0.0;
}
let mut long_range_sum = 0.0;
let mut total_sum = 0.0;
let long_range_threshold = seq_len / 4;
for i in 0..seq_len {
for j in 0..weights[i].len() {
let weight = weights[i][j];
total_sum += weight;
if (i as i32 - j as i32).abs() > long_range_threshold as i32 {
long_range_sum += weight;
}
}
}
if total_sum > 0.0 {
long_range_sum / total_sum
} else {
0.0
}
}
fn measure_positional_bias(&self, weights: &[Vec<f32>]) -> f32 {
let seq_len = weights.len();
if seq_len == 0 {
return 0.0;
}
let mut position_correlation = 0.0;
let mut count = 0;
for i in 0..seq_len {
for j in 0..weights[i].len().min(seq_len) {
let position_similarity = 1.0 - (i as f32 - j as f32).abs() / seq_len as f32;
position_correlation += weights[i][j] * position_similarity;
count += 1;
}
}
if count > 0 {
position_correlation / count as f32
} else {
0.0
}
}
fn compute_attention_distribution(
&self,
weights: &ArrayD<f32>,
) -> Result<AttentionDistribution> {
let values: Vec<f32> = weights.iter().cloned().collect();
if values.is_empty() {
return Ok(AttentionDistribution {
mean_attention: 0.0,
std_attention: 0.0,
max_attention: 0.0,
min_attention: 0.0,
entropy: 0.0,
effective_context_length: 0.0,
});
}
let mean = values.iter().sum::<f32>() / values.len() as f32;
let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
let std_dev = variance.sqrt();
let max_val = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let min_val = values.iter().cloned().fold(f32::INFINITY, f32::min);
let entropy = self.compute_entropy(&values);
let effective_length = self.compute_effective_context_length(&values);
Ok(AttentionDistribution {
mean_attention: mean,
std_attention: std_dev,
max_attention: max_val,
min_attention: min_val,
entropy,
effective_context_length: effective_length,
})
}
fn compute_entropy(&self, values: &[f32]) -> f32 {
if values.is_empty() {
return 0.0;
}
let sum: f32 = values.iter().sum();
if sum <= 0.0 {
return 0.0;
}
let mut entropy = 0.0;
for &value in values {
if value > 0.0 {
let prob = value / sum;
entropy -= prob * prob.log2();
}
}
entropy
}
fn compute_effective_context_length(&self, values: &[f32]) -> f32 {
if values.is_empty() {
return 0.0;
}
let sum: f32 = values.iter().sum();
if sum <= 0.0 {
return 0.0;
}
let mut sorted_values = values.to_vec();
sorted_values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let mut cumulative_sum = 0.0;
let target_sum = sum * 0.9;
for (i, &value) in sorted_values.iter().enumerate() {
cumulative_sum += value;
if cumulative_sum >= target_sum {
return (i + 1) as f32;
}
}
values.len() as f32
}
fn detect_attention_pattern_from_weights(
&self,
weights: &ArrayD<f32>,
) -> Result<AttentionPattern> {
let shape = weights.shape();
if shape.len() != 2 {
return Ok(AttentionPattern::Random);
}
let weights_2d: Vec<Vec<f32>> = (0..shape[0])
.map(|i| (0..shape[1]).map(|j| weights[[i, j]]).collect())
.collect();
Ok(self.detect_attention_pattern(&weights_2d))
}
fn compute_head_importance(&self, weights: &ArrayD<f32>) -> Result<f32> {
let values: Vec<f32> = weights.iter().cloned().collect();
if values.is_empty() {
return Ok(0.0);
}
let entropy = self.compute_entropy(&values);
let max_entropy = (values.len() as f32).log2();
if max_entropy > 0.0 {
Ok(entropy / max_entropy)
} else {
Ok(0.0)
}
}
fn compute_attention_entropy(&self, weights: &[Vec<f32>]) -> f32 {
let values: Vec<f32> = weights.iter().flatten().cloned().collect();
self.compute_entropy(&values)
}
fn compute_sparsity_ratio(&self, weights: &[Vec<f32>]) -> f32 {
let total_count = weights.iter().map(|row| row.len()).sum::<usize>();
if total_count == 0 {
return 0.0;
}
let sparse_count = weights
.iter()
.flatten()
.filter(|&&w| w < self.config.attention_threshold)
.count();
sparse_count as f32 / total_count as f32
}
fn compute_layer_diversity(&self, head_analyses: &[AttentionHeadAnalysis]) -> f32 {
if head_analyses.len() < 2 {
return 0.0;
}
let mut specialization_counts: HashMap<HeadSpecializationType, usize> = HashMap::new();
for analysis in head_analyses {
*specialization_counts.entry(analysis.specialization_type.clone()).or_insert(0) += 1;
}
let num_types = specialization_counts.len() as f32;
let max_types = 8.0;
num_types / max_types
}
fn analyze_head_redundancy(
&self,
head_analyses: &[AttentionHeadAnalysis],
) -> RedundancyAnalysis {
let mut redundant_heads = Vec::new();
let redundancy_groups = Vec::new();
for i in 0..head_analyses.len() {
for j in (i + 1)..head_analyses.len() {
let similarity = self.compute_head_similarity(&head_analyses[i], &head_analyses[j]);
if similarity > 0.8 {
redundant_heads.push((i, j, similarity));
}
}
}
RedundancyAnalysis {
redundant_head_pairs: redundant_heads,
redundancy_groups,
overall_redundancy_score: self.compute_overall_redundancy(head_analyses),
}
}
fn compute_head_similarity(
&self,
head1: &AttentionHeadAnalysis,
head2: &AttentionHeadAnalysis,
) -> f32 {
let type_similarity =
if head1.specialization_type == head2.specialization_type { 1.0 } else { 0.0 };
let dist_similarity = {
let d1 = &head1.attention_distribution;
let d2 = &head2.attention_distribution;
let mean_diff = (d1.mean_attention - d2.mean_attention).abs();
let std_diff = (d1.std_attention - d2.std_attention).abs();
let entropy_diff = (d1.entropy - d2.entropy).abs();
1.0 - (mean_diff + std_diff + entropy_diff) / 3.0
};
(type_similarity + dist_similarity) / 2.0
}
fn compute_overall_redundancy(&self, head_analyses: &[AttentionHeadAnalysis]) -> f32 {
if head_analyses.len() < 2 {
return 0.0;
}
let mut total_similarity = 0.0;
let mut pair_count = 0;
for i in 0..head_analyses.len() {
for j in (i + 1)..head_analyses.len() {
total_similarity +=
self.compute_head_similarity(&head_analyses[i], &head_analyses[j]);
pair_count += 1;
}
}
if pair_count > 0 {
total_similarity / pair_count as f32
} else {
0.0
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerAttentionAnalysis {
pub layer_index: usize,
pub num_heads: usize,
pub head_analyses: Vec<AttentionHeadAnalysis>,
pub attention_maps: Vec<AttentionMap>,
pub layer_diversity_score: f32,
pub redundancy_analysis: RedundancyAnalysis,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RedundancyAnalysis {
pub redundant_head_pairs: Vec<(usize, usize, f32)>, pub redundancy_groups: Vec<Vec<usize>>,
pub overall_redundancy_score: f32,
}
#[derive(Debug)]
pub struct TransformerDebugger {
pub config: TransformerDebugConfig,
layer_analyses: Vec<LayerAttentionAnalysis>,
attention_debugger: AttentionDebugger,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransformerDebugConfig {
pub attention_config: AttentionDebugConfig,
pub enable_layer_analysis: bool,
pub enable_cross_layer_analysis: bool,
pub max_layers_to_analyze: usize,
}
impl Default for TransformerDebugConfig {
fn default() -> Self {
Self {
attention_config: AttentionDebugConfig::default(),
enable_layer_analysis: true,
enable_cross_layer_analysis: true,
max_layers_to_analyze: 48, }
}
}
impl TransformerDebugger {
pub fn new(config: TransformerDebugConfig) -> Self {
let attention_debugger = AttentionDebugger::new(config.attention_config.clone());
Self {
config,
layer_analyses: Vec::new(),
attention_debugger,
}
}
pub fn analyze_transformer_attention(
&mut self,
model_attention_weights: &[Vec<ArrayD<f32>>], ) -> Result<TransformerAttentionAnalysis> {
let mut layer_analyses = Vec::new();
for (layer_idx, layer_weights) in model_attention_weights.iter().enumerate() {
if layer_idx >= self.config.max_layers_to_analyze {
break;
}
let layer_analysis =
self.attention_debugger.analyze_attention_layer(layer_idx, layer_weights)?;
layer_analyses.push(layer_analysis);
}
self.layer_analyses = layer_analyses.clone();
let cross_layer_analysis = if self.config.enable_cross_layer_analysis {
Some(self.perform_cross_layer_analysis(&layer_analyses)?)
} else {
None
};
Ok(TransformerAttentionAnalysis {
num_layers: model_attention_weights.len(),
layer_analyses,
cross_layer_analysis,
model_attention_summary: self.generate_model_attention_summary()?,
})
}
fn perform_cross_layer_analysis(
&self,
layer_analyses: &[LayerAttentionAnalysis],
) -> Result<CrossLayerAnalysis> {
let attention_evolution = self.analyze_attention_evolution(layer_analyses)?;
let head_consistency = self.analyze_head_consistency(layer_analyses)?;
let pattern_progression = self.analyze_pattern_progression(layer_analyses)?;
Ok(CrossLayerAnalysis {
attention_evolution,
head_consistency,
pattern_progression,
layer_diversity_trend: self.compute_layer_diversity_trend(layer_analyses),
})
}
fn analyze_attention_evolution(
&self,
layer_analyses: &[LayerAttentionAnalysis],
) -> Result<AttentionEvolution> {
let mut entropy_trend = Vec::new();
let mut sparsity_trend = Vec::new();
for layer in layer_analyses {
let layer_entropy: f32 =
layer.attention_maps.iter().map(|map| map.attention_entropy).sum::<f32>()
/ layer.attention_maps.len() as f32;
let layer_sparsity: f32 =
layer.attention_maps.iter().map(|map| map.sparsity_ratio).sum::<f32>()
/ layer.attention_maps.len() as f32;
entropy_trend.push(layer_entropy);
sparsity_trend.push(layer_sparsity);
}
let evolution_type = self.classify_evolution_type(&entropy_trend);
Ok(AttentionEvolution {
entropy_trend,
sparsity_trend,
evolution_type,
})
}
fn classify_evolution_type(&self, entropy_trend: &[f32]) -> EvolutionType {
if entropy_trend.len() < 3 {
return EvolutionType::Stable;
}
let start_entropy = entropy_trend[0];
let end_entropy = entropy_trend[entropy_trend.len() - 1];
let change_ratio = (end_entropy - start_entropy) / start_entropy.max(1e-8);
match change_ratio {
r if r > 0.2 => EvolutionType::Increasing,
r if r < -0.2 => EvolutionType::Decreasing,
_ => EvolutionType::Stable,
}
}
fn analyze_head_consistency(
&self,
layer_analyses: &[LayerAttentionAnalysis],
) -> Result<HeadConsistency> {
let mut specialization_consistency = HashMap::new();
let pattern_consistency = HashMap::new();
for layer in layer_analyses {
for head in &layer.head_analyses {
let spec_type = &head.specialization_type;
let layer_counts =
specialization_consistency.entry(spec_type.clone()).or_insert_with(Vec::new);
layer_counts.push(layer.layer_index);
}
}
Ok(HeadConsistency {
specialization_consistency,
pattern_consistency,
consistency_score: self.compute_consistency_score(layer_analyses),
})
}
fn compute_consistency_score(&self, layer_analyses: &[LayerAttentionAnalysis]) -> f32 {
if layer_analyses.len() < 2 {
return 1.0;
}
let mut layer_distributions = Vec::new();
for layer in layer_analyses {
let mut distribution: HashMap<HeadSpecializationType, f32> = HashMap::new();
for head in &layer.head_analyses {
*distribution.entry(head.specialization_type.clone()).or_insert(0.0) += 1.0;
}
let total: f32 = distribution.values().sum();
if total > 0.0 {
for value in distribution.values_mut() {
*value /= total;
}
}
layer_distributions.push(distribution);
}
let mut total_similarity = 0.0;
let mut pair_count = 0;
for i in 0..layer_distributions.len() {
for j in (i + 1)..layer_distributions.len() {
let similarity = self.compute_distribution_similarity(
&layer_distributions[i],
&layer_distributions[j],
);
total_similarity += similarity;
pair_count += 1;
}
}
if pair_count > 0 {
total_similarity / pair_count as f32
} else {
1.0
}
}
fn compute_distribution_similarity(
&self,
dist1: &HashMap<HeadSpecializationType, f32>,
dist2: &HashMap<HeadSpecializationType, f32>,
) -> f32 {
let mut all_keys: std::collections::HashSet<_> = dist1.keys().collect();
all_keys.extend(dist2.keys());
let mut similarity = 0.0;
for key in all_keys {
let val1 = dist1.get(key).unwrap_or(&0.0);
let val2 = dist2.get(key).unwrap_or(&0.0);
similarity += (val1 - val2).abs();
}
1.0 - (similarity / 2.0) }
fn analyze_pattern_progression(
&self,
layer_analyses: &[LayerAttentionAnalysis],
) -> Result<PatternProgression> {
let mut pattern_evolution = Vec::new();
for layer in layer_analyses {
let mut pattern_counts: HashMap<AttentionPattern, usize> = HashMap::new();
for map in &layer.attention_maps {
*pattern_counts.entry(map.attention_pattern.clone()).or_insert(0) += 1;
}
pattern_evolution.push(pattern_counts);
}
let dominant_pattern_sequence = self.extract_dominant_patterns(&pattern_evolution);
Ok(PatternProgression {
pattern_evolution,
dominant_pattern_sequence,
})
}
fn extract_dominant_patterns(
&self,
pattern_evolution: &[HashMap<AttentionPattern, usize>],
) -> Vec<AttentionPattern> {
pattern_evolution
.iter()
.map(|patterns| {
patterns
.iter()
.max_by_key(|(_, &count)| count)
.map(|(pattern, _)| pattern.clone())
.unwrap_or(AttentionPattern::Random)
})
.collect()
}
fn compute_layer_diversity_trend(&self, layer_analyses: &[LayerAttentionAnalysis]) -> Vec<f32> {
layer_analyses.iter().map(|layer| layer.layer_diversity_score).collect()
}
fn generate_model_attention_summary(&self) -> Result<ModelAttentionSummary> {
if self.layer_analyses.is_empty() {
return Ok(ModelAttentionSummary::default());
}
let total_heads: usize = self.layer_analyses.iter().map(|layer| layer.num_heads).sum();
let avg_diversity: f32 =
self.layer_analyses.iter().map(|layer| layer.layer_diversity_score).sum::<f32>()
/ self.layer_analyses.len() as f32;
let avg_redundancy: f32 = self
.layer_analyses
.iter()
.map(|layer| layer.redundancy_analysis.overall_redundancy_score)
.sum::<f32>()
/ self.layer_analyses.len() as f32;
let mut specialization_distribution: HashMap<HeadSpecializationType, usize> =
HashMap::new();
for layer in &self.layer_analyses {
for head in &layer.head_analyses {
*specialization_distribution
.entry(head.specialization_type.clone())
.or_insert(0) += 1;
}
}
Ok(ModelAttentionSummary {
total_layers: self.layer_analyses.len(),
total_heads,
average_diversity_score: avg_diversity,
average_redundancy_score: avg_redundancy,
specialization_distribution,
model_attention_health: self
.assess_model_attention_health(avg_diversity, avg_redundancy),
})
}
fn assess_model_attention_health(
&self,
diversity: f32,
redundancy: f32,
) -> AttentionHealthStatus {
let health_score = diversity * (1.0 - redundancy);
match health_score {
s if s > 0.7 => AttentionHealthStatus::Excellent,
s if s > 0.5 => AttentionHealthStatus::Good,
s if s > 0.3 => AttentionHealthStatus::Fair,
_ => AttentionHealthStatus::Poor,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TransformerAttentionAnalysis {
pub num_layers: usize,
pub layer_analyses: Vec<LayerAttentionAnalysis>,
pub cross_layer_analysis: Option<CrossLayerAnalysis>,
pub model_attention_summary: ModelAttentionSummary,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CrossLayerAnalysis {
pub attention_evolution: AttentionEvolution,
pub head_consistency: HeadConsistency,
pub pattern_progression: PatternProgression,
pub layer_diversity_trend: Vec<f32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AttentionEvolution {
pub entropy_trend: Vec<f32>,
pub sparsity_trend: Vec<f32>,
pub evolution_type: EvolutionType,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum EvolutionType {
Increasing, Decreasing, Stable, }
#[derive(Debug, Serialize, Deserialize)]
pub struct HeadConsistency {
pub specialization_consistency: HashMap<HeadSpecializationType, Vec<usize>>,
pub pattern_consistency: HashMap<AttentionPattern, Vec<usize>>,
pub consistency_score: f32,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PatternProgression {
pub pattern_evolution: Vec<HashMap<AttentionPattern, usize>>,
pub dominant_pattern_sequence: Vec<AttentionPattern>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ModelAttentionSummary {
pub total_layers: usize,
pub total_heads: usize,
pub average_diversity_score: f32,
pub average_redundancy_score: f32,
pub specialization_distribution: HashMap<HeadSpecializationType, usize>,
pub model_attention_health: AttentionHealthStatus,
}
impl Default for ModelAttentionSummary {
fn default() -> Self {
Self {
total_layers: 0,
total_heads: 0,
average_diversity_score: 0.0,
average_redundancy_score: 0.0,
specialization_distribution: HashMap::new(),
model_attention_health: AttentionHealthStatus::Poor,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum AttentionHealthStatus {
Excellent,
Good,
Fair,
Poor,
}
#[macro_export]
macro_rules! debug_attention {
($attention_weights:expr) => {{
let mut debugger = $crate::neural_network_debugging::AttentionDebugger::new(
$crate::neural_network_debugging::AttentionDebugConfig::default(),
);
debugger.analyze_attention_layer(0, $attention_weights)
}};
}
#[macro_export]
macro_rules! debug_transformer {
($model_weights:expr) => {{
let mut debugger = $crate::neural_network_debugging::TransformerDebugger::new(
$crate::neural_network_debugging::TransformerDebugConfig::default(),
);
debugger.analyze_transformer_attention($model_weights)
}};
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{ArrayD, IxDyn};
fn make_attention_config() -> AttentionDebugConfig {
AttentionDebugConfig::default()
}
fn make_uniform_weights(seq_len: usize) -> ArrayD<f32> {
let val = 1.0 / seq_len as f32;
ArrayD::from_elem(IxDyn(&[seq_len, seq_len]), val)
}
fn make_diagonal_weights(seq_len: usize) -> ArrayD<f32> {
let mut weights = ArrayD::zeros(IxDyn(&[seq_len, seq_len]));
for i in 0..seq_len {
weights[[i, i]] = 1.0;
}
weights
}
fn make_sparse_weights(seq_len: usize) -> ArrayD<f32> {
let mut weights = ArrayD::zeros(IxDyn(&[seq_len, seq_len]));
for i in 0..seq_len {
weights[[i, 0]] = 1.0;
}
weights
}
#[test]
fn test_attention_debug_config_default() {
let config = make_attention_config();
assert!(config.enable_attention_visualization);
assert!(config.enable_head_analysis);
assert_eq!(config.max_heads_to_analyze, 16);
}
#[test]
fn test_attention_debugger_creation() {
let debugger = AttentionDebugger::new(make_attention_config());
assert!(debugger.attention_maps.is_empty());
assert!(debugger.head_analysis.is_empty());
}
#[test]
fn test_analyze_attention_layer_single_head() {
let mut debugger = AttentionDebugger::new(make_attention_config());
let weights = vec![make_uniform_weights(8)];
let result = debugger.analyze_attention_layer(0, &weights);
assert!(result.is_ok());
let analysis = result.expect("analysis should succeed");
assert_eq!(analysis.layer_index, 0);
assert_eq!(analysis.num_heads, 1);
assert_eq!(analysis.head_analyses.len(), 1);
}
#[test]
fn test_analyze_attention_layer_multiple_heads() {
let mut debugger = AttentionDebugger::new(make_attention_config());
let weights = vec![
make_uniform_weights(8),
make_diagonal_weights(8),
make_sparse_weights(8),
];
let result = debugger.analyze_attention_layer(0, &weights);
assert!(result.is_ok());
let analysis = result.expect("analysis should succeed");
assert_eq!(analysis.num_heads, 3);
assert_eq!(analysis.head_analyses.len(), 3);
}
#[test]
fn test_detect_attention_pattern_uniform() {
let debugger = AttentionDebugger::new(make_attention_config());
let seq_len = 10;
let val = 1.0 / seq_len as f32;
let weights: Vec<Vec<f32>> = (0..seq_len).map(|_| vec![val; seq_len]).collect();
let pattern = debugger.detect_attention_pattern(&weights);
assert!(matches!(pattern, AttentionPattern::Uniform));
}
#[test]
fn test_detect_attention_pattern_diagonal() {
let debugger = AttentionDebugger::new(make_attention_config());
let seq_len = 10;
let weights: Vec<Vec<f32>> = (0..seq_len)
.map(|i| {
let mut row = vec![0.0; seq_len];
for j in 0..seq_len {
if (i as i32 - j as i32).abs() <= 1 {
row[j] = 1.0;
}
}
row
})
.collect();
let pattern = debugger.detect_attention_pattern(&weights);
assert!(matches!(pattern, AttentionPattern::Diagonal));
}
#[test]
fn test_detect_attention_pattern_empty() {
let debugger = AttentionDebugger::new(make_attention_config());
let weights: Vec<Vec<f32>> = vec![];
let pattern = debugger.detect_attention_pattern(&weights);
assert!(matches!(pattern, AttentionPattern::Random));
}
#[test]
fn test_measure_diagonal_strength() {
let debugger = AttentionDebugger::new(make_attention_config());
let seq_len = 8;
let weights: Vec<Vec<f32>> = (0..seq_len)
.map(|i| {
let mut row = vec![0.01; seq_len];
row[i] = 10.0;
row
})
.collect();
let strength = debugger.measure_diagonal_strength(&weights);
assert!(strength > 0.5);
}
#[test]
fn test_measure_diagonal_strength_empty() {
let debugger = AttentionDebugger::new(make_attention_config());
let weights: Vec<Vec<f32>> = vec![];
assert!((debugger.measure_diagonal_strength(&weights) - 0.0).abs() < f32::EPSILON);
}
#[test]
fn test_measure_uniformity() {
let debugger = AttentionDebugger::new(make_attention_config());
let seq_len = 8;
let val = 1.0 / seq_len as f32;
let weights: Vec<Vec<f32>> = (0..seq_len).map(|_| vec![val; seq_len]).collect();
let uniformity = debugger.measure_uniformity(&weights);
assert!(uniformity > 0.9);
}
#[test]
fn test_measure_uniformity_empty() {
let debugger = AttentionDebugger::new(make_attention_config());
let weights: Vec<Vec<f32>> = vec![];
assert!((debugger.measure_uniformity(&weights) - 0.0).abs() < f32::EPSILON);
}
#[test]
fn test_has_block_structure_false() {
let debugger = AttentionDebugger::new(make_attention_config());
let seq_len = 8;
let val = 1.0 / seq_len as f32;
let weights: Vec<Vec<f32>> = (0..seq_len).map(|_| vec![val; seq_len]).collect();
assert!(!debugger.has_block_structure(&weights));
}
#[test]
fn test_has_block_structure_small() {
let debugger = AttentionDebugger::new(make_attention_config());
let weights: Vec<Vec<f32>> = vec![vec![1.0], vec![1.0]];
assert!(!debugger.has_block_structure(&weights));
}
#[test]
fn test_classify_head_specialization_local() {
let debugger = AttentionDebugger::new(make_attention_config());
let weights = make_diagonal_weights(10);
let result = debugger.classify_head_specialization(&weights);
assert!(result.is_ok());
let spec = result.expect("classification should succeed");
assert!(matches!(spec, HeadSpecializationType::LocalSyntax));
}
#[test]
fn test_compute_sparsity_ratio() {
let debugger = AttentionDebugger::new(make_attention_config());
let seq_len = 10;
let mut weights: Vec<Vec<f32>> = (0..seq_len).map(|_| vec![0.0; seq_len]).collect();
for i in 0..seq_len {
weights[i][0] = 1.0;
}
let sparsity = debugger.compute_sparsity_ratio(&weights);
assert!(sparsity > 0.8);
}
#[test]
fn test_measure_long_range_attention() {
let debugger = AttentionDebugger::new(make_attention_config());
let seq_len = 20;
let weights: Vec<Vec<f32>> = (0..seq_len)
.map(|_| {
let mut row = vec![0.0; seq_len];
row[seq_len - 1] = 1.0;
row
})
.collect();
let long_range = debugger.measure_long_range_attention(&weights);
assert!(long_range > 0.3);
}
#[test]
fn test_measure_long_range_attention_small() {
let debugger = AttentionDebugger::new(make_attention_config());
let weights: Vec<Vec<f32>> = vec![vec![1.0]];
assert!((debugger.measure_long_range_attention(&weights) - 0.0).abs() < f32::EPSILON);
}
#[test]
fn test_model_attention_summary_default() {
let summary = ModelAttentionSummary::default();
assert_eq!(summary.total_layers, 0);
assert_eq!(summary.total_heads, 0);
assert!(matches!(
summary.model_attention_health,
AttentionHealthStatus::Poor
));
}
#[test]
fn test_analyze_attention_layer_head_limit() {
let mut config = make_attention_config();
config.max_heads_to_analyze = 2;
let mut debugger = AttentionDebugger::new(config);
let weights = vec![
make_uniform_weights(4),
make_uniform_weights(4),
make_uniform_weights(4),
make_uniform_weights(4),
];
let result = debugger.analyze_attention_layer(0, &weights);
assert!(result.is_ok());
let analysis = result.expect("analysis should succeed");
assert_eq!(analysis.head_analyses.len(), 2);
}
#[test]
fn test_create_attention_map_wrong_dimensions() {
let debugger = AttentionDebugger::new(make_attention_config());
let weights_3d = ArrayD::zeros(IxDyn(&[2, 3, 4]));
let result = debugger.create_attention_map(0, 0, &weights_3d);
assert!(result.is_err());
}
#[test]
fn test_attention_entropy_computation() {
let mut debugger = AttentionDebugger::new(make_attention_config());
let weights = vec![make_uniform_weights(8)];
let analysis =
debugger.analyze_attention_layer(0, &weights).expect("analysis should succeed");
assert!(analysis.attention_maps[0].attention_entropy > 0.0);
}
#[test]
fn test_attention_pattern_variants() {
let patterns = [
AttentionPattern::Diagonal,
AttentionPattern::Block,
AttentionPattern::Sparse,
AttentionPattern::Uniform,
AttentionPattern::Concentrated,
AttentionPattern::Strided,
AttentionPattern::Random,
];
assert_eq!(patterns.len(), 7);
}
#[test]
fn test_head_specialization_variants() {
let specs = [
HeadSpecializationType::LocalSyntax,
HeadSpecializationType::LongRange,
HeadSpecializationType::Positional,
HeadSpecializationType::ContentBased,
HeadSpecializationType::Copying,
HeadSpecializationType::Delimiter,
HeadSpecializationType::Mixed,
HeadSpecializationType::Redundant,
];
assert_eq!(specs.len(), 8);
}
#[test]
fn test_attention_health_status_variants() {
let statuses = [
AttentionHealthStatus::Excellent,
AttentionHealthStatus::Good,
AttentionHealthStatus::Fair,
AttentionHealthStatus::Poor,
];
assert_eq!(statuses.len(), 4);
}
#[test]
fn test_attention_distribution_creation() {
let dist = AttentionDistribution {
mean_attention: 0.125,
std_attention: 0.05,
max_attention: 0.5,
min_attention: 0.01,
entropy: 2.8,
effective_context_length: 6.5,
};
assert!(dist.mean_attention > 0.0);
assert!(dist.max_attention > dist.mean_attention);
assert!(dist.entropy > 0.0);
}
#[test]
fn test_attention_map_creation() {
let map = AttentionMap {
layer_index: 0,
head_index: 3,
sequence_length: 8,
attention_weights: vec![vec![0.125; 8]; 8],
attention_pattern: AttentionPattern::Uniform,
attention_entropy: 3.0,
sparsity_ratio: 0.0,
};
assert_eq!(map.layer_index, 0);
assert_eq!(map.head_index, 3);
assert_eq!(map.sequence_length, 8);
}
#[test]
fn test_attention_head_analysis_creation() {
let analysis = AttentionHeadAnalysis {
head_id: 0,
layer_id: 0,
specialization_type: HeadSpecializationType::ContentBased,
attention_distribution: AttentionDistribution {
mean_attention: 0.1,
std_attention: 0.02,
max_attention: 0.3,
min_attention: 0.01,
entropy: 2.5,
effective_context_length: 5.0,
},
redundancy_score: 0.1,
importance_score: 0.8,
patterns_detected: vec![AttentionPattern::Random],
};
assert_eq!(analysis.head_id, 0);
assert!(analysis.importance_score > analysis.redundancy_score);
}
#[test]
fn test_pattern_progression_creation() {
let progression = PatternProgression {
pattern_evolution: vec![HashMap::new()],
dominant_pattern_sequence: vec![AttentionPattern::Diagonal, AttentionPattern::Sparse],
};
assert_eq!(progression.dominant_pattern_sequence.len(), 2);
}
}