1use chrono::{DateTime, Utc};
6use rusqlite::{params, Connection};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use crate::error::Result;
11use crate::types::MemoryId;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AuditEntry {
16 pub id: i64,
17 pub timestamp: DateTime<Utc>,
18 pub user_id: Option<String>,
19 pub action: AuditAction,
20 pub memory_id: Option<MemoryId>,
21 pub changes: Option<serde_json::Value>,
22 pub ip_address: Option<String>,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
27#[serde(rename_all = "snake_case")]
28pub enum AuditAction {
29 Create,
30 Update,
31 Delete,
32 Link,
33 Unlink,
34 Search,
35 Export,
36 Import,
37 SyncPush,
38 SyncPull,
39 Login,
40 Logout,
41}
42
43impl AuditAction {
44 pub fn as_str(&self) -> &'static str {
45 match self {
46 AuditAction::Create => "create",
47 AuditAction::Update => "update",
48 AuditAction::Delete => "delete",
49 AuditAction::Link => "link",
50 AuditAction::Unlink => "unlink",
51 AuditAction::Search => "search",
52 AuditAction::Export => "export",
53 AuditAction::Import => "import",
54 AuditAction::SyncPush => "sync_push",
55 AuditAction::SyncPull => "sync_pull",
56 AuditAction::Login => "login",
57 AuditAction::Logout => "logout",
58 }
59 }
60}
61
62impl std::str::FromStr for AuditAction {
63 type Err = String;
64
65 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
66 match s {
67 "create" => Ok(AuditAction::Create),
68 "update" => Ok(AuditAction::Update),
69 "delete" => Ok(AuditAction::Delete),
70 "link" => Ok(AuditAction::Link),
71 "unlink" => Ok(AuditAction::Unlink),
72 "search" => Ok(AuditAction::Search),
73 "export" => Ok(AuditAction::Export),
74 "import" => Ok(AuditAction::Import),
75 "sync_push" => Ok(AuditAction::SyncPush),
76 "sync_pull" => Ok(AuditAction::SyncPull),
77 "login" => Ok(AuditAction::Login),
78 "logout" => Ok(AuditAction::Logout),
79 _ => Err(format!("Unknown audit action: {}", s)),
80 }
81 }
82}
83
84pub fn log_audit(
86 conn: &Connection,
87 action: AuditAction,
88 memory_id: Option<MemoryId>,
89 user_id: Option<&str>,
90 changes: Option<&serde_json::Value>,
91 ip_address: Option<&str>,
92) -> Result<i64> {
93 let now = Utc::now().to_rfc3339();
94 let changes_str = changes.map(|c| c.to_string());
95
96 conn.execute(
97 "INSERT INTO audit_log (timestamp, user_id, action, memory_id, changes, ip_address)
98 VALUES (?, ?, ?, ?, ?, ?)",
99 params![
100 now,
101 user_id,
102 action.as_str(),
103 memory_id,
104 changes_str,
105 ip_address,
106 ],
107 )?;
108
109 Ok(conn.last_insert_rowid())
110}
111
112pub fn calculate_diff(old: &serde_json::Value, new: &serde_json::Value) -> serde_json::Value {
114 let mut diff = serde_json::Map::new();
115
116 if let (Some(old_obj), Some(new_obj)) = (old.as_object(), new.as_object()) {
117 for (key, new_val) in new_obj {
119 match old_obj.get(key) {
120 Some(old_val) if old_val != new_val => {
121 diff.insert(
122 key.clone(),
123 serde_json::json!({
124 "old": old_val,
125 "new": new_val,
126 }),
127 );
128 }
129 None => {
130 diff.insert(
131 key.clone(),
132 serde_json::json!({
133 "old": null,
134 "new": new_val,
135 }),
136 );
137 }
138 _ => {}
139 }
140 }
141
142 for key in old_obj.keys() {
144 if !new_obj.contains_key(key) {
145 diff.insert(
146 key.clone(),
147 serde_json::json!({
148 "old": old_obj.get(key),
149 "new": null,
150 }),
151 );
152 }
153 }
154 }
155
156 serde_json::Value::Object(diff)
157}
158
159pub fn query_audit_log(conn: &Connection, filter: &AuditFilter) -> Result<Vec<AuditEntry>> {
161 let mut sql = String::from(
162 "SELECT id, timestamp, user_id, action, memory_id, changes, ip_address
163 FROM audit_log WHERE 1=1",
164 );
165 let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
166
167 if let Some(memory_id) = filter.memory_id {
168 sql.push_str(" AND memory_id = ?");
169 params_vec.push(Box::new(memory_id));
170 }
171
172 if let Some(ref user_id) = filter.user_id {
173 sql.push_str(" AND user_id = ?");
174 params_vec.push(Box::new(user_id.clone()));
175 }
176
177 if let Some(ref action) = filter.action {
178 sql.push_str(" AND action = ?");
179 params_vec.push(Box::new(action.as_str().to_string()));
180 }
181
182 if let Some(ref since) = filter.since {
183 sql.push_str(" AND timestamp >= ?");
184 params_vec.push(Box::new(since.to_rfc3339()));
185 }
186
187 if let Some(ref until) = filter.until {
188 sql.push_str(" AND timestamp <= ?");
189 params_vec.push(Box::new(until.to_rfc3339()));
190 }
191
192 sql.push_str(" ORDER BY timestamp DESC");
193
194 if let Some(limit) = filter.limit {
195 sql.push_str(&format!(" LIMIT {}", limit));
196 }
197
198 let params_ref: Vec<&dyn rusqlite::ToSql> = params_vec.iter().map(|b| b.as_ref()).collect();
199 let mut stmt = conn.prepare(&sql)?;
200
201 let entries: Vec<AuditEntry> = stmt
202 .query_map(params_ref.as_slice(), |row| {
203 let timestamp_str: String = row.get("timestamp")?;
204 let action_str: String = row.get("action")?;
205 let changes_str: Option<String> = row.get("changes")?;
206
207 Ok(AuditEntry {
208 id: row.get("id")?,
209 timestamp: DateTime::parse_from_rfc3339(×tamp_str)
210 .map(|dt| dt.with_timezone(&Utc))
211 .unwrap_or_else(|_| Utc::now()),
212 user_id: row.get("user_id")?,
213 action: action_str.parse().unwrap_or(AuditAction::Update),
214 memory_id: row.get("memory_id")?,
215 changes: changes_str.and_then(|s| serde_json::from_str(&s).ok()),
216 ip_address: row.get("ip_address")?,
217 })
218 })?
219 .filter_map(|r| r.ok())
220 .collect();
221
222 Ok(entries)
223}
224
225#[derive(Debug, Clone, Default)]
227pub struct AuditFilter {
228 pub memory_id: Option<MemoryId>,
229 pub user_id: Option<String>,
230 pub action: Option<AuditAction>,
231 pub since: Option<DateTime<Utc>>,
232 pub until: Option<DateTime<Utc>>,
233 pub limit: Option<i64>,
234}
235
236pub fn get_memory_audit_summary(conn: &Connection, memory_id: MemoryId) -> Result<AuditSummary> {
238 let filter = AuditFilter {
239 memory_id: Some(memory_id),
240 limit: Some(1000),
241 ..Default::default()
242 };
243
244 let entries = query_audit_log(conn, &filter)?;
245
246 let total_changes = entries.len();
247 let unique_users: std::collections::HashSet<_> =
248 entries.iter().filter_map(|e| e.user_id.as_ref()).collect();
249 let first_action = entries.last().map(|e| e.timestamp);
250 let last_action = entries.first().map(|e| e.timestamp);
251
252 let mut action_counts: HashMap<String, i64> = HashMap::new();
253 for entry in &entries {
254 *action_counts
255 .entry(entry.action.as_str().to_string())
256 .or_insert(0) += 1;
257 }
258
259 Ok(AuditSummary {
260 memory_id,
261 total_changes,
262 unique_users: unique_users.len(),
263 first_action,
264 last_action,
265 action_counts,
266 })
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct AuditSummary {
272 pub memory_id: MemoryId,
273 pub total_changes: usize,
274 pub unique_users: usize,
275 pub first_action: Option<DateTime<Utc>>,
276 pub last_action: Option<DateTime<Utc>>,
277 pub action_counts: HashMap<String, i64>,
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_calculate_diff() {
286 let old = serde_json::json!({
287 "content": "old content",
288 "importance": 0.5,
289 "removed_field": "value"
290 });
291
292 let new = serde_json::json!({
293 "content": "new content",
294 "importance": 0.5,
295 "new_field": "new value"
296 });
297
298 let diff = calculate_diff(&old, &new);
299 let diff_obj = diff.as_object().unwrap();
300
301 assert!(diff_obj.contains_key("content"));
302 assert!(diff_obj.contains_key("removed_field"));
303 assert!(diff_obj.contains_key("new_field"));
304 assert!(!diff_obj.contains_key("importance")); }
306
307 #[test]
308 fn test_audit_action_roundtrip() {
309 for action in [
310 AuditAction::Create,
311 AuditAction::Update,
312 AuditAction::Delete,
313 ] {
314 let s = action.as_str();
315 let parsed: AuditAction = s.parse().unwrap();
316 assert_eq!(action, parsed);
317 }
318 }
319}