Skip to main content

fraiseql_core/security/
audit.rs

1//! Audit logging for GraphQL operations
2//!
3//! Uses `PostgreSQL` `deadpool` for database operations
4
5use std::{sync::Arc, time::SystemTime};
6
7use chrono::{DateTime, Utc};
8use deadpool_postgres::Pool;
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11
12/// Errors that can occur during audit operations.
13#[derive(Debug, thiserror::Error)]
14pub enum AuditError {
15    /// Database operation failed.
16    #[error("Database operation failed: {0}")]
17    Database(#[from] deadpool_postgres::PoolError),
18
19    /// SQL query execution failed.
20    #[error("SQL query failed: {0}")]
21    Sql(#[from] tokio_postgres::Error),
22
23    /// Failed to serialize data to JSON.
24    #[error("JSON serialization failed: {0}")]
25    Serialization(#[from] serde_json::Error),
26}
27
28/// Audit log levels
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30pub enum AuditLevel {
31    /// Informational messages
32    INFO,
33    /// Warnings
34    WARN,
35    /// Errors
36    ERROR,
37}
38
39impl AuditLevel {
40    /// Convert to string for database storage
41    #[must_use]
42    pub const fn as_str(&self) -> &'static str {
43        match self {
44            Self::INFO => "INFO",
45            Self::WARN => "WARN",
46            Self::ERROR => "ERROR",
47        }
48    }
49
50    /// Parse from string
51    #[must_use]
52    pub fn parse(s: &str) -> Self {
53        match s {
54            "WARN" => Self::WARN,
55            "ERROR" => Self::ERROR,
56            _ => Self::INFO, // Default to INFO for unknown strings
57        }
58    }
59}
60
61/// Audit log entry with integrity protection
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct AuditEntry {
64    /// Entry ID (None for new entries)
65    pub id:             Option<i64>,
66    /// Timestamp
67    pub timestamp:      DateTime<Utc>,
68    /// Log level
69    pub level:          AuditLevel,
70    /// User ID
71    pub user_id:        i64,
72    /// Tenant ID
73    pub tenant_id:      i64,
74    /// Operation type (query, mutation)
75    pub operation:      String,
76    /// GraphQL query string
77    pub query:          String,
78    /// Query variables (JSONB)
79    pub variables:      serde_json::Value,
80    /// Client IP address
81    pub ip_address:     String,
82    /// Client user agent
83    pub user_agent:     String,
84    /// Error message (if any)
85    pub error:          Option<String>,
86    /// Query duration in milliseconds (optional)
87    pub duration_ms:    Option<i32>,
88    /// SHA256 hash of previous entry (for integrity chain)
89    pub previous_hash:  Option<String>,
90    /// SHA256 hash of this entry (for integrity verification)
91    pub integrity_hash: Option<String>,
92}
93
94impl AuditEntry {
95    /// Calculate SHA256 hash for this entry (for integrity chain)
96    ///
97    /// Hashes: user_id | timestamp | operation | query to create a tamper-proof chain
98    #[must_use]
99    pub fn calculate_hash(&self) -> String {
100        let mut hasher = Sha256::new();
101
102        // Include all mutable fields in hash for tamper detection
103        hasher.update(self.user_id.to_string().as_bytes());
104        hasher.update(self.timestamp.to_rfc3339().as_bytes());
105        hasher.update(self.operation.as_bytes());
106        hasher.update(self.query.as_bytes());
107        hasher.update(self.level.as_str().as_bytes());
108
109        // Include previous hash if present (hash chain)
110        if let Some(ref prev) = self.previous_hash {
111            hasher.update(prev.as_bytes());
112        }
113
114        format!("{:x}", hasher.finalize())
115    }
116
117    /// Verify integrity of this entry against its stored hash
118    #[must_use]
119    pub fn verify_integrity(&self) -> bool {
120        if let Some(ref stored_hash) = self.integrity_hash {
121            let calculated = self.calculate_hash();
122            // Constant-time comparison to prevent timing attacks
123            constant_time_eq(stored_hash.as_bytes(), calculated.as_bytes())
124        } else {
125            false
126        }
127    }
128}
129
130/// Constant-time comparison to prevent timing attacks
131fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
132    if a.len() != b.len() {
133        return false;
134    }
135
136    let mut result = 0u8;
137    for (x, y) in a.iter().zip(b.iter()) {
138        result |= x ^ y;
139    }
140
141    result == 0
142}
143
144/// Statistics about audit events
145#[derive(Debug, Clone, Default, Serialize, Deserialize)]
146pub struct AuditStats {
147    /// Total number of audit events recorded
148    pub total_events:  u64,
149    /// Number of recent events (last 24 hours or recent window)
150    pub recent_events: u64,
151}
152
153/// Audit logger with `PostgreSQL` backend
154#[derive(Clone, Debug)]
155pub struct AuditLogger {
156    pool: Arc<Pool>,
157}
158
159impl AuditLogger {
160    /// Create a new audit logger
161    #[must_use]
162    pub const fn new(pool: Arc<Pool>) -> Self {
163        Self { pool }
164    }
165
166    /// Log an audit entry
167    ///
168    /// # Errors
169    ///
170    /// Returns error if database operation fails
171    pub async fn log(&self, entry: AuditEntry) -> Result<i64, AuditError> {
172        let sql = r"
173            INSERT INTO fraiseql_audit_logs (
174                timestamp,
175                level,
176                user_id,
177                tenant_id,
178                operation,
179                query,
180                variables,
181                ip_address,
182                user_agent,
183                error,
184                duration_ms
185            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
186            RETURNING id
187        ";
188
189        let client = self.pool.get().await?;
190        let variables_json = serde_json::to_value(&entry.variables)?;
191
192        // Convert DateTime<Utc> to SystemTime for PostgreSQL
193        let timestamp_system = SystemTime::UNIX_EPOCH
194            + std::time::Duration::from_secs(entry.timestamp.timestamp() as u64)
195            + std::time::Duration::from_nanos(u64::from(entry.timestamp.timestamp_subsec_nanos()));
196
197        let row = client
198            .query_one(
199                sql,
200                &[
201                    &timestamp_system,
202                    &entry.level.as_str(),
203                    &entry.user_id,
204                    &entry.tenant_id,
205                    &entry.operation,
206                    &entry.query,
207                    &variables_json,
208                    &entry.ip_address,
209                    &entry.user_agent,
210                    &entry.error,
211                    &entry.duration_ms,
212                ],
213            )
214            .await?;
215
216        let id: i64 = row.get(0);
217        Ok(id)
218    }
219
220    /// Get recent logs for a tenant
221    ///
222    /// # Errors
223    ///
224    /// Returns error if database operation fails
225    pub async fn get_recent_logs(
226        &self,
227        tenant_id: i64,
228        level: Option<AuditLevel>,
229        limit: i64,
230    ) -> Result<Vec<AuditEntry>, AuditError> {
231        let client = self.pool.get().await?;
232
233        let rows = if let Some(lvl) = level {
234            let sql = r"
235                SELECT id, timestamp, level, user_id, tenant_id, operation,
236                       query, variables, ip_address, user_agent, error, duration_ms
237                FROM fraiseql_audit_logs
238                WHERE tenant_id = $1 AND level = $2
239                ORDER BY timestamp DESC
240                LIMIT $3
241            ";
242            client.query(sql, &[&tenant_id, &lvl.as_str(), &limit]).await?
243        } else {
244            let sql = r"
245                SELECT id, timestamp, level, user_id, tenant_id, operation,
246                       query, variables, ip_address, user_agent, error, duration_ms
247                FROM fraiseql_audit_logs
248                WHERE tenant_id = $1
249                ORDER BY timestamp DESC
250                LIMIT $2
251            ";
252            client.query(sql, &[&tenant_id, &limit]).await?
253        };
254
255        let entries: Vec<AuditEntry> = rows
256            .into_iter()
257            .map(|row| {
258                let id: Option<i64> = row.get(0);
259                let timestamp_system: SystemTime = row.get(1);
260                let level_str: String = row.get(2);
261                let user_id: i64 = row.get(3);
262                let tenant_id: i64 = row.get(4);
263                let operation: String = row.get(5);
264                let query: String = row.get(6);
265                let variables: serde_json::Value = row.get(7);
266                let ip_address: String = row.get(8);
267                let user_agent: String = row.get(9);
268                let error: Option<String> = row.get(10);
269                let duration_ms: Option<i32> = row.get(11);
270
271                // Convert SystemTime to DateTime<Utc>
272                let timestamp = DateTime::<Utc>::from(timestamp_system);
273
274                AuditEntry {
275                    id,
276                    timestamp,
277                    level: AuditLevel::parse(&level_str),
278                    user_id,
279                    tenant_id,
280                    operation,
281                    query,
282                    variables,
283                    ip_address,
284                    user_agent,
285                    error,
286                    duration_ms,
287                    previous_hash: None,
288                    integrity_hash: None,
289                }
290            })
291            .collect();
292
293        Ok(entries)
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn test_audit_entry_integrity_hash() {
303        let entry = AuditEntry {
304            id:             Some(1),
305            timestamp:      Utc::now(),
306            level:          AuditLevel::INFO,
307            user_id:        123,
308            tenant_id:      456,
309            operation:      "query".to_string(),
310            query:          "{ users { id name } }".to_string(),
311            variables:      serde_json::json!({}),
312            ip_address:     "192.168.1.1".to_string(),
313            user_agent:     "Mozilla/5.0".to_string(),
314            error:          None,
315            duration_ms:    Some(100),
316            previous_hash:  None,
317            integrity_hash: None,
318        };
319
320        let hash = entry.calculate_hash();
321        assert!(!hash.is_empty());
322        assert_eq!(hash.len(), 64); // SHA256 hex is 64 chars
323    }
324
325    #[test]
326    fn test_audit_integrity_verification() {
327        let mut entry = AuditEntry {
328            id:             Some(1),
329            timestamp:      Utc::now(),
330            level:          AuditLevel::INFO,
331            user_id:        123,
332            tenant_id:      456,
333            operation:      "query".to_string(),
334            query:          "{ users { id name } }".to_string(),
335            variables:      serde_json::json!({}),
336            ip_address:     "192.168.1.1".to_string(),
337            user_agent:     "Mozilla/5.0".to_string(),
338            error:          None,
339            duration_ms:    Some(100),
340            previous_hash:  None,
341            integrity_hash: None,
342        };
343
344        // Calculate hash and store it
345        let calculated_hash = entry.calculate_hash();
346        entry.integrity_hash = Some(calculated_hash);
347
348        // Verify should pass
349        assert!(entry.verify_integrity());
350
351        // Tamper with data
352        entry.user_id = 999;
353
354        // Verify should fail
355        assert!(!entry.verify_integrity());
356    }
357
358    #[test]
359    fn test_audit_hash_chain() {
360        let timestamp = Utc::now();
361
362        let mut entry1 = AuditEntry {
363            id: Some(1),
364            timestamp,
365            level: AuditLevel::INFO,
366            user_id: 123,
367            tenant_id: 456,
368            operation: "query".to_string(),
369            query: "{ users { id } }".to_string(),
370            variables: serde_json::json!({}),
371            ip_address: "192.168.1.1".to_string(),
372            user_agent: "Mozilla/5.0".to_string(),
373            error: None,
374            duration_ms: Some(100),
375            previous_hash: None,
376            integrity_hash: None,
377        };
378
379        let hash1 = entry1.calculate_hash();
380        entry1.integrity_hash = Some(hash1.clone());
381
382        // Create second entry with chain
383        let mut entry2 = AuditEntry {
384            id: Some(2),
385            timestamp,
386            level: AuditLevel::INFO,
387            user_id: 123,
388            tenant_id: 456,
389            operation: "query".to_string(),
390            query: "{ posts { id } }".to_string(),
391            variables: serde_json::json!({}),
392            ip_address: "192.168.1.1".to_string(),
393            user_agent: "Mozilla/5.0".to_string(),
394            error: None,
395            duration_ms: Some(50),
396            previous_hash: Some(hash1),
397            integrity_hash: None,
398        };
399
400        let hash2 = entry2.calculate_hash();
401        entry2.integrity_hash = Some(hash2);
402
403        // Both should verify
404        assert!(entry1.verify_integrity());
405        assert!(entry2.verify_integrity());
406
407        // Breaking the chain should be detected
408        entry1.user_id = 999;
409        assert!(!entry1.verify_integrity());
410    }
411
412    #[test]
413    fn test_audit_level_parsing() {
414        assert_eq!(AuditLevel::parse("WARN"), AuditLevel::WARN);
415        assert_eq!(AuditLevel::parse("ERROR"), AuditLevel::ERROR);
416        assert_eq!(AuditLevel::parse("INFO"), AuditLevel::INFO);
417        assert_eq!(AuditLevel::parse("UNKNOWN"), AuditLevel::INFO);
418    }
419}