Skip to main content

openauth_core/db/sql/
migrations.rs

1use super::*;
2
3/// Executes a pure migration plan through any SQL executor.
4///
5/// Introspection and transaction ownership stay in the adapter crate; this
6/// helper only runs the already-planned SQL statements in order.
7pub async fn execute_schema_migration_plan<E>(
8    executor: &mut E,
9    plan: &SchemaMigrationPlan,
10) -> Result<(), OpenAuthError>
11where
12    E: SqlExecutor,
13{
14    for statement in &plan.statements {
15        executor
16            .execute(SqlStatement::new(statement.sql.clone()))
17            .await?;
18    }
19    Ok(())
20}
21
22/// Additive schema changes planned for a live database.
23#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
24pub struct SchemaMigrationPlan {
25    pub to_be_created: Vec<TableToCreate>,
26    pub to_be_added: Vec<ColumnToAdd>,
27    pub indexes_to_be_created: Vec<IndexToCreate>,
28    pub warnings: Vec<SchemaMigrationWarning>,
29    pub statements: Vec<MigrationStatement>,
30}
31
32impl SchemaMigrationPlan {
33    pub fn is_empty(&self) -> bool {
34        self.statements.is_empty()
35    }
36
37    pub fn has_warnings(&self) -> bool {
38        !self.warnings.is_empty()
39    }
40
41    pub fn compile(&self) -> String {
42        if self.statements.is_empty() {
43            return ";".to_owned();
44        }
45
46        format!(
47            "{};",
48            self.statements
49                .iter()
50                .map(|statement| statement.sql.as_str())
51                .collect::<Vec<_>>()
52                .join(";\n\n")
53        )
54    }
55}
56
57/// A table missing from the database and planned for creation.
58#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
59pub struct TableToCreate {
60    pub logical_name: String,
61    pub table_name: String,
62}
63
64/// A column missing from an existing table and planned for additive creation.
65#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
66pub struct ColumnToAdd {
67    pub table_logical_name: String,
68    pub table_name: String,
69    pub field_logical_name: String,
70    pub column_name: String,
71}
72
73/// A standalone index missing from the database and planned for creation.
74#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75pub struct IndexToCreate {
76    pub table_logical_name: String,
77    pub table_name: String,
78    pub field_logical_name: String,
79    pub column_name: String,
80    pub index_name: String,
81    pub unique: bool,
82}
83
84/// Non-executable findings discovered while planning migrations.
85#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
86#[allow(clippy::enum_variant_names)]
87pub enum SchemaMigrationWarning {
88    ColumnTypeMismatch {
89        table_name: String,
90        column_name: String,
91        expected: String,
92        actual: String,
93    },
94    ColumnNullabilityMismatch {
95        table_name: String,
96        column_name: String,
97        expected_nullable: bool,
98        actual_nullable: bool,
99    },
100    PrimaryKeyMismatch {
101        table_name: String,
102        column_name: String,
103    },
104    GeneratedIdMismatch {
105        table_name: String,
106        column_name: String,
107        expected: IdGeneration,
108        actual: Option<IdGeneration>,
109    },
110    ForeignKeyMismatch {
111        table_name: String,
112        column_name: String,
113        expected: ForeignKey,
114        actual: Option<ForeignKey>,
115    },
116}
117
118/// A SQL statement emitted by a migration plan.
119#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
120pub struct MigrationStatement {
121    pub kind: MigrationStatementKind,
122    pub sql: String,
123}
124
125/// The additive operation represented by a migration statement.
126#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
127pub enum MigrationStatementKind {
128    CreateTable,
129    AddColumn,
130    CreateIndex,
131}
132
133/// Introspected database schema used by the pure migration planner.
134#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
135pub struct SqlSchemaSnapshot {
136    tables: IndexMap<String, SqlTableSnapshot>,
137}
138
139impl SqlSchemaSnapshot {
140    pub fn with_table(mut self, table: impl Into<String>) -> Self {
141        self.tables.entry(table.into()).or_default();
142        self
143    }
144
145    pub fn with_column(mut self, table: impl Into<String>, column: SqlColumnSnapshot) -> Self {
146        self.tables
147            .entry(table.into())
148            .or_default()
149            .columns
150            .insert(column.name.clone(), column);
151        self
152    }
153
154    pub fn with_index(mut self, table: impl Into<String>, index: impl Into<String>) -> Self {
155        self.tables
156            .entry(table.into())
157            .or_default()
158            .indexes
159            .insert(index.into());
160        self
161    }
162
163    pub fn with_unique_column(
164        mut self,
165        table: impl Into<String>,
166        column: impl Into<String>,
167    ) -> Self {
168        self.tables
169            .entry(table.into())
170            .or_default()
171            .unique_columns
172            .insert(column.into());
173        self
174    }
175
176    pub fn table_exists(&self, table: &str) -> bool {
177        self.tables.contains_key(table)
178    }
179
180    pub fn column_type(&self, table: &str, column: &str) -> Option<&str> {
181        self.column(table, column)
182            .map(|column| column.data_type.as_str())
183    }
184
185    pub fn column(&self, table: &str, column: &str) -> Option<&SqlColumnSnapshot> {
186        self.tables
187            .get(table)
188            .and_then(|table| table.columns.get(column))
189    }
190
191    pub fn index_exists(&self, table: &str, index: &str) -> bool {
192        self.tables
193            .get(table)
194            .is_some_and(|table| table.indexes.contains(index))
195            || self
196                .tables
197                .values()
198                .any(|table| table.indexes.contains(index))
199    }
200
201    pub fn unique_column_exists(&self, table: &str, column: &str) -> bool {
202        self.tables
203            .get(table)
204            .is_some_and(|table| table.unique_columns.contains(column))
205    }
206}
207
208/// Introspected table metadata used by the pure migration planner.
209#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
210pub struct SqlTableSnapshot {
211    columns: IndexMap<String, SqlColumnSnapshot>,
212    indexes: IndexSet<String>,
213    unique_columns: IndexSet<String>,
214}
215
216/// Introspected column metadata used by the pure migration planner.
217#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
218pub struct SqlColumnSnapshot {
219    pub name: String,
220    pub data_type: String,
221    pub nullable: Option<bool>,
222    pub primary_key: Option<bool>,
223    pub generated_id: Option<IdGeneration>,
224    pub foreign_key: Option<ForeignKey>,
225}
226
227impl SqlColumnSnapshot {
228    pub fn new(name: impl Into<String>, data_type: impl Into<String>) -> Self {
229        Self {
230            name: name.into(),
231            data_type: data_type.into(),
232            nullable: None,
233            primary_key: None,
234            generated_id: None,
235            foreign_key: None,
236        }
237    }
238
239    pub fn nullable(mut self, nullable: bool) -> Self {
240        self.nullable = Some(nullable);
241        self
242    }
243
244    pub fn primary_key(mut self, primary_key: bool) -> Self {
245        self.primary_key = Some(primary_key);
246        self
247    }
248
249    pub fn generated_id(mut self, generated_id: Option<IdGeneration>) -> Self {
250        self.generated_id = generated_id;
251        self
252    }
253
254    pub fn references(mut self, foreign_key: ForeignKey) -> Self {
255        self.foreign_key = Some(foreign_key);
256        self
257    }
258}
259
260/// Compares a target OpenAuth schema with a SQL schema snapshot and emits an additive plan.
261pub fn plan_schema_migration(
262    dialect: SqlDialect,
263    schema: &DbSchema,
264    snapshot: &SqlSchemaSnapshot,
265) -> Result<SchemaMigrationPlan, OpenAuthError> {
266    let mut plan = SchemaMigrationPlan::default();
267    let mut tables = schema.tables().collect::<Vec<_>>();
268    tables.sort_by_key(|(_, table)| table.order.unwrap_or(u16::MAX));
269
270    for (table_logical_name, table) in &tables {
271        if snapshot.table_exists(&table.name) {
272            for (logical_name, field) in &table.fields {
273                if let Some(column) = snapshot.column(&table.name, &field.name) {
274                    if !dialect.type_matches(&column.data_type, field) {
275                        plan.warnings
276                            .push(SchemaMigrationWarning::ColumnTypeMismatch {
277                                table_name: table.name.clone(),
278                                column_name: field.name.clone(),
279                                expected: dialect.sql_type(logical_name, field),
280                                actual: column.data_type.clone(),
281                            });
282                    }
283                    push_constraint_warnings(&mut plan, table, logical_name, field, column);
284                } else {
285                    plan.to_be_added.push(ColumnToAdd {
286                        table_logical_name: (*table_logical_name).to_owned(),
287                        table_name: table.name.clone(),
288                        field_logical_name: logical_name.clone(),
289                        column_name: field.name.clone(),
290                    });
291                    plan.statements.push(MigrationStatement {
292                        kind: MigrationStatementKind::AddColumn,
293                        sql: dialect.add_column_statement(&table.name, logical_name, field)?,
294                    });
295                }
296            }
297        } else {
298            plan.to_be_created.push(TableToCreate {
299                logical_name: (*table_logical_name).to_owned(),
300                table_name: table.name.clone(),
301            });
302            plan.statements.push(MigrationStatement {
303                kind: MigrationStatementKind::CreateTable,
304                sql: dialect.create_table_statement(table)?,
305            });
306        }
307    }
308
309    for (table_logical_name, table) in tables {
310        let table_exists = snapshot.table_exists(&table.name);
311        for (logical_name, field) in &table.fields {
312            if field.index || field.unique {
313                if field.unique
314                    && (!table_exists || snapshot.unique_column_exists(&table.name, &field.name))
315                {
316                    continue;
317                }
318                let prefix = if field.unique { "uidx" } else { "idx" };
319                let index_name = dialect
320                    .sanitize_identifier(&format!("{prefix}_{}_{}", table.name, logical_name))?;
321                if !snapshot.index_exists(&table.name, &index_name) {
322                    plan.indexes_to_be_created.push(IndexToCreate {
323                        table_logical_name: table_logical_name.to_owned(),
324                        table_name: table.name.clone(),
325                        field_logical_name: logical_name.clone(),
326                        column_name: field.name.clone(),
327                        index_name: index_name.clone(),
328                        unique: field.unique,
329                    });
330                    plan.statements.push(MigrationStatement {
331                        kind: MigrationStatementKind::CreateIndex,
332                        sql: dialect.create_index_statement(
333                            &table.name,
334                            &field.name,
335                            &index_name,
336                            field.unique,
337                        )?,
338                    });
339                }
340            }
341        }
342    }
343
344    Ok(plan)
345}
346
347fn push_constraint_warnings(
348    plan: &mut SchemaMigrationPlan,
349    table: &DbTable,
350    logical_name: &str,
351    field: &DbField,
352    column: &SqlColumnSnapshot,
353) {
354    if logical_name == "id" || field.name == "id" {
355        if column.primary_key == Some(false) {
356            plan.warnings
357                .push(SchemaMigrationWarning::PrimaryKeyMismatch {
358                    table_name: table.name.clone(),
359                    column_name: field.name.clone(),
360                });
361        }
362    } else if let Some(actual_nullable) = column.nullable {
363        let expected_nullable = !field.required;
364        if expected_nullable != actual_nullable {
365            plan.warnings
366                .push(SchemaMigrationWarning::ColumnNullabilityMismatch {
367                    table_name: table.name.clone(),
368                    column_name: field.name.clone(),
369                    expected_nullable,
370                    actual_nullable,
371                });
372        }
373    }
374
375    if logical_name == "id" || field.name == "id" {
376        if let Some(expected) = field.generated_id {
377            if column.generated_id != Some(expected) {
378                plan.warnings
379                    .push(SchemaMigrationWarning::GeneratedIdMismatch {
380                        table_name: table.name.clone(),
381                        column_name: field.name.clone(),
382                        expected,
383                        actual: column.generated_id,
384                    });
385            }
386        }
387    }
388
389    if let Some(expected) = &field.foreign_key {
390        if column.foreign_key.as_ref() != Some(expected) {
391            plan.warnings
392                .push(SchemaMigrationWarning::ForeignKeyMismatch {
393                    table_name: table.name.clone(),
394                    column_name: field.name.clone(),
395                    expected: expected.clone(),
396                    actual: column.foreign_key.clone(),
397                });
398        }
399    }
400}