1use std::collections::HashMap;
4use std::sync::Mutex;
5use std::time::Instant;
6
7use uuid::Uuid;
8
9use cognis_core::{Event, Observer};
10
11#[derive(Debug, Default, Clone)]
13pub struct GraphMetrics {
14 pub node_executions: HashMap<String, u64>,
16 pub errors: HashMap<String, u64>,
18 pub total_steps: u64,
20}
21
22pub struct MetricsObserver {
24 inner: Mutex<GraphMetrics>,
25}
26
27impl Default for MetricsObserver {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl MetricsObserver {
34 pub fn new() -> Self {
36 Self {
37 inner: Mutex::new(GraphMetrics::default()),
38 }
39 }
40
41 pub fn snapshot(&self) -> GraphMetrics {
43 self.inner.lock().map(|g| g.clone()).unwrap_or_default()
44 }
45}
46
47impl Observer for MetricsObserver {
48 fn on_event(&self, event: &Event) {
49 let mut g = match self.inner.lock() {
50 Ok(g) => g,
51 Err(_) => return,
52 };
53 match event {
54 Event::OnNodeEnd { node, .. } => {
55 *g.node_executions.entry(node.clone()).or_insert(0) += 1;
56 g.total_steps += 1;
57 }
58 Event::OnError { error, .. } => {
59 *g.errors.entry(error.clone()).or_insert(0) += 1;
60 }
61 _ => {}
62 }
63 }
64}
65
66pub struct ProfilingObserver {
69 pending: Mutex<HashMap<(Uuid, u64, String), Instant>>,
70 totals: Mutex<HashMap<String, NodeTiming>>,
71}
72
73#[derive(Debug, Default, Clone)]
75pub struct NodeTiming {
76 pub count: u64,
78 pub total_ns: u128,
80 pub max_ns: u128,
82 pub min_ns: u128,
84}
85
86impl NodeTiming {
87 pub fn mean_ns(&self) -> u128 {
89 if self.count == 0 {
90 0
91 } else {
92 self.total_ns / self.count as u128
93 }
94 }
95}
96
97impl Default for ProfilingObserver {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl ProfilingObserver {
104 pub fn new() -> Self {
106 Self {
107 pending: Mutex::new(HashMap::new()),
108 totals: Mutex::new(HashMap::new()),
109 }
110 }
111
112 pub fn snapshot(&self) -> HashMap<String, NodeTiming> {
114 self.totals.lock().map(|m| m.clone()).unwrap_or_default()
115 }
116}
117
118impl Observer for ProfilingObserver {
119 fn on_event(&self, event: &Event) {
120 match event {
121 Event::OnNodeStart { node, step, run_id } => {
122 if let Ok(mut p) = self.pending.lock() {
123 p.insert((*run_id, *step, node.clone()), Instant::now());
124 }
125 }
126 Event::OnNodeEnd {
127 node, step, run_id, ..
128 } => {
129 let mut p = match self.pending.lock() {
130 Ok(p) => p,
131 Err(_) => return,
132 };
133 let key = (*run_id, *step, node.clone());
134 let started = match p.remove(&key) {
135 Some(t) => t,
136 None => return,
137 };
138 let elapsed_ns = started.elapsed().as_nanos();
139 drop(p);
140 let mut t = match self.totals.lock() {
141 Ok(t) => t,
142 Err(_) => return,
143 };
144 let e = t.entry(node.clone()).or_insert_with(|| NodeTiming {
145 min_ns: u128::MAX,
146 ..Default::default()
147 });
148 e.count += 1;
149 e.total_ns += elapsed_ns;
150 e.max_ns = e.max_ns.max(elapsed_ns);
151 e.min_ns = e.min_ns.min(elapsed_ns);
152 }
153 _ => {}
154 }
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 fn ev_node_end(node: &str) -> Event {
163 Event::OnNodeEnd {
164 node: node.into(),
165 step: 0,
166 output: serde_json::Value::Null,
167 run_id: Uuid::nil(),
168 }
169 }
170
171 #[test]
172 fn metrics_count_executions() {
173 let m = MetricsObserver::new();
174 m.on_event(&ev_node_end("a"));
175 m.on_event(&ev_node_end("a"));
176 m.on_event(&ev_node_end("b"));
177 m.on_event(&Event::OnError {
178 error: "boom".into(),
179 run_id: Uuid::nil(),
180 });
181 let snap = m.snapshot();
182 assert_eq!(snap.node_executions["a"], 2);
183 assert_eq!(snap.node_executions["b"], 1);
184 assert_eq!(snap.total_steps, 3);
185 assert_eq!(snap.errors["boom"], 1);
186 }
187
188 #[test]
189 fn profiler_pairs_start_and_end() {
190 let p = ProfilingObserver::new();
191 let id = Uuid::nil();
192 p.on_event(&Event::OnNodeStart {
193 node: "n".into(),
194 step: 0,
195 run_id: id,
196 });
197 std::thread::sleep(std::time::Duration::from_millis(2));
198 p.on_event(&Event::OnNodeEnd {
199 node: "n".into(),
200 step: 0,
201 output: serde_json::Value::Null,
202 run_id: id,
203 });
204 let snap = p.snapshot();
205 let t = snap.get("n").unwrap();
206 assert_eq!(t.count, 1);
207 assert!(t.total_ns > 0);
208 }
209}
210
211pub type ThresholdCallback = std::sync::Arc<dyn Fn(&str, u128) + Send + Sync>;
221
222pub struct ThresholdProfiler {
231 pending: Mutex<HashMap<(Uuid, u64, String), Instant>>,
232 totals: Mutex<HashMap<String, NodeTiming>>,
233 thresholds: Mutex<HashMap<String, u128>>,
234 callbacks: Mutex<Vec<ThresholdCallback>>,
235}
236
237impl Default for ThresholdProfiler {
238 fn default() -> Self {
239 Self::new()
240 }
241}
242
243impl ThresholdProfiler {
244 pub fn new() -> Self {
246 Self {
247 pending: Mutex::new(HashMap::new()),
248 totals: Mutex::new(HashMap::new()),
249 thresholds: Mutex::new(HashMap::new()),
250 callbacks: Mutex::new(Vec::new()),
251 }
252 }
253
254 pub fn snapshot(&self) -> HashMap<String, NodeTiming> {
256 self.totals.lock().map(|m| m.clone()).unwrap_or_default()
257 }
258
259 pub fn with_threshold(self, node: impl Into<String>, max_ns: u128) -> Self {
263 if let Ok(mut t) = self.thresholds.lock() {
264 t.insert(node.into(), max_ns);
265 }
266 self
267 }
268
269 pub fn on_threshold_breached<F>(self, cb: F) -> Self
272 where
273 F: Fn(&str, u128) + Send + Sync + 'static,
274 {
275 if let Ok(mut c) = self.callbacks.lock() {
276 c.push(std::sync::Arc::new(cb));
277 }
278 self
279 }
280}
281
282impl Observer for ThresholdProfiler {
283 fn on_event(&self, event: &Event) {
284 match event {
285 Event::OnNodeStart { node, step, run_id } => {
286 if let Ok(mut p) = self.pending.lock() {
287 p.insert((*run_id, *step, node.clone()), Instant::now());
288 }
289 }
290 Event::OnNodeEnd {
291 node, step, run_id, ..
292 } => {
293 let mut p = match self.pending.lock() {
294 Ok(p) => p,
295 Err(_) => return,
296 };
297 let key = (*run_id, *step, node.clone());
298 let started = match p.remove(&key) {
299 Some(t) => t,
300 None => return,
301 };
302 let elapsed_ns = started.elapsed().as_nanos();
303 drop(p);
304 if let Ok(mut t) = self.totals.lock() {
305 let e = t.entry(node.clone()).or_insert_with(|| NodeTiming {
306 min_ns: u128::MAX,
307 ..Default::default()
308 });
309 e.count += 1;
310 e.total_ns += elapsed_ns;
311 e.max_ns = e.max_ns.max(elapsed_ns);
312 e.min_ns = e.min_ns.min(elapsed_ns);
313 }
314 let breached = self
315 .thresholds
316 .lock()
317 .ok()
318 .and_then(|m| m.get(node).copied())
319 .map(|cap| elapsed_ns > cap)
320 .unwrap_or(false);
321 if breached {
322 if let Ok(cbs) = self.callbacks.lock() {
323 for cb in cbs.iter() {
324 cb(node, elapsed_ns);
325 }
326 }
327 }
328 }
329 _ => {}
330 }
331 }
332}
333
334#[cfg(test)]
335mod threshold_tests {
336 use super::*;
337 use std::sync::atomic::{AtomicUsize, Ordering};
338 use std::sync::Arc;
339 use uuid::Uuid;
340
341 fn end(node: &str, run: Uuid) -> Event {
342 Event::OnNodeEnd {
343 node: node.into(),
344 step: 0,
345 run_id: run,
346 output: serde_json::Value::Null,
347 }
348 }
349 fn start(node: &str, run: Uuid) -> Event {
350 Event::OnNodeStart {
351 node: node.into(),
352 step: 0,
353 run_id: run,
354 }
355 }
356
357 #[test]
358 fn fires_callback_on_breach() {
359 let breaches = Arc::new(AtomicUsize::new(0));
360 let b2 = breaches.clone();
361 let p = ThresholdProfiler::new()
363 .with_threshold("slow", 1)
364 .on_threshold_breached(move |_node, _elapsed| {
365 b2.fetch_add(1, Ordering::Relaxed);
366 });
367 let run = Uuid::nil();
368 p.on_event(&start("slow", run));
369 std::thread::sleep(std::time::Duration::from_millis(2));
370 p.on_event(&end("slow", run));
371 assert_eq!(breaches.load(Ordering::Relaxed), 1);
372 }
373
374 #[test]
375 fn does_not_fire_below_threshold() {
376 let breaches = Arc::new(AtomicUsize::new(0));
377 let b2 = breaches.clone();
378 let p = ThresholdProfiler::new()
380 .with_threshold("fast", u128::MAX)
381 .on_threshold_breached(move |_, _| {
382 b2.fetch_add(1, Ordering::Relaxed);
383 });
384 let run = Uuid::nil();
385 p.on_event(&start("fast", run));
386 p.on_event(&end("fast", run));
387 assert_eq!(breaches.load(Ordering::Relaxed), 0);
388 }
389
390 #[test]
391 fn snapshot_shape_matches_profiling_observer() {
392 let p = ThresholdProfiler::new();
393 let run = Uuid::nil();
394 p.on_event(&start("n", run));
395 p.on_event(&end("n", run));
396 let snap = p.snapshot();
397 let t = snap.get("n").unwrap();
398 assert_eq!(t.count, 1);
399 }
400
401 #[test]
402 fn multiple_callbacks_all_fire() {
403 let count = Arc::new(AtomicUsize::new(0));
404 let c1 = count.clone();
405 let c2 = count.clone();
406 let p = ThresholdProfiler::new()
407 .with_threshold("n", 1)
408 .on_threshold_breached(move |_, _| {
409 c1.fetch_add(1, Ordering::Relaxed);
410 })
411 .on_threshold_breached(move |_, _| {
412 c2.fetch_add(10, Ordering::Relaxed);
413 });
414 let run = Uuid::nil();
415 p.on_event(&start("n", run));
416 std::thread::sleep(std::time::Duration::from_millis(2));
417 p.on_event(&end("n", run));
418 assert_eq!(count.load(Ordering::Relaxed), 11);
419 }
420}