1use crate::auth::AuthFramework;
16use crate::auth_operations::UserListQuery;
17use crate::config::{StorageConfig, app_config::AppConfig};
18use crate::errors::{AuthError, Result};
19use crate::permissions::Role;
20use crate::storage::SessionData;
21use crate::tokens::AuthToken;
22use base64::Engine;
23use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
24use chrono::{DateTime, Utc};
25use serde::{Deserialize, Serialize};
26use serde_json::{Map, Value};
27use sha2::{Digest, Sha256};
28use std::collections::HashSet;
29use std::path::{Path, PathBuf};
30
31const SNAPSHOT_FORMAT_VERSION: u32 = 1;
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct SnapshotManifest {
35 pub format_version: u32,
36 pub created_at: DateTime<Utc>,
37 pub storage_backend: String,
38 pub user_count: usize,
39 pub role_count: usize,
40 pub token_count: usize,
41 pub session_count: usize,
42 pub kv_entry_count: usize,
43 pub checksum_sha256: String,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct SnapshotUserSummary {
48 pub id: String,
49 pub username: String,
50 pub email: Option<String>,
51 pub roles: Vec<String>,
52 pub active: bool,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SnapshotKvEntry {
57 pub key: String,
58 pub value_base64: String,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct MaintenanceSnapshot {
63 pub manifest: SnapshotManifest,
64 pub users: Vec<SnapshotUserSummary>,
65 pub roles: Vec<Role>,
66 pub tokens: Vec<AuthToken>,
67 pub sessions: Vec<SessionData>,
68 pub kv_entries: Vec<SnapshotKvEntry>,
69}
70
71#[derive(Debug, Clone)]
72pub struct BackupReport {
73 pub manifest: SnapshotManifest,
74 pub output_path: PathBuf,
75 pub dry_run: bool,
76}
77
78#[derive(Debug, Clone)]
79pub struct ResetReport {
80 pub users_deleted: usize,
81 pub roles_seen: usize,
82 pub tokens_deleted: usize,
83 pub sessions_deleted: usize,
84 pub kv_entries_deleted: usize,
85 pub dry_run: bool,
86}
87
88#[derive(Debug, Clone)]
89pub struct RestoreReport {
90 pub manifest: SnapshotManifest,
91 pub input_path: PathBuf,
92 pub reset_report: ResetReport,
93 pub dry_run: bool,
94}
95
96#[derive(Debug, Clone)]
97pub struct MigrationFileReport {
98 pub backend: String,
99 pub path: PathBuf,
100}
101
102#[derive(Serialize)]
103struct SnapshotChecksumPayload<'a> {
104 users: &'a [SnapshotUserSummary],
105 roles: &'a [Role],
106 tokens: &'a [AuthToken],
107 sessions: &'a [SessionData],
108 kv_entries: &'a [SnapshotKvEntry],
109}
110
111fn normalize_json_value(value: Value) -> Value {
112 match value {
113 Value::Array(values) => {
114 let mut normalized = values
115 .into_iter()
116 .map(normalize_json_value)
117 .collect::<Vec<_>>();
118 normalized.sort_by(|left, right| left.to_string().cmp(&right.to_string()));
119 Value::Array(normalized)
120 }
121 Value::Object(object) => {
122 let mut entries = object.into_iter().collect::<Vec<_>>();
123 entries.sort_by(|left, right| left.0.cmp(&right.0));
124
125 let normalized = entries
126 .into_iter()
127 .map(|(key, value)| (key, normalize_json_value(value)))
128 .collect::<Map<String, Value>>();
129 Value::Object(normalized)
130 }
131 other => other,
132 }
133}
134
135fn storage_backend_name(config: &StorageConfig) -> &'static str {
136 match config {
137 StorageConfig::Memory => "memory",
138 #[cfg(feature = "postgres-storage")]
139 StorageConfig::Postgres { .. } => "postgres",
140 #[cfg(feature = "redis-storage")]
141 StorageConfig::Redis { .. } => "redis",
142 #[cfg(feature = "mysql-storage")]
143 StorageConfig::MySQL { .. } => "mysql",
144 #[cfg(feature = "sqlite-storage")]
145 StorageConfig::Sqlite { .. } => "sqlite",
146 StorageConfig::Custom(_) => "custom",
147 }
148}
149
150fn backend_name_from_database_url(database_url: &str) -> &'static str {
151 let database_url = database_url.trim().to_ascii_lowercase();
152
153 if database_url.starts_with("postgres://") || database_url.starts_with("postgresql://") {
154 "postgres"
155 } else if database_url.starts_with("mysql://") {
156 "mysql"
157 } else if database_url.starts_with("sqlite:") || database_url.ends_with(".db") {
158 "sqlite"
159 } else if database_url.starts_with("redis://") || database_url.starts_with("rediss://") {
160 "redis"
161 } else if database_url.is_empty() {
162 "memory"
163 } else {
164 "custom"
165 }
166}
167
168fn checksum_snapshot(
169 users: &[SnapshotUserSummary],
170 roles: &[Role],
171 tokens: &[AuthToken],
172 sessions: &[SessionData],
173 kv_entries: &[SnapshotKvEntry],
174) -> Result<String> {
175 let payload = SnapshotChecksumPayload {
176 users,
177 roles,
178 tokens,
179 sessions,
180 kv_entries,
181 };
182 let encoded = serde_json::to_value(&payload)
183 .map(normalize_json_value)
184 .and_then(|value| serde_json::to_vec(&value))
185 .map_err(|e| AuthError::internal(format!("Failed to serialize snapshot payload: {e}")))?;
186 let mut hasher = Sha256::new();
187 hasher.update(encoded);
188 Ok(hex::encode(hasher.finalize()))
189}
190
191fn sanitize_migration_name(name: &str) -> Result<String> {
192 let sanitized = name
193 .trim()
194 .chars()
195 .map(|character| {
196 if character.is_ascii_alphanumeric() {
197 character.to_ascii_lowercase()
198 } else {
199 '_'
200 }
201 })
202 .collect::<String>();
203
204 let collapsed = sanitized
205 .split('_')
206 .filter(|segment| !segment.is_empty())
207 .collect::<Vec<_>>()
208 .join("_");
209
210 if collapsed.is_empty() {
211 return Err(AuthError::validation(
212 "Migration name must contain at least one alphanumeric character",
213 ));
214 }
215
216 Ok(collapsed)
217}
218
219async fn collect_snapshot(framework: &AuthFramework) -> Result<MaintenanceSnapshot> {
220 let storage = framework.storage();
221
222 let mut users = framework
223 .users()
224 .list_with_query(UserListQuery::new())
225 .await?;
226 users.sort_by(|left, right| left.id.cmp(&right.id));
227
228 let mut snapshot_users = Vec::with_capacity(users.len());
229 for user in &users {
230 let mut roles: HashSet<String> = user.roles.iter().cloned().collect();
231 roles.extend(framework.authorization().roles_for_user(&user.id).await?);
232 let mut roles = roles.into_iter().collect::<Vec<_>>();
233 roles.sort();
234
235 snapshot_users.push(SnapshotUserSummary {
236 id: user.id.clone(),
237 username: user.username.clone(),
238 email: user.email.clone(),
239 roles,
240 active: user.active,
241 });
242 }
243
244 let mut roles = framework.authorization().list_roles().await;
245 roles.sort_by(|left, right| left.name.cmp(&right.name));
246
247 let mut tokens = Vec::new();
248 let mut seen_tokens = HashSet::new();
249 let mut sessions = Vec::new();
250 let mut seen_sessions = HashSet::new();
251
252 for user in &users {
253 for token in framework.tokens().list_for_user(&user.id).await? {
254 if seen_tokens.insert(token.token_id.clone()) {
255 tokens.push(token);
256 }
257 }
258
259 for session in framework.sessions().list_for_user(&user.id).await? {
260 if seen_sessions.insert(session.session_id.clone()) {
261 sessions.push(session);
262 }
263 }
264 }
265
266 tokens.sort_by(|left, right| left.token_id.cmp(&right.token_id));
267 sessions.sort_by(|left, right| left.session_id.cmp(&right.session_id));
268
269 let mut kv_keys = storage.list_kv_keys("").await?;
270 kv_keys.sort();
271 kv_keys.dedup();
272
273 let mut kv_entries = Vec::with_capacity(kv_keys.len());
274 for key in kv_keys {
275 if let Some(value) = storage.get_kv(&key).await? {
276 kv_entries.push(SnapshotKvEntry {
277 key,
278 value_base64: BASE64_STANDARD.encode(value),
279 });
280 }
281 }
282
283 let manifest = SnapshotManifest {
284 format_version: SNAPSHOT_FORMAT_VERSION,
285 created_at: Utc::now(),
286 storage_backend: storage_backend_name(&framework.config().storage).to_string(),
287 user_count: snapshot_users.len(),
288 role_count: roles.len(),
289 token_count: tokens.len(),
290 session_count: sessions.len(),
291 kv_entry_count: kv_entries.len(),
292 checksum_sha256: checksum_snapshot(
293 &snapshot_users,
294 &roles,
295 &tokens,
296 &sessions,
297 &kv_entries,
298 )?,
299 };
300
301 Ok(MaintenanceSnapshot {
302 manifest,
303 users: snapshot_users,
304 roles,
305 tokens,
306 sessions,
307 kv_entries,
308 })
309}
310
311fn validate_snapshot(snapshot: &MaintenanceSnapshot) -> Result<()> {
312 if snapshot.manifest.format_version != SNAPSHOT_FORMAT_VERSION {
313 return Err(AuthError::configuration(format!(
314 "Unsupported snapshot format version {}",
315 snapshot.manifest.format_version
316 )));
317 }
318
319 let expected_checksum = checksum_snapshot(
320 &snapshot.users,
321 &snapshot.roles,
322 &snapshot.tokens,
323 &snapshot.sessions,
324 &snapshot.kv_entries,
325 )?;
326
327 if expected_checksum != snapshot.manifest.checksum_sha256 {
328 return Err(AuthError::validation(
329 "Snapshot checksum validation failed; restore aborted",
330 ));
331 }
332
333 Ok(())
334}
335
336pub async fn backup_to_file(
337 framework: &AuthFramework,
338 output_path: impl AsRef<Path>,
339 dry_run: bool,
340) -> Result<BackupReport> {
341 let output_path = output_path.as_ref().to_path_buf();
342 let snapshot = collect_snapshot(framework).await?;
343
344 if !dry_run {
345 if let Some(parent) = output_path.parent() {
346 if !parent.as_os_str().is_empty() {
347 tokio::fs::create_dir_all(parent).await?;
348 }
349 }
350
351 let data = serde_json::to_vec_pretty(&snapshot).map_err(|e| {
352 AuthError::internal(format!("Failed to serialize maintenance snapshot: {e}"))
353 })?;
354 tokio::fs::write(&output_path, data).await?;
355 }
356
357 Ok(BackupReport {
358 manifest: snapshot.manifest,
359 output_path,
360 dry_run,
361 })
362}
363
364pub async fn reset_runtime_data(framework: &AuthFramework, dry_run: bool) -> Result<ResetReport> {
365 let storage = framework.storage();
366 let users = framework
367 .users()
368 .list_with_query(UserListQuery::new())
369 .await?;
370 let roles = framework.authorization().list_roles().await;
371
372 let mut token_ids = HashSet::new();
373 let mut session_ids = HashSet::new();
374 for user in &users {
375 for token in framework.tokens().list_for_user(&user.id).await? {
376 token_ids.insert(token.token_id);
377 }
378
379 for session in framework.sessions().list_for_user(&user.id).await? {
380 session_ids.insert(session.session_id);
381 }
382 }
383
384 let mut kv_keys = storage.list_kv_keys("").await?;
385 kv_keys.sort();
386 kv_keys.dedup();
387
388 if !dry_run {
389 for token_id in &token_ids {
390 storage.delete_token(token_id).await?;
391 }
392
393 for session_id in &session_ids {
394 storage.delete_session(session_id).await?;
395 }
396
397 for user in &users {
398 framework.users().delete_by_id(&user.id).await?;
399 }
400
401 for key in &kv_keys {
402 storage.delete_kv(key).await?;
403 }
404
405 framework.reset_authorization_runtime().await;
406 }
407
408 Ok(ResetReport {
409 users_deleted: users.len(),
410 roles_seen: roles.len(),
411 tokens_deleted: token_ids.len(),
412 sessions_deleted: session_ids.len(),
413 kv_entries_deleted: kv_keys.len(),
414 dry_run,
415 })
416}
417
418pub async fn restore_from_file(
419 framework: &AuthFramework,
420 input_path: impl AsRef<Path>,
421 dry_run: bool,
422) -> Result<RestoreReport> {
423 let input_path = input_path.as_ref().to_path_buf();
424 let data = tokio::fs::read(&input_path).await?;
425 let snapshot: MaintenanceSnapshot = serde_json::from_slice(&data)
426 .map_err(|e| AuthError::validation(format!("Failed to parse maintenance snapshot: {e}")))?;
427 validate_snapshot(&snapshot)?;
428
429 let reset_report = reset_runtime_data(framework, dry_run).await?;
430
431 if !dry_run {
432 let storage = framework.storage();
433
434 for entry in &snapshot.kv_entries {
435 let value = BASE64_STANDARD.decode(&entry.value_base64).map_err(|e| {
436 AuthError::validation(format!(
437 "Snapshot KV entry '{}' is not valid base64: {e}",
438 entry.key
439 ))
440 })?;
441 storage.store_kv(&entry.key, &value, None).await?;
442 }
443
444 for token in &snapshot.tokens {
445 storage.store_token(token).await?;
446 }
447
448 for session in &snapshot.sessions {
449 storage.store_session(&session.session_id, session).await?;
450 }
451
452 framework.reset_authorization_runtime().await;
453 for role in &snapshot.roles {
454 framework.authorization().create_role(role.clone()).await?;
455 }
456 for user in &snapshot.users {
457 for role_name in &user.roles {
458 framework
459 .authorization()
460 .assign_role(&user.id, role_name)
461 .await?;
462 }
463 }
464 }
465
466 Ok(RestoreReport {
467 manifest: snapshot.manifest,
468 input_path,
469 reset_report,
470 dry_run,
471 })
472}
473
474fn build_migration_template(backend: &str, migration_name: &str, original_name: &str) -> String {
475 format!(
476 "-- AuthFramework migration template\n-- Backend: {backend}\n-- Name: {original_name}\n-- Generated at: {}\n\n-- Replace this placeholder with idempotent DDL for {migration_name}.\n-- Prefer CREATE TABLE IF NOT EXISTS / CREATE INDEX IF NOT EXISTS where supported.\n\nBEGIN;\n\n-- Add migration SQL here\n\nCOMMIT;\n",
477 Utc::now().to_rfc3339(),
478 )
479}
480
481pub async fn create_migration_file(config: &AppConfig, name: &str) -> Result<MigrationFileReport> {
482 let backend = backend_name_from_database_url(&config.database.url).to_string();
483 create_migration_template_for_backend(&backend, name).await
484}
485
486pub async fn create_migration_file_for_storage(
487 storage: &StorageConfig,
488 name: &str,
489) -> Result<MigrationFileReport> {
490 let backend = storage_backend_name(storage).to_string();
491 create_migration_template_for_backend(&backend, name).await
492}
493
494async fn create_migration_template_for_backend(
495 backend: &str,
496 name: &str,
497) -> Result<MigrationFileReport> {
498 let sanitized_name = sanitize_migration_name(name)?;
499 let directory = PathBuf::from("migrations").join(&backend);
500 tokio::fs::create_dir_all(&directory).await?;
501
502 let file_name = format!(
503 "{}_{}.sql",
504 Utc::now().format("%Y%m%d%H%M%S"),
505 sanitized_name
506 );
507 let path = directory.join(file_name);
508 let template = build_migration_template(backend, &sanitized_name, name);
509 tokio::fs::write(&path, template).await?;
510
511 Ok(MigrationFileReport {
512 backend: backend.to_string(),
513 path,
514 })
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520 use crate::config::AuthConfig;
521 use crate::methods::{AuthMethodEnum, JwtMethod};
522 use std::time::Duration;
523 use tempfile::tempdir;
524
525 async fn create_framework() -> AuthFramework {
526 let config = AuthConfig::new()
527 .secret("0123456789abcdef0123456789abcdef")
528 .token_lifetime(Duration::from_secs(3600));
529 let mut framework = AuthFramework::new(config);
530 framework.register_method("jwt", AuthMethodEnum::Jwt(JwtMethod::new()));
531 framework.initialize().await.unwrap();
532 framework
533 }
534
535 #[tokio::test]
536 async fn backup_restore_roundtrip_preserves_runtime_state() {
537 let framework = create_framework().await;
538 let user_id = framework
539 .users()
540 .register("alice", "alice@example.com", "Password123!")
541 .await
542 .unwrap();
543 framework
544 .authorization()
545 .create_role(Role::new("auditor"))
546 .await
547 .unwrap();
548 framework
549 .authorization()
550 .assign_role(&user_id, "auditor")
551 .await
552 .unwrap();
553 framework
554 .tokens()
555 .create(&user_id, &["read"], "jwt", None)
556 .await
557 .unwrap();
558 framework
559 .sessions()
560 .create(
561 &user_id,
562 Duration::from_secs(900),
563 Some("127.0.0.1".into()),
564 None,
565 )
566 .await
567 .unwrap();
568 framework
569 .storage()
570 .store_kv("custom:test", b"value", None)
571 .await
572 .unwrap();
573
574 let dir = tempdir().unwrap();
575 let path = dir.path().join("snapshot.json");
576
577 backup_to_file(&framework, &path, false).await.unwrap();
578 reset_runtime_data(&framework, false).await.unwrap();
579 assert!(
580 framework
581 .users()
582 .list_with_query(UserListQuery::new())
583 .await
584 .unwrap()
585 .is_empty()
586 );
587
588 restore_from_file(&framework, &path, false).await.unwrap();
589
590 let restored_user = framework.users().get(&user_id).await.unwrap();
591 assert_eq!(restored_user.username, "alice");
592 assert!(
593 framework
594 .authorization()
595 .has_role(&user_id, "auditor")
596 .await
597 .unwrap()
598 );
599 assert_eq!(
600 framework
601 .tokens()
602 .list_for_user(&user_id)
603 .await
604 .unwrap()
605 .len(),
606 1
607 );
608 assert_eq!(
609 framework
610 .sessions()
611 .list_for_user(&user_id)
612 .await
613 .unwrap()
614 .len(),
615 1
616 );
617 assert_eq!(
618 framework
619 .storage()
620 .get_kv("custom:test")
621 .await
622 .unwrap()
623 .unwrap(),
624 b"value"
625 );
626 }
627
628 #[tokio::test]
629 async fn reset_dry_run_leaves_state_unchanged() {
630 let framework = create_framework().await;
631 let user_id = framework
632 .users()
633 .register("bob", "bob@example.com", "Password123!")
634 .await
635 .unwrap();
636 framework
637 .storage()
638 .store_kv("custom:dry-run", b"present", None)
639 .await
640 .unwrap();
641
642 let report = reset_runtime_data(&framework, true).await.unwrap();
643 assert!(report.dry_run);
644 assert_eq!(
645 framework.users().get(&user_id).await.unwrap().username,
646 "bob"
647 );
648 assert!(
649 framework
650 .storage()
651 .get_kv("custom:dry-run")
652 .await
653 .unwrap()
654 .is_some()
655 );
656 }
657
658 #[tokio::test]
659 async fn create_migration_file_uses_backend_directory_and_sanitized_name() {
660 let dir = tempdir().unwrap();
661 let old_dir = std::env::current_dir().unwrap();
662 std::env::set_current_dir(dir.path()).unwrap();
663
664 let outcome = async {
665 let mut config = AppConfig::default();
666 config.database.url = "sqlite::memory:".to_string();
667 let report = create_migration_file(&config, "Add Audit Table!")
668 .await
669 .unwrap();
670 assert_eq!(report.backend, "sqlite");
671 assert!(
672 report
673 .path
674 .starts_with(Path::new("migrations").join("sqlite"))
675 );
676 assert!(
677 report
678 .path
679 .file_name()
680 .unwrap()
681 .to_string_lossy()
682 .contains("add_audit_table")
683 );
684 }
685 .await;
686
687 std::env::set_current_dir(old_dir).unwrap();
688 outcome
689 }
690
691 #[tokio::test]
692 async fn backup_dry_run_does_not_write_file() {
693 let dir = tempdir().unwrap();
694 let path = dir.path().join("shouldnt_exist.json");
695 let framework = create_framework().await;
696 let report = backup_to_file(&framework, &path, true).await.unwrap();
697 assert!(report.dry_run);
698 assert!(!path.exists());
699 }
700
701 #[tokio::test]
702 async fn backup_empty_framework() {
703 let dir = tempdir().unwrap();
704 let path = dir.path().join("empty.json");
705 let framework = create_framework().await;
706 let report = backup_to_file(&framework, &path, false).await.unwrap();
707 assert_eq!(report.manifest.user_count, 0);
708 assert_eq!(report.manifest.token_count, 0);
709 assert_eq!(report.manifest.session_count, 0);
710 assert!(path.exists());
711 }
712
713 #[tokio::test]
714 async fn reset_clears_all_data() {
715 let framework = create_framework().await;
716 framework
717 .users()
718 .register("clear_me", "clear@example.com", "Password123!")
719 .await
720 .unwrap();
721 framework
722 .storage()
723 .store_kv("custom:keep", b"nope", None)
724 .await
725 .unwrap();
726
727 let report = reset_runtime_data(&framework, false).await.unwrap();
728 assert!(!report.dry_run);
729 assert!(report.users_deleted >= 1);
730 assert!(
731 framework
732 .users()
733 .list_with_query(UserListQuery::new())
734 .await
735 .unwrap()
736 .is_empty()
737 );
738 assert!(
739 framework
740 .storage()
741 .get_kv("custom:keep")
742 .await
743 .unwrap()
744 .is_none()
745 );
746 }
747
748 #[tokio::test]
749 async fn restore_nonexistent_file_fails() {
750 let framework = create_framework().await;
751 let result = restore_from_file(&framework, "/definitely/not/real.json", false).await;
752 assert!(result.is_err());
753 }
754
755 #[tokio::test]
756 async fn backup_manifest_has_checksum() {
757 let dir = tempdir().unwrap();
758 let path = dir.path().join("checksummed.json");
759 let framework = create_framework().await;
760 framework
761 .users()
762 .register("chk_user", "chk@example.com", "Password123!")
763 .await
764 .unwrap();
765 let report = backup_to_file(&framework, &path, false).await.unwrap();
766 assert!(!report.manifest.checksum_sha256.is_empty());
767 assert_eq!(report.manifest.format_version, SNAPSHOT_FORMAT_VERSION);
768 }
769}