Skip to main content

oxirs_embed/
performance_profiler.rs

1//! Performance Profiling for Embedding Operations
2//!
3//! This module provides comprehensive performance profiling capabilities for
4//! knowledge graph embedding operations, including training, inference, and
5//! similarity computations.
6
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use std::time::{Duration, Instant};
12
13/// Operation types for profiling
14#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub enum OperationType {
16    Training,
17    Inference,
18    SimilarityComputation,
19    VectorSearch,
20    ModelSaving,
21    ModelLoading,
22    BatchProcessing,
23    EntityEmbedding,
24    RelationEmbedding,
25    TripleScoring,
26    Prediction,
27    Custom(String),
28}
29
30impl std::fmt::Display for OperationType {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Self::Training => write!(f, "Training"),
34            Self::Inference => write!(f, "Inference"),
35            Self::SimilarityComputation => write!(f, "Similarity"),
36            Self::VectorSearch => write!(f, "VectorSearch"),
37            Self::ModelSaving => write!(f, "ModelSave"),
38            Self::ModelLoading => write!(f, "ModelLoad"),
39            Self::BatchProcessing => write!(f, "BatchProcessing"),
40            Self::EntityEmbedding => write!(f, "EntityEmbedding"),
41            Self::RelationEmbedding => write!(f, "RelationEmbedding"),
42            Self::TripleScoring => write!(f, "TripleScoring"),
43            Self::Prediction => write!(f, "Prediction"),
44            Self::Custom(name) => write!(f, "{}", name),
45        }
46    }
47}
48
49/// Statistics for a specific operation type
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct OperationStats {
52    pub operation_type: OperationType,
53    pub total_count: u64,
54    pub total_duration: Duration,
55    pub min_duration: Duration,
56    pub max_duration: Duration,
57    pub average_duration: Duration,
58    pub percentile_95: Duration,
59    pub percentile_99: Duration,
60    pub error_count: u64,
61}
62
63impl OperationStats {
64    fn new(operation_type: OperationType) -> Self {
65        Self {
66            operation_type,
67            total_count: 0,
68            total_duration: Duration::ZERO,
69            min_duration: Duration::MAX,
70            max_duration: Duration::ZERO,
71            average_duration: Duration::ZERO,
72            percentile_95: Duration::ZERO,
73            percentile_99: Duration::ZERO,
74            error_count: 0,
75        }
76    }
77
78    fn update(&mut self, duration: Duration, is_error: bool) {
79        self.total_count += 1;
80        self.total_duration += duration;
81        self.min_duration = self.min_duration.min(duration);
82        self.max_duration = self.max_duration.max(duration);
83        self.average_duration = self.total_duration / self.total_count as u32;
84
85        if is_error {
86            self.error_count += 1;
87        }
88    }
89
90    /// Calculate success rate
91    pub fn success_rate(&self) -> f64 {
92        if self.total_count == 0 {
93            0.0
94        } else {
95            ((self.total_count - self.error_count) as f64 / self.total_count as f64) * 100.0
96        }
97    }
98
99    /// Calculate throughput (operations per second)
100    pub fn throughput(&self) -> f64 {
101        if self.total_duration.as_secs_f64() > 0.0 {
102            self.total_count as f64 / self.total_duration.as_secs_f64()
103        } else {
104            0.0
105        }
106    }
107}
108
109/// Performance profiler for embedding operations
110#[derive(Debug, Clone)]
111pub struct PerformanceProfiler {
112    stats: Arc<RwLock<HashMap<OperationType, OperationStats>>>,
113    durations_buffer: Arc<RwLock<HashMap<OperationType, Vec<Duration>>>>,
114    enabled: bool,
115}
116
117impl Default for PerformanceProfiler {
118    fn default() -> Self {
119        Self::new()
120    }
121}
122
123impl PerformanceProfiler {
124    /// Create a new performance profiler
125    pub fn new() -> Self {
126        Self {
127            stats: Arc::new(RwLock::new(HashMap::new())),
128            durations_buffer: Arc::new(RwLock::new(HashMap::new())),
129            enabled: true,
130        }
131    }
132
133    /// Enable profiling
134    pub fn enable(&mut self) {
135        self.enabled = true;
136    }
137
138    /// Disable profiling
139    pub fn disable(&mut self) {
140        self.enabled = false;
141    }
142
143    /// Check if profiling is enabled
144    pub fn is_enabled(&self) -> bool {
145        self.enabled
146    }
147
148    /// Start timing an operation
149    pub fn start_operation(&self, operation_type: OperationType) -> OperationTimer {
150        OperationTimer::new(operation_type, self.clone())
151    }
152
153    /// Record an operation duration
154    pub fn record_operation(
155        &self,
156        operation_type: OperationType,
157        duration: Duration,
158        is_error: bool,
159    ) {
160        if !self.enabled {
161            return;
162        }
163
164        // Update stats
165        let mut stats = self.stats.write().expect("lock should not be poisoned");
166        stats
167            .entry(operation_type.clone())
168            .or_insert_with(|| OperationStats::new(operation_type.clone()))
169            .update(duration, is_error);
170
171        // Store duration for percentile calculation
172        let mut durations = self
173            .durations_buffer
174            .write()
175            .expect("lock should not be poisoned");
176        durations
177            .entry(operation_type.clone())
178            .or_default()
179            .push(duration);
180
181        // Keep buffer size manageable (last 1000 operations)
182        if let Some(buffer) = durations.get_mut(&operation_type) {
183            if buffer.len() > 1000 {
184                buffer.remove(0);
185            }
186        }
187    }
188
189    /// Get statistics for a specific operation type
190    pub fn get_stats(&self, operation_type: OperationType) -> Option<OperationStats> {
191        let stats = self.stats.read().expect("read lock should not be poisoned");
192        stats.get(&operation_type).cloned()
193    }
194
195    /// Get all statistics
196    pub fn get_all_stats(&self) -> HashMap<OperationType, OperationStats> {
197        let stats = self.stats.read().expect("read lock should not be poisoned");
198        stats.clone()
199    }
200
201    /// Calculate percentiles for an operation type
202    pub fn calculate_percentiles(&self, operation_type: OperationType) -> Option<OperationStats> {
203        let durations = self
204            .durations_buffer
205            .read()
206            .expect("read lock should not be poisoned");
207        let mut stats = self.stats.write().expect("lock should not be poisoned");
208
209        if let Some(durations_vec) = durations.get(&operation_type) {
210            if let Some(op_stats) = stats.get_mut(&operation_type) {
211                let mut sorted_durations = durations_vec.clone();
212                sorted_durations.sort();
213
214                if !sorted_durations.is_empty() {
215                    let p95_index = (sorted_durations.len() as f64 * 0.95) as usize;
216                    let p99_index = (sorted_durations.len() as f64 * 0.99) as usize;
217
218                    op_stats.percentile_95 =
219                        sorted_durations[p95_index.min(sorted_durations.len() - 1)];
220                    op_stats.percentile_99 =
221                        sorted_durations[p99_index.min(sorted_durations.len() - 1)];
222                }
223
224                return Some(op_stats.clone());
225            }
226        }
227
228        None
229    }
230
231    /// Reset all statistics
232    pub fn reset(&self) {
233        let mut stats = self.stats.write().expect("lock should not be poisoned");
234        let mut durations = self
235            .durations_buffer
236            .write()
237            .expect("lock should not be poisoned");
238        stats.clear();
239        durations.clear();
240    }
241
242    /// Generate a performance report
243    pub fn generate_report(&self) -> PerformanceReport {
244        let stats = self.get_all_stats();
245
246        let total_operations: u64 = stats.values().map(|s| s.total_count).sum();
247        let total_errors: u64 = stats.values().map(|s| s.error_count).sum();
248        let total_duration: Duration = stats.values().map(|s| s.total_duration).sum();
249
250        PerformanceReport {
251            total_operations,
252            total_errors,
253            total_duration,
254            overall_success_rate: if total_operations > 0 {
255                ((total_operations - total_errors) as f64 / total_operations as f64) * 100.0
256            } else {
257                0.0
258            },
259            operation_stats: stats,
260        }
261    }
262
263    /// Export statistics to JSON
264    pub fn export_json(&self) -> Result<String> {
265        let report = self.generate_report();
266        serde_json::to_string_pretty(&report)
267            .map_err(|e| anyhow::anyhow!("Failed to serialize report: {}", e))
268    }
269}
270
271/// Timer for tracking operation duration
272pub struct OperationTimer {
273    operation_type: OperationType,
274    start_time: Instant,
275    profiler: PerformanceProfiler,
276    recorded: bool,
277}
278
279impl OperationTimer {
280    fn new(operation_type: OperationType, profiler: PerformanceProfiler) -> Self {
281        Self {
282            operation_type,
283            start_time: Instant::now(),
284            profiler,
285            recorded: false,
286        }
287    }
288
289    /// Stop the timer and record the duration
290    pub fn stop(mut self) {
291        self.record(false);
292    }
293
294    /// Stop the timer and record as an error
295    pub fn stop_with_error(mut self) {
296        self.record(true);
297    }
298
299    fn record(&mut self, is_error: bool) {
300        if !self.recorded {
301            let duration = self.start_time.elapsed();
302            self.profiler
303                .record_operation(self.operation_type.clone(), duration, is_error);
304            self.recorded = true;
305        }
306    }
307}
308
309impl Drop for OperationTimer {
310    fn drop(&mut self) {
311        // Auto-record if not explicitly stopped
312        if !self.recorded {
313            self.record(false);
314        }
315    }
316}
317
318/// Comprehensive performance report
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct PerformanceReport {
321    pub total_operations: u64,
322    pub total_errors: u64,
323    pub total_duration: Duration,
324    pub overall_success_rate: f64,
325    pub operation_stats: HashMap<OperationType, OperationStats>,
326}
327
328impl PerformanceReport {
329    /// Generate a human-readable summary
330    pub fn summary(&self) -> String {
331        let mut output = String::new();
332        output.push_str("╔════════════════════════════════════════════════════════════════════╗\n");
333        output.push_str("║           Embedding Performance Profiling Report                  ║\n");
334        output
335            .push_str("╚════════════════════════════════════════════════════════════════════╝\n\n");
336
337        output.push_str(&format!("Total Operations: {}\n", self.total_operations));
338        output.push_str(&format!("Total Errors: {}\n", self.total_errors));
339        output.push_str(&format!(
340            "Overall Success Rate: {:.2}%\n",
341            self.overall_success_rate
342        ));
343        output.push_str(&format!(
344            "Total Duration: {:.2}s\n\n",
345            self.total_duration.as_secs_f64()
346        ));
347
348        output.push_str("Operation Statistics:\n");
349        output.push_str("─────────────────────────────────────────────────────────────────────\n");
350
351        let mut sorted_ops: Vec<_> = self.operation_stats.iter().collect();
352        sorted_ops.sort_by_key(|(_, stats)| std::cmp::Reverse(stats.total_count));
353
354        for (_, stats) in sorted_ops {
355            output.push_str(&format!("\n{} Operations:\n", stats.operation_type));
356            output.push_str(&format!("  Count: {}\n", stats.total_count));
357            output.push_str(&format!("  Success Rate: {:.2}%\n", stats.success_rate()));
358            output.push_str(&format!(
359                "  Average Duration: {:.2}ms\n",
360                stats.average_duration.as_secs_f64() * 1000.0
361            ));
362            output.push_str(&format!(
363                "  Min Duration: {:.2}ms\n",
364                stats.min_duration.as_secs_f64() * 1000.0
365            ));
366            output.push_str(&format!(
367                "  Max Duration: {:.2}ms\n",
368                stats.max_duration.as_secs_f64() * 1000.0
369            ));
370            output.push_str(&format!(
371                "  P95 Duration: {:.2}ms\n",
372                stats.percentile_95.as_secs_f64() * 1000.0
373            ));
374            output.push_str(&format!(
375                "  P99 Duration: {:.2}ms\n",
376                stats.percentile_99.as_secs_f64() * 1000.0
377            ));
378            output.push_str(&format!(
379                "  Throughput: {:.2} ops/sec\n",
380                stats.throughput()
381            ));
382        }
383
384        output
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use std::thread;
392
393    #[test]
394    fn test_profiler_creation() {
395        let profiler = PerformanceProfiler::new();
396        assert!(profiler.is_enabled());
397    }
398
399    #[test]
400    fn test_operation_recording() {
401        let profiler = PerformanceProfiler::new();
402
403        profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
404        profiler.record_operation(OperationType::Training, Duration::from_millis(150), false);
405        profiler.record_operation(OperationType::Training, Duration::from_millis(120), true);
406
407        let stats = profiler.get_stats(OperationType::Training).unwrap();
408        assert_eq!(stats.total_count, 3);
409        assert_eq!(stats.error_count, 1);
410        assert!((stats.success_rate() - 66.67).abs() < 0.1);
411    }
412
413    #[test]
414    fn test_operation_timer() {
415        let profiler = PerformanceProfiler::new();
416
417        {
418            let _timer = profiler.start_operation(OperationType::Inference);
419            thread::sleep(Duration::from_millis(50));
420        }
421
422        let stats = profiler.get_stats(OperationType::Inference).unwrap();
423        assert_eq!(stats.total_count, 1);
424        assert!(stats.total_duration >= Duration::from_millis(50));
425    }
426
427    #[test]
428    fn test_multiple_operation_types() {
429        let profiler = PerformanceProfiler::new();
430
431        profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
432        profiler.record_operation(OperationType::Inference, Duration::from_millis(50), false);
433        profiler.record_operation(
434            OperationType::SimilarityComputation,
435            Duration::from_millis(25),
436            false,
437        );
438
439        let all_stats = profiler.get_all_stats();
440        assert_eq!(all_stats.len(), 3);
441    }
442
443    #[test]
444    fn test_profiler_reset() {
445        let profiler = PerformanceProfiler::new();
446
447        profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
448        assert_eq!(profiler.get_all_stats().len(), 1);
449
450        profiler.reset();
451        assert_eq!(profiler.get_all_stats().len(), 0);
452    }
453
454    #[test]
455    fn test_performance_report_generation() {
456        let profiler = PerformanceProfiler::new();
457
458        profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
459        profiler.record_operation(OperationType::Inference, Duration::from_millis(50), false);
460
461        let report = profiler.generate_report();
462        assert_eq!(report.total_operations, 2);
463        assert_eq!(report.total_errors, 0);
464        assert_eq!(report.overall_success_rate, 100.0);
465
466        let summary = report.summary();
467        assert!(summary.contains("Total Operations: 2"));
468    }
469
470    #[test]
471    fn test_percentile_calculation() {
472        let profiler = PerformanceProfiler::new();
473
474        // Record 100 operations with varying durations
475        for i in 1..=100 {
476            profiler.record_operation(OperationType::Inference, Duration::from_millis(i), false);
477        }
478
479        let stats = profiler
480            .calculate_percentiles(OperationType::Inference)
481            .unwrap();
482        assert!(stats.percentile_95 >= Duration::from_millis(90));
483        assert!(stats.percentile_99 >= Duration::from_millis(95));
484    }
485
486    #[test]
487    fn test_profiler_disable() {
488        let mut profiler = PerformanceProfiler::new();
489        profiler.disable();
490
491        profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
492
493        assert_eq!(profiler.get_all_stats().len(), 0);
494    }
495
496    #[test]
497    fn test_json_export() {
498        let profiler = PerformanceProfiler::new();
499
500        profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
501
502        let json = profiler.export_json().unwrap();
503        assert!(json.contains("total_operations"));
504        assert!(json.contains("Training"));
505    }
506}