use crate::config::{Config, Dialect as CliDialect};
use crate::error::CliError;
use crate::output;
use drizzle_migrations::upgrade::upgrade_to_latest;
use drizzle_migrations::version::{is_supported_version, snapshot_version};
use drizzle_types::Dialect;
use std::fs;
use std::path::{Path, PathBuf};
#[derive(clap::Args, Debug, Clone, Default)]
pub struct UpgradeOptions {
#[arg(long)]
pub dialect: Option<CliDialect>,
#[arg(long)]
pub out: Option<PathBuf>,
}
pub fn run(config: &Config, db_name: Option<&str>, opts: &UpgradeOptions) -> Result<(), CliError> {
let db = config.database(db_name)?;
let dialect = opts.dialect.unwrap_or(db.dialect).to_base();
let out_dir = opts.out.as_deref().unwrap_or_else(|| db.migrations_dir());
println!(
"{}",
output::heading(&format!(
"Checking for snapshots to upgrade in {}",
out_dir.display()
))
);
if !out_dir.exists() {
println!(
"{}",
output::warning(&format!(
"No migrations folder found at {}",
out_dir.display()
))
);
return Ok(());
}
let upgraded = upgrade_snapshots(out_dir, dialect)?;
if upgraded == 0 {
println!(
"{}",
output::success(&format!(
"All snapshots are already at the latest version ({})",
snapshot_version(dialect)
))
);
} else {
println!(
"{}",
output::success(&format!(
"Upgraded {} snapshot(s) to version {}",
upgraded,
snapshot_version(dialect)
))
);
}
Ok(())
}
fn upgrade_snapshots(out_dir: &Path, dialect: Dialect) -> Result<usize, CliError> {
let mut upgraded_count = 0;
let v3_snapshots = find_v3_snapshots(out_dir)?;
for snapshot_path in v3_snapshots {
if upgrade_snapshot_file(&snapshot_path, dialect)? {
upgraded_count += 1;
}
}
let meta_folder = out_dir.join("meta");
if meta_folder.exists() {
let legacy_snapshots = find_legacy_snapshots(&meta_folder)?;
for snapshot_path in legacy_snapshots {
if upgrade_snapshot_file(&snapshot_path, dialect)? {
upgraded_count += 1;
}
}
}
Ok(upgraded_count)
}
fn find_v3_snapshots(out_dir: &Path) -> Result<Vec<std::path::PathBuf>, CliError> {
let mut snapshots = Vec::new();
if !out_dir.exists() {
return Ok(snapshots);
}
for entry in fs::read_dir(out_dir).map_err(|e| CliError::IoError(e.to_string()))? {
let entry = entry.map_err(|e| CliError::IoError(e.to_string()))?;
let path = entry.path();
if path.is_dir() {
let snapshot_path = path.join("snapshot.json");
if snapshot_path.exists() {
snapshots.push(snapshot_path);
}
}
}
Ok(snapshots)
}
fn find_legacy_snapshots(meta_folder: &Path) -> Result<Vec<std::path::PathBuf>, CliError> {
let mut snapshots = Vec::new();
if !meta_folder.exists() {
return Ok(snapshots);
}
for entry in fs::read_dir(meta_folder).map_err(|e| CliError::IoError(e.to_string()))? {
let entry = entry.map_err(|e| CliError::IoError(e.to_string()))?;
let path = entry.path();
if path.is_file()
&& let Some(name) = path.file_name().and_then(|n| n.to_str())
&& name.ends_with("_snapshot.json")
{
snapshots.push(path);
}
}
Ok(snapshots)
}
fn upgrade_snapshot_file(path: &Path, dialect: Dialect) -> Result<bool, CliError> {
let contents = fs::read_to_string(path).map_err(|e| CliError::IoError(e.to_string()))?;
let json: serde_json::Value = serde_json::from_str(&contents)
.map_err(|e| CliError::Other(format!("Invalid JSON in {}: {}", path.display(), e)))?;
let version = json
.get("version")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let latest_version = snapshot_version(dialect);
if version == latest_version {
return Ok(false);
}
let version_num: u32 = version.parse().unwrap_or(0);
if !is_supported_version(dialect, version) && version_num > 0 {
println!(
"{}",
output::warning(&format!(
"Skipping {}: version {} is not supported for upgrade",
path.display(),
version
))
);
return Ok(false);
}
println!(
"{}",
output::info(&format!(
"Upgrading {} from version {} to {}",
path.display(),
version,
latest_version
))
);
let upgraded = upgrade_to_latest(json, dialect);
let upgraded_json = serde_json::to_string_pretty(&upgraded)
.map_err(|e| CliError::Other(format!("Failed to serialize upgraded snapshot: {e}")))?;
fs::write(path, upgraded_json).map_err(|e| CliError::IoError(e.to_string()))?;
Ok(true)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_find_v3_snapshots() {
let temp_dir = TempDir::new().unwrap();
let migration_folder = temp_dir.path().join("20231220_initial");
fs::create_dir_all(&migration_folder).unwrap();
fs::write(migration_folder.join("snapshot.json"), "{}").unwrap();
fs::write(migration_folder.join("migration.sql"), "").unwrap();
let snapshots = find_v3_snapshots(temp_dir.path()).unwrap();
assert_eq!(snapshots.len(), 1);
}
#[test]
fn test_find_legacy_snapshots() {
let temp_dir = TempDir::new().unwrap();
let meta_folder = temp_dir.path().join("meta");
fs::create_dir_all(&meta_folder).unwrap();
fs::write(meta_folder.join("0000_initial_snapshot.json"), "{}").unwrap();
fs::write(meta_folder.join("0001_add_users_snapshot.json"), "{}").unwrap();
fs::write(meta_folder.join("_journal.json"), "{}").unwrap();
let snapshots = find_legacy_snapshots(&meta_folder).unwrap();
assert_eq!(snapshots.len(), 2);
}
}