1use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub enum OperationType {
16 Training,
17 Inference,
18 SimilarityComputation,
19 VectorSearch,
20 ModelSaving,
21 ModelLoading,
22 BatchProcessing,
23 EntityEmbedding,
24 RelationEmbedding,
25 TripleScoring,
26 Prediction,
27 Custom(String),
28}
29
30impl std::fmt::Display for OperationType {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Self::Training => write!(f, "Training"),
34 Self::Inference => write!(f, "Inference"),
35 Self::SimilarityComputation => write!(f, "Similarity"),
36 Self::VectorSearch => write!(f, "VectorSearch"),
37 Self::ModelSaving => write!(f, "ModelSave"),
38 Self::ModelLoading => write!(f, "ModelLoad"),
39 Self::BatchProcessing => write!(f, "BatchProcessing"),
40 Self::EntityEmbedding => write!(f, "EntityEmbedding"),
41 Self::RelationEmbedding => write!(f, "RelationEmbedding"),
42 Self::TripleScoring => write!(f, "TripleScoring"),
43 Self::Prediction => write!(f, "Prediction"),
44 Self::Custom(name) => write!(f, "{}", name),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct OperationStats {
52 pub operation_type: OperationType,
53 pub total_count: u64,
54 pub total_duration: Duration,
55 pub min_duration: Duration,
56 pub max_duration: Duration,
57 pub average_duration: Duration,
58 pub percentile_95: Duration,
59 pub percentile_99: Duration,
60 pub error_count: u64,
61}
62
63impl OperationStats {
64 fn new(operation_type: OperationType) -> Self {
65 Self {
66 operation_type,
67 total_count: 0,
68 total_duration: Duration::ZERO,
69 min_duration: Duration::MAX,
70 max_duration: Duration::ZERO,
71 average_duration: Duration::ZERO,
72 percentile_95: Duration::ZERO,
73 percentile_99: Duration::ZERO,
74 error_count: 0,
75 }
76 }
77
78 fn update(&mut self, duration: Duration, is_error: bool) {
79 self.total_count += 1;
80 self.total_duration += duration;
81 self.min_duration = self.min_duration.min(duration);
82 self.max_duration = self.max_duration.max(duration);
83 self.average_duration = self.total_duration / self.total_count as u32;
84
85 if is_error {
86 self.error_count += 1;
87 }
88 }
89
90 pub fn success_rate(&self) -> f64 {
92 if self.total_count == 0 {
93 0.0
94 } else {
95 ((self.total_count - self.error_count) as f64 / self.total_count as f64) * 100.0
96 }
97 }
98
99 pub fn throughput(&self) -> f64 {
101 if self.total_duration.as_secs_f64() > 0.0 {
102 self.total_count as f64 / self.total_duration.as_secs_f64()
103 } else {
104 0.0
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct PerformanceProfiler {
112 stats: Arc<RwLock<HashMap<OperationType, OperationStats>>>,
113 durations_buffer: Arc<RwLock<HashMap<OperationType, Vec<Duration>>>>,
114 enabled: bool,
115}
116
117impl Default for PerformanceProfiler {
118 fn default() -> Self {
119 Self::new()
120 }
121}
122
123impl PerformanceProfiler {
124 pub fn new() -> Self {
126 Self {
127 stats: Arc::new(RwLock::new(HashMap::new())),
128 durations_buffer: Arc::new(RwLock::new(HashMap::new())),
129 enabled: true,
130 }
131 }
132
133 pub fn enable(&mut self) {
135 self.enabled = true;
136 }
137
138 pub fn disable(&mut self) {
140 self.enabled = false;
141 }
142
143 pub fn is_enabled(&self) -> bool {
145 self.enabled
146 }
147
148 pub fn start_operation(&self, operation_type: OperationType) -> OperationTimer {
150 OperationTimer::new(operation_type, self.clone())
151 }
152
153 pub fn record_operation(
155 &self,
156 operation_type: OperationType,
157 duration: Duration,
158 is_error: bool,
159 ) {
160 if !self.enabled {
161 return;
162 }
163
164 let mut stats = self.stats.write().expect("lock should not be poisoned");
166 stats
167 .entry(operation_type.clone())
168 .or_insert_with(|| OperationStats::new(operation_type.clone()))
169 .update(duration, is_error);
170
171 let mut durations = self
173 .durations_buffer
174 .write()
175 .expect("lock should not be poisoned");
176 durations
177 .entry(operation_type.clone())
178 .or_default()
179 .push(duration);
180
181 if let Some(buffer) = durations.get_mut(&operation_type) {
183 if buffer.len() > 1000 {
184 buffer.remove(0);
185 }
186 }
187 }
188
189 pub fn get_stats(&self, operation_type: OperationType) -> Option<OperationStats> {
191 let stats = self.stats.read().expect("read lock should not be poisoned");
192 stats.get(&operation_type).cloned()
193 }
194
195 pub fn get_all_stats(&self) -> HashMap<OperationType, OperationStats> {
197 let stats = self.stats.read().expect("read lock should not be poisoned");
198 stats.clone()
199 }
200
201 pub fn calculate_percentiles(&self, operation_type: OperationType) -> Option<OperationStats> {
203 let durations = self
204 .durations_buffer
205 .read()
206 .expect("read lock should not be poisoned");
207 let mut stats = self.stats.write().expect("lock should not be poisoned");
208
209 if let Some(durations_vec) = durations.get(&operation_type) {
210 if let Some(op_stats) = stats.get_mut(&operation_type) {
211 let mut sorted_durations = durations_vec.clone();
212 sorted_durations.sort();
213
214 if !sorted_durations.is_empty() {
215 let p95_index = (sorted_durations.len() as f64 * 0.95) as usize;
216 let p99_index = (sorted_durations.len() as f64 * 0.99) as usize;
217
218 op_stats.percentile_95 =
219 sorted_durations[p95_index.min(sorted_durations.len() - 1)];
220 op_stats.percentile_99 =
221 sorted_durations[p99_index.min(sorted_durations.len() - 1)];
222 }
223
224 return Some(op_stats.clone());
225 }
226 }
227
228 None
229 }
230
231 pub fn reset(&self) {
233 let mut stats = self.stats.write().expect("lock should not be poisoned");
234 let mut durations = self
235 .durations_buffer
236 .write()
237 .expect("lock should not be poisoned");
238 stats.clear();
239 durations.clear();
240 }
241
242 pub fn generate_report(&self) -> PerformanceReport {
244 let stats = self.get_all_stats();
245
246 let total_operations: u64 = stats.values().map(|s| s.total_count).sum();
247 let total_errors: u64 = stats.values().map(|s| s.error_count).sum();
248 let total_duration: Duration = stats.values().map(|s| s.total_duration).sum();
249
250 PerformanceReport {
251 total_operations,
252 total_errors,
253 total_duration,
254 overall_success_rate: if total_operations > 0 {
255 ((total_operations - total_errors) as f64 / total_operations as f64) * 100.0
256 } else {
257 0.0
258 },
259 operation_stats: stats,
260 }
261 }
262
263 pub fn export_json(&self) -> Result<String> {
265 let report = self.generate_report();
266 serde_json::to_string_pretty(&report)
267 .map_err(|e| anyhow::anyhow!("Failed to serialize report: {}", e))
268 }
269}
270
271pub struct OperationTimer {
273 operation_type: OperationType,
274 start_time: Instant,
275 profiler: PerformanceProfiler,
276 recorded: bool,
277}
278
279impl OperationTimer {
280 fn new(operation_type: OperationType, profiler: PerformanceProfiler) -> Self {
281 Self {
282 operation_type,
283 start_time: Instant::now(),
284 profiler,
285 recorded: false,
286 }
287 }
288
289 pub fn stop(mut self) {
291 self.record(false);
292 }
293
294 pub fn stop_with_error(mut self) {
296 self.record(true);
297 }
298
299 fn record(&mut self, is_error: bool) {
300 if !self.recorded {
301 let duration = self.start_time.elapsed();
302 self.profiler
303 .record_operation(self.operation_type.clone(), duration, is_error);
304 self.recorded = true;
305 }
306 }
307}
308
309impl Drop for OperationTimer {
310 fn drop(&mut self) {
311 if !self.recorded {
313 self.record(false);
314 }
315 }
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct PerformanceReport {
321 pub total_operations: u64,
322 pub total_errors: u64,
323 pub total_duration: Duration,
324 pub overall_success_rate: f64,
325 pub operation_stats: HashMap<OperationType, OperationStats>,
326}
327
328impl PerformanceReport {
329 pub fn summary(&self) -> String {
331 let mut output = String::new();
332 output.push_str("╔════════════════════════════════════════════════════════════════════╗\n");
333 output.push_str("║ Embedding Performance Profiling Report ║\n");
334 output
335 .push_str("╚════════════════════════════════════════════════════════════════════╝\n\n");
336
337 output.push_str(&format!("Total Operations: {}\n", self.total_operations));
338 output.push_str(&format!("Total Errors: {}\n", self.total_errors));
339 output.push_str(&format!(
340 "Overall Success Rate: {:.2}%\n",
341 self.overall_success_rate
342 ));
343 output.push_str(&format!(
344 "Total Duration: {:.2}s\n\n",
345 self.total_duration.as_secs_f64()
346 ));
347
348 output.push_str("Operation Statistics:\n");
349 output.push_str("─────────────────────────────────────────────────────────────────────\n");
350
351 let mut sorted_ops: Vec<_> = self.operation_stats.iter().collect();
352 sorted_ops.sort_by_key(|(_, stats)| std::cmp::Reverse(stats.total_count));
353
354 for (_, stats) in sorted_ops {
355 output.push_str(&format!("\n{} Operations:\n", stats.operation_type));
356 output.push_str(&format!(" Count: {}\n", stats.total_count));
357 output.push_str(&format!(" Success Rate: {:.2}%\n", stats.success_rate()));
358 output.push_str(&format!(
359 " Average Duration: {:.2}ms\n",
360 stats.average_duration.as_secs_f64() * 1000.0
361 ));
362 output.push_str(&format!(
363 " Min Duration: {:.2}ms\n",
364 stats.min_duration.as_secs_f64() * 1000.0
365 ));
366 output.push_str(&format!(
367 " Max Duration: {:.2}ms\n",
368 stats.max_duration.as_secs_f64() * 1000.0
369 ));
370 output.push_str(&format!(
371 " P95 Duration: {:.2}ms\n",
372 stats.percentile_95.as_secs_f64() * 1000.0
373 ));
374 output.push_str(&format!(
375 " P99 Duration: {:.2}ms\n",
376 stats.percentile_99.as_secs_f64() * 1000.0
377 ));
378 output.push_str(&format!(
379 " Throughput: {:.2} ops/sec\n",
380 stats.throughput()
381 ));
382 }
383
384 output
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use std::thread;
392
393 #[test]
394 fn test_profiler_creation() {
395 let profiler = PerformanceProfiler::new();
396 assert!(profiler.is_enabled());
397 }
398
399 #[test]
400 fn test_operation_recording() {
401 let profiler = PerformanceProfiler::new();
402
403 profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
404 profiler.record_operation(OperationType::Training, Duration::from_millis(150), false);
405 profiler.record_operation(OperationType::Training, Duration::from_millis(120), true);
406
407 let stats = profiler.get_stats(OperationType::Training).unwrap();
408 assert_eq!(stats.total_count, 3);
409 assert_eq!(stats.error_count, 1);
410 assert!((stats.success_rate() - 66.67).abs() < 0.1);
411 }
412
413 #[test]
414 fn test_operation_timer() {
415 let profiler = PerformanceProfiler::new();
416
417 {
418 let _timer = profiler.start_operation(OperationType::Inference);
419 thread::sleep(Duration::from_millis(50));
420 }
421
422 let stats = profiler.get_stats(OperationType::Inference).unwrap();
423 assert_eq!(stats.total_count, 1);
424 assert!(stats.total_duration >= Duration::from_millis(50));
425 }
426
427 #[test]
428 fn test_multiple_operation_types() {
429 let profiler = PerformanceProfiler::new();
430
431 profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
432 profiler.record_operation(OperationType::Inference, Duration::from_millis(50), false);
433 profiler.record_operation(
434 OperationType::SimilarityComputation,
435 Duration::from_millis(25),
436 false,
437 );
438
439 let all_stats = profiler.get_all_stats();
440 assert_eq!(all_stats.len(), 3);
441 }
442
443 #[test]
444 fn test_profiler_reset() {
445 let profiler = PerformanceProfiler::new();
446
447 profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
448 assert_eq!(profiler.get_all_stats().len(), 1);
449
450 profiler.reset();
451 assert_eq!(profiler.get_all_stats().len(), 0);
452 }
453
454 #[test]
455 fn test_performance_report_generation() {
456 let profiler = PerformanceProfiler::new();
457
458 profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
459 profiler.record_operation(OperationType::Inference, Duration::from_millis(50), false);
460
461 let report = profiler.generate_report();
462 assert_eq!(report.total_operations, 2);
463 assert_eq!(report.total_errors, 0);
464 assert_eq!(report.overall_success_rate, 100.0);
465
466 let summary = report.summary();
467 assert!(summary.contains("Total Operations: 2"));
468 }
469
470 #[test]
471 fn test_percentile_calculation() {
472 let profiler = PerformanceProfiler::new();
473
474 for i in 1..=100 {
476 profiler.record_operation(OperationType::Inference, Duration::from_millis(i), false);
477 }
478
479 let stats = profiler
480 .calculate_percentiles(OperationType::Inference)
481 .unwrap();
482 assert!(stats.percentile_95 >= Duration::from_millis(90));
483 assert!(stats.percentile_99 >= Duration::from_millis(95));
484 }
485
486 #[test]
487 fn test_profiler_disable() {
488 let mut profiler = PerformanceProfiler::new();
489 profiler.disable();
490
491 profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
492
493 assert_eq!(profiler.get_all_stats().len(), 0);
494 }
495
496 #[test]
497 fn test_json_export() {
498 let profiler = PerformanceProfiler::new();
499
500 profiler.record_operation(OperationType::Training, Duration::from_millis(100), false);
501
502 let json = profiler.export_json().unwrap();
503 assert!(json.contains("total_operations"));
504 assert!(json.contains("Training"));
505 }
506}