Skip to main content

openauth_cli/
db.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3
4use openauth_core::db::{DbAdapter, DbSchema, SchemaMigrationPlan, SchemaMigrationWarning};
5use openauth_core::error::OpenAuthError;
6use openauth_sqlx::{MySqlAdapter, PostgresAdapter, SqliteAdapter};
7use serde::Serialize;
8use sha2::{Digest, Sha256};
9use time::format_description::well_known::Rfc3339;
10use time::OffsetDateTime;
11
12use crate::config::CliConfig;
13use crate::schema::{dialect_from_provider, dialect_name, full_schema_plan, target_schema};
14
15#[derive(Debug, thiserror::Error)]
16pub enum DbCliError {
17    #[error("database provider is not configured")]
18    MissingProvider,
19    #[error("database URL environment variable {0} is not set; add it to .env/.env.local or export it before running this command")]
20    MissingDatabaseUrl(String),
21    #[error("unsupported database adapter `{0}`; CLI migrations currently support sqlx with sqlite, postgres, or mysql")]
22    UnsupportedAdapter(String),
23    #[error("unsupported database provider `{0}`")]
24    UnsupportedProvider(String),
25    #[error("migration has non-executable warnings; fix schema mismatches before applying")]
26    UnsafeMigration,
27    #[error("A migration for this plan already exists: {0}")]
28    DuplicateMigration(String),
29    #[error("database error: {0}")]
30    OpenAuth(#[from] OpenAuthError),
31    #[error("failed to write {path}: {source}")]
32    Write {
33        path: PathBuf,
34        source: std::io::Error,
35    },
36    #[error("failed to read {path}: {source}")]
37    Read {
38        path: PathBuf,
39        source: std::io::Error,
40    },
41    #[error("failed to create {path}: {source}")]
42    CreateDir {
43        path: PathBuf,
44        source: std::io::Error,
45    },
46    #[error("failed to format timestamp: {0}")]
47    TimeFormat(#[from] time::error::Format),
48}
49
50#[derive(Debug, Clone, Serialize)]
51pub struct PlanSummary {
52    pub provider: String,
53    pub tables_to_create: usize,
54    pub columns_to_add: usize,
55    pub indexes_to_create: usize,
56    pub warnings: Vec<SchemaMigrationWarning>,
57    pub statements: usize,
58    pub plan_hash: String,
59}
60
61#[derive(Debug, Clone)]
62pub struct PlannedMigration {
63    pub schema: DbSchema,
64    pub plan: SchemaMigrationPlan,
65    pub provider: String,
66}
67
68impl PlannedMigration {
69    pub fn summary(&self) -> PlanSummary {
70        PlanSummary {
71            provider: self.provider.clone(),
72            tables_to_create: self.plan.to_be_created.len(),
73            columns_to_add: self.plan.to_be_added.len(),
74            indexes_to_create: self.plan.indexes_to_be_created.len(),
75            warnings: self.plan.warnings.clone(),
76            statements: self.plan.statements.len(),
77            plan_hash: plan_hash(&self.plan),
78        }
79    }
80}
81
82pub async fn plan(config: &CliConfig, from_empty: bool) -> Result<PlannedMigration, DbCliError> {
83    plan_with_base(config, from_empty, None).await
84}
85
86pub async fn plan_with_base(
87    config: &CliConfig,
88    from_empty: bool,
89    cwd: Option<&Path>,
90) -> Result<PlannedMigration, DbCliError> {
91    validate_sql_adapter(config)?;
92    let schema = target_schema(config)?;
93    let provider = config
94        .database
95        .provider
96        .clone()
97        .ok_or(DbCliError::MissingProvider)?;
98
99    let plan = if from_empty {
100        let dialect = dialect_from_provider(&provider)
101            .ok_or_else(|| DbCliError::UnsupportedProvider(provider.clone()))?;
102        full_schema_plan(dialect, &schema)?
103    } else {
104        let database_url = database_url_with_base(config, cwd)?;
105        match provider.as_str() {
106            "sqlite" | "sqlite3" => {
107                ensure_sqlite_database(&database_url)?;
108                SqliteAdapter::connect_with_schema(&database_url, schema.clone())
109                    .await?
110                    .plan_migrations(&schema)
111                    .await?
112            }
113            "postgres" | "postgresql" | "pg" => {
114                PostgresAdapter::connect_with_schema(&database_url, schema.clone())
115                    .await?
116                    .plan_migrations(&schema)
117                    .await?
118            }
119            "mysql" => {
120                MySqlAdapter::connect_with_schema(&database_url, schema.clone())
121                    .await?
122                    .plan_migrations(&schema)
123                    .await?
124            }
125            _ => return Err(DbCliError::UnsupportedProvider(provider)),
126        }
127    };
128
129    Ok(PlannedMigration {
130        schema,
131        plan,
132        provider,
133    })
134}
135
136pub async fn migrate(config: &CliConfig) -> Result<PlannedMigration, DbCliError> {
137    migrate_with_base(config, None).await
138}
139
140pub async fn migrate_with_base(
141    config: &CliConfig,
142    cwd: Option<&Path>,
143) -> Result<PlannedMigration, DbCliError> {
144    let planned = plan_with_base(config, false, cwd).await?;
145    if !planned.plan.warnings.is_empty() {
146        return Err(DbCliError::UnsafeMigration);
147    }
148    let database_url = database_url_with_base(config, cwd)?;
149    match planned.provider.as_str() {
150        "sqlite" | "sqlite3" => {
151            ensure_sqlite_database(&database_url)?;
152            let adapter =
153                SqliteAdapter::connect_with_schema(&database_url, planned.schema.clone()).await?;
154            adapter.run_migrations(&planned.schema).await?;
155        }
156        "postgres" | "postgresql" | "pg" => {
157            let adapter =
158                PostgresAdapter::connect_with_schema(&database_url, planned.schema.clone()).await?;
159            adapter.run_migrations(&planned.schema).await?;
160        }
161        "mysql" => {
162            let adapter =
163                MySqlAdapter::connect_with_schema(&database_url, planned.schema.clone()).await?;
164            adapter.run_migrations(&planned.schema).await?;
165        }
166        _ => return Err(DbCliError::UnsupportedProvider(planned.provider.clone())),
167    }
168    Ok(planned)
169}
170
171pub fn migration_sql(config: &CliConfig, planned: &PlannedMigration) -> Result<String, DbCliError> {
172    let dialect = dialect_from_provider(&planned.provider)
173        .ok_or_else(|| DbCliError::UnsupportedProvider(planned.provider.clone()))?;
174    let generated_at = OffsetDateTime::now_utc().format(&Rfc3339)?;
175    let schema_hash = schema_hash(&planned.schema)?;
176    let plan_hash = plan_hash(&planned.plan);
177    Ok(format!(
178        "-- OpenAuth migration\n-- dialect: {}\n-- generated_at: {}\n-- schema_hash: {}\n-- plan_hash: {}\n-- config_base_path: {}\n\n{}",
179        dialect_name(dialect),
180        generated_at,
181        schema_hash,
182        plan_hash,
183        config.project.base_path,
184        planned.plan.compile()
185    ))
186}
187
188pub fn write_migration(
189    config: &CliConfig,
190    planned: &PlannedMigration,
191    output: Option<&Path>,
192    force: bool,
193) -> Result<PathBuf, DbCliError> {
194    write_migration_output(
195        config,
196        planned,
197        output
198            .map(|path| MigrationOutput::Directory(path.to_path_buf()))
199            .unwrap_or(MigrationOutput::Default),
200        force,
201    )
202}
203
204pub enum MigrationOutput {
205    Default,
206    Directory(PathBuf),
207    File(PathBuf),
208}
209
210pub fn write_migration_output(
211    config: &CliConfig,
212    planned: &PlannedMigration,
213    output: MigrationOutput,
214    force: bool,
215) -> Result<PathBuf, DbCliError> {
216    if planned.plan.is_empty() {
217        return Ok(PathBuf::new());
218    }
219    let (dir, explicit_file) = match output {
220        MigrationOutput::Default => (PathBuf::from(&config.database.migrations_dir), None),
221        MigrationOutput::Directory(dir) => (dir, None),
222        MigrationOutput::File(path) => (
223            path.parent()
224                .map(Path::to_path_buf)
225                .unwrap_or_else(|| PathBuf::from(".")),
226            Some(path),
227        ),
228    };
229    let hash = plan_hash(&planned.plan);
230    if !force {
231        if let Some(existing) = find_existing_plan_hash(&dir, &hash)? {
232            return Err(DbCliError::DuplicateMigration(
233                existing.display().to_string(),
234            ));
235        }
236    }
237    fs::create_dir_all(&dir).map_err(|source| DbCliError::CreateDir {
238        path: dir.clone(),
239        source,
240    })?;
241    let path = explicit_file.unwrap_or_else(|| {
242        dir.join(format!(
243            "{}_{}_{}.sql",
244            filename_timestamp(),
245            normalized_provider(&planned.provider),
246            hash
247        ))
248    });
249    if path.exists() && !force {
250        return Err(DbCliError::DuplicateMigration(path.display().to_string()));
251    }
252    let sql = migration_sql(config, planned)?;
253    fs::write(&path, sql).map_err(|source| DbCliError::Write {
254        path: path.clone(),
255        source,
256    })?;
257    Ok(path)
258}
259
260pub fn schema_hash(schema: &DbSchema) -> Result<String, DbCliError> {
261    let payload = serde_json::to_vec(schema)
262        .map_err(|error| OpenAuthError::Adapter(format!("failed to serialize schema: {error}")))?;
263    Ok(short_hash(&payload))
264}
265
266pub fn plan_hash(plan: &SchemaMigrationPlan) -> String {
267    short_hash(plan.compile().as_bytes())
268}
269
270pub fn database_url(config: &CliConfig) -> Result<String, DbCliError> {
271    database_url_with_base(config, None)
272}
273
274pub fn database_url_with_base(
275    config: &CliConfig,
276    cwd: Option<&Path>,
277) -> Result<String, DbCliError> {
278    std::env::var(&config.database.url_env)
279        .map(|url| normalize_database_url(config.database.provider.as_deref(), &url, cwd))
280        .map_err(|_| DbCliError::MissingDatabaseUrl(config.database.url_env.clone()))
281}
282
283pub fn supports_sql_migrations(config: &CliConfig) -> bool {
284    config.database.adapter == "sqlx"
285        && config
286            .database
287            .provider
288            .as_deref()
289            .is_some_and(|provider| dialect_from_provider(provider).is_some())
290}
291
292fn validate_sql_adapter(config: &CliConfig) -> Result<(), DbCliError> {
293    if config.database.adapter != "sqlx" {
294        return Err(DbCliError::UnsupportedAdapter(
295            config.database.adapter.clone(),
296        ));
297    }
298    Ok(())
299}
300
301fn normalize_database_url(provider: Option<&str>, url: &str, cwd: Option<&Path>) -> String {
302    if !matches!(provider, Some("sqlite" | "sqlite3")) {
303        return url.to_owned();
304    }
305    let Some(cwd) = cwd else {
306        return url.to_owned();
307    };
308    let Some(path) = sqlite_path(url) else {
309        return url.to_owned();
310    };
311    if path.as_os_str().is_empty() || path.is_absolute() {
312        return url.to_owned();
313    }
314    format!("sqlite://{}", cwd.join(path).display())
315}
316
317fn short_hash(input: &[u8]) -> String {
318    let digest = Sha256::digest(input);
319    hex::encode(&digest[..8])
320}
321
322fn find_existing_plan_hash(dir: &Path, hash: &str) -> Result<Option<PathBuf>, DbCliError> {
323    if !dir.exists() {
324        return Ok(None);
325    }
326    for entry in fs::read_dir(dir).map_err(|source| DbCliError::Read {
327        path: dir.to_path_buf(),
328        source,
329    })? {
330        let entry = entry.map_err(|source| DbCliError::Read {
331            path: dir.to_path_buf(),
332            source,
333        })?;
334        let path = entry.path();
335        if path.extension().and_then(|extension| extension.to_str()) != Some("sql") {
336            continue;
337        }
338        let content = fs::read_to_string(&path).map_err(|source| DbCliError::Read {
339            path: path.clone(),
340            source,
341        })?;
342        if content.contains(&format!("plan_hash: {hash}")) {
343            return Ok(Some(path));
344        }
345    }
346    Ok(None)
347}
348
349fn filename_timestamp() -> String {
350    let now = OffsetDateTime::now_utc();
351    format!(
352        "{:04}{:02}{:02}{:02}{:02}{:02}",
353        now.year(),
354        u8::from(now.month()),
355        now.day(),
356        now.hour(),
357        now.minute(),
358        now.second()
359    )
360}
361
362fn normalized_provider(provider: &str) -> &str {
363    match provider {
364        "postgresql" | "pg" => "postgres",
365        "sqlite3" => "sqlite",
366        other => other,
367    }
368}
369
370fn ensure_sqlite_database(database_url: &str) -> Result<(), DbCliError> {
371    let Some(path) = sqlite_path(database_url) else {
372        return Ok(());
373    };
374    if path.as_os_str().is_empty() || path.exists() {
375        return Ok(());
376    }
377    if let Some(parent) = path.parent() {
378        fs::create_dir_all(parent).map_err(|source| DbCliError::CreateDir {
379            path: parent.to_path_buf(),
380            source,
381        })?;
382    }
383    fs::File::create(&path)
384        .map(|_| ())
385        .map_err(|source| DbCliError::Write { path, source })
386}
387
388fn sqlite_path(database_url: &str) -> Option<PathBuf> {
389    if database_url == "sqlite::memory:" || database_url == "sqlite://:memory:" {
390        return None;
391    }
392    database_url
393        .strip_prefix("sqlite://")
394        .or_else(|| database_url.strip_prefix("sqlite:"))
395        .map(PathBuf::from)
396}