Skip to main content

auth_framework/
maintenance.rs

1//! Maintenance utilities: snapshots, data export, and health checks.
2//!
3//! Provides tools for operational maintenance including:
4//!
5//! - **Snapshot & restore** — Serialise the entire storage state to a
6//!   versioned, checksummed snapshot file and restore from it.
7//! - **Data export** — Export users, sessions, tokens, and audit logs as
8//!   structured JSON for compliance or migration purposes.
9//! - **Health checks** — Verify storage connectivity, token validity, and
10//!   system integrity.
11//!
12//! Most operations are available through the
13//! [`MaintenanceOperations`](crate::auth::MaintenanceOperations) facade.
14
15use crate::auth::AuthFramework;
16use crate::auth_operations::UserListQuery;
17use crate::config::{StorageConfig, app_config::AppConfig};
18use crate::errors::{AuthError, Result};
19use crate::permissions::Role;
20use crate::storage::SessionData;
21use crate::tokens::AuthToken;
22use base64::Engine;
23use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
24use chrono::{DateTime, Utc};
25use serde::{Deserialize, Serialize};
26use serde_json::{Map, Value};
27use sha2::{Digest, Sha256};
28use std::collections::HashSet;
29use std::path::{Path, PathBuf};
30
31const SNAPSHOT_FORMAT_VERSION: u32 = 1;
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct SnapshotManifest {
35    pub format_version: u32,
36    pub created_at: DateTime<Utc>,
37    pub storage_backend: String,
38    pub user_count: usize,
39    pub role_count: usize,
40    pub token_count: usize,
41    pub session_count: usize,
42    pub kv_entry_count: usize,
43    pub checksum_sha256: String,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct SnapshotUserSummary {
48    pub id: String,
49    pub username: String,
50    pub email: Option<String>,
51    pub roles: Vec<String>,
52    pub active: bool,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SnapshotKvEntry {
57    pub key: String,
58    pub value_base64: String,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct MaintenanceSnapshot {
63    pub manifest: SnapshotManifest,
64    pub users: Vec<SnapshotUserSummary>,
65    pub roles: Vec<Role>,
66    pub tokens: Vec<AuthToken>,
67    pub sessions: Vec<SessionData>,
68    pub kv_entries: Vec<SnapshotKvEntry>,
69}
70
71#[derive(Debug, Clone)]
72pub struct BackupReport {
73    pub manifest: SnapshotManifest,
74    pub output_path: PathBuf,
75    pub dry_run: bool,
76}
77
78#[derive(Debug, Clone)]
79pub struct ResetReport {
80    pub users_deleted: usize,
81    pub roles_seen: usize,
82    pub tokens_deleted: usize,
83    pub sessions_deleted: usize,
84    pub kv_entries_deleted: usize,
85    pub dry_run: bool,
86}
87
88#[derive(Debug, Clone)]
89pub struct RestoreReport {
90    pub manifest: SnapshotManifest,
91    pub input_path: PathBuf,
92    pub reset_report: ResetReport,
93    pub dry_run: bool,
94}
95
96#[derive(Debug, Clone)]
97pub struct MigrationFileReport {
98    pub backend: String,
99    pub path: PathBuf,
100}
101
102#[derive(Serialize)]
103struct SnapshotChecksumPayload<'a> {
104    users: &'a [SnapshotUserSummary],
105    roles: &'a [Role],
106    tokens: &'a [AuthToken],
107    sessions: &'a [SessionData],
108    kv_entries: &'a [SnapshotKvEntry],
109}
110
111fn normalize_json_value(value: Value) -> Value {
112    match value {
113        Value::Array(values) => {
114            let mut normalized = values
115                .into_iter()
116                .map(normalize_json_value)
117                .collect::<Vec<_>>();
118            normalized.sort_by(|left, right| left.to_string().cmp(&right.to_string()));
119            Value::Array(normalized)
120        }
121        Value::Object(object) => {
122            let mut entries = object.into_iter().collect::<Vec<_>>();
123            entries.sort_by(|left, right| left.0.cmp(&right.0));
124
125            let normalized = entries
126                .into_iter()
127                .map(|(key, value)| (key, normalize_json_value(value)))
128                .collect::<Map<String, Value>>();
129            Value::Object(normalized)
130        }
131        other => other,
132    }
133}
134
135fn storage_backend_name(config: &StorageConfig) -> &'static str {
136    match config {
137        StorageConfig::Memory => "memory",
138        #[cfg(feature = "postgres-storage")]
139        StorageConfig::Postgres { .. } => "postgres",
140        #[cfg(feature = "redis-storage")]
141        StorageConfig::Redis { .. } => "redis",
142        #[cfg(feature = "mysql-storage")]
143        StorageConfig::MySQL { .. } => "mysql",
144        #[cfg(feature = "sqlite-storage")]
145        StorageConfig::Sqlite { .. } => "sqlite",
146        StorageConfig::Custom(_) => "custom",
147    }
148}
149
150fn backend_name_from_database_url(database_url: &str) -> &'static str {
151    let database_url = database_url.trim().to_ascii_lowercase();
152
153    if database_url.starts_with("postgres://") || database_url.starts_with("postgresql://") {
154        "postgres"
155    } else if database_url.starts_with("mysql://") {
156        "mysql"
157    } else if database_url.starts_with("sqlite:") || database_url.ends_with(".db") {
158        "sqlite"
159    } else if database_url.starts_with("redis://") || database_url.starts_with("rediss://") {
160        "redis"
161    } else if database_url.is_empty() {
162        "memory"
163    } else {
164        "custom"
165    }
166}
167
168fn checksum_snapshot(
169    users: &[SnapshotUserSummary],
170    roles: &[Role],
171    tokens: &[AuthToken],
172    sessions: &[SessionData],
173    kv_entries: &[SnapshotKvEntry],
174) -> Result<String> {
175    let payload = SnapshotChecksumPayload {
176        users,
177        roles,
178        tokens,
179        sessions,
180        kv_entries,
181    };
182    let encoded = serde_json::to_value(&payload)
183        .map(normalize_json_value)
184        .and_then(|value| serde_json::to_vec(&value))
185        .map_err(|e| AuthError::internal(format!("Failed to serialize snapshot payload: {e}")))?;
186    let mut hasher = Sha256::new();
187    hasher.update(encoded);
188    Ok(hex::encode(hasher.finalize()))
189}
190
191fn sanitize_migration_name(name: &str) -> Result<String> {
192    let sanitized = name
193        .trim()
194        .chars()
195        .map(|character| {
196            if character.is_ascii_alphanumeric() {
197                character.to_ascii_lowercase()
198            } else {
199                '_'
200            }
201        })
202        .collect::<String>();
203
204    let collapsed = sanitized
205        .split('_')
206        .filter(|segment| !segment.is_empty())
207        .collect::<Vec<_>>()
208        .join("_");
209
210    if collapsed.is_empty() {
211        return Err(AuthError::validation(
212            "Migration name must contain at least one alphanumeric character",
213        ));
214    }
215
216    Ok(collapsed)
217}
218
219async fn collect_snapshot(framework: &AuthFramework) -> Result<MaintenanceSnapshot> {
220    let storage = framework.storage();
221
222    let mut users = framework
223        .users()
224        .list_with_query(UserListQuery::new())
225        .await?;
226    users.sort_by(|left, right| left.id.cmp(&right.id));
227
228    let mut snapshot_users = Vec::with_capacity(users.len());
229    for user in &users {
230        let mut roles: HashSet<String> = user.roles.iter().cloned().collect();
231        roles.extend(framework.authorization().roles_for_user(&user.id).await?);
232        let mut roles = roles.into_iter().collect::<Vec<_>>();
233        roles.sort();
234
235        snapshot_users.push(SnapshotUserSummary {
236            id: user.id.clone(),
237            username: user.username.clone(),
238            email: user.email.clone(),
239            roles,
240            active: user.active,
241        });
242    }
243
244    let mut roles = framework.authorization().list_roles().await;
245    roles.sort_by(|left, right| left.name.cmp(&right.name));
246
247    let mut tokens = Vec::new();
248    let mut seen_tokens = HashSet::new();
249    let mut sessions = Vec::new();
250    let mut seen_sessions = HashSet::new();
251
252    for user in &users {
253        for token in framework.tokens().list_for_user(&user.id).await? {
254            if seen_tokens.insert(token.token_id.clone()) {
255                tokens.push(token);
256            }
257        }
258
259        for session in framework.sessions().list_for_user(&user.id).await? {
260            if seen_sessions.insert(session.session_id.clone()) {
261                sessions.push(session);
262            }
263        }
264    }
265
266    tokens.sort_by(|left, right| left.token_id.cmp(&right.token_id));
267    sessions.sort_by(|left, right| left.session_id.cmp(&right.session_id));
268
269    let mut kv_keys = storage.list_kv_keys("").await?;
270    kv_keys.sort();
271    kv_keys.dedup();
272
273    let mut kv_entries = Vec::with_capacity(kv_keys.len());
274    for key in kv_keys {
275        if let Some(value) = storage.get_kv(&key).await? {
276            kv_entries.push(SnapshotKvEntry {
277                key,
278                value_base64: BASE64_STANDARD.encode(value),
279            });
280        }
281    }
282
283    let manifest = SnapshotManifest {
284        format_version: SNAPSHOT_FORMAT_VERSION,
285        created_at: Utc::now(),
286        storage_backend: storage_backend_name(&framework.config().storage).to_string(),
287        user_count: snapshot_users.len(),
288        role_count: roles.len(),
289        token_count: tokens.len(),
290        session_count: sessions.len(),
291        kv_entry_count: kv_entries.len(),
292        checksum_sha256: checksum_snapshot(
293            &snapshot_users,
294            &roles,
295            &tokens,
296            &sessions,
297            &kv_entries,
298        )?,
299    };
300
301    Ok(MaintenanceSnapshot {
302        manifest,
303        users: snapshot_users,
304        roles,
305        tokens,
306        sessions,
307        kv_entries,
308    })
309}
310
311fn validate_snapshot(snapshot: &MaintenanceSnapshot) -> Result<()> {
312    if snapshot.manifest.format_version != SNAPSHOT_FORMAT_VERSION {
313        return Err(AuthError::configuration(format!(
314            "Unsupported snapshot format version {}",
315            snapshot.manifest.format_version
316        )));
317    }
318
319    let expected_checksum = checksum_snapshot(
320        &snapshot.users,
321        &snapshot.roles,
322        &snapshot.tokens,
323        &snapshot.sessions,
324        &snapshot.kv_entries,
325    )?;
326
327    if expected_checksum != snapshot.manifest.checksum_sha256 {
328        return Err(AuthError::validation(
329            "Snapshot checksum validation failed; restore aborted",
330        ));
331    }
332
333    Ok(())
334}
335
336pub async fn backup_to_file(
337    framework: &AuthFramework,
338    output_path: impl AsRef<Path>,
339    dry_run: bool,
340) -> Result<BackupReport> {
341    let output_path = output_path.as_ref().to_path_buf();
342    let snapshot = collect_snapshot(framework).await?;
343
344    if !dry_run {
345        if let Some(parent) = output_path.parent() {
346            if !parent.as_os_str().is_empty() {
347                tokio::fs::create_dir_all(parent).await?;
348            }
349        }
350
351        let data = serde_json::to_vec_pretty(&snapshot).map_err(|e| {
352            AuthError::internal(format!("Failed to serialize maintenance snapshot: {e}"))
353        })?;
354        tokio::fs::write(&output_path, data).await?;
355    }
356
357    Ok(BackupReport {
358        manifest: snapshot.manifest,
359        output_path,
360        dry_run,
361    })
362}
363
364pub async fn reset_runtime_data(framework: &AuthFramework, dry_run: bool) -> Result<ResetReport> {
365    let storage = framework.storage();
366    let users = framework
367        .users()
368        .list_with_query(UserListQuery::new())
369        .await?;
370    let roles = framework.authorization().list_roles().await;
371
372    let mut token_ids = HashSet::new();
373    let mut session_ids = HashSet::new();
374    for user in &users {
375        for token in framework.tokens().list_for_user(&user.id).await? {
376            token_ids.insert(token.token_id);
377        }
378
379        for session in framework.sessions().list_for_user(&user.id).await? {
380            session_ids.insert(session.session_id);
381        }
382    }
383
384    let mut kv_keys = storage.list_kv_keys("").await?;
385    kv_keys.sort();
386    kv_keys.dedup();
387
388    if !dry_run {
389        for token_id in &token_ids {
390            storage.delete_token(token_id).await?;
391        }
392
393        for session_id in &session_ids {
394            storage.delete_session(session_id).await?;
395        }
396
397        for user in &users {
398            framework.users().delete_by_id(&user.id).await?;
399        }
400
401        for key in &kv_keys {
402            storage.delete_kv(key).await?;
403        }
404
405        framework.reset_authorization_runtime().await;
406    }
407
408    Ok(ResetReport {
409        users_deleted: users.len(),
410        roles_seen: roles.len(),
411        tokens_deleted: token_ids.len(),
412        sessions_deleted: session_ids.len(),
413        kv_entries_deleted: kv_keys.len(),
414        dry_run,
415    })
416}
417
418pub async fn restore_from_file(
419    framework: &AuthFramework,
420    input_path: impl AsRef<Path>,
421    dry_run: bool,
422) -> Result<RestoreReport> {
423    let input_path = input_path.as_ref().to_path_buf();
424    let data = tokio::fs::read(&input_path).await?;
425    let snapshot: MaintenanceSnapshot = serde_json::from_slice(&data)
426        .map_err(|e| AuthError::validation(format!("Failed to parse maintenance snapshot: {e}")))?;
427    validate_snapshot(&snapshot)?;
428
429    let reset_report = reset_runtime_data(framework, dry_run).await?;
430
431    if !dry_run {
432        let storage = framework.storage();
433
434        for entry in &snapshot.kv_entries {
435            let value = BASE64_STANDARD.decode(&entry.value_base64).map_err(|e| {
436                AuthError::validation(format!(
437                    "Snapshot KV entry '{}' is not valid base64: {e}",
438                    entry.key
439                ))
440            })?;
441            storage.store_kv(&entry.key, &value, None).await?;
442        }
443
444        for token in &snapshot.tokens {
445            storage.store_token(token).await?;
446        }
447
448        for session in &snapshot.sessions {
449            storage.store_session(&session.session_id, session).await?;
450        }
451
452        framework.reset_authorization_runtime().await;
453        for role in &snapshot.roles {
454            framework.authorization().create_role(role.clone()).await?;
455        }
456        for user in &snapshot.users {
457            for role_name in &user.roles {
458                framework
459                    .authorization()
460                    .assign_role(&user.id, role_name)
461                    .await?;
462            }
463        }
464    }
465
466    Ok(RestoreReport {
467        manifest: snapshot.manifest,
468        input_path,
469        reset_report,
470        dry_run,
471    })
472}
473
474fn build_migration_template(backend: &str, migration_name: &str, original_name: &str) -> String {
475    format!(
476        "-- AuthFramework migration template\n-- Backend: {backend}\n-- Name: {original_name}\n-- Generated at: {}\n\n-- Replace this placeholder with idempotent DDL for {migration_name}.\n-- Prefer CREATE TABLE IF NOT EXISTS / CREATE INDEX IF NOT EXISTS where supported.\n\nBEGIN;\n\n-- Add migration SQL here\n\nCOMMIT;\n",
477        Utc::now().to_rfc3339(),
478    )
479}
480
481pub async fn create_migration_file(config: &AppConfig, name: &str) -> Result<MigrationFileReport> {
482    let backend = backend_name_from_database_url(&config.database.url).to_string();
483    create_migration_template_for_backend(&backend, name).await
484}
485
486pub async fn create_migration_file_for_storage(
487    storage: &StorageConfig,
488    name: &str,
489) -> Result<MigrationFileReport> {
490    let backend = storage_backend_name(storage).to_string();
491    create_migration_template_for_backend(&backend, name).await
492}
493
494async fn create_migration_template_for_backend(
495    backend: &str,
496    name: &str,
497) -> Result<MigrationFileReport> {
498    let sanitized_name = sanitize_migration_name(name)?;
499    let directory = PathBuf::from("migrations").join(&backend);
500    tokio::fs::create_dir_all(&directory).await?;
501
502    let file_name = format!(
503        "{}_{}.sql",
504        Utc::now().format("%Y%m%d%H%M%S"),
505        sanitized_name
506    );
507    let path = directory.join(file_name);
508    let template = build_migration_template(backend, &sanitized_name, name);
509    tokio::fs::write(&path, template).await?;
510
511    Ok(MigrationFileReport {
512        backend: backend.to_string(),
513        path,
514    })
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520    use crate::config::AuthConfig;
521    use crate::methods::{AuthMethodEnum, JwtMethod};
522    use std::time::Duration;
523    use tempfile::tempdir;
524
525    async fn create_framework() -> AuthFramework {
526        let config = AuthConfig::new()
527            .secret("0123456789abcdef0123456789abcdef")
528            .token_lifetime(Duration::from_secs(3600));
529        let mut framework = AuthFramework::new(config);
530        framework.register_method("jwt", AuthMethodEnum::Jwt(JwtMethod::new()));
531        framework.initialize().await.unwrap();
532        framework
533    }
534
535    #[tokio::test]
536    async fn backup_restore_roundtrip_preserves_runtime_state() {
537        let framework = create_framework().await;
538        let user_id = framework
539            .users()
540            .register("alice", "alice@example.com", "Password123!")
541            .await
542            .unwrap();
543        framework
544            .authorization()
545            .create_role(Role::new("auditor"))
546            .await
547            .unwrap();
548        framework
549            .authorization()
550            .assign_role(&user_id, "auditor")
551            .await
552            .unwrap();
553        framework
554            .tokens()
555            .create(&user_id, &["read"], "jwt", None)
556            .await
557            .unwrap();
558        framework
559            .sessions()
560            .create(
561                &user_id,
562                Duration::from_secs(900),
563                Some("127.0.0.1".into()),
564                None,
565            )
566            .await
567            .unwrap();
568        framework
569            .storage()
570            .store_kv("custom:test", b"value", None)
571            .await
572            .unwrap();
573
574        let dir = tempdir().unwrap();
575        let path = dir.path().join("snapshot.json");
576
577        backup_to_file(&framework, &path, false).await.unwrap();
578        reset_runtime_data(&framework, false).await.unwrap();
579        assert!(
580            framework
581                .users()
582                .list_with_query(UserListQuery::new())
583                .await
584                .unwrap()
585                .is_empty()
586        );
587
588        restore_from_file(&framework, &path, false).await.unwrap();
589
590        let restored_user = framework.users().get(&user_id).await.unwrap();
591        assert_eq!(restored_user.username, "alice");
592        assert!(
593            framework
594                .authorization()
595                .has_role(&user_id, "auditor")
596                .await
597                .unwrap()
598        );
599        assert_eq!(
600            framework
601                .tokens()
602                .list_for_user(&user_id)
603                .await
604                .unwrap()
605                .len(),
606            1
607        );
608        assert_eq!(
609            framework
610                .sessions()
611                .list_for_user(&user_id)
612                .await
613                .unwrap()
614                .len(),
615            1
616        );
617        assert_eq!(
618            framework
619                .storage()
620                .get_kv("custom:test")
621                .await
622                .unwrap()
623                .unwrap(),
624            b"value"
625        );
626    }
627
628    #[tokio::test]
629    async fn reset_dry_run_leaves_state_unchanged() {
630        let framework = create_framework().await;
631        let user_id = framework
632            .users()
633            .register("bob", "bob@example.com", "Password123!")
634            .await
635            .unwrap();
636        framework
637            .storage()
638            .store_kv("custom:dry-run", b"present", None)
639            .await
640            .unwrap();
641
642        let report = reset_runtime_data(&framework, true).await.unwrap();
643        assert!(report.dry_run);
644        assert_eq!(
645            framework.users().get(&user_id).await.unwrap().username,
646            "bob"
647        );
648        assert!(
649            framework
650                .storage()
651                .get_kv("custom:dry-run")
652                .await
653                .unwrap()
654                .is_some()
655        );
656    }
657
658    #[tokio::test]
659    async fn create_migration_file_uses_backend_directory_and_sanitized_name() {
660        let dir = tempdir().unwrap();
661        let old_dir = std::env::current_dir().unwrap();
662        std::env::set_current_dir(dir.path()).unwrap();
663
664        let outcome = async {
665            let mut config = AppConfig::default();
666            config.database.url = "sqlite::memory:".to_string();
667            let report = create_migration_file(&config, "Add Audit Table!")
668                .await
669                .unwrap();
670            assert_eq!(report.backend, "sqlite");
671            assert!(
672                report
673                    .path
674                    .starts_with(Path::new("migrations").join("sqlite"))
675            );
676            assert!(
677                report
678                    .path
679                    .file_name()
680                    .unwrap()
681                    .to_string_lossy()
682                    .contains("add_audit_table")
683            );
684        }
685        .await;
686
687        std::env::set_current_dir(old_dir).unwrap();
688        outcome
689    }
690
691    #[tokio::test]
692    async fn backup_dry_run_does_not_write_file() {
693        let dir = tempdir().unwrap();
694        let path = dir.path().join("shouldnt_exist.json");
695        let framework = create_framework().await;
696        let report = backup_to_file(&framework, &path, true).await.unwrap();
697        assert!(report.dry_run);
698        assert!(!path.exists());
699    }
700
701    #[tokio::test]
702    async fn backup_empty_framework() {
703        let dir = tempdir().unwrap();
704        let path = dir.path().join("empty.json");
705        let framework = create_framework().await;
706        let report = backup_to_file(&framework, &path, false).await.unwrap();
707        assert_eq!(report.manifest.user_count, 0);
708        assert_eq!(report.manifest.token_count, 0);
709        assert_eq!(report.manifest.session_count, 0);
710        assert!(path.exists());
711    }
712
713    #[tokio::test]
714    async fn reset_clears_all_data() {
715        let framework = create_framework().await;
716        framework
717            .users()
718            .register("clear_me", "clear@example.com", "Password123!")
719            .await
720            .unwrap();
721        framework
722            .storage()
723            .store_kv("custom:keep", b"nope", None)
724            .await
725            .unwrap();
726
727        let report = reset_runtime_data(&framework, false).await.unwrap();
728        assert!(!report.dry_run);
729        assert!(report.users_deleted >= 1);
730        assert!(
731            framework
732                .users()
733                .list_with_query(UserListQuery::new())
734                .await
735                .unwrap()
736                .is_empty()
737        );
738        assert!(
739            framework
740                .storage()
741                .get_kv("custom:keep")
742                .await
743                .unwrap()
744                .is_none()
745        );
746    }
747
748    #[tokio::test]
749    async fn restore_nonexistent_file_fails() {
750        let framework = create_framework().await;
751        let result = restore_from_file(&framework, "/definitely/not/real.json", false).await;
752        assert!(result.is_err());
753    }
754
755    #[tokio::test]
756    async fn backup_manifest_has_checksum() {
757        let dir = tempdir().unwrap();
758        let path = dir.path().join("checksummed.json");
759        let framework = create_framework().await;
760        framework
761            .users()
762            .register("chk_user", "chk@example.com", "Password123!")
763            .await
764            .unwrap();
765        let report = backup_to_file(&framework, &path, false).await.unwrap();
766        assert!(!report.manifest.checksum_sha256.is_empty());
767        assert_eq!(report.manifest.format_version, SNAPSHOT_FORMAT_VERSION);
768    }
769}