1use std::{sync::Arc, time::SystemTime};
6
7use chrono::{DateTime, Utc};
8use deadpool_postgres::Pool;
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11
12#[derive(Debug, thiserror::Error)]
14pub enum AuditError {
15 #[error("Database operation failed: {0}")]
17 Database(#[from] deadpool_postgres::PoolError),
18
19 #[error("SQL query failed: {0}")]
21 Sql(#[from] tokio_postgres::Error),
22
23 #[error("JSON serialization failed: {0}")]
25 Serialization(#[from] serde_json::Error),
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30pub enum AuditLevel {
31 INFO,
33 WARN,
35 ERROR,
37}
38
39impl AuditLevel {
40 #[must_use]
42 pub const fn as_str(&self) -> &'static str {
43 match self {
44 Self::INFO => "INFO",
45 Self::WARN => "WARN",
46 Self::ERROR => "ERROR",
47 }
48 }
49
50 #[must_use]
52 pub fn parse(s: &str) -> Self {
53 match s {
54 "WARN" => Self::WARN,
55 "ERROR" => Self::ERROR,
56 _ => Self::INFO, }
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct AuditEntry {
64 pub id: Option<i64>,
66 pub timestamp: DateTime<Utc>,
68 pub level: AuditLevel,
70 pub user_id: i64,
72 pub tenant_id: i64,
74 pub operation: String,
76 pub query: String,
78 pub variables: serde_json::Value,
80 pub ip_address: String,
82 pub user_agent: String,
84 pub error: Option<String>,
86 pub duration_ms: Option<i32>,
88 pub previous_hash: Option<String>,
90 pub integrity_hash: Option<String>,
92}
93
94impl AuditEntry {
95 #[must_use]
99 pub fn calculate_hash(&self) -> String {
100 let mut hasher = Sha256::new();
101
102 hasher.update(self.user_id.to_string().as_bytes());
104 hasher.update(self.timestamp.to_rfc3339().as_bytes());
105 hasher.update(self.operation.as_bytes());
106 hasher.update(self.query.as_bytes());
107 hasher.update(self.level.as_str().as_bytes());
108
109 if let Some(ref prev) = self.previous_hash {
111 hasher.update(prev.as_bytes());
112 }
113
114 format!("{:x}", hasher.finalize())
115 }
116
117 #[must_use]
119 pub fn verify_integrity(&self) -> bool {
120 if let Some(ref stored_hash) = self.integrity_hash {
121 let calculated = self.calculate_hash();
122 constant_time_eq(stored_hash.as_bytes(), calculated.as_bytes())
124 } else {
125 false
126 }
127 }
128}
129
130fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
132 if a.len() != b.len() {
133 return false;
134 }
135
136 let mut result = 0u8;
137 for (x, y) in a.iter().zip(b.iter()) {
138 result |= x ^ y;
139 }
140
141 result == 0
142}
143
144#[derive(Debug, Clone, Default, Serialize, Deserialize)]
146pub struct AuditStats {
147 pub total_events: u64,
149 pub recent_events: u64,
151}
152
153#[derive(Clone, Debug)]
155pub struct AuditLogger {
156 pool: Arc<Pool>,
157}
158
159impl AuditLogger {
160 #[must_use]
162 pub const fn new(pool: Arc<Pool>) -> Self {
163 Self { pool }
164 }
165
166 pub async fn log(&self, entry: AuditEntry) -> Result<i64, AuditError> {
172 let sql = r"
173 INSERT INTO fraiseql_audit_logs (
174 timestamp,
175 level,
176 user_id,
177 tenant_id,
178 operation,
179 query,
180 variables,
181 ip_address,
182 user_agent,
183 error,
184 duration_ms
185 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
186 RETURNING id
187 ";
188
189 let client = self.pool.get().await?;
190 let variables_json = serde_json::to_value(&entry.variables)?;
191
192 let timestamp_system = SystemTime::UNIX_EPOCH
194 + std::time::Duration::from_secs(entry.timestamp.timestamp() as u64)
195 + std::time::Duration::from_nanos(u64::from(entry.timestamp.timestamp_subsec_nanos()));
196
197 let row = client
198 .query_one(
199 sql,
200 &[
201 ×tamp_system,
202 &entry.level.as_str(),
203 &entry.user_id,
204 &entry.tenant_id,
205 &entry.operation,
206 &entry.query,
207 &variables_json,
208 &entry.ip_address,
209 &entry.user_agent,
210 &entry.error,
211 &entry.duration_ms,
212 ],
213 )
214 .await?;
215
216 let id: i64 = row.get(0);
217 Ok(id)
218 }
219
220 pub async fn get_recent_logs(
226 &self,
227 tenant_id: i64,
228 level: Option<AuditLevel>,
229 limit: i64,
230 ) -> Result<Vec<AuditEntry>, AuditError> {
231 let client = self.pool.get().await?;
232
233 let rows = if let Some(lvl) = level {
234 let sql = r"
235 SELECT id, timestamp, level, user_id, tenant_id, operation,
236 query, variables, ip_address, user_agent, error, duration_ms
237 FROM fraiseql_audit_logs
238 WHERE tenant_id = $1 AND level = $2
239 ORDER BY timestamp DESC
240 LIMIT $3
241 ";
242 client.query(sql, &[&tenant_id, &lvl.as_str(), &limit]).await?
243 } else {
244 let sql = r"
245 SELECT id, timestamp, level, user_id, tenant_id, operation,
246 query, variables, ip_address, user_agent, error, duration_ms
247 FROM fraiseql_audit_logs
248 WHERE tenant_id = $1
249 ORDER BY timestamp DESC
250 LIMIT $2
251 ";
252 client.query(sql, &[&tenant_id, &limit]).await?
253 };
254
255 let entries: Vec<AuditEntry> = rows
256 .into_iter()
257 .map(|row| {
258 let id: Option<i64> = row.get(0);
259 let timestamp_system: SystemTime = row.get(1);
260 let level_str: String = row.get(2);
261 let user_id: i64 = row.get(3);
262 let tenant_id: i64 = row.get(4);
263 let operation: String = row.get(5);
264 let query: String = row.get(6);
265 let variables: serde_json::Value = row.get(7);
266 let ip_address: String = row.get(8);
267 let user_agent: String = row.get(9);
268 let error: Option<String> = row.get(10);
269 let duration_ms: Option<i32> = row.get(11);
270
271 let timestamp = DateTime::<Utc>::from(timestamp_system);
273
274 AuditEntry {
275 id,
276 timestamp,
277 level: AuditLevel::parse(&level_str),
278 user_id,
279 tenant_id,
280 operation,
281 query,
282 variables,
283 ip_address,
284 user_agent,
285 error,
286 duration_ms,
287 previous_hash: None,
288 integrity_hash: None,
289 }
290 })
291 .collect();
292
293 Ok(entries)
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn test_audit_entry_integrity_hash() {
303 let entry = AuditEntry {
304 id: Some(1),
305 timestamp: Utc::now(),
306 level: AuditLevel::INFO,
307 user_id: 123,
308 tenant_id: 456,
309 operation: "query".to_string(),
310 query: "{ users { id name } }".to_string(),
311 variables: serde_json::json!({}),
312 ip_address: "192.168.1.1".to_string(),
313 user_agent: "Mozilla/5.0".to_string(),
314 error: None,
315 duration_ms: Some(100),
316 previous_hash: None,
317 integrity_hash: None,
318 };
319
320 let hash = entry.calculate_hash();
321 assert!(!hash.is_empty());
322 assert_eq!(hash.len(), 64); }
324
325 #[test]
326 fn test_audit_integrity_verification() {
327 let mut entry = AuditEntry {
328 id: Some(1),
329 timestamp: Utc::now(),
330 level: AuditLevel::INFO,
331 user_id: 123,
332 tenant_id: 456,
333 operation: "query".to_string(),
334 query: "{ users { id name } }".to_string(),
335 variables: serde_json::json!({}),
336 ip_address: "192.168.1.1".to_string(),
337 user_agent: "Mozilla/5.0".to_string(),
338 error: None,
339 duration_ms: Some(100),
340 previous_hash: None,
341 integrity_hash: None,
342 };
343
344 let calculated_hash = entry.calculate_hash();
346 entry.integrity_hash = Some(calculated_hash);
347
348 assert!(entry.verify_integrity());
350
351 entry.user_id = 999;
353
354 assert!(!entry.verify_integrity());
356 }
357
358 #[test]
359 fn test_audit_hash_chain() {
360 let timestamp = Utc::now();
361
362 let mut entry1 = AuditEntry {
363 id: Some(1),
364 timestamp,
365 level: AuditLevel::INFO,
366 user_id: 123,
367 tenant_id: 456,
368 operation: "query".to_string(),
369 query: "{ users { id } }".to_string(),
370 variables: serde_json::json!({}),
371 ip_address: "192.168.1.1".to_string(),
372 user_agent: "Mozilla/5.0".to_string(),
373 error: None,
374 duration_ms: Some(100),
375 previous_hash: None,
376 integrity_hash: None,
377 };
378
379 let hash1 = entry1.calculate_hash();
380 entry1.integrity_hash = Some(hash1.clone());
381
382 let mut entry2 = AuditEntry {
384 id: Some(2),
385 timestamp,
386 level: AuditLevel::INFO,
387 user_id: 123,
388 tenant_id: 456,
389 operation: "query".to_string(),
390 query: "{ posts { id } }".to_string(),
391 variables: serde_json::json!({}),
392 ip_address: "192.168.1.1".to_string(),
393 user_agent: "Mozilla/5.0".to_string(),
394 error: None,
395 duration_ms: Some(50),
396 previous_hash: Some(hash1),
397 integrity_hash: None,
398 };
399
400 let hash2 = entry2.calculate_hash();
401 entry2.integrity_hash = Some(hash2);
402
403 assert!(entry1.verify_integrity());
405 assert!(entry2.verify_integrity());
406
407 entry1.user_id = 999;
409 assert!(!entry1.verify_integrity());
410 }
411
412 #[test]
413 fn test_audit_level_parsing() {
414 assert_eq!(AuditLevel::parse("WARN"), AuditLevel::WARN);
415 assert_eq!(AuditLevel::parse("ERROR"), AuditLevel::ERROR);
416 assert_eq!(AuditLevel::parse("INFO"), AuditLevel::INFO);
417 assert_eq!(AuditLevel::parse("UNKNOWN"), AuditLevel::INFO);
418 }
419}