use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
#[derive(Debug)]
pub struct EnhancedProfiler {
sessions: Arc<RwLock<HashMap<String, ProfilingSession>>>,
global_metrics: Arc<RwLock<GlobalMetrics>>,
config: ProfilerConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProfilerConfig {
pub hardware_profiling: bool,
pub memory_leak_detection: bool,
pub real_time_alerts: bool,
pub ai_powered_analysis: bool,
pub sampling_interval_ms: u64,
pub max_samples: usize,
pub thresholds: PerformanceThresholds,
pub export_formats: Vec<ExportFormat>,
}
impl Default for ProfilerConfig {
fn default() -> Self {
Self {
hardware_profiling: true,
memory_leak_detection: true,
real_time_alerts: true,
ai_powered_analysis: false, sampling_interval_ms: 100,
max_samples: 10000,
thresholds: PerformanceThresholds::default(),
export_formats: vec![ExportFormat::JSON, ExportFormat::Prometheus],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceThresholds {
pub max_latency_ms: f32,
pub min_throughput_ops_per_sec: f32,
pub max_memory_usage_mb: f32,
pub max_cpu_usage_percent: f32,
pub max_gpu_usage_percent: f32,
pub memory_leak_threshold_mb: f32,
}
impl Default for PerformanceThresholds {
fn default() -> Self {
Self {
max_latency_ms: 1000.0,
min_throughput_ops_per_sec: 10.0,
max_memory_usage_mb: 1024.0,
max_cpu_usage_percent: 90.0,
max_gpu_usage_percent: 95.0,
memory_leak_threshold_mb: 10.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ExportFormat {
JSON,
CSV,
Prometheus,
Flamegraph,
OpenTelemetry,
Jaeger,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProfilingSession {
pub session_id: String,
pub operation_name: String,
#[serde(skip, default = "Instant::now")]
pub start_time: Instant,
pub samples: Vec<PerformanceSample>,
pub hardware_info: HardwareInfo,
pub memory_tracker: MemoryTracker,
pub status: SessionStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceSample {
#[serde(skip, default = "Instant::now")]
pub timestamp: Instant,
pub latency_ms: f32,
pub throughput_ops_per_sec: f32,
pub memory_usage_mb: f32,
pub cpu_usage_percent: f32,
pub gpu_usage_percent: f32,
pub custom_metrics: HashMap<String, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HardwareInfo {
pub cpu_cores: usize,
pub cpu_model: String,
pub total_memory_gb: f32,
pub gpu_info: Vec<GPUInfo>,
pub platform: Platform,
pub specialized_hardware: Vec<SpecializedHardware>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GPUInfo {
pub name: String,
pub memory_gb: f32,
pub compute_capability: String,
pub utilization_percent: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Platform {
Linux,
Windows,
MacOS,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SpecializedHardware {
CUDA,
ROCm,
Metal,
OpenCL,
TensorRT,
CoreML,
ONNX,
TPU,
NPU,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryTracker {
pub initial_memory_mb: f32,
pub peak_memory_mb: f32,
pub current_memory_mb: f32,
pub allocation_count: u64,
pub deallocation_count: u64,
pub leak_detected: bool,
pub memory_samples: Vec<MemorySample>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemorySample {
#[serde(skip, default = "Instant::now")]
pub timestamp: Instant,
pub memory_mb: f32,
pub allocations: u64,
pub deallocations: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SessionStatus {
Active,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct GlobalMetrics {
pub total_sessions: u64,
pub active_sessions: u64,
pub average_latency_ms: f32,
pub total_operations: u64,
pub memory_leaks_detected: u64,
pub performance_alerts: u64,
pub optimization_suggestions: Vec<OptimizationSuggestion>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationSuggestion {
pub category: OptimizationCategory,
pub severity: SuggestionSeverity,
pub description: String,
pub suggested_action: String,
pub expected_improvement: String,
pub confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum OptimizationCategory {
Memory,
CPU,
GPU,
IO,
NetworkLatency,
ModelArchitecture,
BatchSize,
Quantization,
Caching,
Threading,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SuggestionSeverity {
Critical,
High,
Medium,
Low,
Info,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceAnalysis {
pub session_summary: SessionSummary,
pub performance_trends: PerformanceTrends,
pub bottleneck_analysis: BottleneckAnalysis,
pub optimization_recommendations: Vec<OptimizationSuggestion>,
pub hardware_utilization: HardwareUtilization,
pub memory_analysis: MemoryAnalysis,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionSummary {
pub total_duration_ms: f32,
pub total_operations: u64,
pub average_latency_ms: f32,
pub p95_latency_ms: f32,
pub p99_latency_ms: f32,
pub peak_throughput_ops_per_sec: f32,
pub peak_memory_mb: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceTrends {
pub latency_trend: TrendDirection,
pub throughput_trend: TrendDirection,
pub memory_trend: TrendDirection,
pub trend_confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TrendDirection {
Improving,
Stable,
Degrading,
Volatile,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BottleneckAnalysis {
pub primary_bottleneck: BottleneckType,
pub bottleneck_severity: f32,
pub contributing_factors: Vec<String>,
pub impact_analysis: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BottleneckType {
CPU,
Memory,
GPU,
IO,
Network,
ModelComplexity,
DataLoading,
Synchronization,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HardwareUtilization {
pub cpu_utilization_percent: f32,
pub memory_utilization_percent: f32,
pub gpu_utilization_percent: f32,
pub efficiency_score: f32,
pub underutilized_resources: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryAnalysis {
pub leak_probability: f32,
pub fragmentation_level: f32,
pub allocation_pattern: AllocationPattern,
pub gc_impact: f32,
pub optimization_potential: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AllocationPattern {
Steady,
Spiky,
Growing,
Cyclical,
Chaotic,
}
impl EnhancedProfiler {
pub fn new(config: ProfilerConfig) -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
global_metrics: Arc::new(RwLock::new(GlobalMetrics::default())),
config,
}
}
pub async fn start_session(
&self,
session_id: String,
operation_name: String,
) -> Result<(), String> {
let hardware_info = self.detect_hardware().await;
let session = ProfilingSession {
session_id: session_id.clone(),
operation_name,
start_time: Instant::now(),
samples: Vec::new(),
hardware_info,
memory_tracker: MemoryTracker {
initial_memory_mb: self.get_current_memory_usage(),
peak_memory_mb: 0.0,
current_memory_mb: 0.0,
allocation_count: 0,
deallocation_count: 0,
leak_detected: false,
memory_samples: Vec::new(),
},
status: SessionStatus::Active,
};
let mut sessions = self.sessions.write().await;
sessions.insert(session_id, session);
let mut global_metrics = self.global_metrics.write().await;
global_metrics.total_sessions += 1;
global_metrics.active_sessions += 1;
Ok(())
}
pub async fn record_sample(
&self,
session_id: &str,
custom_metrics: HashMap<String, f64>,
) -> Result<(), String> {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
let sample = PerformanceSample {
timestamp: Instant::now(),
latency_ms: session.start_time.elapsed().as_millis() as f32,
throughput_ops_per_sec: self.calculate_throughput(session).await,
memory_usage_mb: self.get_current_memory_usage(),
cpu_usage_percent: self.get_cpu_usage().await,
gpu_usage_percent: self.get_gpu_usage().await,
custom_metrics,
};
session.memory_tracker.current_memory_mb = sample.memory_usage_mb;
if sample.memory_usage_mb > session.memory_tracker.peak_memory_mb {
session.memory_tracker.peak_memory_mb = sample.memory_usage_mb;
}
session.memory_tracker.memory_samples.push(MemorySample {
timestamp: sample.timestamp,
memory_mb: sample.memory_usage_mb,
allocations: session.memory_tracker.allocation_count,
deallocations: session.memory_tracker.deallocation_count,
});
session.samples.push(sample);
if self.config.real_time_alerts {
self.check_performance_alerts(session).await;
}
if session.samples.len() > self.config.max_samples {
session.samples.remove(0);
}
Ok(())
} else {
Err(format!("Session {} not found", session_id))
}
}
pub async fn end_session(&self, session_id: &str) -> Result<PerformanceAnalysis, String> {
let mut sessions = self.sessions.write().await;
if let Some(mut session) = sessions.remove(session_id) {
session.status = SessionStatus::Completed;
let mut global_metrics = self.global_metrics.write().await;
global_metrics.active_sessions -= 1;
let analysis = self.generate_analysis(&session).await;
global_metrics
.optimization_suggestions
.extend(analysis.optimization_recommendations.clone());
Ok(analysis)
} else {
Err(format!("Session {} not found", session_id))
}
}
async fn detect_hardware(&self) -> HardwareInfo {
HardwareInfo {
cpu_cores: num_cpus::get(),
cpu_model: "Mock CPU Model".to_string(),
total_memory_gb: 16.0, gpu_info: vec![GPUInfo {
name: "Mock GPU".to_string(),
memory_gb: 8.0,
compute_capability: "8.6".to_string(),
utilization_percent: 0.0,
}],
platform: if cfg!(target_os = "linux") {
Platform::Linux
} else if cfg!(target_os = "windows") {
Platform::Windows
} else if cfg!(target_os = "macos") {
Platform::MacOS
} else {
Platform::Unknown
},
specialized_hardware: vec![
SpecializedHardware::CUDA,
SpecializedHardware::Metal,
SpecializedHardware::ONNX,
],
}
}
fn get_current_memory_usage(&self) -> f32 {
100.0 + (std::ptr::addr_of!(self) as usize % 100) as f32 / 2.0
}
async fn get_cpu_usage(&self) -> f32 {
20.0 + (std::ptr::addr_of!(self) as usize % 60) as f32
}
async fn get_gpu_usage(&self) -> f32 {
10.0 + (std::ptr::addr_of!(self) as usize % 80) as f32
}
async fn calculate_throughput(&self, session: &ProfilingSession) -> f32 {
let duration_sec = session.start_time.elapsed().as_secs_f32();
if duration_sec > 0.0 {
session.samples.len() as f32 / duration_sec
} else {
0.0
}
}
async fn check_performance_alerts(&self, session: &ProfilingSession) {
if let Some(latest_sample) = session.samples.last() {
let mut alerts_triggered = 0;
if latest_sample.latency_ms > self.config.thresholds.max_latency_ms {
alerts_triggered += 1;
println!(
"ALERT: High latency detected: {:.2}ms",
latest_sample.latency_ms
);
}
if latest_sample.memory_usage_mb > self.config.thresholds.max_memory_usage_mb {
alerts_triggered += 1;
println!(
"ALERT: High memory usage: {:.2}MB",
latest_sample.memory_usage_mb
);
}
if latest_sample.cpu_usage_percent > self.config.thresholds.max_cpu_usage_percent {
alerts_triggered += 1;
println!(
"ALERT: High CPU usage: {:.2}%",
latest_sample.cpu_usage_percent
);
}
if alerts_triggered > 0 {
let mut global_metrics = self.global_metrics.write().await;
global_metrics.performance_alerts += alerts_triggered;
}
}
}
async fn generate_analysis(&self, session: &ProfilingSession) -> PerformanceAnalysis {
let session_summary = self.calculate_session_summary(session);
let performance_trends = self.analyze_trends(session);
let bottleneck_analysis = self.analyze_bottlenecks(session);
let optimization_recommendations =
self.generate_optimization_recommendations(session).await;
let hardware_utilization = self.analyze_hardware_utilization(session);
let memory_analysis = self.analyze_memory_usage(session);
PerformanceAnalysis {
session_summary,
performance_trends,
bottleneck_analysis,
optimization_recommendations,
hardware_utilization,
memory_analysis,
}
}
fn calculate_session_summary(&self, session: &ProfilingSession) -> SessionSummary {
let latencies: Vec<f32> = session.samples.iter().map(|s| s.latency_ms).collect();
let throughputs: Vec<f32> =
session.samples.iter().map(|s| s.throughput_ops_per_sec).collect();
SessionSummary {
total_duration_ms: session.start_time.elapsed().as_millis() as f32,
total_operations: session.samples.len() as u64,
average_latency_ms: latencies.iter().sum::<f32>() / latencies.len() as f32,
p95_latency_ms: self.percentile(&latencies, 0.95),
p99_latency_ms: self.percentile(&latencies, 0.99),
peak_throughput_ops_per_sec: throughputs.iter().cloned().fold(0.0f32, f32::max),
peak_memory_mb: session.memory_tracker.peak_memory_mb,
}
}
fn percentile(&self, data: &[f32], percentile: f32) -> f32 {
let mut sorted_data = data.to_vec();
sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let index = ((data.len() as f32 - 1.0) * percentile) as usize;
sorted_data.get(index).copied().unwrap_or(0.0)
}
fn analyze_trends(&self, session: &ProfilingSession) -> PerformanceTrends {
PerformanceTrends {
latency_trend: TrendDirection::Stable,
throughput_trend: TrendDirection::Improving,
memory_trend: TrendDirection::Stable,
trend_confidence: 0.8,
}
}
fn analyze_bottlenecks(&self, session: &ProfilingSession) -> BottleneckAnalysis {
let avg_cpu = session.samples.iter().map(|s| s.cpu_usage_percent).sum::<f32>()
/ session.samples.len() as f32;
let avg_memory = session.samples.iter().map(|s| s.memory_usage_mb).sum::<f32>()
/ session.samples.len() as f32;
let primary_bottleneck = if avg_cpu > 80.0 {
BottleneckType::CPU
} else if avg_memory > 1000.0 {
BottleneckType::Memory
} else {
BottleneckType::Unknown
};
BottleneckAnalysis {
primary_bottleneck,
bottleneck_severity: 0.5,
contributing_factors: vec!["Mock factor 1".to_string(), "Mock factor 2".to_string()],
impact_analysis: "Moderate impact on overall performance".to_string(),
}
}
async fn generate_optimization_recommendations(
&self,
session: &ProfilingSession,
) -> Vec<OptimizationSuggestion> {
let mut suggestions = Vec::new();
if session.memory_tracker.peak_memory_mb > 500.0 {
suggestions.push(OptimizationSuggestion {
category: OptimizationCategory::Memory,
severity: SuggestionSeverity::Medium,
description: "High peak memory usage detected".to_string(),
suggested_action: "Consider implementing memory pooling or reducing batch size"
.to_string(),
expected_improvement: "20-30% reduction in memory usage".to_string(),
confidence: 0.85,
});
}
if session.samples.len() > 100 {
suggestions.push(OptimizationSuggestion {
category: OptimizationCategory::BatchSize,
severity: SuggestionSeverity::Low,
description: "Batch size may be sub-optimal for throughput".to_string(),
suggested_action: "Experiment with larger batch sizes for better GPU utilization"
.to_string(),
expected_improvement: "15-25% improvement in throughput".to_string(),
confidence: 0.7,
});
}
suggestions
}
fn analyze_hardware_utilization(&self, session: &ProfilingSession) -> HardwareUtilization {
let avg_cpu = session.samples.iter().map(|s| s.cpu_usage_percent).sum::<f32>()
/ session.samples.len() as f32;
let avg_gpu = session.samples.iter().map(|s| s.gpu_usage_percent).sum::<f32>()
/ session.samples.len() as f32;
let avg_memory = session.samples.iter().map(|s| s.memory_usage_mb).sum::<f32>()
/ session.samples.len() as f32;
HardwareUtilization {
cpu_utilization_percent: avg_cpu,
memory_utilization_percent: avg_memory / session.hardware_info.total_memory_gb / 10.24, gpu_utilization_percent: avg_gpu,
efficiency_score: (avg_cpu + avg_gpu) / 2.0 / 100.0,
underutilized_resources: vec!["GPU".to_string()], }
}
fn analyze_memory_usage(&self, session: &ProfilingSession) -> MemoryAnalysis {
let memory_growth =
session.memory_tracker.peak_memory_mb - session.memory_tracker.initial_memory_mb;
MemoryAnalysis {
leak_probability: if memory_growth > 50.0 { 0.7 } else { 0.2 },
fragmentation_level: 0.3, allocation_pattern: AllocationPattern::Steady, gc_impact: 0.1, optimization_potential: 0.6, }
}
pub async fn export_data(
&self,
session_id: &str,
format: ExportFormat,
) -> Result<String, String> {
let sessions = self.sessions.read().await;
if let Some(session) = sessions.get(session_id) {
match format {
ExportFormat::JSON => serde_json::to_string_pretty(session)
.map_err(|e| format!("JSON export failed: {}", e)),
ExportFormat::CSV => {
let mut csv =
"timestamp,latency_ms,throughput,memory_mb,cpu_percent,gpu_percent\n"
.to_string();
for sample in &session.samples {
csv.push_str(&format!(
"{:?},{},{},{},{},{}\n",
sample.timestamp,
sample.latency_ms,
sample.throughput_ops_per_sec,
sample.memory_usage_mb,
sample.cpu_usage_percent,
sample.gpu_usage_percent
));
}
Ok(csv)
},
ExportFormat::Prometheus => {
let mut prometheus = String::new();
if let Some(latest_sample) = session.samples.last() {
prometheus.push_str(&format!(
"# HELP trustformers_latency_ms Current latency in milliseconds\n\
# TYPE trustformers_latency_ms gauge\n\
trustformers_latency_ms{{session=\"{}\"}} {}\n",
session_id, latest_sample.latency_ms
));
}
Ok(prometheus)
},
_ => Err("Export format not implemented".to_string()),
}
} else {
Err(format!("Session {} not found", session_id))
}
}
pub async fn get_global_metrics(&self) -> GlobalMetrics {
self.global_metrics.read().await.clone()
}
}
static GLOBAL_PROFILER: std::sync::OnceLock<Arc<EnhancedProfiler>> = std::sync::OnceLock::new();
pub fn init_global_profiler(config: ProfilerConfig) {
let _ = GLOBAL_PROFILER.get_or_init(|| Arc::new(EnhancedProfiler::new(config)));
}
pub fn global_profiler() -> Option<Arc<EnhancedProfiler>> {
GLOBAL_PROFILER.get().cloned()
}
#[macro_export]
macro_rules! enhanced_profile_operation {
($operation_name:expr, $block:block) => {{
let profiler = global_profiler().expect("Profiler not initialized");
let session_id = format!(
"{}_{}",
$operation_name,
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("System time is before UNIX_EPOCH")
.as_nanos()
);
profiler
.start_session(session_id.clone(), $operation_name.to_string())
.await
.expect("Failed to start profiler session");
let result = $block;
profiler
.record_sample(&session_id, std::collections::HashMap::new())
.await
.expect("Failed to record profiler sample");
let _analysis =
profiler.end_session(&session_id).await.expect("Failed to end profiler session");
result
}};
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{sleep, Duration};
#[tokio::test]
async fn test_enhanced_profiler_basic_functionality() {
let config = ProfilerConfig::default();
let profiler = EnhancedProfiler::new(config);
let session_id = "test_session".to_string();
let operation_name = "test_operation".to_string();
profiler
.start_session(session_id.clone(), operation_name)
.await
.expect("async operation failed");
for i in 0..105 {
let mut custom_metrics = HashMap::new();
custom_metrics.insert("iteration".to_string(), i as f64);
profiler
.record_sample(&session_id, custom_metrics)
.await
.expect("async operation failed");
if i % 20 == 0 {
sleep(Duration::from_millis(1)).await; }
}
let analysis = profiler.end_session(&session_id).await.expect("async operation failed");
assert!(analysis.session_summary.total_operations > 0);
assert!(analysis.session_summary.total_duration_ms > 0.0);
assert!(!analysis.optimization_recommendations.is_empty());
}
#[tokio::test]
async fn test_global_profiler() {
let config = ProfilerConfig::default();
init_global_profiler(config);
let profiler = global_profiler().expect("Global profiler should be initialized");
let session_id = "global_test".to_string();
profiler
.start_session(session_id.clone(), "global_test".to_string())
.await
.expect("operation failed in test");
let metrics = profiler.get_global_metrics().await;
assert!(metrics.total_sessions > 0);
}
#[tokio::test]
async fn test_export_functionality() {
let config = ProfilerConfig::default();
let profiler = EnhancedProfiler::new(config);
let session_id = "export_test".to_string();
profiler
.start_session(session_id.clone(), "export_test".to_string())
.await
.expect("operation failed in test");
profiler
.record_sample(&session_id, HashMap::new())
.await
.expect("async operation failed");
let json_export = profiler.export_data(&session_id, ExportFormat::JSON).await;
assert!(json_export.is_ok());
let csv_export = profiler.export_data(&session_id, ExportFormat::CSV).await;
assert!(csv_export.is_ok());
profiler.end_session(&session_id).await.expect("async operation failed");
}
#[test]
fn test_profiler_config_default_values() {
let config = ProfilerConfig::default();
assert!(
config.sampling_interval_ms > 0,
"sampling_interval_ms should be positive"
);
assert!(config.max_samples > 0, "max_samples should be positive");
assert!(
!config.export_formats.is_empty(),
"at least one export format should be enabled by default"
);
}
#[test]
fn test_performance_thresholds_default_values() {
let thresholds = PerformanceThresholds::default();
assert!(
thresholds.max_latency_ms > 0.0,
"max_latency_ms should be positive"
);
assert!(
thresholds.min_throughput_ops_per_sec > 0.0,
"min_throughput_ops_per_sec should be positive"
);
assert!(
thresholds.max_memory_usage_mb > 0.0,
"max_memory_usage_mb should be positive"
);
assert!(
thresholds.max_cpu_usage_percent > 0.0 && thresholds.max_cpu_usage_percent <= 100.0,
"max_cpu_usage_percent should be in (0.0, 100.0]"
);
}
#[tokio::test]
async fn test_profiler_start_increments_global_total_sessions() {
let config = ProfilerConfig::default();
let profiler = EnhancedProfiler::new(config);
let seed: u64 = 0xABCDEF0123456789;
let session_id = format!(
"session_{}",
seed.wrapping_mul(1103515245).wrapping_add(12345)
);
let initial_metrics = profiler.get_global_metrics().await;
profiler
.start_session(session_id.clone(), "op1".to_string())
.await
.expect("start_session should succeed");
let after_start = profiler.get_global_metrics().await;
assert_eq!(
after_start.total_sessions,
initial_metrics.total_sessions + 1,
"total_sessions should increment after start_session"
);
assert_eq!(
after_start.active_sessions,
initial_metrics.active_sessions + 1,
"active_sessions should increment after start_session"
);
}
#[tokio::test]
async fn test_profiler_end_session_decrements_active_sessions() {
let config = ProfilerConfig::default();
let profiler = EnhancedProfiler::new(config);
let session_id = "decrement_test".to_string();
profiler
.start_session(session_id.clone(), "op".to_string())
.await
.expect("start_session should succeed");
let after_start = profiler.get_global_metrics().await;
let active_before = after_start.active_sessions;
profiler.end_session(&session_id).await.expect("end_session should succeed");
let after_end = profiler.get_global_metrics().await;
assert_eq!(
after_end.active_sessions,
active_before - 1,
"active_sessions should decrement after end_session"
);
}
#[tokio::test]
async fn test_profiler_end_session_returns_analysis() {
let config = ProfilerConfig::default();
let profiler = EnhancedProfiler::new(config);
let session_id = "analysis_test".to_string();
profiler
.start_session(session_id.clone(), "test_op".to_string())
.await
.expect("start_session should succeed");
profiler
.record_sample(&session_id, HashMap::new())
.await
.expect("record_sample should succeed");
let analysis = profiler
.end_session(&session_id)
.await
.expect("end_session should return analysis");
assert!(
analysis.session_summary.total_operations > 0,
"session_summary should have at least one operation"
);
assert!(
analysis.session_summary.total_duration_ms >= 0.0,
"total_duration_ms should be non-negative"
);
}
#[tokio::test]
async fn test_profiler_total_duration_calculation() {
let config = ProfilerConfig::default();
let profiler = EnhancedProfiler::new(config);
let session_id = "duration_test".to_string();
profiler
.start_session(session_id.clone(), "duration_op".to_string())
.await
.expect("start_session should succeed");
for i in 0..5 {
let mut metrics = HashMap::new();
metrics.insert("i".to_string(), i as f64);
profiler
.record_sample(&session_id, metrics)
.await
.expect("record_sample should succeed");
}
let analysis = profiler.end_session(&session_id).await.expect("end_session should succeed");
assert!(
analysis.session_summary.total_duration_ms >= 0.0,
"total_duration_ms should be non-negative"
);
}
#[tokio::test]
async fn test_profiler_record_sample_stores_custom_metrics() {
let config = ProfilerConfig::default();
let profiler = EnhancedProfiler::new(config);
let session_id = "custom_metrics_test".to_string();
profiler
.start_session(session_id.clone(), "custom_op".to_string())
.await
.expect("start_session should succeed");
let mut custom = HashMap::new();
custom.insert("tokens_per_sec".to_string(), 42.5f64);
custom.insert("batch_size".to_string(), 32.0f64);
profiler
.record_sample(&session_id, custom)
.await
.expect("record_sample should succeed");
profiler.end_session(&session_id).await.expect("end_session should succeed");
}
#[tokio::test]
async fn test_profiler_multiple_concurrent_sessions() {
let config = ProfilerConfig::default();
let profiler = EnhancedProfiler::new(config);
let seed: u64 = 0xFEEDFACECAFEBABE;
let s1 = format!(
"session_a_{}",
seed.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407)
);
let s2 = format!(
"session_b_{}",
seed.wrapping_mul(1103515245).wrapping_add(12345)
);
profiler
.start_session(s1.clone(), "op_a".to_string())
.await
.expect("start first session");
profiler
.start_session(s2.clone(), "op_b".to_string())
.await
.expect("start second session");
let metrics = profiler.get_global_metrics().await;
assert!(
metrics.active_sessions >= 2,
"should have at least 2 active sessions, got {}",
metrics.active_sessions
);
profiler.end_session(&s1).await.expect("end first session");
profiler.end_session(&s2).await.expect("end second session");
}
#[tokio::test]
async fn test_profiler_end_nonexistent_session_returns_err() {
let config = ProfilerConfig::default();
let profiler = EnhancedProfiler::new(config);
let result = profiler.end_session("nonexistent-session-xyz").await;
assert!(
result.is_err(),
"ending a non-existent session should return Err"
);
}
#[tokio::test]
async fn test_profiler_record_sample_nonexistent_returns_err() {
let config = ProfilerConfig::default();
let profiler = EnhancedProfiler::new(config);
let result = profiler.record_sample("nonexistent-xyz", HashMap::new()).await;
assert!(
result.is_err(),
"recording sample for non-existent session should return Err"
);
}
#[tokio::test]
async fn test_profiler_json_export_contains_session_id() {
let config = ProfilerConfig::default();
let profiler = EnhancedProfiler::new(config);
let session_id = "json_content_test".to_string();
profiler
.start_session(session_id.clone(), "test_op".to_string())
.await
.expect("start_session should succeed");
profiler
.record_sample(&session_id, HashMap::new())
.await
.expect("record_sample should succeed");
let json = profiler
.export_data(&session_id, ExportFormat::JSON)
.await
.expect("JSON export should succeed");
assert!(
json.contains(&session_id),
"JSON export should contain the session_id"
);
}
#[tokio::test]
async fn test_profiler_prometheus_export_format() {
let config = ProfilerConfig::default();
let profiler = EnhancedProfiler::new(config);
let session_id = "prometheus_test".to_string();
profiler
.start_session(session_id.clone(), "prom_op".to_string())
.await
.expect("start_session should succeed");
profiler
.record_sample(&session_id, HashMap::new())
.await
.expect("record_sample should succeed");
let result = profiler.export_data(&session_id, ExportFormat::Prometheus).await;
assert!(result.is_ok(), "Prometheus export should succeed");
let content = result.expect("prometheus content");
assert!(
content.contains("trustformers_latency_ms"),
"Prometheus export should contain metric name"
);
}
#[tokio::test]
async fn test_profiler_analysis_contains_optimization_recommendations() {
let config = ProfilerConfig::default();
let profiler = EnhancedProfiler::new(config);
let session_id = "opt_recs_test".to_string();
profiler
.start_session(session_id.clone(), "opt_op".to_string())
.await
.expect("start_session should succeed");
for i in 0..20u64 {
let mut metrics = HashMap::new();
let val = i.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
metrics.insert(
"iteration".to_string(),
(val >> 32) as f64 / u32::MAX as f64,
);
profiler
.record_sample(&session_id, metrics)
.await
.expect("record_sample should succeed");
}
let analysis = profiler.end_session(&session_id).await.expect("end_session should succeed");
let _ = analysis.optimization_recommendations.len();
}
}