1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
24use std::sync::Arc;
25
26use async_trait::async_trait;
27use serde_json::json;
28use tokio::sync::RwLock;
29use tracing::{instrument, warn};
30
31use crate::metrics::METRICS;
32
33use oxidizedgraph::events::{
34 spawn_handler, CheckpointEvent, Event, EventBus, EventHandler, EventKind, GraphEvent,
35 NodeEvent, StateEvent,
36};
37
38use oxidized_state::storage_traits::{
39 ContentDigest, RunEvent, RunId, RunLedger, RunMetadata, RunSummary,
40};
41
42pub struct LedgerHandler<L: RunLedger> {
50 ledger: Arc<L>,
51 run_id: RwLock<Option<RunId>>,
52 seq: AtomicU64,
53 spec_digest: ContentDigest,
54 metadata: RunMetadata,
55 saw_error: AtomicBool,
56 start_time: RwLock<Option<std::time::Instant>>,
57}
58
59impl<L: RunLedger> LedgerHandler<L> {
60 pub fn new(ledger: Arc<L>, spec_digest: ContentDigest, metadata: RunMetadata) -> Self {
62 Self {
63 ledger,
64 run_id: RwLock::new(None),
65 seq: AtomicU64::new(1),
66 spec_digest,
67 metadata,
68 saw_error: AtomicBool::new(false),
69 start_time: RwLock::new(None),
70 }
71 }
72
73 pub async fn run_id(&self) -> Option<RunId> {
75 self.run_id.read().await.clone()
76 }
77
78 pub fn saw_error(&self) -> bool {
80 self.saw_error.load(Ordering::SeqCst)
81 }
82
83 pub fn seq(&self) -> u64 {
85 self.seq.load(Ordering::SeqCst)
86 }
87
88 fn next_seq(&self) -> u64 {
90 self.seq.fetch_add(1, Ordering::SeqCst)
91 }
92}
93
94fn map_event(event: &Event) -> (String, serde_json::Value) {
96 match &event.kind {
97 EventKind::Graph(g) => match g {
98 GraphEvent::Started {
99 graph_name,
100 entry_point,
101 } => (
102 "graph_started".into(),
103 json!({ "graph_name": graph_name, "entry_point": entry_point }),
104 ),
105 GraphEvent::Completed {
106 iterations,
107 duration_ms,
108 } => (
109 "graph_completed".into(),
110 json!({ "iterations": iterations, "duration_ms": duration_ms }),
111 ),
112 GraphEvent::Error { error } => ("graph_failed".into(), json!({ "error": error })),
113 GraphEvent::Interrupted { reason, node_id } => (
114 "graph_interrupted".into(),
115 json!({ "reason": reason, "node_id": node_id }),
116 ),
117 },
118 EventKind::Node(n) => match n {
119 NodeEvent::Entered { node_id, iteration } => (
120 "node_entered".into(),
121 json!({ "node_id": node_id, "iteration": iteration }),
122 ),
123 NodeEvent::Exited {
124 node_id,
125 next_node,
126 duration_ms,
127 } => (
128 "node_exited".into(),
129 json!({ "node_id": node_id, "next_node": next_node, "duration_ms": duration_ms }),
130 ),
131 NodeEvent::Error { node_id, error } => (
132 "node_failed".into(),
133 json!({ "node_id": node_id, "error": error }),
134 ),
135 NodeEvent::Retrying {
136 node_id,
137 attempt,
138 delay_ms,
139 } => (
140 "node_retrying".into(),
141 json!({ "node_id": node_id, "attempt": attempt, "delay_ms": delay_ms }),
142 ),
143 },
144 EventKind::Checkpoint(c) => match c {
145 CheckpointEvent::Saved {
146 checkpoint_id,
147 node_id,
148 } => (
149 "checkpoint_saved".into(),
150 json!({ "checkpoint_id": checkpoint_id, "node_id": node_id }),
151 ),
152 CheckpointEvent::Restored {
153 checkpoint_id,
154 node_id,
155 } => (
156 "checkpoint_restored".into(),
157 json!({ "checkpoint_id": checkpoint_id, "node_id": node_id }),
158 ),
159 CheckpointEvent::Deleted { checkpoint_id } => (
160 "checkpoint_deleted".into(),
161 json!({ "checkpoint_id": checkpoint_id }),
162 ),
163 },
164 EventKind::State(s) => match s {
165 StateEvent::Updated {
166 node_id,
167 keys_changed,
168 } => (
169 "state_updated".into(),
170 json!({ "node_id": node_id, "keys_changed": keys_changed }),
171 ),
172 StateEvent::MessageAdded {
173 role,
174 content_length,
175 } => (
176 "message_added".into(),
177 json!({ "role": role, "content_length": content_length }),
178 ),
179 },
180 EventKind::Custom { name, payload } => (format!("Custom:{name}"), payload.clone()),
181 }
182}
183
184#[async_trait]
185impl<L: RunLedger + 'static> EventHandler for LedgerHandler<L> {
186 #[instrument(skip(self), name = "ledger_handler_on_start")]
187 async fn on_start(&self) {
188 *self.start_time.write().await = Some(std::time::Instant::now());
189
190 match self
191 .ledger
192 .create_run(&self.spec_digest, self.metadata.clone())
193 .await
194 {
195 Ok(id) => {
196 *self.run_id.write().await = Some(id);
197 }
198 Err(e) => {
199 warn!(error = %e, "LedgerHandler: failed to create run");
200 }
201 }
202 }
203
204 #[instrument(skip(self, event), name = "ledger_handler_handle", level = "debug")]
205 async fn handle(&self, event: &Event) {
206 let run_id = {
207 let guard = self.run_id.read().await;
208 match guard.as_ref() {
209 Some(id) => id.clone(),
210 None => return,
211 }
212 };
213
214 METRICS.inc_events_processed();
215
216 let (kind, payload) = map_event(event);
217
218 if matches!(
220 &event.kind,
221 EventKind::Graph(GraphEvent::Error { .. }) | EventKind::Node(NodeEvent::Error { .. })
222 ) {
223 self.saw_error.store(true, Ordering::SeqCst);
224 }
225
226 let run_event = RunEvent {
227 seq: self.next_seq(),
228 kind,
229 payload,
230 timestamp: event.timestamp,
231 };
232
233 if let Err(e) = self.ledger.append_event(&run_id, run_event).await {
234 warn!(error = %e, run_id = %run_id, "LedgerHandler: failed to append event");
235 }
236 }
237
238 #[instrument(skip(self), name = "ledger_handler_on_stop")]
239 async fn on_stop(&self) {
240 let run_id = {
241 let guard = self.run_id.read().await;
242 match guard.as_ref() {
243 Some(id) => id.clone(),
244 None => return,
245 }
246 };
247
248 let total_events = self.seq.load(Ordering::SeqCst) - 1;
249 let duration_ms = self
250 .start_time
251 .read()
252 .await
253 .map(|t| t.elapsed().as_millis() as u64)
254 .unwrap_or(0);
255 let success = !self.saw_error.load(Ordering::SeqCst);
256
257 let summary = RunSummary {
258 total_events,
259 final_state_digest: None,
260 duration_ms,
261 success,
262 };
263
264 let result = if success {
265 self.ledger.complete_run(&run_id, summary).await
266 } else {
267 self.ledger.fail_run(&run_id, summary).await
268 };
269
270 if let Err(e) = result {
271 warn!(error = %e, run_id = %run_id, "LedgerHandler: failed to finalize run");
272 }
273 }
274}
275
276pub fn subscribe_ledger_to_bus<L: RunLedger + 'static>(
282 bus: &EventBus,
283 ledger: Arc<L>,
284 spec_digest: ContentDigest,
285 metadata: RunMetadata,
286) -> Arc<LedgerHandler<L>> {
287 let handler = Arc::new(LedgerHandler::new(ledger, spec_digest, metadata));
288 let receiver = bus.subscribe();
289 spawn_handler(handler.clone(), receiver);
290 handler
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use oxidized_state::fakes::MemoryRunLedger;
297 use oxidized_state::storage_traits::RunStatus;
298 use std::time::Duration;
299
300 fn test_digest() -> ContentDigest {
301 ContentDigest::from_bytes(b"test-spec")
302 }
303
304 fn test_metadata() -> RunMetadata {
305 RunMetadata {
306 git_sha: None,
307 agent_name: "test-agent".into(),
308 tags: json!({}),
309 }
310 }
311
312 #[tokio::test]
313 async fn map_event_covers_all_variants() {
314 let cases = vec![
315 (
316 Event::graph_started("t", Some("g".into()), "entry".into()),
317 "graph_started",
318 ),
319 (
320 Event::graph_completed("t", 5, Duration::from_millis(100)),
321 "graph_completed",
322 ),
323 (Event::graph_error("t", "boom".into()), "graph_failed"),
324 (Event::node_entered("t", "n".into(), 1), "node_entered"),
325 (
326 Event::node_exited("t", "n".into(), Some("m".into()), Duration::from_millis(50)),
327 "node_exited",
328 ),
329 (
330 Event::node_error("t", "n".into(), "fail".into()),
331 "node_failed",
332 ),
333 (
334 Event::checkpoint_saved("t", "cp1".into(), "n".into()),
335 "checkpoint_saved",
336 ),
337 (
338 Event::checkpoint_restored("t", "cp1".into(), "n".into()),
339 "checkpoint_restored",
340 ),
341 (
342 Event::state_updated("t", "n".into(), vec!["key".into()]),
343 "state_updated",
344 ),
345 ];
346
347 for (event, expected_kind) in cases {
348 let (kind, _) = map_event(&event);
349 assert_eq!(kind, expected_kind, "wrong kind for {expected_kind}");
350 }
351 }
352
353 #[tokio::test]
354 async fn handler_creates_and_completes_run() {
355 let ledger = Arc::new(MemoryRunLedger::new());
356 let handler = LedgerHandler::new(ledger.clone(), test_digest(), test_metadata());
357
358 handler.on_start().await;
359 let run_id = handler.run_id().await.expect("run_id should be set");
360
361 let event = Event::graph_started("t", Some("g".into()), "entry".into());
363 handler.handle(&event).await;
364 assert_eq!(handler.seq(), 2);
366
367 handler.on_stop().await;
368
369 let record = ledger.get_run(&run_id).await.unwrap();
370 assert_eq!(record.status, RunStatus::Completed);
371 assert!(record.summary.as_ref().unwrap().success);
372 }
373
374 #[tokio::test]
375 async fn handler_marks_run_failed_on_error_event() {
376 let ledger = Arc::new(MemoryRunLedger::new());
377 let handler = LedgerHandler::new(ledger.clone(), test_digest(), test_metadata());
378
379 handler.on_start().await;
380 let run_id = handler.run_id().await.unwrap();
381
382 let event = Event::graph_error("t", "kaboom".into());
383 handler.handle(&event).await;
384
385 handler.on_stop().await;
386
387 let record = ledger.get_run(&run_id).await.unwrap();
388 assert_eq!(record.status, RunStatus::Failed);
389 assert!(!record.summary.as_ref().unwrap().success);
390 }
391
392 #[tokio::test]
393 async fn custom_event_mapping() {
394 let event = Event::new(
395 "t",
396 EventKind::Custom {
397 name: "MyCustom".into(),
398 payload: json!({"foo": "bar"}),
399 },
400 );
401 let (kind, payload) = map_event(&event);
402 assert_eq!(kind, "Custom:MyCustom");
403 assert_eq!(payload, json!({"foo": "bar"}));
404 }
405}