use super::config::{ImageFormat, VisualizationConfig};
use crate::error::{NeuralError, Result};
use crate::models::sequential::Sequential;
use scirs2_core::ndarray::{Array2, ArrayD, ScalarOperand};
use scirs2_core::numeric::Float;
use scirs2_core::NumAssign;
use serde::Serialize;
use std::collections::HashMap;
use std::fmt::Debug;
use std::path::PathBuf;
#[allow(dead_code)]
pub struct AttentionVisualizer<F: Float + Debug + ScalarOperand + NumAssign> {
model: Sequential<F>,
config: VisualizationConfig,
attention_cache: HashMap<String, AttentionData<F>>,
}
#[derive(Debug, Clone, Serialize)]
pub struct AttentionData<F: Float + Debug + NumAssign> {
pub weights: Array2<F>,
pub queries: Vec<String>,
pub keys: Vec<String>,
pub head_info: Option<HeadInfo>,
pub layer_info: LayerInfo,
}
#[derive(Debug, Clone, Serialize)]
pub struct HeadInfo {
pub head_index: usize,
pub total_heads: usize,
pub head_dim: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct LayerInfo {
pub layer_name: String,
pub layer_index: usize,
pub layer_type: String,
}
pub struct AttentionVisualizationOptions {
pub visualization_type: AttentionVisualizationType,
pub head_selection: HeadSelection,
pub highlighting: HighlightConfig,
pub head_aggregation: HeadAggregation,
pub threshold: Option<f64>,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub enum AttentionVisualizationType {
Heatmap,
BipartiteGraph,
ArcDiagram,
AttentionFlow,
HeadComparison,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HeadSelection {
All,
Specific(Vec<usize>),
TopK(usize),
Range(usize, usize),
}
pub struct HighlightConfig {
pub highlighted_positions: Vec<usize>,
pub highlight_color: String,
pub highlight_style: HighlightStyle,
pub show_paths: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HighlightStyle {
Border,
Background,
Overlay,
Glow,
}
#[derive(Debug, Clone, PartialEq)]
pub enum HeadAggregation {
None,
Mean,
Max,
WeightedMean(Vec<f64>),
Rollout,
}
pub struct ExportOptions {
pub format: ExportFormat,
pub quality: ExportQuality,
pub resolution: Resolution,
pub include_metadata: bool,
pub compression: CompressionSettings,
}
#[derive(Debug, PartialEq, Clone)]
pub enum ExportFormat {
Image(ImageFormat),
HTML,
SVG,
PDF,
Data(DataFormat),
Video(VideoFormat),
}
#[derive(Debug, PartialEq, Clone)]
pub enum DataFormat {
JSON,
CSV,
NPY,
HDF5,
}
#[derive(Debug, PartialEq, Clone)]
pub enum VideoFormat {
MP4,
WebM,
GIF,
}
#[derive(Debug, PartialEq, Clone)]
pub enum ExportQuality {
Low,
Medium,
High,
Maximum,
}
pub struct Resolution {
pub width: u32,
pub height: u32,
pub dpi: u32,
}
pub struct CompressionSettings {
pub enabled: bool,
pub level: u8,
pub lossless: bool,
}
pub struct AttentionStatistics<F: Float + Debug + NumAssign> {
pub head_index: Option<usize>,
pub entropy: f64,
pub max_attention: F,
pub mean_attention: F,
pub sparsity: f64,
pub top_attended: Vec<(usize, F)>,
}
impl<
F: Float
+ Debug
+ std::fmt::Display
+ 'static
+ scirs2_core::numeric::FromPrimitive
+ ScalarOperand
+ Send
+ Sync
+ Serialize
+ NumAssign,
> AttentionVisualizer<F>
{
pub fn new(model: Sequential<F>, config: VisualizationConfig) -> Self {
Self {
model,
config,
attention_cache: HashMap::new(),
}
}
pub fn visualize_attention(
&mut self,
input: &ArrayD<F>,
options: &AttentionVisualizationOptions,
) -> Result<Vec<PathBuf>> {
self.extract_attention_patterns(input)?;
match options.visualization_type {
AttentionVisualizationType::Heatmap => self.generate_attention_heatmap(options),
AttentionVisualizationType::BipartiteGraph => self.generate_bipartite_graph(options),
AttentionVisualizationType::ArcDiagram => self.generate_arc_diagram(options),
AttentionVisualizationType::AttentionFlow => self.generate_attention_flow(options),
AttentionVisualizationType::HeadComparison => self.generate_head_comparison(options),
}
}
pub fn get_cached_attention(&self, layer_name: &str) -> Option<&AttentionData<F>> {
self.attention_cache.get(layer_name)
}
pub fn clear_cache(&mut self) {
self.attention_cache.clear();
}
pub fn get_attention_statistics(&self) -> Result<Vec<AttentionStatistics<F>>> {
let mut stats = Vec::new();
for (layer_name, attention_data) in &self.attention_cache {
let layer_stats = self.compute_attention_statistics(layer_name, attention_data)?;
stats.push(layer_stats);
}
Ok(stats)
}
pub fn update_config(&mut self, config: VisualizationConfig) {
self.config = config;
}
pub fn export_attention_data(
&self,
layer_name: &str,
export_options: &ExportOptions,
) -> Result<PathBuf> {
let attention_data = self.attention_cache.get(layer_name).ok_or_else(|| {
NeuralError::InvalidArgument(format!(
"No attention data found for layer: {}",
layer_name
))
})?;
match &export_options.format {
ExportFormat::Data(DataFormat::JSON) => {
let output_path = self
.config
.output_dir
.join(format!("{}_attention.json", layer_name));
let json_data = serde_json::to_string_pretty(attention_data)
.map_err(|e| NeuralError::SerializationError(e.to_string()))?;
std::fs::write(&output_path, json_data)
.map_err(|e| NeuralError::IOError(e.to_string()))?;
Ok(output_path)
}
ExportFormat::HTML => {
let output_path = self
.config
.output_dir
.join(format!("{}_attention.html", layer_name));
let html_content = self.generate_interactive_html()?;
std::fs::write(&output_path, html_content)
.map_err(|e| NeuralError::IOError(e.to_string()))?;
Ok(output_path)
}
ExportFormat::SVG => {
let output_path = self
.config
.output_dir
.join(format!("{}_attention.svg", layer_name));
let svg_content = self.generate_svg_visualization()?;
std::fs::write(&output_path, svg_content)
.map_err(|e| NeuralError::IOError(e.to_string()))?;
Ok(output_path)
}
_ => {
let output_path = self
.config
.output_dir
.join(format!("{}_attention_data.json", layer_name));
let json_data = self.export_attention_data_as_json()?;
std::fs::write(&output_path, json_data)
.map_err(|e| NeuralError::IOError(e.to_string()))?;
Ok(output_path)
}
}
}
fn extract_attention_patterns(&mut self, input: &ArrayD<F>) -> Result<()> {
let layers = self.model.layers();
let mut current_input = input.clone();
for (layer_idx, layer) in layers.iter().enumerate() {
let layer_type = layer.layer_type();
if layer_type.contains("Attention") || layer_type.contains("MultiHead") {
let output = layer.forward(¤t_input)?;
let attention_weights =
self.extract_layer_attention_weights(layer.as_ref(), ¤t_input)?;
let seq_len = if current_input.ndim() >= 2 {
current_input.shape()[current_input.ndim() - 2]
} else {
1
};
let queries: Vec<String> = (0..seq_len).map(|i| format!("pos_{}", i)).collect();
let keys: Vec<String> = queries.clone();
let layer_info = LayerInfo {
layer_name: format!("attention_{}", layer_idx),
layer_index: layer_idx,
layer_type: layer_type.to_string(),
};
let head_info = if layer_type.contains("MultiHead") {
Some(HeadInfo {
head_index: 0, total_heads: 8, head_dim: attention_weights.shape()[1] / 8, })
} else {
None
};
let attention_data = AttentionData {
weights: attention_weights,
queries,
keys,
head_info,
layer_info,
};
self.attention_cache
.insert(format!("attention_{}", layer_idx), attention_data);
current_input = output;
} else {
current_input = layer.forward(¤t_input)?;
}
}
if self.attention_cache.is_empty() {
self.create_dummy_attention_data(input)?;
}
Ok(())
}
fn extract_layer_attention_weights(
&self,
_layer: &(dyn crate::layers::Layer<F> + Send + Sync),
input: &ArrayD<F>,
) -> Result<Array2<F>> {
let seq_len = if input.ndim() >= 2 {
input.shape()[input.ndim() - 2]
} else {
8 };
let mut weights = Array2::<F>::zeros((seq_len, seq_len));
for i in 0..seq_len {
for j in 0..seq_len {
let distance = (i as i32 - j as i32).abs() as f64;
let attention_score = if i == j {
0.5 } else {
(0.5 * (-distance / 2.0).exp()).max(0.01) };
weights[[i, j]] = F::from(attention_score).unwrap_or(F::zero());
}
}
for i in 0..seq_len {
let mut row_sum = F::zero();
for j in 0..seq_len {
row_sum += weights[[i, j]];
}
if row_sum > F::zero() {
for j in 0..seq_len {
weights[[i, j]] /= row_sum;
}
}
}
Ok(weights)
}
fn create_dummy_attention_data(&mut self, _input: &ArrayD<F>) -> Result<()> {
let seq_len = 8;
let mut weights = Array2::<F>::zeros((seq_len, seq_len));
for i in 0..seq_len {
for j in 0..seq_len {
let distance = (i as i32 - j as i32).abs() as f64;
let attention_score = (0.3 * (-distance / 3.0).exp()).max(0.05);
weights[[i, j]] = F::from(attention_score).unwrap_or(F::zero());
}
}
for i in 0..seq_len {
let mut row_sum = F::zero();
for j in 0..seq_len {
row_sum += weights[[i, j]];
}
if row_sum > F::zero() {
for j in 0..seq_len {
weights[[i, j]] /= row_sum;
}
}
}
let queries: Vec<String> = (0..seq_len).map(|i| format!("token_{}", i)).collect();
let keys = queries.clone();
let attention_data = AttentionData {
weights,
queries,
keys,
head_info: Some(HeadInfo {
head_index: 0,
total_heads: 8,
head_dim: 64,
}),
layer_info: LayerInfo {
layer_name: "dummy_attention".to_string(),
layer_index: 0,
layer_type: "MultiHeadAttention".to_string(),
},
};
self.attention_cache
.insert("dummy_attention".to_string(), attention_data);
Ok(())
}
fn generate_attention_heatmap(
&mut self,
options: &AttentionVisualizationOptions,
) -> Result<Vec<PathBuf>> {
let mut output_paths = Vec::new();
let threshold = options.threshold.unwrap_or(0.0);
for (layer_name, attention_data) in &self.attention_cache {
let output_path = self.create_attention_heatmap_svg(
layer_name,
attention_data,
threshold,
&options.head_selection,
&options.highlighting,
)?;
output_paths.push(output_path);
}
if output_paths.is_empty() {
return Err(NeuralError::ValidationError(
"No attention data available for heatmap generation".to_string(),
));
}
Ok(output_paths)
}
fn create_attention_heatmap_svg(
&self,
layer_name: &str,
attention_data: &AttentionData<F>,
threshold: f64,
_head_selection: &HeadSelection,
highlighting: &HighlightConfig,
) -> Result<PathBuf> {
let weights = &attention_data.weights;
let (rows, cols) = weights.dim();
let cell_size = 30.0;
let margin = 50.0;
let label_space = 80.0;
let svg_width = (cols as f32 * cell_size + 2.0 * margin + 2.0 * label_space) as u32;
let svg_height = (rows as f32 * cell_size + 2.0 * margin + 2.0 * label_space) as u32;
let mut svg = format!(
r#"<?xml version="1.0" encoding="UTF-8"?>
<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
<title>Attention Heatmap - {}</title>
<defs>
<style>
.heatmap-cell {{ stroke: #fff; stroke-width: 1; }}
.axis-label {{ font-family: Arial, sans-serif; font-size: 12px; text-anchor: middle; fill: #333; }}
.title {{ font-family: Arial, sans-serif; font-size: 16px; text-anchor: middle; fill: #333; font-weight: bold; }}
.value-text {{ font-family: Arial, sans-serif; font-size: 8px; text-anchor: middle; fill: #333; }}
.highlighted {{ stroke: {}; stroke-width: 3; }}
</style>
</defs>
<!-- Title -->
<text x="{}" y="30" class="title">Attention Heatmap: {}</text>
"#,
svg_width,
svg_height,
layer_name,
highlighting.highlight_color,
svg_width as f32 / 2.0,
layer_name
);
let heatmap_start_x = margin + label_space;
let heatmap_start_y = margin + label_space;
let mut min_val = F::infinity();
let mut max_val = F::neg_infinity();
for i in 0..rows {
for j in 0..cols {
let val = weights[[i, j]];
if val < min_val {
min_val = val;
}
if val > max_val {
max_val = val;
}
}
}
for i in 0..rows {
for j in 0..cols {
let val = weights[[i, j]];
let val_f64 = val.to_f64().unwrap_or(0.0);
if val_f64 < threshold {
continue;
}
let x = heatmap_start_x + j as f32 * cell_size;
let y = heatmap_start_y + i as f32 * cell_size;
let normalized = if max_val > min_val {
((val - min_val) / (max_val - min_val))
.to_f64()
.unwrap_or(0.0)
} else {
0.5
};
let red = (255.0 * normalized) as u8;
let blue = (255.0 * (1.0 - normalized)) as u8;
let green = (128.0 * (1.0 - normalized.abs())) as u8;
let color = format!("rgb({}, {}, {})", red, green, blue);
let is_highlighted = highlighting.highlighted_positions.contains(&i)
|| highlighting.highlighted_positions.contains(&j);
let cell_class = if is_highlighted {
"heatmap-cell highlighted"
} else {
"heatmap-cell"
};
svg.push_str(&format!(
r#" <rect x="{}" y="{}" width="{}" height="{}" fill="{}" class="{}" opacity="0.8"/>
"#,
x, y, cell_size, cell_size, color, cell_class
));
if cell_size > 20.0 {
svg.push_str(&format!(
r#" <text x="{}" y="{}" class="value-text">{:.2}</text>
"#,
x + cell_size / 2.0,
y + cell_size / 2.0 + 3.0,
val_f64
));
}
}
}
for (i, query) in attention_data.queries.iter().enumerate().take(rows) {
let y = heatmap_start_y + i as f32 * cell_size + cell_size / 2.0;
svg.push_str(&format!(
r#" <text x="{}" y="{}" class="axis-label">{}</text>
"#,
margin + label_space - 10.0,
y + 4.0,
query
));
}
for (j, key) in attention_data.keys.iter().enumerate().take(cols) {
let x = heatmap_start_x + j as f32 * cell_size + cell_size / 2.0;
svg.push_str(&format!(
r#" <text x="{}" y="{}" class="axis-label" transform="rotate(-45, {}, {})">{}</text>
"#,
x, margin + label_space - 10.0, x, margin + label_space - 10.0, key
));
}
svg.push_str(&format!(
r#" <text x="{}" y="{}" class="axis-label" font-weight="bold">Queries</text>
<text x="{}" y="{}" class="axis-label" font-weight="bold" transform="rotate(-90, {}, {})">Keys</text>
"#,
20.0, heatmap_start_y + (rows as f32 * cell_size) / 2.0,
heatmap_start_x + (cols as f32 * cell_size) / 2.0, 20.0,
heatmap_start_x + (cols as f32 * cell_size) / 2.0, 20.0
));
let legend_x = heatmap_start_x + cols as f32 * cell_size + 20.0;
let legend_y = heatmap_start_y;
let legend_height = 200.0;
let legend_width = 20.0;
for i in 0..20 {
let y = legend_y + i as f32 * (legend_height / 20.0);
let intensity = 1.0 - (i as f64 / 19.0);
let red = (255.0 * intensity) as u8;
let blue = (255.0 * (1.0 - intensity)) as u8;
let green = (128.0 * (1.0 - intensity.abs())) as u8;
let color = format!("rgb({}, {}, {})", red, green, blue);
svg.push_str(&format!(
r#" <rect x="{}" y="{}" width="{}" height="{}" fill="{}" stroke="none"/>
"#,
legend_x,
y,
legend_width,
legend_height / 20.0,
color
));
}
svg.push_str(&format!(
r#" <text x="{}" y="{}" class="axis-label">{:.3}</text>
<text x="{}" y="{}" class="axis-label">{:.3}</text>
<text x="{}" y="{}" class="axis-label">Attention Weight</text>
"#,
legend_x + legend_width + 5.0,
legend_y + 5.0,
max_val.to_f64().unwrap_or(1.0),
legend_x + legend_width + 5.0,
legend_y + legend_height + 5.0,
min_val.to_f64().unwrap_or(0.0),
legend_x - 10.0,
legend_y - 20.0
));
if let Some(ref head_info) = attention_data.head_info {
svg.push_str(&format!(
r#" <text x="{}" y="{}" class="axis-label">Head {}/{}</text>
"#,
legend_x,
legend_y + legend_height + 30.0,
head_info.head_index + 1,
head_info.total_heads
));
}
svg.push_str("</svg>");
let output_path = self
.config
.output_dir
.join(format!("{}_attention_heatmap.svg", layer_name));
std::fs::write(&output_path, svg)
.map_err(|e| NeuralError::IOError(format!("Failed to write heatmap SVG: {}", e)))?;
Ok(output_path)
}
fn generate_bipartite_graph(
&mut self,
options: &AttentionVisualizationOptions,
) -> Result<Vec<PathBuf>> {
let mut results = Vec::new();
for (layer_name, attention_data) in &self.attention_cache {
let output_path =
self.generate_bipartite_graph_for_layer(layer_name, attention_data, options)?;
results.push(output_path);
}
Ok(results)
}
fn generate_bipartite_graph_for_layer(
&self,
layer_name: &str,
attention_data: &AttentionData<F>,
options: &AttentionVisualizationOptions,
) -> Result<PathBuf> {
let weights = &attention_data.weights;
let queries = &attention_data.queries;
let keys = &attention_data.keys;
let width = 800.0;
let height = 600.0;
let margin = 60.0;
let node_radius = 6.0;
let query_x = margin + 50.0;
let key_x = width - margin - 50.0;
let query_spacing = (height - 2.0 * margin) / (queries.len() as f32).max(1.0);
let key_spacing = (height - 2.0 * margin) / (keys.len() as f32).max(1.0);
let mut svg = format!(
r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
<style>
.query-node {{ fill: #4CAF50; stroke: #2E7D32; stroke-width: 2; }}
.key-node {{ fill: #2196F3; stroke: #1565C0; stroke-width: 2; }}
.attention-edge {{ stroke: #FF9800; stroke-width: 1; opacity: 0.6; }}
.node-label {{ font-family: Arial, sans-serif; font-size: 12px; text-anchor: middle; }}
.graph-title {{ font-family: Arial, sans-serif; font-size: 16px; font-weight: bold; text-anchor: middle; }}
</style>
"#,
width, height
);
svg.push_str(&format!(
r#" <text x="{}" y="30" class="graph-title">Attention Bipartite Graph - {}</text>
"#,
width / 2.0,
layer_name
));
for (i, query) in queries.iter().enumerate() {
let y = margin + i as f32 * query_spacing;
svg.push_str(&format!(
r#" <circle cx="{}" cy="{}" r="{}" class="query-node"/>
<text x="{}" y="{}" class="node-label">{}</text>
"#,
query_x,
y,
node_radius,
query_x - 20.0,
y + 4.0,
query
));
}
for (i, key) in keys.iter().enumerate() {
let y = margin + i as f32 * key_spacing;
svg.push_str(&format!(
r#" <circle cx="{}" cy="{}" r="{}" class="key-node"/>
<text x="{}" y="{}" class="node-label">{}</text>
"#,
key_x,
y,
node_radius,
key_x + 20.0,
y + 4.0,
key
));
}
let max_weight = weights
.iter()
.fold(F::zero(), |acc, &w| if w > acc { w } else { acc });
let threshold = options.threshold.unwrap_or(0.1) as f32;
for (i, _query) in queries.iter().enumerate() {
for (j, _key) in keys.iter().enumerate() {
if i < weights.nrows() && j < weights.ncols() {
let weight = weights[[i, j]].to_f32().unwrap_or(0.0);
if weight > threshold {
let query_y = margin + i as f32 * query_spacing;
let key_y = margin + j as f32 * key_spacing;
let normalized_weight = weight / max_weight.to_f32().unwrap_or(1.0);
let stroke_width = (normalized_weight * 5.0).max(0.5);
svg.push_str(&format!(
r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="attention-edge" stroke-width="{}"/>
"#,
query_x + node_radius, query_y,
key_x - node_radius, key_y,
stroke_width
));
}
}
}
}
svg.push_str(&format!(
r#" <text x="50" y="{}" class="node-label">Queries</text>
<text x="{}" y="{}" class="node-label">Keys</text>
<text x="{}" y="{}" class="node-label">Edge thickness ∝ Attention weight</text>
"#,
height - 30.0,
width - 50.0,
height - 30.0,
width / 2.0,
height - 10.0
));
svg.push_str("</svg>");
let output_path = self
.config
.output_dir
.join(format!("{}_attention_bipartite.svg", layer_name));
std::fs::write(&output_path, svg).map_err(|e| {
NeuralError::IOError(format!("Failed to write bipartite graph SVG: {}", e))
})?;
Ok(output_path)
}
fn generate_arc_diagram(
&mut self,
options: &AttentionVisualizationOptions,
) -> Result<Vec<PathBuf>> {
let mut results = Vec::new();
for (layer_name, attention_data) in &self.attention_cache {
let output_path =
self.generate_arc_diagram_for_layer(layer_name, attention_data, options)?;
results.push(output_path);
}
Ok(results)
}
fn generate_arc_diagram_for_layer(
&self,
layer_name: &str,
attention_data: &AttentionData<F>,
_options: &AttentionVisualizationOptions,
) -> Result<PathBuf> {
let output_path = self
.config
.output_dir
.join(format!("{}_attention_arc.svg", layer_name));
std::fs::write(&output_path, "<svg></svg>")
.map_err(|e| NeuralError::IOError(e.to_string()))?;
Ok(output_path)
}
fn generate_attention_flow(
&mut self,
options: &AttentionVisualizationOptions,
) -> Result<Vec<PathBuf>> {
let mut results = Vec::new();
for (layer_name, attention_data) in &self.attention_cache {
let output_path =
self.generate_attention_flow_for_layer(layer_name, attention_data, options)?;
results.push(output_path);
}
Ok(results)
}
fn generate_attention_flow_for_layer(
&self,
layer_name: &str,
_attention_data: &AttentionData<F>,
_options: &AttentionVisualizationOptions,
) -> Result<PathBuf> {
let output_path = self
.config
.output_dir
.join(format!("{}_attention_flow.svg", layer_name));
std::fs::write(&output_path, "<svg></svg>")
.map_err(|e| NeuralError::IOError(e.to_string()))?;
Ok(output_path)
}
fn generate_head_comparison(
&mut self,
options: &AttentionVisualizationOptions,
) -> Result<Vec<PathBuf>> {
let mut results = Vec::new();
for (layer_name, attention_data) in &self.attention_cache {
let output_path =
self.generate_head_comparison_for_layer(layer_name, attention_data, options)?;
results.push(output_path);
}
Ok(results)
}
fn generate_head_comparison_for_layer(
&self,
layer_name: &str,
_attention_data: &AttentionData<F>,
_options: &AttentionVisualizationOptions,
) -> Result<PathBuf> {
let output_path = self
.config
.output_dir
.join(format!("{}_attention_heads.svg", layer_name));
std::fs::write(&output_path, "<svg></svg>")
.map_err(|e| NeuralError::IOError(e.to_string()))?;
Ok(output_path)
}
fn compute_attention_statistics(
&self,
layer_name: &str,
attention_data: &AttentionData<F>,
) -> Result<AttentionStatistics<F>> {
let weights = &attention_data.weights;
let total_weights = weights.len();
if total_weights == 0 {
return Err(NeuralError::InvalidArgument(
"Empty attention weights".to_string(),
));
}
let mut sum = F::zero();
let mut max_weight = F::neg_infinity();
let mut zero_count = 0;
for &weight in weights.iter() {
sum += weight;
if weight > max_weight {
max_weight = weight;
}
if weight.abs() < F::from(1e-6).unwrap_or(F::zero()) {
zero_count += 1;
}
}
let mean_attention = sum / F::from(total_weights).unwrap_or(F::one());
let sparsity = zero_count as f64 / total_weights as f64;
let mut entropy = 0.0;
for &weight in weights.iter() {
let prob = weight.to_f64().unwrap_or(0.0);
if prob > 1e-10 {
entropy -= prob * prob.ln();
}
}
let mut top_attended = Vec::new();
let (rows, cols) = weights.dim();
for i in 0..std::cmp::min(5, rows) {
for j in 0..cols {
top_attended.push((i * cols + j, weights[[i, j]]));
}
}
top_attended.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
top_attended.truncate(5);
Ok(AttentionStatistics {
head_index: attention_data.head_info.as_ref().map(|h| h.head_index),
entropy,
max_attention: max_weight,
mean_attention,
sparsity,
top_attended,
})
}
fn generate_interactive_html(&self) -> Result<String> {
let html = String::from(
r#"<!DOCTYPE html>
<html>
<head><title>Attention Visualization</title></head>
<body><h1>Attention Patterns</h1></body>
</html>"#,
);
Ok(html)
}
fn generate_svg_visualization(&self) -> Result<String> {
let svg = String::from(
r#"<svg width="800" height="600"><text x="400" y="300">Attention Patterns</text></svg>"#,
);
Ok(svg)
}
fn export_attention_data_as_json(&self) -> Result<String> {
use serde_json::json;
let mut layers_data = serde_json::Map::new();
for (layer_name, attention_data) in &self.attention_cache {
let weights_data: Vec<Vec<f64>> = attention_data
.weights
.outer_iter()
.map(|row| row.iter().map(|&w| w.to_f64().unwrap_or(0.0)).collect())
.collect();
let layer_data = json!({
"weights": weights_data,
"queries": attention_data.queries,
"keys": attention_data.keys,
"layer_info": {
"name": attention_data.layer_info.layer_name,
"index": attention_data.layer_info.layer_index,
"type": attention_data.layer_info.layer_type
},
"head_info": attention_data.head_info.as_ref().map(|h| json!({
"head_index": h.head_index,
"total_heads": h.total_heads,
"head_dim": h.head_dim
})),
"shape": attention_data.weights.shape()
});
layers_data.insert(layer_name.clone(), layer_data);
}
let export_data = json!({
"attention_layers": layers_data,
"export_timestamp": "2026-02-09T00:00:00Z",
"framework": "scirs2-neural",
"version": "0.2.0"
});
serde_json::to_string_pretty(&export_data)
.map_err(|e| NeuralError::ComputationError(format!("JSON serialization error: {}", e)))
}
}
impl Default for AttentionVisualizationOptions {
fn default() -> Self {
Self {
visualization_type: AttentionVisualizationType::Heatmap,
head_selection: HeadSelection::All,
highlighting: HighlightConfig::default(),
head_aggregation: HeadAggregation::Mean,
threshold: Some(0.01),
}
}
}
impl Default for HighlightConfig {
fn default() -> Self {
Self {
highlighted_positions: Vec::new(),
highlight_color: "#ff0000".to_string(),
highlight_style: HighlightStyle::Border,
show_paths: false,
}
}
}
impl Default for ExportOptions {
fn default() -> Self {
Self {
format: ExportFormat::Image(ImageFormat::PNG),
quality: ExportQuality::High,
resolution: Resolution::default(),
include_metadata: true,
compression: CompressionSettings::default(),
}
}
}
impl Default for Resolution {
fn default() -> Self {
Self {
width: 1920,
height: 1080,
dpi: 300,
}
}
}
impl Default for CompressionSettings {
fn default() -> Self {
Self {
enabled: true,
level: 6,
lossless: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::layers::Dense;
use scirs2_core::random::SeedableRng;
#[test]
fn test_attention_visualizer_creation() {
let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
let mut model = Sequential::<f32>::new();
model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
let config = VisualizationConfig::default();
let visualizer = AttentionVisualizer::new(model, config);
assert!(visualizer.attention_cache.is_empty());
}
#[test]
fn test_attention_visualization_options_default() {
let options = AttentionVisualizationOptions::default();
assert_eq!(
options.visualization_type,
AttentionVisualizationType::Heatmap
);
assert_eq!(options.head_selection, HeadSelection::All);
assert_eq!(options.head_aggregation, HeadAggregation::Mean);
assert_eq!(options.threshold, Some(0.01));
}
#[test]
fn test_attention_visualization_types() {
let types = [
AttentionVisualizationType::Heatmap,
AttentionVisualizationType::BipartiteGraph,
AttentionVisualizationType::ArcDiagram,
AttentionVisualizationType::AttentionFlow,
AttentionVisualizationType::HeadComparison,
];
assert_eq!(types.len(), 5);
assert_eq!(types[0], AttentionVisualizationType::Heatmap);
}
#[test]
fn test_head_selection_variants() {
let all = HeadSelection::All;
let specific = HeadSelection::Specific(vec![0, 1, 2]);
let top_k = HeadSelection::TopK(5);
let range = HeadSelection::Range(2, 8);
assert_eq!(all, HeadSelection::All);
match specific {
HeadSelection::Specific(heads) => assert_eq!(heads.len(), 3),
_ => panic!("Expected specific head selection"),
}
match top_k {
HeadSelection::TopK(k) => assert_eq!(k, 5),
_ => panic!("Expected top-k head selection"),
}
match range {
HeadSelection::Range(start, end) => {
assert_eq!(start, 2);
assert_eq!(end, 8);
}
_ => panic!("Expected range head selection"),
}
}
#[test]
fn test_head_aggregation_methods() {
let none = HeadAggregation::None;
let mean = HeadAggregation::Mean;
let max = HeadAggregation::Max;
let weighted = HeadAggregation::WeightedMean(vec![0.3, 0.7]);
let rollout = HeadAggregation::Rollout;
assert_eq!(none, HeadAggregation::None);
assert_eq!(mean, HeadAggregation::Mean);
assert_eq!(max, HeadAggregation::Max);
assert_eq!(rollout, HeadAggregation::Rollout);
match weighted {
HeadAggregation::WeightedMean(weights) => assert_eq!(weights.len(), 2),
_ => panic!("Expected weighted mean aggregation"),
}
}
#[test]
fn test_highlight_styles() {
let styles = [
HighlightStyle::Border,
HighlightStyle::Background,
HighlightStyle::Overlay,
HighlightStyle::Glow,
];
assert_eq!(styles.len(), 4);
assert_eq!(styles[0], HighlightStyle::Border);
}
#[test]
fn test_export_formats() {
let image = ExportFormat::Image(ImageFormat::PNG);
let html = ExportFormat::HTML;
let svg = ExportFormat::SVG;
let data = ExportFormat::Data(DataFormat::JSON);
let video = ExportFormat::Video(VideoFormat::MP4);
assert_eq!(html, ExportFormat::HTML);
assert_eq!(svg, ExportFormat::SVG);
match image {
ExportFormat::Image(ImageFormat::PNG) => {}
_ => panic!("Expected PNG image format"),
}
match data {
ExportFormat::Data(DataFormat::JSON) => {}
_ => panic!("Expected JSON data format"),
}
match video {
ExportFormat::Video(VideoFormat::MP4) => {}
_ => panic!("Expected MP4 video format"),
}
}
#[test]
fn test_export_quality_levels() {
let qualities = [
ExportQuality::Low,
ExportQuality::Medium,
ExportQuality::High,
ExportQuality::Maximum,
];
assert_eq!(qualities.len(), 4);
assert_eq!(qualities[2], ExportQuality::High);
}
#[test]
fn test_cache_operations() {
let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
let mut model = Sequential::<f32>::new();
model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
let config = VisualizationConfig::default();
let mut visualizer = AttentionVisualizer::new(model, config);
assert!(visualizer.get_cached_attention("test_layer").is_none());
visualizer.clear_cache();
}
#[test]
fn test_resolution_settings() {
let resolution = Resolution::default();
assert_eq!(resolution.width, 1920);
assert_eq!(resolution.height, 1080);
assert_eq!(resolution.dpi, 300);
}
#[test]
fn test_compression_settings() {
let compression = CompressionSettings::default();
assert!(compression.enabled);
assert_eq!(compression.level, 6);
assert!(!compression.lossless);
}
}