prax_migrate/
sql.rs

1//! SQL generation for migrations.
2
3use crate::diff::{
4    EnumAlterDiff, EnumDiff, FieldAlterDiff, FieldDiff, IndexDiff, ModelAlterDiff, ModelDiff,
5    SchemaDiff, ViewDiff,
6};
7
8/// SQL generator for PostgreSQL.
9pub struct PostgresSqlGenerator;
10
11impl PostgresSqlGenerator {
12    /// Generate SQL for a schema diff.
13    pub fn generate(&self, diff: &SchemaDiff) -> MigrationSql {
14        let mut up = Vec::new();
15        let mut down = Vec::new();
16
17        // Create enums first (they might be used in tables)
18        for enum_diff in &diff.create_enums {
19            up.push(self.create_enum(enum_diff));
20            down.push(self.drop_enum(&enum_diff.name));
21        }
22
23        // Drop enums (in reverse order)
24        for name in &diff.drop_enums {
25            up.push(self.drop_enum(name));
26            // Can't easily recreate dropped enums without knowing values
27        }
28
29        // Alter enums
30        for alter in &diff.alter_enums {
31            up.extend(self.alter_enum(alter));
32            // Reversing enum alterations is complex
33        }
34
35        // Create models
36        for model in &diff.create_models {
37            up.push(self.create_table(model));
38            down.push(self.drop_table(&model.table_name));
39        }
40
41        // Drop models
42        for name in &diff.drop_models {
43            up.push(self.drop_table(name));
44            // Can't easily recreate dropped tables
45        }
46
47        // Alter models
48        for alter in &diff.alter_models {
49            up.extend(self.alter_table(alter));
50            // Reverse alterations could be generated but complex
51        }
52
53        // Create indexes
54        for index in &diff.create_indexes {
55            up.push(self.create_index(index));
56            down.push(self.drop_index(&index.name, &index.table_name));
57        }
58
59        // Drop indexes
60        for index in &diff.drop_indexes {
61            up.push(self.drop_index(&index.name, &index.table_name));
62        }
63
64        // Create views (after tables they depend on)
65        for view in &diff.create_views {
66            up.push(self.create_view(view));
67            down.push(self.drop_view(&view.view_name, view.is_materialized));
68        }
69
70        // Drop views
71        for name in &diff.drop_views {
72            // We don't know if it was materialized, so try both
73            up.push(self.drop_view(name, false));
74        }
75
76        // Alter views (drop and recreate)
77        for view in &diff.alter_views {
78            // Drop the old view first
79            up.push(self.drop_view(&view.view_name, view.is_materialized));
80            // Then create the new one
81            up.push(self.create_view(view));
82        }
83
84        MigrationSql {
85            up: up.join("\n\n"),
86            down: down.join("\n\n"),
87        }
88    }
89
90    /// Generate CREATE TYPE for enum.
91    fn create_enum(&self, enum_diff: &EnumDiff) -> String {
92        let values: Vec<String> = enum_diff
93            .values
94            .iter()
95            .map(|v| format!("'{}'", v))
96            .collect();
97        format!(
98            "CREATE TYPE \"{}\" AS ENUM ({});",
99            enum_diff.name,
100            values.join(", ")
101        )
102    }
103
104    /// Generate DROP TYPE.
105    fn drop_enum(&self, name: &str) -> String {
106        format!("DROP TYPE IF EXISTS \"{}\";", name)
107    }
108
109    /// Generate ALTER TYPE statements.
110    fn alter_enum(&self, alter: &EnumAlterDiff) -> Vec<String> {
111        let mut stmts = Vec::new();
112
113        for value in &alter.add_values {
114            stmts.push(format!(
115                "ALTER TYPE \"{}\" ADD VALUE IF NOT EXISTS '{}';",
116                alter.name, value
117            ));
118        }
119
120        // Note: PostgreSQL doesn't support removing enum values directly
121        // This would require recreating the type
122
123        stmts
124    }
125
126    /// Generate CREATE TABLE statement.
127    fn create_table(&self, model: &ModelDiff) -> String {
128        let mut columns = Vec::new();
129
130        for field in &model.fields {
131            columns.push(self.column_definition(field));
132        }
133
134        // Add primary key constraint
135        if !model.primary_key.is_empty() {
136            let pk_cols: Vec<String> = model
137                .primary_key
138                .iter()
139                .map(|c| format!("\"{}\"", c))
140                .collect();
141            columns.push(format!("PRIMARY KEY ({})", pk_cols.join(", ")));
142        }
143
144        // Add unique constraints
145        for uc in &model.unique_constraints {
146            let cols: Vec<String> = uc.columns.iter().map(|c| format!("\"{}\"", c)).collect();
147            let constraint = if let Some(name) = &uc.name {
148                format!("CONSTRAINT \"{}\" UNIQUE ({})", name, cols.join(", "))
149            } else {
150                format!("UNIQUE ({})", cols.join(", "))
151            };
152            columns.push(constraint);
153        }
154
155        format!(
156            "CREATE TABLE \"{}\" (\n    {}\n);",
157            model.table_name,
158            columns.join(",\n    ")
159        )
160    }
161
162    /// Generate column definition.
163    fn column_definition(&self, field: &FieldDiff) -> String {
164        let mut parts = vec![format!("\"{}\"", field.column_name), field.sql_type.clone()];
165
166        if field.is_auto_increment {
167            // Replace type with SERIAL variants
168            if field.sql_type == "INTEGER" {
169                parts[1] = "SERIAL".to_string();
170            } else if field.sql_type == "BIGINT" {
171                parts[1] = "BIGSERIAL".to_string();
172            }
173        }
174
175        if !field.nullable && !field.is_primary_key {
176            parts.push("NOT NULL".to_string());
177        }
178
179        if field.is_unique && !field.is_primary_key {
180            parts.push("UNIQUE".to_string());
181        }
182
183        if let Some(default) = &field.default {
184            parts.push(format!("DEFAULT {}", default));
185        }
186
187        parts.join(" ")
188    }
189
190    /// Generate DROP TABLE statement.
191    fn drop_table(&self, name: &str) -> String {
192        format!("DROP TABLE IF EXISTS \"{}\" CASCADE;", name)
193    }
194
195    /// Generate ALTER TABLE statements.
196    fn alter_table(&self, alter: &ModelAlterDiff) -> Vec<String> {
197        let mut stmts = Vec::new();
198
199        // Add columns
200        for field in &alter.add_fields {
201            stmts.push(format!(
202                "ALTER TABLE \"{}\" ADD COLUMN {};",
203                alter.table_name,
204                self.column_definition(field)
205            ));
206        }
207
208        // Drop columns
209        for name in &alter.drop_fields {
210            stmts.push(format!(
211                "ALTER TABLE \"{}\" DROP COLUMN IF EXISTS \"{}\";",
212                alter.table_name, name
213            ));
214        }
215
216        // Alter columns
217        for field in &alter.alter_fields {
218            stmts.extend(self.alter_column(&alter.table_name, field));
219        }
220
221        // Add indexes
222        for index in &alter.add_indexes {
223            stmts.push(self.create_index(index));
224        }
225
226        // Drop indexes
227        for name in &alter.drop_indexes {
228            stmts.push(format!("DROP INDEX IF EXISTS \"{}\";", name));
229        }
230
231        stmts
232    }
233
234    /// Generate ALTER COLUMN statements.
235    fn alter_column(&self, table: &str, field: &FieldAlterDiff) -> Vec<String> {
236        let mut stmts = Vec::new();
237
238        if let Some(new_type) = &field.new_type {
239            stmts.push(format!(
240                "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" TYPE {} USING \"{}\"::{};",
241                table, field.column_name, new_type, field.column_name, new_type
242            ));
243        }
244
245        if let Some(new_nullable) = field.new_nullable {
246            if new_nullable {
247                stmts.push(format!(
248                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP NOT NULL;",
249                    table, field.column_name
250                ));
251            } else {
252                stmts.push(format!(
253                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET NOT NULL;",
254                    table, field.column_name
255                ));
256            }
257        }
258
259        if let Some(new_default) = &field.new_default {
260            stmts.push(format!(
261                "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET DEFAULT {};",
262                table, field.column_name, new_default
263            ));
264        }
265
266        stmts
267    }
268
269    /// Generate CREATE INDEX statement.
270    fn create_index(&self, index: &IndexDiff) -> String {
271        let unique = if index.unique { "UNIQUE " } else { "" };
272        let cols: Vec<String> = index.columns.iter().map(|c| format!("\"{}\"", c)).collect();
273        format!(
274            "CREATE {}INDEX \"{}\" ON \"{}\" ({});",
275            unique,
276            index.name,
277            index.table_name,
278            cols.join(", ")
279        )
280    }
281
282    /// Generate DROP INDEX statement.
283    fn drop_index(&self, name: &str, _table: &str) -> String {
284        format!("DROP INDEX IF EXISTS \"{}\";", name)
285    }
286
287    /// Generate CREATE VIEW statement.
288    fn create_view(&self, view: &ViewDiff) -> String {
289        if view.is_materialized {
290            format!(
291                "CREATE MATERIALIZED VIEW \"{}\" AS\n{};",
292                view.view_name, view.sql_query
293            )
294        } else {
295            format!(
296                "CREATE OR REPLACE VIEW \"{}\" AS\n{};",
297                view.view_name, view.sql_query
298            )
299        }
300    }
301
302    /// Generate DROP VIEW statement.
303    fn drop_view(&self, name: &str, is_materialized: bool) -> String {
304        if is_materialized {
305            format!("DROP MATERIALIZED VIEW IF EXISTS \"{}\" CASCADE;", name)
306        } else {
307            format!("DROP VIEW IF EXISTS \"{}\" CASCADE;", name)
308        }
309    }
310
311    /// Generate REFRESH MATERIALIZED VIEW statement.
312    #[allow(dead_code)]
313    fn refresh_materialized_view(&self, name: &str, concurrently: bool) -> String {
314        if concurrently {
315            format!("REFRESH MATERIALIZED VIEW CONCURRENTLY \"{}\";", name)
316        } else {
317            format!("REFRESH MATERIALIZED VIEW \"{}\";", name)
318        }
319    }
320}
321
322/// Generated SQL for a migration.
323#[derive(Debug, Clone)]
324pub struct MigrationSql {
325    /// SQL to apply the migration.
326    pub up: String,
327    /// SQL to rollback the migration.
328    pub down: String,
329}
330
331impl MigrationSql {
332    /// Check if the migration is empty.
333    pub fn is_empty(&self) -> bool {
334        self.up.trim().is_empty()
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_create_enum() {
344        let generator = PostgresSqlGenerator;
345        let enum_diff = EnumDiff {
346            name: "Status".to_string(),
347            values: vec!["PENDING".to_string(), "ACTIVE".to_string()],
348        };
349
350        let sql = generator.create_enum(&enum_diff);
351        assert!(sql.contains("CREATE TYPE"));
352        assert!(sql.contains("Status"));
353        assert!(sql.contains("PENDING"));
354        assert!(sql.contains("ACTIVE"));
355    }
356
357    #[test]
358    fn test_create_table() {
359        let generator = PostgresSqlGenerator;
360        let model = ModelDiff {
361            name: "User".to_string(),
362            table_name: "users".to_string(),
363            fields: vec![
364                FieldDiff {
365                    name: "id".to_string(),
366                    column_name: "id".to_string(),
367                    sql_type: "INTEGER".to_string(),
368                    nullable: false,
369                    default: None,
370                    is_primary_key: true,
371                    is_auto_increment: true,
372                    is_unique: false,
373                },
374                FieldDiff {
375                    name: "email".to_string(),
376                    column_name: "email".to_string(),
377                    sql_type: "TEXT".to_string(),
378                    nullable: false,
379                    default: None,
380                    is_primary_key: false,
381                    is_auto_increment: false,
382                    is_unique: true,
383                },
384            ],
385            primary_key: vec!["id".to_string()],
386            indexes: Vec::new(),
387            unique_constraints: Vec::new(),
388        };
389
390        let sql = generator.create_table(&model);
391        assert!(sql.contains("CREATE TABLE"));
392        assert!(sql.contains("users"));
393        assert!(sql.contains("SERIAL"));
394        assert!(sql.contains("email"));
395        assert!(sql.contains("UNIQUE"));
396        assert!(sql.contains("PRIMARY KEY"));
397    }
398
399    #[test]
400    fn test_create_index() {
401        let generator = PostgresSqlGenerator;
402        let index = IndexDiff {
403            name: "idx_users_email".to_string(),
404            table_name: "users".to_string(),
405            columns: vec!["email".to_string()],
406            unique: true,
407        };
408
409        let sql = generator.create_index(&index);
410        assert!(sql.contains("CREATE UNIQUE INDEX"));
411        assert!(sql.contains("idx_users_email"));
412        assert!(sql.contains("users"));
413    }
414
415    #[test]
416    fn test_alter_table_add_column() {
417        let generator = PostgresSqlGenerator;
418        let alter = ModelAlterDiff {
419            name: "User".to_string(),
420            table_name: "users".to_string(),
421            add_fields: vec![FieldDiff {
422                name: "age".to_string(),
423                column_name: "age".to_string(),
424                sql_type: "INTEGER".to_string(),
425                nullable: true,
426                default: None,
427                is_primary_key: false,
428                is_auto_increment: false,
429                is_unique: false,
430            }],
431            drop_fields: Vec::new(),
432            alter_fields: Vec::new(),
433            add_indexes: Vec::new(),
434            drop_indexes: Vec::new(),
435        };
436
437        let stmts = generator.alter_table(&alter);
438        assert_eq!(stmts.len(), 1);
439        assert!(stmts[0].contains("ADD COLUMN"));
440        assert!(stmts[0].contains("age"));
441    }
442
443    #[test]
444    fn test_create_view() {
445        let generator = PostgresSqlGenerator;
446        let view = ViewDiff {
447            name: "UserStats".to_string(),
448            view_name: "user_stats".to_string(),
449            sql_query: "SELECT id, COUNT(*) as post_count FROM users GROUP BY id".to_string(),
450            is_materialized: false,
451            refresh_interval: None,
452            fields: vec![],
453        };
454
455        let sql = generator.create_view(&view);
456        assert!(sql.contains("CREATE OR REPLACE VIEW"));
457        assert!(sql.contains("user_stats"));
458        assert!(sql.contains("SELECT id"));
459        assert!(sql.contains("post_count"));
460    }
461
462    #[test]
463    fn test_create_materialized_view() {
464        let generator = PostgresSqlGenerator;
465        let view = ViewDiff {
466            name: "UserStats".to_string(),
467            view_name: "user_stats".to_string(),
468            sql_query: "SELECT id, COUNT(*) as post_count FROM users GROUP BY id".to_string(),
469            is_materialized: true,
470            refresh_interval: Some("1h".to_string()),
471            fields: vec![],
472        };
473
474        let sql = generator.create_view(&view);
475        assert!(sql.contains("CREATE MATERIALIZED VIEW"));
476        assert!(sql.contains("user_stats"));
477        assert!(!sql.contains("OR REPLACE")); // Materialized views don't support OR REPLACE
478    }
479
480    #[test]
481    fn test_drop_view() {
482        let generator = PostgresSqlGenerator;
483
484        let sql = generator.drop_view("user_stats", false);
485        assert!(sql.contains("DROP VIEW"));
486        assert!(sql.contains("user_stats"));
487        assert!(sql.contains("CASCADE"));
488
489        let sql_mat = generator.drop_view("user_stats", true);
490        assert!(sql_mat.contains("DROP MATERIALIZED VIEW"));
491        assert!(sql_mat.contains("user_stats"));
492    }
493
494    #[test]
495    fn test_refresh_materialized_view() {
496        let generator = PostgresSqlGenerator;
497
498        let sql = generator.refresh_materialized_view("user_stats", false);
499        assert!(sql.contains("REFRESH MATERIALIZED VIEW"));
500        assert!(sql.contains("user_stats"));
501        assert!(!sql.contains("CONCURRENTLY"));
502
503        let sql_concurrent = generator.refresh_materialized_view("user_stats", true);
504        assert!(sql_concurrent.contains("CONCURRENTLY"));
505    }
506
507    #[test]
508    fn test_generate_with_views() {
509        use crate::diff::SchemaDiff;
510
511        let generator = PostgresSqlGenerator;
512        let mut diff = SchemaDiff::default();
513        diff.create_views.push(ViewDiff {
514            name: "ActiveUsers".to_string(),
515            view_name: "active_users".to_string(),
516            sql_query: "SELECT * FROM users WHERE active = true".to_string(),
517            is_materialized: false,
518            refresh_interval: None,
519            fields: vec![],
520        });
521
522        let sql = generator.generate(&diff);
523        assert!(!sql.is_empty());
524        assert!(sql.up.contains("CREATE OR REPLACE VIEW"));
525        assert!(sql.up.contains("active_users"));
526        assert!(sql.down.contains("DROP VIEW"));
527    }
528}