Skip to main content

reinhardt_db/migrations/
source.rs

1//! Migration source abstraction
2//!
3//! This module defines the `MigrationSource` trait, which abstracts where migrations come from.
4//! Multiple sources can be combined using the Composite pattern.
5
6pub mod composite;
7pub mod filesystem;
8pub mod registry;
9
10use super::{Migration, MigrationError, Result};
11use async_trait::async_trait;
12
13/// Trait for loading migrations from various sources
14///
15/// Implementations:
16/// - `RegistrySource`: Loads from compile-time registered migrations (linkme)
17/// - `FilesystemSource`: Loads from .rs files on disk
18/// - `CompositeSource`: Combines multiple sources
19/// - `TestMigrationSource`: In-memory source for testing
20#[async_trait]
21pub trait MigrationSource: Send + Sync {
22	/// Returns all migrations from this source
23	async fn all_migrations(&self) -> Result<Vec<Migration>>;
24
25	/// Returns migrations for a specific app
26	async fn migrations_for_app(&self, app_label: &str) -> Result<Vec<Migration>> {
27		let all = self.all_migrations().await?;
28		Ok(all
29			.into_iter()
30			.filter(|m| m.app_label == app_label)
31			.collect())
32	}
33
34	/// Returns a specific migration by app and name
35	async fn get_migration(&self, app_label: &str, name: &str) -> Result<Migration> {
36		let migrations = self.migrations_for_app(app_label).await?;
37		migrations
38			.into_iter()
39			.find(|m| m.name == name)
40			.ok_or_else(|| MigrationError::NotFound(format!("{}.{}", app_label, name)))
41	}
42}
43
44#[cfg(test)]
45mod tests {
46	use super::*;
47
48	/// Test helper to create a migration
49	fn create_test_migration(app_label: &str, name: &str) -> Migration {
50		Migration {
51			app_label: app_label.to_string(),
52			name: name.to_string(),
53			operations: vec![],
54			dependencies: vec![],
55			atomic: true,
56			initial: None,
57			replaces: vec![],
58			state_only: false,
59			database_only: false,
60			swappable_dependencies: vec![],
61			optional_dependencies: vec![],
62		}
63	}
64
65	/// Test MigrationSource implementation for unit tests
66	struct TestSource {
67		migrations: Vec<Migration>,
68	}
69
70	#[async_trait]
71	impl MigrationSource for TestSource {
72		async fn all_migrations(&self) -> Result<Vec<Migration>> {
73			Ok(self.migrations.clone())
74		}
75	}
76
77	#[tokio::test]
78	async fn test_all_migrations() {
79		let source = TestSource {
80			migrations: vec![
81				create_test_migration("polls", "0001_initial"),
82				create_test_migration("polls", "0002_add_field"),
83			],
84		};
85
86		let all = source.all_migrations().await.unwrap();
87		assert_eq!(all.len(), 2);
88		assert_eq!(all[0].app_label, "polls");
89		assert_eq!(all[0].name, "0001_initial");
90	}
91
92	#[tokio::test]
93	async fn test_migrations_for_app() {
94		let source = TestSource {
95			migrations: vec![
96				create_test_migration("polls", "0001_initial"),
97				create_test_migration("users", "0001_initial"),
98				create_test_migration("polls", "0002_add_field"),
99			],
100		};
101
102		let polls_migrations = source.migrations_for_app("polls").await.unwrap();
103		assert_eq!(polls_migrations.len(), 2);
104		assert!(polls_migrations.iter().all(|m| m.app_label == "polls"));
105
106		let users_migrations = source.migrations_for_app("users").await.unwrap();
107		assert_eq!(users_migrations.len(), 1);
108		assert_eq!(users_migrations[0].name, "0001_initial");
109	}
110
111	#[tokio::test]
112	async fn test_get_migration() {
113		let source = TestSource {
114			migrations: vec![
115				create_test_migration("polls", "0001_initial"),
116				create_test_migration("polls", "0002_add_field"),
117			],
118		};
119
120		let migration = source.get_migration("polls", "0001_initial").await.unwrap();
121		assert_eq!(migration.app_label, "polls");
122		assert_eq!(migration.name, "0001_initial");
123	}
124
125	#[tokio::test]
126	async fn test_get_migration_not_found() {
127		let source = TestSource {
128			migrations: vec![create_test_migration("polls", "0001_initial")],
129		};
130
131		let result = source.get_migration("polls", "0002_nonexistent").await;
132		assert!(result.is_err());
133		assert!(matches!(result.unwrap_err(), MigrationError::NotFound(_)));
134	}
135}