use crate::{Result, Error};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MigrationMeta {
pub id: Uuid,
pub version: u64,
pub name: String,
pub description: String,
pub created_at: chrono::DateTime<chrono::Utc>,
pub backends: Vec<String>,
}
#[async_trait]
pub trait Migration: Send + Sync {
fn meta(&self) -> MigrationMeta;
async fn up(&self, ctx: &mut MigrationContext) -> Result<()>;
async fn down(&self, ctx: &mut MigrationContext) -> Result<()>;
}
pub struct MigrationContext {
pub backends: HashMap<String, Box<dyn std::any::Any + Send + Sync>>,
pub work_dir: PathBuf,
pub data: HashMap<String, serde_json::Value>,
}
impl MigrationContext {
pub fn new(work_dir: PathBuf) -> Self {
Self {
backends: HashMap::new(),
work_dir,
data: HashMap::new(),
}
}
pub fn with_backend<T: 'static + Send + Sync>(
mut self,
name: &str,
backend: T,
) -> Self {
self.backends.insert(name.to_string(), Box::new(backend));
self
}
pub fn get_backend<T: 'static>(&self, name: &str) -> Result<&T> {
self.backends
.get(name)
.and_then(|b| b.downcast_ref::<T>())
.ok_or_else(|| Error::not_found(format!("Backend '{}' not found or wrong type", name)))
}
pub fn get_backend_mut<T: 'static>(&mut self, name: &str) -> Result<&mut T> {
self.backends
.get_mut(name)
.and_then(|b| b.downcast_mut::<T>())
.ok_or_else(|| Error::not_found(format!("Backend '{}' not found or wrong type", name)))
}
}
pub struct MigrationRegistry {
migrations: HashMap<Uuid, Box<dyn Migration>>,
versions: HashMap<u64, Uuid>,
}
impl MigrationRegistry {
pub fn new() -> Self {
Self {
migrations: HashMap::new(),
versions: HashMap::new(),
}
}
pub fn register<M: Migration + 'static>(&mut self, migration: M) -> Result<()> {
let meta = migration.meta();
let id = meta.id;
let version = meta.version;
if self.versions.contains_key(&version) {
return Err(Error::invalid_input(format!(
"Migration version {} already exists",
version
)));
}
self.migrations.insert(id, Box::new(migration));
self.versions.insert(version, id);
Ok(())
}
pub fn get(&self, id: &Uuid) -> Option<&dyn Migration> {
self.migrations.get(id).map(|m| m.as_ref())
}
pub fn all_sorted(&self) -> Vec<&dyn Migration> {
let mut migrations: Vec<&dyn Migration> = self.migrations.values().map(|m| m.as_ref()).collect();
migrations.sort_by_key(|m| m.meta().version);
migrations
}
pub fn between(&self, from_version: u64, to_version: u64) -> Vec<&dyn Migration> {
self.all_sorted()
.into_iter()
.filter(|m| {
let v = m.meta().version;
v >= from_version && v <= to_version
})
.collect()
}
}
impl Default for MigrationRegistry {
fn default() -> Self {
Self::new()
}
}
pub struct MigrationRunner {
registry: MigrationRegistry,
work_dir: PathBuf,
}
impl MigrationRunner {
pub fn new(registry: MigrationRegistry, work_dir: PathBuf) -> Self {
Self { registry, work_dir }
}
pub async fn migrate_up(
&self,
target_version: Option<u64>,
backends: HashMap<String, Box<dyn std::any::Any + Send + Sync>>,
) -> Result<Vec<MigrationResult>> {
let current_version = self.get_current_version().await?;
let target_version = target_version.unwrap_or(u64::MAX);
let migrations = self.registry.between(current_version + 1, target_version);
let mut results = Vec::new();
for migration in migrations {
let mut ctx = MigrationContext::new(self.work_dir.clone());
for (name, backend) in &backends {
ctx.backends.insert(name.clone(), backend.clone());
}
let start_time = std::time::Instant::now();
let result = migration.up(&mut ctx).await;
let duration = start_time.elapsed();
let migration_result = MigrationResult {
migration_id: migration.meta().id,
migration_name: migration.meta().name.clone(),
direction: MigrationDirection::Up,
success: result.is_ok(),
error: result.err(),
duration,
};
results.push(migration_result.clone());
if !migration_result.success {
break; }
self.set_current_version(migration.meta().version).await?;
}
Ok(results)
}
pub async fn migrate_down(
&self,
target_version: u64,
backends: HashMap<String, Box<dyn std::any::Any + Send + Sync>>,
) -> Result<Vec<MigrationResult>> {
let current_version = self.get_current_version().await?;
if current_version <= target_version {
return Ok(Vec::new());
}
let migrations = self.registry.between(target_version + 1, current_version);
let mut results = Vec::new();
for migration in migrations.into_iter().rev() {
let mut ctx = MigrationContext::new(self.work_dir.clone());
for (name, backend) in &backends {
ctx.backends.insert(name.clone(), backend.clone());
}
let start_time = std::time::Instant::now();
let result = migration.down(&mut ctx).await;
let duration = start_time.elapsed();
let migration_result = MigrationResult {
migration_id: migration.meta().id,
migration_name: migration.meta().name.clone(),
direction: MigrationDirection::Down,
success: result.is_ok(),
error: result.err(),
duration,
};
results.push(migration_result.clone());
if !migration_result.success {
break; }
self.set_current_version(migration.meta().version - 1).await?;
}
Ok(results)
}
async fn get_current_version(&self) -> Result<u64> {
let version_file = self.work_dir.join("current_version");
if !version_file.exists() {
return Ok(0);
}
let content = tokio::fs::read_to_string(&version_file).await
.map_err(|e| Error::io(format!("Failed to read version file: {}", e)))?;
content.trim().parse::<u64>()
.map_err(|e| Error::parse(format!("Invalid version format: {}", e)))
}
async fn set_current_version(&self, version: u64) -> Result<()> {
tokio::fs::create_dir_all(&self.work_dir).await
.map_err(|e| Error::io(format!("Failed to create work directory: {}", e)))?;
let version_file = self.work_dir.join("current_version");
tokio::fs::write(&version_file, version.to_string()).await
.map_err(|e| Error::io(format!("Failed to write version file: {}", e)))?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MigrationResult {
pub migration_id: Uuid,
pub migration_name: String,
pub direction: MigrationDirection,
pub success: bool,
pub error: Option<Error>,
pub duration: std::time::Duration,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MigrationDirection {
Up,
Down,
}
#[macro_export]
macro_rules! migration {
($name:ident, $version:expr, $desc:expr, $backends:expr) => {
#[derive(Debug)]
pub struct $name;
impl $crate::migrations::Migration for $name {
fn meta(&self) -> $crate::migrations::MigrationMeta {
$crate::migrations::MigrationMeta {
id: uuid::Uuid::new_v4(),
version: $version,
name: stringify!($name).to_string(),
description: $desc.to_string(),
created_at: chrono::Utc::now(),
backends: $backends.iter().map(|s| s.to_string()).collect(),
}
}
}
};
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tokio::sync::Mutex;
migration!(TestMigration001, 1, "Test migration 1", &["qdrant"]);
impl Migration for TestMigration001 {
async fn up(&self, _ctx: &mut MigrationContext) -> Result<()> {
Ok(())
}
async fn down(&self, _ctx: &mut MigrationContext) -> Result<()> {
Ok(())
}
}
#[tokio::test]
async fn test_migration_registry() {
let mut registry = MigrationRegistry::new();
let migration = TestMigration001;
registry.register(migration).unwrap();
let all = registry.all_sorted();
assert_eq!(all.len(), 1);
assert_eq!(all[0].meta().version, 1);
}
#[tokio::test]
async fn test_migration_runner() {
let temp_dir = tempfile::tempdir().unwrap();
let work_dir = temp_dir.path().to_path_buf();
let mut registry = MigrationRegistry::new();
registry.register(TestMigration001).unwrap();
let runner = MigrationRunner::new(registry, work_dir.clone());
let version = runner.get_current_version().await.unwrap();
assert_eq!(version, 0);
let backends = HashMap::new();
let results = runner.migrate_up(None, backends).await.unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].success);
let version = runner.get_current_version().await.unwrap();
assert_eq!(version, 1);
}
}