pub mod activations;
pub mod attention;
pub mod config;
pub mod network;
pub mod training;
pub use config::{
ColorPalette, CustomTheme, DownsamplingStrategy, FontConfig, GridConfig, ImageFormat,
InteractiveConfig, LayoutConfig, Margins, PerformanceConfig, StyleConfig, Theme,
VisualizationConfig,
};
pub use network::{
ArrowStyle, BoundingBox, Connection, ConnectionType, ConnectionVisualProps, DataFlowInfo,
LayerIOInfo, LayerInfo, LayerPosition, LayerVisualProps, LayoutAlgorithm, LineStyle,
NetworkLayout, NetworkVisualizer, Point2D, Size2D, ThroughputInfo,
};
pub use training::{
AxisConfig, AxisScale, LineStyleConfig, MarkerConfig, MarkerShape, PlotConfig, PlotType,
SeriesConfig, SystemMetrics, TickConfig, TickFormat, TrainingMetrics, TrainingVisualizer,
UpdateMode,
};
pub use activations::{
ActivationHistogram, ActivationNormalization, ActivationStatistics,
ActivationVisualizationOptions, ActivationVisualizationType, ActivationVisualizer,
ChannelAggregation, Colormap, FeatureMapInfo,
};
pub use attention::{
AttentionData, AttentionStatistics, AttentionVisualizationOptions, AttentionVisualizationType,
AttentionVisualizer, CompressionSettings, DataFormat, ExportFormat, ExportOptions,
ExportQuality, HeadAggregation, HeadInfo, HeadSelection, HighlightConfig, HighlightStyle,
Resolution, VideoFormat,
};
pub type NetworkViz<F> = NetworkVisualizer<F>;
pub type TrainingViz<F> = TrainingVisualizer<F>;
pub type ActivationViz<F> = ActivationVisualizer<F>;
pub type AttentionViz<F> = AttentionVisualizer<F>;
pub struct VisualizationSuite<F>
where
F: scirs2_core::numeric::Float
+ std::fmt::Debug
+ scirs2_core::numeric::NumAssign
+ scirs2_core::ndarray::ScalarOperand
+ 'static
+ scirs2_core::numeric::FromPrimitive
+ Send
+ Sync
+ serde::Serialize,
{
pub network: NetworkVisualizer<F>,
pub training: TrainingVisualizer<F>,
pub activation: ActivationVisualizer<F>,
pub attention: AttentionVisualizer<F>,
config: VisualizationConfig,
}
impl<F> VisualizationSuite<F>
where
F: scirs2_core::numeric::Float
+ std::fmt::Debug
+ std::fmt::Display
+ scirs2_core::numeric::NumAssign
+ scirs2_core::ndarray::ScalarOperand
+ 'static
+ scirs2_core::numeric::FromPrimitive
+ Send
+ Sync
+ serde::Serialize,
{
pub fn new(
model: crate::models::sequential::Sequential<F>,
config: VisualizationConfig,
) -> Self {
let training = TrainingVisualizer::new(config.clone());
let activation = ActivationVisualizer::new(
crate::models::sequential::Sequential::default(),
config.clone(),
);
let attention = AttentionVisualizer::new(
crate::models::sequential::Sequential::default(),
config.clone(),
);
let network = NetworkVisualizer::new(model, config.clone());
Self {
network,
training,
activation,
attention,
config,
}
}
pub fn update_config(&mut self, config: VisualizationConfig) {
self.config = config.clone();
self.network.update_config(config.clone());
self.training.update_config(config.clone());
self.activation.update_config(config.clone());
self.attention.update_config(config);
}
pub fn get_config(&self) -> &VisualizationConfig {
&self.config
}
pub fn clear_all_caches(&mut self) {
self.network.clear_cache();
self.training.clear_history();
self.activation.clear_cache();
self.attention.clear_cache();
}
}
pub struct VisualizationConfigBuilder {
config: VisualizationConfig,
}
impl VisualizationConfigBuilder {
pub fn new() -> Self {
Self {
config: VisualizationConfig::default(),
}
}
pub fn output_dir<P: Into<std::path::PathBuf>>(mut self, path: P) -> Self {
self.config.output_dir = path.into();
self
}
pub fn image_format(mut self, format: ImageFormat) -> Self {
self.config.image_format = format;
self
}
pub fn color_palette(mut self, palette: ColorPalette) -> Self {
self.config.style.color_palette = palette;
self
}
pub fn theme(mut self, theme: Theme) -> Self {
self.config.style.theme = theme;
self
}
pub fn interactive(mut self, enable: bool) -> Self {
self.config.interactive.enable_interaction = enable;
self
}
pub fn canvas_size(mut self, width: u32, height: u32) -> Self {
self.config.style.layout.width = width;
self.config.style.layout.height = height;
self
}
pub fn max_points(mut self, max_points: usize) -> Self {
self.config.performance.max_points_per_plot = max_points;
self
}
pub fn downsampling(mut self, strategy: DownsamplingStrategy) -> Self {
self.config.performance.enable_downsampling = true;
self.config.performance.downsampling_strategy = strategy;
self
}
pub fn build(self) -> VisualizationConfig {
self.config
}
}
impl Default for VisualizationConfigBuilder {
fn default() -> Self {
Self::new()
}
}
pub mod utils {
use super::*;
pub fn quick_config() -> VisualizationConfig {
VisualizationConfigBuilder::new()
.canvas_size(800, 600)
.theme(Theme::Light)
.color_palette(ColorPalette::Default)
.interactive(false)
.build()
}
pub fn publication_config() -> VisualizationConfig {
VisualizationConfigBuilder::new()
.canvas_size(1920, 1080)
.image_format(ImageFormat::PDF)
.color_palette(ColorPalette::ColorblindFriendly)
.build()
}
pub fn dashboard_config() -> VisualizationConfig {
VisualizationConfigBuilder::new()
.canvas_size(1200, 800)
.image_format(ImageFormat::HTML)
.theme(Theme::Dark)
.color_palette(ColorPalette::HighContrast)
.interactive(true)
.max_points(50000)
.downsampling(DownsamplingStrategy::LTTB)
.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::layers::Dense;
use scirs2_core::random::SeedableRng;
#[test]
fn test_visualization_suite_creation() {
let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
let mut model = crate::models::sequential::Sequential::<f32>::new();
model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
let config = VisualizationConfig::default();
let _suite = VisualizationSuite::new(model, config);
}
#[test]
fn test_config_builder() {
let config = VisualizationConfigBuilder::new()
.canvas_size(1920, 1080)
.theme(Theme::Dark)
.color_palette(ColorPalette::HighContrast)
.interactive(true)
.build();
assert_eq!(config.style.layout.width, 1920);
assert_eq!(config.style.layout.height, 1080);
assert_eq!(config.style.theme, Theme::Dark);
assert_eq!(config.style.color_palette, ColorPalette::HighContrast);
assert!(config.interactive.enable_interaction);
}
#[test]
fn test_utility_configs() {
let quick = utils::quick_config();
assert_eq!(quick.style.layout.width, 800);
assert_eq!(quick.style.layout.height, 600);
assert_eq!(quick.style.theme, Theme::Light);
assert!(!quick.interactive.enable_interaction);
let publication = utils::publication_config();
assert_eq!(publication.image_format, ImageFormat::PDF);
assert_eq!(
publication.style.color_palette,
ColorPalette::ColorblindFriendly
);
let dashboard = utils::dashboard_config();
assert_eq!(dashboard.image_format, ImageFormat::HTML);
assert_eq!(dashboard.style.theme, Theme::Dark);
assert!(dashboard.interactive.enable_interaction);
}
#[test]
fn test_type_aliases() {
let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
let mut model = crate::models::sequential::Sequential::<f32>::new();
model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
let config = VisualizationConfig::default();
let _network_viz: NetworkViz<f32> = NetworkVisualizer::new(model.clone(), config.clone());
let _training_viz: TrainingViz<f32> = TrainingVisualizer::new(config.clone());
let _activation_viz: ActivationViz<f32> =
ActivationVisualizer::new(model.clone(), config.clone());
let _attention_viz: AttentionViz<f32> = AttentionVisualizer::new(model, config);
}
#[test]
fn test_suite_operations() {
let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
let mut model = crate::models::sequential::Sequential::<f32>::new();
model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
let config = VisualizationConfig::default();
let mut suite = VisualizationSuite::new(model, config.clone());
assert_eq!(
suite.get_config().style.layout.width,
config.style.layout.width
);
suite.clear_all_caches();
let new_config = VisualizationConfigBuilder::new()
.canvas_size(1024, 768)
.build();
suite.update_config(new_config);
assert_eq!(suite.get_config().style.layout.width, 1024);
}
#[test]
fn test_module_integration() {
use super::activations::*;
use super::attention::*;
use super::config::*;
use super::network::*;
use super::training::*;
let _viz_config = VisualizationConfig::default();
let _plot_config = PlotConfig::default();
let _activation_options = ActivationVisualizationOptions::default();
let _attention_options = AttentionVisualizationOptions::default();
let _export_options = ExportOptions::default();
let _image_format = ImageFormat::SVG;
let _layout_algo = LayoutAlgorithm::Hierarchical;
let _plot_type = PlotType::Line;
let _colormap = Colormap::Viridis;
let _attention_type = AttentionVisualizationType::Heatmap;
}
}