use std::collections::HashMap;
use crate::{WeeDb, WeeDbRaw};
impl<T> WeeDb<T> {
pub fn apply<P>(&self, migrations: Migrations<P, Self>) -> Result<(), MigrationError>
where
P: VersionProvider,
{
apply_migrations(self, migrations)
}
}
impl WeeDbRaw {
pub fn apply<P>(&self, migrations: Migrations<P, Self>) -> Result<(), MigrationError>
where
P: VersionProvider,
{
apply_migrations(self, migrations)
}
}
fn apply_migrations<P: VersionProvider, D: AsRef<WeeDbRaw>>(
db: &D,
migrations: Migrations<P, D>,
) -> Result<(), MigrationError> {
let raw = db.as_ref();
if migrations.version_provider.get_version(raw)?.is_none() {
tracing::info!("starting with empty db");
return migrations
.version_provider
.set_version(raw, migrations.target_version);
}
loop {
let version: Semver = migrations
.version_provider
.get_version(raw)?
.ok_or(MigrationError::VersionNotFound)?;
match version.cmp(&migrations.target_version) {
std::cmp::Ordering::Less => {}
std::cmp::Ordering::Equal => {
tracing::info!("stored DB version is compatible");
break Ok(());
}
std::cmp::Ordering::Greater => {
break Err(MigrationError::IncompatibleDbVersion {
version,
expected: migrations.target_version,
})
}
}
let migration = migrations
.migrations
.get(&version)
.ok_or(MigrationError::MigrationNotFound(version))?;
tracing::info!(?version, "applying migration");
migrations
.version_provider
.set_version(raw, (*migration)(db)?)?;
}
}
pub struct Migrations<P, D = WeeDbRaw> {
target_version: Semver,
migrations: HashMap<Semver, Migration<D>>,
version_provider: P,
}
impl<D> Migrations<DefaultVersionProvider, D> {
pub fn with_target_version(target_version: Semver) -> Self {
Self {
target_version,
migrations: Default::default(),
version_provider: DefaultVersionProvider,
}
}
}
impl<P: VersionProvider, D> Migrations<P, D> {
pub fn with_target_version_and_provider(target_version: Semver, version_provider: P) -> Self {
Self {
target_version,
migrations: Default::default(),
version_provider,
}
}
pub fn register<F>(
&mut self,
from: Semver,
to: Semver,
migration: F,
) -> Result<(), MigrationError>
where
F: Fn(&D) -> Result<(), MigrationError> + 'static,
{
use std::collections::hash_map;
match self.migrations.entry(from) {
hash_map::Entry::Vacant(entry) => {
entry.insert(Box::new(move |db| {
migration(db)?;
Ok(to)
}));
Ok(())
}
hash_map::Entry::Occupied(entry) => {
Err(MigrationError::DuplicateMigration(*entry.key()))
}
}
}
}
pub type Semver = [u8; 3];
type Migration<D = WeeDbRaw> = Box<dyn Fn(&D) -> Result<Semver, MigrationError>>;
pub trait VersionProvider {
fn get_version(&self, db: &WeeDbRaw) -> Result<Option<Semver>, MigrationError>;
fn set_version(&self, db: &WeeDbRaw, version: Semver) -> Result<(), MigrationError>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct DefaultVersionProvider;
impl DefaultVersionProvider {
const DB_VERSION_KEY: &'static str = "weedb_version";
}
impl VersionProvider for DefaultVersionProvider {
fn get_version(&self, db: &WeeDbRaw) -> Result<Option<Semver>, MigrationError> {
match db.rocksdb().get(Self::DB_VERSION_KEY)? {
Some(version) => version
.try_into()
.map_err(|_| MigrationError::InvalidDbVersion)
.map(Some),
None => Ok(None),
}
}
fn set_version(&self, db: &WeeDbRaw, version: Semver) -> Result<(), MigrationError> {
db.rocksdb()
.put(Self::DB_VERSION_KEY, version)
.map_err(MigrationError::DbError)
}
}
#[derive(thiserror::Error, Debug)]
pub enum MigrationError {
#[error("incompatible DB version: {version:?}, expected {expected:?}")]
IncompatibleDbVersion { version: Semver, expected: Semver },
#[error("existing DB version not found")]
VersionNotFound,
#[error("invalid version")]
InvalidDbVersion,
#[error("migration not found: {0:?}")]
MigrationNotFound(Semver),
#[error("duplicate migration: {0:?}")]
DuplicateMigration(Semver),
#[error("db error: {0}")]
DbError(#[from] rocksdb::Error),
#[error("{0}")]
Custom(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
}