Skip to main content

fraiseql_core/security/
audit.rs

1//! Audit logging for GraphQL operations.
2//!
3//! Uses `PostgreSQL` `deadpool` for database operations. Supports optional
4//! pluggable export sinks (syslog, webhook) for streaming audit entries to
5//! external immutable stores.
6
7use std::{sync::Arc, time::SystemTime};
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use deadpool_postgres::Pool;
12use serde::{Deserialize, Serialize};
13use sha2::{Digest, Sha256};
14use tracing::error;
15
16/// Errors that can occur during audit operations.
17#[derive(Debug, thiserror::Error)]
18#[non_exhaustive]
19pub enum AuditError {
20    /// Database operation failed.
21    #[error("Database operation failed: {0}")]
22    Database(#[from] deadpool_postgres::PoolError),
23
24    /// SQL query execution failed.
25    #[error("SQL query failed: {0}")]
26    Sql(#[from] tokio_postgres::Error),
27
28    /// Failed to serialize data to JSON.
29    #[error("JSON serialization failed: {0}")]
30    Serialization(#[from] serde_json::Error),
31
32    /// Export to an external sink failed.
33    #[error("Audit export failed: {0}")]
34    Export(String),
35}
36
37/// Audit log levels
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39#[non_exhaustive]
40pub enum AuditLevel {
41    /// Informational messages
42    INFO,
43    /// Warnings
44    WARN,
45    /// Errors
46    ERROR,
47}
48
49impl AuditLevel {
50    /// Convert to string for database storage
51    #[must_use]
52    pub const fn as_str(&self) -> &'static str {
53        match self {
54            Self::INFO => "INFO",
55            Self::WARN => "WARN",
56            Self::ERROR => "ERROR",
57        }
58    }
59
60    /// Parse from string
61    #[must_use]
62    pub fn parse(s: &str) -> Self {
63        match s {
64            "WARN" => Self::WARN,
65            "ERROR" => Self::ERROR,
66            _ => Self::INFO, // Default to INFO for unknown strings
67        }
68    }
69}
70
71/// Audit log entry with integrity protection
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct AuditEntry {
74    /// Entry ID (None for new entries)
75    pub id:             Option<i64>,
76    /// Timestamp
77    pub timestamp:      DateTime<Utc>,
78    /// Log level
79    pub level:          AuditLevel,
80    /// User ID
81    pub user_id:        i64,
82    /// Tenant ID
83    pub tenant_id:      i64,
84    /// Operation type (query, mutation)
85    pub operation:      String,
86    /// GraphQL query string
87    pub query:          String,
88    /// Query variables (JSONB)
89    pub variables:      serde_json::Value,
90    /// Client IP address
91    pub ip_address:     String,
92    /// Client user agent
93    pub user_agent:     String,
94    /// Error message (if any)
95    pub error:          Option<String>,
96    /// Query duration in milliseconds (optional)
97    pub duration_ms:    Option<i32>,
98    /// SHA256 hash of previous entry (for integrity chain)
99    pub previous_hash:  Option<String>,
100    /// SHA256 hash of this entry (for integrity verification)
101    pub integrity_hash: Option<String>,
102}
103
104impl AuditEntry {
105    /// Calculate SHA256 hash for this entry (for integrity chain)
106    ///
107    /// Hashes: `user_id` | timestamp | operation | query to create a tamper-proof chain
108    #[must_use]
109    pub fn calculate_hash(&self) -> String {
110        let mut hasher = Sha256::new();
111
112        // Include all mutable fields in hash for tamper detection
113        hasher.update(self.user_id.to_string().as_bytes());
114        hasher.update(self.timestamp.to_rfc3339().as_bytes());
115        hasher.update(self.operation.as_bytes());
116        hasher.update(self.query.as_bytes());
117        hasher.update(self.level.as_str().as_bytes());
118
119        // Include previous hash if present (hash chain)
120        if let Some(ref prev) = self.previous_hash {
121            hasher.update(prev.as_bytes());
122        }
123
124        format!("{:x}", hasher.finalize())
125    }
126
127    /// Verify integrity of this entry against its stored hash
128    #[must_use]
129    pub fn verify_integrity(&self) -> bool {
130        if let Some(ref stored_hash) = self.integrity_hash {
131            let calculated = self.calculate_hash();
132            // Constant-time comparison to prevent timing attacks
133            constant_time_eq(stored_hash.as_bytes(), calculated.as_bytes())
134        } else {
135            false
136        }
137    }
138}
139
140/// Constant-time comparison to prevent timing attacks.
141///
142/// Uses [`subtle::ConstantTimeEq`] — the same primitive used elsewhere in this
143/// codebase — instead of a hand-rolled loop, which is brittle under optimiser
144/// changes.
145fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
146    use subtle::ConstantTimeEq;
147    a.ct_eq(b).into()
148}
149
150// ============================================================================
151// Pluggable export sinks
152// ============================================================================
153
154/// Pluggable sink for streaming audit entries to external systems.
155///
156/// Implementations can send entries to syslog, webhooks, S3, or any other
157/// external store. The caller decides whether to retry or drop entries on
158/// failure.
159///
160/// # Errors
161///
162/// Returns error if the export fails (network, serialization, etc.).
163#[async_trait]
164pub trait AuditExporter: Send + Sync {
165    /// Export a single audit entry to the external sink.
166    ///
167    /// # Errors
168    ///
169    /// Returns [`AuditError`] if the export fails.
170    async fn export(&self, entry: &AuditEntry) -> Result<(), AuditError>;
171
172    /// Flush any buffered entries to the external sink.
173    ///
174    /// # Errors
175    ///
176    /// Returns [`AuditError`] if the flush fails.
177    async fn flush(&self) -> Result<(), AuditError>;
178}
179
180/// Configuration for audit log export sinks.
181#[derive(Debug, Clone, Default, Serialize, Deserialize)]
182pub struct AuditExportConfig {
183    /// Syslog export configuration (requires `audit-syslog` feature).
184    #[serde(default)]
185    pub syslog:  Option<SyslogExportConfig>,
186    /// Webhook export configuration (requires `audit-webhook` feature).
187    #[serde(default)]
188    pub webhook: Option<WebhookExportConfig>,
189}
190
191/// Configuration for the syslog audit exporter (RFC 5424).
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct SyslogExportConfig {
194    /// Syslog server hostname or IP.
195    pub address:  String,
196    /// Syslog server port (default: 514).
197    #[serde(default = "default_syslog_port")]
198    pub port:     u16,
199    /// Transport protocol: "tcp" or "udp" (default: "udp").
200    #[serde(default = "default_syslog_protocol")]
201    pub protocol: String,
202}
203
204/// Configuration for the webhook audit exporter.
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct WebhookExportConfig {
207    /// Webhook URL (must be HTTPS).
208    pub url:                 String,
209    /// Additional HTTP headers (e.g. `Authorization: Bearer ...`).
210    #[serde(default)]
211    pub headers:             std::collections::HashMap<String, String>,
212    /// Number of entries to accumulate before flushing (default: 100).
213    #[serde(default = "default_batch_size")]
214    pub batch_size:          usize,
215    /// Flush interval in seconds (default: 30).
216    #[serde(default = "default_flush_interval_secs")]
217    pub flush_interval_secs: u64,
218}
219
220const fn default_syslog_port() -> u16 {
221    514
222}
223
224fn default_syslog_protocol() -> String {
225    "udp".to_string()
226}
227
228const fn default_batch_size() -> usize {
229    100
230}
231
232const fn default_flush_interval_secs() -> u64 {
233    30
234}
235
236/// Statistics about audit events
237#[derive(Debug, Clone, Default, Serialize, Deserialize)]
238pub struct AuditStats {
239    /// Total number of audit events recorded
240    pub total_events:  u64,
241    /// Number of recent events (last 24 hours or recent window)
242    pub recent_events: u64,
243}
244
245/// Audit logger with `PostgreSQL` backend and optional export sinks.
246#[derive(Clone)]
247pub struct AuditLogger {
248    pool:      Arc<Pool>,
249    exporters: Arc<Vec<Box<dyn AuditExporter>>>,
250}
251
252impl std::fmt::Debug for AuditLogger {
253    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254        f.debug_struct("AuditLogger")
255            .field("pool", &self.pool)
256            .field("exporters_count", &self.exporters.len())
257            .finish()
258    }
259}
260
261impl AuditLogger {
262    /// Maximum byte length for the `query` and `variables` fields in an audit log entry.
263    ///
264    /// Limits audit table bloat and prevents excess allocation during serialization.
265    /// Queries that bypass `QueryValidator` (e.g. on error paths) may be arbitrarily large.
266    const MAX_AUDIT_FIELD_BYTES: usize = 64 * 1024;
267
268    /// Create a new audit logger with no export sinks.
269    #[must_use]
270    pub fn new(pool: Arc<Pool>) -> Self {
271        Self {
272            pool,
273            exporters: Arc::new(Vec::new()),
274        }
275    }
276
277    /// Create a new audit logger with export sinks.
278    ///
279    /// Entries are written to PostgreSQL first, then exported to each sink
280    /// on a best-effort basis (export failures are logged but do not fail the
281    /// primary write).
282    #[must_use]
283    pub fn with_exporters(pool: Arc<Pool>, exporters: Vec<Box<dyn AuditExporter>>) -> Self {
284        Self {
285            pool,
286            exporters: Arc::new(exporters),
287        }
288    }
289
290    // 64 KiB
291
292    /// Log an audit entry.
293    ///
294    /// Truncates `query` and `variables` to `MAX_AUDIT_FIELD_BYTES` before
295    /// storing to prevent audit table bloat.
296    ///
297    /// # Errors
298    ///
299    /// Returns error if database operation fails.
300    pub async fn log(&self, mut entry: AuditEntry) -> Result<i64, AuditError> {
301        // Truncate oversized query and variables before storing to prevent
302        // audit-table bloat from queries that bypass QueryValidator.
303        if entry.query.len() > Self::MAX_AUDIT_FIELD_BYTES {
304            entry.query.truncate(Self::MAX_AUDIT_FIELD_BYTES);
305            entry.query.push_str("…[truncated]");
306        }
307        let vars_serialized = serde_json::to_string(&entry.variables).unwrap_or_default();
308        if vars_serialized.len() > Self::MAX_AUDIT_FIELD_BYTES {
309            entry.variables = serde_json::json!({"_truncated": true});
310        }
311
312        let sql = r"
313            INSERT INTO fraiseql_audit_logs (
314                timestamp,
315                level,
316                user_id,
317                tenant_id,
318                operation,
319                query,
320                variables,
321                ip_address,
322                user_agent,
323                error,
324                duration_ms
325            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
326            RETURNING id
327        ";
328
329        let client = self.pool.get().await?;
330        let variables_json = serde_json::to_value(&entry.variables)?;
331
332        // Convert DateTime<Utc> to SystemTime for PostgreSQL
333        // timestamp() returns i64 seconds since epoch; audit events are always post-epoch.
334        let secs = u64::try_from(entry.timestamp.timestamp()).unwrap_or(0);
335        let timestamp_system = SystemTime::UNIX_EPOCH
336            + std::time::Duration::from_secs(secs)
337            + std::time::Duration::from_nanos(u64::from(entry.timestamp.timestamp_subsec_nanos()));
338
339        let row = client
340            .query_one(
341                sql,
342                &[
343                    &timestamp_system,
344                    &entry.level.as_str(),
345                    &entry.user_id,
346                    &entry.tenant_id,
347                    &entry.operation,
348                    &entry.query,
349                    &variables_json,
350                    &entry.ip_address,
351                    &entry.user_agent,
352                    &entry.error,
353                    &entry.duration_ms,
354                ],
355            )
356            .await?;
357
358        let id: i64 = row.get(0);
359
360        // Fire-and-forget export to external sinks. Failures are logged but
361        // do not affect the primary PostgreSQL write path.
362        if !self.exporters.is_empty() {
363            for exporter in self.exporters.iter() {
364                if let Err(e) = exporter.export(&entry).await {
365                    error!(error = %e, "Audit exporter failed");
366                }
367            }
368        }
369
370        Ok(id)
371    }
372
373    /// Flush all export sinks.
374    ///
375    /// Call this during graceful shutdown to ensure buffered entries are delivered.
376    ///
377    /// # Errors
378    ///
379    /// Returns the first flush error encountered; remaining sinks are still flushed.
380    pub async fn flush_exporters(&self) -> Result<(), AuditError> {
381        let mut first_err = None;
382        for exporter in self.exporters.iter() {
383            if let Err(e) = exporter.flush().await {
384                error!(error = %e, "Audit exporter flush failed");
385                if first_err.is_none() {
386                    first_err = Some(e);
387                }
388            }
389        }
390        match first_err {
391            Some(e) => Err(e),
392            None => Ok(()),
393        }
394    }
395
396    /// Get recent logs for a tenant
397    ///
398    /// # Errors
399    ///
400    /// Returns error if database operation fails
401    pub async fn get_recent_logs(
402        &self,
403        tenant_id: i64,
404        level: Option<AuditLevel>,
405        limit: i64,
406    ) -> Result<Vec<AuditEntry>, AuditError> {
407        let client = self.pool.get().await?;
408
409        let rows = if let Some(lvl) = level {
410            let sql = r"
411                SELECT id, timestamp, level, user_id, tenant_id, operation,
412                       query, variables, ip_address, user_agent, error, duration_ms
413                FROM fraiseql_audit_logs
414                WHERE tenant_id = $1 AND level = $2
415                ORDER BY timestamp DESC
416                LIMIT $3
417            ";
418            client.query(sql, &[&tenant_id, &lvl.as_str(), &limit]).await?
419        } else {
420            let sql = r"
421                SELECT id, timestamp, level, user_id, tenant_id, operation,
422                       query, variables, ip_address, user_agent, error, duration_ms
423                FROM fraiseql_audit_logs
424                WHERE tenant_id = $1
425                ORDER BY timestamp DESC
426                LIMIT $2
427            ";
428            client.query(sql, &[&tenant_id, &limit]).await?
429        };
430
431        let entries: Vec<AuditEntry> = rows
432            .into_iter()
433            .map(|row| {
434                let id: Option<i64> = row.get(0);
435                let timestamp_system: SystemTime = row.get(1);
436                let level_str: String = row.get(2);
437                let user_id: i64 = row.get(3);
438                let tenant_id: i64 = row.get(4);
439                let operation: String = row.get(5);
440                let query: String = row.get(6);
441                let variables: serde_json::Value = row.get(7);
442                let ip_address: String = row.get(8);
443                let user_agent: String = row.get(9);
444                let error: Option<String> = row.get(10);
445                let duration_ms: Option<i32> = row.get(11);
446
447                // Convert SystemTime to DateTime<Utc>
448                let timestamp = DateTime::<Utc>::from(timestamp_system);
449
450                AuditEntry {
451                    id,
452                    timestamp,
453                    level: AuditLevel::parse(&level_str),
454                    user_id,
455                    tenant_id,
456                    operation,
457                    query,
458                    variables,
459                    ip_address,
460                    user_agent,
461                    error,
462                    duration_ms,
463                    previous_hash: None,
464                    integrity_hash: None,
465                }
466            })
467            .collect();
468
469        Ok(entries)
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476
477    #[test]
478    fn test_audit_entry_integrity_hash() {
479        let entry = AuditEntry {
480            id:             Some(1),
481            timestamp:      Utc::now(),
482            level:          AuditLevel::INFO,
483            user_id:        123,
484            tenant_id:      456,
485            operation:      "query".to_string(),
486            query:          "{ users { id name } }".to_string(),
487            variables:      serde_json::json!({}),
488            ip_address:     "192.168.1.1".to_string(),
489            user_agent:     "Mozilla/5.0".to_string(),
490            error:          None,
491            duration_ms:    Some(100),
492            previous_hash:  None,
493            integrity_hash: None,
494        };
495
496        let hash = entry.calculate_hash();
497        assert!(!hash.is_empty());
498        assert_eq!(hash.len(), 64); // SHA256 hex is 64 chars
499    }
500
501    #[test]
502    fn test_audit_integrity_verification() {
503        let mut entry = AuditEntry {
504            id:             Some(1),
505            timestamp:      Utc::now(),
506            level:          AuditLevel::INFO,
507            user_id:        123,
508            tenant_id:      456,
509            operation:      "query".to_string(),
510            query:          "{ users { id name } }".to_string(),
511            variables:      serde_json::json!({}),
512            ip_address:     "192.168.1.1".to_string(),
513            user_agent:     "Mozilla/5.0".to_string(),
514            error:          None,
515            duration_ms:    Some(100),
516            previous_hash:  None,
517            integrity_hash: None,
518        };
519
520        // Calculate hash and store it
521        let calculated_hash = entry.calculate_hash();
522        entry.integrity_hash = Some(calculated_hash);
523
524        // Verify should pass
525        assert!(entry.verify_integrity());
526
527        // Tamper with data
528        entry.user_id = 999;
529
530        // Verify should fail
531        assert!(!entry.verify_integrity());
532    }
533
534    #[test]
535    fn test_audit_hash_chain() {
536        let timestamp = Utc::now();
537
538        let mut entry1 = AuditEntry {
539            id: Some(1),
540            timestamp,
541            level: AuditLevel::INFO,
542            user_id: 123,
543            tenant_id: 456,
544            operation: "query".to_string(),
545            query: "{ users { id } }".to_string(),
546            variables: serde_json::json!({}),
547            ip_address: "192.168.1.1".to_string(),
548            user_agent: "Mozilla/5.0".to_string(),
549            error: None,
550            duration_ms: Some(100),
551            previous_hash: None,
552            integrity_hash: None,
553        };
554
555        let hash1 = entry1.calculate_hash();
556        entry1.integrity_hash = Some(hash1.clone());
557
558        // Create second entry with chain
559        let mut entry2 = AuditEntry {
560            id: Some(2),
561            timestamp,
562            level: AuditLevel::INFO,
563            user_id: 123,
564            tenant_id: 456,
565            operation: "query".to_string(),
566            query: "{ posts { id } }".to_string(),
567            variables: serde_json::json!({}),
568            ip_address: "192.168.1.1".to_string(),
569            user_agent: "Mozilla/5.0".to_string(),
570            error: None,
571            duration_ms: Some(50),
572            previous_hash: Some(hash1),
573            integrity_hash: None,
574        };
575
576        let hash2 = entry2.calculate_hash();
577        entry2.integrity_hash = Some(hash2);
578
579        // Both should verify
580        assert!(entry1.verify_integrity());
581        assert!(entry2.verify_integrity());
582
583        // Breaking the chain should be detected
584        entry1.user_id = 999;
585        assert!(!entry1.verify_integrity());
586    }
587
588    #[test]
589    fn test_audit_level_parsing() {
590        assert_eq!(AuditLevel::parse("WARN"), AuditLevel::WARN);
591        assert_eq!(AuditLevel::parse("ERROR"), AuditLevel::ERROR);
592        assert_eq!(AuditLevel::parse("INFO"), AuditLevel::INFO);
593        assert_eq!(AuditLevel::parse("UNKNOWN"), AuditLevel::INFO);
594    }
595
596    #[test]
597    fn test_audit_export_config_deserialization() {
598        let json = r#"{
599            "syslog": { "address": "syslog.internal", "port": 514, "protocol": "tcp" },
600            "webhook": { "url": "https://logs.example.com/ingest" }
601        }"#;
602        let config: AuditExportConfig =
603            serde_json::from_str(json).expect("should deserialize AuditExportConfig");
604        assert!(config.syslog.is_some());
605        assert!(config.webhook.is_some());
606
607        let syslog = config.syslog.expect("syslog should be Some");
608        assert_eq!(syslog.address, "syslog.internal");
609        assert_eq!(syslog.port, 514);
610        assert_eq!(syslog.protocol, "tcp");
611
612        let webhook = config.webhook.expect("webhook should be Some");
613        assert_eq!(webhook.url, "https://logs.example.com/ingest");
614        assert_eq!(webhook.batch_size, 100);
615        assert_eq!(webhook.flush_interval_secs, 30);
616    }
617
618    #[test]
619    fn test_audit_export_config_empty() {
620        let config: AuditExportConfig =
621            serde_json::from_str("{}").expect("should deserialize empty config");
622        assert!(config.syslog.is_none());
623        assert!(config.webhook.is_none());
624    }
625
626    #[test]
627    fn test_audit_error_export_variant() {
628        let err = AuditError::Export("connection refused".to_string());
629        assert!(err.to_string().contains("connection refused"));
630    }
631}