pub mod activation_stats;
pub mod attention;
pub mod gradient_flow;
pub mod memory;
pub mod metrics;
pub mod profiler;
pub mod tensorboard;
pub use activation_stats::*;
pub use attention::*;
pub use gradient_flow::*;
pub use memory::*;
pub use metrics::*;
pub use profiler::*;
pub use tensorboard::*;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct ModelMonitor {
memory_tracker: MemoryTracker,
attention_visualizer: AttentionVisualizer,
profiler: ModelProfiler,
metrics_collector: MetricsCollector,
enabled: bool,
}
impl Default for ModelMonitor {
fn default() -> Self {
Self::new()
}
}
impl ModelMonitor {
pub fn new() -> Self {
Self {
memory_tracker: MemoryTracker::new(),
attention_visualizer: AttentionVisualizer::new(),
profiler: ModelProfiler::new(),
metrics_collector: MetricsCollector::new(),
enabled: true,
}
}
pub fn with_config(config: MonitoringConfig) -> Self {
Self {
memory_tracker: MemoryTracker::with_config(config.memory_config),
attention_visualizer: AttentionVisualizer::with_config(config.attention_config),
profiler: ModelProfiler::with_config(config.profiler_config),
metrics_collector: MetricsCollector::with_config(config.metrics_config),
enabled: config.enabled,
}
}
pub fn start_forward_pass(
&mut self,
batch_size: usize,
sequence_length: usize,
) -> Result<MonitoringSession> {
if !self.enabled {
return Ok(MonitoringSession::disabled());
}
let session_id = uuid::Uuid::new_v4().to_string();
let start_time = Instant::now();
self.memory_tracker.start_tracking(&session_id)?;
self.attention_visualizer.start_tracking(&session_id)?;
self.profiler.start_profiling(&session_id)?;
Ok(MonitoringSession {
id: session_id,
start_time,
batch_size,
sequence_length,
enabled: true,
})
}
pub fn track_attention(
&mut self,
session: &MonitoringSession,
layer_idx: usize,
attention_weights: &crate::tensor::Tensor,
input_tokens: Option<&[String]>,
) -> Result<()> {
if !session.enabled {
return Ok(());
}
self.attention_visualizer.track_attention(
&session.id,
layer_idx,
attention_weights,
input_tokens,
)
}
pub fn track_memory(
&mut self,
session: &MonitoringSession,
checkpoint: &str,
) -> Result<MemorySnapshot> {
if !session.enabled {
return Ok(MemorySnapshot::default());
}
self.memory_tracker.take_snapshot(&session.id, checkpoint)
}
pub fn end_session(&mut self, session: MonitoringSession) -> Result<MonitoringReport> {
if !session.enabled {
return Ok(MonitoringReport::default());
}
let duration = session.start_time.elapsed();
let memory_report = self.memory_tracker.end_tracking(&session.id)?;
let profiling_report = self.profiler.end_profiling(&session.id)?;
let attention_report = self.attention_visualizer.get_report(&session.id)?;
let report = MonitoringReport {
session_id: session.id,
duration,
batch_size: session.batch_size,
sequence_length: session.sequence_length,
memory_report,
profiling_report,
attention_report,
metrics: self.metrics_collector.collect_metrics()?,
};
Ok(report)
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn clear(&mut self) -> Result<()> {
self.memory_tracker.clear()?;
self.attention_visualizer.clear()?;
self.profiler.clear()?;
self.metrics_collector.clear()?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MonitoringConfig {
pub enabled: bool,
pub memory_config: MemoryTrackerConfig,
pub attention_config: AttentionVisualizerConfig,
pub profiler_config: ProfilerConfig,
pub metrics_config: MetricsCollectorConfig,
}
impl Default for MonitoringConfig {
fn default() -> Self {
Self {
enabled: true,
memory_config: MemoryTrackerConfig::default(),
attention_config: AttentionVisualizerConfig::default(),
profiler_config: ProfilerConfig::default(),
metrics_config: MetricsCollectorConfig::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct MonitoringSession {
pub id: String,
pub start_time: Instant,
pub batch_size: usize,
pub sequence_length: usize,
pub enabled: bool,
}
impl MonitoringSession {
fn disabled() -> Self {
Self {
id: String::new(),
start_time: Instant::now(),
batch_size: 0,
sequence_length: 0,
enabled: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MonitoringReport {
pub session_id: String,
pub duration: Duration,
pub batch_size: usize,
pub sequence_length: usize,
pub memory_report: MemoryReport,
pub profiling_report: ProfilingReport,
pub attention_report: AttentionReport,
pub metrics: HashMap<String, f64>,
}
impl Default for MonitoringReport {
fn default() -> Self {
Self {
session_id: String::new(),
duration: Duration::from_secs(0),
batch_size: 0,
sequence_length: 0,
memory_report: MemoryReport::default(),
profiling_report: ProfilingReport::default(),
attention_report: AttentionReport::default(),
metrics: HashMap::new(),
}
}
}
impl MonitoringReport {
pub fn save_to_file(&self, path: &str) -> Result<()> {
let json = serde_json::to_string_pretty(self)?;
std::fs::write(path, json)?;
Ok(())
}
pub fn load_from_file(path: &str) -> Result<Self> {
let content = std::fs::read_to_string(path)?;
let report = serde_json::from_str(&content)?;
Ok(report)
}
pub fn print_summary(&self) {
println!("Monitoring Report Summary");
println!("========================");
println!("Session ID: {}", self.session_id);
println!("Duration: {:.2}ms", self.duration.as_millis());
println!("Batch Size: {}", self.batch_size);
println!("Sequence Length: {}", self.sequence_length);
println!();
self.memory_report.print_summary();
println!();
self.profiling_report.print_summary();
println!();
self.attention_report.print_summary();
println!();
if !self.metrics.is_empty() {
println!("Additional Metrics:");
for (name, value) in &self.metrics {
println!(" {}: {:.4}", name, value);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_monitor_creation() {
let monitor = ModelMonitor::new();
assert!(monitor.is_enabled());
}
#[test]
fn test_monitor_with_config() {
let config = MonitoringConfig {
enabled: false,
..Default::default()
};
let monitor = ModelMonitor::with_config(config);
assert!(!monitor.is_enabled());
}
#[test]
fn test_monitoring_session() -> Result<()> {
let mut monitor = ModelMonitor::new();
let session = monitor.start_forward_pass(4, 128)?;
assert_eq!(session.batch_size, 4);
assert_eq!(session.sequence_length, 128);
assert!(session.enabled);
let report = monitor.end_session(session)?;
assert!(report.duration > Duration::from_nanos(0));
Ok(())
}
#[test]
fn test_disabled_monitoring() -> Result<()> {
let mut monitor = ModelMonitor::new();
monitor.set_enabled(false);
let session = monitor.start_forward_pass(4, 128)?;
assert!(!session.enabled);
let report = monitor.end_session(session)?;
assert_eq!(report.session_id, "");
Ok(())
}
#[test]
fn test_monitor_clear() -> Result<()> {
let mut monitor = ModelMonitor::new();
let session = monitor.start_forward_pass(4, 128)?;
let _report = monitor.end_session(session)?;
monitor.clear()?;
Ok(())
}
#[test]
fn test_monitoring_config_default() {
let config = MonitoringConfig::default();
assert!(config.enabled);
}
#[test]
fn test_monitoring_report_serialization() -> Result<()> {
let report = MonitoringReport::default();
let temp_path = "/tmp/test_monitoring_report.json";
report.save_to_file(temp_path)?;
let loaded_report = MonitoringReport::load_from_file(temp_path)?;
assert_eq!(report.session_id, loaded_report.session_id);
assert_eq!(report.batch_size, loaded_report.batch_size);
std::fs::remove_file(temp_path).ok();
Ok(())
}
}