axonml_profile/
compute.rs1use std::collections::HashMap;
6use std::time::{Duration, Instant};
7use serde::{Serialize, Deserialize};
8
9#[derive(Debug, Clone, Default, Serialize, Deserialize)]
11pub struct OperationStats {
12 pub name: String,
14 pub call_count: usize,
16 pub total_time_ns: u64,
18 pub min_time_ns: u64,
20 pub max_time_ns: u64,
22 pub flops: Option<f64>,
24 pub bytes_processed: Option<usize>,
26}
27
28impl OperationStats {
29 pub fn avg_time(&self) -> Duration {
31 if self.call_count == 0 {
32 Duration::ZERO
33 } else {
34 Duration::from_nanos(self.total_time_ns / self.call_count as u64)
35 }
36 }
37
38 pub fn total_time(&self) -> Duration {
40 Duration::from_nanos(self.total_time_ns)
41 }
42
43 pub fn min_time(&self) -> Duration {
45 Duration::from_nanos(self.min_time_ns)
46 }
47
48 pub fn max_time(&self) -> Duration {
50 Duration::from_nanos(self.max_time_ns)
51 }
52
53 pub fn gflops(&self) -> Option<f64> {
55 self.flops.map(|f| f / 1e9)
56 }
57
58 pub fn bandwidth_gbps(&self) -> Option<f64> {
60 if let Some(bytes) = self.bytes_processed {
61 if self.total_time_ns > 0 {
62 let seconds = self.total_time_ns as f64 / 1e9;
63 Some(bytes as f64 / seconds / 1e9)
64 } else {
65 None
66 }
67 } else {
68 None
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct ProfiledOp {
76 pub name: String,
78 pub start: Instant,
80 pub flops: Option<f64>,
82 pub bytes: Option<usize>,
84}
85
86#[derive(Debug)]
88pub struct ComputeProfiler {
89 stats: HashMap<String, OperationStats>,
91 active: HashMap<String, Vec<ProfiledOp>>,
93}
94
95impl Default for ComputeProfiler {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101impl ComputeProfiler {
102 pub fn new() -> Self {
104 Self {
105 stats: HashMap::new(),
106 active: HashMap::new(),
107 }
108 }
109
110 pub fn start(&mut self, name: &str) {
112 let op = ProfiledOp {
113 name: name.to_string(),
114 start: Instant::now(),
115 flops: None,
116 bytes: None,
117 };
118
119 self.active
120 .entry(name.to_string())
121 .or_insert_with(Vec::new)
122 .push(op);
123 }
124
125 pub fn start_with_flops(&mut self, name: &str, flops: f64) {
127 let op = ProfiledOp {
128 name: name.to_string(),
129 start: Instant::now(),
130 flops: Some(flops),
131 bytes: None,
132 };
133
134 self.active
135 .entry(name.to_string())
136 .or_insert_with(Vec::new)
137 .push(op);
138 }
139
140 pub fn start_with_bytes(&mut self, name: &str, bytes: usize) {
142 let op = ProfiledOp {
143 name: name.to_string(),
144 start: Instant::now(),
145 flops: None,
146 bytes: Some(bytes),
147 };
148
149 self.active
150 .entry(name.to_string())
151 .or_insert_with(Vec::new)
152 .push(op);
153 }
154
155 pub fn stop(&mut self, name: &str) {
157 let _elapsed = if let Some(ops) = self.active.get_mut(name) {
158 if let Some(op) = ops.pop() {
159 let elapsed = op.start.elapsed();
160
161 let stats = self.stats.entry(name.to_string()).or_insert_with(|| {
163 OperationStats {
164 name: name.to_string(),
165 min_time_ns: u64::MAX,
166 ..Default::default()
167 }
168 });
169
170 let elapsed_ns = elapsed.as_nanos() as u64;
171 stats.call_count += 1;
172 stats.total_time_ns += elapsed_ns;
173 stats.min_time_ns = stats.min_time_ns.min(elapsed_ns);
174 stats.max_time_ns = stats.max_time_ns.max(elapsed_ns);
175
176 if let Some(flops) = op.flops {
177 stats.flops = Some(stats.flops.unwrap_or(0.0) + flops);
178 }
179 if let Some(bytes) = op.bytes {
180 stats.bytes_processed = Some(stats.bytes_processed.unwrap_or(0) + bytes);
181 }
182
183 Some(elapsed)
184 } else {
185 None
186 }
187 } else {
188 None
189 };
190 }
191
192 pub fn get_stats(&self, name: &str) -> Option<&OperationStats> {
194 self.stats.get(name)
195 }
196
197 pub fn all_stats(&self) -> HashMap<String, OperationStats> {
199 self.stats.clone()
200 }
201
202 pub fn total_time(&self, name: &str) -> Duration {
204 self.stats
205 .get(name)
206 .map(|s| s.total_time())
207 .unwrap_or(Duration::ZERO)
208 }
209
210 pub fn avg_time(&self, name: &str) -> Duration {
212 self.stats
213 .get(name)
214 .map(|s| s.avg_time())
215 .unwrap_or(Duration::ZERO)
216 }
217
218 pub fn top_by_time(&self, n: usize) -> Vec<&OperationStats> {
220 let mut sorted: Vec<_> = self.stats.values().collect();
221 sorted.sort_by(|a, b| b.total_time_ns.cmp(&a.total_time_ns));
222 sorted.into_iter().take(n).collect()
223 }
224
225 pub fn top_by_calls(&self, n: usize) -> Vec<&OperationStats> {
227 let mut sorted: Vec<_> = self.stats.values().collect();
228 sorted.sort_by(|a, b| b.call_count.cmp(&a.call_count));
229 sorted.into_iter().take(n).collect()
230 }
231
232 pub fn reset(&mut self) {
234 self.stats.clear();
235 self.active.clear();
236 }
237
238 pub fn format_duration(d: Duration) -> String {
240 let nanos = d.as_nanos();
241 if nanos >= 1_000_000_000 {
242 format!("{:.3} s", d.as_secs_f64())
243 } else if nanos >= 1_000_000 {
244 format!("{:.3} ms", nanos as f64 / 1_000_000.0)
245 } else if nanos >= 1_000 {
246 format!("{:.3} us", nanos as f64 / 1_000.0)
247 } else {
248 format!("{} ns", nanos)
249 }
250 }
251}
252
253pub struct TimingGuard<'a> {
255 profiler: &'a mut ComputeProfiler,
256 name: String,
257}
258
259impl<'a> TimingGuard<'a> {
260 pub fn new(profiler: &'a mut ComputeProfiler, name: &str) -> Self {
262 profiler.start(name);
263 Self {
264 profiler,
265 name: name.to_string(),
266 }
267 }
268}
269
270impl<'a> Drop for TimingGuard<'a> {
271 fn drop(&mut self) {
272 self.profiler.stop(&self.name);
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_basic_timing() {
282 let mut profiler = ComputeProfiler::new();
283
284 profiler.start("test_op");
285 std::thread::sleep(Duration::from_millis(10));
286 profiler.stop("test_op");
287
288 let stats = profiler.get_stats("test_op").unwrap();
289 assert_eq!(stats.call_count, 1);
290 assert!(stats.total_time() >= Duration::from_millis(10));
291 }
292
293 #[test]
294 fn test_multiple_calls() {
295 let mut profiler = ComputeProfiler::new();
296
297 for _ in 0..5 {
298 profiler.start("multi_op");
299 std::thread::sleep(Duration::from_millis(1));
300 profiler.stop("multi_op");
301 }
302
303 let stats = profiler.get_stats("multi_op").unwrap();
304 assert_eq!(stats.call_count, 5);
305 }
306
307 #[test]
308 fn test_nested_operations() {
309 let mut profiler = ComputeProfiler::new();
310
311 profiler.start("outer");
312 profiler.start("inner");
313 std::thread::sleep(Duration::from_millis(5));
314 profiler.stop("inner");
315 std::thread::sleep(Duration::from_millis(5));
316 profiler.stop("outer");
317
318 let outer = profiler.get_stats("outer").unwrap();
319 let inner = profiler.get_stats("inner").unwrap();
320
321 assert!(outer.total_time() >= inner.total_time());
322 }
323
324 #[test]
325 fn test_top_operations() {
326 let mut profiler = ComputeProfiler::new();
327
328 profiler.start("slow");
329 std::thread::sleep(Duration::from_millis(20));
330 profiler.stop("slow");
331
332 profiler.start("fast");
333 std::thread::sleep(Duration::from_millis(5));
334 profiler.stop("fast");
335
336 let top = profiler.top_by_time(2);
337 assert_eq!(top[0].name, "slow");
338 assert_eq!(top[1].name, "fast");
339 }
340
341 #[test]
342 fn test_format_duration() {
343 assert!(ComputeProfiler::format_duration(Duration::from_nanos(500)).contains("ns"));
344 assert!(ComputeProfiler::format_duration(Duration::from_micros(500)).contains("us"));
345 assert!(ComputeProfiler::format_duration(Duration::from_millis(500)).contains("ms"));
346 assert!(ComputeProfiler::format_duration(Duration::from_secs(5)).contains("s"));
347 }
348}