Skip to main content

yauth_migration/
diff.rs

1//! Schema diff engine.
2//!
3//! Compares two `YAuthSchema` snapshots (previous plugins vs current)
4//! and produces incremental SQL operations.
5
6use crate::collector::YAuthSchema;
7use crate::mysql::{mysql_default, mysql_type};
8use crate::postgres::pg_type;
9use crate::sqlite::{sqlite_default, sqlite_type};
10use crate::types::TableDef;
11
12/// A single schema change operation.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum SchemaChange {
15    /// A new table needs to be created (includes all columns).
16    CreateTable(TableDef),
17    /// An existing table needs to be dropped.
18    DropTable(TableDef),
19    /// A new column needs to be added to an existing table.
20    AddColumn {
21        table_name: String,
22        column: crate::types::ColumnDef,
23    },
24    /// A column needs to be removed from an existing table.
25    DropColumn {
26        table_name: String,
27        column_name: String,
28    },
29}
30
31/// Compute the diff between two schemas.
32///
33/// `from` is the previous schema state, `to` is the desired state.
34/// Returns a list of changes needed to go from `from` to `to`.
35pub fn schema_diff(from: &YAuthSchema, to: &YAuthSchema) -> Vec<SchemaChange> {
36    let mut changes = Vec::new();
37
38    let from_tables: std::collections::HashMap<&str, &TableDef> =
39        from.tables.iter().map(|t| (t.name.as_str(), t)).collect();
40    let to_tables: std::collections::HashMap<&str, &TableDef> =
41        to.tables.iter().map(|t| (t.name.as_str(), t)).collect();
42
43    // Tables to create (in `to` but not in `from`) -- preserve topological order
44    for table in &to.tables {
45        if !from_tables.contains_key(table.name.as_str()) {
46            changes.push(SchemaChange::CreateTable(table.clone()));
47        }
48    }
49
50    // Tables to drop (in `from` but not in `to`) -- reverse topological order
51    for table in from.tables.iter().rev() {
52        if !to_tables.contains_key(table.name.as_str()) {
53            changes.push(SchemaChange::DropTable(table.clone()));
54        }
55    }
56
57    // Column-level changes for tables that exist in both
58    for table in &to.tables {
59        if let Some(from_table) = from_tables.get(table.name.as_str()) {
60            let from_cols: std::collections::HashSet<&str> =
61                from_table.columns.iter().map(|c| c.name.as_str()).collect();
62            let to_cols: std::collections::HashSet<&str> =
63                table.columns.iter().map(|c| c.name.as_str()).collect();
64
65            // New columns
66            for col in &table.columns {
67                if !from_cols.contains(col.name.as_str()) {
68                    changes.push(SchemaChange::AddColumn {
69                        table_name: table.name.clone(),
70                        column: col.clone(),
71                    });
72                }
73            }
74
75            // Dropped columns
76            for col in &from_table.columns {
77                if !to_cols.contains(col.name.as_str()) {
78                    changes.push(SchemaChange::DropColumn {
79                        table_name: table.name.clone(),
80                        column_name: col.name.clone(),
81                    });
82                }
83            }
84        }
85    }
86
87    changes
88}
89
90/// Render schema changes as SQL for the given dialect.
91pub fn render_changes_sql(changes: &[SchemaChange], dialect: crate::Dialect) -> (String, String) {
92    let mut up_sql = String::new();
93    let mut down_sql = String::new();
94
95    for change in changes {
96        match change {
97            SchemaChange::CreateTable(table) => {
98                let schema = YAuthSchema {
99                    tables: vec![table.clone()],
100                };
101                let create = match dialect {
102                    crate::Dialect::Postgres => crate::generate_postgres_ddl(&schema),
103                    crate::Dialect::Sqlite => {
104                        // Don't include PRAGMA for individual table creates
105                        generate_single_table_sqlite(table)
106                    }
107                    crate::Dialect::Mysql => crate::generate_mysql_ddl(&schema),
108                };
109                up_sql.push_str(&create);
110                up_sql.push('\n');
111
112                // Down: drop
113                let drop = match dialect {
114                    crate::Dialect::Postgres => crate::generate_postgres_drop(table),
115                    crate::Dialect::Sqlite => crate::generate_sqlite_drop(table),
116                    crate::Dialect::Mysql => crate::generate_mysql_drop(table),
117                };
118                down_sql.push_str(&drop);
119                down_sql.push('\n');
120            }
121            SchemaChange::DropTable(table) => {
122                let drop = match dialect {
123                    crate::Dialect::Postgres => crate::generate_postgres_drop(table),
124                    crate::Dialect::Sqlite => crate::generate_sqlite_drop(table),
125                    crate::Dialect::Mysql => crate::generate_mysql_drop(table),
126                };
127                up_sql.push_str(&drop);
128                up_sql.push('\n');
129
130                // Down: recreate
131                let schema = YAuthSchema {
132                    tables: vec![table.clone()],
133                };
134                let create = match dialect {
135                    crate::Dialect::Postgres => crate::generate_postgres_ddl(&schema),
136                    crate::Dialect::Sqlite => generate_single_table_sqlite(table),
137                    crate::Dialect::Mysql => crate::generate_mysql_ddl(&schema),
138                };
139                down_sql.push_str(&create);
140                down_sql.push('\n');
141            }
142            SchemaChange::AddColumn { table_name, column } => {
143                let stmt = render_add_column(table_name, column, dialect);
144                up_sql.push_str(&stmt);
145                up_sql.push('\n');
146
147                let drop_stmt = render_drop_column(table_name, &column.name, dialect);
148                down_sql.push_str(&drop_stmt);
149                down_sql.push('\n');
150            }
151            SchemaChange::DropColumn {
152                table_name,
153                column_name,
154            } => {
155                let stmt = render_drop_column(table_name, column_name, dialect);
156                up_sql.push_str(&stmt);
157                up_sql.push('\n');
158                // Down for drop column is hard without the original column def,
159                // so we add a comment.
160                down_sql.push_str(&format!(
161                    "-- TODO: Re-add column {column_name} to {table_name}\n\n"
162                ));
163            }
164        }
165    }
166
167    (up_sql, down_sql)
168}
169
170fn render_add_column(
171    table_name: &str,
172    column: &crate::types::ColumnDef,
173    dialect: crate::Dialect,
174) -> String {
175    match dialect {
176        crate::Dialect::Postgres => {
177            let col_type = pg_type(&column.col_type);
178            let mut stmt = format!(
179                "ALTER TABLE {} ADD COLUMN {} {}",
180                table_name, column.name, col_type
181            );
182            if !column.nullable && column.default.is_none() {
183                // Can't add NOT NULL without a default to a table with existing rows
184                stmt.push_str(" NULL");
185            } else {
186                if !column.nullable {
187                    stmt.push_str(" NOT NULL");
188                }
189                if let Some(ref default) = column.default {
190                    stmt.push_str(&format!(" DEFAULT {}", default));
191                }
192            }
193            stmt.push_str(";\n");
194            stmt
195        }
196        crate::Dialect::Sqlite => {
197            let col_type = sqlite_type(&column.col_type);
198            let mut stmt = format!(
199                "ALTER TABLE {} ADD COLUMN {} {}",
200                table_name, column.name, col_type
201            );
202            if !column.nullable && column.default.is_none() {
203                stmt.push_str(" NULL");
204            } else {
205                if !column.nullable {
206                    stmt.push_str(" NOT NULL");
207                }
208                if let Some(ref default) = column.default
209                    && let Some(d) = sqlite_default(default)
210                {
211                    stmt.push_str(&format!(" DEFAULT {}", d));
212                }
213            }
214            stmt.push_str(";\n");
215            stmt
216        }
217        crate::Dialect::Mysql => {
218            let col_type = mysql_type(&column.col_type);
219            let mut stmt = format!(
220                "ALTER TABLE `{}` ADD COLUMN `{}` {}",
221                table_name, column.name, col_type
222            );
223            if !column.nullable && column.default.is_none() {
224                stmt.push_str(" NULL");
225            } else {
226                if !column.nullable {
227                    stmt.push_str(" NOT NULL");
228                }
229                if let Some(ref default) = column.default
230                    && let Some(d) = mysql_default(default)
231                {
232                    stmt.push_str(&format!(" DEFAULT {}", d));
233                }
234            }
235            stmt.push_str(";\n");
236            stmt
237        }
238    }
239}
240
241fn render_drop_column(table_name: &str, column_name: &str, dialect: crate::Dialect) -> String {
242    match dialect {
243        crate::Dialect::Postgres => {
244            format!(
245                "ALTER TABLE {} DROP COLUMN IF EXISTS {};\n",
246                table_name, column_name
247            )
248        }
249        crate::Dialect::Sqlite => {
250            format!("ALTER TABLE {} DROP COLUMN {};\n", table_name, column_name)
251        }
252        crate::Dialect::Mysql => {
253            format!(
254                "ALTER TABLE `{}` DROP COLUMN `{}`;\n",
255                table_name, column_name
256            )
257        }
258    }
259}
260
261fn generate_single_table_sqlite(table: &TableDef) -> String {
262    // Reuse the full generator but strip the PRAGMA
263    let schema = YAuthSchema {
264        tables: vec![table.clone()],
265    };
266    let full = crate::generate_sqlite_ddl(&schema);
267    // Strip PRAGMA line
268    full.lines()
269        .filter(|l| !l.starts_with("PRAGMA"))
270        .collect::<Vec<_>>()
271        .join("\n")
272        .trim_start_matches('\n')
273        .to_string()
274        + "\n"
275}
276
277/// Format a text diff of two SQL strings for display.
278pub fn format_sql_diff(old: &str, new: &str) -> String {
279    use similar::{ChangeTag, TextDiff};
280
281    let diff = TextDiff::from_lines(old, new);
282    let mut output = String::new();
283
284    for change in diff.iter_all_changes() {
285        let sign = match change.tag() {
286            ChangeTag::Delete => "-",
287            ChangeTag::Insert => "+",
288            ChangeTag::Equal => " ",
289        };
290        output.push_str(sign);
291        output.push_str(change.as_str().unwrap_or(""));
292        if !change.as_str().unwrap_or("").ends_with('\n') {
293            output.push('\n');
294        }
295    }
296
297    output
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use crate::{collect_schema, core_schema, plugin_schemas};
304
305    #[test]
306    fn diff_empty_to_core_creates_tables() {
307        let from = YAuthSchema { tables: vec![] };
308        let to = collect_schema(vec![core_schema()]).unwrap();
309        let changes = schema_diff(&from, &to);
310
311        assert_eq!(changes.len(), 6);
312        let table_names: Vec<&str> = changes
313            .iter()
314            .filter_map(|c| match c {
315                SchemaChange::CreateTable(t) => Some(t.name.as_str()),
316                _ => None,
317            })
318            .collect();
319        for expected in &[
320            "yauth_users",
321            "yauth_sessions",
322            "yauth_audit_log",
323            "yauth_challenges",
324            "yauth_rate_limits",
325            "yauth_revocations",
326        ] {
327            assert!(table_names.contains(expected), "Missing table: {expected}");
328        }
329    }
330
331    #[test]
332    fn diff_add_plugin_creates_plugin_tables() {
333        let from = collect_schema(vec![core_schema()]).unwrap();
334        let to = collect_schema(vec![core_schema(), plugin_schemas::mfa_schema()]).unwrap();
335
336        let changes = schema_diff(&from, &to);
337        assert_eq!(changes.len(), 2);
338        assert!(
339            matches!(&changes[0], SchemaChange::CreateTable(t) if t.name == "yauth_totp_secrets")
340        );
341        assert!(
342            matches!(&changes[1], SchemaChange::CreateTable(t) if t.name == "yauth_backup_codes")
343        );
344    }
345
346    #[test]
347    fn diff_remove_plugin_drops_plugin_tables() {
348        let from = collect_schema(vec![core_schema(), plugin_schemas::passkey_schema()]).unwrap();
349        let to = collect_schema(vec![core_schema()]).unwrap();
350
351        let changes = schema_diff(&from, &to);
352        assert_eq!(changes.len(), 1);
353        assert!(
354            matches!(&changes[0], SchemaChange::DropTable(t) if t.name == "yauth_webauthn_credentials")
355        );
356    }
357
358    #[test]
359    fn diff_no_changes() {
360        let schema = collect_schema(vec![core_schema()]).unwrap();
361        let changes = schema_diff(&schema, &schema);
362        assert!(changes.is_empty());
363    }
364
365    #[test]
366    fn diff_add_mfa_produces_valid_postgres_sql() {
367        let from = collect_schema(vec![core_schema()]).unwrap();
368        let to = collect_schema(vec![core_schema(), plugin_schemas::mfa_schema()]).unwrap();
369
370        let changes = schema_diff(&from, &to);
371        let (up, down) = render_changes_sql(&changes, crate::Dialect::Postgres);
372
373        assert!(up.contains("CREATE TABLE IF NOT EXISTS yauth_totp_secrets"));
374        assert!(up.contains("CREATE TABLE IF NOT EXISTS yauth_backup_codes"));
375        assert!(down.contains("DROP TABLE IF EXISTS yauth_totp_secrets CASCADE"));
376        assert!(down.contains("DROP TABLE IF EXISTS yauth_backup_codes CASCADE"));
377    }
378
379    #[test]
380    fn diff_add_mfa_produces_valid_sqlite_sql() {
381        let from = collect_schema(vec![core_schema()]).unwrap();
382        let to = collect_schema(vec![core_schema(), plugin_schemas::mfa_schema()]).unwrap();
383
384        let changes = schema_diff(&from, &to);
385        let (up, _down) = render_changes_sql(&changes, crate::Dialect::Sqlite);
386
387        assert!(up.contains("CREATE TABLE IF NOT EXISTS yauth_totp_secrets"));
388        assert!(!up.contains("PRAGMA")); // Individual table creates shouldn't have PRAGMA
389    }
390
391    #[test]
392    fn diff_add_mfa_produces_valid_mysql_sql() {
393        let from = collect_schema(vec![core_schema()]).unwrap();
394        let to = collect_schema(vec![core_schema(), plugin_schemas::mfa_schema()]).unwrap();
395
396        let changes = schema_diff(&from, &to);
397        let (up, _down) = render_changes_sql(&changes, crate::Dialect::Mysql);
398
399        assert!(up.contains("CREATE TABLE IF NOT EXISTS `yauth_totp_secrets`"));
400        assert!(up.contains("ENGINE=InnoDB"));
401    }
402
403    #[test]
404    fn diff_complex_add_and_remove() {
405        // Start with email-password + passkey, end with email-password + mfa
406        let from = collect_schema(vec![
407            core_schema(),
408            plugin_schemas::email_password_schema(),
409            plugin_schemas::passkey_schema(),
410        ])
411        .unwrap();
412        let to = collect_schema(vec![
413            core_schema(),
414            plugin_schemas::email_password_schema(),
415            plugin_schemas::mfa_schema(),
416        ])
417        .unwrap();
418
419        let changes = schema_diff(&from, &to);
420
421        // Should create mfa tables and drop passkey table
422        let creates: Vec<_> = changes
423            .iter()
424            .filter(|c| matches!(c, SchemaChange::CreateTable(_)))
425            .collect();
426        let drops: Vec<_> = changes
427            .iter()
428            .filter(|c| matches!(c, SchemaChange::DropTable(_)))
429            .collect();
430
431        assert_eq!(creates.len(), 2); // totp_secrets + backup_codes
432        assert_eq!(drops.len(), 1); // webauthn_credentials
433    }
434
435    #[test]
436    fn format_diff_shows_additions() {
437        let old = "line1\nline2\n";
438        let new = "line1\nline2\nline3\n";
439        let diff = format_sql_diff(old, new);
440        assert!(diff.contains("+line3"));
441    }
442}