use crate::error::{RusTorchError, RusTorchResult};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProfilingLevel {
Disabled,
Basic,
Standard,
Comprehensive,
Verbose,
}
impl Default for ProfilingLevel {
fn default() -> Self {
ProfilingLevel::Standard
}
}
#[derive(Debug, Clone)]
pub struct ProfilerConfig {
pub level: ProfilingLevel,
pub enable_memory_profiling: bool,
pub enable_gpu_profiling: bool,
pub enable_system_metrics: bool,
pub enable_call_stack: bool,
pub max_session_duration: Option<u64>,
pub metrics_buffer_size: usize,
pub sampling_rate: f64,
pub export_chrome_trace: bool,
pub export_tensorboard: bool,
pub export_json: bool,
}
impl Default for ProfilerConfig {
fn default() -> Self {
Self {
level: ProfilingLevel::Standard,
enable_memory_profiling: true,
enable_gpu_profiling: true,
enable_system_metrics: true,
enable_call_stack: true,
max_session_duration: Some(3600), metrics_buffer_size: 10000,
sampling_rate: 10.0, export_chrome_trace: true,
export_tensorboard: false,
export_json: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SessionState {
NotStarted,
Running,
Paused,
Completed,
Error,
}
#[derive(Debug, Clone)]
pub struct ProfilingSession {
pub session_id: String,
pub session_name: String,
pub state: SessionState,
pub start_time: Instant,
pub end_time: Option<Instant>,
pub config: ProfilerConfig,
pub operations: HashMap<String, OperationMetrics>,
pub max_call_depth: usize,
pub total_operations: usize,
pub error_message: Option<String>,
}
impl ProfilingSession {
pub fn new(name: String, config: ProfilerConfig) -> Self {
let session_id = generate_session_id();
Self {
session_id,
session_name: name,
state: SessionState::NotStarted,
start_time: Instant::now(),
end_time: None,
config,
operations: HashMap::new(),
max_call_depth: 0,
total_operations: 0,
error_message: None,
}
}
pub fn start(&mut self) -> RusTorchResult<()> {
if self.state != SessionState::NotStarted {
return Err(RusTorchError::Profiling {
message: "Session already started".to_string(),
});
}
self.state = SessionState::Running;
self.start_time = Instant::now();
Ok(())
}
pub fn stop(&mut self) -> RusTorchResult<SessionSnapshot> {
if self.state != SessionState::Running {
return Err(RusTorchError::Profiling {
message: "Session not running".to_string(),
});
}
self.state = SessionState::Completed;
self.end_time = Some(Instant::now());
Ok(self.create_snapshot())
}
pub fn pause(&mut self) -> RusTorchResult<()> {
if self.state != SessionState::Running {
return Err(RusTorchError::Profiling {
message: "Session not running".to_string(),
});
}
self.state = SessionState::Paused;
Ok(())
}
pub fn resume(&mut self) -> RusTorchResult<()> {
if self.state != SessionState::Paused {
return Err(RusTorchError::Profiling {
message: "Session not paused".to_string(),
});
}
self.state = SessionState::Running;
Ok(())
}
pub fn record_operation(&mut self, name: &str, duration: Duration, call_depth: usize) {
if self.state != SessionState::Running {
return;
}
let metrics = self
.operations
.entry(name.to_string())
.or_insert_with(|| OperationMetrics::new(name.to_string()));
metrics.record_timing(duration);
self.max_call_depth = self.max_call_depth.max(call_depth);
self.total_operations += 1;
}
pub fn duration(&self) -> Duration {
match self.end_time {
Some(end) => end.duration_since(self.start_time),
None => self.start_time.elapsed(),
}
}
pub fn create_snapshot(&self) -> SessionSnapshot {
let operations: Vec<_> = self.operations.values().cloned().collect();
SessionSnapshot {
session_id: self.session_id.clone(),
session_name: self.session_name.clone(),
start_time: self.start_time,
duration: self.duration(),
operations,
total_operations: self.total_operations,
max_call_depth: self.max_call_depth,
config: self.config.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct OperationMetrics {
pub name: String,
pub call_count: usize,
pub total_time: Duration,
pub avg_time: Duration,
pub min_time: Duration,
pub max_time: Duration,
pub std_dev: f64,
pub timing_samples: Vec<Duration>,
pub cpu_percentage: Option<f64>,
pub memory_usage: Option<u64>,
pub gpu_time: Option<Duration>,
}
impl OperationMetrics {
pub fn new(name: String) -> Self {
Self {
name,
call_count: 0,
total_time: Duration::ZERO,
avg_time: Duration::ZERO,
min_time: Duration::MAX,
max_time: Duration::ZERO,
std_dev: 0.0,
timing_samples: Vec::new(),
cpu_percentage: None,
memory_usage: None,
gpu_time: None,
}
}
pub fn record_timing(&mut self, duration: Duration) {
self.call_count += 1;
self.total_time += duration;
self.min_time = self.min_time.min(duration);
self.max_time = self.max_time.max(duration);
self.avg_time = self.total_time / self.call_count as u32;
self.timing_samples.push(duration);
if self.timing_samples.len() > 1000 {
self.timing_samples.drain(0..500);
}
self.update_std_dev();
}
fn update_std_dev(&mut self) {
if self.timing_samples.len() < 2 {
return;
}
let avg_secs = self.avg_time.as_secs_f64();
let variance: f64 = self
.timing_samples
.iter()
.map(|&d| {
let diff = d.as_secs_f64() - avg_secs;
diff * diff
})
.sum::<f64>()
/ (self.timing_samples.len() - 1) as f64;
self.std_dev = variance.sqrt();
}
pub fn get_statistics(&self) -> PerformanceStatistics {
PerformanceStatistics {
operation_name: self.name.clone(),
call_count: self.call_count,
total_time_ms: self.total_time.as_secs_f64() * 1000.0,
avg_time_ms: self.avg_time.as_secs_f64() * 1000.0,
min_time_ms: self.min_time.as_secs_f64() * 1000.0,
max_time_ms: self.max_time.as_secs_f64() * 1000.0,
std_dev_ms: self.std_dev * 1000.0,
throughput_ops_per_sec: if self.avg_time.as_secs_f64() > 0.0 {
1.0 / self.avg_time.as_secs_f64()
} else {
0.0
},
cpu_percentage: self.cpu_percentage.unwrap_or(0.0),
memory_usage_mb: self.memory_usage.unwrap_or(0) as f64 / (1024.0 * 1024.0),
gpu_time_ms: self.gpu_time.map(|d| d.as_secs_f64() * 1000.0),
}
}
}
#[derive(Debug, Clone)]
pub struct SessionSnapshot {
pub session_id: String,
pub session_name: String,
pub start_time: Instant,
pub duration: Duration,
pub operations: Vec<OperationMetrics>,
pub total_operations: usize,
pub max_call_depth: usize,
pub config: ProfilerConfig,
}
#[derive(Debug, Clone)]
pub struct PerformanceStatistics {
pub operation_name: String,
pub call_count: usize,
pub total_time_ms: f64,
pub avg_time_ms: f64,
pub min_time_ms: f64,
pub max_time_ms: f64,
pub std_dev_ms: f64,
pub throughput_ops_per_sec: f64,
pub cpu_percentage: f64,
pub memory_usage_mb: f64,
pub gpu_time_ms: Option<f64>,
}
#[derive(Debug)]
pub struct ProfilerCore {
current_session: Option<ProfilingSession>,
config: ProfilerConfig,
active_timers: HashMap<String, Instant>,
call_stack: Vec<String>,
session_history: Vec<SessionSnapshot>,
}
impl ProfilerCore {
pub fn new(config: ProfilerConfig) -> Self {
Self {
current_session: None,
config,
active_timers: HashMap::new(),
call_stack: Vec::new(),
session_history: Vec::new(),
}
}
pub fn start_session(&mut self, name: String) -> RusTorchResult<()> {
if self.current_session.is_some() {
return Err(RusTorchError::Profiling {
message: "Session already active".to_string(),
});
}
let mut session = ProfilingSession::new(name, self.config.clone());
session.start()?;
self.current_session = Some(session);
Ok(())
}
pub fn stop_session(&mut self) -> RusTorchResult<SessionSnapshot> {
let session = self
.current_session
.as_mut()
.ok_or_else(|| RusTorchError::Profiling {
message: "No active session".to_string(),
})?;
let snapshot = session.stop()?;
self.session_history.push(snapshot.clone());
self.current_session = None;
self.call_stack.clear();
self.active_timers.clear();
Ok(snapshot)
}
pub fn start_timer(&mut self, name: String) -> RusTorchResult<()> {
if let Some(session) = &self.current_session {
if session.state != SessionState::Running {
return Err(RusTorchError::Profiling {
message: "Session not running".to_string(),
});
}
} else {
return Err(RusTorchError::Profiling {
message: "No active session".to_string(),
});
}
self.active_timers.insert(name.clone(), Instant::now());
self.call_stack.push(name);
Ok(())
}
pub fn stop_timer(&mut self, name: &str) -> RusTorchResult<f64> {
let start_time =
self.active_timers
.remove(name)
.ok_or_else(|| RusTorchError::Profiling {
message: "Timer not found".to_string(),
})?;
let duration = start_time.elapsed();
if let Some(session) = &mut self.current_session {
session.record_operation(name, duration, self.call_stack.len());
}
if let Some(pos) = self.call_stack.iter().rposition(|x| x == name) {
self.call_stack.remove(pos);
}
Ok(duration.as_secs_f64() * 1000.0) }
pub fn record_custom_metric(
&mut self,
name: &str,
value: f64,
metric_type: super::metrics_collector::MetricType,
) -> RusTorchResult<()> {
if self.current_session.is_none() {
return Err(RusTorchError::Profiling {
message: "No active session".to_string(),
});
}
let duration = Duration::from_nanos((value * 1_000_000.0) as u64);
if let Some(session) = &mut self.current_session {
session.record_operation(name, duration, self.call_stack.len());
}
Ok(())
}
pub fn get_current_statistics(&self) -> Option<Vec<PerformanceStatistics>> {
self.current_session.as_ref().map(|session| {
session
.operations
.values()
.map(|op| op.get_statistics())
.collect()
})
}
pub fn get_session_history(&self) -> &[SessionSnapshot] {
&self.session_history
}
pub fn clear_history(&mut self) {
self.session_history.clear();
}
pub fn get_config(&self) -> &ProfilerConfig {
&self.config
}
pub fn update_config(&mut self, config: ProfilerConfig) -> RusTorchResult<()> {
if self.current_session.is_some() {
return Err(RusTorchError::Profiling {
message: "Cannot update config during active session".to_string(),
});
}
self.config = config;
Ok(())
}
}
fn generate_session_id() -> String {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis();
format!("session_{}", timestamp)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_profiler_core_creation() {
let config = ProfilerConfig::default();
let profiler = ProfilerCore::new(config);
assert!(profiler.current_session.is_none());
}
#[test]
fn test_session_lifecycle() {
let config = ProfilerConfig::default();
let mut profiler = ProfilerCore::new(config);
assert!(profiler.start_session("test".to_string()).is_ok());
assert!(profiler.current_session.is_some());
let snapshot = profiler.stop_session();
assert!(snapshot.is_ok());
assert!(profiler.current_session.is_none());
assert_eq!(profiler.session_history.len(), 1);
}
#[test]
fn test_timer_operations() {
let config = ProfilerConfig::default();
let mut profiler = ProfilerCore::new(config);
profiler.start_session("test".to_string()).unwrap();
assert!(profiler.start_timer("test_op".to_string()).is_ok());
std::thread::sleep(Duration::from_millis(10));
let elapsed = profiler.stop_timer("test_op");
assert!(elapsed.is_ok());
assert!(elapsed.unwrap() >= 10.0);
let stats = profiler.get_current_statistics().unwrap();
assert!(!stats.is_empty());
assert_eq!(stats[0].operation_name, "test_op");
assert_eq!(stats[0].call_count, 1);
}
#[test]
fn test_operation_metrics() {
let mut metrics = OperationMetrics::new("test".to_string());
metrics.record_timing(Duration::from_millis(100));
metrics.record_timing(Duration::from_millis(200));
metrics.record_timing(Duration::from_millis(150));
assert_eq!(metrics.call_count, 3);
assert_eq!(metrics.min_time, Duration::from_millis(100));
assert_eq!(metrics.max_time, Duration::from_millis(200));
assert_eq!(metrics.avg_time, Duration::from_millis(150));
let stats = metrics.get_statistics();
assert_eq!(stats.call_count, 3);
assert!((stats.avg_time_ms - 150.0).abs() < 0.1);
}
}