Skip to main content

axonml_profile/
compute.rs

1//! Compute Profiling Module
2//!
3//! Tracks operation execution times, FLOPS, and throughput.
4
5use std::collections::HashMap;
6use std::time::{Duration, Instant};
7use serde::{Serialize, Deserialize};
8
9/// Statistics for a single profiled operation.
10#[derive(Debug, Clone, Default, Serialize, Deserialize)]
11pub struct OperationStats {
12    /// Name of the operation
13    pub name: String,
14    /// Number of times this operation was called
15    pub call_count: usize,
16    /// Total time spent in this operation
17    pub total_time_ns: u64,
18    /// Minimum execution time
19    pub min_time_ns: u64,
20    /// Maximum execution time
21    pub max_time_ns: u64,
22    /// FLOPS (if computed)
23    pub flops: Option<f64>,
24    /// Bytes processed (for bandwidth calculation)
25    pub bytes_processed: Option<usize>,
26}
27
28impl OperationStats {
29    /// Returns the average execution time.
30    pub fn avg_time(&self) -> Duration {
31        if self.call_count == 0 {
32            Duration::ZERO
33        } else {
34            Duration::from_nanos(self.total_time_ns / self.call_count as u64)
35        }
36    }
37
38    /// Returns the total execution time.
39    pub fn total_time(&self) -> Duration {
40        Duration::from_nanos(self.total_time_ns)
41    }
42
43    /// Returns the minimum execution time.
44    pub fn min_time(&self) -> Duration {
45        Duration::from_nanos(self.min_time_ns)
46    }
47
48    /// Returns the maximum execution time.
49    pub fn max_time(&self) -> Duration {
50        Duration::from_nanos(self.max_time_ns)
51    }
52
53    /// Returns GFLOPS if FLOPS is set.
54    pub fn gflops(&self) -> Option<f64> {
55        self.flops.map(|f| f / 1e9)
56    }
57
58    /// Returns bandwidth in GB/s if bytes_processed is set.
59    pub fn bandwidth_gbps(&self) -> Option<f64> {
60        if let Some(bytes) = self.bytes_processed {
61            if self.total_time_ns > 0 {
62                let seconds = self.total_time_ns as f64 / 1e9;
63                Some(bytes as f64 / seconds / 1e9)
64            } else {
65                None
66            }
67        } else {
68            None
69        }
70    }
71}
72
73/// A profiled operation with timing.
74#[derive(Debug, Clone)]
75pub struct ProfiledOp {
76    /// Operation name
77    pub name: String,
78    /// Start time
79    pub start: Instant,
80    /// FLOPS count (optional)
81    pub flops: Option<f64>,
82    /// Bytes processed (optional)
83    pub bytes: Option<usize>,
84}
85
86/// Compute profiler for tracking operation execution times.
87#[derive(Debug)]
88pub struct ComputeProfiler {
89    /// Statistics per operation name
90    stats: HashMap<String, OperationStats>,
91    /// Currently active operations
92    active: HashMap<String, Vec<ProfiledOp>>,
93}
94
95impl Default for ComputeProfiler {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101impl ComputeProfiler {
102    /// Creates a new compute profiler.
103    pub fn new() -> Self {
104        Self {
105            stats: HashMap::new(),
106            active: HashMap::new(),
107        }
108    }
109
110    /// Starts profiling an operation.
111    pub fn start(&mut self, name: &str) {
112        let op = ProfiledOp {
113            name: name.to_string(),
114            start: Instant::now(),
115            flops: None,
116            bytes: None,
117        };
118
119        self.active
120            .entry(name.to_string())
121            .or_insert_with(Vec::new)
122            .push(op);
123    }
124
125    /// Starts profiling an operation with FLOPS count.
126    pub fn start_with_flops(&mut self, name: &str, flops: f64) {
127        let op = ProfiledOp {
128            name: name.to_string(),
129            start: Instant::now(),
130            flops: Some(flops),
131            bytes: None,
132        };
133
134        self.active
135            .entry(name.to_string())
136            .or_insert_with(Vec::new)
137            .push(op);
138    }
139
140    /// Starts profiling an operation with bytes processed.
141    pub fn start_with_bytes(&mut self, name: &str, bytes: usize) {
142        let op = ProfiledOp {
143            name: name.to_string(),
144            start: Instant::now(),
145            flops: None,
146            bytes: Some(bytes),
147        };
148
149        self.active
150            .entry(name.to_string())
151            .or_insert_with(Vec::new)
152            .push(op);
153    }
154
155    /// Stops profiling an operation and records its duration.
156    pub fn stop(&mut self, name: &str) {
157        let _elapsed = if let Some(ops) = self.active.get_mut(name) {
158            if let Some(op) = ops.pop() {
159                let elapsed = op.start.elapsed();
160
161                // Update stats
162                let stats = self.stats.entry(name.to_string()).or_insert_with(|| {
163                    OperationStats {
164                        name: name.to_string(),
165                        min_time_ns: u64::MAX,
166                        ..Default::default()
167                    }
168                });
169
170                let elapsed_ns = elapsed.as_nanos() as u64;
171                stats.call_count += 1;
172                stats.total_time_ns += elapsed_ns;
173                stats.min_time_ns = stats.min_time_ns.min(elapsed_ns);
174                stats.max_time_ns = stats.max_time_ns.max(elapsed_ns);
175
176                if let Some(flops) = op.flops {
177                    stats.flops = Some(stats.flops.unwrap_or(0.0) + flops);
178                }
179                if let Some(bytes) = op.bytes {
180                    stats.bytes_processed = Some(stats.bytes_processed.unwrap_or(0) + bytes);
181                }
182
183                Some(elapsed)
184            } else {
185                None
186            }
187        } else {
188            None
189        };
190    }
191
192    /// Gets statistics for a specific operation.
193    pub fn get_stats(&self, name: &str) -> Option<&OperationStats> {
194        self.stats.get(name)
195    }
196
197    /// Gets all operation statistics.
198    pub fn all_stats(&self) -> HashMap<String, OperationStats> {
199        self.stats.clone()
200    }
201
202    /// Gets total time for an operation.
203    pub fn total_time(&self, name: &str) -> Duration {
204        self.stats
205            .get(name)
206            .map(|s| s.total_time())
207            .unwrap_or(Duration::ZERO)
208    }
209
210    /// Gets average time for an operation.
211    pub fn avg_time(&self, name: &str) -> Duration {
212        self.stats
213            .get(name)
214            .map(|s| s.avg_time())
215            .unwrap_or(Duration::ZERO)
216    }
217
218    /// Gets the top N operations by total time.
219    pub fn top_by_time(&self, n: usize) -> Vec<&OperationStats> {
220        let mut sorted: Vec<_> = self.stats.values().collect();
221        sorted.sort_by(|a, b| b.total_time_ns.cmp(&a.total_time_ns));
222        sorted.into_iter().take(n).collect()
223    }
224
225    /// Gets the top N operations by call count.
226    pub fn top_by_calls(&self, n: usize) -> Vec<&OperationStats> {
227        let mut sorted: Vec<_> = self.stats.values().collect();
228        sorted.sort_by(|a, b| b.call_count.cmp(&a.call_count));
229        sorted.into_iter().take(n).collect()
230    }
231
232    /// Resets all statistics.
233    pub fn reset(&mut self) {
234        self.stats.clear();
235        self.active.clear();
236    }
237
238    /// Formats a duration for display.
239    pub fn format_duration(d: Duration) -> String {
240        let nanos = d.as_nanos();
241        if nanos >= 1_000_000_000 {
242            format!("{:.3} s", d.as_secs_f64())
243        } else if nanos >= 1_000_000 {
244            format!("{:.3} ms", nanos as f64 / 1_000_000.0)
245        } else if nanos >= 1_000 {
246            format!("{:.3} us", nanos as f64 / 1_000.0)
247        } else {
248            format!("{} ns", nanos)
249        }
250    }
251}
252
253/// RAII guard for automatic operation timing.
254pub struct TimingGuard<'a> {
255    profiler: &'a mut ComputeProfiler,
256    name: String,
257}
258
259impl<'a> TimingGuard<'a> {
260    /// Creates a new timing guard.
261    pub fn new(profiler: &'a mut ComputeProfiler, name: &str) -> Self {
262        profiler.start(name);
263        Self {
264            profiler,
265            name: name.to_string(),
266        }
267    }
268}
269
270impl<'a> Drop for TimingGuard<'a> {
271    fn drop(&mut self) {
272        self.profiler.stop(&self.name);
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_basic_timing() {
282        let mut profiler = ComputeProfiler::new();
283
284        profiler.start("test_op");
285        std::thread::sleep(Duration::from_millis(10));
286        profiler.stop("test_op");
287
288        let stats = profiler.get_stats("test_op").unwrap();
289        assert_eq!(stats.call_count, 1);
290        assert!(stats.total_time() >= Duration::from_millis(10));
291    }
292
293    #[test]
294    fn test_multiple_calls() {
295        let mut profiler = ComputeProfiler::new();
296
297        for _ in 0..5 {
298            profiler.start("multi_op");
299            std::thread::sleep(Duration::from_millis(1));
300            profiler.stop("multi_op");
301        }
302
303        let stats = profiler.get_stats("multi_op").unwrap();
304        assert_eq!(stats.call_count, 5);
305    }
306
307    #[test]
308    fn test_nested_operations() {
309        let mut profiler = ComputeProfiler::new();
310
311        profiler.start("outer");
312        profiler.start("inner");
313        std::thread::sleep(Duration::from_millis(5));
314        profiler.stop("inner");
315        std::thread::sleep(Duration::from_millis(5));
316        profiler.stop("outer");
317
318        let outer = profiler.get_stats("outer").unwrap();
319        let inner = profiler.get_stats("inner").unwrap();
320
321        assert!(outer.total_time() >= inner.total_time());
322    }
323
324    #[test]
325    fn test_top_operations() {
326        let mut profiler = ComputeProfiler::new();
327
328        profiler.start("slow");
329        std::thread::sleep(Duration::from_millis(20));
330        profiler.stop("slow");
331
332        profiler.start("fast");
333        std::thread::sleep(Duration::from_millis(5));
334        profiler.stop("fast");
335
336        let top = profiler.top_by_time(2);
337        assert_eq!(top[0].name, "slow");
338        assert_eq!(top[1].name, "fast");
339    }
340
341    #[test]
342    fn test_format_duration() {
343        assert!(ComputeProfiler::format_duration(Duration::from_nanos(500)).contains("ns"));
344        assert!(ComputeProfiler::format_duration(Duration::from_micros(500)).contains("us"));
345        assert!(ComputeProfiler::format_duration(Duration::from_millis(500)).contains("ms"));
346        assert!(ComputeProfiler::format_duration(Duration::from_secs(5)).contains("s"));
347    }
348}