Skip to main content

nautilus_dialect/
postgres.rs

1//! PostgreSQL SQL dialect renderer.
2
3use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Expr, Insert, Result, Select, Update, Value};
5
6/// PostgreSQL SQL dialect renderer.
7///
8/// Uses `$1, $2, ...` numbered parameter placeholders and double-quoted identifiers.
9/// Supports `RETURNING`, `DISTINCT ON`, PostgreSQL array operators, UUID type casts,
10/// and `FILTER (WHERE ...)` on aggregates.
11#[derive(Debug, Clone, Copy)]
12pub struct PostgresDialect;
13
14impl Dialect for PostgresDialect {
15    fn render_select_owned(&self, mut select: Select) -> Result<Sql> {
16        let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(&select));
17        render_select_body_core_mut!(&mut ctx, &mut select, '"', render_expr_owned, true, false);
18        Ok(Sql {
19            text: ctx.sql,
20            params: ctx.params,
21        })
22    }
23
24    fn render_insert_owned(&self, mut insert: Insert) -> Result<Sql> {
25        let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(&insert));
26        render_insert_body_mut!(&mut ctx, &mut insert, '"', true, true);
27        Ok(Sql {
28            text: ctx.sql,
29            params: ctx.params,
30        })
31    }
32
33    fn render_update_owned(&self, mut update: Update) -> Result<Sql> {
34        let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(&update));
35        render_update_body_mut!(&mut ctx, &mut update, '"', render_expr_owned, true, true);
36        Ok(Sql {
37            text: ctx.sql,
38            params: ctx.params,
39        })
40    }
41
42    fn render_delete_owned(&self, mut delete: Delete) -> Result<Sql> {
43        let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(&delete));
44        render_delete_body_mut!(&mut ctx, &mut delete, '"', render_expr_owned, true);
45        Ok(Sql {
46            text: ctx.sql,
47            params: ctx.params,
48        })
49    }
50}
51
52struct RenderContext {
53    sql: String,
54    params: Vec<Value>,
55}
56
57impl RenderContext {
58    fn with_estimate(estimate: crate::RenderEstimate) -> Self {
59        Self {
60            sql: String::with_capacity(estimate.sql_capacity),
61            params: Vec::with_capacity(estimate.params_capacity),
62        }
63    }
64
65    fn push_param(&mut self, value: Value) {
66        self.params.push(value);
67        self.sql.push('$');
68        crate::push_usize(&mut self.sql, self.params.len());
69    }
70
71    fn take_param(&mut self, value: &mut Value) {
72        self.push_param(std::mem::replace(value, Value::Null));
73    }
74}
75
76fn render_select_body_owned(ctx: &mut RenderContext, select: &mut crate::Select) {
77    render_select_body_core_mut!(ctx, select, '"', render_expr_owned, true, false);
78}
79
80fn render_expr_owned(ctx: &mut RenderContext, expr: &mut Expr) {
81    render_expr_common_mut!(ctx, expr, '"', render_expr_owned, render_select_body_owned, {
82        Expr::CompositeField {
83            table,
84            column,
85            field,
86            ..
87        } => {
88            crate::push_composite_field_reference(&mut ctx.sql, table, column, field, '"');
89        }
90        Expr::Param(value) => {
91            // NULL is emitted literally; PostgreSQL cannot implicitly resolve a
92            // typed NULL sent as an unknown OID via the binary protocol.
93            if matches!(value, Value::Null) {
94                ctx.sql.push_str("NULL");
95            } else {
96                let cast = postgres_param_cast(value);
97                ctx.take_param(value);
98                if let Some(cast) = cast {
99                    cast.push_sql(&mut ctx.sql);
100                }
101            }
102        }
103        Expr::Binary { left, op, right } => {
104            if matches!(*op, BinaryOp::In | BinaryOp::NotIn) {
105                ctx.sql.push('(');
106                render_expr_owned(ctx, left.as_mut());
107                ctx.sql.push(' ');
108                ctx.sql
109                    .push_str(if matches!(*op, BinaryOp::In) { "IN" } else { "NOT IN" });
110                ctx.sql.push_str(" (");
111                if let Expr::List(exprs) = right.as_mut() {
112                    for (i, e) in exprs.iter_mut().enumerate() {
113                        if i > 0 {
114                            ctx.sql.push_str(", ");
115                        }
116                        render_expr_owned(ctx, e);
117                    }
118                } else {
119                    render_expr_owned(ctx, right.as_mut());
120                }
121                ctx.sql.push(')');
122                ctx.sql.push(')');
123            } else {
124                ctx.sql.push('(');
125                render_expr_owned(ctx, left.as_mut());
126                ctx.sql.push(' ');
127                ctx.sql.push_str(match *op {
128                    BinaryOp::ArrayContains => "@>",
129                    BinaryOp::ArrayContainedBy => "<@",
130                    BinaryOp::ArrayOverlaps => "&&",
131                    _ => crate::binary_op_sql(op),
132                });
133                ctx.sql.push(' ');
134                render_expr_owned(ctx, right.as_mut());
135                ctx.sql.push(')');
136            }
137        }
138        Expr::FunctionCall { name, args } => {
139            if args.len() == 2 {
140                let op = match name.as_str() {
141                    nautilus_core::expr::VECTOR_L2_DISTANCE_FUNCTION => Some("<->"),
142                    nautilus_core::expr::VECTOR_INNER_PRODUCT_FUNCTION => Some("<#>"),
143                    nautilus_core::expr::VECTOR_COSINE_DISTANCE_FUNCTION => Some("<=>"),
144                    _ => None,
145                };
146                if let Some(op) = op {
147                    ctx.sql.push('(');
148                    render_expr_owned(ctx, &mut args[0]);
149                    ctx.sql.push(' ');
150                    ctx.sql.push_str(op);
151                    ctx.sql.push(' ');
152                    render_expr_owned(ctx, &mut args[1]);
153                    ctx.sql.push(')');
154                    return;
155                }
156            }
157            ctx.sql.push_str(name);
158            ctx.sql.push('(');
159            for (i, arg) in args.iter_mut().enumerate() {
160                if i > 0 {
161                    ctx.sql.push_str(", ");
162                }
163                render_expr_owned(ctx, arg);
164            }
165            ctx.sql.push(')');
166        }
167        Expr::Filter { expr, predicate } => {
168            render_expr_owned(ctx, expr.as_mut());
169            ctx.sql.push_str(" FILTER (WHERE ");
170            render_expr_owned(ctx, predicate.as_mut());
171            ctx.sql.push(')');
172        }
173    });
174}
175
176enum ParamCast {
177    Static(&'static str),
178    Enum(String),
179    Composite(String),
180}
181
182impl ParamCast {
183    fn push_sql(&self, sql: &mut String) {
184        match self {
185            Self::Static(name) => {
186                sql.push_str("::");
187                sql.push_str(name);
188            }
189            Self::Enum(type_name) | Self::Composite(type_name) => {
190                sql.push_str("::");
191                crate::push_quoted_identifier(sql, type_name, '"');
192            }
193        }
194    }
195}
196
197fn postgres_param_cast(value: &Value) -> Option<ParamCast> {
198    match value {
199        Value::Uuid(_) => Some(ParamCast::Static("uuid")),
200        Value::Json(_) => Some(ParamCast::Static("json")),
201        Value::Vector(_) => Some(ParamCast::Static("vector")),
202        Value::Geometry(_) => Some(ParamCast::Static("geometry")),
203        Value::Geography(_) => Some(ParamCast::Static("geography")),
204        value if is_homogeneous_geometry_array(value) => Some(ParamCast::Static("geometry[]")),
205        value if is_homogeneous_geography_array(value) => Some(ParamCast::Static("geography[]")),
206        Value::Enum { type_name, .. } => Some(ParamCast::Enum(type_name.clone())),
207        Value::Composite { type_name, .. } => Some(ParamCast::Composite(type_name.clone())),
208        _ => None,
209    }
210}
211
212fn is_homogeneous_geometry_array(value: &Value) -> bool {
213    matches!(
214        value,
215        Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geometry(_)))
216    )
217}
218
219fn is_homogeneous_geography_array(value: &Value) -> bool {
220    matches!(
221        value,
222        Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geography(_)))
223    )
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    fn quote_identifier(name: &str) -> String {
231        let mut sql = String::new();
232        crate::push_quoted_identifier(&mut sql, name, '"');
233        sql
234    }
235
236    #[test]
237    fn test_quote_identifier() {
238        assert_eq!(quote_identifier("users"), "\"users\"");
239        assert_eq!(quote_identifier("email"), "\"email\"");
240        assert_eq!(quote_identifier("foo\"bar"), "\"foo\"\"bar\"");
241        assert_eq!(quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
242    }
243
244    #[test]
245    fn test_array_contains_operator() {
246        let dialect = PostgresDialect;
247        let expr = Expr::Binary {
248            left: Box::new(Expr::column("posts__tags")),
249            op: BinaryOp::ArrayContains,
250            right: Box::new(Expr::param(Value::Array(vec![Value::String(
251                "rust".to_string(),
252            )]))),
253        };
254        let select = Select::from_table("posts").filter(expr).build().unwrap();
255        let sql = dialect.render_select(&select).unwrap();
256
257        assert_eq!(
258            sql.text,
259            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" @> $1)"
260        );
261        assert_eq!(sql.params.len(), 1);
262        match &sql.params[0] {
263            Value::Array(arr) => {
264                assert_eq!(arr.len(), 1);
265                assert_eq!(arr[0], Value::String("rust".to_string()));
266            }
267            _ => panic!("Expected Array value"),
268        }
269    }
270
271    #[test]
272    fn test_array_contained_by_operator() {
273        let dialect = PostgresDialect;
274        let expr = Expr::Binary {
275            left: Box::new(Expr::column("posts__tags")),
276            op: BinaryOp::ArrayContainedBy,
277            right: Box::new(Expr::param(Value::Array(vec![
278                Value::String("rust".to_string()),
279                Value::String("go".to_string()),
280            ]))),
281        };
282        let select = Select::from_table("posts").filter(expr).build().unwrap();
283        let sql = dialect.render_select(&select).unwrap();
284
285        assert_eq!(
286            sql.text,
287            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" <@ $1)"
288        );
289        assert_eq!(sql.params.len(), 1);
290        match &sql.params[0] {
291            Value::Array(arr) => {
292                assert_eq!(arr.len(), 2);
293                assert_eq!(arr[0], Value::String("rust".to_string()));
294                assert_eq!(arr[1], Value::String("go".to_string()));
295            }
296            _ => panic!("Expected Array value"),
297        }
298    }
299
300    #[test]
301    fn test_array_overlaps_operator() {
302        let dialect = PostgresDialect;
303        let expr = Expr::Binary {
304            left: Box::new(Expr::column("posts__tags")),
305            op: BinaryOp::ArrayOverlaps,
306            right: Box::new(Expr::param(Value::Array(vec![
307                Value::String("rust".to_string()),
308                Value::String("python".to_string()),
309            ]))),
310        };
311        let select = Select::from_table("posts").filter(expr).build().unwrap();
312        let sql = dialect.render_select(&select).unwrap();
313
314        assert_eq!(
315            sql.text,
316            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" && $1)"
317        );
318        assert_eq!(sql.params.len(), 1);
319        match &sql.params[0] {
320            Value::Array(arr) => {
321                assert_eq!(arr.len(), 2);
322                assert_eq!(arr[0], Value::String("rust".to_string()));
323                assert_eq!(arr[1], Value::String("python".to_string()));
324            }
325            _ => panic!("Expected Array value"),
326        }
327    }
328
329    #[test]
330    fn test_array_operators_with_integers() {
331        let dialect = PostgresDialect;
332        let expr = Expr::Binary {
333            left: Box::new(Expr::column("posts__scores")),
334            op: BinaryOp::ArrayContains,
335            right: Box::new(Expr::param(Value::Array(vec![
336                Value::I32(100),
337                Value::I32(200),
338            ]))),
339        };
340        let select = Select::from_table("posts").filter(expr).build().unwrap();
341        let sql = dialect.render_select(&select).unwrap();
342
343        assert_eq!(
344            sql.text,
345            "SELECT * FROM \"posts\" WHERE (\"posts\".\"scores\" @> $1)"
346        );
347        assert_eq!(sql.params.len(), 1);
348        match &sql.params[0] {
349            Value::Array(arr) => {
350                assert_eq!(arr.len(), 2);
351                assert_eq!(arr[0], Value::I32(100));
352                assert_eq!(arr[1], Value::I32(200));
353            }
354            _ => panic!("Expected Array value"),
355        }
356    }
357
358    #[test]
359    fn vector_params_are_cast_to_pgvector_type() {
360        let dialect = PostgresDialect;
361        let select = Select::from_table("embeddings")
362            .filter(
363                Expr::column("embeddings__vector")
364                    .eq(Expr::param(Value::Vector(vec![1.0, 2.0, 3.0]))),
365            )
366            .build()
367            .unwrap();
368        let sql = dialect.render_select(&select).unwrap();
369
370        assert_eq!(
371            sql.text,
372            "SELECT * FROM \"embeddings\" WHERE (\"embeddings\".\"vector\" = $1::vector)"
373        );
374        assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
375    }
376
377    #[test]
378    fn postgis_params_are_cast_to_spatial_types() {
379        let dialect = PostgresDialect;
380        let select = Select::from_table("places")
381            .filter(
382                Expr::column("places__geom")
383                    .eq(Expr::param(Value::Geometry("POINT(1 2)".to_string()))),
384            )
385            .build()
386            .unwrap();
387        let sql = dialect.render_select(&select).unwrap();
388
389        assert_eq!(
390            sql.text,
391            "SELECT * FROM \"places\" WHERE (\"places\".\"geom\" = $1::geometry)"
392        );
393        assert_eq!(sql.params, vec![Value::Geometry("POINT(1 2)".to_string())]);
394
395        let select = Select::from_table("places")
396            .filter(
397                Expr::column("places__geog")
398                    .eq(Expr::param(Value::Geography("POINT(1 2)".to_string()))),
399            )
400            .build()
401            .unwrap();
402        let sql = dialect.render_select(&select).unwrap();
403
404        assert_eq!(
405            sql.text,
406            "SELECT * FROM \"places\" WHERE (\"places\".\"geog\" = $1::geography)"
407        );
408        assert_eq!(sql.params, vec![Value::Geography("POINT(1 2)".to_string())]);
409    }
410
411    #[test]
412    fn composite_params_are_cast_to_their_type_name() {
413        let dialect = PostgresDialect;
414        let composite = Value::Composite {
415            type_name: "ChampionStatsT".to_string(),
416            fields: vec![Value::I32(0), Value::I32(0)],
417        };
418        let select = Select::from_table("champions")
419            .filter(Expr::column("champions__stats").eq(Expr::param(composite.clone())))
420            .build()
421            .unwrap();
422        let sql = dialect.render_select(&select).unwrap();
423
424        assert_eq!(
425            sql.text,
426            "SELECT * FROM \"champions\" WHERE (\"champions\".\"stats\" = $1::\"ChampionStatsT\")"
427        );
428        assert_eq!(sql.params, vec![composite]);
429    }
430
431    #[test]
432    fn composite_insert_and_update_params_are_cast_to_their_type_name() {
433        let dialect = PostgresDialect;
434        let composite = Value::Composite {
435            type_name: "ChampionStatsT".to_string(),
436            fields: vec![Value::I32(0), Value::I32(0)],
437        };
438
439        let insert = Insert::into_table("champions")
440            .column(nautilus_core::ColumnMarker::new("champions", "stats"))
441            .values(vec![composite.clone()])
442            .build()
443            .unwrap();
444        let sql = dialect.render_insert(&insert).unwrap();
445
446        assert_eq!(
447            sql.text,
448            "INSERT INTO \"champions\" (\"stats\") VALUES ($1::\"ChampionStatsT\")"
449        );
450        assert_eq!(sql.params, vec![composite.clone()]);
451
452        let update = Update::table("champions")
453            .set(
454                nautilus_core::ColumnMarker::new("champions", "stats"),
455                composite.clone(),
456            )
457            .build()
458            .unwrap();
459        let sql = dialect.render_update(&update).unwrap();
460
461        assert_eq!(
462            sql.text,
463            "UPDATE \"champions\" SET \"stats\" = $1::\"ChampionStatsT\""
464        );
465        assert_eq!(sql.params, vec![composite]);
466    }
467
468    #[test]
469    fn composite_field_ordering_uses_native_attribute_syntax() {
470        let dialect = PostgresDialect;
471        let select = Select::from_table("shipments")
472            .order_by_expr(
473                Expr::composite_field(
474                    "shipments",
475                    "delivery_snapshot",
476                    "eta_minutes",
477                    "etaMinutes",
478                    nautilus_core::JsonPathCast::Signed,
479                ),
480                nautilus_core::OrderDir::Asc,
481            )
482            .build()
483            .unwrap();
484        let sql = dialect.render_select(&select).unwrap();
485
486        assert_eq!(
487            sql.text,
488            "SELECT * FROM \"shipments\" ORDER BY (\"shipments\".\"delivery_snapshot\").\"eta_minutes\" ASC"
489        );
490    }
491
492    #[test]
493    fn vector_distance_ordering_uses_pgvector_operator() {
494        let dialect = PostgresDialect;
495        let select = Select::from_table("embeddings")
496            .order_by_expr(
497                Expr::vector_distance(
498                    nautilus_core::VectorMetric::Cosine,
499                    Expr::column("embeddings__vector"),
500                    Expr::param(Value::Vector(vec![1.0, 2.0, 3.0])),
501                ),
502                nautilus_core::OrderDir::Asc,
503            )
504            .take(5)
505            .build()
506            .unwrap();
507        let sql = dialect.render_select(&select).unwrap();
508
509        assert_eq!(
510            sql.text,
511            "SELECT * FROM \"embeddings\" ORDER BY (\"embeddings\".\"vector\" <=> $1::vector) ASC LIMIT 5"
512        );
513        assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
514    }
515}