1use std::fs;
2use std::path::{Path, PathBuf};
3
4use openauth_core::db::{DbAdapter, DbSchema, SchemaMigrationPlan, SchemaMigrationWarning};
5use openauth_core::error::OpenAuthError;
6use openauth_sqlx::{MySqlAdapter, PostgresAdapter, SqliteAdapter};
7use serde::Serialize;
8use sha2::{Digest, Sha256};
9use time::format_description::well_known::Rfc3339;
10use time::OffsetDateTime;
11
12use crate::config::CliConfig;
13use crate::schema::{dialect_from_provider, dialect_name, full_schema_plan, target_schema};
14
15#[derive(Debug, thiserror::Error)]
16pub enum DbCliError {
17 #[error("database provider is not configured")]
18 MissingProvider,
19 #[error("database URL environment variable {0} is not set")]
20 MissingDatabaseUrl(String),
21 #[error("unsupported database provider `{0}`")]
22 UnsupportedProvider(String),
23 #[error("migration has non-executable warnings; fix schema mismatches before applying")]
24 UnsafeMigration,
25 #[error("A migration for this plan already exists: {0}")]
26 DuplicateMigration(String),
27 #[error("database error: {0}")]
28 OpenAuth(#[from] OpenAuthError),
29 #[error("failed to write {path}: {source}")]
30 Write {
31 path: PathBuf,
32 source: std::io::Error,
33 },
34 #[error("failed to read {path}: {source}")]
35 Read {
36 path: PathBuf,
37 source: std::io::Error,
38 },
39 #[error("failed to create {path}: {source}")]
40 CreateDir {
41 path: PathBuf,
42 source: std::io::Error,
43 },
44 #[error("failed to format timestamp: {0}")]
45 TimeFormat(#[from] time::error::Format),
46}
47
48#[derive(Debug, Clone, Serialize)]
49pub struct PlanSummary {
50 pub provider: String,
51 pub tables_to_create: usize,
52 pub columns_to_add: usize,
53 pub indexes_to_create: usize,
54 pub warnings: Vec<SchemaMigrationWarning>,
55 pub statements: usize,
56 pub plan_hash: String,
57}
58
59#[derive(Debug, Clone)]
60pub struct PlannedMigration {
61 pub schema: DbSchema,
62 pub plan: SchemaMigrationPlan,
63 pub provider: String,
64}
65
66impl PlannedMigration {
67 pub fn summary(&self) -> PlanSummary {
68 PlanSummary {
69 provider: self.provider.clone(),
70 tables_to_create: self.plan.to_be_created.len(),
71 columns_to_add: self.plan.to_be_added.len(),
72 indexes_to_create: self.plan.indexes_to_be_created.len(),
73 warnings: self.plan.warnings.clone(),
74 statements: self.plan.statements.len(),
75 plan_hash: plan_hash(&self.plan),
76 }
77 }
78}
79
80pub async fn plan(config: &CliConfig, from_empty: bool) -> Result<PlannedMigration, DbCliError> {
81 let schema = target_schema(config)?;
82 let provider = config
83 .database
84 .provider
85 .clone()
86 .ok_or(DbCliError::MissingProvider)?;
87
88 let plan = if from_empty {
89 let dialect = dialect_from_provider(&provider)
90 .ok_or_else(|| DbCliError::UnsupportedProvider(provider.clone()))?;
91 full_schema_plan(dialect, &schema)?
92 } else {
93 let database_url = database_url(config)?;
94 match provider.as_str() {
95 "sqlite" | "sqlite3" => {
96 ensure_sqlite_database(&database_url)?;
97 SqliteAdapter::connect_with_schema(&database_url, schema.clone())
98 .await?
99 .plan_migrations(&schema)
100 .await?
101 }
102 "postgres" | "postgresql" | "pg" => {
103 PostgresAdapter::connect_with_schema(&database_url, schema.clone())
104 .await?
105 .plan_migrations(&schema)
106 .await?
107 }
108 "mysql" => {
109 MySqlAdapter::connect_with_schema(&database_url, schema.clone())
110 .await?
111 .plan_migrations(&schema)
112 .await?
113 }
114 _ => return Err(DbCliError::UnsupportedProvider(provider)),
115 }
116 };
117
118 Ok(PlannedMigration {
119 schema,
120 plan,
121 provider,
122 })
123}
124
125pub async fn migrate(config: &CliConfig) -> Result<PlannedMigration, DbCliError> {
126 let planned = plan(config, false).await?;
127 if !planned.plan.warnings.is_empty() {
128 return Err(DbCliError::UnsafeMigration);
129 }
130 let database_url = database_url(config)?;
131 match planned.provider.as_str() {
132 "sqlite" | "sqlite3" => {
133 ensure_sqlite_database(&database_url)?;
134 let adapter =
135 SqliteAdapter::connect_with_schema(&database_url, planned.schema.clone()).await?;
136 adapter.run_migrations(&planned.schema).await?;
137 }
138 "postgres" | "postgresql" | "pg" => {
139 let adapter =
140 PostgresAdapter::connect_with_schema(&database_url, planned.schema.clone()).await?;
141 adapter.run_migrations(&planned.schema).await?;
142 }
143 "mysql" => {
144 let adapter =
145 MySqlAdapter::connect_with_schema(&database_url, planned.schema.clone()).await?;
146 adapter.run_migrations(&planned.schema).await?;
147 }
148 _ => return Err(DbCliError::UnsupportedProvider(planned.provider.clone())),
149 }
150 Ok(planned)
151}
152
153pub fn migration_sql(config: &CliConfig, planned: &PlannedMigration) -> Result<String, DbCliError> {
154 let dialect = dialect_from_provider(&planned.provider)
155 .ok_or_else(|| DbCliError::UnsupportedProvider(planned.provider.clone()))?;
156 let generated_at = OffsetDateTime::now_utc().format(&Rfc3339)?;
157 let schema_hash = schema_hash(&planned.schema)?;
158 let plan_hash = plan_hash(&planned.plan);
159 Ok(format!(
160 "-- OpenAuth migration\n-- dialect: {}\n-- generated_at: {}\n-- schema_hash: {}\n-- plan_hash: {}\n-- config_base_path: {}\n\n{}",
161 dialect_name(dialect),
162 generated_at,
163 schema_hash,
164 plan_hash,
165 config.project.base_path,
166 planned.plan.compile()
167 ))
168}
169
170pub fn write_migration(
171 config: &CliConfig,
172 planned: &PlannedMigration,
173 output: Option<&Path>,
174 force: bool,
175) -> Result<PathBuf, DbCliError> {
176 if planned.plan.is_empty() {
177 return Ok(PathBuf::new());
178 }
179 let dir = output
180 .map(Path::to_path_buf)
181 .unwrap_or_else(|| PathBuf::from(&config.database.migrations_dir));
182 let hash = plan_hash(&planned.plan);
183 if let Some(existing) = find_existing_plan_hash(&dir, &hash)? {
184 return Err(DbCliError::DuplicateMigration(
185 existing.display().to_string(),
186 ));
187 }
188 fs::create_dir_all(&dir).map_err(|source| DbCliError::CreateDir {
189 path: dir.clone(),
190 source,
191 })?;
192 let path = dir.join(format!(
193 "{}_{}_{}.sql",
194 filename_timestamp(),
195 normalized_provider(&planned.provider),
196 hash
197 ));
198 if path.exists() && !force {
199 return Err(DbCliError::DuplicateMigration(path.display().to_string()));
200 }
201 let sql = migration_sql(config, planned)?;
202 fs::write(&path, sql).map_err(|source| DbCliError::Write {
203 path: path.clone(),
204 source,
205 })?;
206 Ok(path)
207}
208
209pub fn schema_hash(schema: &DbSchema) -> Result<String, DbCliError> {
210 let payload = serde_json::to_vec(schema)
211 .map_err(|error| OpenAuthError::Adapter(format!("failed to serialize schema: {error}")))?;
212 Ok(short_hash(&payload))
213}
214
215pub fn plan_hash(plan: &SchemaMigrationPlan) -> String {
216 short_hash(plan.compile().as_bytes())
217}
218
219pub fn database_url(config: &CliConfig) -> Result<String, DbCliError> {
220 std::env::var(&config.database.url_env)
221 .map_err(|_| DbCliError::MissingDatabaseUrl(config.database.url_env.clone()))
222}
223
224fn short_hash(input: &[u8]) -> String {
225 let digest = Sha256::digest(input);
226 hex::encode(&digest[..8])
227}
228
229fn find_existing_plan_hash(dir: &Path, hash: &str) -> Result<Option<PathBuf>, DbCliError> {
230 if !dir.exists() {
231 return Ok(None);
232 }
233 for entry in fs::read_dir(dir).map_err(|source| DbCliError::Read {
234 path: dir.to_path_buf(),
235 source,
236 })? {
237 let entry = entry.map_err(|source| DbCliError::Read {
238 path: dir.to_path_buf(),
239 source,
240 })?;
241 let path = entry.path();
242 if path.extension().and_then(|extension| extension.to_str()) != Some("sql") {
243 continue;
244 }
245 let content = fs::read_to_string(&path).map_err(|source| DbCliError::Read {
246 path: path.clone(),
247 source,
248 })?;
249 if content.contains(&format!("plan_hash: {hash}")) {
250 return Ok(Some(path));
251 }
252 }
253 Ok(None)
254}
255
256fn filename_timestamp() -> String {
257 let now = OffsetDateTime::now_utc();
258 format!(
259 "{:04}{:02}{:02}{:02}{:02}{:02}",
260 now.year(),
261 u8::from(now.month()),
262 now.day(),
263 now.hour(),
264 now.minute(),
265 now.second()
266 )
267}
268
269fn normalized_provider(provider: &str) -> &str {
270 match provider {
271 "postgresql" | "pg" => "postgres",
272 "sqlite3" => "sqlite",
273 other => other,
274 }
275}
276
277fn ensure_sqlite_database(database_url: &str) -> Result<(), DbCliError> {
278 let Some(path) = sqlite_path(database_url) else {
279 return Ok(());
280 };
281 if path.as_os_str().is_empty() || path.exists() {
282 return Ok(());
283 }
284 if let Some(parent) = path.parent() {
285 fs::create_dir_all(parent).map_err(|source| DbCliError::CreateDir {
286 path: parent.to_path_buf(),
287 source,
288 })?;
289 }
290 fs::File::create(&path)
291 .map(|_| ())
292 .map_err(|source| DbCliError::Write { path, source })
293}
294
295fn sqlite_path(database_url: &str) -> Option<PathBuf> {
296 if database_url == "sqlite::memory:" || database_url == "sqlite://:memory:" {
297 return None;
298 }
299 database_url
300 .strip_prefix("sqlite://")
301 .or_else(|| database_url.strip_prefix("sqlite:"))
302 .map(PathBuf::from)
303}