Skip to main content

cognis_graph/
metrics.rs

1//! Per-node counters + simple timing aggregation as Observers.
2
3use std::collections::HashMap;
4use std::sync::Mutex;
5use std::time::Instant;
6
7use uuid::Uuid;
8
9use cognis_core::{Event, Observer};
10
11/// Per-node execution counts and error counts.
12#[derive(Debug, Default, Clone)]
13pub struct GraphMetrics {
14    /// Times each node has finished successfully.
15    pub node_executions: HashMap<String, u64>,
16    /// Times each node errored.
17    pub errors: HashMap<String, u64>,
18    /// Total supersteps observed (`OnNodeEnd` events).
19    pub total_steps: u64,
20}
21
22/// Observer that maintains a [`GraphMetrics`] under a `Mutex`.
23pub struct MetricsObserver {
24    inner: Mutex<GraphMetrics>,
25}
26
27impl Default for MetricsObserver {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl MetricsObserver {
34    /// Empty observer.
35    pub fn new() -> Self {
36        Self {
37            inner: Mutex::new(GraphMetrics::default()),
38        }
39    }
40
41    /// Snapshot the current metrics.
42    pub fn snapshot(&self) -> GraphMetrics {
43        self.inner.lock().map(|g| g.clone()).unwrap_or_default()
44    }
45}
46
47impl Observer for MetricsObserver {
48    fn on_event(&self, event: &Event) {
49        let mut g = match self.inner.lock() {
50            Ok(g) => g,
51            Err(_) => return,
52        };
53        match event {
54            Event::OnNodeEnd { node, .. } => {
55                *g.node_executions.entry(node.clone()).or_insert(0) += 1;
56                g.total_steps += 1;
57            }
58            Event::OnError { error, .. } => {
59                *g.errors.entry(error.clone()).or_insert(0) += 1;
60            }
61            _ => {}
62        }
63    }
64}
65
66/// Per-node timing aggregator. Pairs `OnNodeStart` / `OnNodeEnd` events
67/// keyed by `(run_id, step, node)` to compute durations.
68pub struct ProfilingObserver {
69    pending: Mutex<HashMap<(Uuid, u64, String), Instant>>,
70    totals: Mutex<HashMap<String, NodeTiming>>,
71}
72
73/// One node's timing aggregate.
74#[derive(Debug, Default, Clone)]
75pub struct NodeTiming {
76    /// Number of finished invocations seen.
77    pub count: u64,
78    /// Total elapsed nanoseconds across invocations.
79    pub total_ns: u128,
80    /// Slowest single invocation.
81    pub max_ns: u128,
82    /// Fastest single invocation.
83    pub min_ns: u128,
84}
85
86impl NodeTiming {
87    /// Mean nanoseconds per invocation. Returns `0` for zero count.
88    pub fn mean_ns(&self) -> u128 {
89        if self.count == 0 {
90            0
91        } else {
92            self.total_ns / self.count as u128
93        }
94    }
95}
96
97impl Default for ProfilingObserver {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl ProfilingObserver {
104    /// Empty profiler.
105    pub fn new() -> Self {
106        Self {
107            pending: Mutex::new(HashMap::new()),
108            totals: Mutex::new(HashMap::new()),
109        }
110    }
111
112    /// Snapshot of per-node timings.
113    pub fn snapshot(&self) -> HashMap<String, NodeTiming> {
114        self.totals.lock().map(|m| m.clone()).unwrap_or_default()
115    }
116}
117
118impl Observer for ProfilingObserver {
119    fn on_event(&self, event: &Event) {
120        match event {
121            Event::OnNodeStart { node, step, run_id } => {
122                if let Ok(mut p) = self.pending.lock() {
123                    p.insert((*run_id, *step, node.clone()), Instant::now());
124                }
125            }
126            Event::OnNodeEnd {
127                node, step, run_id, ..
128            } => {
129                let mut p = match self.pending.lock() {
130                    Ok(p) => p,
131                    Err(_) => return,
132                };
133                let key = (*run_id, *step, node.clone());
134                let started = match p.remove(&key) {
135                    Some(t) => t,
136                    None => return,
137                };
138                let elapsed_ns = started.elapsed().as_nanos();
139                drop(p);
140                let mut t = match self.totals.lock() {
141                    Ok(t) => t,
142                    Err(_) => return,
143                };
144                let e = t.entry(node.clone()).or_insert_with(|| NodeTiming {
145                    min_ns: u128::MAX,
146                    ..Default::default()
147                });
148                e.count += 1;
149                e.total_ns += elapsed_ns;
150                e.max_ns = e.max_ns.max(elapsed_ns);
151                e.min_ns = e.min_ns.min(elapsed_ns);
152            }
153            _ => {}
154        }
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    fn ev_node_end(node: &str) -> Event {
163        Event::OnNodeEnd {
164            node: node.into(),
165            step: 0,
166            output: serde_json::Value::Null,
167            run_id: Uuid::nil(),
168        }
169    }
170
171    #[test]
172    fn metrics_count_executions() {
173        let m = MetricsObserver::new();
174        m.on_event(&ev_node_end("a"));
175        m.on_event(&ev_node_end("a"));
176        m.on_event(&ev_node_end("b"));
177        m.on_event(&Event::OnError {
178            error: "boom".into(),
179            run_id: Uuid::nil(),
180        });
181        let snap = m.snapshot();
182        assert_eq!(snap.node_executions["a"], 2);
183        assert_eq!(snap.node_executions["b"], 1);
184        assert_eq!(snap.total_steps, 3);
185        assert_eq!(snap.errors["boom"], 1);
186    }
187
188    #[test]
189    fn profiler_pairs_start_and_end() {
190        let p = ProfilingObserver::new();
191        let id = Uuid::nil();
192        p.on_event(&Event::OnNodeStart {
193            node: "n".into(),
194            step: 0,
195            run_id: id,
196        });
197        std::thread::sleep(std::time::Duration::from_millis(2));
198        p.on_event(&Event::OnNodeEnd {
199            node: "n".into(),
200            step: 0,
201            output: serde_json::Value::Null,
202            run_id: id,
203        });
204        let snap = p.snapshot();
205        let t = snap.get("n").unwrap();
206        assert_eq!(t.count, 1);
207        assert!(t.total_ns > 0);
208    }
209}
210
211// ────────────────────────────────────────────────────────────────────────
212// ThresholdProfiler — same timing logic as ProfilingObserver, plus
213// per-node duration thresholds that trigger a callback when breached.
214// Use to wire SLO-style alerts ("if any 'embed' invocation runs longer
215// than 5s, log/page/notify").
216// ────────────────────────────────────────────────────────────────────────
217
218/// Callback fired when a node's invocation exceeds its configured
219/// threshold. Receives the node name and the actual elapsed nanoseconds.
220pub type ThresholdCallback = std::sync::Arc<dyn Fn(&str, u128) + Send + Sync>;
221
222/// ProfilingObserver variant that also fires alerts on per-node duration
223/// thresholds. Same timing snapshot via `snapshot()`; thresholds are
224/// configured per node via `with_threshold(node, max_ns)`. On an
225/// `OnNodeEnd` whose elapsed > the configured cap, every registered
226/// callback runs (synchronously, on the observer's thread).
227///
228/// Callbacks should be cheap and non-blocking — they run inline on the
229/// graph engine's thread.
230pub struct ThresholdProfiler {
231    pending: Mutex<HashMap<(Uuid, u64, String), Instant>>,
232    totals: Mutex<HashMap<String, NodeTiming>>,
233    thresholds: Mutex<HashMap<String, u128>>,
234    callbacks: Mutex<Vec<ThresholdCallback>>,
235}
236
237impl Default for ThresholdProfiler {
238    fn default() -> Self {
239        Self::new()
240    }
241}
242
243impl ThresholdProfiler {
244    /// Empty profiler with no thresholds and no callbacks.
245    pub fn new() -> Self {
246        Self {
247            pending: Mutex::new(HashMap::new()),
248            totals: Mutex::new(HashMap::new()),
249            thresholds: Mutex::new(HashMap::new()),
250            callbacks: Mutex::new(Vec::new()),
251        }
252    }
253
254    /// Snapshot of per-node timings (same shape as `ProfilingObserver`).
255    pub fn snapshot(&self) -> HashMap<String, NodeTiming> {
256        self.totals.lock().map(|m| m.clone()).unwrap_or_default()
257    }
258
259    /// Register a duration cap (in nanoseconds) for `node`. Subsequent
260    /// invocations that exceed `max_ns` will fire every registered
261    /// callback. Replaces any prior threshold for the node.
262    pub fn with_threshold(self, node: impl Into<String>, max_ns: u128) -> Self {
263        if let Ok(mut t) = self.thresholds.lock() {
264            t.insert(node.into(), max_ns);
265        }
266        self
267    }
268
269    /// Add an alert callback. Multiple callbacks are supported and all
270    /// fire (in registration order) on each breach. Builder-style.
271    pub fn on_threshold_breached<F>(self, cb: F) -> Self
272    where
273        F: Fn(&str, u128) + Send + Sync + 'static,
274    {
275        if let Ok(mut c) = self.callbacks.lock() {
276            c.push(std::sync::Arc::new(cb));
277        }
278        self
279    }
280}
281
282impl Observer for ThresholdProfiler {
283    fn on_event(&self, event: &Event) {
284        match event {
285            Event::OnNodeStart { node, step, run_id } => {
286                if let Ok(mut p) = self.pending.lock() {
287                    p.insert((*run_id, *step, node.clone()), Instant::now());
288                }
289            }
290            Event::OnNodeEnd {
291                node, step, run_id, ..
292            } => {
293                let mut p = match self.pending.lock() {
294                    Ok(p) => p,
295                    Err(_) => return,
296                };
297                let key = (*run_id, *step, node.clone());
298                let started = match p.remove(&key) {
299                    Some(t) => t,
300                    None => return,
301                };
302                let elapsed_ns = started.elapsed().as_nanos();
303                drop(p);
304                if let Ok(mut t) = self.totals.lock() {
305                    let e = t.entry(node.clone()).or_insert_with(|| NodeTiming {
306                        min_ns: u128::MAX,
307                        ..Default::default()
308                    });
309                    e.count += 1;
310                    e.total_ns += elapsed_ns;
311                    e.max_ns = e.max_ns.max(elapsed_ns);
312                    e.min_ns = e.min_ns.min(elapsed_ns);
313                }
314                let breached = self
315                    .thresholds
316                    .lock()
317                    .ok()
318                    .and_then(|m| m.get(node).copied())
319                    .map(|cap| elapsed_ns > cap)
320                    .unwrap_or(false);
321                if breached {
322                    if let Ok(cbs) = self.callbacks.lock() {
323                        for cb in cbs.iter() {
324                            cb(node, elapsed_ns);
325                        }
326                    }
327                }
328            }
329            _ => {}
330        }
331    }
332}
333
334#[cfg(test)]
335mod threshold_tests {
336    use super::*;
337    use std::sync::atomic::{AtomicUsize, Ordering};
338    use std::sync::Arc;
339    use uuid::Uuid;
340
341    fn end(node: &str, run: Uuid) -> Event {
342        Event::OnNodeEnd {
343            node: node.into(),
344            step: 0,
345            run_id: run,
346            output: serde_json::Value::Null,
347        }
348    }
349    fn start(node: &str, run: Uuid) -> Event {
350        Event::OnNodeStart {
351            node: node.into(),
352            step: 0,
353            run_id: run,
354        }
355    }
356
357    #[test]
358    fn fires_callback_on_breach() {
359        let breaches = Arc::new(AtomicUsize::new(0));
360        let b2 = breaches.clone();
361        // 1ns threshold so any real elapsed time breaches.
362        let p = ThresholdProfiler::new()
363            .with_threshold("slow", 1)
364            .on_threshold_breached(move |_node, _elapsed| {
365                b2.fetch_add(1, Ordering::Relaxed);
366            });
367        let run = Uuid::nil();
368        p.on_event(&start("slow", run));
369        std::thread::sleep(std::time::Duration::from_millis(2));
370        p.on_event(&end("slow", run));
371        assert_eq!(breaches.load(Ordering::Relaxed), 1);
372    }
373
374    #[test]
375    fn does_not_fire_below_threshold() {
376        let breaches = Arc::new(AtomicUsize::new(0));
377        let b2 = breaches.clone();
378        // Huge threshold — no real invocation will breach.
379        let p = ThresholdProfiler::new()
380            .with_threshold("fast", u128::MAX)
381            .on_threshold_breached(move |_, _| {
382                b2.fetch_add(1, Ordering::Relaxed);
383            });
384        let run = Uuid::nil();
385        p.on_event(&start("fast", run));
386        p.on_event(&end("fast", run));
387        assert_eq!(breaches.load(Ordering::Relaxed), 0);
388    }
389
390    #[test]
391    fn snapshot_shape_matches_profiling_observer() {
392        let p = ThresholdProfiler::new();
393        let run = Uuid::nil();
394        p.on_event(&start("n", run));
395        p.on_event(&end("n", run));
396        let snap = p.snapshot();
397        let t = snap.get("n").unwrap();
398        assert_eq!(t.count, 1);
399    }
400
401    #[test]
402    fn multiple_callbacks_all_fire() {
403        let count = Arc::new(AtomicUsize::new(0));
404        let c1 = count.clone();
405        let c2 = count.clone();
406        let p = ThresholdProfiler::new()
407            .with_threshold("n", 1)
408            .on_threshold_breached(move |_, _| {
409                c1.fetch_add(1, Ordering::Relaxed);
410            })
411            .on_threshold_breached(move |_, _| {
412                c2.fetch_add(10, Ordering::Relaxed);
413            });
414        let run = Uuid::nil();
415        p.on_event(&start("n", run));
416        std::thread::sleep(std::time::Duration::from_millis(2));
417        p.on_event(&end("n", run));
418        assert_eq!(count.load(Ordering::Relaxed), 11);
419    }
420}