Skip to main content

allowthem_core/
audit.rs

1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4
5use crate::db::Db;
6use crate::error::AuthError;
7use crate::types::{AuditEntryId, UserId};
8
9/// Every type of authentication event that can be recorded.
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
11#[sqlx(rename_all = "snake_case")]
12pub enum AuditEvent {
13    Login,
14    LoginFailed,
15    Logout,
16    Register,
17    PasswordChange,
18    PasswordReset,
19    RoleAssigned,
20    RoleUnassigned,
21    PermissionAssigned,
22    PermissionUnassigned,
23    SessionCreated,
24    SessionExpired,
25    UserUpdated,
26    UserDeleted,
27    MfaEnabled,
28    MfaDisabled,
29    MfaChallengeSuccess,
30    MfaChallengeFailed,
31    OrgCreated,
32    OrgUpdated,
33    OrgDeleted,
34    OrgMemberAdded,
35    OrgMemberRemoved,
36    OrgMemberRoleChanged,
37    OrgOwnershipTransferred,
38    TeamCreated,
39    TeamUpdated,
40    TeamDeleted,
41    TeamMemberAdded,
42    TeamMemberRemoved,
43    TeamMemberRoleChanged,
44    OrgInvitationCreated,
45    OrgInvitationAccepted,
46    OrgInvitationDeclined,
47    OrgInvitationRevoked,
48}
49
50/// A single record in the audit log.
51#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
52pub struct AuditEntry {
53    pub id: AuditEntryId,
54    pub event_type: AuditEvent,
55    pub user_id: Option<UserId>,
56    pub target_id: Option<String>,
57    pub ip_address: Option<String>,
58    pub user_agent: Option<String>,
59    pub detail: Option<String>,
60    pub created_at: DateTime<Utc>,
61}
62
63/// Map an `AuditEvent` variant to its snake_case database value.
64///
65/// `AuditEvent` has `#[sqlx(rename_all = "snake_case")]` but no matching
66/// serde attribute. This function provides the canonical snake_case string
67/// for use in dynamic SQL bind values.
68fn event_to_slug(event: &AuditEvent) -> &'static str {
69    match event {
70        AuditEvent::Login => "login",
71        AuditEvent::LoginFailed => "login_failed",
72        AuditEvent::Logout => "logout",
73        AuditEvent::Register => "register",
74        AuditEvent::PasswordChange => "password_change",
75        AuditEvent::PasswordReset => "password_reset",
76        AuditEvent::RoleAssigned => "role_assigned",
77        AuditEvent::RoleUnassigned => "role_unassigned",
78        AuditEvent::PermissionAssigned => "permission_assigned",
79        AuditEvent::PermissionUnassigned => "permission_unassigned",
80        AuditEvent::SessionCreated => "session_created",
81        AuditEvent::SessionExpired => "session_expired",
82        AuditEvent::UserUpdated => "user_updated",
83        AuditEvent::UserDeleted => "user_deleted",
84        AuditEvent::MfaEnabled => "mfa_enabled",
85        AuditEvent::MfaDisabled => "mfa_disabled",
86        AuditEvent::MfaChallengeSuccess => "mfa_challenge_success",
87        AuditEvent::MfaChallengeFailed => "mfa_challenge_failed",
88        AuditEvent::OrgCreated => "org_created",
89        AuditEvent::OrgUpdated => "org_updated",
90        AuditEvent::OrgDeleted => "org_deleted",
91        AuditEvent::OrgMemberAdded => "org_member_added",
92        AuditEvent::OrgMemberRemoved => "org_member_removed",
93        AuditEvent::OrgMemberRoleChanged => "org_member_role_changed",
94        AuditEvent::OrgOwnershipTransferred => "org_ownership_transferred",
95        AuditEvent::TeamCreated => "team_created",
96        AuditEvent::TeamUpdated => "team_updated",
97        AuditEvent::TeamDeleted => "team_deleted",
98        AuditEvent::TeamMemberAdded => "team_member_added",
99        AuditEvent::TeamMemberRemoved => "team_member_removed",
100        AuditEvent::TeamMemberRoleChanged => "team_member_role_changed",
101        AuditEvent::OrgInvitationCreated => "org_invitation_created",
102        AuditEvent::OrgInvitationAccepted => "org_invitation_accepted",
103        AuditEvent::OrgInvitationDeclined => "org_invitation_declined",
104        AuditEvent::OrgInvitationRevoked => "org_invitation_revoked",
105    }
106}
107
108/// Parameters for searching/filtering audit log entries.
109pub struct SearchAuditParams<'a> {
110    pub user_id: Option<UserId>,
111    pub event_type: Option<&'a AuditEvent>,
112    pub is_success: Option<bool>,
113    pub from: Option<DateTime<Utc>>,
114    pub to: Option<DateTime<Utc>>,
115    pub limit: u32,
116    pub offset: u32,
117}
118
119/// An audit log entry with the user's email resolved via LEFT JOIN.
120/// Used for admin list display — avoids showing raw UUIDs.
121#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
122pub struct AuditListEntry {
123    pub id: AuditEntryId,
124    pub event_type: AuditEvent,
125    pub user_id: Option<UserId>,
126    pub user_email: Option<String>,
127    pub target_id: Option<String>,
128    pub ip_address: Option<String>,
129    pub user_agent: Option<String>,
130    pub detail: Option<String>,
131    pub created_at: DateTime<Utc>,
132}
133
134/// Result of a paginated audit log search.
135pub struct SearchAuditResult {
136    pub entries: Vec<AuditListEntry>,
137    pub total: u32,
138}
139
140/// Opaque keyset cursor for paginating `list_audit_paginated`.
141///
142/// Encodes `(created_at, id)` as a base64url-encoded JSON blob.
143pub struct AuditCursor {
144    pub created_at: DateTime<Utc>,
145    pub id: AuditEntryId,
146}
147
148#[derive(Serialize, Deserialize)]
149struct RawAuditCursor {
150    ca: String,
151    id: String,
152}
153
154impl AuditCursor {
155    pub fn from_entry(entry: &AuditListEntry) -> Self {
156        Self {
157            created_at: entry.created_at,
158            id: entry.id,
159        }
160    }
161
162    pub fn encode(&self) -> String {
163        let raw = RawAuditCursor {
164            ca: self.created_at.to_rfc3339(),
165            id: self.id.to_string(),
166        };
167        let json = serde_json::to_string(&raw).expect("RawAuditCursor serializes");
168        Base64UrlUnpadded::encode_string(json.as_bytes())
169    }
170
171    pub fn decode(s: &str) -> Option<Self> {
172        let bytes = Base64UrlUnpadded::decode_vec(s).ok()?;
173        let raw: RawAuditCursor = serde_json::from_slice(&bytes).ok()?;
174        let created_at = chrono::DateTime::parse_from_rfc3339(&raw.ca)
175            .ok()?
176            .with_timezone(&Utc);
177        let id = raw
178            .id
179            .parse::<uuid::Uuid>()
180            .ok()
181            .map(AuditEntryId::from_uuid)?;
182        Some(Self { created_at, id })
183    }
184}
185
186impl Db {
187    /// Record an audit event.
188    ///
189    /// `user_id` may be `None` for events where no authenticated user is
190    /// involved (e.g. a failed login attempt against an unknown email).
191    pub async fn log_audit(
192        &self,
193        event_type: AuditEvent,
194        user_id: Option<&UserId>,
195        target_id: Option<&str>,
196        ip_address: Option<&str>,
197        user_agent: Option<&str>,
198        detail: Option<&str>,
199    ) -> Result<(), AuthError> {
200        let id = AuditEntryId::new();
201        sqlx::query(
202            "INSERT INTO allowthem_audit_log
203             (id, event_type, user_id, target_id, ip_address, user_agent, detail)
204             VALUES (?, ?, ?, ?, ?, ?, ?)",
205        )
206        .bind(id)
207        .bind(event_type)
208        .bind(user_id.copied())
209        .bind(target_id)
210        .bind(ip_address)
211        .bind(user_agent)
212        .bind(detail)
213        .execute(self.pool())
214        .await
215        .map_err(AuthError::Database)?;
216        Ok(())
217    }
218
219    /// Retrieve audit log entries, optionally filtered by user.
220    ///
221    /// Results are ordered by `created_at` descending (newest first).
222    pub async fn get_audit_log(
223        &self,
224        user_id: Option<&UserId>,
225        limit: u32,
226        offset: u32,
227    ) -> Result<Vec<AuditEntry>, AuthError> {
228        match user_id {
229            Some(uid) => {
230                sqlx::query_as::<_, AuditEntry>(
231                    "SELECT id, event_type, user_id, target_id, ip_address, user_agent, detail, created_at
232                     FROM allowthem_audit_log
233                     WHERE user_id = ?
234                     ORDER BY created_at DESC
235                     LIMIT ? OFFSET ?",
236                )
237                .bind(*uid)
238                .bind(limit)
239                .bind(offset)
240                .fetch_all(self.pool())
241                .await
242                .map_err(AuthError::Database)
243            }
244            None => {
245                sqlx::query_as::<_, AuditEntry>(
246                    "SELECT id, event_type, user_id, target_id, ip_address, user_agent, detail, created_at
247                     FROM allowthem_audit_log
248                     ORDER BY created_at DESC
249                     LIMIT ? OFFSET ?",
250                )
251                .bind(limit)
252                .bind(offset)
253                .fetch_all(self.pool())
254                .await
255                .map_err(AuthError::Database)
256            }
257        }
258    }
259
260    /// Retrieve audit log entries filtered by event type.
261    ///
262    /// Results are ordered by `created_at` descending (newest first).
263    pub async fn get_audit_log_by_event(
264        &self,
265        event_type: AuditEvent,
266        limit: u32,
267        offset: u32,
268    ) -> Result<Vec<AuditEntry>, AuthError> {
269        sqlx::query_as::<_, AuditEntry>(
270            "SELECT id, event_type, user_id, target_id, ip_address, user_agent, detail, created_at
271             FROM allowthem_audit_log
272             WHERE event_type = ?
273             ORDER BY created_at DESC
274             LIMIT ? OFFSET ?",
275        )
276        .bind(event_type)
277        .bind(limit)
278        .bind(offset)
279        .fetch_all(self.pool())
280        .await
281        .map_err(AuthError::Database)
282    }
283
284    /// Get the most recent login timestamp for a user, if any.
285    ///
286    /// Returns `None` if the user has never logged in (no audit entry
287    /// with event_type = 'login' for this user_id).
288    pub async fn last_login_at(&self, user_id: UserId) -> Result<Option<DateTime<Utc>>, AuthError> {
289        sqlx::query_scalar(
290            "SELECT MAX(created_at) FROM allowthem_audit_log \
291             WHERE user_id = ? AND event_type = 'login'",
292        )
293        .bind(user_id)
294        .fetch_one(self.pool())
295        .await
296        .map_err(AuthError::Database)
297    }
298
299    /// Paginated list of audit entries using a `(created_at, id)` keyset cursor.
300    ///
301    /// Ordered newest-first. Pass `None` for cursor to start from the beginning.
302    /// Limits are capped at 200.
303    pub async fn list_audit_paginated(
304        &self,
305        limit: u32,
306        cursor: Option<&AuditCursor>,
307    ) -> Result<Vec<AuditListEntry>, AuthError> {
308        let limit = (limit as i64).min(200);
309        match cursor {
310            None => sqlx::query_as::<_, AuditListEntry>(
311                "SELECT a.id, a.event_type, a.user_id, u.email AS user_email, \
312                 a.target_id, a.ip_address, a.user_agent, a.detail, a.created_at \
313                 FROM allowthem_audit_log a \
314                 LEFT JOIN allowthem_users u ON a.user_id = u.id \
315                 ORDER BY a.created_at DESC, a.id DESC \
316                 LIMIT ?",
317            )
318            .bind(limit)
319            .fetch_all(self.pool())
320            .await
321            .map_err(AuthError::Database),
322            Some(c) => {
323                let ca = c.created_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
324                sqlx::query_as::<_, AuditListEntry>(
325                    "SELECT a.id, a.event_type, a.user_id, u.email AS user_email, \
326                     a.target_id, a.ip_address, a.user_agent, a.detail, a.created_at \
327                     FROM allowthem_audit_log a \
328                     LEFT JOIN allowthem_users u ON a.user_id = u.id \
329                     WHERE (a.created_at < ?1 OR (a.created_at = ?1 AND a.id < ?2)) \
330                     ORDER BY a.created_at DESC, a.id DESC \
331                     LIMIT ?3",
332                )
333                .bind(&ca)
334                .bind(c.id)
335                .bind(limit)
336                .fetch_all(self.pool())
337                .await
338                .map_err(AuthError::Database)
339            }
340        }
341    }
342
343    /// Search and filter audit log entries with pagination.
344    ///
345    /// Builds a dynamic query with optional filters for user, event type,
346    /// outcome, and date range. LEFT JOINs `allowthem_users` for email
347    /// resolution. Follows the same dynamic-SQL pattern as `search_users`.
348    pub async fn search_audit_log(
349        &self,
350        params: SearchAuditParams<'_>,
351    ) -> Result<SearchAuditResult, AuthError> {
352        // Build WHERE clauses and bind values. user_id is bound separately
353        // because it is a UUID (not a String), and sqlx needs the correct
354        // type for TEXT column comparison in SQLite.
355        let mut where_clauses: Vec<String> = Vec::new();
356        let mut string_binds: Vec<String> = Vec::new();
357
358        if params.user_id.is_some() {
359            where_clauses.push("a.user_id = ?".into());
360            // Bound separately below — position tracked by clause order
361        }
362
363        if let Some(event) = params.event_type {
364            where_clauses.push("a.event_type = ?".into());
365            string_binds.push(event_to_slug(event).to_string());
366        }
367
368        match params.is_success {
369            Some(true) => {
370                where_clauses.push("a.event_type != 'login_failed'".into());
371            }
372            Some(false) => {
373                where_clauses.push("a.event_type = 'login_failed'".into());
374            }
375            None => {}
376        }
377
378        if let Some(from) = params.from {
379            where_clauses.push("a.created_at >= ?".into());
380            string_binds.push(from.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string());
381        }
382
383        if let Some(to) = params.to {
384            where_clauses.push("a.created_at < ?".into());
385            string_binds.push(to.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string());
386        }
387
388        let where_sql = if where_clauses.is_empty() {
389            String::new()
390        } else {
391            format!("WHERE {}", where_clauses.join(" AND "))
392        };
393
394        // Count query
395        let count_sql: &'static str = Box::leak(
396            format!("SELECT COUNT(*) FROM allowthem_audit_log a {where_sql}").into_boxed_str(),
397        );
398        let mut count_query = sqlx::query_scalar::<_, i64>(count_sql);
399        // Bind user_id first (it's the first WHERE clause if present)
400        if let Some(uid) = params.user_id {
401            count_query = count_query.bind(uid);
402        }
403        for val in &string_binds {
404            count_query = count_query.bind(val);
405        }
406        let total = count_query
407            .fetch_one(self.pool())
408            .await
409            .map_err(AuthError::Database)? as u32;
410
411        // Data query with LEFT JOIN for user email
412        let data_sql: &'static str = Box::leak(
413            format!(
414                "SELECT a.id, a.event_type, a.user_id, u.email AS user_email, \
415                 a.target_id, a.ip_address, a.user_agent, a.detail, a.created_at \
416                 FROM allowthem_audit_log a \
417                 LEFT JOIN allowthem_users u ON a.user_id = u.id \
418                 {where_sql} \
419                 ORDER BY a.created_at DESC \
420                 LIMIT ? OFFSET ?"
421            )
422            .into_boxed_str(),
423        );
424        let mut data_query = sqlx::query_as::<_, AuditListEntry>(data_sql);
425        if let Some(uid) = params.user_id {
426            data_query = data_query.bind(uid);
427        }
428        for val in &string_binds {
429            data_query = data_query.bind(val);
430        }
431        data_query = data_query.bind(params.limit).bind(params.offset);
432
433        let entries = data_query
434            .fetch_all(self.pool())
435            .await
436            .map_err(AuthError::Database)?;
437
438        Ok(SearchAuditResult { entries, total })
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use crate::handle::{AllowThem, AllowThemBuilder};
446
447    async fn setup() -> AllowThem {
448        AllowThemBuilder::new("sqlite::memory:")
449            .cookie_secure(false)
450            .build()
451            .await
452            .unwrap()
453    }
454
455    async fn log_event(db: &Db, tag: u32) {
456        db.log_audit(
457            AuditEvent::Login,
458            None,
459            Some(&format!("target-{tag}")),
460            None,
461            None,
462            None,
463        )
464        .await
465        .unwrap();
466    }
467
468    #[tokio::test]
469    async fn audit_cursor_encode_decode_roundtrip() {
470        let ath = setup().await;
471        let db = ath.db();
472        log_event(db, 1).await;
473        let entries = db.list_audit_paginated(10, None).await.unwrap();
474        assert_eq!(entries.len(), 1);
475        let cursor = AuditCursor::from_entry(&entries[0]);
476        let encoded = cursor.encode();
477        let decoded = AuditCursor::decode(&encoded).unwrap();
478        assert_eq!(decoded.id, entries[0].id);
479    }
480
481    #[tokio::test]
482    async fn list_audit_paginated_returns_first_page() {
483        let ath = setup().await;
484        let db = ath.db();
485        for i in 0..5 {
486            log_event(db, i).await;
487        }
488        let page = db.list_audit_paginated(3, None).await.unwrap();
489        assert_eq!(page.len(), 3);
490    }
491
492    #[tokio::test]
493    async fn list_audit_paginated_cursor_advances() {
494        let ath = setup().await;
495        let db = ath.db();
496        for i in 0..5 {
497            log_event(db, i + 10).await;
498        }
499        let page1 = db.list_audit_paginated(3, None).await.unwrap();
500        assert_eq!(page1.len(), 3);
501        let cursor = AuditCursor::from_entry(page1.last().unwrap());
502        let page2 = db.list_audit_paginated(3, Some(&cursor)).await.unwrap();
503        assert_eq!(page2.len(), 2);
504        assert!(!page2.iter().any(|e| page1.iter().any(|f| f.id == e.id)));
505    }
506}