use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::ArrayD;
use scirs2_core::numeric::Float;
use std::collections::HashMap;
use std::fmt::Debug;
use std::iter::Sum;
#[derive(Debug, Clone, PartialEq)]
pub enum VisualizationMethod {
ActivationMaximization {
target_layer: String,
target_unit: Option<usize>,
num_iterations: usize,
learning_rate: f64,
},
DeepDream {
amplify_factor: f64,
FeatureInversion {
regularization_weight: f64,
ClassActivationMapping {
target_class: usize,
NetworkDissection {
concept_data: Vec<ArrayD<f32>>,
concept_labels: Vec<String>,
}
pub enum AttentionAggregation {
Average,
Maximum,
Head(usize),
Weighted(Vec<f64>),
#[derive(Debug, Clone)]
pub struct AttentionVisualizer<F: Float + Debug + NumAssign> {
pub num_heads: usize,
pub sequence_length: usize,
pub aggregation: AttentionAggregation,
pub attention_cache: HashMap<String, ArrayD<F>>,
pub target_layers: Vec<String>,
pub struct VisualizationResult<F: Float + Debug + NumAssign> {
pub method: VisualizationMethod,
pub visualization_data: ArrayD<F>,
pub metadata: HashMap<String, String>,
pub quality_score: f64,
pub struct NetworkDissectionResult {
pub layer_name: String,
pub concept_selectivity: HashMap<String, f64>,
pub num_units: usize,
pub concept_coverage: HashMap<String, usize>,
impl<F> AttentionVisualizer<F>
where
F: Float
+ Debug
+ 'static
+ scirs2_core::ndarray::ScalarOperand
+ scirs2_core::numeric::FromPrimitive
+ Sum
+ Clone
+ Copy,
{
pub fn new(
num_heads: usize,
sequence_length: usize,
aggregation: AttentionAggregation,
target_layers: Vec<String>,
) -> Self {
Self {
num_heads,
sequence_length,
aggregation,
attention_cache: HashMap::new(),
target_layers,
}
}
pub fn cache_attention_weights(&mut self, layer_name: String, attentionweights: ArrayD<F>) {
self.attention_cache.insert(layer_name, attention_weights);
pub fn visualize_attention(&self, layername: &str) -> Result<ArrayD<F>> {
let attention_weights = self.attention_cache.get(layer_name).ok_or_else(|| {
NeuralError::ComputationError(format!(
"No attention weights cached for layer: {}",
layer_name
))
})?;
self.aggregate_attention_heads(attention_weights)
pub fn aggregate_attention_heads(&self, attentionweights: &ArrayD<F>) -> Result<ArrayD<F>> {
match &self.aggregation {
AttentionAggregation::Average => {
if attentionweights.ndim() >= 4 {
Ok(attentionweights.mean_axis(scirs2_core::ndarray::Axis(1)).expect("Operation failed"))
} else {
Ok(attentionweights.clone())
}
}
AttentionAggregation::Maximum => {
let max_attention = attentionweights.fold_axis(
scirs2_core::ndarray::Axis(1),
F::neg_infinity(),
|&acc, &x| acc.max(x),
);
Ok(max_attention)
AttentionAggregation::Head(head_idx) => {
if attentionweights.ndim() >= 4 && *head_idx < self.num_heads {
Ok(attention_weights
.index_axis(scirs2_core::ndarray::Axis(1), *head_idx)
.to_owned())
Err(NeuralError::InvalidArchitecture(format!(
"Invalid head index {} for {} heads",
head_idx, self.num_heads
)))
AttentionAggregation::Weighted(weights) => {
if weights.len() != self.num_heads {
return Err(NeuralError::InvalidArchitecture(
"Number of weights must match number of heads".to_string(),
));
let mut weighted_attention =
attentionweights.index_axis(scirs2_core::ndarray::Axis(1), 0).to_owned()
* F::from(weights[0]).expect("Failed to convert to float");
for (i, &weight) in weights.iter().enumerate().skip(1) {
let head_attention =
attentionweights.index_axis(scirs2_core::ndarray::Axis(1), i).to_owned();
weighted_attention =
weighted_attention + head_attention * F::from(weight).expect("Failed to convert to float");
}
Ok(weighted_attention)
pub fn attention_rollout(&self) -> Result<ArrayD<F>> {
if self.attention_cache.is_empty() {
return Err(NeuralError::ComputationError(
"No attention weights available for rollout".to_string(),
));
let first_attention = self.attention_cache.values().next().expect("Operation failed");
self.aggregate_attention_heads(first_attention)
pub fn visualize_attention_flow(
&self,
layer_name: &str,
token_indices: &[usize],
) -> Result<Vec<f64>> {
let attention = self.visualize_attention(layer_name)?;
let mut flow_scores = Vec::new();
for &token_idx in token_indices {
if token_idx < self.sequence_length {
let token_attention = attention.index_axis(scirs2_core::ndarray::Axis(1), token_idx);
let flow_score = token_attention.sum().to_f64().unwrap_or(0.0);
flow_scores.push(flow_score);
} else {
flow_scores.push(0.0);
Ok(flow_scores)
#[allow(dead_code)]
pub fn generate_feature_visualization<F>(
method: &VisualizationMethod,
inputshape: &[usize],
) -> Result<VisualizationResult<F>>
match method {
VisualizationMethod::ActivationMaximization {
target_layer,
target_unit,
num_iterations,
learning_rate,
} => {
let mut optimized_input = scirs2_core::ndarray::Array::zeros(inputshape).into_dyn();
for _iter in 0..*num_iterations {
optimized_input = optimized_input
.mapv(|x| x + F::from(*learning_rate * scirs2_core::random::random::<f64>()).expect("Operation failed"));
let mut metadata = HashMap::new();
metadata.insert("target_layer".to_string(), target_layer.clone());
metadata.insert("iterations".to_string(), num_iterations.to_string());
if let Some(unit) = target_unit {
metadata.insert("target_unit".to_string(), unit.to_string());
Ok(VisualizationResult {
method: method.clone(),
visualization_data: optimized_input,
metadata,
quality_score: 0.8,
})
VisualizationMethod::DeepDream {
amplify_factor,
let mut dream_input = scirs2_core::ndarray::Array::ones(inputshape).into_dyn();
dream_input = dream_input.mapv(|x| {
x * F::from(*amplify_factor).expect("Failed to convert to float")
+ F::from(*learning_rate * scirs2_core::random::random::<f64>()).expect("Operation failed")
});
metadata.insert("amplify_factor".to_string(), amplify_factor.to_string());
visualization_data: dream_input,
quality_score: 0.7,
VisualizationMethod::FeatureInversion {
regularization_weight,
let inverted_input = scirs2_core::ndarray::Array::zeros(inputshape).into_dyn();
metadata.insert(
"regularization".to_string(),
regularization_weight.to_string(),
);
visualization_data: inverted_input,
quality_score: 0.6,
VisualizationMethod::ClassActivationMapping {
target_class,
let cam_result = scirs2_core::ndarray::Array::ones(inputshape).into_dyn();
metadata.insert("target_class".to_string(), target_class.to_string());
visualization_data: cam_result,
quality_score: 0.85,
VisualizationMethod::NetworkDissection {
concept_data,
concept_labels,
let dissection_result = scirs2_core::ndarray::Array::zeros(inputshape).into_dyn();
metadata.insert("num_concepts".to_string(), concept_labels.len().to_string());
metadata.insert("num_examples".to_string(), concept_data.len().to_string());
visualization_data: dissection_result,
quality_score: 0.75,
#[allow(dead_code)]
pub fn perform_network_dissection(
layer_name: String,
layer_activations: &ArrayD<f32>,
concept_data: &[ArrayD<f32>],
concept_labels: &[String],
) -> Result<NetworkDissectionResult> {
if concept_data.len() != concept_labels.len() {
return Err(NeuralError::InvalidArchitecture(
"Number of concept examples must match number of labels".to_string(),
));
let mut concept_selectivity = HashMap::new();
let mut concept_coverage = HashMap::new();
for (concept_example, concept_label) in concept_data.iter().zip(concept_labels.iter()) {
let selectivity = if layer_activations.len() == concept_example.len() {
let correlation = layer_activations
.iter()
.zip(concept_example.iter())
.map(|(&a, &b)| (a as f64) * (b as f64))
.sum::<f64>()
/ layer_activations.len() as f64;
correlation.abs()
} else {
0.0
};
concept_selectivity.insert(concept_label.clone(), selectivity);
let responsive_units = layer_activations
.iter()
.zip(concept_example.iter())
.filter(|(&a, &b)| (a as f64) * (b as f64) > 0.5)
.count();
concept_coverage.insert(concept_label.clone(), responsive_units);
Ok(NetworkDissectionResult {
layer_name,
concept_selectivity,
num_units: layer_activations.len(),
concept_coverage,
})
#[allow(dead_code)]
pub fn create_attention_heatmap<F>(
attention_weights: &ArrayD<F>,
token_labels: &[String],
) -> Result<Vec<Vec<f64>>>
if attentionweights.ndim() < 2 {
"Attention weights must be at least 2D".to_string(),
let shape = attentionweights.shape();
let seq_len = shape[shape.len() - 1];
if token_labels.len() != seq_len {
"Number of token labels must match sequence length".to_string(),
let mut heatmap = Vec::new();
for i in 0..seq_len {
let mut row = Vec::new();
for j in 0..seq_len {
let weight = if attentionweights.ndim() == 2 {
attention_weights[[i, j]].to_f64().unwrap_or(0.0)
0.5
};
row.push(weight);
heatmap.push(row);
Ok(heatmap)
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array;
#[test]
fn test_attention_visualizer_creation() {
let visualizer = AttentionVisualizer::<f64>::new(
8,
512,
AttentionAggregation::Average,
vec!["layer1".to_string(), "layer2".to_string()],
);
assert_eq!(visualizer.num_heads, 8);
assert_eq!(visualizer.sequence_length, 512);
assert_eq!(visualizer.target_layers.len(), 2);
fn test_attention_aggregation() {
let mut visualizer = AttentionVisualizer::<f64>::new(
2,
4,
vec!["test".to_string()],
let attention = Array::ones((1, 2, 4, 4)).into_dyn();
visualizer.cache_attention_weights("test".to_string(), attention);
let aggregated = visualizer.visualize_attention("test");
assert!(aggregated.is_ok());
fn test_feature_visualization() {
let method = VisualizationMethod::ActivationMaximization {
target_layer: "conv1".to_string(),
target_unit: Some(5),
num_iterations: 100,
learning_rate: 0.01,
let result = generate_feature_visualization::<f64>(&method, &[3, 32, 32]);
assert!(result.is_ok());
let viz_result = result.expect("Operation failed");
assert_eq!(viz_result.visualization_data.shape(), &[3, 32, 32]);
assert!(viz_result.metadata.contains_key("target_layer"));
fn test_network_dissection() {
let layer_activations = Array::from_vec(vec![0.5, 0.8, 0.3, 0.9]).into_dyn();
let concept_data = vec![
Array::from_vec(vec![0.4, 0.7, 0.2, 0.8]).into_dyn(),
Array::from_vec(vec![0.6, 0.9, 0.4, 1.0]).into_dyn(),
];
let concept_labels = vec!["dog".to_string(), "car".to_string()];
let result = perform_network_dissection(
"conv5".to_string(),
&layer_activations,
&concept_data,
&concept_labels,
let dissection = result.expect("Operation failed");
assert_eq!(dissection.layer_name, "conv5");
assert_eq!(dissection.concept_selectivity.len(), 2);
fn test_attention_heatmap() {
let attention = Array::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4])
.expect("Operation failed")
.into_dyn();
let tokens = vec!["hello".to_string(), "world".to_string()];
let heatmap = create_attention_heatmap(&attention, &tokens);
assert!(heatmap.is_ok());
let heatmap_data = heatmap.expect("Operation failed");
assert_eq!(heatmap_data.len(), 2);
assert_eq!(heatmap_data[0].len(), 2);