1use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub enum OperationType {
16 Training,
17 Inference,
18 SimilarityComputation,
19 VectorSearch,
20 ModelSaving,
21 ModelLoading,
22 BatchProcessing,
23 EntityEmbedding,
24 RelationEmbedding,
25 TripleScoring,
26 Prediction,
27 Custom(String),
28}
29
30impl std::fmt::Display for OperationType {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Self::Training => write!(f, "Training"),
34 Self::Inference => write!(f, "Inference"),
35 Self::SimilarityComputation => write!(f, "Similarity"),
36 Self::VectorSearch => write!(f, "VectorSearch"),
37 Self::ModelSaving => write!(f, "ModelSave"),
38 Self::ModelLoading => write!(f, "ModelLoad"),
39 Self::BatchProcessing => write!(f, "BatchProcessing"),
40 Self::EntityEmbedding => write!(f, "EntityEmbedding"),
41 Self::RelationEmbedding => write!(f, "RelationEmbedding"),
42 Self::TripleScoring => write!(f, "TripleScoring"),
43 Self::Prediction => write!(f, "Prediction"),
44 Self::Custom(name) => write!(f, "{}", name),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct OperationStats {
52 pub operation_type: OperationType,
53 pub total_count: u64,
54 pub total_duration: Duration,
55 pub min_duration: Duration,
56 pub max_duration: Duration,
57 pub average_duration: Duration,
58 pub percentile_95: Duration,
59 pub percentile_99: Duration,
60 pub error_count: u64,
61}
62
63impl OperationStats {
64 fn new(operation_type: OperationType) -> Self {
65 Self {
66 operation_type,
67 total_count: 0,
68 total_duration: Duration::ZERO,
69 min_duration: Duration::MAX,
70 max_duration: Duration::ZERO,
71 average_duration: Duration::ZERO,
72 percentile_95: Duration::ZERO,
73 percentile_99: Duration::ZERO,
74 error_count: 0,
75 }
76 }
77
78 fn update(&mut self, duration: Duration, is_error: bool) {
79 self.total_count += 1;
80 self.total_duration += duration;
81 self.min_duration = self.min_duration.min(duration);
82 self.max_duration = self.max_duration.max(duration);
83 self.average_duration = self.total_duration / self.total_count as u32;
84
85 if is_error {
86 self.error_count += 1;
87 }
88 }
89
90 pub fn success_rate(&self) -> f64 {
92 if self.total_count == 0 {
93 0.0
94 } else {
95 ((self.total_count - self.error_count) as f64 / self.total_count as f64) * 100.0
96 }
97 }
98
99 pub fn throughput(&self) -> f64 {
101 if self.total_duration.as_secs_f64() > 0.0 {
102 self.total_count as f64 / self.total_duration.as_secs_f64()
103 } else {
104 0.0
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct PerformanceProfiler {
112 stats: Arc<RwLock<HashMap<OperationType, OperationStats>>>,
113 durations_buffer: Arc<RwLock<HashMap<OperationType, Vec<Duration>>>>,
114 enabled: bool,
115}
116
117impl Default for PerformanceProfiler {
118 fn default() -> Self {
119 Self::new()
120 }
121}
122
123impl PerformanceProfiler {
124 pub fn new() -> Self {
126 Self {
127 stats: Arc::new(RwLock::new(HashMap::new())),
128 durations_buffer: Arc::new(RwLock::new(HashMap::new())),
129 enabled: true,
130 }
131 }
132
133 pub fn enable(&mut self) {
135 self.enabled = true;
136 }
137
138 pub fn disable(&mut self) {
140 self.enabled = false;
141 }
142
143 pub fn is_enabled(&self) -> bool {
145 self.enabled
146 }
147
148 pub fn start_operation(&self, operation_type: OperationType) -> OperationTimer {
150 OperationTimer::new(operation_type, self.clone())
151 }
152
153 pub fn record_operation(
155 &self,
156 operation_type: OperationType,
157 duration: Duration,
158 is_error: bool,
159 ) {
160 if !self.enabled {
161 return;
162 }
163
164 let mut stats = self.stats.write().unwrap();
166 stats
167 .entry(operation_type.clone())
168 .or_insert_with(|| OperationStats::new(operation_type.clone()))
169 .update(duration, is_error);
170
171 let mut durations = self.durations_buffer.write().unwrap();
173 durations
174 .entry(operation_type.clone())
175 .or_default()
176 .push(duration);
177
178 if let Some(buffer) = durations.get_mut(&operation_type) {
180 if buffer.len() > 1000 {
181 buffer.remove(0);
182 }
183 }
184 }
185
186 pub fn get_stats(&self, operation_type: OperationType) -> Option<OperationStats> {
188 let stats = self.stats.read().unwrap();
189 stats.get(&operation_type).cloned()
190 }
191
192 pub fn get_all_stats(&self) -> HashMap<OperationType, OperationStats> {
194 let stats = self.stats.read().unwrap();
195 stats.clone()
196 }
197
198 pub fn calculate_percentiles(&self, operation_type: OperationType) -> Option<OperationStats> {
200 let durations = self.durations_buffer.read().unwrap();
201 let mut stats = self.stats.write().unwrap();
202
203 if let Some(durations_vec) = durations.get(&operation_type) {
204 if let Some(op_stats) = stats.get_mut(&operation_type) {
205 let mut sorted_durations = durations_vec.clone();
206 sorted_durations.sort();
207
208 if !sorted_durations.is_empty() {
209 let p95_index = (sorted_durations.len() as f64 * 0.95) as usize;
210 let p99_index = (sorted_durations.len() as f64 * 0.99) as usize;
211
212 op_stats.percentile_95 =
213 sorted_durations[p95_index.min(sorted_durations.len() - 1)];
214 op_stats.percentile_99 =
215 sorted_durations[p99_index.min(sorted_durations.len() - 1)];
216 }
217
218 return Some(op_stats.clone());
219 }
220 }
221
222 None
223 }
224
225 pub fn reset(&self) {
227 let mut stats = self.stats.write().unwrap();
228 let mut durations = self.durations_buffer.write().unwrap();
229 stats.clear();
230 durations.clear();
231 }
232
233 pub fn generate_report(&self) -> PerformanceReport {
235 let stats = self.get_all_stats();
236
237 let total_operations: u64 = stats.values().map(|s| s.total_count).sum();
238 let total_errors: u64 = stats.values().map(|s| s.error_count).sum();
239 let total_duration: Duration = stats.values().map(|s| s.total_duration).sum();
240
241 PerformanceReport {
242 total_operations,
243 total_errors,
244 total_duration,
245 overall_success_rate: if total_operations > 0 {
246 ((total_operations - total_errors) as f64 / total_operations as f64) * 100.0
247 } else {
248 0.0
249 },
250 operation_stats: stats,
251 }
252 }
253
254 pub fn export_json(&self) -> Result<String> {
256 let report = self.generate_report();
257 serde_json::to_string_pretty(&report)
258 .map_err(|e| anyhow::anyhow!("Failed to serialize report: {}", e))
259 }
260}
261
262pub struct OperationTimer {
264 operation_type: OperationType,
265 start_time: Instant,
266 profiler: PerformanceProfiler,
267 recorded: bool,
268}
269
270impl OperationTimer {
271 fn new(operation_type: OperationType, profiler: PerformanceProfiler) -> Self {
272 Self {
273 operation_type,
274 start_time: Instant::now(),
275 profiler,
276 recorded: false,
277 }
278 }
279
280 pub fn stop(mut self) {
282 self.record(false);
283 }
284
285 pub fn stop_with_error(mut self) {
287 self.record(true);
288 }
289
290 fn record(&mut self, is_error: bool) {
291 if !self.recorded {
292 let duration = self.start_time.elapsed();
293 self.profiler
294 .record_operation(self.operation_type.clone(), duration, is_error);
295 self.recorded = true;
296 }
297 }
298}
299
300impl Drop for OperationTimer {
301 fn drop(&mut self) {
302 if !self.recorded {
304 self.record(false);
305 }
306 }
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct PerformanceReport {
312 pub total_operations: u64,
313 pub total_errors: u64,
314 pub total_duration: Duration,
315 pub overall_success_rate: f64,
316 pub operation_stats: HashMap<OperationType, OperationStats>,
317}
318
319impl PerformanceReport {
320 pub fn summary(&self) -> String {
322 let mut output = String::new();
323 output.push_str("╔════════════════════════════════════════════════════════════════════╗\n");
324 output.push_str("║ Embedding Performance Profiling Report ║\n");
325 output
326 .push_str("╚════════════════════════════════════════════════════════════════════╝\n\n");
327
328 output.push_str(&format!("Total Operations: {}\n", self.total_operations));
329 output.push_str(&format!("Total Errors: {}\n", self.total_errors));
330 output.push_str(&format!(
331 "Overall Success Rate: {:.2}%\n",
332 self.overall_success_rate
333 ));
334 output.push_str(&format!(
335 "Total Duration: {:.2}s\n\n",
336 self.total_duration.as_secs_f64()
337 ));
338
339 output.push_str("Operation Statistics:\n");
340 output.push_str("─────────────────────────────────────────────────────────────────────\n");
341
342 let mut sorted_ops: Vec<_> = self.operation_stats.iter().collect();
343 sorted_ops.sort_by_key(|(_, stats)| std::cmp::Reverse(stats.total_count));
344
345 for (_, stats) in sorted_ops {
346 output.push_str(&format!("\n{} Operations:\n", stats.operation_type));
347 output.push_str(&format!(" Count: {}\n", stats.total_count));
348 output.push_str(&format!(" Success Rate: {:.2}%\n", stats.success_rate()));
349 output.push_str(&format!(
350 " Average Duration: {:.2}ms\n",
351 stats.average_duration.as_secs_f64() * 1000.0
352 ));
353 output.push_str(&format!(
354 " Min Duration: {:.2}ms\n",
355 stats.min_duration.as_secs_f64() * 1000.0
356 ));
357 output.push_str(&format!(
358 " Max Duration: {:.2}ms\n",
359 stats.max_duration.as_secs_f64() * 1000.0
360 ));
361 output.push_str(&format!(
362 " P95 Duration: {:.2}ms\n",
363 stats.percentile_95.as_secs_f64() * 1000.0
364 ));
365 output.push_str(&format!(
366 " P99 Duration: {:.2}ms\n",
367 stats.percentile_99.as_secs_f64() * 1000.0
368 ));
369 output.push_str(&format!(
370 " Throughput: {:.2} ops/sec\n",
371 stats.throughput()
372 ));
373 }
374
375 output
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use std::thread;
383
384 #[test]
385 fn test_profiler_creation() {
386 let profiler = PerformanceProfiler::new();
387 assert!(profiler.is_enabled());
388 }
389
390 #[test]
391 fn test_operation_recording() {
392 let profiler = PerformanceProfiler::new();
393
394 profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
395 profiler.record_operation(OperationType::Training, Duration::from_millis(150), false);
396 profiler.record_operation(OperationType::Training, Duration::from_millis(120), true);
397
398 let stats = profiler.get_stats(OperationType::Training).unwrap();
399 assert_eq!(stats.total_count, 3);
400 assert_eq!(stats.error_count, 1);
401 assert!((stats.success_rate() - 66.67).abs() < 0.1);
402 }
403
404 #[test]
405 fn test_operation_timer() {
406 let profiler = PerformanceProfiler::new();
407
408 {
409 let _timer = profiler.start_operation(OperationType::Inference);
410 thread::sleep(Duration::from_millis(50));
411 }
412
413 let stats = profiler.get_stats(OperationType::Inference).unwrap();
414 assert_eq!(stats.total_count, 1);
415 assert!(stats.total_duration >= Duration::from_millis(50));
416 }
417
418 #[test]
419 fn test_multiple_operation_types() {
420 let profiler = PerformanceProfiler::new();
421
422 profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
423 profiler.record_operation(OperationType::Inference, Duration::from_millis(50), false);
424 profiler.record_operation(
425 OperationType::SimilarityComputation,
426 Duration::from_millis(25),
427 false,
428 );
429
430 let all_stats = profiler.get_all_stats();
431 assert_eq!(all_stats.len(), 3);
432 }
433
434 #[test]
435 fn test_profiler_reset() {
436 let profiler = PerformanceProfiler::new();
437
438 profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
439 assert_eq!(profiler.get_all_stats().len(), 1);
440
441 profiler.reset();
442 assert_eq!(profiler.get_all_stats().len(), 0);
443 }
444
445 #[test]
446 fn test_performance_report_generation() {
447 let profiler = PerformanceProfiler::new();
448
449 profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
450 profiler.record_operation(OperationType::Inference, Duration::from_millis(50), false);
451
452 let report = profiler.generate_report();
453 assert_eq!(report.total_operations, 2);
454 assert_eq!(report.total_errors, 0);
455 assert_eq!(report.overall_success_rate, 100.0);
456
457 let summary = report.summary();
458 assert!(summary.contains("Total Operations: 2"));
459 }
460
461 #[test]
462 fn test_percentile_calculation() {
463 let profiler = PerformanceProfiler::new();
464
465 for i in 1..=100 {
467 profiler.record_operation(OperationType::Inference, Duration::from_millis(i), false);
468 }
469
470 let stats = profiler
471 .calculate_percentiles(OperationType::Inference)
472 .unwrap();
473 assert!(stats.percentile_95 >= Duration::from_millis(90));
474 assert!(stats.percentile_99 >= Duration::from_millis(95));
475 }
476
477 #[test]
478 fn test_profiler_disable() {
479 let mut profiler = PerformanceProfiler::new();
480 profiler.disable();
481
482 profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
483
484 assert_eq!(profiler.get_all_stats().len(), 0);
485 }
486
487 #[test]
488 fn test_json_export() {
489 let profiler = PerformanceProfiler::new();
490
491 profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
492
493 let json = profiler.export_json().unwrap();
494 assert!(json.contains("total_operations"));
495 assert!(json.contains("Training"));
496 }
497}