Skip to main content

nautilus_dialect/
postgres.rs

1//! PostgreSQL SQL dialect renderer.
2
3use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Expr, Insert, Result, Select, Update, Value};
5
6/// PostgreSQL SQL dialect renderer.
7///
8/// Uses `$1, $2, ...` numbered parameter placeholders and double-quoted identifiers.
9/// Supports `RETURNING`, `DISTINCT ON`, PostgreSQL array operators, UUID type casts,
10/// and `FILTER (WHERE ...)` on aggregates.
11#[derive(Debug, Clone, Copy)]
12pub struct PostgresDialect;
13
14impl Dialect for PostgresDialect {
15    fn render_select(&self, select: &Select) -> Result<Sql> {
16        let mut ctx = RenderContext::new();
17        // DISTINCT ON (cols) is PostgreSQL-specific; mysql_limit_hack not needed.
18        render_select_body_core!(&mut ctx, select, quote_identifier, render_expr, true, false);
19        Ok(Sql {
20            text: ctx.sql,
21            params: ctx.params,
22        })
23    }
24
25    fn render_insert(&self, insert: &Insert) -> Result<Sql> {
26        let mut ctx = RenderContext::new();
27        render_insert_body!(&mut ctx, insert, quote_identifier, true, true);
28        Ok(Sql {
29            text: ctx.sql,
30            params: ctx.params,
31        })
32    }
33
34    fn render_update(&self, update: &Update) -> Result<Sql> {
35        let mut ctx = RenderContext::new();
36        render_update_body!(&mut ctx, update, quote_identifier, render_expr, true, true);
37        Ok(Sql {
38            text: ctx.sql,
39            params: ctx.params,
40        })
41    }
42
43    fn render_delete(&self, delete: &Delete) -> Result<Sql> {
44        let mut ctx = RenderContext::new();
45        render_delete_body!(&mut ctx, delete, quote_identifier, render_expr, true);
46        Ok(Sql {
47            text: ctx.sql,
48            params: ctx.params,
49        })
50    }
51}
52
53/// Quote a SQL identifier with double quotes (delegates to the shared helper).
54fn quote_identifier(name: &str) -> String {
55    crate::double_quote_identifier(name)
56}
57
58/// Rendering context: accumulates SQL text and bound parameter values.
59struct RenderContext {
60    sql: String,
61    params: Vec<Value>,
62}
63
64impl RenderContext {
65    fn new() -> Self {
66        Self {
67            sql: String::new(),
68            params: Vec::new(),
69        }
70    }
71
72    /// Append a bound parameter and return its `$N` placeholder.
73    fn push_param(&mut self, value: Value) -> String {
74        self.params.push(value);
75        format!("${}", self.params.len())
76    }
77}
78
79/// Render a SELECT query body into an existing context.
80///
81/// Called for both top-level queries and subqueries inside `EXISTS` / `NOT EXISTS`.
82fn render_select_body(ctx: &mut RenderContext, select: &crate::Select) {
83    render_select_body_core!(ctx, select, quote_identifier, render_expr, true, false);
84}
85
86/// Render an expression into SQL.
87fn render_expr(ctx: &mut RenderContext, expr: &Expr) {
88    render_expr_common!(ctx, expr, quote_identifier, render_expr, render_select_body, {
89        Expr::Param(value) => {
90            // NULL is emitted literally; PostgreSQL cannot implicitly resolve a
91            // typed NULL sent as an unknown OID via the binary protocol.
92            if matches!(value, Value::Null) {
93                ctx.sql.push_str("NULL");
94            } else {
95                let placeholder = ctx.push_param(value.clone());
96                ctx.sql.push_str(&placeholder);
97                // PostgreSQL needs an explicit cast when the driver sends an unknown OID.
98                if matches!(value, Value::Uuid(_)) {
99                    ctx.sql.push_str("::uuid");
100                } else if matches!(value, Value::Json(_)) {
101                    ctx.sql.push_str("::json");
102                } else if let Value::Enum { type_name, .. } = value {
103                    ctx.sql.push_str("::");
104                    ctx.sql.push_str(type_name);
105                }
106            }
107        }
108        Expr::Binary { left, op, right } => {
109            if matches!(op, BinaryOp::In | BinaryOp::NotIn) {
110                ctx.sql.push('(');
111                render_expr(ctx, left);
112                ctx.sql.push(' ');
113                ctx.sql.push_str(if matches!(op, BinaryOp::In) { "IN" } else { "NOT IN" });
114                ctx.sql.push_str(" (");
115                if let Expr::List(exprs) = right.as_ref() {
116                    for (i, e) in exprs.iter().enumerate() {
117                        if i > 0 { ctx.sql.push_str(", "); }
118                        render_expr(ctx, e);
119                    }
120                } else {
121                    render_expr(ctx, right);
122                }
123                ctx.sql.push(')');
124                ctx.sql.push(')');
125            } else {
126                ctx.sql.push('(');
127                render_expr(ctx, left);
128                ctx.sql.push(' ');
129                ctx.sql.push_str(match op {
130                    BinaryOp::ArrayContains => "@>",
131                    BinaryOp::ArrayContainedBy => "<@",
132                    BinaryOp::ArrayOverlaps => "&&",
133                    _ => crate::binary_op_sql(op),
134                });
135                ctx.sql.push(' ');
136                render_expr(ctx, right);
137                ctx.sql.push(')');
138            }
139        }
140        Expr::FunctionCall { name, args } => {
141            ctx.sql.push_str(name);
142            ctx.sql.push('(');
143            for (i, arg) in args.iter().enumerate() {
144                if i > 0 { ctx.sql.push_str(", "); }
145                render_expr(ctx, arg);
146            }
147            ctx.sql.push(')');
148        }
149        Expr::Filter { expr, predicate } => {
150            // Native PostgreSQL FILTER clause (supported since pg 9.4).
151            render_expr(ctx, expr);
152            ctx.sql.push_str(" FILTER (WHERE ");
153            render_expr(ctx, predicate);
154            ctx.sql.push(')');
155        }
156    });
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_quote_identifier() {
165        assert_eq!(quote_identifier("users"), "\"users\"");
166        assert_eq!(quote_identifier("email"), "\"email\"");
167        assert_eq!(quote_identifier("foo\"bar"), "\"foo\"\"bar\"");
168        assert_eq!(quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
169    }
170
171    // ——— PostgreSQL-specific: native array operators @>, <@, && —————————————————
172
173    #[test]
174    fn test_array_contains_operator() {
175        let dialect = PostgresDialect;
176        let expr = Expr::Binary {
177            left: Box::new(Expr::column("posts__tags")),
178            op: BinaryOp::ArrayContains,
179            right: Box::new(Expr::param(Value::Array(vec![Value::String(
180                "rust".to_string(),
181            )]))),
182        };
183        let select = Select::from_table("posts").filter(expr).build().unwrap();
184        let sql = dialect.render_select(&select).unwrap();
185
186        assert_eq!(
187            sql.text,
188            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" @> $1)"
189        );
190        assert_eq!(sql.params.len(), 1);
191        match &sql.params[0] {
192            Value::Array(arr) => {
193                assert_eq!(arr.len(), 1);
194                assert_eq!(arr[0], Value::String("rust".to_string()));
195            }
196            _ => panic!("Expected Array value"),
197        }
198    }
199
200    #[test]
201    fn test_array_contained_by_operator() {
202        let dialect = PostgresDialect;
203        let expr = Expr::Binary {
204            left: Box::new(Expr::column("posts__tags")),
205            op: BinaryOp::ArrayContainedBy,
206            right: Box::new(Expr::param(Value::Array(vec![
207                Value::String("rust".to_string()),
208                Value::String("go".to_string()),
209            ]))),
210        };
211        let select = Select::from_table("posts").filter(expr).build().unwrap();
212        let sql = dialect.render_select(&select).unwrap();
213
214        assert_eq!(
215            sql.text,
216            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" <@ $1)"
217        );
218        assert_eq!(sql.params.len(), 1);
219        match &sql.params[0] {
220            Value::Array(arr) => {
221                assert_eq!(arr.len(), 2);
222                assert_eq!(arr[0], Value::String("rust".to_string()));
223                assert_eq!(arr[1], Value::String("go".to_string()));
224            }
225            _ => panic!("Expected Array value"),
226        }
227    }
228
229    #[test]
230    fn test_array_overlaps_operator() {
231        let dialect = PostgresDialect;
232        let expr = Expr::Binary {
233            left: Box::new(Expr::column("posts__tags")),
234            op: BinaryOp::ArrayOverlaps,
235            right: Box::new(Expr::param(Value::Array(vec![
236                Value::String("rust".to_string()),
237                Value::String("python".to_string()),
238            ]))),
239        };
240        let select = Select::from_table("posts").filter(expr).build().unwrap();
241        let sql = dialect.render_select(&select).unwrap();
242
243        assert_eq!(
244            sql.text,
245            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" && $1)"
246        );
247        assert_eq!(sql.params.len(), 1);
248        match &sql.params[0] {
249            Value::Array(arr) => {
250                assert_eq!(arr.len(), 2);
251                assert_eq!(arr[0], Value::String("rust".to_string()));
252                assert_eq!(arr[1], Value::String("python".to_string()));
253            }
254            _ => panic!("Expected Array value"),
255        }
256    }
257
258    #[test]
259    fn test_array_operators_with_integers() {
260        let dialect = PostgresDialect;
261        let expr = Expr::Binary {
262            left: Box::new(Expr::column("posts__scores")),
263            op: BinaryOp::ArrayContains,
264            right: Box::new(Expr::param(Value::Array(vec![
265                Value::I32(100),
266                Value::I32(200),
267            ]))),
268        };
269        let select = Select::from_table("posts").filter(expr).build().unwrap();
270        let sql = dialect.render_select(&select).unwrap();
271
272        assert_eq!(
273            sql.text,
274            "SELECT * FROM \"posts\" WHERE (\"posts\".\"scores\" @> $1)"
275        );
276        assert_eq!(sql.params.len(), 1);
277        match &sql.params[0] {
278            Value::Array(arr) => {
279                assert_eq!(arr.len(), 2);
280                assert_eq!(arr[0], Value::I32(100));
281                assert_eq!(arr[1], Value::I32(200));
282            }
283            _ => panic!("Expected Array value"),
284        }
285    }
286}