use std::collections::HashMap;
use std::sync::Arc;
use crate::SzalError;
use crate::flow::FlowDef;
pub trait FlowMigration: Send + Sync {
fn source_version(&self) -> u32;
fn target_version(&self) -> u32;
fn migrate(&self, flow: FlowDef) -> Result<FlowDef, SzalError>;
}
#[must_use]
pub fn fn_migration<F>(from: u32, to: u32, f: F) -> Arc<dyn FlowMigration>
where
F: Fn(FlowDef) -> Result<FlowDef, SzalError> + Send + Sync + 'static,
{
Arc::new(FnMigration { from, to, f })
}
struct FnMigration<F> {
from: u32,
to: u32,
f: F,
}
impl<F> FlowMigration for FnMigration<F>
where
F: Fn(FlowDef) -> Result<FlowDef, SzalError> + Send + Sync,
{
fn source_version(&self) -> u32 {
self.from
}
fn target_version(&self) -> u32 {
self.to
}
fn migrate(&self, flow: FlowDef) -> Result<FlowDef, SzalError> {
(self.f)(flow)
}
}
#[derive(Default, Clone)]
pub struct MigrationRegistry {
migrations: HashMap<u32, Arc<dyn FlowMigration>>,
}
impl MigrationRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_migration(mut self, migration: Arc<dyn FlowMigration>) -> Self {
self.register(migration);
self
}
pub fn register(&mut self, migration: Arc<dyn FlowMigration>) {
let from = migration.source_version();
let to = migration.target_version();
assert!(
to > from,
"migration target_version ({to}) must be greater than source_version ({from})"
);
assert!(
!self.migrations.contains_key(&from),
"a migration is already registered for version {from}"
);
self.migrations.insert(from, migration);
}
#[must_use]
pub fn latest_version(&self) -> Option<u32> {
self.migrations.values().map(|m| m.target_version()).max()
}
pub fn migrate_to(&self, mut flow: FlowDef, target: u32) -> Result<FlowDef, SzalError> {
if flow.version == target {
return Ok(flow);
}
if flow.version > target {
return Err(SzalError::MigrationFailed(format!(
"flow '{}' is at version {} which is newer than target {target}; downgrades are not supported",
flow.name, flow.version
)));
}
while flow.version < target {
let current = flow.version;
let migration = self.migrations.get(¤t).ok_or_else(|| {
SzalError::MigrationFailed(format!(
"no migration registered from version {current} (flow '{}', target {target})",
flow.name
))
})?;
let to = migration.target_version();
if to > target {
return Err(SzalError::MigrationFailed(format!(
"migration from version {current} jumps to {to}, overshooting target {target} (flow '{}')",
flow.name
)));
}
let name = flow.name.clone();
flow = migration.migrate(flow)?;
flow.version = to;
tracing::debug!(flow = %name, from = current, to, "applied flow migration");
}
Ok(flow)
}
pub fn migrate_latest(&self, flow: FlowDef) -> Result<FlowDef, SzalError> {
match self.latest_version() {
Some(target) if target > flow.version => self.migrate_to(flow, target),
_ => Ok(flow),
}
}
}
impl std::fmt::Debug for MigrationRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut versions: Vec<_> = self
.migrations
.values()
.map(|m| (m.source_version(), m.target_version()))
.collect();
versions.sort_unstable();
f.debug_struct("MigrationRegistry")
.field("migrations", &versions)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::flow::{FlowDef, FlowMode};
use crate::step::StepDef;
fn flow_v(version: u32) -> FlowDef {
let mut flow = FlowDef::new("pipeline", FlowMode::Sequential).with_version(version);
flow.add_step(StepDef::new("deploy"));
flow
}
#[test]
fn migrate_single_step() {
let registry = MigrationRegistry::new().with_migration(fn_migration(1, 2, |mut flow| {
flow.steps[0].name = "release".into();
Ok(flow)
}));
let out = registry.migrate_to(flow_v(1), 2).unwrap();
assert_eq!(out.version, 2);
assert_eq!(out.steps[0].name, "release");
}
#[test]
fn migrate_chains_multiple_versions() {
let registry = MigrationRegistry::new()
.with_migration(fn_migration(1, 2, |mut flow| {
flow.description = "v2".into();
Ok(flow)
}))
.with_migration(fn_migration(2, 3, |mut flow| {
flow.add_step(StepDef::new("verify"));
Ok(flow)
}));
let out = registry.migrate_latest(flow_v(1)).unwrap();
assert_eq!(out.version, 3);
assert_eq!(out.description, "v2");
assert_eq!(out.steps.len(), 2);
assert_eq!(out.steps[1].name, "verify");
}
#[test]
fn already_at_target_is_noop() {
let registry = MigrationRegistry::new().with_migration(fn_migration(1, 2, Ok));
let out = registry.migrate_to(flow_v(2), 2).unwrap();
assert_eq!(out.version, 2);
}
#[test]
fn empty_registry_migrate_latest_noop() {
let registry = MigrationRegistry::new();
let out = registry.migrate_latest(flow_v(1)).unwrap();
assert_eq!(out.version, 1);
assert!(registry.latest_version().is_none());
}
#[test]
fn downgrade_is_rejected() {
let registry = MigrationRegistry::new().with_migration(fn_migration(1, 2, Ok));
let err = registry.migrate_to(flow_v(3), 2).unwrap_err();
assert!(matches!(err, SzalError::MigrationFailed(_)));
}
#[test]
fn missing_migration_path_errors() {
let registry = MigrationRegistry::new().with_migration(fn_migration(1, 2, Ok));
let err = registry.migrate_to(flow_v(1), 3).unwrap_err();
match err {
SzalError::MigrationFailed(m) => assert!(m.contains("no migration registered")),
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn overshoot_target_errors() {
let registry = MigrationRegistry::new().with_migration(fn_migration(1, 3, Ok));
let err = registry.migrate_to(flow_v(1), 2).unwrap_err();
assert!(matches!(err, SzalError::MigrationFailed(_)));
}
#[test]
fn migration_error_propagates() {
let registry = MigrationRegistry::new().with_migration(fn_migration(1, 2, |_flow| {
Err(SzalError::MigrationFailed("boom".into()))
}));
assert!(registry.migrate_to(flow_v(1), 2).is_err());
}
#[test]
#[should_panic(expected = "already registered")]
fn duplicate_source_version_panics() {
let _ = MigrationRegistry::new()
.with_migration(fn_migration(1, 2, Ok))
.with_migration(fn_migration(1, 3, Ok));
}
#[test]
#[should_panic(expected = "must be greater")]
fn non_increasing_version_panics() {
let _ = MigrationRegistry::new().with_migration(fn_migration(2, 2, Ok));
}
}