1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4
5use crate::db::Db;
6use crate::error::AuthError;
7use crate::types::{AuditEntryId, UserId};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
11#[sqlx(rename_all = "snake_case")]
12pub enum AuditEvent {
13 Login,
14 LoginFailed,
15 Logout,
16 Register,
17 PasswordChange,
18 PasswordReset,
19 RoleAssigned,
20 RoleUnassigned,
21 PermissionAssigned,
22 PermissionUnassigned,
23 SessionCreated,
24 SessionExpired,
25 UserUpdated,
26 UserDeleted,
27 MfaEnabled,
28 MfaDisabled,
29 MfaChallengeSuccess,
30 MfaChallengeFailed,
31 OrgCreated,
32 OrgUpdated,
33 OrgDeleted,
34 OrgMemberAdded,
35 OrgMemberRemoved,
36 OrgMemberRoleChanged,
37 OrgOwnershipTransferred,
38 TeamCreated,
39 TeamUpdated,
40 TeamDeleted,
41 TeamMemberAdded,
42 TeamMemberRemoved,
43 TeamMemberRoleChanged,
44 OrgInvitationCreated,
45 OrgInvitationAccepted,
46 OrgInvitationDeclined,
47 OrgInvitationRevoked,
48}
49
50#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
52pub struct AuditEntry {
53 pub id: AuditEntryId,
54 pub event_type: AuditEvent,
55 pub user_id: Option<UserId>,
56 pub target_id: Option<String>,
57 pub ip_address: Option<String>,
58 pub user_agent: Option<String>,
59 pub detail: Option<String>,
60 pub created_at: DateTime<Utc>,
61}
62
63fn event_to_slug(event: &AuditEvent) -> &'static str {
69 match event {
70 AuditEvent::Login => "login",
71 AuditEvent::LoginFailed => "login_failed",
72 AuditEvent::Logout => "logout",
73 AuditEvent::Register => "register",
74 AuditEvent::PasswordChange => "password_change",
75 AuditEvent::PasswordReset => "password_reset",
76 AuditEvent::RoleAssigned => "role_assigned",
77 AuditEvent::RoleUnassigned => "role_unassigned",
78 AuditEvent::PermissionAssigned => "permission_assigned",
79 AuditEvent::PermissionUnassigned => "permission_unassigned",
80 AuditEvent::SessionCreated => "session_created",
81 AuditEvent::SessionExpired => "session_expired",
82 AuditEvent::UserUpdated => "user_updated",
83 AuditEvent::UserDeleted => "user_deleted",
84 AuditEvent::MfaEnabled => "mfa_enabled",
85 AuditEvent::MfaDisabled => "mfa_disabled",
86 AuditEvent::MfaChallengeSuccess => "mfa_challenge_success",
87 AuditEvent::MfaChallengeFailed => "mfa_challenge_failed",
88 AuditEvent::OrgCreated => "org_created",
89 AuditEvent::OrgUpdated => "org_updated",
90 AuditEvent::OrgDeleted => "org_deleted",
91 AuditEvent::OrgMemberAdded => "org_member_added",
92 AuditEvent::OrgMemberRemoved => "org_member_removed",
93 AuditEvent::OrgMemberRoleChanged => "org_member_role_changed",
94 AuditEvent::OrgOwnershipTransferred => "org_ownership_transferred",
95 AuditEvent::TeamCreated => "team_created",
96 AuditEvent::TeamUpdated => "team_updated",
97 AuditEvent::TeamDeleted => "team_deleted",
98 AuditEvent::TeamMemberAdded => "team_member_added",
99 AuditEvent::TeamMemberRemoved => "team_member_removed",
100 AuditEvent::TeamMemberRoleChanged => "team_member_role_changed",
101 AuditEvent::OrgInvitationCreated => "org_invitation_created",
102 AuditEvent::OrgInvitationAccepted => "org_invitation_accepted",
103 AuditEvent::OrgInvitationDeclined => "org_invitation_declined",
104 AuditEvent::OrgInvitationRevoked => "org_invitation_revoked",
105 }
106}
107
108pub struct SearchAuditParams<'a> {
110 pub user_id: Option<UserId>,
111 pub event_type: Option<&'a AuditEvent>,
112 pub is_success: Option<bool>,
113 pub from: Option<DateTime<Utc>>,
114 pub to: Option<DateTime<Utc>>,
115 pub limit: u32,
116 pub offset: u32,
117}
118
119#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
122pub struct AuditListEntry {
123 pub id: AuditEntryId,
124 pub event_type: AuditEvent,
125 pub user_id: Option<UserId>,
126 pub user_email: Option<String>,
127 pub target_id: Option<String>,
128 pub ip_address: Option<String>,
129 pub user_agent: Option<String>,
130 pub detail: Option<String>,
131 pub created_at: DateTime<Utc>,
132}
133
134pub struct SearchAuditResult {
136 pub entries: Vec<AuditListEntry>,
137 pub total: u32,
138}
139
140pub struct AuditCursor {
144 pub created_at: DateTime<Utc>,
145 pub id: AuditEntryId,
146}
147
148#[derive(Serialize, Deserialize)]
149struct RawAuditCursor {
150 ca: String,
151 id: String,
152}
153
154impl AuditCursor {
155 pub fn from_entry(entry: &AuditListEntry) -> Self {
156 Self {
157 created_at: entry.created_at,
158 id: entry.id,
159 }
160 }
161
162 pub fn encode(&self) -> String {
163 let raw = RawAuditCursor {
164 ca: self.created_at.to_rfc3339(),
165 id: self.id.to_string(),
166 };
167 let json = serde_json::to_string(&raw).expect("RawAuditCursor serializes");
168 Base64UrlUnpadded::encode_string(json.as_bytes())
169 }
170
171 pub fn decode(s: &str) -> Option<Self> {
172 let bytes = Base64UrlUnpadded::decode_vec(s).ok()?;
173 let raw: RawAuditCursor = serde_json::from_slice(&bytes).ok()?;
174 let created_at = chrono::DateTime::parse_from_rfc3339(&raw.ca)
175 .ok()?
176 .with_timezone(&Utc);
177 let id = raw
178 .id
179 .parse::<uuid::Uuid>()
180 .ok()
181 .map(AuditEntryId::from_uuid)?;
182 Some(Self { created_at, id })
183 }
184}
185
186impl Db {
187 pub async fn log_audit(
192 &self,
193 event_type: AuditEvent,
194 user_id: Option<&UserId>,
195 target_id: Option<&str>,
196 ip_address: Option<&str>,
197 user_agent: Option<&str>,
198 detail: Option<&str>,
199 ) -> Result<(), AuthError> {
200 let id = AuditEntryId::new();
201 sqlx::query(
202 "INSERT INTO allowthem_audit_log
203 (id, event_type, user_id, target_id, ip_address, user_agent, detail)
204 VALUES (?, ?, ?, ?, ?, ?, ?)",
205 )
206 .bind(id)
207 .bind(event_type)
208 .bind(user_id.copied())
209 .bind(target_id)
210 .bind(ip_address)
211 .bind(user_agent)
212 .bind(detail)
213 .execute(self.pool())
214 .await
215 .map_err(AuthError::Database)?;
216 Ok(())
217 }
218
219 pub async fn get_audit_log(
223 &self,
224 user_id: Option<&UserId>,
225 limit: u32,
226 offset: u32,
227 ) -> Result<Vec<AuditEntry>, AuthError> {
228 match user_id {
229 Some(uid) => {
230 sqlx::query_as::<_, AuditEntry>(
231 "SELECT id, event_type, user_id, target_id, ip_address, user_agent, detail, created_at
232 FROM allowthem_audit_log
233 WHERE user_id = ?
234 ORDER BY created_at DESC
235 LIMIT ? OFFSET ?",
236 )
237 .bind(*uid)
238 .bind(limit)
239 .bind(offset)
240 .fetch_all(self.pool())
241 .await
242 .map_err(AuthError::Database)
243 }
244 None => {
245 sqlx::query_as::<_, AuditEntry>(
246 "SELECT id, event_type, user_id, target_id, ip_address, user_agent, detail, created_at
247 FROM allowthem_audit_log
248 ORDER BY created_at DESC
249 LIMIT ? OFFSET ?",
250 )
251 .bind(limit)
252 .bind(offset)
253 .fetch_all(self.pool())
254 .await
255 .map_err(AuthError::Database)
256 }
257 }
258 }
259
260 pub async fn get_audit_log_by_event(
264 &self,
265 event_type: AuditEvent,
266 limit: u32,
267 offset: u32,
268 ) -> Result<Vec<AuditEntry>, AuthError> {
269 sqlx::query_as::<_, AuditEntry>(
270 "SELECT id, event_type, user_id, target_id, ip_address, user_agent, detail, created_at
271 FROM allowthem_audit_log
272 WHERE event_type = ?
273 ORDER BY created_at DESC
274 LIMIT ? OFFSET ?",
275 )
276 .bind(event_type)
277 .bind(limit)
278 .bind(offset)
279 .fetch_all(self.pool())
280 .await
281 .map_err(AuthError::Database)
282 }
283
284 pub async fn last_login_at(&self, user_id: UserId) -> Result<Option<DateTime<Utc>>, AuthError> {
289 sqlx::query_scalar(
290 "SELECT MAX(created_at) FROM allowthem_audit_log \
291 WHERE user_id = ? AND event_type = 'login'",
292 )
293 .bind(user_id)
294 .fetch_one(self.pool())
295 .await
296 .map_err(AuthError::Database)
297 }
298
299 pub async fn list_audit_paginated(
304 &self,
305 limit: u32,
306 cursor: Option<&AuditCursor>,
307 ) -> Result<Vec<AuditListEntry>, AuthError> {
308 let limit = (limit as i64).min(200);
309 match cursor {
310 None => sqlx::query_as::<_, AuditListEntry>(
311 "SELECT a.id, a.event_type, a.user_id, u.email AS user_email, \
312 a.target_id, a.ip_address, a.user_agent, a.detail, a.created_at \
313 FROM allowthem_audit_log a \
314 LEFT JOIN allowthem_users u ON a.user_id = u.id \
315 ORDER BY a.created_at DESC, a.id DESC \
316 LIMIT ?",
317 )
318 .bind(limit)
319 .fetch_all(self.pool())
320 .await
321 .map_err(AuthError::Database),
322 Some(c) => {
323 let ca = c.created_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
324 sqlx::query_as::<_, AuditListEntry>(
325 "SELECT a.id, a.event_type, a.user_id, u.email AS user_email, \
326 a.target_id, a.ip_address, a.user_agent, a.detail, a.created_at \
327 FROM allowthem_audit_log a \
328 LEFT JOIN allowthem_users u ON a.user_id = u.id \
329 WHERE (a.created_at < ?1 OR (a.created_at = ?1 AND a.id < ?2)) \
330 ORDER BY a.created_at DESC, a.id DESC \
331 LIMIT ?3",
332 )
333 .bind(&ca)
334 .bind(c.id)
335 .bind(limit)
336 .fetch_all(self.pool())
337 .await
338 .map_err(AuthError::Database)
339 }
340 }
341 }
342
343 pub async fn search_audit_log(
349 &self,
350 params: SearchAuditParams<'_>,
351 ) -> Result<SearchAuditResult, AuthError> {
352 let mut where_clauses: Vec<String> = Vec::new();
356 let mut string_binds: Vec<String> = Vec::new();
357
358 if params.user_id.is_some() {
359 where_clauses.push("a.user_id = ?".into());
360 }
362
363 if let Some(event) = params.event_type {
364 where_clauses.push("a.event_type = ?".into());
365 string_binds.push(event_to_slug(event).to_string());
366 }
367
368 match params.is_success {
369 Some(true) => {
370 where_clauses.push("a.event_type != 'login_failed'".into());
371 }
372 Some(false) => {
373 where_clauses.push("a.event_type = 'login_failed'".into());
374 }
375 None => {}
376 }
377
378 if let Some(from) = params.from {
379 where_clauses.push("a.created_at >= ?".into());
380 string_binds.push(from.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string());
381 }
382
383 if let Some(to) = params.to {
384 where_clauses.push("a.created_at < ?".into());
385 string_binds.push(to.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string());
386 }
387
388 let where_sql = if where_clauses.is_empty() {
389 String::new()
390 } else {
391 format!("WHERE {}", where_clauses.join(" AND "))
392 };
393
394 let count_sql: &'static str = Box::leak(
396 format!("SELECT COUNT(*) FROM allowthem_audit_log a {where_sql}").into_boxed_str(),
397 );
398 let mut count_query = sqlx::query_scalar::<_, i64>(count_sql);
399 if let Some(uid) = params.user_id {
401 count_query = count_query.bind(uid);
402 }
403 for val in &string_binds {
404 count_query = count_query.bind(val);
405 }
406 let total = count_query
407 .fetch_one(self.pool())
408 .await
409 .map_err(AuthError::Database)? as u32;
410
411 let data_sql: &'static str = Box::leak(
413 format!(
414 "SELECT a.id, a.event_type, a.user_id, u.email AS user_email, \
415 a.target_id, a.ip_address, a.user_agent, a.detail, a.created_at \
416 FROM allowthem_audit_log a \
417 LEFT JOIN allowthem_users u ON a.user_id = u.id \
418 {where_sql} \
419 ORDER BY a.created_at DESC \
420 LIMIT ? OFFSET ?"
421 )
422 .into_boxed_str(),
423 );
424 let mut data_query = sqlx::query_as::<_, AuditListEntry>(data_sql);
425 if let Some(uid) = params.user_id {
426 data_query = data_query.bind(uid);
427 }
428 for val in &string_binds {
429 data_query = data_query.bind(val);
430 }
431 data_query = data_query.bind(params.limit).bind(params.offset);
432
433 let entries = data_query
434 .fetch_all(self.pool())
435 .await
436 .map_err(AuthError::Database)?;
437
438 Ok(SearchAuditResult { entries, total })
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445 use crate::handle::{AllowThem, AllowThemBuilder};
446
447 async fn setup() -> AllowThem {
448 AllowThemBuilder::new("sqlite::memory:")
449 .cookie_secure(false)
450 .build()
451 .await
452 .unwrap()
453 }
454
455 async fn log_event(db: &Db, tag: u32) {
456 db.log_audit(
457 AuditEvent::Login,
458 None,
459 Some(&format!("target-{tag}")),
460 None,
461 None,
462 None,
463 )
464 .await
465 .unwrap();
466 }
467
468 #[tokio::test]
469 async fn audit_cursor_encode_decode_roundtrip() {
470 let ath = setup().await;
471 let db = ath.db();
472 log_event(db, 1).await;
473 let entries = db.list_audit_paginated(10, None).await.unwrap();
474 assert_eq!(entries.len(), 1);
475 let cursor = AuditCursor::from_entry(&entries[0]);
476 let encoded = cursor.encode();
477 let decoded = AuditCursor::decode(&encoded).unwrap();
478 assert_eq!(decoded.id, entries[0].id);
479 }
480
481 #[tokio::test]
482 async fn list_audit_paginated_returns_first_page() {
483 let ath = setup().await;
484 let db = ath.db();
485 for i in 0..5 {
486 log_event(db, i).await;
487 }
488 let page = db.list_audit_paginated(3, None).await.unwrap();
489 assert_eq!(page.len(), 3);
490 }
491
492 #[tokio::test]
493 async fn list_audit_paginated_cursor_advances() {
494 let ath = setup().await;
495 let db = ath.db();
496 for i in 0..5 {
497 log_event(db, i + 10).await;
498 }
499 let page1 = db.list_audit_paginated(3, None).await.unwrap();
500 assert_eq!(page1.len(), 3);
501 let cursor = AuditCursor::from_entry(page1.last().unwrap());
502 let page2 = db.list_audit_paginated(3, Some(&cursor)).await.unwrap();
503 assert_eq!(page2.len(), 2);
504 assert!(!page2.iter().any(|e| page1.iter().any(|f| f.id == e.id)));
505 }
506}