Skip to main content

openauth_cli/
db.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3
4use openauth_core::db::{DbAdapter, DbSchema, SchemaMigrationPlan, SchemaMigrationWarning};
5use openauth_core::error::OpenAuthError;
6use openauth_sqlx::{MySqlAdapter, PostgresAdapter, SqliteAdapter};
7use serde::Serialize;
8use sha2::{Digest, Sha256};
9use time::format_description::well_known::Rfc3339;
10use time::OffsetDateTime;
11
12use crate::config::CliConfig;
13use crate::schema::{dialect_from_provider, dialect_name, full_schema_plan, target_schema};
14
15#[derive(Debug, thiserror::Error)]
16pub enum DbCliError {
17    #[error("database provider is not configured")]
18    MissingProvider,
19    #[error("database URL environment variable {0} is not set")]
20    MissingDatabaseUrl(String),
21    #[error("unsupported database provider `{0}`")]
22    UnsupportedProvider(String),
23    #[error("migration has non-executable warnings; fix schema mismatches before applying")]
24    UnsafeMigration,
25    #[error("A migration for this plan already exists: {0}")]
26    DuplicateMigration(String),
27    #[error("database error: {0}")]
28    OpenAuth(#[from] OpenAuthError),
29    #[error("failed to write {path}: {source}")]
30    Write {
31        path: PathBuf,
32        source: std::io::Error,
33    },
34    #[error("failed to read {path}: {source}")]
35    Read {
36        path: PathBuf,
37        source: std::io::Error,
38    },
39    #[error("failed to create {path}: {source}")]
40    CreateDir {
41        path: PathBuf,
42        source: std::io::Error,
43    },
44    #[error("failed to format timestamp: {0}")]
45    TimeFormat(#[from] time::error::Format),
46}
47
48#[derive(Debug, Clone, Serialize)]
49pub struct PlanSummary {
50    pub provider: String,
51    pub tables_to_create: usize,
52    pub columns_to_add: usize,
53    pub indexes_to_create: usize,
54    pub warnings: Vec<SchemaMigrationWarning>,
55    pub statements: usize,
56    pub plan_hash: String,
57}
58
59#[derive(Debug, Clone)]
60pub struct PlannedMigration {
61    pub schema: DbSchema,
62    pub plan: SchemaMigrationPlan,
63    pub provider: String,
64}
65
66impl PlannedMigration {
67    pub fn summary(&self) -> PlanSummary {
68        PlanSummary {
69            provider: self.provider.clone(),
70            tables_to_create: self.plan.to_be_created.len(),
71            columns_to_add: self.plan.to_be_added.len(),
72            indexes_to_create: self.plan.indexes_to_be_created.len(),
73            warnings: self.plan.warnings.clone(),
74            statements: self.plan.statements.len(),
75            plan_hash: plan_hash(&self.plan),
76        }
77    }
78}
79
80pub async fn plan(config: &CliConfig, from_empty: bool) -> Result<PlannedMigration, DbCliError> {
81    let schema = target_schema(config)?;
82    let provider = config
83        .database
84        .provider
85        .clone()
86        .ok_or(DbCliError::MissingProvider)?;
87
88    let plan = if from_empty {
89        let dialect = dialect_from_provider(&provider)
90            .ok_or_else(|| DbCliError::UnsupportedProvider(provider.clone()))?;
91        full_schema_plan(dialect, &schema)?
92    } else {
93        let database_url = database_url(config)?;
94        match provider.as_str() {
95            "sqlite" | "sqlite3" => {
96                ensure_sqlite_database(&database_url)?;
97                SqliteAdapter::connect_with_schema(&database_url, schema.clone())
98                    .await?
99                    .plan_migrations(&schema)
100                    .await?
101            }
102            "postgres" | "postgresql" | "pg" => {
103                PostgresAdapter::connect_with_schema(&database_url, schema.clone())
104                    .await?
105                    .plan_migrations(&schema)
106                    .await?
107            }
108            "mysql" => {
109                MySqlAdapter::connect_with_schema(&database_url, schema.clone())
110                    .await?
111                    .plan_migrations(&schema)
112                    .await?
113            }
114            _ => return Err(DbCliError::UnsupportedProvider(provider)),
115        }
116    };
117
118    Ok(PlannedMigration {
119        schema,
120        plan,
121        provider,
122    })
123}
124
125pub async fn migrate(config: &CliConfig) -> Result<PlannedMigration, DbCliError> {
126    let planned = plan(config, false).await?;
127    if !planned.plan.warnings.is_empty() {
128        return Err(DbCliError::UnsafeMigration);
129    }
130    let database_url = database_url(config)?;
131    match planned.provider.as_str() {
132        "sqlite" | "sqlite3" => {
133            ensure_sqlite_database(&database_url)?;
134            let adapter =
135                SqliteAdapter::connect_with_schema(&database_url, planned.schema.clone()).await?;
136            adapter.run_migrations(&planned.schema).await?;
137        }
138        "postgres" | "postgresql" | "pg" => {
139            let adapter =
140                PostgresAdapter::connect_with_schema(&database_url, planned.schema.clone()).await?;
141            adapter.run_migrations(&planned.schema).await?;
142        }
143        "mysql" => {
144            let adapter =
145                MySqlAdapter::connect_with_schema(&database_url, planned.schema.clone()).await?;
146            adapter.run_migrations(&planned.schema).await?;
147        }
148        _ => return Err(DbCliError::UnsupportedProvider(planned.provider.clone())),
149    }
150    Ok(planned)
151}
152
153pub fn migration_sql(config: &CliConfig, planned: &PlannedMigration) -> Result<String, DbCliError> {
154    let dialect = dialect_from_provider(&planned.provider)
155        .ok_or_else(|| DbCliError::UnsupportedProvider(planned.provider.clone()))?;
156    let generated_at = OffsetDateTime::now_utc().format(&Rfc3339)?;
157    let schema_hash = schema_hash(&planned.schema)?;
158    let plan_hash = plan_hash(&planned.plan);
159    Ok(format!(
160        "-- OpenAuth migration\n-- dialect: {}\n-- generated_at: {}\n-- schema_hash: {}\n-- plan_hash: {}\n-- config_base_path: {}\n\n{}",
161        dialect_name(dialect),
162        generated_at,
163        schema_hash,
164        plan_hash,
165        config.project.base_path,
166        planned.plan.compile()
167    ))
168}
169
170pub fn write_migration(
171    config: &CliConfig,
172    planned: &PlannedMigration,
173    output: Option<&Path>,
174    force: bool,
175) -> Result<PathBuf, DbCliError> {
176    if planned.plan.is_empty() {
177        return Ok(PathBuf::new());
178    }
179    let dir = output
180        .map(Path::to_path_buf)
181        .unwrap_or_else(|| PathBuf::from(&config.database.migrations_dir));
182    let hash = plan_hash(&planned.plan);
183    if let Some(existing) = find_existing_plan_hash(&dir, &hash)? {
184        return Err(DbCliError::DuplicateMigration(
185            existing.display().to_string(),
186        ));
187    }
188    fs::create_dir_all(&dir).map_err(|source| DbCliError::CreateDir {
189        path: dir.clone(),
190        source,
191    })?;
192    let path = dir.join(format!(
193        "{}_{}_{}.sql",
194        filename_timestamp(),
195        normalized_provider(&planned.provider),
196        hash
197    ));
198    if path.exists() && !force {
199        return Err(DbCliError::DuplicateMigration(path.display().to_string()));
200    }
201    let sql = migration_sql(config, planned)?;
202    fs::write(&path, sql).map_err(|source| DbCliError::Write {
203        path: path.clone(),
204        source,
205    })?;
206    Ok(path)
207}
208
209pub fn schema_hash(schema: &DbSchema) -> Result<String, DbCliError> {
210    let payload = serde_json::to_vec(schema)
211        .map_err(|error| OpenAuthError::Adapter(format!("failed to serialize schema: {error}")))?;
212    Ok(short_hash(&payload))
213}
214
215pub fn plan_hash(plan: &SchemaMigrationPlan) -> String {
216    short_hash(plan.compile().as_bytes())
217}
218
219pub fn database_url(config: &CliConfig) -> Result<String, DbCliError> {
220    std::env::var(&config.database.url_env)
221        .map_err(|_| DbCliError::MissingDatabaseUrl(config.database.url_env.clone()))
222}
223
224fn short_hash(input: &[u8]) -> String {
225    let digest = Sha256::digest(input);
226    hex::encode(&digest[..8])
227}
228
229fn find_existing_plan_hash(dir: &Path, hash: &str) -> Result<Option<PathBuf>, DbCliError> {
230    if !dir.exists() {
231        return Ok(None);
232    }
233    for entry in fs::read_dir(dir).map_err(|source| DbCliError::Read {
234        path: dir.to_path_buf(),
235        source,
236    })? {
237        let entry = entry.map_err(|source| DbCliError::Read {
238            path: dir.to_path_buf(),
239            source,
240        })?;
241        let path = entry.path();
242        if path.extension().and_then(|extension| extension.to_str()) != Some("sql") {
243            continue;
244        }
245        let content = fs::read_to_string(&path).map_err(|source| DbCliError::Read {
246            path: path.clone(),
247            source,
248        })?;
249        if content.contains(&format!("plan_hash: {hash}")) {
250            return Ok(Some(path));
251        }
252    }
253    Ok(None)
254}
255
256fn filename_timestamp() -> String {
257    let now = OffsetDateTime::now_utc();
258    format!(
259        "{:04}{:02}{:02}{:02}{:02}{:02}",
260        now.year(),
261        u8::from(now.month()),
262        now.day(),
263        now.hour(),
264        now.minute(),
265        now.second()
266    )
267}
268
269fn normalized_provider(provider: &str) -> &str {
270    match provider {
271        "postgresql" | "pg" => "postgres",
272        "sqlite3" => "sqlite",
273        other => other,
274    }
275}
276
277fn ensure_sqlite_database(database_url: &str) -> Result<(), DbCliError> {
278    let Some(path) = sqlite_path(database_url) else {
279        return Ok(());
280    };
281    if path.as_os_str().is_empty() || path.exists() {
282        return Ok(());
283    }
284    if let Some(parent) = path.parent() {
285        fs::create_dir_all(parent).map_err(|source| DbCliError::CreateDir {
286            path: parent.to_path_buf(),
287            source,
288        })?;
289    }
290    fs::File::create(&path)
291        .map(|_| ())
292        .map_err(|source| DbCliError::Write { path, source })
293}
294
295fn sqlite_path(database_url: &str) -> Option<PathBuf> {
296    if database_url == "sqlite::memory:" || database_url == "sqlite://:memory:" {
297        return None;
298    }
299    database_url
300        .strip_prefix("sqlite://")
301        .or_else(|| database_url.strip_prefix("sqlite:"))
302        .map(PathBuf::from)
303}