1use std::io::Write;
2
3use chrono::{DateTime, Utc};
4use sqlx::{QueryBuilder, Row, Sqlite, SqlitePool};
5use tokio::sync::mpsc;
6use tokio::time::{interval, Duration, MissedTickBehavior};
7use uuid::Uuid;
8
9use crate::error::{GuardError, GuardResult};
10
11#[derive(Debug, Clone, PartialEq)]
13pub struct AuditEntry {
14 pub id: Uuid,
16 pub session_id: Option<Uuid>,
18 pub workspace_id: Uuid,
20 pub agent_id: Option<Uuid>,
22 pub action: String,
24 pub resource: String,
26 pub resource_id: Option<String>,
28 pub decision: String,
30 pub reason: Option<String>,
32 pub risk_score: f64,
34 pub metadata: serde_json::Value,
36 pub ts: DateTime<Utc>,
38}
39
40#[derive(Debug, Clone, Default, PartialEq)]
42pub struct AuditFilter {
43 pub workspace_id: Option<Uuid>,
45 pub session_id: Option<Uuid>,
47 pub decision: Option<String>,
49 pub start_time: Option<DateTime<Utc>>,
51 pub end_time: Option<DateTime<Utc>>,
53 pub resource: Option<String>,
55 pub limit: Option<u32>,
57}
58
59#[derive(Clone)]
61pub struct AuditWriter {
62 tx: mpsc::Sender<AuditEntry>,
63}
64
65#[derive(Clone)]
67pub struct AuditReader {
68 pool: SqlitePool,
69}
70
71impl AuditWriter {
72 pub fn new(pool: SqlitePool, flush_interval: Duration, batch_size: usize) -> Self {
74 let batch_size = batch_size.max(1);
75 let (tx, mut rx) = mpsc::channel::<AuditEntry>(batch_size * 2);
76 tokio::spawn(async move {
77 let mut ticker = interval(flush_interval);
78 ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
79 let mut batch = Vec::with_capacity(batch_size);
80 loop {
81 tokio::select! {
82 maybe_entry = rx.recv() => {
83 match maybe_entry {
84 Some(entry) => {
85 batch.push(entry);
86 if batch.len() >= batch_size {
87 let _ = flush_batch(&pool, &mut batch).await;
88 }
89 }
90 None => {
91 let _ = flush_batch(&pool, &mut batch).await;
92 break;
93 }
94 }
95 }
96 _ = ticker.tick() => {
97 let _ = flush_batch(&pool, &mut batch).await;
98 }
99 }
100 }
101 });
102 Self { tx }
103 }
104
105 pub async fn write(&self, entry: AuditEntry) -> GuardResult<()> {
107 self.tx
108 .send(entry)
109 .await
110 .map_err(|_| GuardError::AuditChannelClosed)
111 }
112}
113
114impl AuditReader {
115 pub fn new(pool: SqlitePool) -> Self {
117 Self { pool }
118 }
119
120 pub async fn query(&self, filter: AuditFilter) -> GuardResult<Vec<AuditEntry>> {
122 let mut builder = QueryBuilder::<Sqlite>::new(
123 "SELECT id, session_id, workspace_id, agent_id, action, resource, resource_id, decision, reason, risk_score, metadata, ts FROM audit_log WHERE 1 = 1",
124 );
125
126 if let Some(workspace_id) = filter.workspace_id {
127 builder
128 .push(" AND workspace_id = ")
129 .push_bind(workspace_id.to_string());
130 }
131 if let Some(session_id) = filter.session_id {
132 builder
133 .push(" AND session_id = ")
134 .push_bind(session_id.to_string());
135 }
136 if let Some(decision) = filter.decision {
137 builder.push(" AND decision = ").push_bind(decision);
138 }
139 if let Some(start_time) = filter.start_time {
140 builder
141 .push(" AND ts >= ")
142 .push_bind(start_time.timestamp_millis());
143 }
144 if let Some(end_time) = filter.end_time {
145 builder
146 .push(" AND ts <= ")
147 .push_bind(end_time.timestamp_millis());
148 }
149 if let Some(resource) = filter.resource {
150 builder.push(" AND resource = ").push_bind(resource);
151 }
152 builder.push(" ORDER BY ts DESC");
153 if let Some(limit) = filter.limit {
154 builder.push(" LIMIT ").push_bind(limit as i64);
155 }
156
157 let rows = builder.build().fetch_all(&self.pool).await?;
158 rows.iter().map(row_to_audit_entry).collect()
159 }
160
161 pub async fn export_csv(&self, filter: AuditFilter, mut writer: impl Write) -> GuardResult<()> {
163 writer.write_all(b"id,session_id,workspace_id,agent_id,action,resource,resource_id,decision,reason,risk_score,ts,metadata\n")?;
164 for entry in self.query(filter).await? {
165 writeln!(
166 writer,
167 "{},{},{},{},{},{},{},{},{},{:.4},{},{}",
168 csv_escape(&entry.id.to_string()),
169 csv_escape(
170 &entry
171 .session_id
172 .map(|value| value.to_string())
173 .unwrap_or_default()
174 ),
175 csv_escape(&entry.workspace_id.to_string()),
176 csv_escape(
177 &entry
178 .agent_id
179 .map(|value| value.to_string())
180 .unwrap_or_default()
181 ),
182 csv_escape(&entry.action),
183 csv_escape(&entry.resource),
184 csv_escape(&entry.resource_id.unwrap_or_default()),
185 csv_escape(&entry.decision),
186 csv_escape(&entry.reason.unwrap_or_default()),
187 entry.risk_score,
188 csv_escape(&entry.ts.to_rfc3339()),
189 csv_escape(&serde_json::to_string(&entry.metadata)?),
190 )?;
191 }
192 Ok(())
193 }
194}
195
196pub(crate) async fn write_direct(pool: &SqlitePool, entry: &AuditEntry) -> GuardResult<()> {
197 sqlx::query(
198 "INSERT INTO audit_log (id, session_id, workspace_id, agent_id, action, resource, resource_id, decision, reason, risk_score, metadata, ts)
199 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
200 )
201 .bind(entry.id.to_string())
202 .bind(entry.session_id.map(|value| value.to_string()))
203 .bind(entry.workspace_id.to_string())
204 .bind(entry.agent_id.map(|value| value.to_string()))
205 .bind(&entry.action)
206 .bind(&entry.resource)
207 .bind(&entry.resource_id)
208 .bind(&entry.decision)
209 .bind(&entry.reason)
210 .bind(entry.risk_score)
211 .bind(serde_json::to_string(&entry.metadata)?)
212 .bind(entry.ts.timestamp_millis())
213 .execute(pool)
214 .await?;
215 Ok(())
216}
217
218async fn flush_batch(pool: &SqlitePool, batch: &mut Vec<AuditEntry>) -> GuardResult<()> {
219 if batch.is_empty() {
220 return Ok(());
221 }
222
223 let pending = std::mem::take(batch);
224 let mut tx = pool.begin().await?;
225 for entry in &pending {
226 sqlx::query(
227 "INSERT INTO audit_log (id, session_id, workspace_id, agent_id, action, resource, resource_id, decision, reason, risk_score, metadata, ts)
228 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
229 )
230 .bind(entry.id.to_string())
231 .bind(entry.session_id.map(|value| value.to_string()))
232 .bind(entry.workspace_id.to_string())
233 .bind(entry.agent_id.map(|value| value.to_string()))
234 .bind(&entry.action)
235 .bind(&entry.resource)
236 .bind(&entry.resource_id)
237 .bind(&entry.decision)
238 .bind(&entry.reason)
239 .bind(entry.risk_score)
240 .bind(serde_json::to_string(&entry.metadata)?)
241 .bind(entry.ts.timestamp_millis())
242 .execute(&mut *tx)
243 .await?;
244 }
245 tx.commit().await?;
246 Ok(())
247}
248
249fn row_to_audit_entry(row: &sqlx::sqlite::SqliteRow) -> GuardResult<AuditEntry> {
250 Ok(AuditEntry {
251 id: Uuid::parse_str(&row.try_get::<String, _>("id")?)?,
252 session_id: row
253 .try_get::<Option<String>, _>("session_id")?
254 .map(|value| Uuid::parse_str(&value))
255 .transpose()?,
256 workspace_id: Uuid::parse_str(&row.try_get::<String, _>("workspace_id")?)?,
257 agent_id: row
258 .try_get::<Option<String>, _>("agent_id")?
259 .map(|value| Uuid::parse_str(&value))
260 .transpose()?,
261 action: row.try_get("action")?,
262 resource: row.try_get("resource")?,
263 resource_id: row.try_get("resource_id")?,
264 decision: row.try_get("decision")?,
265 reason: row.try_get("reason")?,
266 risk_score: row.try_get("risk_score")?,
267 metadata: serde_json::from_str(&row.try_get::<String, _>("metadata")?)?,
268 ts: from_ms(row.try_get("ts")?)?,
269 })
270}
271
272fn from_ms(value: i64) -> GuardResult<DateTime<Utc>> {
273 DateTime::from_timestamp_millis(value)
274 .ok_or_else(|| GuardError::ConfigError(format!("invalid timestamp millis: {value}")))
275}
276
277fn csv_escape(value: &str) -> String {
278 if value.contains(',') || value.contains('"') || value.contains('\n') {
279 format!("\"{}\"", value.replace('"', "\"\""))
280 } else {
281 value.to_owned()
282 }
283}