Skip to main content

oxibonsai_runtime/
profiler.rs

1//! Inference profiler: per-layer timing, memory, and FLOP accounting.
2//!
3//! The profiler uses `std::time::Instant` for timing and provides detailed
4//! per-layer and per-phase breakdowns suitable for performance analysis.
5//!
6//! ## Usage
7//!
8//! ```rust
9//! use oxibonsai_runtime::profiler::{Profiler, flop_counter};
10//!
11//! let mut prof = Profiler::new();
12//! prof.begin_trace();
13//!
14//! let result = prof.profile("attention.layer0", flop_counter::attention(512, 64, 8), || {
15//!     42u32
16//! });
17//!
18//! let trace = prof.end_trace().expect("trace should exist");
19//! println!("{}", trace.summary());
20//! ```
21
22use std::collections::HashMap;
23use std::fmt::Write as FmtWrite;
24use std::time::{Duration, Instant};
25
26// ─── ProfileEvent ────────────────────────────────────────────────────────────
27
28/// A single profiled event (one layer or one phase).
29#[derive(Debug, Clone)]
30pub struct ProfileEvent {
31    /// Human-readable name, e.g. `"attention.layer3"`.
32    pub name: String,
33    /// Wall-clock duration of the event.
34    pub duration: Duration,
35    /// Signed memory delta in bytes (positive = allocated, negative = freed).
36    pub memory_delta_bytes: i64,
37    /// Estimated floating point operations performed.
38    pub flops: u64,
39    /// Arbitrary key-value metadata attached to this event.
40    pub metadata: HashMap<String, String>,
41}
42
43impl ProfileEvent {
44    /// Create a new event with zero duration, no memory delta, and no FLOPs.
45    pub fn new(name: impl Into<String>) -> Self {
46        Self {
47            name: name.into(),
48            duration: Duration::ZERO,
49            memory_delta_bytes: 0,
50            flops: 0,
51            metadata: HashMap::new(),
52        }
53    }
54
55    /// Builder: attach an estimated FLOP count.
56    pub fn with_flops(mut self, flops: u64) -> Self {
57        self.flops = flops;
58        self
59    }
60
61    /// Builder: attach a key-value metadata entry.
62    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
63        self.metadata.insert(key.into(), value.into());
64        self
65    }
66
67    /// Duration in milliseconds (floating point).
68    pub fn duration_ms(&self) -> f64 {
69        self.duration.as_secs_f64() * 1_000.0
70    }
71
72    /// Estimated GFLOPs per second for this event.
73    ///
74    /// Returns `0.0` if duration is zero or flops is zero.
75    pub fn gflops_per_second(&self) -> f64 {
76        let secs = self.duration.as_secs_f64();
77        if secs <= 0.0 || self.flops == 0 {
78            return 0.0;
79        }
80        (self.flops as f64) / secs / 1e9
81    }
82}
83
84// ─── ProfileGuard ─────────────────────────────────────────────────────────────
85
86/// RAII guard that measures wall-clock time for a scope and appends a
87/// [`ProfileEvent`] to the owning [`Profiler`] when dropped.
88///
89/// Obtain via [`Profiler::begin_event`] is the manual pair; for a scoped
90/// version use [`Profiler::profile`].  `ProfileGuard` is exposed so callers
91/// can adjust FLOP counts mid-scope via [`Self::set_flops`].
92pub struct ProfileGuard<'a> {
93    profiler: &'a mut Profiler,
94    name: String,
95    start: Instant,
96    flops: u64,
97}
98
99impl<'a> ProfileGuard<'a> {
100    /// Update the estimated FLOP count before the guard is dropped.
101    pub fn set_flops(&mut self, flops: u64) {
102        self.flops = flops;
103    }
104}
105
106impl<'a> Drop for ProfileGuard<'a> {
107    fn drop(&mut self) {
108        let elapsed = self.start.elapsed();
109        if !self.profiler.enabled {
110            return;
111        }
112        if let Some(trace) = self.profiler.current_trace.as_mut() {
113            let event = ProfileEvent {
114                name: self.name.clone(),
115                duration: elapsed,
116                memory_delta_bytes: 0,
117                flops: self.flops,
118                metadata: HashMap::new(),
119            };
120            trace.total_flops = trace.total_flops.saturating_add(event.flops);
121            trace.events.push(event);
122        }
123    }
124}
125
126// ─── ProfileTrace ─────────────────────────────────────────────────────────────
127
128/// Complete record of one inference pass.
129#[derive(Debug, Clone, Default)]
130pub struct ProfileTrace {
131    /// Ordered list of events that occurred during the pass.
132    pub events: Vec<ProfileEvent>,
133    /// Total wall-clock duration of the entire trace.
134    pub total_duration: Duration,
135    /// Peak resident memory observed during the trace (best-effort).
136    pub peak_memory_bytes: usize,
137    /// Sum of estimated FLOPs across all events.
138    pub total_flops: u64,
139}
140
141impl ProfileTrace {
142    /// Return the `n` events with the longest duration, sorted descending.
143    pub fn top_events(&self, n: usize) -> Vec<&ProfileEvent> {
144        let mut refs: Vec<&ProfileEvent> = self.events.iter().collect();
145        refs.sort_by_key(|b| std::cmp::Reverse(b.duration));
146        refs.into_iter().take(n).collect()
147    }
148
149    /// Sum the durations of all events whose names start with `prefix`.
150    pub fn duration_for_prefix(&self, prefix: &str) -> Duration {
151        self.events
152            .iter()
153            .filter(|e| e.name.starts_with(prefix))
154            .map(|e| e.duration)
155            .fold(Duration::ZERO, |acc, d| acc + d)
156    }
157
158    /// Average duration of events whose names start with `prefix`.
159    ///
160    /// Returns `None` if no events match.
161    pub fn avg_duration_for_prefix(&self, prefix: &str) -> Option<Duration> {
162        let matching: Vec<Duration> = self
163            .events
164            .iter()
165            .filter(|e| e.name.starts_with(prefix))
166            .map(|e| e.duration)
167            .collect();
168
169        if matching.is_empty() {
170            return None;
171        }
172
173        let total_nanos: u128 = matching.iter().map(|d| d.as_nanos()).sum();
174        let avg_nanos = total_nanos / matching.len() as u128;
175        Some(Duration::from_nanos(avg_nanos as u64))
176    }
177
178    /// Human-readable summary of the trace.
179    pub fn summary(&self) -> String {
180        let mut out = String::with_capacity(512);
181        let _ = writeln!(
182            out,
183            "=== ProfileTrace: {:.3} ms total, {} events, {:.2} GFLOPs ===",
184            self.total_duration.as_secs_f64() * 1_000.0,
185            self.events.len(),
186            self.aggregate_gflops(),
187        );
188        let _ = writeln!(out, "  peak_memory: {} bytes", self.peak_memory_bytes);
189
190        let top = self.top_events(10);
191        if !top.is_empty() {
192            let _ = writeln!(out, "  Top events by duration:");
193            for ev in top {
194                let _ = writeln!(
195                    out,
196                    "    {:40} {:8.3} ms  {:6.2} GFLOPs/s",
197                    ev.name,
198                    ev.duration_ms(),
199                    ev.gflops_per_second(),
200                );
201            }
202        }
203
204        out
205    }
206
207    /// Overall GFLOPs/s: total_flops / total_duration.
208    ///
209    /// Returns `0.0` if duration is zero.
210    pub fn aggregate_gflops(&self) -> f64 {
211        let secs = self.total_duration.as_secs_f64();
212        if secs <= 0.0 || self.total_flops == 0 {
213            return 0.0;
214        }
215        (self.total_flops as f64) / secs / 1e9
216    }
217
218    /// Map from event name to duration in milliseconds.
219    ///
220    /// If multiple events share the same name, their durations are summed.
221    pub fn layer_breakdown(&self) -> HashMap<String, f64> {
222        let mut map: HashMap<String, f64> = HashMap::new();
223        for ev in &self.events {
224            *map.entry(ev.name.clone()).or_insert(0.0) += ev.duration_ms();
225        }
226        map
227    }
228}
229
230// ─── Profiler ─────────────────────────────────────────────────────────────────
231
232/// Main inference profiler.
233///
234/// Maintains a stack of completed [`ProfileTrace`]s and an optional
235/// in-progress trace.  Use [`Self::begin_trace`] / [`Self::end_trace`] to
236/// bracket an inference pass, and [`Self::profile`] (or the
237/// `begin_event`/`end_event` pair) to record individual operations.
238pub struct Profiler {
239    /// All completed traces.
240    traces: Vec<ProfileTrace>,
241    /// The trace currently being built, if any.
242    current_trace: Option<ProfileTrace>,
243    /// Wall-clock start of the current trace.
244    current_trace_start: Option<Instant>,
245    /// When `false`, `profile()` still runs closures but records nothing.
246    enabled: bool,
247    /// RSS at profiler construction (reserved for future memory-delta tracking).
248    #[allow(dead_code)]
249    memory_baseline: usize,
250}
251
252impl Profiler {
253    /// Create an enabled profiler.
254    pub fn new() -> Self {
255        Self {
256            traces: Vec::new(),
257            current_trace: None,
258            current_trace_start: None,
259            enabled: true,
260            memory_baseline: crate::memory::get_rss_bytes() as usize,
261        }
262    }
263
264    /// Create a profiler with an explicit enabled/disabled flag.
265    pub fn enabled(enabled: bool) -> Self {
266        Self {
267            enabled,
268            ..Self::new()
269        }
270    }
271
272    /// Whether the profiler is currently recording events.
273    pub fn is_enabled(&self) -> bool {
274        self.enabled
275    }
276
277    /// Begin a new trace.
278    ///
279    /// Any previous in-progress trace is discarded; call [`Self::end_trace`]
280    /// first if you want to keep it.
281    pub fn begin_trace(&mut self) {
282        if !self.enabled {
283            return;
284        }
285        self.current_trace = Some(ProfileTrace::default());
286        self.current_trace_start = Some(Instant::now());
287    }
288
289    /// Finalise the current trace, push it to the completed list, and return
290    /// a clone.
291    ///
292    /// Returns `None` if no trace is in progress.
293    pub fn end_trace(&mut self) -> Option<ProfileTrace> {
294        let trace_start = self.current_trace_start.take()?;
295        let mut trace = self.current_trace.take()?;
296        trace.total_duration = trace_start.elapsed();
297        trace.peak_memory_bytes = crate::memory::get_rss_bytes() as usize;
298        self.traces.push(trace.clone());
299        Some(trace)
300    }
301
302    /// Record the start of an event and return the `Instant`.
303    ///
304    /// Pair with [`Self::end_event`].
305    pub fn begin_event(&mut self, _name: impl Into<String>) -> Instant {
306        Instant::now()
307    }
308
309    /// Complete an event started at `start_time` and record it in the active
310    /// trace (if any).
311    pub fn end_event(&mut self, name: impl Into<String>, start_time: Instant, flops: u64) {
312        if !self.enabled {
313            return;
314        }
315        let elapsed = start_time.elapsed();
316        if let Some(trace) = self.current_trace.as_mut() {
317            let event = ProfileEvent {
318                name: name.into(),
319                duration: elapsed,
320                memory_delta_bytes: 0,
321                flops,
322                metadata: HashMap::new(),
323            };
324            trace.total_flops = trace.total_flops.saturating_add(event.flops);
325            trace.events.push(event);
326        }
327    }
328
329    /// Time `f`, record the event as `name` with `flops` estimated FLOPs, and
330    /// return whatever `f` returns.
331    ///
332    /// When the profiler is disabled the closure is still executed; only
333    /// recording is skipped.
334    pub fn profile<F, R>(&mut self, name: impl Into<String>, flops: u64, f: F) -> R
335    where
336        F: FnOnce() -> R,
337    {
338        if !self.enabled {
339            return f();
340        }
341        let name_str: String = name.into();
342        let start = Instant::now();
343        let result = f();
344        let elapsed = start.elapsed();
345        if let Some(trace) = self.current_trace.as_mut() {
346            let event = ProfileEvent {
347                name: name_str,
348                duration: elapsed,
349                memory_delta_bytes: 0,
350                flops,
351                metadata: HashMap::new(),
352            };
353            trace.total_flops = trace.total_flops.saturating_add(event.flops);
354            trace.events.push(event);
355        }
356        result
357    }
358
359    /// Return a scoped guard that records an event when dropped.
360    ///
361    /// This allows recording events that span a `?`-early-return path without
362    /// an explicit `end_event` call.
363    pub fn scoped<'a>(&'a mut self, name: impl Into<String>) -> ProfileGuard<'a> {
364        ProfileGuard {
365            profiler: self,
366            name: name.into(),
367            start: Instant::now(),
368            flops: 0,
369        }
370    }
371
372    /// All completed traces (oldest first).
373    pub fn traces(&self) -> &[ProfileTrace] {
374        &self.traces
375    }
376
377    /// The most recently completed trace, if any.
378    pub fn last_trace(&self) -> Option<&ProfileTrace> {
379        self.traces.last()
380    }
381
382    /// Aggregate statistics across all completed traces.
383    pub fn aggregate_stats(&self) -> AggregateStats {
384        let num_traces = self.traces.len();
385        if num_traces == 0 {
386            return AggregateStats {
387                num_traces: 0,
388                total_duration: Duration::ZERO,
389                avg_duration: Duration::ZERO,
390                p50_duration: Duration::ZERO,
391                p99_duration: Duration::ZERO,
392                total_flops: 0,
393                avg_tokens_per_second: 0.0,
394            };
395        }
396
397        let total_duration: Duration = self
398            .traces
399            .iter()
400            .map(|t| t.total_duration)
401            .fold(Duration::ZERO, |acc, d| acc + d);
402
403        let avg_nanos = total_duration.as_nanos() / num_traces as u128;
404        let avg_duration = Duration::from_nanos(avg_nanos as u64);
405
406        let total_flops: u64 = self
407            .traces
408            .iter()
409            .map(|t| t.total_flops)
410            .fold(0u64, |acc, f| acc.saturating_add(f));
411
412        // Percentile computation on sorted durations
413        let mut sorted_nanos: Vec<u128> = self
414            .traces
415            .iter()
416            .map(|t| t.total_duration.as_nanos())
417            .collect();
418        sorted_nanos.sort_unstable();
419
420        let p50_idx = (num_traces as f64 * 0.50) as usize;
421        let p99_idx = ((num_traces as f64 * 0.99) as usize).min(num_traces - 1);
422
423        let p50_nanos = sorted_nanos.get(p50_idx).copied().unwrap_or(0);
424        let p99_nanos = sorted_nanos.get(p99_idx).copied().unwrap_or(0);
425
426        let p50_duration = Duration::from_nanos(p50_nanos as u64);
427        let p99_duration = Duration::from_nanos(p99_nanos as u64);
428
429        // avg tokens/s approximation: assume 1 "token" per trace for now
430        let avg_tokens_per_second = if avg_duration.as_secs_f64() > 0.0 {
431            1.0 / avg_duration.as_secs_f64()
432        } else {
433            0.0
434        };
435
436        AggregateStats {
437            num_traces,
438            total_duration,
439            avg_duration,
440            p50_duration,
441            p99_duration,
442            total_flops,
443            avg_tokens_per_second,
444        }
445    }
446}
447
448impl Default for Profiler {
449    fn default() -> Self {
450        Self::new()
451    }
452}
453
454// ─── AggregateStats ───────────────────────────────────────────────────────────
455
456/// Aggregate statistics across multiple completed [`ProfileTrace`]s.
457#[derive(Debug, Clone)]
458pub struct AggregateStats {
459    /// Number of traces included in the aggregate.
460    pub num_traces: usize,
461    /// Sum of all trace durations.
462    pub total_duration: Duration,
463    /// Mean trace duration.
464    pub avg_duration: Duration,
465    /// Median (p50) trace duration.
466    pub p50_duration: Duration,
467    /// 99th-percentile trace duration.
468    pub p99_duration: Duration,
469    /// Sum of all FLOPs across all traces.
470    pub total_flops: u64,
471    /// Approximate average tokens per second (1 token per trace).
472    pub avg_tokens_per_second: f64,
473}
474
475impl AggregateStats {
476    /// Human-readable aggregate summary.
477    pub fn summary(&self) -> String {
478        let mut out = String::with_capacity(256);
479        let _ = writeln!(out, "=== AggregateStats ({} traces) ===", self.num_traces);
480        let _ = writeln!(
481            out,
482            "  total_duration : {:.3} ms",
483            self.total_duration.as_secs_f64() * 1_000.0,
484        );
485        let _ = writeln!(
486            out,
487            "  avg_duration   : {:.3} ms",
488            self.avg_duration.as_secs_f64() * 1_000.0,
489        );
490        let _ = writeln!(
491            out,
492            "  p50_duration   : {:.3} ms",
493            self.p50_duration.as_secs_f64() * 1_000.0,
494        );
495        let _ = writeln!(
496            out,
497            "  p99_duration   : {:.3} ms",
498            self.p99_duration.as_secs_f64() * 1_000.0,
499        );
500        let _ = writeln!(out, "  total_flops    : {}", self.total_flops);
501        let _ = writeln!(out, "  avg_tok/s      : {:.2}", self.avg_tokens_per_second,);
502        out
503    }
504}
505
506// ─── flop_counter ─────────────────────────────────────────────────────────────
507
508/// FLOP estimation helpers for common transformer operations.
509///
510/// All formulas count multiply-add pairs as **2** FLOPs (the standard
511/// "operations" convention used in most ML literature).
512pub mod flop_counter {
513    /// FLOPs for a general matrix multiplication A\[m,k\] × B\[k,n\].
514    ///
515    /// Formula: `2 * m * k * n`
516    pub fn matmul(m: usize, k: usize, n: usize) -> u64 {
517        2u64.saturating_mul(m as u64)
518            .saturating_mul(k as u64)
519            .saturating_mul(n as u64)
520    }
521
522    /// FLOPs for a linear (fully-connected) layer without bias.
523    ///
524    /// Equivalent to `matmul(batch, in_features, out_features)`.
525    pub fn linear(batch: usize, in_features: usize, out_features: usize) -> u64 {
526        matmul(batch, in_features, out_features)
527    }
528
529    /// FLOPs for scaled dot-product attention.
530    ///
531    /// Formula: `2 * seq_len^2 * head_dim * num_heads`
532    ///
533    /// This accounts for the QK^T and softmax(QK^T)V matmuls but not
534    /// the projection layers.
535    pub fn attention(seq_len: usize, head_dim: usize, num_heads: usize) -> u64 {
536        2u64.saturating_mul(seq_len as u64)
537            .saturating_mul(seq_len as u64)
538            .saturating_mul(head_dim as u64)
539            .saturating_mul(num_heads as u64)
540    }
541
542    /// FLOPs for RMSNorm over a sequence.
543    ///
544    /// Formula: `5 * seq_len * hidden`  (square, sum, rsqrt, scale, multiply).
545    pub fn rms_norm(seq_len: usize, hidden: usize) -> u64 {
546        5u64.saturating_mul(seq_len as u64)
547            .saturating_mul(hidden as u64)
548    }
549
550    /// FLOPs for a SwiGLU feed-forward network.
551    ///
552    /// Counts three linear projections:
553    /// - gate projection:  `batch × hidden × intermediate`
554    /// - up projection:    `batch × hidden × intermediate`
555    /// - down projection:  `batch × intermediate × hidden`
556    ///
557    /// Plus the element-wise SiLU gate (2 ops per element):
558    /// `2 * batch * intermediate`
559    ///
560    /// Total: `2*batch*(2*hidden*intermediate + intermediate*hidden + intermediate)`
561    pub fn swiglu_ffn(seq_len: usize, hidden: usize, intermediate: usize) -> u64 {
562        // gate + up projections: each is seq_len × hidden × intermediate
563        let gate_up = 2u64
564            .saturating_mul(seq_len as u64)
565            .saturating_mul(hidden as u64)
566            .saturating_mul(intermediate as u64);
567        // down projection: seq_len × intermediate × hidden
568        let down = 2u64
569            .saturating_mul(seq_len as u64)
570            .saturating_mul(intermediate as u64)
571            .saturating_mul(hidden as u64);
572        // SiLU element-wise gate (2 ops per element)
573        let silu = 2u64
574            .saturating_mul(seq_len as u64)
575            .saturating_mul(intermediate as u64);
576
577        gate_up
578            .saturating_add(gate_up)
579            .saturating_add(down)
580            .saturating_add(silu)
581    }
582}
583
584// ─── Unit tests ──────────────────────────────────────────────────────────────
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    #[test]
591    fn profile_event_new() {
592        let ev = ProfileEvent::new("test.layer");
593        assert_eq!(ev.name, "test.layer");
594        assert_eq!(ev.flops, 0);
595        assert_eq!(ev.duration, Duration::ZERO);
596    }
597
598    #[test]
599    fn profile_event_builders() {
600        let ev = ProfileEvent::new("layer")
601            .with_flops(1_000_000)
602            .with_metadata("dtype", "f16");
603        assert_eq!(ev.flops, 1_000_000);
604        assert_eq!(ev.metadata["dtype"], "f16");
605    }
606
607    #[test]
608    fn profile_event_duration_ms() {
609        let mut ev = ProfileEvent::new("x");
610        ev.duration = Duration::from_millis(250);
611        assert!((ev.duration_ms() - 250.0).abs() < 1e-6);
612    }
613
614    #[test]
615    fn profile_event_gflops_zero_duration() {
616        let mut ev = ProfileEvent::new("x");
617        ev.flops = 1_000_000_000;
618        assert_eq!(ev.gflops_per_second(), 0.0);
619    }
620
621    #[test]
622    fn flop_counter_matmul_formula() {
623        assert_eq!(flop_counter::matmul(2, 3, 4), 48);
624    }
625
626    #[test]
627    fn flop_counter_linear_formula() {
628        assert_eq!(flop_counter::linear(1, 4, 8), 64);
629    }
630
631    #[test]
632    fn flop_counter_attention_formula() {
633        // 2 * seq_len^2 * head_dim * num_heads = 2 * 4 * 4 * 8 * 2 = 512
634        assert_eq!(flop_counter::attention(4, 8, 2), 512);
635    }
636}