1use std::sync::Mutex;
8
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use uuid::Uuid;
12
13use cognis_core::{Event, Observer, Result};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AuditEntry {
18 pub run_id: Uuid,
20 pub step: u64,
22 pub node: String,
24 pub kind: AuditKind,
26 pub ts: u64,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
32#[serde(rename_all = "snake_case")]
33pub enum AuditKind {
34 Start,
36 End,
38 Error,
40}
41
42#[async_trait]
44pub trait AuditLog: Send + Sync {
45 async fn record(&self, entry: AuditEntry) -> Result<()>;
47 async fn entries(&self, run_id: Uuid) -> Result<Vec<AuditEntry>>;
49}
50
51#[derive(Default)]
53pub struct InMemoryAuditLog {
54 inner: Mutex<Vec<AuditEntry>>,
55}
56
57impl InMemoryAuditLog {
58 pub fn new() -> Self {
60 Self::default()
61 }
62}
63
64#[async_trait]
65impl AuditLog for InMemoryAuditLog {
66 async fn record(&self, entry: AuditEntry) -> Result<()> {
67 let mut v = self
68 .inner
69 .lock()
70 .map_err(|e| cognis_core::CognisError::Internal(format!("audit mutex: {e}")))?;
71 v.push(entry);
72 Ok(())
73 }
74 async fn entries(&self, run_id: Uuid) -> Result<Vec<AuditEntry>> {
75 let v = self
76 .inner
77 .lock()
78 .map_err(|e| cognis_core::CognisError::Internal(format!("audit mutex: {e}")))?;
79 Ok(v.iter().filter(|e| e.run_id == run_id).cloned().collect())
80 }
81}
82
83pub struct AuditLogObserver {
85 log: std::sync::Arc<dyn AuditLog>,
86}
87
88impl AuditLogObserver {
89 pub fn new(log: std::sync::Arc<dyn AuditLog>) -> Self {
91 Self { log }
92 }
93}
94
95fn now_secs() -> u64 {
96 std::time::SystemTime::now()
97 .duration_since(std::time::UNIX_EPOCH)
98 .map(|d| d.as_secs())
99 .unwrap_or(0)
100}
101
102impl Observer for AuditLogObserver {
103 fn on_event(&self, event: &Event) {
104 let entry = match event {
105 Event::OnNodeStart { node, step, run_id } => AuditEntry {
106 run_id: *run_id,
107 step: *step,
108 node: node.clone(),
109 kind: AuditKind::Start,
110 ts: now_secs(),
111 },
112 Event::OnNodeEnd {
113 node, step, run_id, ..
114 } => AuditEntry {
115 run_id: *run_id,
116 step: *step,
117 node: node.clone(),
118 kind: AuditKind::End,
119 ts: now_secs(),
120 },
121 Event::OnError { error, run_id } => AuditEntry {
122 run_id: *run_id,
123 step: 0,
124 node: error.clone(),
125 kind: AuditKind::Error,
126 ts: now_secs(),
127 },
128 _ => return,
129 };
130 let log = self.log.clone();
131 tokio::spawn(async move {
132 let _ = log.record(entry).await;
133 });
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use std::sync::Arc;
141
142 #[tokio::test]
143 async fn record_and_query() {
144 let log = InMemoryAuditLog::new();
145 let id = Uuid::new_v4();
146 log.record(AuditEntry {
147 run_id: id,
148 step: 0,
149 node: "a".into(),
150 kind: AuditKind::Start,
151 ts: 1000,
152 })
153 .await
154 .unwrap();
155 log.record(AuditEntry {
156 run_id: Uuid::new_v4(),
157 step: 0,
158 node: "x".into(),
159 kind: AuditKind::Start,
160 ts: 1001,
161 })
162 .await
163 .unwrap();
164 let got = log.entries(id).await.unwrap();
165 assert_eq!(got.len(), 1);
166 assert_eq!(got[0].node, "a");
167 }
168
169 #[tokio::test]
170 async fn observer_bridges_events() {
171 let log: Arc<dyn AuditLog> = Arc::new(InMemoryAuditLog::new());
172 let obs = AuditLogObserver::new(log.clone());
173 let id = Uuid::new_v4();
174 obs.on_event(&Event::OnNodeStart {
175 node: "n".into(),
176 step: 1,
177 run_id: id,
178 });
179 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
181 let entries = log.entries(id).await.unwrap();
182 assert_eq!(entries.len(), 1);
183 }
184}