1use aspect_core::{Aspect, AspectError, ProceedingJoinPoint};
4use parking_lot::Mutex;
5use std::any::Any;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10#[derive(Clone)]
30pub struct TimingAspect {
31 stats: Arc<Mutex<HashMap<String, FunctionStats>>>,
32 threshold_ms: Option<u64>,
33 print_on_complete: bool,
34}
35
36#[derive(Debug, Clone)]
38pub struct FunctionStats {
39 pub name: String,
41 pub count: u64,
43 pub total_duration: Duration,
45 pub min_duration: Duration,
47 pub max_duration: Duration,
49}
50
51impl FunctionStats {
52 fn new(name: String) -> Self {
53 Self {
54 name,
55 count: 0,
56 total_duration: Duration::ZERO,
57 min_duration: Duration::MAX,
58 max_duration: Duration::ZERO,
59 }
60 }
61
62 fn record(&mut self, duration: Duration) {
63 self.count += 1;
64 self.total_duration += duration;
65 self.min_duration = self.min_duration.min(duration);
66 self.max_duration = self.max_duration.max(duration);
67 }
68
69 pub fn average_duration(&self) -> Duration {
71 if self.count > 0 {
72 self.total_duration / self.count as u32
73 } else {
74 Duration::ZERO
75 }
76 }
77}
78
79impl TimingAspect {
80 pub fn new() -> Self {
82 Self {
83 stats: Arc::new(Mutex::new(HashMap::new())),
84 threshold_ms: None,
85 print_on_complete: false,
86 }
87 }
88
89 pub fn with_threshold(mut self, threshold_ms: u64) -> Self {
91 self.threshold_ms = Some(threshold_ms);
92 self
93 }
94
95 pub fn print_on_complete(mut self) -> Self {
97 self.print_on_complete = true;
98 self
99 }
100
101 pub fn get_stats(&self, function_name: &str) -> Option<FunctionStats> {
103 self.stats.lock().get(function_name).cloned()
104 }
105
106 pub fn all_stats(&self) -> Vec<FunctionStats> {
108 self.stats.lock().values().cloned().collect()
109 }
110
111 pub fn print_stats(&self) {
113 let stats = self.stats.lock();
114 if stats.is_empty() {
115 println!("No timing data collected.");
116 return;
117 }
118
119 println!("\n=== Timing Statistics ===");
120 println!("{:<30} {:>10} {:>15} {:>15} {:>15} {:>15}",
121 "Function", "Calls", "Total", "Average", "Min", "Max");
122 println!("{:-<100}", "");
123
124 for stat in stats.values() {
125 println!(
126 "{:<30} {:>10} {:>15.3?} {:>15.3?} {:>15.3?} {:>15.3?}",
127 stat.name,
128 stat.count,
129 stat.total_duration,
130 stat.average_duration(),
131 stat.min_duration,
132 stat.max_duration
133 );
134 }
135 println!();
136 }
137
138 pub fn clear(&self) {
140 self.stats.lock().clear();
141 }
142
143 fn record_timing(&self, function_name: &str, duration: Duration) {
144 let mut stats = self.stats.lock();
145 stats
146 .entry(function_name.to_string())
147 .or_insert_with(|| FunctionStats::new(function_name.to_string()))
148 .record(duration);
149 }
150}
151
152impl Default for TimingAspect {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158impl Aspect for TimingAspect {
159 fn around(&self, pjp: ProceedingJoinPoint) -> Result<Box<dyn Any>, AspectError> {
160 let function_name = pjp.context().function_name.to_string();
161 let start = Instant::now();
162
163 let result = pjp.proceed();
164
165 let duration = start.elapsed();
166 self.record_timing(&function_name, duration);
167
168 if let Some(threshold_ms) = self.threshold_ms {
170 if duration.as_millis() > threshold_ms as u128 {
171 println!(
172 "[SLOW] {} took {:?} (threshold: {}ms)",
173 function_name, duration, threshold_ms
174 );
175 }
176 }
177
178 if self.print_on_complete {
180 println!("[TIMING] {} took {:?}", function_name, duration);
181 }
182
183 result
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use aspect_core::JoinPoint;
191
192 #[test]
193 fn test_timing_aspect_creation() {
194 let aspect = TimingAspect::new();
195 assert!(aspect.threshold_ms.is_none());
196 assert!(!aspect.print_on_complete);
197 assert!(aspect.all_stats().is_empty());
198 }
199
200 #[test]
201 fn test_timing_aspect_builder() {
202 let aspect = TimingAspect::new()
203 .with_threshold(100)
204 .print_on_complete();
205
206 assert_eq!(aspect.threshold_ms, Some(100));
207 assert!(aspect.print_on_complete);
208 }
209
210 #[test]
211 fn test_function_stats() {
212 let mut stats = FunctionStats::new("test_func".to_string());
213
214 stats.record(Duration::from_millis(10));
215 stats.record(Duration::from_millis(20));
216 stats.record(Duration::from_millis(30));
217
218 assert_eq!(stats.count, 3);
219 assert_eq!(stats.min_duration, Duration::from_millis(10));
220 assert_eq!(stats.max_duration, Duration::from_millis(30));
221 assert_eq!(stats.average_duration(), Duration::from_millis(20));
222 }
223
224 #[test]
225 fn test_timing_aspect_record() {
226 let aspect = TimingAspect::new();
227
228 aspect.record_timing("func1", Duration::from_millis(10));
229 aspect.record_timing("func1", Duration::from_millis(20));
230 aspect.record_timing("func2", Duration::from_millis(30));
231
232 let stats1 = aspect.get_stats("func1").unwrap();
233 assert_eq!(stats1.count, 2);
234
235 let stats2 = aspect.get_stats("func2").unwrap();
236 assert_eq!(stats2.count, 1);
237
238 assert_eq!(aspect.all_stats().len(), 2);
239 }
240}