use super::{Migration, MigrationError, MigrationRepository, MigrationSource, Result};
use std::sync::Arc;
pub struct MigrationService {
source: Arc<dyn MigrationSource>,
repository: Arc<tokio::sync::Mutex<dyn MigrationRepository>>,
}
impl MigrationService {
pub fn new(
source: Arc<dyn MigrationSource>,
repository: Arc<tokio::sync::Mutex<dyn MigrationRepository>>,
) -> Self {
Self { source, repository }
}
pub async fn load_all(&self) -> Result<Vec<Migration>> {
self.source.all_migrations().await
}
pub async fn load_for_app(&self, app_label: &str) -> Result<Vec<Migration>> {
self.source.migrations_for_app(app_label).await
}
pub async fn load_migration(&self, app_label: &str, name: &str) -> Result<Migration> {
self.source.get_migration(app_label, name).await
}
pub async fn save_migration(&self, migration: &Migration) -> Result<()> {
let mut repo = self.repository.lock().await;
repo.save(migration).await
}
pub async fn migration_exists(&self, app_label: &str, name: &str) -> Result<bool> {
let repo = self.repository.lock().await;
repo.exists(app_label, name).await
}
pub async fn list_saved_migrations(&self, app_label: &str) -> Result<Vec<Migration>> {
let repo = self.repository.lock().await;
repo.list(app_label).await
}
pub async fn delete_migration(&self, app_label: &str, name: &str) -> Result<()> {
let mut repo = self.repository.lock().await;
repo.delete(app_label, name).await
}
pub async fn build_dependency_graph(&self) -> Result<Vec<Migration>> {
let migrations = self.load_all().await?;
let mut graph: std::collections::HashMap<(String, String), Vec<(String, String)>> =
std::collections::HashMap::new();
let mut in_degree: std::collections::HashMap<(String, String), usize> =
std::collections::HashMap::new();
for migration in &migrations {
let key = (migration.app_label.to_string(), migration.name.to_string());
graph.insert(key.clone(), Vec::new());
in_degree.insert(key, 0);
}
for migration in &migrations {
let key = (migration.app_label.to_string(), migration.name.to_string());
for dep in &migration.dependencies {
let dep_key = (dep.0.to_string(), dep.1.to_string());
if let Some(deps) = graph.get_mut(&dep_key) {
deps.push(key.clone());
}
*in_degree.get_mut(&key).unwrap() += 1;
}
}
let mut queue: Vec<(String, String)> = in_degree
.iter()
.filter(|&(_, °ree)| degree == 0)
.map(|(k, _)| k.clone())
.collect();
let mut sorted = Vec::new();
while let Some(current) = queue.pop() {
if let Some(migration) = migrations
.iter()
.find(|m| m.app_label == current.0 && m.name == current.1)
{
sorted.push(migration.clone());
}
if let Some(neighbors) = graph.get(¤t) {
for neighbor in neighbors {
if let Some(degree) = in_degree.get_mut(neighbor) {
*degree -= 1;
if *degree == 0 {
queue.push(neighbor.clone());
}
}
}
}
}
if sorted.len() != migrations.len() {
return Err(MigrationError::CircularDependency {
cycle: "Circular dependency detected in migrations".to_string(),
});
}
Ok(sorted)
}
pub async fn detect_new_migrations(&self, app_label: &str) -> Result<Vec<Migration>> {
let source_migrations = self.load_for_app(app_label).await?;
let saved_migrations = self.list_saved_migrations(app_label).await?;
let saved_names: std::collections::HashSet<_> =
saved_migrations.iter().map(|m| &m.name).collect();
Ok(source_migrations
.into_iter()
.filter(|m| !saved_names.contains(&m.name))
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::migrations::source::MigrationSource;
use async_trait::async_trait;
use std::collections::HashMap;
use tokio::sync::Mutex;
struct TestSource {
migrations: Vec<Migration>,
}
#[async_trait]
impl MigrationSource for TestSource {
async fn all_migrations(&self) -> Result<Vec<Migration>> {
Ok(self.migrations.clone())
}
}
struct TestRepository {
migrations: HashMap<(String, String), Migration>,
}
impl TestRepository {
fn new() -> Self {
Self {
migrations: HashMap::new(),
}
}
}
#[async_trait]
impl MigrationRepository for TestRepository {
async fn save(&mut self, migration: &Migration) -> Result<()> {
let key = (migration.app_label.to_string(), migration.name.to_string());
self.migrations.insert(key, migration.clone());
Ok(())
}
async fn get(&self, app_label: &str, name: &str) -> Result<Migration> {
let key = (app_label.to_string(), name.to_string());
self.migrations
.get(&key)
.cloned()
.ok_or_else(|| MigrationError::NotFound(format!("{}.{}", app_label, name)))
}
async fn list(&self, app_label: &str) -> Result<Vec<Migration>> {
Ok(self
.migrations
.values()
.filter(|m| m.app_label == app_label)
.cloned()
.collect())
}
async fn exists(&self, app_label: &str, name: &str) -> Result<bool> {
let key = (app_label.to_string(), name.to_string());
Ok(self.migrations.contains_key(&key))
}
async fn delete(&mut self, app_label: &str, name: &str) -> Result<()> {
let key = (app_label.to_string(), name.to_string());
self.migrations
.remove(&key)
.ok_or_else(|| MigrationError::NotFound(format!("{}.{}", app_label, name)))?;
Ok(())
}
}
fn create_test_migration(app_label: &str, name: &str) -> Migration {
Migration {
app_label: app_label.to_string(),
name: name.to_string(),
operations: vec![],
dependencies: vec![],
atomic: true,
initial: None,
replaces: vec![],
state_only: false,
database_only: false,
swappable_dependencies: vec![],
optional_dependencies: vec![],
}
}
#[tokio::test]
async fn test_migration_service_load_all() {
let source = Arc::new(TestSource {
migrations: vec![
create_test_migration("polls", "0001_initial"),
create_test_migration("users", "0001_initial"),
],
});
let repository = Arc::new(Mutex::new(TestRepository::new()));
let service = MigrationService::new(source, repository);
let migrations = service.load_all().await.unwrap();
assert_eq!(migrations.len(), 2);
}
#[tokio::test]
async fn test_migration_service_load_for_app() {
let source = Arc::new(TestSource {
migrations: vec![
create_test_migration("polls", "0001_initial"),
create_test_migration("polls", "0002_add_field"),
create_test_migration("users", "0001_initial"),
],
});
let repository = Arc::new(Mutex::new(TestRepository::new()));
let service = MigrationService::new(source, repository);
let polls_migrations = service.load_for_app("polls").await.unwrap();
assert_eq!(polls_migrations.len(), 2);
}
#[tokio::test]
async fn test_migration_service_save_and_load() {
let source = Arc::new(TestSource {
migrations: vec![create_test_migration("polls", "0001_initial")],
});
let repository = Arc::new(Mutex::new(TestRepository::new()));
let service = MigrationService::new(source, repository);
let migration = create_test_migration("polls", "0001_initial");
service.save_migration(&migration).await.unwrap();
assert!(
service
.migration_exists("polls", "0001_initial")
.await
.unwrap()
);
}
#[tokio::test]
async fn test_migration_service_dependency_graph() {
let source = Arc::new(TestSource {
migrations: vec![
create_test_migration("polls", "0001_initial"),
Migration {
app_label: "polls".to_string(),
name: "0002_add_field".to_string(),
operations: vec![],
dependencies: vec![("polls".to_string(), "0001_initial".to_string())],
atomic: true,
initial: None,
replaces: vec![],
state_only: false,
database_only: false,
swappable_dependencies: vec![],
optional_dependencies: vec![],
},
],
});
let repository = Arc::new(Mutex::new(TestRepository::new()));
let service = MigrationService::new(source, repository);
let sorted = service.build_dependency_graph().await.unwrap();
assert_eq!(sorted.len(), 2);
assert_eq!(sorted[0].name, "0001_initial");
assert_eq!(sorted[1].name, "0002_add_field");
}
#[tokio::test]
async fn test_migration_service_detect_new_migrations() {
let source = Arc::new(TestSource {
migrations: vec![
create_test_migration("polls", "0001_initial"),
create_test_migration("polls", "0002_add_field"),
],
});
let repository = Arc::new(Mutex::new(TestRepository::new()));
let service = MigrationService::new(source.clone(), repository);
service
.save_migration(&create_test_migration("polls", "0001_initial"))
.await
.unwrap();
let new_migrations = service.detect_new_migrations("polls").await.unwrap();
assert_eq!(new_migrations.len(), 1);
assert_eq!(new_migrations[0].name, "0002_add_field");
}
#[tokio::test]
async fn test_migration_service_delete() {
let source = Arc::new(TestSource {
migrations: vec![create_test_migration("polls", "0001_initial")],
});
let repository = Arc::new(Mutex::new(TestRepository::new()));
let service = MigrationService::new(source, repository);
let migration = create_test_migration("polls", "0001_initial");
service.save_migration(&migration).await.unwrap();
assert!(
service
.migration_exists("polls", "0001_initial")
.await
.unwrap()
);
service
.delete_migration("polls", "0001_initial")
.await
.unwrap();
assert!(
!service
.migration_exists("polls", "0001_initial")
.await
.unwrap()
);
}
}