1use std::collections::HashMap;
23use std::fmt::Write as FmtWrite;
24use std::time::{Duration, Instant};
25
26#[derive(Debug, Clone)]
30pub struct ProfileEvent {
31 pub name: String,
33 pub duration: Duration,
35 pub memory_delta_bytes: i64,
37 pub flops: u64,
39 pub metadata: HashMap<String, String>,
41}
42
43impl ProfileEvent {
44 pub fn new(name: impl Into<String>) -> Self {
46 Self {
47 name: name.into(),
48 duration: Duration::ZERO,
49 memory_delta_bytes: 0,
50 flops: 0,
51 metadata: HashMap::new(),
52 }
53 }
54
55 pub fn with_flops(mut self, flops: u64) -> Self {
57 self.flops = flops;
58 self
59 }
60
61 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
63 self.metadata.insert(key.into(), value.into());
64 self
65 }
66
67 pub fn duration_ms(&self) -> f64 {
69 self.duration.as_secs_f64() * 1_000.0
70 }
71
72 pub fn gflops_per_second(&self) -> f64 {
76 let secs = self.duration.as_secs_f64();
77 if secs <= 0.0 || self.flops == 0 {
78 return 0.0;
79 }
80 (self.flops as f64) / secs / 1e9
81 }
82}
83
84pub struct ProfileGuard<'a> {
93 profiler: &'a mut Profiler,
94 name: String,
95 start: Instant,
96 flops: u64,
97}
98
99impl<'a> ProfileGuard<'a> {
100 pub fn set_flops(&mut self, flops: u64) {
102 self.flops = flops;
103 }
104}
105
106impl<'a> Drop for ProfileGuard<'a> {
107 fn drop(&mut self) {
108 let elapsed = self.start.elapsed();
109 if !self.profiler.enabled {
110 return;
111 }
112 if let Some(trace) = self.profiler.current_trace.as_mut() {
113 let event = ProfileEvent {
114 name: self.name.clone(),
115 duration: elapsed,
116 memory_delta_bytes: 0,
117 flops: self.flops,
118 metadata: HashMap::new(),
119 };
120 trace.total_flops = trace.total_flops.saturating_add(event.flops);
121 trace.events.push(event);
122 }
123 }
124}
125
126#[derive(Debug, Clone, Default)]
130pub struct ProfileTrace {
131 pub events: Vec<ProfileEvent>,
133 pub total_duration: Duration,
135 pub peak_memory_bytes: usize,
137 pub total_flops: u64,
139}
140
141impl ProfileTrace {
142 pub fn top_events(&self, n: usize) -> Vec<&ProfileEvent> {
144 let mut refs: Vec<&ProfileEvent> = self.events.iter().collect();
145 refs.sort_by_key(|b| std::cmp::Reverse(b.duration));
146 refs.into_iter().take(n).collect()
147 }
148
149 pub fn duration_for_prefix(&self, prefix: &str) -> Duration {
151 self.events
152 .iter()
153 .filter(|e| e.name.starts_with(prefix))
154 .map(|e| e.duration)
155 .fold(Duration::ZERO, |acc, d| acc + d)
156 }
157
158 pub fn avg_duration_for_prefix(&self, prefix: &str) -> Option<Duration> {
162 let matching: Vec<Duration> = self
163 .events
164 .iter()
165 .filter(|e| e.name.starts_with(prefix))
166 .map(|e| e.duration)
167 .collect();
168
169 if matching.is_empty() {
170 return None;
171 }
172
173 let total_nanos: u128 = matching.iter().map(|d| d.as_nanos()).sum();
174 let avg_nanos = total_nanos / matching.len() as u128;
175 Some(Duration::from_nanos(avg_nanos as u64))
176 }
177
178 pub fn summary(&self) -> String {
180 let mut out = String::with_capacity(512);
181 let _ = writeln!(
182 out,
183 "=== ProfileTrace: {:.3} ms total, {} events, {:.2} GFLOPs ===",
184 self.total_duration.as_secs_f64() * 1_000.0,
185 self.events.len(),
186 self.aggregate_gflops(),
187 );
188 let _ = writeln!(out, " peak_memory: {} bytes", self.peak_memory_bytes);
189
190 let top = self.top_events(10);
191 if !top.is_empty() {
192 let _ = writeln!(out, " Top events by duration:");
193 for ev in top {
194 let _ = writeln!(
195 out,
196 " {:40} {:8.3} ms {:6.2} GFLOPs/s",
197 ev.name,
198 ev.duration_ms(),
199 ev.gflops_per_second(),
200 );
201 }
202 }
203
204 out
205 }
206
207 pub fn aggregate_gflops(&self) -> f64 {
211 let secs = self.total_duration.as_secs_f64();
212 if secs <= 0.0 || self.total_flops == 0 {
213 return 0.0;
214 }
215 (self.total_flops as f64) / secs / 1e9
216 }
217
218 pub fn layer_breakdown(&self) -> HashMap<String, f64> {
222 let mut map: HashMap<String, f64> = HashMap::new();
223 for ev in &self.events {
224 *map.entry(ev.name.clone()).or_insert(0.0) += ev.duration_ms();
225 }
226 map
227 }
228}
229
230pub struct Profiler {
239 traces: Vec<ProfileTrace>,
241 current_trace: Option<ProfileTrace>,
243 current_trace_start: Option<Instant>,
245 enabled: bool,
247 #[allow(dead_code)]
249 memory_baseline: usize,
250}
251
252impl Profiler {
253 pub fn new() -> Self {
255 Self {
256 traces: Vec::new(),
257 current_trace: None,
258 current_trace_start: None,
259 enabled: true,
260 memory_baseline: crate::memory::get_rss_bytes() as usize,
261 }
262 }
263
264 pub fn enabled(enabled: bool) -> Self {
266 Self {
267 enabled,
268 ..Self::new()
269 }
270 }
271
272 pub fn is_enabled(&self) -> bool {
274 self.enabled
275 }
276
277 pub fn begin_trace(&mut self) {
282 if !self.enabled {
283 return;
284 }
285 self.current_trace = Some(ProfileTrace::default());
286 self.current_trace_start = Some(Instant::now());
287 }
288
289 pub fn end_trace(&mut self) -> Option<ProfileTrace> {
294 let trace_start = self.current_trace_start.take()?;
295 let mut trace = self.current_trace.take()?;
296 trace.total_duration = trace_start.elapsed();
297 trace.peak_memory_bytes = crate::memory::get_rss_bytes() as usize;
298 self.traces.push(trace.clone());
299 Some(trace)
300 }
301
302 pub fn begin_event(&mut self, _name: impl Into<String>) -> Instant {
306 Instant::now()
307 }
308
309 pub fn end_event(&mut self, name: impl Into<String>, start_time: Instant, flops: u64) {
312 if !self.enabled {
313 return;
314 }
315 let elapsed = start_time.elapsed();
316 if let Some(trace) = self.current_trace.as_mut() {
317 let event = ProfileEvent {
318 name: name.into(),
319 duration: elapsed,
320 memory_delta_bytes: 0,
321 flops,
322 metadata: HashMap::new(),
323 };
324 trace.total_flops = trace.total_flops.saturating_add(event.flops);
325 trace.events.push(event);
326 }
327 }
328
329 pub fn profile<F, R>(&mut self, name: impl Into<String>, flops: u64, f: F) -> R
335 where
336 F: FnOnce() -> R,
337 {
338 if !self.enabled {
339 return f();
340 }
341 let name_str: String = name.into();
342 let start = Instant::now();
343 let result = f();
344 let elapsed = start.elapsed();
345 if let Some(trace) = self.current_trace.as_mut() {
346 let event = ProfileEvent {
347 name: name_str,
348 duration: elapsed,
349 memory_delta_bytes: 0,
350 flops,
351 metadata: HashMap::new(),
352 };
353 trace.total_flops = trace.total_flops.saturating_add(event.flops);
354 trace.events.push(event);
355 }
356 result
357 }
358
359 pub fn scoped<'a>(&'a mut self, name: impl Into<String>) -> ProfileGuard<'a> {
364 ProfileGuard {
365 profiler: self,
366 name: name.into(),
367 start: Instant::now(),
368 flops: 0,
369 }
370 }
371
372 pub fn traces(&self) -> &[ProfileTrace] {
374 &self.traces
375 }
376
377 pub fn last_trace(&self) -> Option<&ProfileTrace> {
379 self.traces.last()
380 }
381
382 pub fn aggregate_stats(&self) -> AggregateStats {
384 let num_traces = self.traces.len();
385 if num_traces == 0 {
386 return AggregateStats {
387 num_traces: 0,
388 total_duration: Duration::ZERO,
389 avg_duration: Duration::ZERO,
390 p50_duration: Duration::ZERO,
391 p99_duration: Duration::ZERO,
392 total_flops: 0,
393 avg_tokens_per_second: 0.0,
394 };
395 }
396
397 let total_duration: Duration = self
398 .traces
399 .iter()
400 .map(|t| t.total_duration)
401 .fold(Duration::ZERO, |acc, d| acc + d);
402
403 let avg_nanos = total_duration.as_nanos() / num_traces as u128;
404 let avg_duration = Duration::from_nanos(avg_nanos as u64);
405
406 let total_flops: u64 = self
407 .traces
408 .iter()
409 .map(|t| t.total_flops)
410 .fold(0u64, |acc, f| acc.saturating_add(f));
411
412 let mut sorted_nanos: Vec<u128> = self
414 .traces
415 .iter()
416 .map(|t| t.total_duration.as_nanos())
417 .collect();
418 sorted_nanos.sort_unstable();
419
420 let p50_idx = (num_traces as f64 * 0.50) as usize;
421 let p99_idx = ((num_traces as f64 * 0.99) as usize).min(num_traces - 1);
422
423 let p50_nanos = sorted_nanos.get(p50_idx).copied().unwrap_or(0);
424 let p99_nanos = sorted_nanos.get(p99_idx).copied().unwrap_or(0);
425
426 let p50_duration = Duration::from_nanos(p50_nanos as u64);
427 let p99_duration = Duration::from_nanos(p99_nanos as u64);
428
429 let avg_tokens_per_second = if avg_duration.as_secs_f64() > 0.0 {
431 1.0 / avg_duration.as_secs_f64()
432 } else {
433 0.0
434 };
435
436 AggregateStats {
437 num_traces,
438 total_duration,
439 avg_duration,
440 p50_duration,
441 p99_duration,
442 total_flops,
443 avg_tokens_per_second,
444 }
445 }
446}
447
448impl Default for Profiler {
449 fn default() -> Self {
450 Self::new()
451 }
452}
453
454#[derive(Debug, Clone)]
458pub struct AggregateStats {
459 pub num_traces: usize,
461 pub total_duration: Duration,
463 pub avg_duration: Duration,
465 pub p50_duration: Duration,
467 pub p99_duration: Duration,
469 pub total_flops: u64,
471 pub avg_tokens_per_second: f64,
473}
474
475impl AggregateStats {
476 pub fn summary(&self) -> String {
478 let mut out = String::with_capacity(256);
479 let _ = writeln!(out, "=== AggregateStats ({} traces) ===", self.num_traces);
480 let _ = writeln!(
481 out,
482 " total_duration : {:.3} ms",
483 self.total_duration.as_secs_f64() * 1_000.0,
484 );
485 let _ = writeln!(
486 out,
487 " avg_duration : {:.3} ms",
488 self.avg_duration.as_secs_f64() * 1_000.0,
489 );
490 let _ = writeln!(
491 out,
492 " p50_duration : {:.3} ms",
493 self.p50_duration.as_secs_f64() * 1_000.0,
494 );
495 let _ = writeln!(
496 out,
497 " p99_duration : {:.3} ms",
498 self.p99_duration.as_secs_f64() * 1_000.0,
499 );
500 let _ = writeln!(out, " total_flops : {}", self.total_flops);
501 let _ = writeln!(out, " avg_tok/s : {:.2}", self.avg_tokens_per_second,);
502 out
503 }
504}
505
506pub mod flop_counter {
513 pub fn matmul(m: usize, k: usize, n: usize) -> u64 {
517 2u64.saturating_mul(m as u64)
518 .saturating_mul(k as u64)
519 .saturating_mul(n as u64)
520 }
521
522 pub fn linear(batch: usize, in_features: usize, out_features: usize) -> u64 {
526 matmul(batch, in_features, out_features)
527 }
528
529 pub fn attention(seq_len: usize, head_dim: usize, num_heads: usize) -> u64 {
536 2u64.saturating_mul(seq_len as u64)
537 .saturating_mul(seq_len as u64)
538 .saturating_mul(head_dim as u64)
539 .saturating_mul(num_heads as u64)
540 }
541
542 pub fn rms_norm(seq_len: usize, hidden: usize) -> u64 {
546 5u64.saturating_mul(seq_len as u64)
547 .saturating_mul(hidden as u64)
548 }
549
550 pub fn swiglu_ffn(seq_len: usize, hidden: usize, intermediate: usize) -> u64 {
562 let gate_up = 2u64
564 .saturating_mul(seq_len as u64)
565 .saturating_mul(hidden as u64)
566 .saturating_mul(intermediate as u64);
567 let down = 2u64
569 .saturating_mul(seq_len as u64)
570 .saturating_mul(intermediate as u64)
571 .saturating_mul(hidden as u64);
572 let silu = 2u64
574 .saturating_mul(seq_len as u64)
575 .saturating_mul(intermediate as u64);
576
577 gate_up
578 .saturating_add(gate_up)
579 .saturating_add(down)
580 .saturating_add(silu)
581 }
582}
583
584#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn profile_event_new() {
592 let ev = ProfileEvent::new("test.layer");
593 assert_eq!(ev.name, "test.layer");
594 assert_eq!(ev.flops, 0);
595 assert_eq!(ev.duration, Duration::ZERO);
596 }
597
598 #[test]
599 fn profile_event_builders() {
600 let ev = ProfileEvent::new("layer")
601 .with_flops(1_000_000)
602 .with_metadata("dtype", "f16");
603 assert_eq!(ev.flops, 1_000_000);
604 assert_eq!(ev.metadata["dtype"], "f16");
605 }
606
607 #[test]
608 fn profile_event_duration_ms() {
609 let mut ev = ProfileEvent::new("x");
610 ev.duration = Duration::from_millis(250);
611 assert!((ev.duration_ms() - 250.0).abs() < 1e-6);
612 }
613
614 #[test]
615 fn profile_event_gflops_zero_duration() {
616 let mut ev = ProfileEvent::new("x");
617 ev.flops = 1_000_000_000;
618 assert_eq!(ev.gflops_per_second(), 0.0);
619 }
620
621 #[test]
622 fn flop_counter_matmul_formula() {
623 assert_eq!(flop_counter::matmul(2, 3, 4), 48);
624 }
625
626 #[test]
627 fn flop_counter_linear_formula() {
628 assert_eq!(flop_counter::linear(1, 4, 8), 64);
629 }
630
631 #[test]
632 fn flop_counter_attention_formula() {
633 assert_eq!(flop_counter::attention(4, 8, 2), 512);
635 }
636}