Skip to main content

nautilus_dialect/
postgres.rs

1//! PostgreSQL SQL dialect renderer.
2
3use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Expr, Insert, Result, Select, Update, Value};
5
6/// PostgreSQL SQL dialect renderer.
7///
8/// Uses `$1, $2, ...` numbered parameter placeholders and double-quoted identifiers.
9/// Supports `RETURNING`, `DISTINCT ON`, PostgreSQL array operators, UUID type casts,
10/// and `FILTER (WHERE ...)` on aggregates.
11#[derive(Debug, Clone, Copy)]
12pub struct PostgresDialect;
13
14impl Dialect for PostgresDialect {
15    fn render_select(&self, select: &Select) -> Result<Sql> {
16        let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(select));
17        render_select_body_core!(&mut ctx, select, '"', render_expr, true, false);
18        Ok(Sql {
19            text: ctx.sql,
20            params: ctx.params,
21        })
22    }
23
24    fn render_select_owned(&self, mut select: Select) -> Result<Sql> {
25        let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(&select));
26        render_select_body_core_mut!(&mut ctx, &mut select, '"', render_expr_owned, true, false);
27        Ok(Sql {
28            text: ctx.sql,
29            params: ctx.params,
30        })
31    }
32
33    fn render_insert(&self, insert: &Insert) -> Result<Sql> {
34        let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(insert));
35        render_insert_body!(&mut ctx, insert, '"', true, true);
36        Ok(Sql {
37            text: ctx.sql,
38            params: ctx.params,
39        })
40    }
41
42    fn render_insert_owned(&self, mut insert: Insert) -> Result<Sql> {
43        let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(&insert));
44        render_insert_body_mut!(&mut ctx, &mut insert, '"', true, true);
45        Ok(Sql {
46            text: ctx.sql,
47            params: ctx.params,
48        })
49    }
50
51    fn render_update(&self, update: &Update) -> Result<Sql> {
52        let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(update));
53        render_update_body!(&mut ctx, update, '"', render_expr, true, true);
54        Ok(Sql {
55            text: ctx.sql,
56            params: ctx.params,
57        })
58    }
59
60    fn render_update_owned(&self, mut update: Update) -> Result<Sql> {
61        let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(&update));
62        render_update_body_mut!(&mut ctx, &mut update, '"', render_expr_owned, true, true);
63        Ok(Sql {
64            text: ctx.sql,
65            params: ctx.params,
66        })
67    }
68
69    fn render_delete(&self, delete: &Delete) -> Result<Sql> {
70        let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(delete));
71        render_delete_body!(&mut ctx, delete, '"', render_expr, true);
72        Ok(Sql {
73            text: ctx.sql,
74            params: ctx.params,
75        })
76    }
77
78    fn render_delete_owned(&self, mut delete: Delete) -> Result<Sql> {
79        let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(&delete));
80        render_delete_body_mut!(&mut ctx, &mut delete, '"', render_expr_owned, true);
81        Ok(Sql {
82            text: ctx.sql,
83            params: ctx.params,
84        })
85    }
86}
87
88struct RenderContext {
89    sql: String,
90    params: Vec<Value>,
91}
92
93impl RenderContext {
94    fn with_estimate(estimate: crate::RenderEstimate) -> Self {
95        Self {
96            sql: String::with_capacity(estimate.sql_capacity),
97            params: Vec::with_capacity(estimate.params_capacity),
98        }
99    }
100
101    fn push_param(&mut self, value: Value) {
102        self.params.push(value);
103        self.sql.push('$');
104        crate::push_usize(&mut self.sql, self.params.len());
105    }
106
107    fn take_param(&mut self, value: &mut Value) {
108        self.push_param(std::mem::replace(value, Value::Null));
109    }
110}
111
112fn render_select_body(ctx: &mut RenderContext, select: &crate::Select) {
113    render_select_body_core!(ctx, select, '"', render_expr, true, false);
114}
115
116fn render_select_body_owned(ctx: &mut RenderContext, select: &mut crate::Select) {
117    render_select_body_core_mut!(ctx, select, '"', render_expr_owned, true, false);
118}
119
120fn render_expr(ctx: &mut RenderContext, expr: &Expr) {
121    render_expr_common!(ctx, expr, '"', render_expr, render_select_body, {
122        Expr::Param(value) => {
123            // NULL is emitted literally; PostgreSQL cannot implicitly resolve a
124            // typed NULL sent as an unknown OID via the binary protocol.
125            if matches!(value, Value::Null) {
126                ctx.sql.push_str("NULL");
127            } else {
128                ctx.push_param(value.clone());
129                // PostgreSQL needs an explicit cast when the driver sends an unknown OID.
130                if matches!(value, Value::Uuid(_)) {
131                    ctx.sql.push_str("::uuid");
132                } else if matches!(value, Value::Json(_)) {
133                    ctx.sql.push_str("::json");
134                } else if matches!(value, Value::Vector(_)) {
135                    ctx.sql.push_str("::vector");
136                } else if matches!(value, Value::Geometry(_)) {
137                    ctx.sql.push_str("::geometry");
138                } else if matches!(value, Value::Geography(_)) {
139                    ctx.sql.push_str("::geography");
140                } else if is_homogeneous_geometry_array(value) {
141                    ctx.sql.push_str("::geometry[]");
142                } else if is_homogeneous_geography_array(value) {
143                    ctx.sql.push_str("::geography[]");
144                } else if let Value::Enum { type_name, .. } = value {
145                    ctx.sql.push_str("::");
146                    ctx.sql.push_str(type_name);
147                }
148            }
149        }
150        Expr::Binary { left, op, right } => {
151            if matches!(op, BinaryOp::In | BinaryOp::NotIn) {
152                ctx.sql.push('(');
153                render_expr(ctx, left);
154                ctx.sql.push(' ');
155                ctx.sql.push_str(if matches!(op, BinaryOp::In) { "IN" } else { "NOT IN" });
156                ctx.sql.push_str(" (");
157                if let Expr::List(exprs) = right.as_ref() {
158                    for (i, e) in exprs.iter().enumerate() {
159                        if i > 0 { ctx.sql.push_str(", "); }
160                        render_expr(ctx, e);
161                    }
162                } else {
163                    render_expr(ctx, right);
164                }
165                ctx.sql.push(')');
166                ctx.sql.push(')');
167            } else {
168                ctx.sql.push('(');
169                render_expr(ctx, left);
170                ctx.sql.push(' ');
171                ctx.sql.push_str(match op {
172                    BinaryOp::ArrayContains => "@>",
173                    BinaryOp::ArrayContainedBy => "<@",
174                    BinaryOp::ArrayOverlaps => "&&",
175                    _ => crate::binary_op_sql(op),
176                });
177                ctx.sql.push(' ');
178                render_expr(ctx, right);
179                ctx.sql.push(')');
180            }
181        }
182        Expr::FunctionCall { name, args } => {
183            if args.len() == 2 {
184                let op = match name.as_str() {
185                    nautilus_core::expr::VECTOR_L2_DISTANCE_FUNCTION => Some("<->"),
186                    nautilus_core::expr::VECTOR_INNER_PRODUCT_FUNCTION => Some("<#>"),
187                    nautilus_core::expr::VECTOR_COSINE_DISTANCE_FUNCTION => Some("<=>"),
188                    _ => None,
189                };
190                if let Some(op) = op {
191                    ctx.sql.push('(');
192                    render_expr(ctx, &args[0]);
193                    ctx.sql.push(' ');
194                    ctx.sql.push_str(op);
195                    ctx.sql.push(' ');
196                    render_expr(ctx, &args[1]);
197                    ctx.sql.push(')');
198                    return;
199                }
200            }
201            ctx.sql.push_str(name);
202            ctx.sql.push('(');
203            for (i, arg) in args.iter().enumerate() {
204                if i > 0 { ctx.sql.push_str(", "); }
205                render_expr(ctx, arg);
206            }
207            ctx.sql.push(')');
208        }
209        Expr::Filter { expr, predicate } => {
210            // Native PostgreSQL FILTER clause (supported since pg 9.4).
211            render_expr(ctx, expr);
212            ctx.sql.push_str(" FILTER (WHERE ");
213            render_expr(ctx, predicate);
214            ctx.sql.push(')');
215        }
216    });
217}
218
219fn render_expr_owned(ctx: &mut RenderContext, expr: &mut Expr) {
220    render_expr_common_mut!(ctx, expr, '"', render_expr_owned, render_select_body_owned, {
221        Expr::Param(value) => {
222            // NULL is emitted literally; PostgreSQL cannot implicitly resolve a
223            // typed NULL sent as an unknown OID via the binary protocol.
224            if matches!(value, Value::Null) {
225                ctx.sql.push_str("NULL");
226            } else {
227                let cast = postgres_param_cast(value);
228                ctx.take_param(value);
229                if let Some(cast) = cast {
230                    cast.push_sql(&mut ctx.sql);
231                }
232            }
233        }
234        Expr::Binary { left, op, right } => {
235            if matches!(*op, BinaryOp::In | BinaryOp::NotIn) {
236                ctx.sql.push('(');
237                render_expr_owned(ctx, left.as_mut());
238                ctx.sql.push(' ');
239                ctx.sql
240                    .push_str(if matches!(*op, BinaryOp::In) { "IN" } else { "NOT IN" });
241                ctx.sql.push_str(" (");
242                if let Expr::List(exprs) = right.as_mut() {
243                    for (i, e) in exprs.iter_mut().enumerate() {
244                        if i > 0 {
245                            ctx.sql.push_str(", ");
246                        }
247                        render_expr_owned(ctx, e);
248                    }
249                } else {
250                    render_expr_owned(ctx, right.as_mut());
251                }
252                ctx.sql.push(')');
253                ctx.sql.push(')');
254            } else {
255                ctx.sql.push('(');
256                render_expr_owned(ctx, left.as_mut());
257                ctx.sql.push(' ');
258                ctx.sql.push_str(match *op {
259                    BinaryOp::ArrayContains => "@>",
260                    BinaryOp::ArrayContainedBy => "<@",
261                    BinaryOp::ArrayOverlaps => "&&",
262                    _ => crate::binary_op_sql(op),
263                });
264                ctx.sql.push(' ');
265                render_expr_owned(ctx, right.as_mut());
266                ctx.sql.push(')');
267            }
268        }
269        Expr::FunctionCall { name, args } => {
270            if args.len() == 2 {
271                let op = match name.as_str() {
272                    nautilus_core::expr::VECTOR_L2_DISTANCE_FUNCTION => Some("<->"),
273                    nautilus_core::expr::VECTOR_INNER_PRODUCT_FUNCTION => Some("<#>"),
274                    nautilus_core::expr::VECTOR_COSINE_DISTANCE_FUNCTION => Some("<=>"),
275                    _ => None,
276                };
277                if let Some(op) = op {
278                    ctx.sql.push('(');
279                    render_expr_owned(ctx, &mut args[0]);
280                    ctx.sql.push(' ');
281                    ctx.sql.push_str(op);
282                    ctx.sql.push(' ');
283                    render_expr_owned(ctx, &mut args[1]);
284                    ctx.sql.push(')');
285                    return;
286                }
287            }
288            ctx.sql.push_str(name);
289            ctx.sql.push('(');
290            for (i, arg) in args.iter_mut().enumerate() {
291                if i > 0 {
292                    ctx.sql.push_str(", ");
293                }
294                render_expr_owned(ctx, arg);
295            }
296            ctx.sql.push(')');
297        }
298        Expr::Filter { expr, predicate } => {
299            render_expr_owned(ctx, expr.as_mut());
300            ctx.sql.push_str(" FILTER (WHERE ");
301            render_expr_owned(ctx, predicate.as_mut());
302            ctx.sql.push(')');
303        }
304    });
305}
306
307enum ParamCast {
308    Static(&'static str),
309    Enum(String),
310}
311
312impl ParamCast {
313    fn push_sql(&self, sql: &mut String) {
314        match self {
315            Self::Static(name) => {
316                sql.push_str("::");
317                sql.push_str(name);
318            }
319            Self::Enum(type_name) => {
320                sql.push_str("::");
321                sql.push_str(type_name);
322            }
323        }
324    }
325}
326
327fn postgres_param_cast(value: &Value) -> Option<ParamCast> {
328    match value {
329        Value::Uuid(_) => Some(ParamCast::Static("uuid")),
330        Value::Json(_) => Some(ParamCast::Static("json")),
331        Value::Vector(_) => Some(ParamCast::Static("vector")),
332        Value::Geometry(_) => Some(ParamCast::Static("geometry")),
333        Value::Geography(_) => Some(ParamCast::Static("geography")),
334        value if is_homogeneous_geometry_array(value) => Some(ParamCast::Static("geometry[]")),
335        value if is_homogeneous_geography_array(value) => Some(ParamCast::Static("geography[]")),
336        Value::Enum { type_name, .. } => Some(ParamCast::Enum(type_name.clone())),
337        _ => None,
338    }
339}
340
341fn is_homogeneous_geometry_array(value: &Value) -> bool {
342    matches!(
343        value,
344        Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geometry(_)))
345    )
346}
347
348fn is_homogeneous_geography_array(value: &Value) -> bool {
349    matches!(
350        value,
351        Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geography(_)))
352    )
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    fn quote_identifier(name: &str) -> String {
360        let mut sql = String::new();
361        crate::push_quoted_identifier(&mut sql, name, '"');
362        sql
363    }
364
365    #[test]
366    fn test_quote_identifier() {
367        assert_eq!(quote_identifier("users"), "\"users\"");
368        assert_eq!(quote_identifier("email"), "\"email\"");
369        assert_eq!(quote_identifier("foo\"bar"), "\"foo\"\"bar\"");
370        assert_eq!(quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
371    }
372
373    #[test]
374    fn test_array_contains_operator() {
375        let dialect = PostgresDialect;
376        let expr = Expr::Binary {
377            left: Box::new(Expr::column("posts__tags")),
378            op: BinaryOp::ArrayContains,
379            right: Box::new(Expr::param(Value::Array(vec![Value::String(
380                "rust".to_string(),
381            )]))),
382        };
383        let select = Select::from_table("posts").filter(expr).build().unwrap();
384        let sql = dialect.render_select(&select).unwrap();
385
386        assert_eq!(
387            sql.text,
388            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" @> $1)"
389        );
390        assert_eq!(sql.params.len(), 1);
391        match &sql.params[0] {
392            Value::Array(arr) => {
393                assert_eq!(arr.len(), 1);
394                assert_eq!(arr[0], Value::String("rust".to_string()));
395            }
396            _ => panic!("Expected Array value"),
397        }
398    }
399
400    #[test]
401    fn test_array_contained_by_operator() {
402        let dialect = PostgresDialect;
403        let expr = Expr::Binary {
404            left: Box::new(Expr::column("posts__tags")),
405            op: BinaryOp::ArrayContainedBy,
406            right: Box::new(Expr::param(Value::Array(vec![
407                Value::String("rust".to_string()),
408                Value::String("go".to_string()),
409            ]))),
410        };
411        let select = Select::from_table("posts").filter(expr).build().unwrap();
412        let sql = dialect.render_select(&select).unwrap();
413
414        assert_eq!(
415            sql.text,
416            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" <@ $1)"
417        );
418        assert_eq!(sql.params.len(), 1);
419        match &sql.params[0] {
420            Value::Array(arr) => {
421                assert_eq!(arr.len(), 2);
422                assert_eq!(arr[0], Value::String("rust".to_string()));
423                assert_eq!(arr[1], Value::String("go".to_string()));
424            }
425            _ => panic!("Expected Array value"),
426        }
427    }
428
429    #[test]
430    fn test_array_overlaps_operator() {
431        let dialect = PostgresDialect;
432        let expr = Expr::Binary {
433            left: Box::new(Expr::column("posts__tags")),
434            op: BinaryOp::ArrayOverlaps,
435            right: Box::new(Expr::param(Value::Array(vec![
436                Value::String("rust".to_string()),
437                Value::String("python".to_string()),
438            ]))),
439        };
440        let select = Select::from_table("posts").filter(expr).build().unwrap();
441        let sql = dialect.render_select(&select).unwrap();
442
443        assert_eq!(
444            sql.text,
445            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" && $1)"
446        );
447        assert_eq!(sql.params.len(), 1);
448        match &sql.params[0] {
449            Value::Array(arr) => {
450                assert_eq!(arr.len(), 2);
451                assert_eq!(arr[0], Value::String("rust".to_string()));
452                assert_eq!(arr[1], Value::String("python".to_string()));
453            }
454            _ => panic!("Expected Array value"),
455        }
456    }
457
458    #[test]
459    fn test_array_operators_with_integers() {
460        let dialect = PostgresDialect;
461        let expr = Expr::Binary {
462            left: Box::new(Expr::column("posts__scores")),
463            op: BinaryOp::ArrayContains,
464            right: Box::new(Expr::param(Value::Array(vec![
465                Value::I32(100),
466                Value::I32(200),
467            ]))),
468        };
469        let select = Select::from_table("posts").filter(expr).build().unwrap();
470        let sql = dialect.render_select(&select).unwrap();
471
472        assert_eq!(
473            sql.text,
474            "SELECT * FROM \"posts\" WHERE (\"posts\".\"scores\" @> $1)"
475        );
476        assert_eq!(sql.params.len(), 1);
477        match &sql.params[0] {
478            Value::Array(arr) => {
479                assert_eq!(arr.len(), 2);
480                assert_eq!(arr[0], Value::I32(100));
481                assert_eq!(arr[1], Value::I32(200));
482            }
483            _ => panic!("Expected Array value"),
484        }
485    }
486
487    #[test]
488    fn vector_params_are_cast_to_pgvector_type() {
489        let dialect = PostgresDialect;
490        let select = Select::from_table("embeddings")
491            .filter(
492                Expr::column("embeddings__vector")
493                    .eq(Expr::param(Value::Vector(vec![1.0, 2.0, 3.0]))),
494            )
495            .build()
496            .unwrap();
497        let sql = dialect.render_select(&select).unwrap();
498
499        assert_eq!(
500            sql.text,
501            "SELECT * FROM \"embeddings\" WHERE (\"embeddings\".\"vector\" = $1::vector)"
502        );
503        assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
504    }
505
506    #[test]
507    fn postgis_params_are_cast_to_spatial_types() {
508        let dialect = PostgresDialect;
509        let select = Select::from_table("places")
510            .filter(
511                Expr::column("places__geom")
512                    .eq(Expr::param(Value::Geometry("POINT(1 2)".to_string()))),
513            )
514            .build()
515            .unwrap();
516        let sql = dialect.render_select(&select).unwrap();
517
518        assert_eq!(
519            sql.text,
520            "SELECT * FROM \"places\" WHERE (\"places\".\"geom\" = $1::geometry)"
521        );
522        assert_eq!(sql.params, vec![Value::Geometry("POINT(1 2)".to_string())]);
523
524        let select = Select::from_table("places")
525            .filter(
526                Expr::column("places__geog")
527                    .eq(Expr::param(Value::Geography("POINT(1 2)".to_string()))),
528            )
529            .build()
530            .unwrap();
531        let sql = dialect.render_select(&select).unwrap();
532
533        assert_eq!(
534            sql.text,
535            "SELECT * FROM \"places\" WHERE (\"places\".\"geog\" = $1::geography)"
536        );
537        assert_eq!(sql.params, vec![Value::Geography("POINT(1 2)".to_string())]);
538    }
539
540    #[test]
541    fn vector_distance_ordering_uses_pgvector_operator() {
542        let dialect = PostgresDialect;
543        let select = Select::from_table("embeddings")
544            .order_by_expr(
545                Expr::vector_distance(
546                    nautilus_core::VectorMetric::Cosine,
547                    Expr::column("embeddings__vector"),
548                    Expr::param(Value::Vector(vec![1.0, 2.0, 3.0])),
549                ),
550                nautilus_core::OrderDir::Asc,
551            )
552            .take(5)
553            .build()
554            .unwrap();
555        let sql = dialect.render_select(&select).unwrap();
556
557        assert_eq!(
558            sql.text,
559            "SELECT * FROM \"embeddings\" ORDER BY (\"embeddings\".\"vector\" <=> $1::vector) ASC LIMIT 5"
560        );
561        assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
562    }
563}