1use super::config::VisualizationConfig;
7use crate::error::{NeuralError, Result};
8use crate::models::sequential::Sequential;
9use scirs2_core::ndarray::ArrayStatCompat;
10use scirs2_core::ndarray::{ArrayD, ScalarOperand};
11use scirs2_core::numeric::Float;
12use scirs2_core::NumAssign;
13use serde::Serialize;
14use statrs::statistics::Statistics;
15use std::collections::HashMap;
16use std::fmt::Debug;
17use std::path::PathBuf;
18#[allow(dead_code)]
20pub struct ActivationVisualizer<F: Float + Debug + ScalarOperand + NumAssign> {
21 model: Sequential<F>,
23 config: VisualizationConfig,
25 activation_cache: HashMap<String, ArrayD<F>>,
27}
28#[derive(Debug, Clone, Serialize)]
30pub struct ActivationVisualizationOptions {
31 pub target_layers: Vec<String>,
33 pub visualization_type: ActivationVisualizationType,
35 pub normalization: ActivationNormalization,
37 pub colormap: Colormap,
39 pub aggregation: ChannelAggregation,
41}
42
43#[derive(Debug, Clone, PartialEq, Serialize)]
45pub enum ActivationVisualizationType {
46 FeatureMaps,
48 Histograms,
50 Statistics,
52 AttentionMaps,
54 ActivationFlow,
56}
57
58#[derive(Debug, Clone, PartialEq, Serialize)]
60pub enum ActivationNormalization {
61 None,
63 MinMax,
65 ZScore,
67 Percentile(f64, f64),
69 Custom(String),
71}
72
73#[derive(Debug, Clone, PartialEq, Serialize)]
75pub enum Colormap {
76 Viridis,
78 Plasma,
80 Inferno,
82 Jet,
84 Gray,
86 RdBu,
88 Custom(Vec<String>),
90}
91
92#[derive(Debug, Clone, PartialEq, Serialize)]
94pub enum ChannelAggregation {
95 None,
97 Mean,
99 Max,
101 Min,
103 Std,
105 Select(Vec<usize>),
107}
108
109#[derive(Debug, Clone, Serialize)]
111pub struct ActivationStatistics<F: Float + Debug + serde::Serialize + NumAssign> {
112 pub layer_name: String,
114 pub mean: F,
116 pub std: F,
118 pub min: F,
120 pub max: F,
122 pub percentiles: [F; 5],
124 pub sparsity: f64,
126 pub dead_neurons: usize,
128 pub total_neurons: usize,
130}
131
132#[derive(Debug, Clone, Serialize)]
134pub struct FeatureMapInfo {
135 pub feature_index: usize,
137 pub spatial_dims: (usize, usize),
139 pub channels: usize,
141 pub activation_range: (f64, f64),
143}
144
145#[derive(Debug, Clone, Serialize)]
147pub struct ActivationHistogram<F: Float + Debug + NumAssign> {
148 pub layer_name: String,
150 pub bins: Vec<F>,
152 pub counts: Vec<usize>,
154 pub edges: Vec<F>,
156 pub total_samples: usize,
158}
159
160impl<
162 F: Float
163 + Debug
164 + std::fmt::Display
165 + 'static
166 + scirs2_core::numeric::FromPrimitive
167 + ScalarOperand
168 + Send
169 + Sync
170 + serde::Serialize
171 + NumAssign,
172 > ActivationVisualizer<F>
173{
174 pub fn new(model: Sequential<F>, config: VisualizationConfig) -> Self {
176 Self {
177 model,
178 config,
179 activation_cache: HashMap::new(),
180 }
181 }
182 pub fn visualize_activations(
184 &mut self,
185 input: &ArrayD<F>,
186 options: &ActivationVisualizationOptions,
187 ) -> Result<Vec<PathBuf>> {
188 self.compute_activations(input, &options.target_layers)?;
190 match options.visualization_type {
192 ActivationVisualizationType::FeatureMaps => self.generate_feature_maps(options),
193 ActivationVisualizationType::Histograms => self.generate_histograms(options),
194 ActivationVisualizationType::Statistics => self.generate_statistics(options),
195 ActivationVisualizationType::AttentionMaps => self.generate_attention_maps(options),
196 ActivationVisualizationType::ActivationFlow => self.generate_activation_flow(options),
197 }
198 }
199
200 pub fn get_cached_activations(&self, layer_name: &str) -> Option<&ArrayD<F>> {
202 self.activation_cache.get(layer_name)
203 }
204
205 pub fn clear_cache(&mut self) {
207 self.activation_cache.clear();
208 }
209
210 pub fn get_activation_statistics(&self) -> Result<Vec<ActivationStatistics<F>>> {
212 let mut stats = Vec::new();
213 for (layer_name, activations) in &self.activation_cache {
214 let layer_stats = self.compute_layer_statistics(layer_name, activations)?;
215 stats.push(layer_stats);
216 }
217 Ok(stats)
218 }
219
220 pub fn update_config(&mut self, config: VisualizationConfig) {
222 self.config = config;
223 }
224 fn compute_activations(&mut self, input: &ArrayD<F>, target_layers: &[String]) -> Result<()> {
225 let mut current_output = input.clone();
226 if target_layers.is_empty() || target_layers.contains(&"input".to_string()) {
228 self.activation_cache
229 .insert("input".to_string(), input.clone());
230 }
231 for (layer_idx, layer) in self.model.layers().iter().enumerate() {
233 current_output = layer.forward(¤t_output)?;
234 let layer_name = format!("layer_{}", layer_idx);
235 if target_layers.is_empty() || target_layers.contains(&layer_name) {
237 self.activation_cache
238 .insert(layer_name, current_output.clone());
239 }
240 }
241 Ok(())
242 }
243 fn generate_feature_maps(
244 &self,
245 options: &ActivationVisualizationOptions,
246 ) -> Result<Vec<PathBuf>> {
247 let mut output_paths = Vec::new();
248 for layer_name in &options.target_layers {
249 if let Some(activations) = self.activation_cache.get(layer_name) {
250 let feature_maps = self.process_activations_for_visualization(
251 activations,
252 &options.normalization,
253 &options.aggregation,
254 )?;
255 let svg_content =
257 self.create_feature_map_svg(&feature_maps, layer_name, &options.colormap)?;
258 let output_path = self
259 .config
260 .output_dir
261 .join(format!("{}_feature_maps.svg", layer_name));
262 std::fs::write(&output_path, svg_content).map_err(|e| {
263 NeuralError::IOError(format!("Failed to write feature map: {}", e))
264 })?;
265 output_paths.push(output_path);
266 }
267 }
268 Ok(output_paths)
269 }
270 fn generate_histograms(
271 &self,
272 options: &ActivationVisualizationOptions,
273 ) -> Result<Vec<PathBuf>> {
274 let mut output_paths = Vec::new();
275 for layer_name in &options.target_layers {
276 if let Some(activations) = self.activation_cache.get(layer_name) {
277 let histogram = self.compute_activation_histogram(layer_name, activations, 50)?;
278 let svg_content = self.create_histogram_svg(&histogram)?;
280 let output_path = self
281 .config
282 .output_dir
283 .join(format!("{}_histogram.svg", layer_name));
284 std::fs::write(&output_path, svg_content).map_err(|e| {
285 NeuralError::IOError(format!("Failed to write histogram: {}", e))
286 })?;
287 output_paths.push(output_path);
288 }
289 }
290 Ok(output_paths)
291 }
292 fn generate_statistics(
293 &self,
294 options: &ActivationVisualizationOptions,
295 ) -> Result<Vec<PathBuf>> {
296 let mut all_stats = Vec::new();
297 for layer_name in &options.target_layers {
298 if let Some(activations) = self.activation_cache.get(layer_name) {
299 let stats = self.compute_layer_statistics(layer_name, activations)?;
300 all_stats.push(stats);
301 }
302 }
303 let json_content = serde_json::to_string_pretty(&all_stats).map_err(|e| {
305 NeuralError::SerializationError(format!("Failed to serialize statistics: {}", e))
306 })?;
307 let json_path = self.config.output_dir.join("activation_statistics.json");
308 std::fs::write(&json_path, json_content)
309 .map_err(|e| NeuralError::IOError(format!("Failed to write statistics: {}", e)))?;
310 let svg_content = self.create_statistics_svg(&all_stats)?;
312 let svg_path = self.config.output_dir.join("activation_statistics.svg");
313 std::fs::write(&svg_path, svg_content).map_err(|e| {
314 NeuralError::IOError(format!("Failed to write statistics visualization: {}", e))
315 })?;
316 Ok(vec![json_path, svg_path])
317 }
318 fn generate_attention_maps(
319 &self,
320 options: &ActivationVisualizationOptions,
321 ) -> Result<Vec<PathBuf>> {
322 let mut output_paths = Vec::new();
323 for layer_name in &options.target_layers {
324 if let Some(activations) = self.activation_cache.get(layer_name) {
325 if activations.ndim() >= 3 {
327 let attention_map = self.compute_spatial_attention(activations)?;
328 let svg_content = self.create_attention_map_svg(&attention_map, layer_name)?;
329 let output_path = self
330 .config
331 .output_dir
332 .join(format!("{}_attention.svg", layer_name));
333 std::fs::write(&output_path, svg_content).map_err(|e| {
334 NeuralError::IOError(format!("Failed to write attention map: {}", e))
335 })?;
336 output_paths.push(output_path);
337 }
338 }
339 }
340 Ok(output_paths)
341 }
342 fn generate_activation_flow(
343 &self,
344 options: &ActivationVisualizationOptions,
345 ) -> Result<Vec<PathBuf>> {
346 let mut flow_data = Vec::new();
348 let sorted_layers: Vec<_> = options.target_layers.iter().collect();
349 for i in 0..sorted_layers.len().saturating_sub(1) {
350 let from_layer = sorted_layers[i];
351 let to_layer = sorted_layers[i + 1];
352 if let (Some(from_activations), Some(to_activations)) = (
353 self.activation_cache.get(from_layer),
354 self.activation_cache.get(to_layer),
355 ) {
356 let flow_intensity =
357 self.compute_activation_flow(from_activations, to_activations)?;
358 flow_data.push((from_layer.clone(), to_layer.clone(), flow_intensity));
359 }
360 }
361 if !flow_data.is_empty() {
362 let svg_content = self.create_flow_diagram_svg(&flow_data)?;
363 let output_path = self.config.output_dir.join("activation_flow.svg");
364 std::fs::write(&output_path, svg_content).map_err(|e| {
365 NeuralError::IOError(format!("Failed to write activation flow: {}", e))
366 })?;
367 Ok(vec![output_path])
368 } else {
369 Ok(Vec::new())
370 }
371 }
372 fn compute_layer_statistics(
373 &self,
374 layer_name: &str,
375 activations: &ArrayD<F>,
376 ) -> Result<ActivationStatistics<F>> {
377 let total_elements = activations.len();
378 if total_elements == 0 {
379 return Err(NeuralError::InvalidArgument(
380 "Empty activation tensor".to_string(),
381 ));
382 }
383 let mut sum = F::zero();
385 let mut min_val = F::infinity();
386 let mut max_val = F::neg_infinity();
387 let mut zero_count = 0;
388 for &val in activations.iter() {
389 sum += val;
390 if val < min_val {
391 min_val = val;
392 }
393 if val > max_val {
394 max_val = val;
395 }
396 if val.abs() < F::from(1e-6).unwrap_or(F::zero()) {
397 zero_count += 1;
398 }
399 }
400 let mean = sum / F::from(total_elements).unwrap_or(F::one());
401 let mut variance_sum = F::zero();
403 for &val in activations.iter() {
404 let diff = val - mean;
405 variance_sum += diff * diff;
406 }
407 let variance = variance_sum / F::from(total_elements - 1).unwrap_or(F::one());
408 let std = variance.sqrt();
409 let mut sorted_values: Vec<F> = activations.iter().copied().collect();
411 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
412 let percentiles = [
413 sorted_values[total_elements * 5 / 100], sorted_values[total_elements * 25 / 100], sorted_values[total_elements * 50 / 100], sorted_values[total_elements * 75 / 100], sorted_values[total_elements * 95 / 100], ];
419 let sparsity = zero_count as f64 / total_elements as f64;
420 Ok(ActivationStatistics {
421 layer_name: layer_name.to_string(),
422 mean,
423 std,
424 min: min_val,
425 max: max_val,
426 percentiles,
427 sparsity,
428 dead_neurons: zero_count,
429 total_neurons: total_elements,
430 })
431 }
432 fn process_activations_for_visualization(
433 &self,
434 activations: &ArrayD<F>,
435 normalization: &ActivationNormalization,
436 aggregation: &ChannelAggregation,
437 ) -> Result<ArrayD<F>> {
438 let mut processed = activations.clone();
439 processed = match aggregation {
441 ChannelAggregation::None => processed,
442 ChannelAggregation::Mean => {
443 if processed.ndim() > 2 {
444 let mean_axis = processed.ndim() - 1; processed
446 .mean_axis(scirs2_core::ndarray::Axis(mean_axis))
447 .expect("Operation failed")
448 .insert_axis(scirs2_core::ndarray::Axis(mean_axis))
449 } else {
450 processed
451 }
452 }
453 ChannelAggregation::Max => {
454 if processed.ndim() > 2 {
455 let max_axis = processed.ndim() - 1;
456 let max_values = processed.fold_axis(
457 scirs2_core::ndarray::Axis(max_axis),
458 F::neg_infinity(),
459 |&acc, &x| acc.max(x),
460 );
461 max_values.insert_axis(scirs2_core::ndarray::Axis(max_axis))
462 } else {
463 processed
464 }
465 }
466 ChannelAggregation::Min => {
467 if processed.ndim() > 2 {
468 let min_axis = processed.ndim() - 1;
469 let min_values = processed.fold_axis(
470 scirs2_core::ndarray::Axis(min_axis),
471 F::infinity(),
472 |&acc, &x| acc.min(x),
473 );
474 min_values.insert_axis(scirs2_core::ndarray::Axis(min_axis))
475 } else {
476 processed
477 }
478 }
479 ChannelAggregation::Std => {
480 if processed.ndim() > 2 {
481 let std_axis = processed.ndim() - 1;
482 let mean = processed
483 .mean_axis(scirs2_core::ndarray::Axis(std_axis))
484 .expect("Operation failed");
485 let variance =
486 processed.map_axis(scirs2_core::ndarray::Axis(std_axis), |channel| {
487 let mean_val = mean.iter().next().copied().unwrap_or(F::zero());
488 let variance_sum = channel
489 .iter()
490 .map(|&x| (x - mean_val) * (x - mean_val))
491 .fold(F::zero(), |acc, x| acc + x);
492 (variance_sum / F::from(channel.len()).unwrap_or(F::one())).sqrt()
493 });
494 variance.insert_axis(scirs2_core::ndarray::Axis(std_axis))
495 } else {
496 processed
497 }
498 }
499 ChannelAggregation::Select(channels) => {
500 if processed.ndim() > 2 && !channels.is_empty() {
501 let channel_axis = processed.ndim() - 1;
502 let mut selected_slices = Vec::new();
503 for &channel_idx in channels {
504 if channel_idx < processed.shape()[channel_axis] {
505 let slice = processed
506 .index_axis(scirs2_core::ndarray::Axis(channel_axis), channel_idx);
507 selected_slices
508 .push(slice.insert_axis(scirs2_core::ndarray::Axis(channel_axis)));
509 }
510 }
511 if !selected_slices.is_empty() {
512 scirs2_core::ndarray::concatenate(
513 scirs2_core::ndarray::Axis(channel_axis),
514 &selected_slices.iter().map(|x| x.view()).collect::<Vec<_>>(),
515 )
516 .map_err(|_| {
517 NeuralError::DimensionMismatch(
518 "Failed to concatenate selected channels".to_string(),
519 )
520 })?
521 } else {
522 processed
523 }
524 } else {
525 processed
526 }
527 }
528 };
529 processed = match normalization {
531 ActivationNormalization::None => processed,
532 ActivationNormalization::MinMax => {
533 let min_val = processed.iter().copied().fold(F::infinity(), F::min);
534 let max_val = processed.iter().copied().fold(F::neg_infinity(), F::max);
535 let range = max_val - min_val;
536 if range > F::zero() {
537 processed.mapv(|x| (x - min_val) / range)
538 } else {
539 processed.mapv(|_| F::zero())
540 }
541 }
542 ActivationNormalization::ZScore => {
543 let mean = processed.mean_or(F::zero());
544 let variance = processed
545 .iter()
546 .map(|&x| (x - mean) * (x - mean))
547 .fold(F::zero(), |acc, x| acc + x)
548 / F::from(processed.len()).unwrap_or(F::one());
549 let std = variance.sqrt();
550 if std > F::zero() {
551 processed.mapv(|x| (x - mean) / std)
552 } else {
553 processed.mapv(|_| F::zero())
554 }
555 }
556 ActivationNormalization::Percentile(low, high) => {
557 let mut values: Vec<F> = processed.iter().copied().collect();
558 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
559 let n = values.len();
560 let low_idx = ((low / 100.0) * n as f64) as usize;
561 let high_idx = ((high / 100.0) * n as f64) as usize;
562 if low_idx < n && high_idx < n && low_idx < high_idx {
563 let low_val = values[low_idx];
564 let high_val = values[high_idx];
565 let range = high_val - low_val;
566 if range > F::zero() {
567 processed.mapv(|x| ((x - low_val) / range).max(F::zero()).min(F::one()))
568 } else {
569 processed.mapv(|_| F::zero())
570 }
571 } else {
572 processed
573 }
574 }
575 ActivationNormalization::Custom(_) => {
576 processed
579 }
580 };
581 Ok(processed)
582 }
583 fn create_feature_map_svg(
584 &self,
585 feature_maps: &ArrayD<F>,
586 layer_name: &str,
587 colormap: &Colormap,
588 ) -> Result<String> {
589 let width = self.config.style.layout.width;
590 let height = self.config.style.layout.height;
591 let colors = self.get_colormap_colors(colormap);
593 let mut svg = format!(
594 r#"<svg xmlns="http://www.w3.org/2000/svg" width="{}" height="{}" viewBox="0 0 {} {}">
595"#,
596 width, height, width, height
597 );
598 svg.push_str(&format!(
600 "<text x=\"{}\" y=\"30\" text-anchor=\"middle\" font-family=\"{}\" font-size=\"{}\" fill=\"#333\">{} Feature Maps</text>\n",
601 width / 2, self.config.style.font.family,
602 (self.config.style.font.size as f32 * self.config.style.font.title_scale) as u32,
603 layer_name
604 ));
605 if feature_maps.ndim() >= 2 {
607 let shape = feature_maps.shape();
608 let map_height = shape[0].min(32); let map_width = shape[1].min(32);
610 let cell_width = (width - 100) / map_width as u32;
611 let cell_height = (height - 100) / map_height as u32;
612 for i in 0..map_height {
613 for j in 0..map_width {
614 let value = if let Some(&val) = feature_maps.get([i, j].as_slice()) {
615 val
616 } else {
617 F::zero()
618 };
619 let intensity = (value.to_f64().unwrap_or(0.0) * 255.0).clamp(0.0, 255.0) as u8;
620 let color = if colors.len() > 1 {
621 let color_idx =
623 (intensity as f64 / 255.0 * (colors.len() - 1) as f64) as usize;
624 colors[color_idx.min(colors.len() - 1)].clone()
625 } else {
626 format!("rgb({},{},{})", intensity, intensity, intensity)
627 };
628 svg.push_str(&format!(
629 "<rect x=\"{}\" y=\"{}\" width=\"{}\" height=\"{}\" fill=\"{}\" stroke=\"#ccc\" stroke-width=\"0.5\"/>\n",
630 50 + j * cell_width as usize,
631 50 + i * cell_height as usize,
632 cell_width,
633 cell_height,
634 color
635 ));
636 }
637 }
638 }
639 svg.push_str("</svg>");
640 Ok(svg)
641 }
642 fn compute_activation_histogram(
643 &self,
644 layer_name: &str,
645 activations: &ArrayD<F>,
646 num_bins: usize,
647 ) -> Result<ActivationHistogram<F>> {
648 let values: Vec<F> = activations.iter().copied().collect();
649 if values.is_empty() {
650 return Err(NeuralError::ValidationError(
651 "Empty activations for histogram".to_string(),
652 ));
653 }
654 let min_val = values.iter().copied().fold(F::infinity(), F::min);
655 let max_val = values.iter().copied().fold(F::neg_infinity(), F::max);
656 let range = max_val - min_val;
657 if range <= F::zero() {
658 return Ok(ActivationHistogram {
660 layer_name: layer_name.to_string(),
661 bins: vec![min_val],
662 counts: vec![values.len()],
663 edges: vec![min_val, max_val],
664 total_samples: values.len(),
665 });
666 }
667 let bin_width = range / F::from(num_bins).unwrap_or(F::one());
668 let mut bins = Vec::with_capacity(num_bins);
669 let mut counts = vec![0; num_bins];
670 let mut edges = Vec::with_capacity(num_bins + 1);
671 for i in 0..=num_bins {
673 edges.push(min_val + F::from(i).unwrap_or(F::zero()) * bin_width);
674 }
675 for i in 0..num_bins {
676 bins.push(
677 min_val
678 + (F::from(i).unwrap_or(F::zero()) + F::from(0.5).unwrap_or(F::zero()))
679 * bin_width,
680 );
681 }
682 for &value in &values {
684 let bin_idx = ((value - min_val) / bin_width)
685 .to_usize()
686 .unwrap_or(0)
687 .min(num_bins - 1);
688 counts[bin_idx] += 1;
689 }
690 Ok(ActivationHistogram {
691 layer_name: layer_name.to_string(),
692 bins,
693 counts,
694 edges,
695 total_samples: values.len(),
696 })
697 }
698
699 fn create_histogram_svg(&self, histogram: &ActivationHistogram<F>) -> Result<String> {
700 let width = self.config.style.layout.width;
701 let height = self.config.style.layout.height;
702 let mut svg = format!(
703 "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\
704 <svg width=\"{}\" height=\"{}\" xmlns=\"http://www.w3.org/2000/svg\">\n",
705 width, height
706 );
707 svg.push_str("</svg>");
708 Ok(svg)
709 }
710
711 fn create_statistics_svg(&self, stats: &[ActivationStatistics<F>]) -> Result<String> {
712 let width = self.config.style.layout.width;
713 let height = self.config.style.layout.height;
714 let mut svg = format!(
715 "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\
716 <svg width=\"{}\" height=\"{}\" xmlns=\"http://www.w3.org/2000/svg\">\n",
717 width, height
718 );
719 svg.push_str("</svg>");
720 Ok(svg)
721 }
722
723 fn compute_spatial_attention(&self, activations: &ArrayD<F>) -> Result<ArrayD<F>> {
724 Ok(activations.clone())
726 }
727
728 fn create_attention_map_svg(
729 &self,
730 attention_map: &ArrayD<F>,
731 layer_name: &str,
732 ) -> Result<String> {
733 let width = self.config.style.layout.width;
734 let height = self.config.style.layout.height;
735 let mut svg = format!(
736 "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\
737 <svg width=\"{}\" height=\"{}\" xmlns=\"http://www.w3.org/2000/svg\">\n",
738 width, height
739 );
740 svg.push_str("</svg>");
741 Ok(svg)
742 }
743
744 fn compute_activation_flow(
745 &self,
746 from_activations: &ArrayD<F>,
747 to_activations: &ArrayD<F>,
748 ) -> Result<f64> {
749 Ok(0.0)
750 }
751
752 fn create_flow_diagram_svg(&self, flow_data: &[(String, String, f64)]) -> Result<String> {
753 let width = self.config.style.layout.width;
754 let height = self.config.style.layout.height;
755 let mut svg = format!(
756 "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\
757 <svg width=\"{}\" height=\"{}\" xmlns=\"http://www.w3.org/2000/svg\">\n",
758 width, height
759 );
760 svg.push_str(&format!(
761 "<text x=\"{}\" y=\"30\" text-anchor=\"middle\" font-family=\"{}\" font-size=\"{}\" fill=\"#333\">Activation Flow Diagram</text>\n",
762 width / 2, self.config.style.font.family, self.config.style.font.size
763 ));
764 svg.push_str("</svg>");
765 Ok(svg)
766 }
767
768 fn get_colormap_colors(&self, colormap: &Colormap) -> Vec<String> {
769 match colormap {
770 Colormap::Viridis => vec![
771 "#440154".to_string(),
772 "#482677".to_string(),
773 "#3f4a8a".to_string(),
774 "#31678e".to_string(),
775 "#26838f".to_string(),
776 "#1f9d8a".to_string(),
777 "#6cce5a".to_string(),
778 "#b6de2b".to_string(),
779 "#fee825".to_string(),
780 ],
781 Colormap::Plasma => vec![
782 "#0c0786".to_string(),
783 "#40039a".to_string(),
784 "#6a0a83".to_string(),
785 "#8b0aa5".to_string(),
786 "#a83eaf".to_string(),
787 "#c06fad".to_string(),
788 "#d8a1a3".to_string(),
789 "#f0d3a3".to_string(),
790 "#fcffa4".to_string(),
791 ],
792 Colormap::Inferno => vec![
793 "#000003".to_string(),
794 "#1f0c48".to_string(),
795 "#581845".to_string(),
796 "#8b1538".to_string(),
797 "#b71f2b".to_string(),
798 "#db4c26".to_string(),
799 "#ed7953".to_string(),
800 "#fbad76".to_string(),
801 ],
802 Colormap::Jet => vec![
803 "#00007f".to_string(),
804 "#0000ff".to_string(),
805 "#007fff".to_string(),
806 "#00ffff".to_string(),
807 "#7fff00".to_string(),
808 "#ffff00".to_string(),
809 "#ff7f00".to_string(),
810 "#ff0000".to_string(),
811 "#7f0000".to_string(),
812 ],
813 Colormap::Gray => vec![
814 "#000000".to_string(),
815 "#404040".to_string(),
816 "#808080".to_string(),
817 "#c0c0c0".to_string(),
818 "#ffffff".to_string(),
819 ],
820 Colormap::RdBu => vec![
821 "#053061".to_string(),
822 "#2166ac".to_string(),
823 "#4393c3".to_string(),
824 "#92c5de".to_string(),
825 "#d1e5f0".to_string(),
826 "#f7f7f7".to_string(),
827 "#fddbc7".to_string(),
828 "#f4a582".to_string(),
829 "#d6604d".to_string(),
830 "#b2182b".to_string(),
831 "#67001f".to_string(),
832 ],
833 Colormap::Custom(colors) => colors.clone(),
834 }
835 }
836}
837
838impl Default for ActivationVisualizationOptions {
840 fn default() -> Self {
841 Self {
842 target_layers: Vec::new(),
843 visualization_type: ActivationVisualizationType::FeatureMaps,
844 normalization: ActivationNormalization::MinMax,
845 colormap: Colormap::Viridis,
846 aggregation: ChannelAggregation::Mean,
847 }
848 }
849}
850
851impl Default for FeatureMapInfo {
852 fn default() -> Self {
853 Self {
854 feature_index: 0,
855 spatial_dims: (1, 1),
856 channels: 1,
857 activation_range: (0.0, 1.0),
858 }
859 }
860}
861#[cfg(test)]
862mod tests {
863 use super::*;
864 use crate::layers::Dense;
865 use scirs2_core::random::SeedableRng;
866 #[test]
867 fn test_activation_visualizer_creation() {
868 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
869 let mut model = Sequential::<f32>::new();
870 model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
871 let config = VisualizationConfig::default();
872 let visualizer = ActivationVisualizer::new(model, config);
873 assert!(visualizer.activation_cache.is_empty());
874 }
875
876 #[test]
877 fn test_activation_visualization_options_default() {
878 let options = ActivationVisualizationOptions::default();
879 assert_eq!(
880 options.visualization_type,
881 ActivationVisualizationType::FeatureMaps
882 );
883 assert_eq!(options.normalization, ActivationNormalization::MinMax);
884 assert_eq!(options.colormap, Colormap::Viridis);
885 assert_eq!(options.aggregation, ChannelAggregation::Mean);
886 }
887
888 #[test]
889 fn test_activation_visualization_types() {
890 let types = [
891 ActivationVisualizationType::FeatureMaps,
892 ActivationVisualizationType::Histograms,
893 ActivationVisualizationType::Statistics,
894 ActivationVisualizationType::AttentionMaps,
895 ActivationVisualizationType::ActivationFlow,
896 ];
897 assert_eq!(types.len(), 5);
898 assert_eq!(types[0], ActivationVisualizationType::FeatureMaps);
899 }
900
901 #[test]
902 fn test_normalization_methods() {
903 let none = ActivationNormalization::None;
904 let minmax = ActivationNormalization::MinMax;
905 let zscore = ActivationNormalization::ZScore;
906 let percentile = ActivationNormalization::Percentile(5.0, 95.0);
907 assert_eq!(none, ActivationNormalization::None);
908 assert_eq!(minmax, ActivationNormalization::MinMax);
909 assert_eq!(zscore, ActivationNormalization::ZScore);
910 match percentile {
911 ActivationNormalization::Percentile(low, high) => {
912 assert_eq!(low, 5.0);
913 assert_eq!(high, 95.0);
914 }
915 _ => unreachable!("Expected percentile normalization"),
916 }
917 }
918
919 #[test]
920 fn test_colormaps() {
921 let colormaps = [
922 Colormap::Viridis,
923 Colormap::Plasma,
924 Colormap::Inferno,
925 Colormap::Jet,
926 Colormap::Gray,
927 Colormap::RdBu,
928 ];
929 assert_eq!(colormaps.len(), 6);
930 assert_eq!(colormaps[0], Colormap::Viridis);
931 let custom = Colormap::Custom(vec!["#ff0000".to_string(), "#00ff00".to_string()]);
932 match custom {
933 Colormap::Custom(colors) => assert_eq!(colors.len(), 2),
934 _ => unreachable!("Expected custom colormap"),
935 }
936 }
937
938 #[test]
939 fn test_channel_aggregation() {
940 let aggregations = [
941 ChannelAggregation::None,
942 ChannelAggregation::Mean,
943 ChannelAggregation::Max,
944 ChannelAggregation::Min,
945 ChannelAggregation::Std,
946 ChannelAggregation::Select(vec![0, 1, 2]),
947 ];
948 assert_eq!(aggregations.len(), 6);
949 assert_eq!(aggregations[1], ChannelAggregation::Mean);
950 match &aggregations[5] {
951 ChannelAggregation::Select(channels) => assert_eq!(channels.len(), 3),
952 _ => unreachable!("Expected select aggregation"),
953 }
954 }
955
956 #[test]
957 fn test_feature_map_info_default() {
958 let info = FeatureMapInfo::default();
959 assert_eq!(info.feature_index, 0);
960 assert_eq!(info.spatial_dims, (1, 1));
961 assert_eq!(info.channels, 1);
962 assert_eq!(info.activation_range, (0.0, 1.0));
963 }
964
965 #[test]
966 fn test_cache_operations() {
967 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
968 let mut model = Sequential::<f32>::new();
969 model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
970 let config = VisualizationConfig::default();
971 let mut visualizer = ActivationVisualizer::new(model, config);
972 assert!(visualizer.get_cached_activations("test_layer").is_none());
973 visualizer.clear_cache();
974 }
975}