1use std::{sync::Arc, time::SystemTime};
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use deadpool_postgres::Pool;
12use serde::{Deserialize, Serialize};
13use sha2::{Digest, Sha256};
14use tracing::error;
15
16#[derive(Debug, thiserror::Error)]
18#[non_exhaustive]
19pub enum AuditError {
20 #[error("Database operation failed: {0}")]
22 Database(#[from] deadpool_postgres::PoolError),
23
24 #[error("SQL query failed: {0}")]
26 Sql(#[from] tokio_postgres::Error),
27
28 #[error("JSON serialization failed: {0}")]
30 Serialization(#[from] serde_json::Error),
31
32 #[error("Audit export failed: {0}")]
34 Export(String),
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39#[non_exhaustive]
40pub enum AuditLevel {
41 INFO,
43 WARN,
45 ERROR,
47}
48
49impl AuditLevel {
50 #[must_use]
52 pub const fn as_str(&self) -> &'static str {
53 match self {
54 Self::INFO => "INFO",
55 Self::WARN => "WARN",
56 Self::ERROR => "ERROR",
57 }
58 }
59
60 #[must_use]
62 pub fn parse(s: &str) -> Self {
63 match s {
64 "WARN" => Self::WARN,
65 "ERROR" => Self::ERROR,
66 _ => Self::INFO, }
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct AuditEntry {
74 pub id: Option<i64>,
76 pub timestamp: DateTime<Utc>,
78 pub level: AuditLevel,
80 pub user_id: i64,
82 pub tenant_id: i64,
84 pub operation: String,
86 pub query: String,
88 pub variables: serde_json::Value,
90 pub ip_address: String,
92 pub user_agent: String,
94 pub error: Option<String>,
96 pub duration_ms: Option<i32>,
98 pub previous_hash: Option<String>,
100 pub integrity_hash: Option<String>,
102}
103
104impl AuditEntry {
105 #[must_use]
109 pub fn calculate_hash(&self) -> String {
110 let mut hasher = Sha256::new();
111
112 hasher.update(self.user_id.to_string().as_bytes());
114 hasher.update(self.timestamp.to_rfc3339().as_bytes());
115 hasher.update(self.operation.as_bytes());
116 hasher.update(self.query.as_bytes());
117 hasher.update(self.level.as_str().as_bytes());
118
119 if let Some(ref prev) = self.previous_hash {
121 hasher.update(prev.as_bytes());
122 }
123
124 format!("{:x}", hasher.finalize())
125 }
126
127 #[must_use]
129 pub fn verify_integrity(&self) -> bool {
130 if let Some(ref stored_hash) = self.integrity_hash {
131 let calculated = self.calculate_hash();
132 constant_time_eq(stored_hash.as_bytes(), calculated.as_bytes())
134 } else {
135 false
136 }
137 }
138}
139
140fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
146 use subtle::ConstantTimeEq;
147 a.ct_eq(b).into()
148}
149
150#[async_trait]
164pub trait AuditExporter: Send + Sync {
165 async fn export(&self, entry: &AuditEntry) -> Result<(), AuditError>;
171
172 async fn flush(&self) -> Result<(), AuditError>;
178}
179
180#[derive(Debug, Clone, Default, Serialize, Deserialize)]
182pub struct AuditExportConfig {
183 #[serde(default)]
185 pub syslog: Option<SyslogExportConfig>,
186 #[serde(default)]
188 pub webhook: Option<WebhookExportConfig>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct SyslogExportConfig {
194 pub address: String,
196 #[serde(default = "default_syslog_port")]
198 pub port: u16,
199 #[serde(default = "default_syslog_protocol")]
201 pub protocol: String,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct WebhookExportConfig {
207 pub url: String,
209 #[serde(default)]
211 pub headers: std::collections::HashMap<String, String>,
212 #[serde(default = "default_batch_size")]
214 pub batch_size: usize,
215 #[serde(default = "default_flush_interval_secs")]
217 pub flush_interval_secs: u64,
218}
219
220const fn default_syslog_port() -> u16 {
221 514
222}
223
224fn default_syslog_protocol() -> String {
225 "udp".to_string()
226}
227
228const fn default_batch_size() -> usize {
229 100
230}
231
232const fn default_flush_interval_secs() -> u64 {
233 30
234}
235
236#[derive(Debug, Clone, Default, Serialize, Deserialize)]
238pub struct AuditStats {
239 pub total_events: u64,
241 pub recent_events: u64,
243}
244
245#[derive(Clone)]
247pub struct AuditLogger {
248 pool: Arc<Pool>,
249 exporters: Arc<Vec<Box<dyn AuditExporter>>>,
250}
251
252impl std::fmt::Debug for AuditLogger {
253 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254 f.debug_struct("AuditLogger")
255 .field("pool", &self.pool)
256 .field("exporters_count", &self.exporters.len())
257 .finish()
258 }
259}
260
261impl AuditLogger {
262 const MAX_AUDIT_FIELD_BYTES: usize = 64 * 1024;
267
268 #[must_use]
270 pub fn new(pool: Arc<Pool>) -> Self {
271 Self {
272 pool,
273 exporters: Arc::new(Vec::new()),
274 }
275 }
276
277 #[must_use]
283 pub fn with_exporters(pool: Arc<Pool>, exporters: Vec<Box<dyn AuditExporter>>) -> Self {
284 Self {
285 pool,
286 exporters: Arc::new(exporters),
287 }
288 }
289
290 pub async fn log(&self, mut entry: AuditEntry) -> Result<i64, AuditError> {
301 if entry.query.len() > Self::MAX_AUDIT_FIELD_BYTES {
304 entry.query.truncate(Self::MAX_AUDIT_FIELD_BYTES);
305 entry.query.push_str("…[truncated]");
306 }
307 let vars_serialized = serde_json::to_string(&entry.variables).unwrap_or_default();
308 if vars_serialized.len() > Self::MAX_AUDIT_FIELD_BYTES {
309 entry.variables = serde_json::json!({"_truncated": true});
310 }
311
312 let sql = r"
313 INSERT INTO fraiseql_audit_logs (
314 timestamp,
315 level,
316 user_id,
317 tenant_id,
318 operation,
319 query,
320 variables,
321 ip_address,
322 user_agent,
323 error,
324 duration_ms
325 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
326 RETURNING id
327 ";
328
329 let client = self.pool.get().await?;
330 let variables_json = serde_json::to_value(&entry.variables)?;
331
332 let secs = u64::try_from(entry.timestamp.timestamp()).unwrap_or(0);
335 let timestamp_system = SystemTime::UNIX_EPOCH
336 + std::time::Duration::from_secs(secs)
337 + std::time::Duration::from_nanos(u64::from(entry.timestamp.timestamp_subsec_nanos()));
338
339 let row = client
340 .query_one(
341 sql,
342 &[
343 ×tamp_system,
344 &entry.level.as_str(),
345 &entry.user_id,
346 &entry.tenant_id,
347 &entry.operation,
348 &entry.query,
349 &variables_json,
350 &entry.ip_address,
351 &entry.user_agent,
352 &entry.error,
353 &entry.duration_ms,
354 ],
355 )
356 .await?;
357
358 let id: i64 = row.get(0);
359
360 if !self.exporters.is_empty() {
363 for exporter in self.exporters.iter() {
364 if let Err(e) = exporter.export(&entry).await {
365 error!(error = %e, "Audit exporter failed");
366 }
367 }
368 }
369
370 Ok(id)
371 }
372
373 pub async fn flush_exporters(&self) -> Result<(), AuditError> {
381 let mut first_err = None;
382 for exporter in self.exporters.iter() {
383 if let Err(e) = exporter.flush().await {
384 error!(error = %e, "Audit exporter flush failed");
385 if first_err.is_none() {
386 first_err = Some(e);
387 }
388 }
389 }
390 match first_err {
391 Some(e) => Err(e),
392 None => Ok(()),
393 }
394 }
395
396 pub async fn get_recent_logs(
402 &self,
403 tenant_id: i64,
404 level: Option<AuditLevel>,
405 limit: i64,
406 ) -> Result<Vec<AuditEntry>, AuditError> {
407 let client = self.pool.get().await?;
408
409 let rows = if let Some(lvl) = level {
410 let sql = r"
411 SELECT id, timestamp, level, user_id, tenant_id, operation,
412 query, variables, ip_address, user_agent, error, duration_ms
413 FROM fraiseql_audit_logs
414 WHERE tenant_id = $1 AND level = $2
415 ORDER BY timestamp DESC
416 LIMIT $3
417 ";
418 client.query(sql, &[&tenant_id, &lvl.as_str(), &limit]).await?
419 } else {
420 let sql = r"
421 SELECT id, timestamp, level, user_id, tenant_id, operation,
422 query, variables, ip_address, user_agent, error, duration_ms
423 FROM fraiseql_audit_logs
424 WHERE tenant_id = $1
425 ORDER BY timestamp DESC
426 LIMIT $2
427 ";
428 client.query(sql, &[&tenant_id, &limit]).await?
429 };
430
431 let entries: Vec<AuditEntry> = rows
432 .into_iter()
433 .map(|row| {
434 let id: Option<i64> = row.get(0);
435 let timestamp_system: SystemTime = row.get(1);
436 let level_str: String = row.get(2);
437 let user_id: i64 = row.get(3);
438 let tenant_id: i64 = row.get(4);
439 let operation: String = row.get(5);
440 let query: String = row.get(6);
441 let variables: serde_json::Value = row.get(7);
442 let ip_address: String = row.get(8);
443 let user_agent: String = row.get(9);
444 let error: Option<String> = row.get(10);
445 let duration_ms: Option<i32> = row.get(11);
446
447 let timestamp = DateTime::<Utc>::from(timestamp_system);
449
450 AuditEntry {
451 id,
452 timestamp,
453 level: AuditLevel::parse(&level_str),
454 user_id,
455 tenant_id,
456 operation,
457 query,
458 variables,
459 ip_address,
460 user_agent,
461 error,
462 duration_ms,
463 previous_hash: None,
464 integrity_hash: None,
465 }
466 })
467 .collect();
468
469 Ok(entries)
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476
477 #[test]
478 fn test_audit_entry_integrity_hash() {
479 let entry = AuditEntry {
480 id: Some(1),
481 timestamp: Utc::now(),
482 level: AuditLevel::INFO,
483 user_id: 123,
484 tenant_id: 456,
485 operation: "query".to_string(),
486 query: "{ users { id name } }".to_string(),
487 variables: serde_json::json!({}),
488 ip_address: "192.168.1.1".to_string(),
489 user_agent: "Mozilla/5.0".to_string(),
490 error: None,
491 duration_ms: Some(100),
492 previous_hash: None,
493 integrity_hash: None,
494 };
495
496 let hash = entry.calculate_hash();
497 assert!(!hash.is_empty());
498 assert_eq!(hash.len(), 64); }
500
501 #[test]
502 fn test_audit_integrity_verification() {
503 let mut entry = AuditEntry {
504 id: Some(1),
505 timestamp: Utc::now(),
506 level: AuditLevel::INFO,
507 user_id: 123,
508 tenant_id: 456,
509 operation: "query".to_string(),
510 query: "{ users { id name } }".to_string(),
511 variables: serde_json::json!({}),
512 ip_address: "192.168.1.1".to_string(),
513 user_agent: "Mozilla/5.0".to_string(),
514 error: None,
515 duration_ms: Some(100),
516 previous_hash: None,
517 integrity_hash: None,
518 };
519
520 let calculated_hash = entry.calculate_hash();
522 entry.integrity_hash = Some(calculated_hash);
523
524 assert!(entry.verify_integrity());
526
527 entry.user_id = 999;
529
530 assert!(!entry.verify_integrity());
532 }
533
534 #[test]
535 fn test_audit_hash_chain() {
536 let timestamp = Utc::now();
537
538 let mut entry1 = AuditEntry {
539 id: Some(1),
540 timestamp,
541 level: AuditLevel::INFO,
542 user_id: 123,
543 tenant_id: 456,
544 operation: "query".to_string(),
545 query: "{ users { id } }".to_string(),
546 variables: serde_json::json!({}),
547 ip_address: "192.168.1.1".to_string(),
548 user_agent: "Mozilla/5.0".to_string(),
549 error: None,
550 duration_ms: Some(100),
551 previous_hash: None,
552 integrity_hash: None,
553 };
554
555 let hash1 = entry1.calculate_hash();
556 entry1.integrity_hash = Some(hash1.clone());
557
558 let mut entry2 = AuditEntry {
560 id: Some(2),
561 timestamp,
562 level: AuditLevel::INFO,
563 user_id: 123,
564 tenant_id: 456,
565 operation: "query".to_string(),
566 query: "{ posts { id } }".to_string(),
567 variables: serde_json::json!({}),
568 ip_address: "192.168.1.1".to_string(),
569 user_agent: "Mozilla/5.0".to_string(),
570 error: None,
571 duration_ms: Some(50),
572 previous_hash: Some(hash1),
573 integrity_hash: None,
574 };
575
576 let hash2 = entry2.calculate_hash();
577 entry2.integrity_hash = Some(hash2);
578
579 assert!(entry1.verify_integrity());
581 assert!(entry2.verify_integrity());
582
583 entry1.user_id = 999;
585 assert!(!entry1.verify_integrity());
586 }
587
588 #[test]
589 fn test_audit_level_parsing() {
590 assert_eq!(AuditLevel::parse("WARN"), AuditLevel::WARN);
591 assert_eq!(AuditLevel::parse("ERROR"), AuditLevel::ERROR);
592 assert_eq!(AuditLevel::parse("INFO"), AuditLevel::INFO);
593 assert_eq!(AuditLevel::parse("UNKNOWN"), AuditLevel::INFO);
594 }
595
596 #[test]
597 fn test_audit_export_config_deserialization() {
598 let json = r#"{
599 "syslog": { "address": "syslog.internal", "port": 514, "protocol": "tcp" },
600 "webhook": { "url": "https://logs.example.com/ingest" }
601 }"#;
602 let config: AuditExportConfig =
603 serde_json::from_str(json).expect("should deserialize AuditExportConfig");
604 assert!(config.syslog.is_some());
605 assert!(config.webhook.is_some());
606
607 let syslog = config.syslog.expect("syslog should be Some");
608 assert_eq!(syslog.address, "syslog.internal");
609 assert_eq!(syslog.port, 514);
610 assert_eq!(syslog.protocol, "tcp");
611
612 let webhook = config.webhook.expect("webhook should be Some");
613 assert_eq!(webhook.url, "https://logs.example.com/ingest");
614 assert_eq!(webhook.batch_size, 100);
615 assert_eq!(webhook.flush_interval_secs, 30);
616 }
617
618 #[test]
619 fn test_audit_export_config_empty() {
620 let config: AuditExportConfig =
621 serde_json::from_str("{}").expect("should deserialize empty config");
622 assert!(config.syslog.is_none());
623 assert!(config.webhook.is_none());
624 }
625
626 #[test]
627 fn test_audit_error_export_variant() {
628 let err = AuditError::Export("connection refused".to_string());
629 assert!(err.to_string().contains("connection refused"));
630 }
631}