Skip to main content

rustauth_cli/
db.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3
4use rustauth_core::db::{DbAdapter, DbSchema, SchemaMigrationPlan, SchemaMigrationWarning};
5use rustauth_core::error::RustAuthError;
6#[cfg(feature = "deadpool-postgres")]
7use rustauth_deadpool_postgres::DeadpoolPostgresAdapter;
8#[cfg(feature = "sqlx")]
9use rustauth_sqlx::{MySqlAdapter, PostgresAdapter, SqliteAdapter};
10#[cfg(feature = "tokio-postgres")]
11use rustauth_tokio_postgres::TokioPostgresAdapter;
12use serde::Serialize;
13use sha2::{Digest, Sha256};
14use time::format_description::well_known::Rfc3339;
15use time::OffsetDateTime;
16
17use crate::config::CliConfig;
18use crate::plugins::plugin_migrations_for_config;
19use crate::schema::{dialect_from_provider, dialect_name, full_schema_plan, target_schema};
20
21pub fn is_cli_migration_adapter(adapter: &str) -> bool {
22    match adapter {
23        "sqlx" if cfg!(feature = "sqlx") => true,
24        "tokio-postgres" if cfg!(feature = "tokio-postgres") => true,
25        "deadpool-postgres" if cfg!(feature = "deadpool-postgres") => true,
26        _ => false,
27    }
28}
29
30pub fn is_known_cli_migration_adapter(adapter: &str) -> bool {
31    matches!(adapter, "sqlx" | "tokio-postgres" | "deadpool-postgres")
32}
33
34fn is_adapter_feature_disabled(adapter: &str) -> bool {
35    is_known_cli_migration_adapter(adapter) && !is_cli_migration_adapter(adapter)
36}
37
38pub fn cli_migration_adapter_names() -> Vec<&'static str> {
39    let mut adapters = Vec::new();
40    if cfg!(feature = "sqlx") {
41        adapters.push("sqlx");
42    }
43    if cfg!(feature = "tokio-postgres") {
44        adapters.push("tokio-postgres");
45    }
46    if cfg!(feature = "deadpool-postgres") {
47        adapters.push("deadpool-postgres");
48    }
49    adapters
50}
51
52fn is_postgres_provider(provider: &str) -> bool {
53    matches!(provider, "postgres" | "postgresql" | "pg")
54}
55
56#[derive(Debug, thiserror::Error)]
57pub enum DbCliError {
58    #[error("database provider is not configured")]
59    MissingProvider,
60    #[error("database URL environment variable {0} is not set; add it to .env/.env.local next to the project or config file, or export it before running this command")]
61    MissingDatabaseUrl(String),
62    #[error(
63        "unsupported database adapter `{adapter}`; {support}",
64        adapter = .0,
65        support = unsupported_adapter_support_suffix()
66    )]
67    UnsupportedAdapter(String),
68    #[error(
69        "database adapter `{0}` is not enabled in this CLI build; rebuild with the matching \
70         Cargo feature ({1})"
71    )]
72    AdapterFeatureDisabled(String, String),
73    #[error("unsupported database provider `{0}`")]
74    UnsupportedProvider(String),
75    #[error("migration has non-executable warnings; fix schema mismatches before applying")]
76    UnsafeMigration,
77    #[error("A migration for this plan already exists: {0}")]
78    DuplicateMigration(String),
79    #[error("database error: {0}")]
80    RustAuth(#[from] RustAuthError),
81    #[error("failed to write {path}: {source}")]
82    Write {
83        path: PathBuf,
84        source: std::io::Error,
85    },
86    #[error("failed to read {path}: {source}")]
87    Read {
88        path: PathBuf,
89        source: std::io::Error,
90    },
91    #[error("failed to create {path}: {source}")]
92    CreateDir {
93        path: PathBuf,
94        source: std::io::Error,
95    },
96    #[error("failed to format timestamp: {0}")]
97    TimeFormat(#[from] time::error::Format),
98}
99
100#[derive(Debug, Clone, Serialize)]
101pub struct PlanSummary {
102    pub provider: String,
103    pub tables_to_create: usize,
104    pub columns_to_add: usize,
105    pub indexes_to_create: usize,
106    pub warnings: Vec<SchemaMigrationWarning>,
107    pub statements: usize,
108    pub plan_hash: String,
109}
110
111#[derive(Debug, Clone)]
112pub struct PlannedMigration {
113    pub schema: DbSchema,
114    pub plan: SchemaMigrationPlan,
115    pub provider: String,
116}
117
118impl PlannedMigration {
119    pub fn summary(&self) -> PlanSummary {
120        PlanSummary {
121            provider: self.provider.clone(),
122            tables_to_create: self.plan.to_be_created.len(),
123            columns_to_add: self.plan.to_be_added.len(),
124            indexes_to_create: self.plan.indexes_to_be_created.len(),
125            warnings: self.plan.warnings.clone(),
126            statements: self.plan.statements.len(),
127            plan_hash: plan_hash(&self.plan),
128        }
129    }
130}
131
132pub async fn plan(config: &CliConfig, from_empty: bool) -> Result<PlannedMigration, DbCliError> {
133    plan_with_base(config, from_empty, None).await
134}
135
136pub async fn plan_with_base(
137    config: &CliConfig,
138    from_empty: bool,
139    cwd: Option<&Path>,
140) -> Result<PlannedMigration, DbCliError> {
141    validate_cli_migration_adapter(config)?;
142    let schema = target_schema(config)?;
143    let provider = config
144        .database
145        .provider
146        .clone()
147        .ok_or(DbCliError::MissingProvider)?;
148
149    let plan = if from_empty {
150        let dialect = dialect_from_provider(&provider)
151            .ok_or_else(|| DbCliError::UnsupportedProvider(provider.clone()))?;
152        full_schema_plan(dialect, &schema)?
153    } else {
154        let database_url = database_url_with_base(config, cwd)?;
155        match config.database.adapter.as_str() {
156            #[cfg(feature = "sqlx")]
157            "sqlx" => plan_with_sqlx(&provider, &database_url, &schema).await?,
158            #[cfg(feature = "tokio-postgres")]
159            "tokio-postgres" => {
160                if !is_postgres_provider(&provider) {
161                    return Err(DbCliError::UnsupportedProvider(provider));
162                }
163                TokioPostgresAdapter::connect_with_schema(&database_url, schema.clone())
164                    .await?
165                    .plan_migrations(&schema)
166                    .await?
167            }
168            #[cfg(feature = "deadpool-postgres")]
169            "deadpool-postgres" => {
170                if !is_postgres_provider(&provider) {
171                    return Err(DbCliError::UnsupportedProvider(provider));
172                }
173                DeadpoolPostgresAdapter::builder()
174                    .database_url(database_url)
175                    .schema(schema.clone())
176                    .connect()
177                    .await?
178                    .plan_migrations(&schema)
179                    .await?
180            }
181            adapter => return Err(adapter_dispatch_error(adapter)),
182        }
183    };
184
185    Ok(PlannedMigration {
186        schema,
187        plan,
188        provider,
189    })
190}
191
192pub async fn migrate(config: &CliConfig) -> Result<PlannedMigration, DbCliError> {
193    migrate_with_base(config, None).await
194}
195
196pub async fn migrate_with_base(
197    config: &CliConfig,
198    cwd: Option<&Path>,
199) -> Result<PlannedMigration, DbCliError> {
200    let planned = plan_with_base(config, false, cwd).await?;
201    if !planned.plan.warnings.is_empty() {
202        return Err(DbCliError::UnsafeMigration);
203    }
204    let database_url = database_url_with_base(config, cwd)?;
205    let plugin_migrations = plugin_migrations_for_config(&config.plugins.enabled)?;
206    match config.database.adapter.as_str() {
207        #[cfg(feature = "sqlx")]
208        "sqlx" => {
209            run_migrations_with_sqlx(
210                &planned.provider,
211                &database_url,
212                &planned.schema,
213                &plugin_migrations,
214            )
215            .await?;
216        }
217        #[cfg(feature = "tokio-postgres")]
218        "tokio-postgres" => {
219            let adapter =
220                TokioPostgresAdapter::connect_with_schema(&database_url, planned.schema.clone())
221                    .await?;
222            adapter.run_migrations(&planned.schema).await?;
223            adapter.run_plugin_migrations(&plugin_migrations).await?;
224        }
225        #[cfg(feature = "deadpool-postgres")]
226        "deadpool-postgres" => {
227            let adapter = DeadpoolPostgresAdapter::builder()
228                .database_url(database_url)
229                .schema(planned.schema.clone())
230                .connect()
231                .await?;
232            adapter.run_migrations(&planned.schema).await?;
233            adapter.run_plugin_migrations(&plugin_migrations).await?;
234        }
235        adapter => return Err(adapter_dispatch_error(adapter)),
236    }
237    Ok(planned)
238}
239
240#[cfg(feature = "sqlx")]
241async fn plan_with_sqlx(
242    provider: &str,
243    database_url: &str,
244    schema: &DbSchema,
245) -> Result<SchemaMigrationPlan, DbCliError> {
246    match provider {
247        "sqlite" | "sqlite3" => {
248            ensure_sqlite_database(database_url)?;
249            SqliteAdapter::connect_with_schema(database_url, schema.clone())
250                .await?
251                .plan_migrations(schema)
252                .await
253                .map_err(Into::into)
254        }
255        "postgres" | "postgresql" | "pg" => {
256            PostgresAdapter::connect_with_schema(database_url, schema.clone())
257                .await?
258                .plan_migrations(schema)
259                .await
260                .map_err(Into::into)
261        }
262        "mysql" => MySqlAdapter::connect_with_schema(database_url, schema.clone())
263            .await?
264            .plan_migrations(schema)
265            .await
266            .map_err(Into::into),
267        other => Err(DbCliError::UnsupportedProvider(other.to_owned())),
268    }
269}
270
271#[cfg(feature = "sqlx")]
272async fn run_migrations_with_sqlx(
273    provider: &str,
274    database_url: &str,
275    schema: &DbSchema,
276    plugin_migrations: &[rustauth_core::plugin::PluginMigration],
277) -> Result<(), DbCliError> {
278    match provider {
279        "sqlite" | "sqlite3" => {
280            ensure_sqlite_database(database_url)?;
281            let adapter = SqliteAdapter::connect_with_schema(database_url, schema.clone()).await?;
282            adapter.run_migrations(schema).await?;
283            adapter.run_plugin_migrations(plugin_migrations).await?;
284        }
285        "postgres" | "postgresql" | "pg" => {
286            let adapter =
287                PostgresAdapter::connect_with_schema(database_url, schema.clone()).await?;
288            adapter.run_migrations(schema).await?;
289            adapter.run_plugin_migrations(plugin_migrations).await?;
290        }
291        "mysql" => {
292            let adapter = MySqlAdapter::connect_with_schema(database_url, schema.clone()).await?;
293            adapter.run_migrations(schema).await?;
294            adapter.run_plugin_migrations(plugin_migrations).await?;
295        }
296        other => return Err(DbCliError::UnsupportedProvider(other.to_owned())),
297    }
298    Ok(())
299}
300
301pub fn migration_sql(config: &CliConfig, planned: &PlannedMigration) -> Result<String, DbCliError> {
302    let dialect = dialect_from_provider(&planned.provider)
303        .ok_or_else(|| DbCliError::UnsupportedProvider(planned.provider.clone()))?;
304    let generated_at = OffsetDateTime::now_utc().format(&Rfc3339)?;
305    let schema_hash = schema_hash(&planned.schema)?;
306    let plan_hash = plan_hash(&planned.plan);
307    Ok(format!(
308        "-- RustAuth migration\n-- dialect: {}\n-- generated_at: {}\n-- schema_hash: {}\n-- plan_hash: {}\n-- config_base_path: {}\n\n{}",
309        dialect_name(dialect),
310        generated_at,
311        schema_hash,
312        plan_hash,
313        config.project.base_path,
314        planned.plan.compile()
315    ))
316}
317
318pub fn write_migration(
319    config: &CliConfig,
320    planned: &PlannedMigration,
321    output: Option<&Path>,
322    force: bool,
323) -> Result<PathBuf, DbCliError> {
324    write_migration_output(
325        config,
326        planned,
327        output
328            .map(|path| MigrationOutput::Directory(path.to_path_buf()))
329            .unwrap_or(MigrationOutput::Default),
330        force,
331    )
332}
333
334pub enum MigrationOutput {
335    Default,
336    Directory(PathBuf),
337    File(PathBuf),
338}
339
340pub fn write_migration_output(
341    config: &CliConfig,
342    planned: &PlannedMigration,
343    output: MigrationOutput,
344    force: bool,
345) -> Result<PathBuf, DbCliError> {
346    if planned.plan.is_empty() {
347        return Ok(PathBuf::new());
348    }
349    let (dir, explicit_file) = match output {
350        MigrationOutput::Default => (PathBuf::from(&config.database.migrations_dir), None),
351        MigrationOutput::Directory(dir) => (dir, None),
352        MigrationOutput::File(path) => (
353            path.parent()
354                .map(Path::to_path_buf)
355                .unwrap_or_else(|| PathBuf::from(".")),
356            Some(path),
357        ),
358    };
359    let hash = plan_hash(&planned.plan);
360    if !force {
361        if let Some(existing) = find_existing_plan_hash(&dir, &hash)? {
362            return Err(DbCliError::DuplicateMigration(
363                existing.display().to_string(),
364            ));
365        }
366    }
367    fs::create_dir_all(&dir).map_err(|source| DbCliError::CreateDir {
368        path: dir.clone(),
369        source,
370    })?;
371    let path = explicit_file.unwrap_or_else(|| {
372        dir.join(format!(
373            "{}_{}_{}.sql",
374            filename_timestamp(),
375            normalized_provider(&planned.provider),
376            hash
377        ))
378    });
379    if path.exists() && !force {
380        return Err(DbCliError::DuplicateMigration(path.display().to_string()));
381    }
382    let sql = migration_sql(config, planned)?;
383    fs::write(&path, sql).map_err(|source| DbCliError::Write {
384        path: path.clone(),
385        source,
386    })?;
387    Ok(path)
388}
389
390pub fn schema_hash(schema: &DbSchema) -> Result<String, DbCliError> {
391    let payload = serde_json::to_vec(schema)
392        .map_err(|error| RustAuthError::Adapter(format!("failed to serialize schema: {error}")))?;
393    Ok(short_hash(&payload))
394}
395
396pub fn plan_hash(plan: &SchemaMigrationPlan) -> String {
397    short_hash(plan.compile().as_bytes())
398}
399
400pub fn database_url(config: &CliConfig) -> Result<String, DbCliError> {
401    database_url_with_base(config, None)
402}
403
404pub fn database_url_with_base(
405    config: &CliConfig,
406    cwd: Option<&Path>,
407) -> Result<String, DbCliError> {
408    std::env::var(&config.database.url_env)
409        .map(|url| normalize_database_url(config.database.provider.as_deref(), &url, cwd))
410        .map_err(|_| DbCliError::MissingDatabaseUrl(config.database.url_env.clone()))
411}
412
413pub fn supports_sql_migrations(config: &CliConfig) -> bool {
414    if !is_cli_migration_adapter(&config.database.adapter) {
415        return false;
416    }
417    match config.database.adapter.as_str() {
418        "sqlx" if cfg!(feature = "sqlx") => config
419            .database
420            .provider
421            .as_deref()
422            .is_some_and(|provider| dialect_from_provider(provider).is_some()),
423        "tokio-postgres" if cfg!(feature = "tokio-postgres") => config
424            .database
425            .provider
426            .as_deref()
427            .is_some_and(is_postgres_provider),
428        "deadpool-postgres" if cfg!(feature = "deadpool-postgres") => config
429            .database
430            .provider
431            .as_deref()
432            .is_some_and(is_postgres_provider),
433        _ => false,
434    }
435}
436
437/// Adapters that are valid in the ecosystem but not driven by `rustauth db migrate`.
438///
439/// For these we print guidance and exit successfully (Better Auth parity for Prisma/Drizzle).
440pub fn unsupported_adapter_exits_successfully(adapter: &str) -> bool {
441    matches!(
442        adapter,
443        "prisma" | "drizzle" | "memory" | "mongodb" | "kysely"
444    )
445}
446
447pub fn unsupported_adapter_guidance(adapter: &str, command: &str) -> String {
448    match adapter {
449        "prisma" => format!(
450            "The {command} command applies RustAuth SQL migrations through the sqlx adapter. \
451             With Prisma configured, run `rustauth db generate` to write `.sql` files, then apply \
452             them with `prisma migrate` or `prisma db push`."
453        ),
454        "drizzle" => format!(
455            "The {command} command applies RustAuth SQL migrations through the sqlx adapter. \
456             With Drizzle configured, run `rustauth db generate` to write `.sql` files, then apply \
457             them with your Drizzle migration workflow."
458        ),
459        "kysely" => format!(
460            "The {command} command uses the sqlx adapter in rustauth.toml. \
461             Set `database.adapter = \"sqlx\"` and configure `database.provider`, or run \
462             `rustauth db generate` and apply the SQL with your existing Kysely tooling."
463        ),
464        "memory" => format!(
465            "The {command} command does not apply migrations for the in-memory adapter. \
466             Use `database.adapter = \"sqlx\"` with a real provider for CLI migrations, or \
467             `rustauth schema print` to inspect the target schema."
468        ),
469        "mongodb" => format!(
470            "The {command} command does not support MongoDB. \
471             Use a SQL provider with {}",
472            enabled_adapter_guidance()
473        ),
474        other => format!(
475            "Unsupported database adapter `{other}` for {command}. \
476             RustAuth CLI migrations require {}",
477            enabled_adapter_guidance()
478        ),
479    }
480}
481
482fn validate_cli_migration_adapter(config: &CliConfig) -> Result<(), DbCliError> {
483    let adapter = config.database.adapter.as_str();
484    if is_adapter_feature_disabled(adapter) {
485        return Err(DbCliError::AdapterFeatureDisabled(
486            adapter.to_owned(),
487            adapter_cargo_feature(adapter).to_owned(),
488        ));
489    }
490    if !is_cli_migration_adapter(adapter) {
491        return Err(DbCliError::UnsupportedAdapter(
492            config.database.adapter.clone(),
493        ));
494    }
495    Ok(())
496}
497
498fn adapter_dispatch_error(adapter: &str) -> DbCliError {
499    if is_adapter_feature_disabled(adapter) {
500        DbCliError::AdapterFeatureDisabled(
501            adapter.to_owned(),
502            adapter_cargo_feature(adapter).to_owned(),
503        )
504    } else {
505        DbCliError::UnsupportedAdapter(adapter.to_owned())
506    }
507}
508
509fn adapter_cargo_feature(adapter: &str) -> &'static str {
510    match adapter {
511        "sqlx" => "sqlx",
512        "tokio-postgres" => "tokio-postgres",
513        "deadpool-postgres" => "deadpool-postgres",
514        _ => "unknown",
515    }
516}
517
518fn unsupported_adapter_support_suffix() -> String {
519    format!("CLI migrations support {}", enabled_adapter_guidance())
520}
521
522fn enabled_adapter_guidance() -> String {
523    let mut parts = Vec::new();
524    if cfg!(feature = "sqlx") {
525        parts.push("`database.adapter = \"sqlx\"` (sqlite, postgres, mysql)".to_owned());
526    }
527    if cfg!(feature = "tokio-postgres") {
528        parts.push("`database.adapter = \"tokio-postgres\"` (postgres only)".to_owned());
529    }
530    if cfg!(feature = "deadpool-postgres") {
531        parts.push("`database.adapter = \"deadpool-postgres\"` (postgres only)".to_owned());
532    }
533    if parts.is_empty() {
534        "no database migration adapters in this CLI build".to_owned()
535    } else {
536        parts.join(", ")
537    }
538}
539
540fn normalize_database_url(provider: Option<&str>, url: &str, cwd: Option<&Path>) -> String {
541    if !matches!(provider, Some("sqlite" | "sqlite3")) {
542        return url.to_owned();
543    }
544    let Some(cwd) = cwd else {
545        return url.to_owned();
546    };
547    let Some(path) = sqlite_path(url) else {
548        return url.to_owned();
549    };
550    if path.as_os_str().is_empty() || path.is_absolute() {
551        return url.to_owned();
552    }
553    format!("sqlite://{}", cwd.join(path).display())
554}
555
556fn short_hash(input: &[u8]) -> String {
557    let digest = Sha256::digest(input);
558    hex::encode(&digest[..8])
559}
560
561fn find_existing_plan_hash(dir: &Path, hash: &str) -> Result<Option<PathBuf>, DbCliError> {
562    if !dir.exists() {
563        return Ok(None);
564    }
565    for entry in fs::read_dir(dir).map_err(|source| DbCliError::Read {
566        path: dir.to_path_buf(),
567        source,
568    })? {
569        let entry = entry.map_err(|source| DbCliError::Read {
570            path: dir.to_path_buf(),
571            source,
572        })?;
573        let path = entry.path();
574        if path.extension().and_then(|extension| extension.to_str()) != Some("sql") {
575            continue;
576        }
577        let content = fs::read_to_string(&path).map_err(|source| DbCliError::Read {
578            path: path.clone(),
579            source,
580        })?;
581        if content.contains(&format!("plan_hash: {hash}")) {
582            return Ok(Some(path));
583        }
584    }
585    Ok(None)
586}
587
588fn filename_timestamp() -> String {
589    let now = OffsetDateTime::now_utc();
590    format!(
591        "{:04}{:02}{:02}{:02}{:02}{:02}",
592        now.year(),
593        u8::from(now.month()),
594        now.day(),
595        now.hour(),
596        now.minute(),
597        now.second()
598    )
599}
600
601fn normalized_provider(provider: &str) -> &str {
602    match provider {
603        "postgresql" | "pg" => "postgres",
604        "sqlite3" => "sqlite",
605        other => other,
606    }
607}
608
609fn ensure_sqlite_database(database_url: &str) -> Result<(), DbCliError> {
610    let Some(path) = sqlite_path(database_url) else {
611        return Ok(());
612    };
613    if path.as_os_str().is_empty() || path.exists() {
614        return Ok(());
615    }
616    if let Some(parent) = path.parent() {
617        fs::create_dir_all(parent).map_err(|source| DbCliError::CreateDir {
618            path: parent.to_path_buf(),
619            source,
620        })?;
621    }
622    fs::File::create(&path)
623        .map(|_| ())
624        .map_err(|source| DbCliError::Write { path, source })
625}
626
627fn sqlite_path(database_url: &str) -> Option<PathBuf> {
628    if database_url == "sqlite::memory:" || database_url == "sqlite://:memory:" {
629        return None;
630    }
631    database_url
632        .strip_prefix("sqlite://")
633        .or_else(|| database_url.strip_prefix("sqlite:"))
634        .map(PathBuf::from)
635}