Skip to main content

nautilus_dialect/
postgres.rs

1//! PostgreSQL SQL dialect renderer.
2
3use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Expr, Insert, Result, Select, Update, Value};
5
6/// PostgreSQL SQL dialect renderer.
7///
8/// Uses `$1, $2, ...` numbered parameter placeholders and double-quoted identifiers.
9/// Supports `RETURNING`, `DISTINCT ON`, PostgreSQL array operators, UUID type casts,
10/// and `FILTER (WHERE ...)` on aggregates.
11#[derive(Debug, Clone, Copy)]
12pub struct PostgresDialect;
13
14impl Dialect for PostgresDialect {
15    fn render_select(&self, select: &Select) -> Result<Sql> {
16        let mut ctx = RenderContext::new();
17        render_select_body_core!(&mut ctx, select, quote_identifier, render_expr, true, false);
18        Ok(Sql {
19            text: ctx.sql,
20            params: ctx.params,
21        })
22    }
23
24    fn render_insert(&self, insert: &Insert) -> Result<Sql> {
25        let mut ctx = RenderContext::new();
26        render_insert_body!(&mut ctx, insert, quote_identifier, true, true);
27        Ok(Sql {
28            text: ctx.sql,
29            params: ctx.params,
30        })
31    }
32
33    fn render_update(&self, update: &Update) -> Result<Sql> {
34        let mut ctx = RenderContext::new();
35        render_update_body!(&mut ctx, update, quote_identifier, render_expr, true, true);
36        Ok(Sql {
37            text: ctx.sql,
38            params: ctx.params,
39        })
40    }
41
42    fn render_delete(&self, delete: &Delete) -> Result<Sql> {
43        let mut ctx = RenderContext::new();
44        render_delete_body!(&mut ctx, delete, quote_identifier, render_expr, true);
45        Ok(Sql {
46            text: ctx.sql,
47            params: ctx.params,
48        })
49    }
50}
51
52fn quote_identifier(name: &str) -> String {
53    crate::double_quote_identifier(name)
54}
55
56struct RenderContext {
57    sql: String,
58    params: Vec<Value>,
59}
60
61impl RenderContext {
62    fn new() -> Self {
63        Self {
64            sql: String::new(),
65            params: Vec::new(),
66        }
67    }
68
69    fn push_param(&mut self, value: Value) -> String {
70        self.params.push(value);
71        format!("${}", self.params.len())
72    }
73}
74
75fn render_select_body(ctx: &mut RenderContext, select: &crate::Select) {
76    render_select_body_core!(ctx, select, quote_identifier, render_expr, true, false);
77}
78
79fn render_expr(ctx: &mut RenderContext, expr: &Expr) {
80    render_expr_common!(ctx, expr, quote_identifier, render_expr, render_select_body, {
81        Expr::Param(value) => {
82            // NULL is emitted literally; PostgreSQL cannot implicitly resolve a
83            // typed NULL sent as an unknown OID via the binary protocol.
84            if matches!(value, Value::Null) {
85                ctx.sql.push_str("NULL");
86            } else {
87                let placeholder = ctx.push_param(value.clone());
88                ctx.sql.push_str(&placeholder);
89                // PostgreSQL needs an explicit cast when the driver sends an unknown OID.
90                if matches!(value, Value::Uuid(_)) {
91                    ctx.sql.push_str("::uuid");
92                } else if matches!(value, Value::Json(_)) {
93                    ctx.sql.push_str("::json");
94                } else if let Value::Enum { type_name, .. } = value {
95                    ctx.sql.push_str("::");
96                    ctx.sql.push_str(type_name);
97                }
98            }
99        }
100        Expr::Binary { left, op, right } => {
101            if matches!(op, BinaryOp::In | BinaryOp::NotIn) {
102                ctx.sql.push('(');
103                render_expr(ctx, left);
104                ctx.sql.push(' ');
105                ctx.sql.push_str(if matches!(op, BinaryOp::In) { "IN" } else { "NOT IN" });
106                ctx.sql.push_str(" (");
107                if let Expr::List(exprs) = right.as_ref() {
108                    for (i, e) in exprs.iter().enumerate() {
109                        if i > 0 { ctx.sql.push_str(", "); }
110                        render_expr(ctx, e);
111                    }
112                } else {
113                    render_expr(ctx, right);
114                }
115                ctx.sql.push(')');
116                ctx.sql.push(')');
117            } else {
118                ctx.sql.push('(');
119                render_expr(ctx, left);
120                ctx.sql.push(' ');
121                ctx.sql.push_str(match op {
122                    BinaryOp::ArrayContains => "@>",
123                    BinaryOp::ArrayContainedBy => "<@",
124                    BinaryOp::ArrayOverlaps => "&&",
125                    _ => crate::binary_op_sql(op),
126                });
127                ctx.sql.push(' ');
128                render_expr(ctx, right);
129                ctx.sql.push(')');
130            }
131        }
132        Expr::FunctionCall { name, args } => {
133            ctx.sql.push_str(name);
134            ctx.sql.push('(');
135            for (i, arg) in args.iter().enumerate() {
136                if i > 0 { ctx.sql.push_str(", "); }
137                render_expr(ctx, arg);
138            }
139            ctx.sql.push(')');
140        }
141        Expr::Filter { expr, predicate } => {
142            // Native PostgreSQL FILTER clause (supported since pg 9.4).
143            render_expr(ctx, expr);
144            ctx.sql.push_str(" FILTER (WHERE ");
145            render_expr(ctx, predicate);
146            ctx.sql.push(')');
147        }
148    });
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test_quote_identifier() {
157        assert_eq!(quote_identifier("users"), "\"users\"");
158        assert_eq!(quote_identifier("email"), "\"email\"");
159        assert_eq!(quote_identifier("foo\"bar"), "\"foo\"\"bar\"");
160        assert_eq!(quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
161    }
162
163    #[test]
164    fn test_array_contains_operator() {
165        let dialect = PostgresDialect;
166        let expr = Expr::Binary {
167            left: Box::new(Expr::column("posts__tags")),
168            op: BinaryOp::ArrayContains,
169            right: Box::new(Expr::param(Value::Array(vec![Value::String(
170                "rust".to_string(),
171            )]))),
172        };
173        let select = Select::from_table("posts").filter(expr).build().unwrap();
174        let sql = dialect.render_select(&select).unwrap();
175
176        assert_eq!(
177            sql.text,
178            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" @> $1)"
179        );
180        assert_eq!(sql.params.len(), 1);
181        match &sql.params[0] {
182            Value::Array(arr) => {
183                assert_eq!(arr.len(), 1);
184                assert_eq!(arr[0], Value::String("rust".to_string()));
185            }
186            _ => panic!("Expected Array value"),
187        }
188    }
189
190    #[test]
191    fn test_array_contained_by_operator() {
192        let dialect = PostgresDialect;
193        let expr = Expr::Binary {
194            left: Box::new(Expr::column("posts__tags")),
195            op: BinaryOp::ArrayContainedBy,
196            right: Box::new(Expr::param(Value::Array(vec![
197                Value::String("rust".to_string()),
198                Value::String("go".to_string()),
199            ]))),
200        };
201        let select = Select::from_table("posts").filter(expr).build().unwrap();
202        let sql = dialect.render_select(&select).unwrap();
203
204        assert_eq!(
205            sql.text,
206            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" <@ $1)"
207        );
208        assert_eq!(sql.params.len(), 1);
209        match &sql.params[0] {
210            Value::Array(arr) => {
211                assert_eq!(arr.len(), 2);
212                assert_eq!(arr[0], Value::String("rust".to_string()));
213                assert_eq!(arr[1], Value::String("go".to_string()));
214            }
215            _ => panic!("Expected Array value"),
216        }
217    }
218
219    #[test]
220    fn test_array_overlaps_operator() {
221        let dialect = PostgresDialect;
222        let expr = Expr::Binary {
223            left: Box::new(Expr::column("posts__tags")),
224            op: BinaryOp::ArrayOverlaps,
225            right: Box::new(Expr::param(Value::Array(vec![
226                Value::String("rust".to_string()),
227                Value::String("python".to_string()),
228            ]))),
229        };
230        let select = Select::from_table("posts").filter(expr).build().unwrap();
231        let sql = dialect.render_select(&select).unwrap();
232
233        assert_eq!(
234            sql.text,
235            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" && $1)"
236        );
237        assert_eq!(sql.params.len(), 1);
238        match &sql.params[0] {
239            Value::Array(arr) => {
240                assert_eq!(arr.len(), 2);
241                assert_eq!(arr[0], Value::String("rust".to_string()));
242                assert_eq!(arr[1], Value::String("python".to_string()));
243            }
244            _ => panic!("Expected Array value"),
245        }
246    }
247
248    #[test]
249    fn test_array_operators_with_integers() {
250        let dialect = PostgresDialect;
251        let expr = Expr::Binary {
252            left: Box::new(Expr::column("posts__scores")),
253            op: BinaryOp::ArrayContains,
254            right: Box::new(Expr::param(Value::Array(vec![
255                Value::I32(100),
256                Value::I32(200),
257            ]))),
258        };
259        let select = Select::from_table("posts").filter(expr).build().unwrap();
260        let sql = dialect.render_select(&select).unwrap();
261
262        assert_eq!(
263            sql.text,
264            "SELECT * FROM \"posts\" WHERE (\"posts\".\"scores\" @> $1)"
265        );
266        assert_eq!(sql.params.len(), 1);
267        match &sql.params[0] {
268            Value::Array(arr) => {
269                assert_eq!(arr.len(), 2);
270                assert_eq!(arr[0], Value::I32(100));
271                assert_eq!(arr[1], Value::I32(200));
272            }
273            _ => panic!("Expected Array value"),
274        }
275    }
276}