Skip to main content

panopticon_core/hooks/core/
profiler.rs

1use crate::imports::*;
2use std::io::{self, Write};
3use std::time::{Duration, Instant};
4
5/// A built-in observer hook that measures per-step execution time.
6///
7/// Records a [`Duration`](std::time::Duration) for each step (keyed by
8/// step name, with iteration index appended for steps inside an
9/// iteration) and, on pipeline completion, writes a sorted timings
10/// report to a writer (stderr by default). Retain a clone of
11/// [`timings`](Self::timings) before moving the profiler into the
12/// pipeline to inspect the data programmatically.
13pub struct Profiler {
14    name: String,
15    writer: Arc<Mutex<Box<dyn Write + Send>>>,
16    timings: Arc<Mutex<HashMap<String, Duration>>>,
17}
18
19impl Default for Profiler {
20    fn default() -> Self {
21        Profiler::new()
22    }
23}
24
25impl Profiler {
26    /// Constructs a profiler with the default name and stderr as the
27    /// report writer.
28    pub fn new() -> Self {
29        Profiler {
30            name: "profiler".into(),
31            writer: Arc::new(Mutex::new(Box::new(io::stderr()))),
32            timings: Arc::new(Mutex::new(HashMap::new())),
33        }
34    }
35
36    /// Overrides the hook name.
37    pub fn name(mut self, name: impl Into<String>) -> Self {
38        self.name = name.into();
39        self
40    }
41
42    /// Redirects the on-completion report to a user-supplied writer.
43    pub fn writer(mut self, writer: impl Write + Send + 'static) -> Self {
44        self.writer = Arc::new(Mutex::new(Box::new(writer)));
45        self
46    }
47
48    /// Returns a cloned handle to the timings map. Retain this before
49    /// moving the `Profiler` into [`Pipeline::hook`] so the data can be
50    /// inspected programmatically once the pipeline finishes.
51    pub fn timings(&self) -> Arc<Mutex<HashMap<String, Duration>>> {
52        Arc::clone(&self.timings)
53    }
54}
55
56impl From<Profiler> for Hook {
57    fn from(profiler: Profiler) -> Hook {
58        let timings = profiler.timings;
59        let writer = profiler.writer;
60        let pending: Arc<Mutex<HashMap<String, Instant>>> = Arc::new(Mutex::new(HashMap::new()));
61
62        Hook::observer(profiler.name, move |event, _store| {
63            // Build a timing key that includes iteration index when present
64            let timing_key = |step_name: &str, iter_context: &Option<&IterContext>| -> String {
65                match iter_context {
66                    Some(ctx) => format!("{}[{}]", step_name, ctx.index),
67                    None => step_name.to_string(),
68                }
69            };
70
71            #[allow(unreachable_patterns)]
72            match event {
73                HookEvent::BeforeStep {
74                    step_name,
75                    iter_context,
76                    ..
77                } => {
78                    let key = timing_key(step_name, iter_context);
79                    let mut p = pending.lock().unwrap();
80                    p.insert(key, Instant::now());
81                }
82                HookEvent::AfterStep {
83                    step_name,
84                    iter_context,
85                    ..
86                } => {
87                    let key = timing_key(step_name, iter_context);
88                    let mut p = pending.lock().unwrap();
89                    if let Some(start) = p.remove(&key) {
90                        let elapsed = start.elapsed();
91                        let mut t = timings.lock().unwrap();
92                        t.insert(key, elapsed);
93                    }
94                }
95                HookEvent::Complete => {
96                    let t = timings.lock().unwrap();
97                    let mut w = writer.lock().unwrap();
98                    let _ = writeln!(w, "[profiler] Step timings:");
99                    let mut entries: Vec<_> = t.iter().collect();
100                    entries.sort_by_key(|(name, _)| (*name).clone());
101                    for (name, duration) in entries {
102                        let _ = writeln!(w, "  {}: {:.3}ms", name, duration.as_secs_f64() * 1000.0);
103                    }
104                }
105                _ => {}
106            }
107        })
108    }
109}