Skip to main content

nautilus_dialect/
postgres.rs

1//! PostgreSQL SQL dialect renderer.
2
3use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Expr, Insert, Result, Select, Update, Value};
5
6/// PostgreSQL SQL dialect renderer.
7///
8/// Uses `$1, $2, ...` numbered parameter placeholders and double-quoted identifiers.
9/// Supports `RETURNING`, `DISTINCT ON`, PostgreSQL array operators, UUID type casts,
10/// and `FILTER (WHERE ...)` on aggregates.
11#[derive(Debug, Clone, Copy)]
12pub struct PostgresDialect;
13
14impl Dialect for PostgresDialect {
15    fn render_select(&self, select: &Select) -> Result<Sql> {
16        let mut ctx = RenderContext::new();
17        render_select_body_core!(&mut ctx, select, quote_identifier, render_expr, true, false);
18        Ok(Sql {
19            text: ctx.sql,
20            params: ctx.params,
21        })
22    }
23
24    fn render_insert(&self, insert: &Insert) -> Result<Sql> {
25        let mut ctx = RenderContext::new();
26        render_insert_body!(&mut ctx, insert, quote_identifier, true, true);
27        Ok(Sql {
28            text: ctx.sql,
29            params: ctx.params,
30        })
31    }
32
33    fn render_update(&self, update: &Update) -> Result<Sql> {
34        let mut ctx = RenderContext::new();
35        render_update_body!(&mut ctx, update, quote_identifier, render_expr, true, true);
36        Ok(Sql {
37            text: ctx.sql,
38            params: ctx.params,
39        })
40    }
41
42    fn render_delete(&self, delete: &Delete) -> Result<Sql> {
43        let mut ctx = RenderContext::new();
44        render_delete_body!(&mut ctx, delete, quote_identifier, render_expr, true);
45        Ok(Sql {
46            text: ctx.sql,
47            params: ctx.params,
48        })
49    }
50}
51
52fn quote_identifier(name: &str) -> String {
53    crate::double_quote_identifier(name)
54}
55
56struct RenderContext {
57    sql: String,
58    params: Vec<Value>,
59}
60
61impl RenderContext {
62    fn new() -> Self {
63        Self {
64            sql: String::new(),
65            params: Vec::new(),
66        }
67    }
68
69    fn push_param(&mut self, value: Value) -> String {
70        self.params.push(value);
71        format!("${}", self.params.len())
72    }
73}
74
75fn render_select_body(ctx: &mut RenderContext, select: &crate::Select) {
76    render_select_body_core!(ctx, select, quote_identifier, render_expr, true, false);
77}
78
79fn render_expr(ctx: &mut RenderContext, expr: &Expr) {
80    render_expr_common!(ctx, expr, quote_identifier, render_expr, render_select_body, {
81        Expr::Param(value) => {
82            // NULL is emitted literally; PostgreSQL cannot implicitly resolve a
83            // typed NULL sent as an unknown OID via the binary protocol.
84            if matches!(value, Value::Null) {
85                ctx.sql.push_str("NULL");
86            } else {
87                let placeholder = ctx.push_param(value.clone());
88                ctx.sql.push_str(&placeholder);
89                // PostgreSQL needs an explicit cast when the driver sends an unknown OID.
90                if matches!(value, Value::Uuid(_)) {
91                    ctx.sql.push_str("::uuid");
92                } else if matches!(value, Value::Json(_)) {
93                    ctx.sql.push_str("::json");
94                } else if matches!(value, Value::Vector(_)) {
95                    ctx.sql.push_str("::vector");
96                } else if matches!(value, Value::Geometry(_)) {
97                    ctx.sql.push_str("::geometry");
98                } else if matches!(value, Value::Geography(_)) {
99                    ctx.sql.push_str("::geography");
100                } else if is_homogeneous_geometry_array(value) {
101                    ctx.sql.push_str("::geometry[]");
102                } else if is_homogeneous_geography_array(value) {
103                    ctx.sql.push_str("::geography[]");
104                } else if let Value::Enum { type_name, .. } = value {
105                    ctx.sql.push_str("::");
106                    ctx.sql.push_str(type_name);
107                }
108            }
109        }
110        Expr::Binary { left, op, right } => {
111            if matches!(op, BinaryOp::In | BinaryOp::NotIn) {
112                ctx.sql.push('(');
113                render_expr(ctx, left);
114                ctx.sql.push(' ');
115                ctx.sql.push_str(if matches!(op, BinaryOp::In) { "IN" } else { "NOT IN" });
116                ctx.sql.push_str(" (");
117                if let Expr::List(exprs) = right.as_ref() {
118                    for (i, e) in exprs.iter().enumerate() {
119                        if i > 0 { ctx.sql.push_str(", "); }
120                        render_expr(ctx, e);
121                    }
122                } else {
123                    render_expr(ctx, right);
124                }
125                ctx.sql.push(')');
126                ctx.sql.push(')');
127            } else {
128                ctx.sql.push('(');
129                render_expr(ctx, left);
130                ctx.sql.push(' ');
131                ctx.sql.push_str(match op {
132                    BinaryOp::ArrayContains => "@>",
133                    BinaryOp::ArrayContainedBy => "<@",
134                    BinaryOp::ArrayOverlaps => "&&",
135                    _ => crate::binary_op_sql(op),
136                });
137                ctx.sql.push(' ');
138                render_expr(ctx, right);
139                ctx.sql.push(')');
140            }
141        }
142        Expr::FunctionCall { name, args } => {
143            if args.len() == 2 {
144                let op = match name.as_str() {
145                    nautilus_core::expr::VECTOR_L2_DISTANCE_FUNCTION => Some("<->"),
146                    nautilus_core::expr::VECTOR_INNER_PRODUCT_FUNCTION => Some("<#>"),
147                    nautilus_core::expr::VECTOR_COSINE_DISTANCE_FUNCTION => Some("<=>"),
148                    _ => None,
149                };
150                if let Some(op) = op {
151                    ctx.sql.push('(');
152                    render_expr(ctx, &args[0]);
153                    ctx.sql.push(' ');
154                    ctx.sql.push_str(op);
155                    ctx.sql.push(' ');
156                    render_expr(ctx, &args[1]);
157                    ctx.sql.push(')');
158                    return;
159                }
160            }
161            ctx.sql.push_str(name);
162            ctx.sql.push('(');
163            for (i, arg) in args.iter().enumerate() {
164                if i > 0 { ctx.sql.push_str(", "); }
165                render_expr(ctx, arg);
166            }
167            ctx.sql.push(')');
168        }
169        Expr::Filter { expr, predicate } => {
170            // Native PostgreSQL FILTER clause (supported since pg 9.4).
171            render_expr(ctx, expr);
172            ctx.sql.push_str(" FILTER (WHERE ");
173            render_expr(ctx, predicate);
174            ctx.sql.push(')');
175        }
176    });
177}
178
179fn is_homogeneous_geometry_array(value: &Value) -> bool {
180    matches!(
181        value,
182        Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geometry(_)))
183    )
184}
185
186fn is_homogeneous_geography_array(value: &Value) -> bool {
187    matches!(
188        value,
189        Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geography(_)))
190    )
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn test_quote_identifier() {
199        assert_eq!(quote_identifier("users"), "\"users\"");
200        assert_eq!(quote_identifier("email"), "\"email\"");
201        assert_eq!(quote_identifier("foo\"bar"), "\"foo\"\"bar\"");
202        assert_eq!(quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
203    }
204
205    #[test]
206    fn test_array_contains_operator() {
207        let dialect = PostgresDialect;
208        let expr = Expr::Binary {
209            left: Box::new(Expr::column("posts__tags")),
210            op: BinaryOp::ArrayContains,
211            right: Box::new(Expr::param(Value::Array(vec![Value::String(
212                "rust".to_string(),
213            )]))),
214        };
215        let select = Select::from_table("posts").filter(expr).build().unwrap();
216        let sql = dialect.render_select(&select).unwrap();
217
218        assert_eq!(
219            sql.text,
220            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" @> $1)"
221        );
222        assert_eq!(sql.params.len(), 1);
223        match &sql.params[0] {
224            Value::Array(arr) => {
225                assert_eq!(arr.len(), 1);
226                assert_eq!(arr[0], Value::String("rust".to_string()));
227            }
228            _ => panic!("Expected Array value"),
229        }
230    }
231
232    #[test]
233    fn test_array_contained_by_operator() {
234        let dialect = PostgresDialect;
235        let expr = Expr::Binary {
236            left: Box::new(Expr::column("posts__tags")),
237            op: BinaryOp::ArrayContainedBy,
238            right: Box::new(Expr::param(Value::Array(vec![
239                Value::String("rust".to_string()),
240                Value::String("go".to_string()),
241            ]))),
242        };
243        let select = Select::from_table("posts").filter(expr).build().unwrap();
244        let sql = dialect.render_select(&select).unwrap();
245
246        assert_eq!(
247            sql.text,
248            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" <@ $1)"
249        );
250        assert_eq!(sql.params.len(), 1);
251        match &sql.params[0] {
252            Value::Array(arr) => {
253                assert_eq!(arr.len(), 2);
254                assert_eq!(arr[0], Value::String("rust".to_string()));
255                assert_eq!(arr[1], Value::String("go".to_string()));
256            }
257            _ => panic!("Expected Array value"),
258        }
259    }
260
261    #[test]
262    fn test_array_overlaps_operator() {
263        let dialect = PostgresDialect;
264        let expr = Expr::Binary {
265            left: Box::new(Expr::column("posts__tags")),
266            op: BinaryOp::ArrayOverlaps,
267            right: Box::new(Expr::param(Value::Array(vec![
268                Value::String("rust".to_string()),
269                Value::String("python".to_string()),
270            ]))),
271        };
272        let select = Select::from_table("posts").filter(expr).build().unwrap();
273        let sql = dialect.render_select(&select).unwrap();
274
275        assert_eq!(
276            sql.text,
277            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" && $1)"
278        );
279        assert_eq!(sql.params.len(), 1);
280        match &sql.params[0] {
281            Value::Array(arr) => {
282                assert_eq!(arr.len(), 2);
283                assert_eq!(arr[0], Value::String("rust".to_string()));
284                assert_eq!(arr[1], Value::String("python".to_string()));
285            }
286            _ => panic!("Expected Array value"),
287        }
288    }
289
290    #[test]
291    fn test_array_operators_with_integers() {
292        let dialect = PostgresDialect;
293        let expr = Expr::Binary {
294            left: Box::new(Expr::column("posts__scores")),
295            op: BinaryOp::ArrayContains,
296            right: Box::new(Expr::param(Value::Array(vec![
297                Value::I32(100),
298                Value::I32(200),
299            ]))),
300        };
301        let select = Select::from_table("posts").filter(expr).build().unwrap();
302        let sql = dialect.render_select(&select).unwrap();
303
304        assert_eq!(
305            sql.text,
306            "SELECT * FROM \"posts\" WHERE (\"posts\".\"scores\" @> $1)"
307        );
308        assert_eq!(sql.params.len(), 1);
309        match &sql.params[0] {
310            Value::Array(arr) => {
311                assert_eq!(arr.len(), 2);
312                assert_eq!(arr[0], Value::I32(100));
313                assert_eq!(arr[1], Value::I32(200));
314            }
315            _ => panic!("Expected Array value"),
316        }
317    }
318
319    #[test]
320    fn vector_params_are_cast_to_pgvector_type() {
321        let dialect = PostgresDialect;
322        let select = Select::from_table("embeddings")
323            .filter(
324                Expr::column("embeddings__vector")
325                    .eq(Expr::param(Value::Vector(vec![1.0, 2.0, 3.0]))),
326            )
327            .build()
328            .unwrap();
329        let sql = dialect.render_select(&select).unwrap();
330
331        assert_eq!(
332            sql.text,
333            "SELECT * FROM \"embeddings\" WHERE (\"embeddings\".\"vector\" = $1::vector)"
334        );
335        assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
336    }
337
338    #[test]
339    fn postgis_params_are_cast_to_spatial_types() {
340        let dialect = PostgresDialect;
341        let select = Select::from_table("places")
342            .filter(
343                Expr::column("places__geom")
344                    .eq(Expr::param(Value::Geometry("POINT(1 2)".to_string()))),
345            )
346            .build()
347            .unwrap();
348        let sql = dialect.render_select(&select).unwrap();
349
350        assert_eq!(
351            sql.text,
352            "SELECT * FROM \"places\" WHERE (\"places\".\"geom\" = $1::geometry)"
353        );
354        assert_eq!(sql.params, vec![Value::Geometry("POINT(1 2)".to_string())]);
355
356        let select = Select::from_table("places")
357            .filter(
358                Expr::column("places__geog")
359                    .eq(Expr::param(Value::Geography("POINT(1 2)".to_string()))),
360            )
361            .build()
362            .unwrap();
363        let sql = dialect.render_select(&select).unwrap();
364
365        assert_eq!(
366            sql.text,
367            "SELECT * FROM \"places\" WHERE (\"places\".\"geog\" = $1::geography)"
368        );
369        assert_eq!(sql.params, vec![Value::Geography("POINT(1 2)".to_string())]);
370    }
371
372    #[test]
373    fn vector_distance_ordering_uses_pgvector_operator() {
374        let dialect = PostgresDialect;
375        let select = Select::from_table("embeddings")
376            .order_by_expr(
377                Expr::vector_distance(
378                    nautilus_core::VectorMetric::Cosine,
379                    Expr::column("embeddings__vector"),
380                    Expr::param(Value::Vector(vec![1.0, 2.0, 3.0])),
381                ),
382                nautilus_core::OrderDir::Asc,
383            )
384            .take(5)
385            .build()
386            .unwrap();
387        let sql = dialect.render_select(&select).unwrap();
388
389        assert_eq!(
390            sql.text,
391            "SELECT * FROM \"embeddings\" ORDER BY (\"embeddings\".\"vector\" <=> $1::vector) ASC LIMIT 5"
392        );
393        assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
394    }
395}