oxirs_embed/
performance_profiler.rs

1//! Performance Profiling for Embedding Operations
2//!
3//! This module provides comprehensive performance profiling capabilities for
4//! knowledge graph embedding operations, including training, inference, and
5//! similarity computations.
6
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use std::time::{Duration, Instant};
12
13/// Operation types for profiling
14#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub enum OperationType {
16    Training,
17    Inference,
18    SimilarityComputation,
19    VectorSearch,
20    ModelSaving,
21    ModelLoading,
22    BatchProcessing,
23    EntityEmbedding,
24    RelationEmbedding,
25    TripleScoring,
26    Prediction,
27    Custom(String),
28}
29
30impl std::fmt::Display for OperationType {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Self::Training => write!(f, "Training"),
34            Self::Inference => write!(f, "Inference"),
35            Self::SimilarityComputation => write!(f, "Similarity"),
36            Self::VectorSearch => write!(f, "VectorSearch"),
37            Self::ModelSaving => write!(f, "ModelSave"),
38            Self::ModelLoading => write!(f, "ModelLoad"),
39            Self::BatchProcessing => write!(f, "BatchProcessing"),
40            Self::EntityEmbedding => write!(f, "EntityEmbedding"),
41            Self::RelationEmbedding => write!(f, "RelationEmbedding"),
42            Self::TripleScoring => write!(f, "TripleScoring"),
43            Self::Prediction => write!(f, "Prediction"),
44            Self::Custom(name) => write!(f, "{}", name),
45        }
46    }
47}
48
49/// Statistics for a specific operation type
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct OperationStats {
52    pub operation_type: OperationType,
53    pub total_count: u64,
54    pub total_duration: Duration,
55    pub min_duration: Duration,
56    pub max_duration: Duration,
57    pub average_duration: Duration,
58    pub percentile_95: Duration,
59    pub percentile_99: Duration,
60    pub error_count: u64,
61}
62
63impl OperationStats {
64    fn new(operation_type: OperationType) -> Self {
65        Self {
66            operation_type,
67            total_count: 0,
68            total_duration: Duration::ZERO,
69            min_duration: Duration::MAX,
70            max_duration: Duration::ZERO,
71            average_duration: Duration::ZERO,
72            percentile_95: Duration::ZERO,
73            percentile_99: Duration::ZERO,
74            error_count: 0,
75        }
76    }
77
78    fn update(&mut self, duration: Duration, is_error: bool) {
79        self.total_count += 1;
80        self.total_duration += duration;
81        self.min_duration = self.min_duration.min(duration);
82        self.max_duration = self.max_duration.max(duration);
83        self.average_duration = self.total_duration / self.total_count as u32;
84
85        if is_error {
86            self.error_count += 1;
87        }
88    }
89
90    /// Calculate success rate
91    pub fn success_rate(&self) -> f64 {
92        if self.total_count == 0 {
93            0.0
94        } else {
95            ((self.total_count - self.error_count) as f64 / self.total_count as f64) * 100.0
96        }
97    }
98
99    /// Calculate throughput (operations per second)
100    pub fn throughput(&self) -> f64 {
101        if self.total_duration.as_secs_f64() > 0.0 {
102            self.total_count as f64 / self.total_duration.as_secs_f64()
103        } else {
104            0.0
105        }
106    }
107}
108
109/// Performance profiler for embedding operations
110#[derive(Debug, Clone)]
111pub struct PerformanceProfiler {
112    stats: Arc<RwLock<HashMap<OperationType, OperationStats>>>,
113    durations_buffer: Arc<RwLock<HashMap<OperationType, Vec<Duration>>>>,
114    enabled: bool,
115}
116
117impl Default for PerformanceProfiler {
118    fn default() -> Self {
119        Self::new()
120    }
121}
122
123impl PerformanceProfiler {
124    /// Create a new performance profiler
125    pub fn new() -> Self {
126        Self {
127            stats: Arc::new(RwLock::new(HashMap::new())),
128            durations_buffer: Arc::new(RwLock::new(HashMap::new())),
129            enabled: true,
130        }
131    }
132
133    /// Enable profiling
134    pub fn enable(&mut self) {
135        self.enabled = true;
136    }
137
138    /// Disable profiling
139    pub fn disable(&mut self) {
140        self.enabled = false;
141    }
142
143    /// Check if profiling is enabled
144    pub fn is_enabled(&self) -> bool {
145        self.enabled
146    }
147
148    /// Start timing an operation
149    pub fn start_operation(&self, operation_type: OperationType) -> OperationTimer {
150        OperationTimer::new(operation_type, self.clone())
151    }
152
153    /// Record an operation duration
154    pub fn record_operation(
155        &self,
156        operation_type: OperationType,
157        duration: Duration,
158        is_error: bool,
159    ) {
160        if !self.enabled {
161            return;
162        }
163
164        // Update stats
165        let mut stats = self.stats.write().unwrap();
166        stats
167            .entry(operation_type.clone())
168            .or_insert_with(|| OperationStats::new(operation_type.clone()))
169            .update(duration, is_error);
170
171        // Store duration for percentile calculation
172        let mut durations = self.durations_buffer.write().unwrap();
173        durations
174            .entry(operation_type.clone())
175            .or_default()
176            .push(duration);
177
178        // Keep buffer size manageable (last 1000 operations)
179        if let Some(buffer) = durations.get_mut(&operation_type) {
180            if buffer.len() > 1000 {
181                buffer.remove(0);
182            }
183        }
184    }
185
186    /// Get statistics for a specific operation type
187    pub fn get_stats(&self, operation_type: OperationType) -> Option<OperationStats> {
188        let stats = self.stats.read().unwrap();
189        stats.get(&operation_type).cloned()
190    }
191
192    /// Get all statistics
193    pub fn get_all_stats(&self) -> HashMap<OperationType, OperationStats> {
194        let stats = self.stats.read().unwrap();
195        stats.clone()
196    }
197
198    /// Calculate percentiles for an operation type
199    pub fn calculate_percentiles(&self, operation_type: OperationType) -> Option<OperationStats> {
200        let durations = self.durations_buffer.read().unwrap();
201        let mut stats = self.stats.write().unwrap();
202
203        if let Some(durations_vec) = durations.get(&operation_type) {
204            if let Some(op_stats) = stats.get_mut(&operation_type) {
205                let mut sorted_durations = durations_vec.clone();
206                sorted_durations.sort();
207
208                if !sorted_durations.is_empty() {
209                    let p95_index = (sorted_durations.len() as f64 * 0.95) as usize;
210                    let p99_index = (sorted_durations.len() as f64 * 0.99) as usize;
211
212                    op_stats.percentile_95 =
213                        sorted_durations[p95_index.min(sorted_durations.len() - 1)];
214                    op_stats.percentile_99 =
215                        sorted_durations[p99_index.min(sorted_durations.len() - 1)];
216                }
217
218                return Some(op_stats.clone());
219            }
220        }
221
222        None
223    }
224
225    /// Reset all statistics
226    pub fn reset(&self) {
227        let mut stats = self.stats.write().unwrap();
228        let mut durations = self.durations_buffer.write().unwrap();
229        stats.clear();
230        durations.clear();
231    }
232
233    /// Generate a performance report
234    pub fn generate_report(&self) -> PerformanceReport {
235        let stats = self.get_all_stats();
236
237        let total_operations: u64 = stats.values().map(|s| s.total_count).sum();
238        let total_errors: u64 = stats.values().map(|s| s.error_count).sum();
239        let total_duration: Duration = stats.values().map(|s| s.total_duration).sum();
240
241        PerformanceReport {
242            total_operations,
243            total_errors,
244            total_duration,
245            overall_success_rate: if total_operations > 0 {
246                ((total_operations - total_errors) as f64 / total_operations as f64) * 100.0
247            } else {
248                0.0
249            },
250            operation_stats: stats,
251        }
252    }
253
254    /// Export statistics to JSON
255    pub fn export_json(&self) -> Result<String> {
256        let report = self.generate_report();
257        serde_json::to_string_pretty(&report)
258            .map_err(|e| anyhow::anyhow!("Failed to serialize report: {}", e))
259    }
260}
261
262/// Timer for tracking operation duration
263pub struct OperationTimer {
264    operation_type: OperationType,
265    start_time: Instant,
266    profiler: PerformanceProfiler,
267    recorded: bool,
268}
269
270impl OperationTimer {
271    fn new(operation_type: OperationType, profiler: PerformanceProfiler) -> Self {
272        Self {
273            operation_type,
274            start_time: Instant::now(),
275            profiler,
276            recorded: false,
277        }
278    }
279
280    /// Stop the timer and record the duration
281    pub fn stop(mut self) {
282        self.record(false);
283    }
284
285    /// Stop the timer and record as an error
286    pub fn stop_with_error(mut self) {
287        self.record(true);
288    }
289
290    fn record(&mut self, is_error: bool) {
291        if !self.recorded {
292            let duration = self.start_time.elapsed();
293            self.profiler
294                .record_operation(self.operation_type.clone(), duration, is_error);
295            self.recorded = true;
296        }
297    }
298}
299
300impl Drop for OperationTimer {
301    fn drop(&mut self) {
302        // Auto-record if not explicitly stopped
303        if !self.recorded {
304            self.record(false);
305        }
306    }
307}
308
309/// Comprehensive performance report
310#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct PerformanceReport {
312    pub total_operations: u64,
313    pub total_errors: u64,
314    pub total_duration: Duration,
315    pub overall_success_rate: f64,
316    pub operation_stats: HashMap<OperationType, OperationStats>,
317}
318
319impl PerformanceReport {
320    /// Generate a human-readable summary
321    pub fn summary(&self) -> String {
322        let mut output = String::new();
323        output.push_str("╔════════════════════════════════════════════════════════════════════╗\n");
324        output.push_str("║           Embedding Performance Profiling Report                  ║\n");
325        output
326            .push_str("╚════════════════════════════════════════════════════════════════════╝\n\n");
327
328        output.push_str(&format!("Total Operations: {}\n", self.total_operations));
329        output.push_str(&format!("Total Errors: {}\n", self.total_errors));
330        output.push_str(&format!(
331            "Overall Success Rate: {:.2}%\n",
332            self.overall_success_rate
333        ));
334        output.push_str(&format!(
335            "Total Duration: {:.2}s\n\n",
336            self.total_duration.as_secs_f64()
337        ));
338
339        output.push_str("Operation Statistics:\n");
340        output.push_str("─────────────────────────────────────────────────────────────────────\n");
341
342        let mut sorted_ops: Vec<_> = self.operation_stats.iter().collect();
343        sorted_ops.sort_by_key(|(_, stats)| std::cmp::Reverse(stats.total_count));
344
345        for (_, stats) in sorted_ops {
346            output.push_str(&format!("\n{} Operations:\n", stats.operation_type));
347            output.push_str(&format!("  Count: {}\n", stats.total_count));
348            output.push_str(&format!("  Success Rate: {:.2}%\n", stats.success_rate()));
349            output.push_str(&format!(
350                "  Average Duration: {:.2}ms\n",
351                stats.average_duration.as_secs_f64() * 1000.0
352            ));
353            output.push_str(&format!(
354                "  Min Duration: {:.2}ms\n",
355                stats.min_duration.as_secs_f64() * 1000.0
356            ));
357            output.push_str(&format!(
358                "  Max Duration: {:.2}ms\n",
359                stats.max_duration.as_secs_f64() * 1000.0
360            ));
361            output.push_str(&format!(
362                "  P95 Duration: {:.2}ms\n",
363                stats.percentile_95.as_secs_f64() * 1000.0
364            ));
365            output.push_str(&format!(
366                "  P99 Duration: {:.2}ms\n",
367                stats.percentile_99.as_secs_f64() * 1000.0
368            ));
369            output.push_str(&format!(
370                "  Throughput: {:.2} ops/sec\n",
371                stats.throughput()
372            ));
373        }
374
375        output
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use std::thread;
383
384    #[test]
385    fn test_profiler_creation() {
386        let profiler = PerformanceProfiler::new();
387        assert!(profiler.is_enabled());
388    }
389
390    #[test]
391    fn test_operation_recording() {
392        let profiler = PerformanceProfiler::new();
393
394        profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
395        profiler.record_operation(OperationType::Training, Duration::from_millis(150), false);
396        profiler.record_operation(OperationType::Training, Duration::from_millis(120), true);
397
398        let stats = profiler.get_stats(OperationType::Training).unwrap();
399        assert_eq!(stats.total_count, 3);
400        assert_eq!(stats.error_count, 1);
401        assert!((stats.success_rate() - 66.67).abs() < 0.1);
402    }
403
404    #[test]
405    fn test_operation_timer() {
406        let profiler = PerformanceProfiler::new();
407
408        {
409            let _timer = profiler.start_operation(OperationType::Inference);
410            thread::sleep(Duration::from_millis(50));
411        }
412
413        let stats = profiler.get_stats(OperationType::Inference).unwrap();
414        assert_eq!(stats.total_count, 1);
415        assert!(stats.total_duration >= Duration::from_millis(50));
416    }
417
418    #[test]
419    fn test_multiple_operation_types() {
420        let profiler = PerformanceProfiler::new();
421
422        profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
423        profiler.record_operation(OperationType::Inference, Duration::from_millis(50), false);
424        profiler.record_operation(
425            OperationType::SimilarityComputation,
426            Duration::from_millis(25),
427            false,
428        );
429
430        let all_stats = profiler.get_all_stats();
431        assert_eq!(all_stats.len(), 3);
432    }
433
434    #[test]
435    fn test_profiler_reset() {
436        let profiler = PerformanceProfiler::new();
437
438        profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
439        assert_eq!(profiler.get_all_stats().len(), 1);
440
441        profiler.reset();
442        assert_eq!(profiler.get_all_stats().len(), 0);
443    }
444
445    #[test]
446    fn test_performance_report_generation() {
447        let profiler = PerformanceProfiler::new();
448
449        profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
450        profiler.record_operation(OperationType::Inference, Duration::from_millis(50), false);
451
452        let report = profiler.generate_report();
453        assert_eq!(report.total_operations, 2);
454        assert_eq!(report.total_errors, 0);
455        assert_eq!(report.overall_success_rate, 100.0);
456
457        let summary = report.summary();
458        assert!(summary.contains("Total Operations: 2"));
459    }
460
461    #[test]
462    fn test_percentile_calculation() {
463        let profiler = PerformanceProfiler::new();
464
465        // Record 100 operations with varying durations
466        for i in 1..=100 {
467            profiler.record_operation(OperationType::Inference, Duration::from_millis(i), false);
468        }
469
470        let stats = profiler
471            .calculate_percentiles(OperationType::Inference)
472            .unwrap();
473        assert!(stats.percentile_95 >= Duration::from_millis(90));
474        assert!(stats.percentile_99 >= Duration::from_millis(95));
475    }
476
477    #[test]
478    fn test_profiler_disable() {
479        let mut profiler = PerformanceProfiler::new();
480        profiler.disable();
481
482        profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
483
484        assert_eq!(profiler.get_all_stats().len(), 0);
485    }
486
487    #[test]
488    fn test_json_export() {
489        let profiler = PerformanceProfiler::new();
490
491        profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
492
493        let json = profiler.export_json().unwrap();
494        assert!(json.contains("total_operations"));
495        assert!(json.contains("Training"));
496    }
497}