1use super::config::{ImageFormat, VisualizationConfig};
7use crate::error::{NeuralError, Result};
8use crate::models::sequential::Sequential;
9use scirs2_core::ndarray::{Array2, ArrayD, ScalarOperand};
10use scirs2_core::numeric::Float;
11use scirs2_core::NumAssign;
12use serde::Serialize;
13use std::collections::HashMap;
14use std::fmt::Debug;
15use std::path::PathBuf;
16#[allow(dead_code)]
18pub struct AttentionVisualizer<F: Float + Debug + ScalarOperand + NumAssign> {
19 model: Sequential<F>,
21 config: VisualizationConfig,
23 attention_cache: HashMap<String, AttentionData<F>>,
25}
26#[derive(Debug, Clone, Serialize)]
28pub struct AttentionData<F: Float + Debug + NumAssign> {
29 pub weights: Array2<F>,
31 pub queries: Vec<String>,
33 pub keys: Vec<String>,
35 pub head_info: Option<HeadInfo>,
37 pub layer_info: LayerInfo,
39}
40
41#[derive(Debug, Clone, Serialize)]
43pub struct HeadInfo {
44 pub head_index: usize,
46 pub total_heads: usize,
48 pub head_dim: usize,
50}
51
52#[derive(Debug, Clone, Serialize)]
54pub struct LayerInfo {
55 pub layer_name: String,
57 pub layer_index: usize,
59 pub layer_type: String,
61}
62
63pub struct AttentionVisualizationOptions {
65 pub visualization_type: AttentionVisualizationType,
67 pub head_selection: HeadSelection,
69 pub highlighting: HighlightConfig,
71 pub head_aggregation: HeadAggregation,
73 pub threshold: Option<f64>,
75}
76
77#[derive(Debug, Clone, PartialEq, Serialize)]
79pub enum AttentionVisualizationType {
80 Heatmap,
82 BipartiteGraph,
84 ArcDiagram,
86 AttentionFlow,
88 HeadComparison,
90}
91
92#[derive(Debug, Clone, PartialEq, Eq)]
94pub enum HeadSelection {
95 All,
97 Specific(Vec<usize>),
99 TopK(usize),
101 Range(usize, usize),
103}
104
105pub struct HighlightConfig {
107 pub highlighted_positions: Vec<usize>,
109 pub highlight_color: String,
111 pub highlight_style: HighlightStyle,
113 pub show_paths: bool,
115}
116
117#[derive(Debug, Clone, PartialEq, Eq)]
119pub enum HighlightStyle {
120 Border,
122 Background,
124 Overlay,
126 Glow,
128}
129
130#[derive(Debug, Clone, PartialEq)]
132pub enum HeadAggregation {
133 None,
135 Mean,
137 Max,
139 WeightedMean(Vec<f64>),
141 Rollout,
143}
144
145pub struct ExportOptions {
147 pub format: ExportFormat,
149 pub quality: ExportQuality,
151 pub resolution: Resolution,
153 pub include_metadata: bool,
155 pub compression: CompressionSettings,
157}
158
159#[derive(Debug, PartialEq, Clone)]
161pub enum ExportFormat {
162 Image(ImageFormat),
164 HTML,
166 SVG,
168 PDF,
170 Data(DataFormat),
172 Video(VideoFormat),
174}
175
176#[derive(Debug, PartialEq, Clone)]
178pub enum DataFormat {
179 JSON,
181 CSV,
183 NPY,
185 HDF5,
187}
188
189#[derive(Debug, PartialEq, Clone)]
191pub enum VideoFormat {
192 MP4,
194 WebM,
196 GIF,
198}
199
200#[derive(Debug, PartialEq, Clone)]
202pub enum ExportQuality {
203 Low,
205 Medium,
207 High,
209 Maximum,
211}
212
213pub struct Resolution {
215 pub width: u32,
217 pub height: u32,
219 pub dpi: u32,
221}
222
223pub struct CompressionSettings {
225 pub enabled: bool,
227 pub level: u8,
229 pub lossless: bool,
231}
232
233pub struct AttentionStatistics<F: Float + Debug + NumAssign> {
235 pub head_index: Option<usize>,
237 pub entropy: f64,
239 pub max_attention: F,
241 pub mean_attention: F,
243 pub sparsity: f64,
245 pub top_attended: Vec<(usize, F)>,
247}
248
249impl<
251 F: Float
252 + Debug
253 + std::fmt::Display
254 + 'static
255 + scirs2_core::numeric::FromPrimitive
256 + ScalarOperand
257 + Send
258 + Sync
259 + Serialize
260 + NumAssign,
261 > AttentionVisualizer<F>
262{
263 pub fn new(model: Sequential<F>, config: VisualizationConfig) -> Self {
265 Self {
266 model,
267 config,
268 attention_cache: HashMap::new(),
269 }
270 }
271 pub fn visualize_attention(
273 &mut self,
274 input: &ArrayD<F>,
275 options: &AttentionVisualizationOptions,
276 ) -> Result<Vec<PathBuf>> {
277 self.extract_attention_patterns(input)?;
279 match options.visualization_type {
281 AttentionVisualizationType::Heatmap => self.generate_attention_heatmap(options),
282 AttentionVisualizationType::BipartiteGraph => self.generate_bipartite_graph(options),
283 AttentionVisualizationType::ArcDiagram => self.generate_arc_diagram(options),
284 AttentionVisualizationType::AttentionFlow => self.generate_attention_flow(options),
285 AttentionVisualizationType::HeadComparison => self.generate_head_comparison(options),
286 }
287 }
288
289 pub fn get_cached_attention(&self, layer_name: &str) -> Option<&AttentionData<F>> {
291 self.attention_cache.get(layer_name)
292 }
293
294 pub fn clear_cache(&mut self) {
296 self.attention_cache.clear();
297 }
298
299 pub fn get_attention_statistics(&self) -> Result<Vec<AttentionStatistics<F>>> {
301 let mut stats = Vec::new();
302 for (layer_name, attention_data) in &self.attention_cache {
303 let layer_stats = self.compute_attention_statistics(layer_name, attention_data)?;
304 stats.push(layer_stats);
305 }
306 Ok(stats)
307 }
308
309 pub fn update_config(&mut self, config: VisualizationConfig) {
311 self.config = config;
312 }
313 pub fn export_attention_data(
315 &self,
316 layer_name: &str,
317 export_options: &ExportOptions,
318 ) -> Result<PathBuf> {
319 let attention_data = self.attention_cache.get(layer_name).ok_or_else(|| {
320 NeuralError::InvalidArgument(format!(
321 "No attention data found for layer: {}",
322 layer_name
323 ))
324 })?;
325 match &export_options.format {
326 ExportFormat::Data(DataFormat::JSON) => {
327 let output_path = self
328 .config
329 .output_dir
330 .join(format!("{}_attention.json", layer_name));
331 let json_data = serde_json::to_string_pretty(attention_data)
332 .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
333 std::fs::write(&output_path, json_data)
334 .map_err(|e| NeuralError::IOError(e.to_string()))?;
335 Ok(output_path)
336 }
337 ExportFormat::HTML => {
338 let output_path = self
339 .config
340 .output_dir
341 .join(format!("{}_attention.html", layer_name));
342 let html_content = self.generate_interactive_html()?;
343 std::fs::write(&output_path, html_content)
344 .map_err(|e| NeuralError::IOError(e.to_string()))?;
345 Ok(output_path)
346 }
347 ExportFormat::SVG => {
348 let output_path = self
349 .config
350 .output_dir
351 .join(format!("{}_attention.svg", layer_name));
352 let svg_content = self.generate_svg_visualization()?;
353 std::fs::write(&output_path, svg_content)
354 .map_err(|e| NeuralError::IOError(e.to_string()))?;
355 Ok(output_path)
356 }
357 _ => {
358 let output_path = self
359 .config
360 .output_dir
361 .join(format!("{}_attention_data.json", layer_name));
362 let json_data = self.export_attention_data_as_json()?;
363 std::fs::write(&output_path, json_data)
364 .map_err(|e| NeuralError::IOError(e.to_string()))?;
365 Ok(output_path)
366 }
367 }
368 }
369
370 fn extract_attention_patterns(&mut self, input: &ArrayD<F>) -> Result<()> {
371 let layers = self.model.layers();
374 let mut current_input = input.clone();
375 for (layer_idx, layer) in layers.iter().enumerate() {
377 let layer_type = layer.layer_type();
378 if layer_type.contains("Attention") || layer_type.contains("MultiHead") {
380 let output = layer.forward(¤t_input)?;
382 let attention_weights =
385 self.extract_layer_attention_weights(layer.as_ref(), ¤t_input)?;
386 let seq_len = if current_input.ndim() >= 2 {
388 current_input.shape()[current_input.ndim() - 2]
389 } else {
390 1
391 };
392 let queries: Vec<String> = (0..seq_len).map(|i| format!("pos_{}", i)).collect();
393 let keys: Vec<String> = queries.clone();
394 let layer_info = LayerInfo {
396 layer_name: format!("attention_{}", layer_idx),
397 layer_index: layer_idx,
398 layer_type: layer_type.to_string(),
399 };
400
401 let head_info = if layer_type.contains("MultiHead") {
403 Some(HeadInfo {
404 head_index: 0, total_heads: 8, head_dim: attention_weights.shape()[1] / 8, })
408 } else {
409 None
410 };
411 let attention_data = AttentionData {
413 weights: attention_weights,
414 queries,
415 keys,
416 head_info,
417 layer_info,
418 };
419
420 self.attention_cache
421 .insert(format!("attention_{}", layer_idx), attention_data);
422 current_input = output;
423 } else {
424 current_input = layer.forward(¤t_input)?;
426 }
427 }
428 if self.attention_cache.is_empty() {
430 self.create_dummy_attention_data(input)?;
431 }
432 Ok(())
433 }
434
435 fn extract_layer_attention_weights(
437 &self,
438 _layer: &(dyn crate::layers::Layer<F> + Send + Sync),
439 input: &ArrayD<F>,
440 ) -> Result<Array2<F>> {
441 let seq_len = if input.ndim() >= 2 {
445 input.shape()[input.ndim() - 2]
446 } else {
447 8 };
449 let mut weights = Array2::<F>::zeros((seq_len, seq_len));
451 for i in 0..seq_len {
454 for j in 0..seq_len {
455 let distance = (i as i32 - j as i32).abs() as f64;
456 let attention_score = if i == j {
458 0.5 } else {
460 (0.5 * (-distance / 2.0).exp()).max(0.01) };
462 weights[[i, j]] = F::from(attention_score).unwrap_or(F::zero());
463 }
464 }
465 for i in 0..seq_len {
467 let mut row_sum = F::zero();
468 for j in 0..seq_len {
469 row_sum += weights[[i, j]];
470 }
471 if row_sum > F::zero() {
472 for j in 0..seq_len {
473 weights[[i, j]] /= row_sum;
474 }
475 }
476 }
477 Ok(weights)
478 }
479
480 fn create_dummy_attention_data(&mut self, _input: &ArrayD<F>) -> Result<()> {
482 let seq_len = 8; let mut weights = Array2::<F>::zeros((seq_len, seq_len));
486
487 for i in 0..seq_len {
489 for j in 0..seq_len {
490 let distance = (i as i32 - j as i32).abs() as f64;
491 let attention_score = (0.3 * (-distance / 3.0).exp()).max(0.05);
492 weights[[i, j]] = F::from(attention_score).unwrap_or(F::zero());
493 }
494 }
495
496 for i in 0..seq_len {
498 let mut row_sum = F::zero();
499 for j in 0..seq_len {
500 row_sum += weights[[i, j]];
501 }
502 if row_sum > F::zero() {
503 for j in 0..seq_len {
504 weights[[i, j]] /= row_sum;
505 }
506 }
507 }
508
509 let queries: Vec<String> = (0..seq_len).map(|i| format!("token_{}", i)).collect();
511 let keys = queries.clone();
512 let attention_data = AttentionData {
514 weights,
515 queries,
516 keys,
517 head_info: Some(HeadInfo {
518 head_index: 0,
519 total_heads: 8,
520 head_dim: 64,
521 }),
522 layer_info: LayerInfo {
523 layer_name: "dummy_attention".to_string(),
524 layer_index: 0,
525 layer_type: "MultiHeadAttention".to_string(),
526 },
527 };
528
529 self.attention_cache
530 .insert("dummy_attention".to_string(), attention_data);
531
532 Ok(())
533 }
534
535 fn generate_attention_heatmap(
536 &mut self,
537 options: &AttentionVisualizationOptions,
538 ) -> Result<Vec<PathBuf>> {
539 let mut output_paths = Vec::new();
540
541 let threshold = options.threshold.unwrap_or(0.0);
543
544 for (layer_name, attention_data) in &self.attention_cache {
546 let output_path = self.create_attention_heatmap_svg(
547 layer_name,
548 attention_data,
549 threshold,
550 &options.head_selection,
551 &options.highlighting,
552 )?;
553 output_paths.push(output_path);
554 }
555
556 if output_paths.is_empty() {
557 return Err(NeuralError::ValidationError(
558 "No attention data available for heatmap generation".to_string(),
559 ));
560 }
561
562 Ok(output_paths)
563 }
564
565 fn create_attention_heatmap_svg(
567 &self,
568 layer_name: &str,
569 attention_data: &AttentionData<F>,
570 threshold: f64,
571 _head_selection: &HeadSelection,
572 highlighting: &HighlightConfig,
573 ) -> Result<PathBuf> {
574 let weights = &attention_data.weights;
575 let (rows, cols) = weights.dim();
576 let cell_size = 30.0;
578 let margin = 50.0;
579 let label_space = 80.0;
580 let svg_width = (cols as f32 * cell_size + 2.0 * margin + 2.0 * label_space) as u32;
581 let svg_height = (rows as f32 * cell_size + 2.0 * margin + 2.0 * label_space) as u32;
582 let mut svg = format!(
584 r#"<?xml version="1.0" encoding="UTF-8"?>
585<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
586 <title>Attention Heatmap - {}</title>
587 <defs>
588 <style>
589 .heatmap-cell {{ stroke: #fff; stroke-width: 1; }}
590 .axis-label {{ font-family: Arial, sans-serif; font-size: 12px; text-anchor: middle; fill: #333; }}
591 .title {{ font-family: Arial, sans-serif; font-size: 16px; text-anchor: middle; fill: #333; font-weight: bold; }}
592 .value-text {{ font-family: Arial, sans-serif; font-size: 8px; text-anchor: middle; fill: #333; }}
593 .highlighted {{ stroke: {}; stroke-width: 3; }}
594 </style>
595 </defs>
596
597 <!-- Title -->
598 <text x="{}" y="30" class="title">Attention Heatmap: {}</text>
599"#,
600 svg_width,
601 svg_height,
602 layer_name,
603 highlighting.highlight_color,
604 svg_width as f32 / 2.0,
605 layer_name
606 );
607 let heatmap_start_x = margin + label_space;
609 let heatmap_start_y = margin + label_space;
610 let mut min_val = F::infinity();
612 let mut max_val = F::neg_infinity();
613 for i in 0..rows {
614 for j in 0..cols {
615 let val = weights[[i, j]];
616 if val < min_val {
617 min_val = val;
618 }
619 if val > max_val {
620 max_val = val;
621 }
622 }
623 }
624 for i in 0..rows {
626 for j in 0..cols {
627 let val = weights[[i, j]];
628 let val_f64 = val.to_f64().unwrap_or(0.0);
629 if val_f64 < threshold {
631 continue;
632 }
633 let x = heatmap_start_x + j as f32 * cell_size;
635 let y = heatmap_start_y + i as f32 * cell_size;
636 let normalized = if max_val > min_val {
638 ((val - min_val) / (max_val - min_val))
639 .to_f64()
640 .unwrap_or(0.0)
641 } else {
642 0.5
643 };
644 let red = (255.0 * normalized) as u8;
646 let blue = (255.0 * (1.0 - normalized)) as u8;
647 let green = (128.0 * (1.0 - normalized.abs())) as u8;
648 let color = format!("rgb({}, {}, {})", red, green, blue);
649 let is_highlighted = highlighting.highlighted_positions.contains(&i)
651 || highlighting.highlighted_positions.contains(&j);
652 let cell_class = if is_highlighted {
653 "heatmap-cell highlighted"
654 } else {
655 "heatmap-cell"
656 };
657 svg.push_str(&format!(
659 r#" <rect x="{}" y="{}" width="{}" height="{}" fill="{}" class="{}" opacity="0.8"/>
660"#,
661 x, y, cell_size, cell_size, color, cell_class
662 ));
663 if cell_size > 20.0 {
665 svg.push_str(&format!(
666 r#" <text x="{}" y="{}" class="value-text">{:.2}</text>
667"#,
668 x + cell_size / 2.0,
669 y + cell_size / 2.0 + 3.0,
670 val_f64
671 ));
672 }
673 }
674 }
675 for (i, query) in attention_data.queries.iter().enumerate().take(rows) {
677 let y = heatmap_start_y + i as f32 * cell_size + cell_size / 2.0;
678 svg.push_str(&format!(
679 r#" <text x="{}" y="{}" class="axis-label">{}</text>
680"#,
681 margin + label_space - 10.0,
682 y + 4.0,
683 query
684 ));
685 }
686 for (j, key) in attention_data.keys.iter().enumerate().take(cols) {
688 let x = heatmap_start_x + j as f32 * cell_size + cell_size / 2.0;
689 svg.push_str(&format!(
690 r#" <text x="{}" y="{}" class="axis-label" transform="rotate(-45, {}, {})">{}</text>
691"#,
692 x, margin + label_space - 10.0, x, margin + label_space - 10.0, key
693 ));
694 }
695 svg.push_str(&format!(
697 r#" <text x="{}" y="{}" class="axis-label" font-weight="bold">Queries</text>
698 <text x="{}" y="{}" class="axis-label" font-weight="bold" transform="rotate(-90, {}, {})">Keys</text>
699"#,
700 20.0, heatmap_start_y + (rows as f32 * cell_size) / 2.0,
701 heatmap_start_x + (cols as f32 * cell_size) / 2.0, 20.0,
702 heatmap_start_x + (cols as f32 * cell_size) / 2.0, 20.0
703 ));
704 let legend_x = heatmap_start_x + cols as f32 * cell_size + 20.0;
706 let legend_y = heatmap_start_y;
707 let legend_height = 200.0;
708 let legend_width = 20.0;
709 for i in 0..20 {
711 let y = legend_y + i as f32 * (legend_height / 20.0);
712 let intensity = 1.0 - (i as f64 / 19.0);
713 let red = (255.0 * intensity) as u8;
714 let blue = (255.0 * (1.0 - intensity)) as u8;
715 let green = (128.0 * (1.0 - intensity.abs())) as u8;
716 let color = format!("rgb({}, {}, {})", red, green, blue);
717 svg.push_str(&format!(
718 r#" <rect x="{}" y="{}" width="{}" height="{}" fill="{}" stroke="none"/>
719"#,
720 legend_x,
721 y,
722 legend_width,
723 legend_height / 20.0,
724 color
725 ));
726 }
727 svg.push_str(&format!(
729 r#" <text x="{}" y="{}" class="axis-label">{:.3}</text>
730 <text x="{}" y="{}" class="axis-label">{:.3}</text>
731 <text x="{}" y="{}" class="axis-label">Attention Weight</text>
732"#,
733 legend_x + legend_width + 5.0,
734 legend_y + 5.0,
735 max_val.to_f64().unwrap_or(1.0),
736 legend_x + legend_width + 5.0,
737 legend_y + legend_height + 5.0,
738 min_val.to_f64().unwrap_or(0.0),
739 legend_x - 10.0,
740 legend_y - 20.0
741 ));
742 if let Some(ref head_info) = attention_data.head_info {
744 svg.push_str(&format!(
745 r#" <text x="{}" y="{}" class="axis-label">Head {}/{}</text>
746"#,
747 legend_x,
748 legend_y + legend_height + 30.0,
749 head_info.head_index + 1,
750 head_info.total_heads
751 ));
752 }
753
754 svg.push_str("</svg>");
755 let output_path = self
757 .config
758 .output_dir
759 .join(format!("{}_attention_heatmap.svg", layer_name));
760 std::fs::write(&output_path, svg)
761 .map_err(|e| NeuralError::IOError(format!("Failed to write heatmap SVG: {}", e)))?;
762 Ok(output_path)
763 }
764
765 fn generate_bipartite_graph(
766 &mut self,
767 options: &AttentionVisualizationOptions,
768 ) -> Result<Vec<PathBuf>> {
769 let mut results = Vec::new();
770
771 for (layer_name, attention_data) in &self.attention_cache {
772 let output_path =
773 self.generate_bipartite_graph_for_layer(layer_name, attention_data, options)?;
774 results.push(output_path);
775 }
776
777 Ok(results)
778 }
779
780 fn generate_bipartite_graph_for_layer(
781 &self,
782 layer_name: &str,
783 attention_data: &AttentionData<F>,
784 options: &AttentionVisualizationOptions,
785 ) -> Result<PathBuf> {
786 let weights = &attention_data.weights;
787 let queries = &attention_data.queries;
788 let keys = &attention_data.keys;
789 let width = 800.0;
791 let height = 600.0;
792 let margin = 60.0;
793 let node_radius = 6.0;
794
795 let query_x = margin + 50.0;
797 let key_x = width - margin - 50.0;
798 let query_spacing = (height - 2.0 * margin) / (queries.len() as f32).max(1.0);
799 let key_spacing = (height - 2.0 * margin) / (keys.len() as f32).max(1.0);
800
801 let mut svg = format!(
802 r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
803<style>
804 .query-node {{ fill: #4CAF50; stroke: #2E7D32; stroke-width: 2; }}
805 .key-node {{ fill: #2196F3; stroke: #1565C0; stroke-width: 2; }}
806 .attention-edge {{ stroke: #FF9800; stroke-width: 1; opacity: 0.6; }}
807 .node-label {{ font-family: Arial, sans-serif; font-size: 12px; text-anchor: middle; }}
808 .graph-title {{ font-family: Arial, sans-serif; font-size: 16px; font-weight: bold; text-anchor: middle; }}
809</style>
810"#,
811 width, height
812 );
813
814 svg.push_str(&format!(
816 r#" <text x="{}" y="30" class="graph-title">Attention Bipartite Graph - {}</text>
817"#,
818 width / 2.0,
819 layer_name
820 ));
821 for (i, query) in queries.iter().enumerate() {
823 let y = margin + i as f32 * query_spacing;
824 svg.push_str(&format!(
825 r#" <circle cx="{}" cy="{}" r="{}" class="query-node"/>
826 <text x="{}" y="{}" class="node-label">{}</text>
827"#,
828 query_x,
829 y,
830 node_radius,
831 query_x - 20.0,
832 y + 4.0,
833 query
834 ));
835 }
836
837 for (i, key) in keys.iter().enumerate() {
839 let y = margin + i as f32 * key_spacing;
840 svg.push_str(&format!(
841 r#" <circle cx="{}" cy="{}" r="{}" class="key-node"/>
842 <text x="{}" y="{}" class="node-label">{}</text>
843"#,
844 key_x,
845 y,
846 node_radius,
847 key_x + 20.0,
848 y + 4.0,
849 key
850 ));
851 }
852 let max_weight = weights
854 .iter()
855 .fold(F::zero(), |acc, &w| if w > acc { w } else { acc });
856 let threshold = options.threshold.unwrap_or(0.1) as f32;
857
858 for (i, _query) in queries.iter().enumerate() {
859 for (j, _key) in keys.iter().enumerate() {
860 if i < weights.nrows() && j < weights.ncols() {
861 let weight = weights[[i, j]].to_f32().unwrap_or(0.0);
862 if weight > threshold {
863 let query_y = margin + i as f32 * query_spacing;
864 let key_y = margin + j as f32 * key_spacing;
865 let normalized_weight = weight / max_weight.to_f32().unwrap_or(1.0);
866 let stroke_width = (normalized_weight * 5.0).max(0.5);
867 svg.push_str(&format!(
868 r#" <line x1="{}" y1="{}" x2="{}" y2="{}" class="attention-edge" stroke-width="{}"/>
869"#,
870 query_x + node_radius, query_y,
871 key_x - node_radius, key_y,
872 stroke_width
873 ));
874 }
875 }
876 }
877 }
878 svg.push_str(&format!(
880 r#" <text x="50" y="{}" class="node-label">Queries</text>
881 <text x="{}" y="{}" class="node-label">Keys</text>
882 <text x="{}" y="{}" class="node-label">Edge thickness ∝ Attention weight</text>
883"#,
884 height - 30.0,
885 width - 50.0,
886 height - 30.0,
887 width / 2.0,
888 height - 10.0
889 ));
890
891 svg.push_str("</svg>");
892
893 let output_path = self
894 .config
895 .output_dir
896 .join(format!("{}_attention_bipartite.svg", layer_name));
897 std::fs::write(&output_path, svg).map_err(|e| {
898 NeuralError::IOError(format!("Failed to write bipartite graph SVG: {}", e))
899 })?;
900 Ok(output_path)
901 }
902
903 fn generate_arc_diagram(
904 &mut self,
905 options: &AttentionVisualizationOptions,
906 ) -> Result<Vec<PathBuf>> {
907 let mut results = Vec::new();
908 for (layer_name, attention_data) in &self.attention_cache {
909 let output_path =
910 self.generate_arc_diagram_for_layer(layer_name, attention_data, options)?;
911 results.push(output_path);
912 }
913 Ok(results)
914 }
915
916 fn generate_arc_diagram_for_layer(
917 &self,
918 layer_name: &str,
919 attention_data: &AttentionData<F>,
920 _options: &AttentionVisualizationOptions,
921 ) -> Result<PathBuf> {
922 let output_path = self
924 .config
925 .output_dir
926 .join(format!("{}_attention_arc.svg", layer_name));
927 std::fs::write(&output_path, "<svg></svg>")
928 .map_err(|e| NeuralError::IOError(e.to_string()))?;
929 Ok(output_path)
930 }
931
932 fn generate_attention_flow(
933 &mut self,
934 options: &AttentionVisualizationOptions,
935 ) -> Result<Vec<PathBuf>> {
936 let mut results = Vec::new();
937 for (layer_name, attention_data) in &self.attention_cache {
938 let output_path =
939 self.generate_attention_flow_for_layer(layer_name, attention_data, options)?;
940 results.push(output_path);
941 }
942 Ok(results)
943 }
944
945 fn generate_attention_flow_for_layer(
946 &self,
947 layer_name: &str,
948 _attention_data: &AttentionData<F>,
949 _options: &AttentionVisualizationOptions,
950 ) -> Result<PathBuf> {
951 let output_path = self
953 .config
954 .output_dir
955 .join(format!("{}_attention_flow.svg", layer_name));
956 std::fs::write(&output_path, "<svg></svg>")
957 .map_err(|e| NeuralError::IOError(e.to_string()))?;
958 Ok(output_path)
959 }
960
961 fn generate_head_comparison(
962 &mut self,
963 options: &AttentionVisualizationOptions,
964 ) -> Result<Vec<PathBuf>> {
965 let mut results = Vec::new();
966 for (layer_name, attention_data) in &self.attention_cache {
967 let output_path =
968 self.generate_head_comparison_for_layer(layer_name, attention_data, options)?;
969 results.push(output_path);
970 }
971 Ok(results)
972 }
973
974 fn generate_head_comparison_for_layer(
975 &self,
976 layer_name: &str,
977 _attention_data: &AttentionData<F>,
978 _options: &AttentionVisualizationOptions,
979 ) -> Result<PathBuf> {
980 let output_path = self
982 .config
983 .output_dir
984 .join(format!("{}_attention_heads.svg", layer_name));
985 std::fs::write(&output_path, "<svg></svg>")
986 .map_err(|e| NeuralError::IOError(e.to_string()))?;
987 Ok(output_path)
988 }
989
990 fn compute_attention_statistics(
991 &self,
992 layer_name: &str,
993 attention_data: &AttentionData<F>,
994 ) -> Result<AttentionStatistics<F>> {
995 let weights = &attention_data.weights;
996 let total_weights = weights.len();
997
998 if total_weights == 0 {
999 return Err(NeuralError::InvalidArgument(
1000 "Empty attention weights".to_string(),
1001 ));
1002 }
1003
1004 let mut sum = F::zero();
1006 let mut max_weight = F::neg_infinity();
1007 let mut zero_count = 0;
1008
1009 for &weight in weights.iter() {
1010 sum += weight;
1011 if weight > max_weight {
1012 max_weight = weight;
1013 }
1014 if weight.abs() < F::from(1e-6).unwrap_or(F::zero()) {
1015 zero_count += 1;
1016 }
1017 }
1018
1019 let mean_attention = sum / F::from(total_weights).unwrap_or(F::one());
1020 let sparsity = zero_count as f64 / total_weights as f64;
1021
1022 let mut entropy = 0.0;
1024 for &weight in weights.iter() {
1025 let prob = weight.to_f64().unwrap_or(0.0);
1026 if prob > 1e-10 {
1027 entropy -= prob * prob.ln();
1028 }
1029 }
1030
1031 let mut top_attended = Vec::new();
1033 let (rows, cols) = weights.dim();
1034 for i in 0..std::cmp::min(5, rows) {
1035 for j in 0..cols {
1036 top_attended.push((i * cols + j, weights[[i, j]]));
1037 }
1038 }
1039 top_attended.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1040 top_attended.truncate(5);
1041
1042 Ok(AttentionStatistics {
1043 head_index: attention_data.head_info.as_ref().map(|h| h.head_index),
1044 entropy,
1045 max_attention: max_weight,
1046 mean_attention,
1047 sparsity,
1048 top_attended,
1049 })
1050 }
1051
1052 fn generate_interactive_html(&self) -> Result<String> {
1054 let html = String::from(
1055 r#"<!DOCTYPE html>
1056<html>
1057<head><title>Attention Visualization</title></head>
1058<body><h1>Attention Patterns</h1></body>
1059</html>"#,
1060 );
1061 Ok(html)
1062 }
1063
1064 fn generate_svg_visualization(&self) -> Result<String> {
1066 let svg = String::from(
1067 r#"<svg width="800" height="600"><text x="400" y="300">Attention Patterns</text></svg>"#,
1068 );
1069 Ok(svg)
1070 }
1071
1072 fn export_attention_data_as_json(&self) -> Result<String> {
1074 use serde_json::json;
1075
1076 let mut layers_data = serde_json::Map::new();
1077
1078 for (layer_name, attention_data) in &self.attention_cache {
1079 let weights_data: Vec<Vec<f64>> = attention_data
1080 .weights
1081 .outer_iter()
1082 .map(|row| row.iter().map(|&w| w.to_f64().unwrap_or(0.0)).collect())
1083 .collect();
1084
1085 let layer_data = json!({
1086 "weights": weights_data,
1087 "queries": attention_data.queries,
1088 "keys": attention_data.keys,
1089 "layer_info": {
1090 "name": attention_data.layer_info.layer_name,
1091 "index": attention_data.layer_info.layer_index,
1092 "type": attention_data.layer_info.layer_type
1093 },
1094 "head_info": attention_data.head_info.as_ref().map(|h| json!({
1095 "head_index": h.head_index,
1096 "total_heads": h.total_heads,
1097 "head_dim": h.head_dim
1098 })),
1099 "shape": attention_data.weights.shape()
1100 });
1101
1102 layers_data.insert(layer_name.clone(), layer_data);
1103 }
1104
1105 let export_data = json!({
1106 "attention_layers": layers_data,
1107 "export_timestamp": "2026-02-09T00:00:00Z",
1108 "framework": "scirs2-neural",
1109 "version": "0.2.0"
1110 });
1111
1112 serde_json::to_string_pretty(&export_data)
1113 .map_err(|e| NeuralError::ComputationError(format!("JSON serialization error: {}", e)))
1114 }
1115}
1116
1117impl Default for AttentionVisualizationOptions {
1119 fn default() -> Self {
1120 Self {
1121 visualization_type: AttentionVisualizationType::Heatmap,
1122 head_selection: HeadSelection::All,
1123 highlighting: HighlightConfig::default(),
1124 head_aggregation: HeadAggregation::Mean,
1125 threshold: Some(0.01),
1126 }
1127 }
1128}
1129
1130impl Default for HighlightConfig {
1131 fn default() -> Self {
1132 Self {
1133 highlighted_positions: Vec::new(),
1134 highlight_color: "#ff0000".to_string(),
1135 highlight_style: HighlightStyle::Border,
1136 show_paths: false,
1137 }
1138 }
1139}
1140
1141impl Default for ExportOptions {
1142 fn default() -> Self {
1143 Self {
1144 format: ExportFormat::Image(ImageFormat::PNG),
1145 quality: ExportQuality::High,
1146 resolution: Resolution::default(),
1147 include_metadata: true,
1148 compression: CompressionSettings::default(),
1149 }
1150 }
1151}
1152
1153impl Default for Resolution {
1154 fn default() -> Self {
1155 Self {
1156 width: 1920,
1157 height: 1080,
1158 dpi: 300,
1159 }
1160 }
1161}
1162
1163impl Default for CompressionSettings {
1164 fn default() -> Self {
1165 Self {
1166 enabled: true,
1167 level: 6,
1168 lossless: false,
1169 }
1170 }
1171}
1172
1173#[cfg(test)]
1174mod tests {
1175 use super::*;
1176 use crate::layers::Dense;
1177 use scirs2_core::random::SeedableRng;
1178
1179 #[test]
1180 fn test_attention_visualizer_creation() {
1181 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
1182 let mut model = Sequential::<f32>::new();
1183 model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
1184 let config = VisualizationConfig::default();
1185 let visualizer = AttentionVisualizer::new(model, config);
1186 assert!(visualizer.attention_cache.is_empty());
1187 }
1188
1189 #[test]
1190 fn test_attention_visualization_options_default() {
1191 let options = AttentionVisualizationOptions::default();
1192 assert_eq!(
1193 options.visualization_type,
1194 AttentionVisualizationType::Heatmap
1195 );
1196 assert_eq!(options.head_selection, HeadSelection::All);
1197 assert_eq!(options.head_aggregation, HeadAggregation::Mean);
1198 assert_eq!(options.threshold, Some(0.01));
1199 }
1200
1201 #[test]
1202 fn test_attention_visualization_types() {
1203 let types = [
1204 AttentionVisualizationType::Heatmap,
1205 AttentionVisualizationType::BipartiteGraph,
1206 AttentionVisualizationType::ArcDiagram,
1207 AttentionVisualizationType::AttentionFlow,
1208 AttentionVisualizationType::HeadComparison,
1209 ];
1210 assert_eq!(types.len(), 5);
1211 assert_eq!(types[0], AttentionVisualizationType::Heatmap);
1212 }
1213
1214 #[test]
1215 fn test_head_selection_variants() {
1216 let all = HeadSelection::All;
1217 let specific = HeadSelection::Specific(vec![0, 1, 2]);
1218 let top_k = HeadSelection::TopK(5);
1219 let range = HeadSelection::Range(2, 8);
1220 assert_eq!(all, HeadSelection::All);
1221 match specific {
1222 HeadSelection::Specific(heads) => assert_eq!(heads.len(), 3),
1223 _ => panic!("Expected specific head selection"),
1224 }
1225 match top_k {
1226 HeadSelection::TopK(k) => assert_eq!(k, 5),
1227 _ => panic!("Expected top-k head selection"),
1228 }
1229 match range {
1230 HeadSelection::Range(start, end) => {
1231 assert_eq!(start, 2);
1232 assert_eq!(end, 8);
1233 }
1234 _ => panic!("Expected range head selection"),
1235 }
1236 }
1237
1238 #[test]
1239 fn test_head_aggregation_methods() {
1240 let none = HeadAggregation::None;
1241 let mean = HeadAggregation::Mean;
1242 let max = HeadAggregation::Max;
1243 let weighted = HeadAggregation::WeightedMean(vec![0.3, 0.7]);
1244 let rollout = HeadAggregation::Rollout;
1245 assert_eq!(none, HeadAggregation::None);
1246 assert_eq!(mean, HeadAggregation::Mean);
1247 assert_eq!(max, HeadAggregation::Max);
1248 assert_eq!(rollout, HeadAggregation::Rollout);
1249 match weighted {
1250 HeadAggregation::WeightedMean(weights) => assert_eq!(weights.len(), 2),
1251 _ => panic!("Expected weighted mean aggregation"),
1252 }
1253 }
1254
1255 #[test]
1256 fn test_highlight_styles() {
1257 let styles = [
1258 HighlightStyle::Border,
1259 HighlightStyle::Background,
1260 HighlightStyle::Overlay,
1261 HighlightStyle::Glow,
1262 ];
1263 assert_eq!(styles.len(), 4);
1264 assert_eq!(styles[0], HighlightStyle::Border);
1265 }
1266
1267 #[test]
1268 fn test_export_formats() {
1269 let image = ExportFormat::Image(ImageFormat::PNG);
1270 let html = ExportFormat::HTML;
1271 let svg = ExportFormat::SVG;
1272 let data = ExportFormat::Data(DataFormat::JSON);
1273 let video = ExportFormat::Video(VideoFormat::MP4);
1274 assert_eq!(html, ExportFormat::HTML);
1275 assert_eq!(svg, ExportFormat::SVG);
1276 match image {
1277 ExportFormat::Image(ImageFormat::PNG) => {}
1278 _ => panic!("Expected PNG image format"),
1279 }
1280 match data {
1281 ExportFormat::Data(DataFormat::JSON) => {}
1282 _ => panic!("Expected JSON data format"),
1283 }
1284 match video {
1285 ExportFormat::Video(VideoFormat::MP4) => {}
1286 _ => panic!("Expected MP4 video format"),
1287 }
1288 }
1289
1290 #[test]
1291 fn test_export_quality_levels() {
1292 let qualities = [
1293 ExportQuality::Low,
1294 ExportQuality::Medium,
1295 ExportQuality::High,
1296 ExportQuality::Maximum,
1297 ];
1298 assert_eq!(qualities.len(), 4);
1299 assert_eq!(qualities[2], ExportQuality::High);
1300 }
1301 #[test]
1302 fn test_cache_operations() {
1303 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
1304 let mut model = Sequential::<f32>::new();
1305 model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
1306 let config = VisualizationConfig::default();
1307 let mut visualizer = AttentionVisualizer::new(model, config);
1308 assert!(visualizer.get_cached_attention("test_layer").is_none());
1309 visualizer.clear_cache();
1310 }
1311
1312 #[test]
1313 fn test_resolution_settings() {
1314 let resolution = Resolution::default();
1315 assert_eq!(resolution.width, 1920);
1316 assert_eq!(resolution.height, 1080);
1317 assert_eq!(resolution.dpi, 300);
1318 }
1319
1320 #[test]
1321 fn test_compression_settings() {
1322 let compression = CompressionSettings::default();
1323 assert!(compression.enabled);
1324 assert_eq!(compression.level, 6);
1325 assert!(!compression.lossless);
1326 }
1327}