1use crate::diff::{
4 EnumAlterDiff, EnumDiff, FieldAlterDiff, FieldDiff, IndexDiff, ModelAlterDiff, ModelDiff,
5 SchemaDiff, ViewDiff,
6};
7
8pub struct PostgresSqlGenerator;
10
11impl PostgresSqlGenerator {
12 pub fn generate(&self, diff: &SchemaDiff) -> MigrationSql {
14 let mut up = Vec::new();
15 let mut down = Vec::new();
16
17 for enum_diff in &diff.create_enums {
19 up.push(self.create_enum(enum_diff));
20 down.push(self.drop_enum(&enum_diff.name));
21 }
22
23 for name in &diff.drop_enums {
25 up.push(self.drop_enum(name));
26 }
28
29 for alter in &diff.alter_enums {
31 up.extend(self.alter_enum(alter));
32 }
34
35 for model in &diff.create_models {
37 up.push(self.create_table(model));
38 down.push(self.drop_table(&model.table_name));
39 }
40
41 for name in &diff.drop_models {
43 up.push(self.drop_table(name));
44 }
46
47 for alter in &diff.alter_models {
49 up.extend(self.alter_table(alter));
50 }
52
53 for index in &diff.create_indexes {
55 up.push(self.create_index(index));
56 down.push(self.drop_index(&index.name, &index.table_name));
57 }
58
59 for index in &diff.drop_indexes {
61 up.push(self.drop_index(&index.name, &index.table_name));
62 }
63
64 for view in &diff.create_views {
66 up.push(self.create_view(view));
67 down.push(self.drop_view(&view.view_name, view.is_materialized));
68 }
69
70 for name in &diff.drop_views {
72 up.push(self.drop_view(name, false));
74 }
75
76 for view in &diff.alter_views {
78 up.push(self.drop_view(&view.view_name, view.is_materialized));
80 up.push(self.create_view(view));
82 }
83
84 MigrationSql {
85 up: up.join("\n\n"),
86 down: down.join("\n\n"),
87 }
88 }
89
90 fn create_enum(&self, enum_diff: &EnumDiff) -> String {
92 let values: Vec<String> = enum_diff
93 .values
94 .iter()
95 .map(|v| format!("'{}'", v))
96 .collect();
97 format!(
98 "CREATE TYPE \"{}\" AS ENUM ({});",
99 enum_diff.name,
100 values.join(", ")
101 )
102 }
103
104 fn drop_enum(&self, name: &str) -> String {
106 format!("DROP TYPE IF EXISTS \"{}\";", name)
107 }
108
109 fn alter_enum(&self, alter: &EnumAlterDiff) -> Vec<String> {
111 let mut stmts = Vec::new();
112
113 for value in &alter.add_values {
114 stmts.push(format!(
115 "ALTER TYPE \"{}\" ADD VALUE IF NOT EXISTS '{}';",
116 alter.name, value
117 ));
118 }
119
120 stmts
124 }
125
126 fn create_table(&self, model: &ModelDiff) -> String {
128 let mut columns = Vec::new();
129
130 for field in &model.fields {
131 columns.push(self.column_definition(field));
132 }
133
134 if !model.primary_key.is_empty() {
136 let pk_cols: Vec<String> = model
137 .primary_key
138 .iter()
139 .map(|c| format!("\"{}\"", c))
140 .collect();
141 columns.push(format!("PRIMARY KEY ({})", pk_cols.join(", ")));
142 }
143
144 for uc in &model.unique_constraints {
146 let cols: Vec<String> = uc.columns.iter().map(|c| format!("\"{}\"", c)).collect();
147 let constraint = if let Some(name) = &uc.name {
148 format!("CONSTRAINT \"{}\" UNIQUE ({})", name, cols.join(", "))
149 } else {
150 format!("UNIQUE ({})", cols.join(", "))
151 };
152 columns.push(constraint);
153 }
154
155 format!(
156 "CREATE TABLE \"{}\" (\n {}\n);",
157 model.table_name,
158 columns.join(",\n ")
159 )
160 }
161
162 fn column_definition(&self, field: &FieldDiff) -> String {
164 let mut parts = vec![format!("\"{}\"", field.column_name), field.sql_type.clone()];
165
166 if field.is_auto_increment {
167 if field.sql_type == "INTEGER" {
169 parts[1] = "SERIAL".to_string();
170 } else if field.sql_type == "BIGINT" {
171 parts[1] = "BIGSERIAL".to_string();
172 }
173 }
174
175 if !field.nullable && !field.is_primary_key {
176 parts.push("NOT NULL".to_string());
177 }
178
179 if field.is_unique && !field.is_primary_key {
180 parts.push("UNIQUE".to_string());
181 }
182
183 if let Some(default) = &field.default {
184 parts.push(format!("DEFAULT {}", default));
185 }
186
187 parts.join(" ")
188 }
189
190 fn drop_table(&self, name: &str) -> String {
192 format!("DROP TABLE IF EXISTS \"{}\" CASCADE;", name)
193 }
194
195 fn alter_table(&self, alter: &ModelAlterDiff) -> Vec<String> {
197 let mut stmts = Vec::new();
198
199 for field in &alter.add_fields {
201 stmts.push(format!(
202 "ALTER TABLE \"{}\" ADD COLUMN {};",
203 alter.table_name,
204 self.column_definition(field)
205 ));
206 }
207
208 for name in &alter.drop_fields {
210 stmts.push(format!(
211 "ALTER TABLE \"{}\" DROP COLUMN IF EXISTS \"{}\";",
212 alter.table_name, name
213 ));
214 }
215
216 for field in &alter.alter_fields {
218 stmts.extend(self.alter_column(&alter.table_name, field));
219 }
220
221 for index in &alter.add_indexes {
223 stmts.push(self.create_index(index));
224 }
225
226 for name in &alter.drop_indexes {
228 stmts.push(format!("DROP INDEX IF EXISTS \"{}\";", name));
229 }
230
231 stmts
232 }
233
234 fn alter_column(&self, table: &str, field: &FieldAlterDiff) -> Vec<String> {
236 let mut stmts = Vec::new();
237
238 if let Some(new_type) = &field.new_type {
239 stmts.push(format!(
240 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" TYPE {} USING \"{}\"::{};",
241 table, field.column_name, new_type, field.column_name, new_type
242 ));
243 }
244
245 if let Some(new_nullable) = field.new_nullable {
246 if new_nullable {
247 stmts.push(format!(
248 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP NOT NULL;",
249 table, field.column_name
250 ));
251 } else {
252 stmts.push(format!(
253 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET NOT NULL;",
254 table, field.column_name
255 ));
256 }
257 }
258
259 if let Some(new_default) = &field.new_default {
260 stmts.push(format!(
261 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET DEFAULT {};",
262 table, field.column_name, new_default
263 ));
264 }
265
266 stmts
267 }
268
269 fn create_index(&self, index: &IndexDiff) -> String {
271 let unique = if index.unique { "UNIQUE " } else { "" };
272 let cols: Vec<String> = index.columns.iter().map(|c| format!("\"{}\"", c)).collect();
273 format!(
274 "CREATE {}INDEX \"{}\" ON \"{}\" ({});",
275 unique,
276 index.name,
277 index.table_name,
278 cols.join(", ")
279 )
280 }
281
282 fn drop_index(&self, name: &str, _table: &str) -> String {
284 format!("DROP INDEX IF EXISTS \"{}\";", name)
285 }
286
287 fn create_view(&self, view: &ViewDiff) -> String {
289 if view.is_materialized {
290 format!(
291 "CREATE MATERIALIZED VIEW \"{}\" AS\n{};",
292 view.view_name, view.sql_query
293 )
294 } else {
295 format!(
296 "CREATE OR REPLACE VIEW \"{}\" AS\n{};",
297 view.view_name, view.sql_query
298 )
299 }
300 }
301
302 fn drop_view(&self, name: &str, is_materialized: bool) -> String {
304 if is_materialized {
305 format!("DROP MATERIALIZED VIEW IF EXISTS \"{}\" CASCADE;", name)
306 } else {
307 format!("DROP VIEW IF EXISTS \"{}\" CASCADE;", name)
308 }
309 }
310
311 #[allow(dead_code)]
313 fn refresh_materialized_view(&self, name: &str, concurrently: bool) -> String {
314 if concurrently {
315 format!("REFRESH MATERIALIZED VIEW CONCURRENTLY \"{}\";", name)
316 } else {
317 format!("REFRESH MATERIALIZED VIEW \"{}\";", name)
318 }
319 }
320}
321
322#[derive(Debug, Clone)]
324pub struct MigrationSql {
325 pub up: String,
327 pub down: String,
329}
330
331impl MigrationSql {
332 pub fn is_empty(&self) -> bool {
334 self.up.trim().is_empty()
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_create_enum() {
344 let generator = PostgresSqlGenerator;
345 let enum_diff = EnumDiff {
346 name: "Status".to_string(),
347 values: vec!["PENDING".to_string(), "ACTIVE".to_string()],
348 };
349
350 let sql = generator.create_enum(&enum_diff);
351 assert!(sql.contains("CREATE TYPE"));
352 assert!(sql.contains("Status"));
353 assert!(sql.contains("PENDING"));
354 assert!(sql.contains("ACTIVE"));
355 }
356
357 #[test]
358 fn test_create_table() {
359 let generator = PostgresSqlGenerator;
360 let model = ModelDiff {
361 name: "User".to_string(),
362 table_name: "users".to_string(),
363 fields: vec![
364 FieldDiff {
365 name: "id".to_string(),
366 column_name: "id".to_string(),
367 sql_type: "INTEGER".to_string(),
368 nullable: false,
369 default: None,
370 is_primary_key: true,
371 is_auto_increment: true,
372 is_unique: false,
373 },
374 FieldDiff {
375 name: "email".to_string(),
376 column_name: "email".to_string(),
377 sql_type: "TEXT".to_string(),
378 nullable: false,
379 default: None,
380 is_primary_key: false,
381 is_auto_increment: false,
382 is_unique: true,
383 },
384 ],
385 primary_key: vec!["id".to_string()],
386 indexes: Vec::new(),
387 unique_constraints: Vec::new(),
388 };
389
390 let sql = generator.create_table(&model);
391 assert!(sql.contains("CREATE TABLE"));
392 assert!(sql.contains("users"));
393 assert!(sql.contains("SERIAL"));
394 assert!(sql.contains("email"));
395 assert!(sql.contains("UNIQUE"));
396 assert!(sql.contains("PRIMARY KEY"));
397 }
398
399 #[test]
400 fn test_create_index() {
401 let generator = PostgresSqlGenerator;
402 let index = IndexDiff {
403 name: "idx_users_email".to_string(),
404 table_name: "users".to_string(),
405 columns: vec!["email".to_string()],
406 unique: true,
407 };
408
409 let sql = generator.create_index(&index);
410 assert!(sql.contains("CREATE UNIQUE INDEX"));
411 assert!(sql.contains("idx_users_email"));
412 assert!(sql.contains("users"));
413 }
414
415 #[test]
416 fn test_alter_table_add_column() {
417 let generator = PostgresSqlGenerator;
418 let alter = ModelAlterDiff {
419 name: "User".to_string(),
420 table_name: "users".to_string(),
421 add_fields: vec![FieldDiff {
422 name: "age".to_string(),
423 column_name: "age".to_string(),
424 sql_type: "INTEGER".to_string(),
425 nullable: true,
426 default: None,
427 is_primary_key: false,
428 is_auto_increment: false,
429 is_unique: false,
430 }],
431 drop_fields: Vec::new(),
432 alter_fields: Vec::new(),
433 add_indexes: Vec::new(),
434 drop_indexes: Vec::new(),
435 };
436
437 let stmts = generator.alter_table(&alter);
438 assert_eq!(stmts.len(), 1);
439 assert!(stmts[0].contains("ADD COLUMN"));
440 assert!(stmts[0].contains("age"));
441 }
442
443 #[test]
444 fn test_create_view() {
445 let generator = PostgresSqlGenerator;
446 let view = ViewDiff {
447 name: "UserStats".to_string(),
448 view_name: "user_stats".to_string(),
449 sql_query: "SELECT id, COUNT(*) as post_count FROM users GROUP BY id".to_string(),
450 is_materialized: false,
451 refresh_interval: None,
452 fields: vec![],
453 };
454
455 let sql = generator.create_view(&view);
456 assert!(sql.contains("CREATE OR REPLACE VIEW"));
457 assert!(sql.contains("user_stats"));
458 assert!(sql.contains("SELECT id"));
459 assert!(sql.contains("post_count"));
460 }
461
462 #[test]
463 fn test_create_materialized_view() {
464 let generator = PostgresSqlGenerator;
465 let view = ViewDiff {
466 name: "UserStats".to_string(),
467 view_name: "user_stats".to_string(),
468 sql_query: "SELECT id, COUNT(*) as post_count FROM users GROUP BY id".to_string(),
469 is_materialized: true,
470 refresh_interval: Some("1h".to_string()),
471 fields: vec![],
472 };
473
474 let sql = generator.create_view(&view);
475 assert!(sql.contains("CREATE MATERIALIZED VIEW"));
476 assert!(sql.contains("user_stats"));
477 assert!(!sql.contains("OR REPLACE")); }
479
480 #[test]
481 fn test_drop_view() {
482 let generator = PostgresSqlGenerator;
483
484 let sql = generator.drop_view("user_stats", false);
485 assert!(sql.contains("DROP VIEW"));
486 assert!(sql.contains("user_stats"));
487 assert!(sql.contains("CASCADE"));
488
489 let sql_mat = generator.drop_view("user_stats", true);
490 assert!(sql_mat.contains("DROP MATERIALIZED VIEW"));
491 assert!(sql_mat.contains("user_stats"));
492 }
493
494 #[test]
495 fn test_refresh_materialized_view() {
496 let generator = PostgresSqlGenerator;
497
498 let sql = generator.refresh_materialized_view("user_stats", false);
499 assert!(sql.contains("REFRESH MATERIALIZED VIEW"));
500 assert!(sql.contains("user_stats"));
501 assert!(!sql.contains("CONCURRENTLY"));
502
503 let sql_concurrent = generator.refresh_materialized_view("user_stats", true);
504 assert!(sql_concurrent.contains("CONCURRENTLY"));
505 }
506
507 #[test]
508 fn test_generate_with_views() {
509 use crate::diff::SchemaDiff;
510
511 let generator = PostgresSqlGenerator;
512 let mut diff = SchemaDiff::default();
513 diff.create_views.push(ViewDiff {
514 name: "ActiveUsers".to_string(),
515 view_name: "active_users".to_string(),
516 sql_query: "SELECT * FROM users WHERE active = true".to_string(),
517 is_materialized: false,
518 refresh_interval: None,
519 fields: vec![],
520 });
521
522 let sql = generator.generate(&diff);
523 assert!(!sql.is_empty());
524 assert!(sql.up.contains("CREATE OR REPLACE VIEW"));
525 assert!(sql.up.contains("active_users"));
526 assert!(sql.down.contains("DROP VIEW"));
527 }
528}