Skip to main content

claw_guard/audit/
mod.rs

1use std::io::Write;
2
3use chrono::{DateTime, Utc};
4use sqlx::{QueryBuilder, Row, Sqlite, SqlitePool};
5use tokio::sync::mpsc;
6use tokio::time::{interval, Duration, MissedTickBehavior};
7use uuid::Uuid;
8
9use crate::error::{GuardError, GuardResult};
10
11/// A persisted audit log entry.
12#[derive(Debug, Clone, PartialEq)]
13pub struct AuditEntry {
14    /// Audit entry identifier.
15    pub id: Uuid,
16    /// Session identifier, when available.
17    pub session_id: Option<Uuid>,
18    /// Workspace identifier for the audited action.
19    pub workspace_id: Uuid,
20    /// Agent identifier, when available.
21    pub agent_id: Option<Uuid>,
22    /// Action name.
23    pub action: String,
24    /// Resource name.
25    pub resource: String,
26    /// Resource identifier, when available.
27    pub resource_id: Option<String>,
28    /// Decision string: `Allow`, `Deny`, or `Mask`.
29    pub decision: String,
30    /// Optional decision reason.
31    pub reason: Option<String>,
32    /// Computed risk score.
33    pub risk_score: f64,
34    /// Additional structured metadata.
35    pub metadata: serde_json::Value,
36    /// Event timestamp.
37    pub ts: DateTime<Utc>,
38}
39
40/// Query filter for audit log lookups.
41#[derive(Debug, Clone, Default, PartialEq)]
42pub struct AuditFilter {
43    /// Workspace identifier to filter by.
44    pub workspace_id: Option<Uuid>,
45    /// Session identifier to filter by.
46    pub session_id: Option<Uuid>,
47    /// Decision name to filter by.
48    pub decision: Option<String>,
49    /// Inclusive start time.
50    pub start_time: Option<DateTime<Utc>>,
51    /// Inclusive end time.
52    pub end_time: Option<DateTime<Utc>>,
53    /// Resource name to filter by.
54    pub resource: Option<String>,
55    /// Result limit.
56    pub limit: Option<u32>,
57}
58
59/// Asynchronous audit writer that batches inserts in the background.
60#[derive(Clone)]
61pub struct AuditWriter {
62    tx: mpsc::Sender<AuditEntry>,
63}
64
65/// Read-only audit accessors.
66#[derive(Clone)]
67pub struct AuditReader {
68    pool: SqlitePool,
69}
70
71impl AuditWriter {
72    /// Creates a new background audit writer.
73    pub fn new(pool: SqlitePool, flush_interval: Duration, batch_size: usize) -> Self {
74        let batch_size = batch_size.max(1);
75        let (tx, mut rx) = mpsc::channel::<AuditEntry>(batch_size * 2);
76        tokio::spawn(async move {
77            let mut ticker = interval(flush_interval);
78            ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
79            let mut batch = Vec::with_capacity(batch_size);
80            loop {
81                tokio::select! {
82                    maybe_entry = rx.recv() => {
83                        match maybe_entry {
84                            Some(entry) => {
85                                batch.push(entry);
86                                if batch.len() >= batch_size {
87                                    let _ = flush_batch(&pool, &mut batch).await;
88                                }
89                            }
90                            None => {
91                                let _ = flush_batch(&pool, &mut batch).await;
92                                break;
93                            }
94                        }
95                    }
96                    _ = ticker.tick() => {
97                        let _ = flush_batch(&pool, &mut batch).await;
98                    }
99                }
100            }
101        });
102        Self { tx }
103    }
104
105    /// Queues an audit entry for asynchronous persistence.
106    pub async fn write(&self, entry: AuditEntry) -> GuardResult<()> {
107        self.tx
108            .send(entry)
109            .await
110            .map_err(|_| GuardError::AuditChannelClosed)
111    }
112}
113
114impl AuditReader {
115    /// Creates a new audit reader.
116    pub fn new(pool: SqlitePool) -> Self {
117        Self { pool }
118    }
119
120    /// Queries audit rows using the provided filter.
121    pub async fn query(&self, filter: AuditFilter) -> GuardResult<Vec<AuditEntry>> {
122        let mut builder = QueryBuilder::<Sqlite>::new(
123            "SELECT id, session_id, workspace_id, agent_id, action, resource, resource_id, decision, reason, risk_score, metadata, ts FROM audit_log WHERE 1 = 1",
124        );
125
126        if let Some(workspace_id) = filter.workspace_id {
127            builder
128                .push(" AND workspace_id = ")
129                .push_bind(workspace_id.to_string());
130        }
131        if let Some(session_id) = filter.session_id {
132            builder
133                .push(" AND session_id = ")
134                .push_bind(session_id.to_string());
135        }
136        if let Some(decision) = filter.decision {
137            builder.push(" AND decision = ").push_bind(decision);
138        }
139        if let Some(start_time) = filter.start_time {
140            builder
141                .push(" AND ts >= ")
142                .push_bind(start_time.timestamp_millis());
143        }
144        if let Some(end_time) = filter.end_time {
145            builder
146                .push(" AND ts <= ")
147                .push_bind(end_time.timestamp_millis());
148        }
149        if let Some(resource) = filter.resource {
150            builder.push(" AND resource = ").push_bind(resource);
151        }
152        builder.push(" ORDER BY ts DESC");
153        if let Some(limit) = filter.limit {
154            builder.push(" LIMIT ").push_bind(limit as i64);
155        }
156
157        let rows = builder.build().fetch_all(&self.pool).await?;
158        rows.iter().map(row_to_audit_entry).collect()
159    }
160
161    /// Exports filtered audit rows as CSV.
162    pub async fn export_csv(&self, filter: AuditFilter, mut writer: impl Write) -> GuardResult<()> {
163        writer.write_all(b"id,session_id,workspace_id,agent_id,action,resource,resource_id,decision,reason,risk_score,ts,metadata\n")?;
164        for entry in self.query(filter).await? {
165            writeln!(
166                writer,
167                "{},{},{},{},{},{},{},{},{},{:.4},{},{}",
168                csv_escape(&entry.id.to_string()),
169                csv_escape(
170                    &entry
171                        .session_id
172                        .map(|value| value.to_string())
173                        .unwrap_or_default()
174                ),
175                csv_escape(&entry.workspace_id.to_string()),
176                csv_escape(
177                    &entry
178                        .agent_id
179                        .map(|value| value.to_string())
180                        .unwrap_or_default()
181                ),
182                csv_escape(&entry.action),
183                csv_escape(&entry.resource),
184                csv_escape(&entry.resource_id.unwrap_or_default()),
185                csv_escape(&entry.decision),
186                csv_escape(&entry.reason.unwrap_or_default()),
187                entry.risk_score,
188                csv_escape(&entry.ts.to_rfc3339()),
189                csv_escape(&serde_json::to_string(&entry.metadata)?),
190            )?;
191        }
192        Ok(())
193    }
194}
195
196pub(crate) async fn write_direct(pool: &SqlitePool, entry: &AuditEntry) -> GuardResult<()> {
197    sqlx::query(
198        "INSERT INTO audit_log (id, session_id, workspace_id, agent_id, action, resource, resource_id, decision, reason, risk_score, metadata, ts)
199         VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
200    )
201    .bind(entry.id.to_string())
202    .bind(entry.session_id.map(|value| value.to_string()))
203    .bind(entry.workspace_id.to_string())
204    .bind(entry.agent_id.map(|value| value.to_string()))
205    .bind(&entry.action)
206    .bind(&entry.resource)
207    .bind(&entry.resource_id)
208    .bind(&entry.decision)
209    .bind(&entry.reason)
210    .bind(entry.risk_score)
211    .bind(serde_json::to_string(&entry.metadata)?)
212    .bind(entry.ts.timestamp_millis())
213    .execute(pool)
214    .await?;
215    Ok(())
216}
217
218async fn flush_batch(pool: &SqlitePool, batch: &mut Vec<AuditEntry>) -> GuardResult<()> {
219    if batch.is_empty() {
220        return Ok(());
221    }
222
223    let pending = std::mem::take(batch);
224    let mut tx = pool.begin().await?;
225    for entry in &pending {
226        sqlx::query(
227            "INSERT INTO audit_log (id, session_id, workspace_id, agent_id, action, resource, resource_id, decision, reason, risk_score, metadata, ts)
228             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
229        )
230        .bind(entry.id.to_string())
231        .bind(entry.session_id.map(|value| value.to_string()))
232        .bind(entry.workspace_id.to_string())
233        .bind(entry.agent_id.map(|value| value.to_string()))
234        .bind(&entry.action)
235        .bind(&entry.resource)
236        .bind(&entry.resource_id)
237        .bind(&entry.decision)
238        .bind(&entry.reason)
239        .bind(entry.risk_score)
240        .bind(serde_json::to_string(&entry.metadata)?)
241        .bind(entry.ts.timestamp_millis())
242        .execute(&mut *tx)
243        .await?;
244    }
245    tx.commit().await?;
246    Ok(())
247}
248
249fn row_to_audit_entry(row: &sqlx::sqlite::SqliteRow) -> GuardResult<AuditEntry> {
250    Ok(AuditEntry {
251        id: Uuid::parse_str(&row.try_get::<String, _>("id")?)?,
252        session_id: row
253            .try_get::<Option<String>, _>("session_id")?
254            .map(|value| Uuid::parse_str(&value))
255            .transpose()?,
256        workspace_id: Uuid::parse_str(&row.try_get::<String, _>("workspace_id")?)?,
257        agent_id: row
258            .try_get::<Option<String>, _>("agent_id")?
259            .map(|value| Uuid::parse_str(&value))
260            .transpose()?,
261        action: row.try_get("action")?,
262        resource: row.try_get("resource")?,
263        resource_id: row.try_get("resource_id")?,
264        decision: row.try_get("decision")?,
265        reason: row.try_get("reason")?,
266        risk_score: row.try_get("risk_score")?,
267        metadata: serde_json::from_str(&row.try_get::<String, _>("metadata")?)?,
268        ts: from_ms(row.try_get("ts")?)?,
269    })
270}
271
272fn from_ms(value: i64) -> GuardResult<DateTime<Utc>> {
273    DateTime::from_timestamp_millis(value)
274        .ok_or_else(|| GuardError::ConfigError(format!("invalid timestamp millis: {value}")))
275}
276
277fn csv_escape(value: &str) -> String {
278    if value.contains(',') || value.contains('"') || value.contains('\n') {
279        format!("\"{}\"", value.replace('"', "\"\""))
280    } else {
281        value.to_owned()
282    }
283}