Skip to main content

forge_runtime/migrations/
diff.rs

1use forge_core::schema::{FieldDef, TableDef};
2
3/// Quote a SQL identifier to prevent injection via malformed names.
4/// Doubles any embedded double-quotes per the SQL standard.
5fn quote_ident(name: &str) -> String {
6    format!("\"{}\"", name.replace('"', "\"\""))
7}
8
9/// Represents the difference between two schemas.
10///
11/// The diff algorithm compares the Rust schema (source of truth) against the
12/// database schema (current state) to produce a set of changes. The comparison
13/// is done at two levels:
14///
15/// 1. **Table level**: Find tables that exist in Rust but not DB (CREATE),
16///    or exist in DB but not Rust (DROP, except forge_ internal tables).
17///
18/// 2. **Column level**: For tables in both, compare fields:
19///    - Field in Rust but not DB → ADD COLUMN
20///    - Field in DB but not Rust → DROP COLUMN
21///    - Field in both but different type → ALTER COLUMN TYPE
22///
23/// The algorithm is intentionally simple and doesn't handle:
24/// - Column renames (seen as DROP + ADD)
25/// - Index changes (handled separately)
26/// - Complex type migrations (require manual migration)
27///
28/// This is by design: automatic migrations are for development convenience.
29/// Production deployments should use explicit migration files.
30#[derive(Debug, Clone)]
31pub struct SchemaDiff {
32    /// Changes to be applied.
33    pub entries: Vec<DiffEntry>,
34}
35
36impl SchemaDiff {
37    /// Create an empty diff.
38    pub fn new() -> Self {
39        Self {
40            entries: Vec::new(),
41        }
42    }
43
44    /// Compare a Rust schema to a database schema.
45    pub fn from_comparison(rust_tables: &[TableDef], db_tables: &[DatabaseTable]) -> Self {
46        let mut entries = Vec::new();
47
48        // Find tables to add
49        for rust_table in rust_tables {
50            let db_table = db_tables.iter().find(|t| t.name == rust_table.name);
51
52            match db_table {
53                None => {
54                    // Table doesn't exist, create it
55                    // Note: Actual SQL should come from migrations, this is just for diff detection
56                    entries.push(DiffEntry {
57                        action: DiffAction::CreateTable,
58                        table_name: rust_table.name.clone(),
59                        details: format!("Create table {}", rust_table.name),
60                        sql: format!("-- Create table {} (see migrations)", rust_table.name),
61                    });
62                }
63                Some(db) => {
64                    // Compare columns
65                    for rust_field in &rust_table.fields {
66                        let db_column =
67                            db.columns.iter().find(|c| c.name == rust_field.column_name);
68
69                        match db_column {
70                            None => {
71                                // Column doesn't exist, add it
72                                entries.push(DiffEntry {
73                                    action: DiffAction::AddColumn,
74                                    table_name: rust_table.name.clone(),
75                                    details: format!("Add column {}", rust_field.column_name),
76                                    sql: Self::add_column_sql(&rust_table.name, rust_field),
77                                });
78                            }
79                            Some(db_col) => {
80                                // Check if column type changed
81                                let rust_type = rust_field.sql_type.to_sql();
82                                if db_col.data_type != rust_type {
83                                    entries.push(DiffEntry {
84                                        action: DiffAction::AlterColumn,
85                                        table_name: rust_table.name.clone(),
86                                        details: format!(
87                                            "Change column {} type from {} to {}",
88                                            rust_field.column_name, db_col.data_type, rust_type
89                                        ),
90                                        sql: format!(
91                                            "ALTER TABLE {} ALTER COLUMN {} TYPE {};",
92                                            quote_ident(&rust_table.name),
93                                            quote_ident(&rust_field.column_name),
94                                            rust_type
95                                        ),
96                                    });
97                                }
98                            }
99                        }
100                    }
101
102                    // Find columns to drop (exist in DB but not in Rust)
103                    for db_col in &db.columns {
104                        let exists_in_rust = rust_table
105                            .fields
106                            .iter()
107                            .any(|f| f.column_name == db_col.name);
108
109                        if !exists_in_rust {
110                            entries.push(DiffEntry {
111                                action: DiffAction::DropColumn,
112                                table_name: rust_table.name.clone(),
113                                details: format!("Drop column {}", db_col.name),
114                                sql: format!(
115                                    "ALTER TABLE {} DROP COLUMN {};",
116                                    quote_ident(&rust_table.name),
117                                    quote_ident(&db_col.name)
118                                ),
119                            });
120                        }
121                    }
122                }
123            }
124        }
125
126        // Find tables to drop (exist in DB but not in Rust)
127        for db_table in db_tables {
128            let exists_in_rust = rust_tables.iter().any(|t| t.name == db_table.name);
129
130            if !exists_in_rust && !db_table.name.starts_with("forge_") {
131                entries.push(DiffEntry {
132                    action: DiffAction::DropTable,
133                    table_name: db_table.name.clone(),
134                    details: format!("Drop table {}", db_table.name),
135                    sql: format!("DROP TABLE {};", quote_ident(&db_table.name)),
136                });
137            }
138        }
139
140        Self { entries }
141    }
142
143    fn add_column_sql(table_name: &str, field: &FieldDef) -> String {
144        let mut sql = format!(
145            "ALTER TABLE {} ADD COLUMN {} {}",
146            quote_ident(table_name),
147            quote_ident(&field.column_name),
148            field.sql_type.to_sql()
149        );
150
151        if !field.nullable {
152            // For non-nullable columns, provide a sensible default
153            let default_val = match field.sql_type {
154                forge_core::schema::SqlType::Varchar(_) | forge_core::schema::SqlType::Text => "''",
155                forge_core::schema::SqlType::Integer | forge_core::schema::SqlType::BigInt => "0",
156                forge_core::schema::SqlType::Boolean => "false",
157                forge_core::schema::SqlType::Timestamptz => "NOW()",
158                _ => "NULL",
159            };
160            sql.push_str(&format!(" NOT NULL DEFAULT {}", default_val));
161        }
162
163        sql.push(';');
164        sql
165    }
166
167    /// Check if there are any changes.
168    pub fn is_empty(&self) -> bool {
169        self.entries.is_empty()
170    }
171
172    /// Get all SQL statements.
173    pub fn to_sql(&self) -> Vec<String> {
174        self.entries.iter().map(|e| e.sql.clone()).collect()
175    }
176}
177
178impl Default for SchemaDiff {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184/// A single diff entry.
185#[derive(Debug, Clone)]
186pub struct DiffEntry {
187    /// Type of action.
188    pub action: DiffAction,
189    /// Affected table name.
190    pub table_name: String,
191    /// Human-readable description.
192    pub details: String,
193    /// SQL to apply.
194    pub sql: String,
195}
196
197/// Type of schema change.
198#[derive(Debug, Clone, Copy, PartialEq, Eq)]
199pub enum DiffAction {
200    CreateTable,
201    DropTable,
202    AddColumn,
203    DropColumn,
204    AlterColumn,
205    AddIndex,
206    DropIndex,
207    CreateEnum,
208    AlterEnum,
209}
210
211/// Representation of a database table (from introspection).
212#[derive(Debug, Clone)]
213pub struct DatabaseTable {
214    pub name: String,
215    pub columns: Vec<DatabaseColumn>,
216}
217
218/// Representation of a database column (from introspection).
219#[derive(Debug, Clone)]
220pub struct DatabaseColumn {
221    pub name: String,
222    pub data_type: String,
223    pub nullable: bool,
224    pub default: Option<String>,
225}
226
227#[cfg(test)]
228#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
229mod tests {
230    use super::*;
231    use forge_core::schema::RustType;
232    use forge_core::schema::{FieldDef, TableDef};
233
234    #[test]
235    fn test_empty_diff() {
236        let diff = SchemaDiff::new();
237        assert!(diff.is_empty());
238    }
239
240    #[test]
241    fn test_create_table_diff() {
242        let mut table = TableDef::new("users", "User");
243        table.fields.push(FieldDef::new("id", RustType::Uuid));
244
245        let diff = SchemaDiff::from_comparison(&[table], &[]);
246
247        assert_eq!(diff.entries.len(), 1);
248        assert_eq!(diff.entries[0].action, DiffAction::CreateTable);
249    }
250
251    #[test]
252    fn test_add_column_diff() {
253        let mut rust_table = TableDef::new("users", "User");
254        rust_table.fields.push(FieldDef::new("id", RustType::Uuid));
255        rust_table
256            .fields
257            .push(FieldDef::new("email", RustType::String));
258
259        let db_table = DatabaseTable {
260            name: "users".to_string(),
261            columns: vec![DatabaseColumn {
262                name: "id".to_string(),
263                data_type: "UUID".to_string(),
264                nullable: false,
265                default: None,
266            }],
267        };
268
269        let diff = SchemaDiff::from_comparison(&[rust_table], &[db_table]);
270
271        assert_eq!(diff.entries.len(), 1);
272        assert_eq!(diff.entries[0].action, DiffAction::AddColumn);
273        assert!(diff.entries[0].details.contains("email"));
274    }
275}