Skip to main content

engram/storage/
audit.rs

1//! Audit logging for all operations (RML-884)
2//!
3//! Append-only audit log for tracking who changed what and when.
4
5use chrono::{DateTime, Utc};
6use rusqlite::{params, Connection};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use crate::error::Result;
11use crate::types::MemoryId;
12
13/// Audit log entry
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AuditEntry {
16    pub id: i64,
17    pub timestamp: DateTime<Utc>,
18    pub user_id: Option<String>,
19    pub action: AuditAction,
20    pub memory_id: Option<MemoryId>,
21    pub changes: Option<serde_json::Value>,
22    pub ip_address: Option<String>,
23}
24
25/// Types of auditable actions
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
27#[serde(rename_all = "snake_case")]
28pub enum AuditAction {
29    Create,
30    Update,
31    Delete,
32    Link,
33    Unlink,
34    Search,
35    Export,
36    Import,
37    SyncPush,
38    SyncPull,
39    Login,
40    Logout,
41}
42
43impl AuditAction {
44    pub fn as_str(&self) -> &'static str {
45        match self {
46            AuditAction::Create => "create",
47            AuditAction::Update => "update",
48            AuditAction::Delete => "delete",
49            AuditAction::Link => "link",
50            AuditAction::Unlink => "unlink",
51            AuditAction::Search => "search",
52            AuditAction::Export => "export",
53            AuditAction::Import => "import",
54            AuditAction::SyncPush => "sync_push",
55            AuditAction::SyncPull => "sync_pull",
56            AuditAction::Login => "login",
57            AuditAction::Logout => "logout",
58        }
59    }
60}
61
62impl std::str::FromStr for AuditAction {
63    type Err = String;
64
65    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
66        match s {
67            "create" => Ok(AuditAction::Create),
68            "update" => Ok(AuditAction::Update),
69            "delete" => Ok(AuditAction::Delete),
70            "link" => Ok(AuditAction::Link),
71            "unlink" => Ok(AuditAction::Unlink),
72            "search" => Ok(AuditAction::Search),
73            "export" => Ok(AuditAction::Export),
74            "import" => Ok(AuditAction::Import),
75            "sync_push" => Ok(AuditAction::SyncPush),
76            "sync_pull" => Ok(AuditAction::SyncPull),
77            "login" => Ok(AuditAction::Login),
78            "logout" => Ok(AuditAction::Logout),
79            _ => Err(format!("Unknown audit action: {}", s)),
80        }
81    }
82}
83
84/// Log an audit entry
85pub fn log_audit(
86    conn: &Connection,
87    action: AuditAction,
88    memory_id: Option<MemoryId>,
89    user_id: Option<&str>,
90    changes: Option<&serde_json::Value>,
91    ip_address: Option<&str>,
92) -> Result<i64> {
93    let now = Utc::now().to_rfc3339();
94    let changes_str = changes.map(|c| c.to_string());
95
96    conn.execute(
97        "INSERT INTO audit_log (timestamp, user_id, action, memory_id, changes, ip_address)
98         VALUES (?, ?, ?, ?, ?, ?)",
99        params![
100            now,
101            user_id,
102            action.as_str(),
103            memory_id,
104            changes_str,
105            ip_address,
106        ],
107    )?;
108
109    Ok(conn.last_insert_rowid())
110}
111
112/// Calculate a diff between two memory states
113pub fn calculate_diff(old: &serde_json::Value, new: &serde_json::Value) -> serde_json::Value {
114    let mut diff = serde_json::Map::new();
115
116    if let (Some(old_obj), Some(new_obj)) = (old.as_object(), new.as_object()) {
117        // Check for changed/added fields
118        for (key, new_val) in new_obj {
119            match old_obj.get(key) {
120                Some(old_val) if old_val != new_val => {
121                    diff.insert(
122                        key.clone(),
123                        serde_json::json!({
124                            "old": old_val,
125                            "new": new_val,
126                        }),
127                    );
128                }
129                None => {
130                    diff.insert(
131                        key.clone(),
132                        serde_json::json!({
133                            "old": null,
134                            "new": new_val,
135                        }),
136                    );
137                }
138                _ => {}
139            }
140        }
141
142        // Check for removed fields
143        for key in old_obj.keys() {
144            if !new_obj.contains_key(key) {
145                diff.insert(
146                    key.clone(),
147                    serde_json::json!({
148                        "old": old_obj.get(key),
149                        "new": null,
150                    }),
151                );
152            }
153        }
154    }
155
156    serde_json::Value::Object(diff)
157}
158
159/// Query audit log entries
160pub fn query_audit_log(conn: &Connection, filter: &AuditFilter) -> Result<Vec<AuditEntry>> {
161    let mut sql = String::from(
162        "SELECT id, timestamp, user_id, action, memory_id, changes, ip_address
163         FROM audit_log WHERE 1=1",
164    );
165    let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
166
167    if let Some(memory_id) = filter.memory_id {
168        sql.push_str(" AND memory_id = ?");
169        params_vec.push(Box::new(memory_id));
170    }
171
172    if let Some(ref user_id) = filter.user_id {
173        sql.push_str(" AND user_id = ?");
174        params_vec.push(Box::new(user_id.clone()));
175    }
176
177    if let Some(ref action) = filter.action {
178        sql.push_str(" AND action = ?");
179        params_vec.push(Box::new(action.as_str().to_string()));
180    }
181
182    if let Some(ref since) = filter.since {
183        sql.push_str(" AND timestamp >= ?");
184        params_vec.push(Box::new(since.to_rfc3339()));
185    }
186
187    if let Some(ref until) = filter.until {
188        sql.push_str(" AND timestamp <= ?");
189        params_vec.push(Box::new(until.to_rfc3339()));
190    }
191
192    sql.push_str(" ORDER BY timestamp DESC");
193
194    if let Some(limit) = filter.limit {
195        sql.push_str(&format!(" LIMIT {}", limit));
196    }
197
198    let params_ref: Vec<&dyn rusqlite::ToSql> = params_vec.iter().map(|b| b.as_ref()).collect();
199    let mut stmt = conn.prepare(&sql)?;
200
201    let entries: Vec<AuditEntry> = stmt
202        .query_map(params_ref.as_slice(), |row| {
203            let timestamp_str: String = row.get("timestamp")?;
204            let action_str: String = row.get("action")?;
205            let changes_str: Option<String> = row.get("changes")?;
206
207            Ok(AuditEntry {
208                id: row.get("id")?,
209                timestamp: DateTime::parse_from_rfc3339(&timestamp_str)
210                    .map(|dt| dt.with_timezone(&Utc))
211                    .unwrap_or_else(|_| Utc::now()),
212                user_id: row.get("user_id")?,
213                action: action_str.parse().unwrap_or(AuditAction::Update),
214                memory_id: row.get("memory_id")?,
215                changes: changes_str.and_then(|s| serde_json::from_str(&s).ok()),
216                ip_address: row.get("ip_address")?,
217            })
218        })?
219        .filter_map(|r| r.ok())
220        .collect();
221
222    Ok(entries)
223}
224
225/// Filter for querying audit log
226#[derive(Debug, Clone, Default)]
227pub struct AuditFilter {
228    pub memory_id: Option<MemoryId>,
229    pub user_id: Option<String>,
230    pub action: Option<AuditAction>,
231    pub since: Option<DateTime<Utc>>,
232    pub until: Option<DateTime<Utc>>,
233    pub limit: Option<i64>,
234}
235
236/// Get audit summary for a memory
237pub fn get_memory_audit_summary(conn: &Connection, memory_id: MemoryId) -> Result<AuditSummary> {
238    let filter = AuditFilter {
239        memory_id: Some(memory_id),
240        limit: Some(1000),
241        ..Default::default()
242    };
243
244    let entries = query_audit_log(conn, &filter)?;
245
246    let total_changes = entries.len();
247    let unique_users: std::collections::HashSet<_> =
248        entries.iter().filter_map(|e| e.user_id.as_ref()).collect();
249    let first_action = entries.last().map(|e| e.timestamp);
250    let last_action = entries.first().map(|e| e.timestamp);
251
252    let mut action_counts: HashMap<String, i64> = HashMap::new();
253    for entry in &entries {
254        *action_counts
255            .entry(entry.action.as_str().to_string())
256            .or_insert(0) += 1;
257    }
258
259    Ok(AuditSummary {
260        memory_id,
261        total_changes,
262        unique_users: unique_users.len(),
263        first_action,
264        last_action,
265        action_counts,
266    })
267}
268
269/// Summary of audit activity for a memory
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct AuditSummary {
272    pub memory_id: MemoryId,
273    pub total_changes: usize,
274    pub unique_users: usize,
275    pub first_action: Option<DateTime<Utc>>,
276    pub last_action: Option<DateTime<Utc>>,
277    pub action_counts: HashMap<String, i64>,
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_calculate_diff() {
286        let old = serde_json::json!({
287            "content": "old content",
288            "importance": 0.5,
289            "removed_field": "value"
290        });
291
292        let new = serde_json::json!({
293            "content": "new content",
294            "importance": 0.5,
295            "new_field": "new value"
296        });
297
298        let diff = calculate_diff(&old, &new);
299        let diff_obj = diff.as_object().unwrap();
300
301        assert!(diff_obj.contains_key("content"));
302        assert!(diff_obj.contains_key("removed_field"));
303        assert!(diff_obj.contains_key("new_field"));
304        assert!(!diff_obj.contains_key("importance")); // unchanged
305    }
306
307    #[test]
308    fn test_audit_action_roundtrip() {
309        for action in [
310            AuditAction::Create,
311            AuditAction::Update,
312            AuditAction::Delete,
313        ] {
314            let s = action.as_str();
315            let parsed: AuditAction = s.parse().unwrap();
316            assert_eq!(action, parsed);
317        }
318    }
319}