Skip to main content

oxide_sql_core/migrations/dialect/
postgres.rs

1//! PostgreSQL dialect for migrations.
2
3use super::MigrationDialect;
4use crate::ast::DataType;
5use crate::migrations::column_builder::{ColumnDefinition, DefaultValue};
6use crate::migrations::operation::{
7    AlterColumnChange, AlterColumnOp, DropIndexOp, RenameColumnOp, RenameTableOp,
8};
9
10/// PostgreSQL dialect for migration SQL generation.
11#[derive(Debug, Clone, Copy, Default)]
12pub struct PostgresDialect;
13
14impl PostgresDialect {
15    /// Creates a new PostgreSQL dialect.
16    #[must_use]
17    pub const fn new() -> Self {
18        Self
19    }
20}
21
22impl MigrationDialect for PostgresDialect {
23    fn name(&self) -> &'static str {
24        "postgresql"
25    }
26
27    fn map_data_type(&self, dt: &DataType) -> String {
28        match dt {
29            DataType::Smallint => "SMALLINT".to_string(),
30            DataType::Integer => "INTEGER".to_string(),
31            DataType::Bigint => "BIGINT".to_string(),
32            DataType::Real => "REAL".to_string(),
33            DataType::Double => "DOUBLE PRECISION".to_string(),
34            DataType::Decimal { precision, scale } => match (precision, scale) {
35                (Some(p), Some(s)) => format!("DECIMAL({p}, {s})"),
36                (Some(p), None) => format!("DECIMAL({p})"),
37                _ => "DECIMAL".to_string(),
38            },
39            DataType::Numeric { precision, scale } => match (precision, scale) {
40                (Some(p), Some(s)) => format!("NUMERIC({p}, {s})"),
41                (Some(p), None) => format!("NUMERIC({p})"),
42                _ => "NUMERIC".to_string(),
43            },
44            DataType::Char(len) => match len {
45                Some(n) => format!("CHAR({n})"),
46                None => "CHAR".to_string(),
47            },
48            DataType::Varchar(len) => match len {
49                Some(n) => format!("VARCHAR({n})"),
50                None => "VARCHAR".to_string(),
51            },
52            DataType::Text => "TEXT".to_string(),
53            DataType::Blob => "BYTEA".to_string(), // PostgreSQL uses BYTEA
54            DataType::Binary(len) => match len {
55                Some(n) => format!("BIT({n})"),
56                None => "BYTEA".to_string(),
57            },
58            DataType::Varbinary(len) => match len {
59                Some(n) => format!("VARBIT({n})"),
60                None => "BYTEA".to_string(),
61            },
62            DataType::Date => "DATE".to_string(),
63            DataType::Time => "TIME".to_string(),
64            DataType::Timestamp => "TIMESTAMP".to_string(),
65            DataType::Datetime => "TIMESTAMP".to_string(), // PostgreSQL uses TIMESTAMP
66            DataType::Boolean => "BOOLEAN".to_string(),
67            DataType::Custom(name) => name.clone(),
68        }
69    }
70
71    fn autoincrement_keyword(&self) -> String {
72        // PostgreSQL uses SERIAL types instead of AUTOINCREMENT keyword
73        // However, when PRIMARY KEY is specified with BIGINT, we don't change the type
74        // The application should use SERIAL/BIGSERIAL types directly
75        String::new()
76    }
77
78    fn column_definition(&self, col: &ColumnDefinition) -> String {
79        // PostgreSQL uses SERIAL/BIGSERIAL for auto-increment
80        let data_type = if col.autoincrement && col.primary_key {
81            match col.data_type {
82                DataType::Integer | DataType::Smallint => "SERIAL".to_string(),
83                DataType::Bigint => "BIGSERIAL".to_string(),
84                _ => self.map_data_type(&col.data_type),
85            }
86        } else {
87            self.map_data_type(&col.data_type)
88        };
89
90        let mut sql = format!("{} {}", self.quote_identifier(&col.name), data_type);
91
92        if col.primary_key {
93            sql.push_str(" PRIMARY KEY");
94        } else {
95            if !col.nullable {
96                sql.push_str(" NOT NULL");
97            }
98            if col.unique {
99                sql.push_str(" UNIQUE");
100            }
101        }
102
103        if let Some(ref default) = col.default {
104            sql.push_str(" DEFAULT ");
105            sql.push_str(&self.render_default(default));
106        }
107
108        if let Some(ref fk) = col.references {
109            sql.push_str(" REFERENCES ");
110            sql.push_str(&self.quote_identifier(&fk.table));
111            sql.push_str(" (");
112            sql.push_str(&self.quote_identifier(&fk.column));
113            sql.push(')');
114            if let Some(action) = fk.on_delete {
115                sql.push_str(" ON DELETE ");
116                sql.push_str(action.as_sql());
117            }
118            if let Some(action) = fk.on_update {
119                sql.push_str(" ON UPDATE ");
120                sql.push_str(action.as_sql());
121            }
122        }
123
124        if let Some(ref check) = col.check {
125            sql.push_str(&format!(" CHECK ({})", check));
126        }
127
128        if let Some(ref collation) = col.collation {
129            sql.push_str(&format!(" COLLATE \"{}\"", collation));
130        }
131
132        sql
133    }
134
135    fn render_default(&self, default: &DefaultValue) -> String {
136        match default {
137            DefaultValue::Boolean(b) => {
138                if *b {
139                    "TRUE".to_string()
140                } else {
141                    "FALSE".to_string()
142                }
143            }
144            _ => default.to_sql(),
145        }
146    }
147
148    fn rename_table(&self, op: &RenameTableOp) -> String {
149        format!(
150            "ALTER TABLE {} RENAME TO {}",
151            self.quote_identifier(&op.old_name),
152            self.quote_identifier(&op.new_name)
153        )
154    }
155
156    fn rename_column(&self, op: &RenameColumnOp) -> String {
157        format!(
158            "ALTER TABLE {} RENAME COLUMN {} TO {}",
159            self.quote_identifier(&op.table),
160            self.quote_identifier(&op.old_name),
161            self.quote_identifier(&op.new_name)
162        )
163    }
164
165    fn alter_column(&self, op: &AlterColumnOp) -> String {
166        let table = self.quote_identifier(&op.table);
167        let column = self.quote_identifier(&op.column);
168
169        match &op.change {
170            AlterColumnChange::SetDataType(dt) => {
171                format!(
172                    "ALTER TABLE {} ALTER COLUMN {} TYPE {}",
173                    table,
174                    column,
175                    self.map_data_type(dt)
176                )
177            }
178            AlterColumnChange::SetNullable(nullable) => {
179                if *nullable {
180                    format!(
181                        "ALTER TABLE {} ALTER COLUMN {} DROP NOT NULL",
182                        table, column
183                    )
184                } else {
185                    format!("ALTER TABLE {} ALTER COLUMN {} SET NOT NULL", table, column)
186                }
187            }
188            AlterColumnChange::SetDefault(default) => {
189                format!(
190                    "ALTER TABLE {} ALTER COLUMN {} SET DEFAULT {}",
191                    table,
192                    column,
193                    self.render_default(default)
194                )
195            }
196            AlterColumnChange::DropDefault => {
197                format!("ALTER TABLE {} ALTER COLUMN {} DROP DEFAULT", table, column)
198            }
199        }
200    }
201
202    fn drop_index(&self, op: &DropIndexOp) -> String {
203        let mut sql = String::from("DROP INDEX ");
204        if op.if_exists {
205            sql.push_str("IF EXISTS ");
206        }
207        sql.push_str(&self.quote_identifier(&op.name));
208        sql
209    }
210
211    fn drop_foreign_key(&self, op: &super::super::operation::DropForeignKeyOp) -> String {
212        format!(
213            "ALTER TABLE {} DROP CONSTRAINT {}",
214            self.quote_identifier(&op.table),
215            self.quote_identifier(&op.name)
216        )
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use crate::migrations::column_builder::{bigint, varchar};
224    use crate::migrations::table_builder::CreateTableBuilder;
225
226    #[test]
227    fn test_postgres_data_types() {
228        let dialect = PostgresDialect::new();
229        assert_eq!(dialect.map_data_type(&DataType::Integer), "INTEGER");
230        assert_eq!(dialect.map_data_type(&DataType::Bigint), "BIGINT");
231        assert_eq!(dialect.map_data_type(&DataType::Text), "TEXT");
232        assert_eq!(
233            dialect.map_data_type(&DataType::Varchar(Some(255))),
234            "VARCHAR(255)"
235        );
236        assert_eq!(dialect.map_data_type(&DataType::Blob), "BYTEA");
237        assert_eq!(dialect.map_data_type(&DataType::Boolean), "BOOLEAN");
238        assert_eq!(dialect.map_data_type(&DataType::Timestamp), "TIMESTAMP");
239        assert_eq!(
240            dialect.map_data_type(&DataType::Decimal {
241                precision: Some(10),
242                scale: Some(2)
243            }),
244            "DECIMAL(10, 2)"
245        );
246    }
247
248    #[test]
249    fn test_create_table_with_serial() {
250        let dialect = PostgresDialect::new();
251        let op = CreateTableBuilder::new()
252            .name("users")
253            .column(bigint("id").primary_key().autoincrement().build())
254            .column(varchar("username", 255).not_null().unique().build())
255            .build();
256
257        let sql = dialect.create_table(&op);
258        assert!(sql.contains("CREATE TABLE \"users\""));
259        assert!(sql.contains("\"id\" BIGSERIAL PRIMARY KEY"));
260        assert!(sql.contains("\"username\" VARCHAR(255) NOT NULL UNIQUE"));
261    }
262
263    #[test]
264    fn test_alter_column_sql() {
265        let dialect = PostgresDialect::new();
266
267        // Set NOT NULL
268        let op = AlterColumnOp {
269            table: "users".to_string(),
270            column: "email".to_string(),
271            change: AlterColumnChange::SetNullable(false),
272        };
273        assert_eq!(
274            dialect.alter_column(&op),
275            "ALTER TABLE \"users\" ALTER COLUMN \"email\" SET NOT NULL"
276        );
277
278        // Drop NOT NULL
279        let op = AlterColumnOp {
280            table: "users".to_string(),
281            column: "email".to_string(),
282            change: AlterColumnChange::SetNullable(true),
283        };
284        assert_eq!(
285            dialect.alter_column(&op),
286            "ALTER TABLE \"users\" ALTER COLUMN \"email\" DROP NOT NULL"
287        );
288
289        // Change type
290        let op = AlterColumnOp {
291            table: "users".to_string(),
292            column: "age".to_string(),
293            change: AlterColumnChange::SetDataType(DataType::Bigint),
294        };
295        assert_eq!(
296            dialect.alter_column(&op),
297            "ALTER TABLE \"users\" ALTER COLUMN \"age\" TYPE BIGINT"
298        );
299    }
300
301    #[test]
302    fn test_drop_foreign_key() {
303        let dialect = PostgresDialect::new();
304        let op = super::super::super::operation::DropForeignKeyOp {
305            table: "invoices".to_string(),
306            name: "fk_invoices_user".to_string(),
307        };
308        assert_eq!(
309            dialect.drop_foreign_key(&op),
310            "ALTER TABLE \"invoices\" DROP CONSTRAINT \"fk_invoices_user\""
311        );
312    }
313}