openauth-cli 0.0.4

Command-line tools for OpenAuth.
Documentation
use std::fs;
use std::path::{Path, PathBuf};

use openauth_core::db::{DbAdapter, DbSchema, SchemaMigrationPlan, SchemaMigrationWarning};
use openauth_core::error::OpenAuthError;
use openauth_sqlx::{MySqlAdapter, PostgresAdapter, SqliteAdapter};
use serde::Serialize;
use sha2::{Digest, Sha256};
use time::format_description::well_known::Rfc3339;
use time::OffsetDateTime;

use crate::config::CliConfig;
use crate::schema::{dialect_from_provider, dialect_name, full_schema_plan, target_schema};

#[derive(Debug, thiserror::Error)]
pub enum DbCliError {
    #[error("database provider is not configured")]
    MissingProvider,
    #[error("database URL environment variable {0} is not set")]
    MissingDatabaseUrl(String),
    #[error("unsupported database provider `{0}`")]
    UnsupportedProvider(String),
    #[error("migration has non-executable warnings; fix schema mismatches before applying")]
    UnsafeMigration,
    #[error("A migration for this plan already exists: {0}")]
    DuplicateMigration(String),
    #[error("database error: {0}")]
    OpenAuth(#[from] OpenAuthError),
    #[error("failed to write {path}: {source}")]
    Write {
        path: PathBuf,
        source: std::io::Error,
    },
    #[error("failed to read {path}: {source}")]
    Read {
        path: PathBuf,
        source: std::io::Error,
    },
    #[error("failed to create {path}: {source}")]
    CreateDir {
        path: PathBuf,
        source: std::io::Error,
    },
    #[error("failed to format timestamp: {0}")]
    TimeFormat(#[from] time::error::Format),
}

#[derive(Debug, Clone, Serialize)]
pub struct PlanSummary {
    pub provider: String,
    pub tables_to_create: usize,
    pub columns_to_add: usize,
    pub indexes_to_create: usize,
    pub warnings: Vec<SchemaMigrationWarning>,
    pub statements: usize,
    pub plan_hash: String,
}

#[derive(Debug, Clone)]
pub struct PlannedMigration {
    pub schema: DbSchema,
    pub plan: SchemaMigrationPlan,
    pub provider: String,
}

impl PlannedMigration {
    pub fn summary(&self) -> PlanSummary {
        PlanSummary {
            provider: self.provider.clone(),
            tables_to_create: self.plan.to_be_created.len(),
            columns_to_add: self.plan.to_be_added.len(),
            indexes_to_create: self.plan.indexes_to_be_created.len(),
            warnings: self.plan.warnings.clone(),
            statements: self.plan.statements.len(),
            plan_hash: plan_hash(&self.plan),
        }
    }
}

pub async fn plan(config: &CliConfig, from_empty: bool) -> Result<PlannedMigration, DbCliError> {
    let schema = target_schema(config)?;
    let provider = config
        .database
        .provider
        .clone()
        .ok_or(DbCliError::MissingProvider)?;

    let plan = if from_empty {
        let dialect = dialect_from_provider(&provider)
            .ok_or_else(|| DbCliError::UnsupportedProvider(provider.clone()))?;
        full_schema_plan(dialect, &schema)?
    } else {
        let database_url = database_url(config)?;
        match provider.as_str() {
            "sqlite" | "sqlite3" => {
                ensure_sqlite_database(&database_url)?;
                SqliteAdapter::connect_with_schema(&database_url, schema.clone())
                    .await?
                    .plan_migrations(&schema)
                    .await?
            }
            "postgres" | "postgresql" | "pg" => {
                PostgresAdapter::connect_with_schema(&database_url, schema.clone())
                    .await?
                    .plan_migrations(&schema)
                    .await?
            }
            "mysql" => {
                MySqlAdapter::connect_with_schema(&database_url, schema.clone())
                    .await?
                    .plan_migrations(&schema)
                    .await?
            }
            _ => return Err(DbCliError::UnsupportedProvider(provider)),
        }
    };

    Ok(PlannedMigration {
        schema,
        plan,
        provider,
    })
}

pub async fn migrate(config: &CliConfig) -> Result<PlannedMigration, DbCliError> {
    let planned = plan(config, false).await?;
    if !planned.plan.warnings.is_empty() {
        return Err(DbCliError::UnsafeMigration);
    }
    let database_url = database_url(config)?;
    match planned.provider.as_str() {
        "sqlite" | "sqlite3" => {
            ensure_sqlite_database(&database_url)?;
            let adapter =
                SqliteAdapter::connect_with_schema(&database_url, planned.schema.clone()).await?;
            adapter.run_migrations(&planned.schema).await?;
        }
        "postgres" | "postgresql" | "pg" => {
            let adapter =
                PostgresAdapter::connect_with_schema(&database_url, planned.schema.clone()).await?;
            adapter.run_migrations(&planned.schema).await?;
        }
        "mysql" => {
            let adapter =
                MySqlAdapter::connect_with_schema(&database_url, planned.schema.clone()).await?;
            adapter.run_migrations(&planned.schema).await?;
        }
        _ => return Err(DbCliError::UnsupportedProvider(planned.provider.clone())),
    }
    Ok(planned)
}

pub fn migration_sql(config: &CliConfig, planned: &PlannedMigration) -> Result<String, DbCliError> {
    let dialect = dialect_from_provider(&planned.provider)
        .ok_or_else(|| DbCliError::UnsupportedProvider(planned.provider.clone()))?;
    let generated_at = OffsetDateTime::now_utc().format(&Rfc3339)?;
    let schema_hash = schema_hash(&planned.schema)?;
    let plan_hash = plan_hash(&planned.plan);
    Ok(format!(
        "-- OpenAuth migration\n-- dialect: {}\n-- generated_at: {}\n-- schema_hash: {}\n-- plan_hash: {}\n-- config_base_path: {}\n\n{}",
        dialect_name(dialect),
        generated_at,
        schema_hash,
        plan_hash,
        config.project.base_path,
        planned.plan.compile()
    ))
}

pub fn write_migration(
    config: &CliConfig,
    planned: &PlannedMigration,
    output: Option<&Path>,
    force: bool,
) -> Result<PathBuf, DbCliError> {
    if planned.plan.is_empty() {
        return Ok(PathBuf::new());
    }
    let dir = output
        .map(Path::to_path_buf)
        .unwrap_or_else(|| PathBuf::from(&config.database.migrations_dir));
    let hash = plan_hash(&planned.plan);
    if let Some(existing) = find_existing_plan_hash(&dir, &hash)? {
        return Err(DbCliError::DuplicateMigration(
            existing.display().to_string(),
        ));
    }
    fs::create_dir_all(&dir).map_err(|source| DbCliError::CreateDir {
        path: dir.clone(),
        source,
    })?;
    let path = dir.join(format!(
        "{}_{}_{}.sql",
        filename_timestamp(),
        normalized_provider(&planned.provider),
        hash
    ));
    if path.exists() && !force {
        return Err(DbCliError::DuplicateMigration(path.display().to_string()));
    }
    let sql = migration_sql(config, planned)?;
    fs::write(&path, sql).map_err(|source| DbCliError::Write {
        path: path.clone(),
        source,
    })?;
    Ok(path)
}

pub fn schema_hash(schema: &DbSchema) -> Result<String, DbCliError> {
    let payload = serde_json::to_vec(schema)
        .map_err(|error| OpenAuthError::Adapter(format!("failed to serialize schema: {error}")))?;
    Ok(short_hash(&payload))
}

pub fn plan_hash(plan: &SchemaMigrationPlan) -> String {
    short_hash(plan.compile().as_bytes())
}

pub fn database_url(config: &CliConfig) -> Result<String, DbCliError> {
    std::env::var(&config.database.url_env)
        .map_err(|_| DbCliError::MissingDatabaseUrl(config.database.url_env.clone()))
}

fn short_hash(input: &[u8]) -> String {
    let digest = Sha256::digest(input);
    hex::encode(&digest[..8])
}

fn find_existing_plan_hash(dir: &Path, hash: &str) -> Result<Option<PathBuf>, DbCliError> {
    if !dir.exists() {
        return Ok(None);
    }
    for entry in fs::read_dir(dir).map_err(|source| DbCliError::Read {
        path: dir.to_path_buf(),
        source,
    })? {
        let entry = entry.map_err(|source| DbCliError::Read {
            path: dir.to_path_buf(),
            source,
        })?;
        let path = entry.path();
        if path.extension().and_then(|extension| extension.to_str()) != Some("sql") {
            continue;
        }
        let content = fs::read_to_string(&path).map_err(|source| DbCliError::Read {
            path: path.clone(),
            source,
        })?;
        if content.contains(&format!("plan_hash: {hash}")) {
            return Ok(Some(path));
        }
    }
    Ok(None)
}

fn filename_timestamp() -> String {
    let now = OffsetDateTime::now_utc();
    format!(
        "{:04}{:02}{:02}{:02}{:02}{:02}",
        now.year(),
        u8::from(now.month()),
        now.day(),
        now.hour(),
        now.minute(),
        now.second()
    )
}

fn normalized_provider(provider: &str) -> &str {
    match provider {
        "postgresql" | "pg" => "postgres",
        "sqlite3" => "sqlite",
        other => other,
    }
}

fn ensure_sqlite_database(database_url: &str) -> Result<(), DbCliError> {
    let Some(path) = sqlite_path(database_url) else {
        return Ok(());
    };
    if path.as_os_str().is_empty() || path.exists() {
        return Ok(());
    }
    if let Some(parent) = path.parent() {
        fs::create_dir_all(parent).map_err(|source| DbCliError::CreateDir {
            path: parent.to_path_buf(),
            source,
        })?;
    }
    fs::File::create(&path)
        .map(|_| ())
        .map_err(|source| DbCliError::Write { path, source })
}

fn sqlite_path(database_url: &str) -> Option<PathBuf> {
    if database_url == "sqlite::memory:" || database_url == "sqlite://:memory:" {
        return None;
    }
    database_url
        .strip_prefix("sqlite://")
        .or_else(|| database_url.strip_prefix("sqlite:"))
        .map(PathBuf::from)
}