Skip to main content

nautilus_dialect/
postgres.rs

1//! PostgreSQL SQL dialect renderer.
2
3use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Expr, Insert, Result, Select, Update, Value};
5
6/// PostgreSQL SQL dialect renderer.
7///
8/// Uses `$1, $2, ...` numbered parameter placeholders and double-quoted identifiers.
9/// Supports `RETURNING`, `DISTINCT ON`, PostgreSQL array operators, UUID type casts,
10/// and `FILTER (WHERE ...)` on aggregates.
11#[derive(Debug, Clone, Copy)]
12pub struct PostgresDialect;
13
14impl Dialect for PostgresDialect {
15    fn render_select_owned(&self, mut select: Select) -> Result<Sql> {
16        let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(&select));
17        render_select_body_core_mut!(&mut ctx, &mut select, '"', render_expr_owned, true, false);
18        Ok(Sql {
19            text: ctx.sql,
20            params: ctx.params,
21        })
22    }
23
24    fn render_insert_owned(&self, mut insert: Insert) -> Result<Sql> {
25        let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(&insert));
26        render_insert_body_mut!(&mut ctx, &mut insert, '"', true, true);
27        Ok(Sql {
28            text: ctx.sql,
29            params: ctx.params,
30        })
31    }
32
33    fn render_update_owned(&self, mut update: Update) -> Result<Sql> {
34        let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(&update));
35        render_update_body_mut!(&mut ctx, &mut update, '"', render_expr_owned, true, true);
36        Ok(Sql {
37            text: ctx.sql,
38            params: ctx.params,
39        })
40    }
41
42    fn render_delete_owned(&self, mut delete: Delete) -> Result<Sql> {
43        let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(&delete));
44        render_delete_body_mut!(&mut ctx, &mut delete, '"', render_expr_owned, true);
45        Ok(Sql {
46            text: ctx.sql,
47            params: ctx.params,
48        })
49    }
50}
51
52struct RenderContext {
53    sql: String,
54    params: Vec<Value>,
55}
56
57impl RenderContext {
58    fn with_estimate(estimate: crate::RenderEstimate) -> Self {
59        Self {
60            sql: String::with_capacity(estimate.sql_capacity),
61            params: Vec::with_capacity(estimate.params_capacity),
62        }
63    }
64
65    fn push_param(&mut self, value: Value) {
66        self.params.push(value);
67        self.sql.push('$');
68        crate::push_usize(&mut self.sql, self.params.len());
69    }
70
71    fn take_param(&mut self, value: &mut Value) {
72        self.push_param(std::mem::replace(value, Value::Null));
73    }
74}
75
76fn render_select_body_owned(ctx: &mut RenderContext, select: &mut crate::Select) {
77    render_select_body_core_mut!(ctx, select, '"', render_expr_owned, true, false);
78}
79
80fn render_expr_owned(ctx: &mut RenderContext, expr: &mut Expr) {
81    render_expr_common_mut!(ctx, expr, '"', render_expr_owned, render_select_body_owned, {
82        Expr::Param(value) => {
83            // NULL is emitted literally; PostgreSQL cannot implicitly resolve a
84            // typed NULL sent as an unknown OID via the binary protocol.
85            if matches!(value, Value::Null) {
86                ctx.sql.push_str("NULL");
87            } else {
88                let cast = postgres_param_cast(value);
89                ctx.take_param(value);
90                if let Some(cast) = cast {
91                    cast.push_sql(&mut ctx.sql);
92                }
93            }
94        }
95        Expr::Binary { left, op, right } => {
96            if matches!(*op, BinaryOp::In | BinaryOp::NotIn) {
97                ctx.sql.push('(');
98                render_expr_owned(ctx, left.as_mut());
99                ctx.sql.push(' ');
100                ctx.sql
101                    .push_str(if matches!(*op, BinaryOp::In) { "IN" } else { "NOT IN" });
102                ctx.sql.push_str(" (");
103                if let Expr::List(exprs) = right.as_mut() {
104                    for (i, e) in exprs.iter_mut().enumerate() {
105                        if i > 0 {
106                            ctx.sql.push_str(", ");
107                        }
108                        render_expr_owned(ctx, e);
109                    }
110                } else {
111                    render_expr_owned(ctx, right.as_mut());
112                }
113                ctx.sql.push(')');
114                ctx.sql.push(')');
115            } else {
116                ctx.sql.push('(');
117                render_expr_owned(ctx, left.as_mut());
118                ctx.sql.push(' ');
119                ctx.sql.push_str(match *op {
120                    BinaryOp::ArrayContains => "@>",
121                    BinaryOp::ArrayContainedBy => "<@",
122                    BinaryOp::ArrayOverlaps => "&&",
123                    _ => crate::binary_op_sql(op),
124                });
125                ctx.sql.push(' ');
126                render_expr_owned(ctx, right.as_mut());
127                ctx.sql.push(')');
128            }
129        }
130        Expr::FunctionCall { name, args } => {
131            if args.len() == 2 {
132                let op = match name.as_str() {
133                    nautilus_core::expr::VECTOR_L2_DISTANCE_FUNCTION => Some("<->"),
134                    nautilus_core::expr::VECTOR_INNER_PRODUCT_FUNCTION => Some("<#>"),
135                    nautilus_core::expr::VECTOR_COSINE_DISTANCE_FUNCTION => Some("<=>"),
136                    _ => None,
137                };
138                if let Some(op) = op {
139                    ctx.sql.push('(');
140                    render_expr_owned(ctx, &mut args[0]);
141                    ctx.sql.push(' ');
142                    ctx.sql.push_str(op);
143                    ctx.sql.push(' ');
144                    render_expr_owned(ctx, &mut args[1]);
145                    ctx.sql.push(')');
146                    return;
147                }
148            }
149            ctx.sql.push_str(name);
150            ctx.sql.push('(');
151            for (i, arg) in args.iter_mut().enumerate() {
152                if i > 0 {
153                    ctx.sql.push_str(", ");
154                }
155                render_expr_owned(ctx, arg);
156            }
157            ctx.sql.push(')');
158        }
159        Expr::Filter { expr, predicate } => {
160            render_expr_owned(ctx, expr.as_mut());
161            ctx.sql.push_str(" FILTER (WHERE ");
162            render_expr_owned(ctx, predicate.as_mut());
163            ctx.sql.push(')');
164        }
165    });
166}
167
168enum ParamCast {
169    Static(&'static str),
170    Enum(String),
171    Composite(String),
172}
173
174impl ParamCast {
175    fn push_sql(&self, sql: &mut String) {
176        match self {
177            Self::Static(name) => {
178                sql.push_str("::");
179                sql.push_str(name);
180            }
181            Self::Enum(type_name) | Self::Composite(type_name) => {
182                sql.push_str("::");
183                crate::push_quoted_identifier(sql, type_name, '"');
184            }
185        }
186    }
187}
188
189fn postgres_param_cast(value: &Value) -> Option<ParamCast> {
190    match value {
191        Value::Uuid(_) => Some(ParamCast::Static("uuid")),
192        Value::Json(_) => Some(ParamCast::Static("json")),
193        Value::Vector(_) => Some(ParamCast::Static("vector")),
194        Value::Geometry(_) => Some(ParamCast::Static("geometry")),
195        Value::Geography(_) => Some(ParamCast::Static("geography")),
196        value if is_homogeneous_geometry_array(value) => Some(ParamCast::Static("geometry[]")),
197        value if is_homogeneous_geography_array(value) => Some(ParamCast::Static("geography[]")),
198        Value::Enum { type_name, .. } => Some(ParamCast::Enum(type_name.clone())),
199        Value::Composite { type_name, .. } => Some(ParamCast::Composite(type_name.clone())),
200        _ => None,
201    }
202}
203
204fn is_homogeneous_geometry_array(value: &Value) -> bool {
205    matches!(
206        value,
207        Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geometry(_)))
208    )
209}
210
211fn is_homogeneous_geography_array(value: &Value) -> bool {
212    matches!(
213        value,
214        Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geography(_)))
215    )
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    fn quote_identifier(name: &str) -> String {
223        let mut sql = String::new();
224        crate::push_quoted_identifier(&mut sql, name, '"');
225        sql
226    }
227
228    #[test]
229    fn test_quote_identifier() {
230        assert_eq!(quote_identifier("users"), "\"users\"");
231        assert_eq!(quote_identifier("email"), "\"email\"");
232        assert_eq!(quote_identifier("foo\"bar"), "\"foo\"\"bar\"");
233        assert_eq!(quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
234    }
235
236    #[test]
237    fn test_array_contains_operator() {
238        let dialect = PostgresDialect;
239        let expr = Expr::Binary {
240            left: Box::new(Expr::column("posts__tags")),
241            op: BinaryOp::ArrayContains,
242            right: Box::new(Expr::param(Value::Array(vec![Value::String(
243                "rust".to_string(),
244            )]))),
245        };
246        let select = Select::from_table("posts").filter(expr).build().unwrap();
247        let sql = dialect.render_select(&select).unwrap();
248
249        assert_eq!(
250            sql.text,
251            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" @> $1)"
252        );
253        assert_eq!(sql.params.len(), 1);
254        match &sql.params[0] {
255            Value::Array(arr) => {
256                assert_eq!(arr.len(), 1);
257                assert_eq!(arr[0], Value::String("rust".to_string()));
258            }
259            _ => panic!("Expected Array value"),
260        }
261    }
262
263    #[test]
264    fn test_array_contained_by_operator() {
265        let dialect = PostgresDialect;
266        let expr = Expr::Binary {
267            left: Box::new(Expr::column("posts__tags")),
268            op: BinaryOp::ArrayContainedBy,
269            right: Box::new(Expr::param(Value::Array(vec![
270                Value::String("rust".to_string()),
271                Value::String("go".to_string()),
272            ]))),
273        };
274        let select = Select::from_table("posts").filter(expr).build().unwrap();
275        let sql = dialect.render_select(&select).unwrap();
276
277        assert_eq!(
278            sql.text,
279            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" <@ $1)"
280        );
281        assert_eq!(sql.params.len(), 1);
282        match &sql.params[0] {
283            Value::Array(arr) => {
284                assert_eq!(arr.len(), 2);
285                assert_eq!(arr[0], Value::String("rust".to_string()));
286                assert_eq!(arr[1], Value::String("go".to_string()));
287            }
288            _ => panic!("Expected Array value"),
289        }
290    }
291
292    #[test]
293    fn test_array_overlaps_operator() {
294        let dialect = PostgresDialect;
295        let expr = Expr::Binary {
296            left: Box::new(Expr::column("posts__tags")),
297            op: BinaryOp::ArrayOverlaps,
298            right: Box::new(Expr::param(Value::Array(vec![
299                Value::String("rust".to_string()),
300                Value::String("python".to_string()),
301            ]))),
302        };
303        let select = Select::from_table("posts").filter(expr).build().unwrap();
304        let sql = dialect.render_select(&select).unwrap();
305
306        assert_eq!(
307            sql.text,
308            "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" && $1)"
309        );
310        assert_eq!(sql.params.len(), 1);
311        match &sql.params[0] {
312            Value::Array(arr) => {
313                assert_eq!(arr.len(), 2);
314                assert_eq!(arr[0], Value::String("rust".to_string()));
315                assert_eq!(arr[1], Value::String("python".to_string()));
316            }
317            _ => panic!("Expected Array value"),
318        }
319    }
320
321    #[test]
322    fn test_array_operators_with_integers() {
323        let dialect = PostgresDialect;
324        let expr = Expr::Binary {
325            left: Box::new(Expr::column("posts__scores")),
326            op: BinaryOp::ArrayContains,
327            right: Box::new(Expr::param(Value::Array(vec![
328                Value::I32(100),
329                Value::I32(200),
330            ]))),
331        };
332        let select = Select::from_table("posts").filter(expr).build().unwrap();
333        let sql = dialect.render_select(&select).unwrap();
334
335        assert_eq!(
336            sql.text,
337            "SELECT * FROM \"posts\" WHERE (\"posts\".\"scores\" @> $1)"
338        );
339        assert_eq!(sql.params.len(), 1);
340        match &sql.params[0] {
341            Value::Array(arr) => {
342                assert_eq!(arr.len(), 2);
343                assert_eq!(arr[0], Value::I32(100));
344                assert_eq!(arr[1], Value::I32(200));
345            }
346            _ => panic!("Expected Array value"),
347        }
348    }
349
350    #[test]
351    fn vector_params_are_cast_to_pgvector_type() {
352        let dialect = PostgresDialect;
353        let select = Select::from_table("embeddings")
354            .filter(
355                Expr::column("embeddings__vector")
356                    .eq(Expr::param(Value::Vector(vec![1.0, 2.0, 3.0]))),
357            )
358            .build()
359            .unwrap();
360        let sql = dialect.render_select(&select).unwrap();
361
362        assert_eq!(
363            sql.text,
364            "SELECT * FROM \"embeddings\" WHERE (\"embeddings\".\"vector\" = $1::vector)"
365        );
366        assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
367    }
368
369    #[test]
370    fn postgis_params_are_cast_to_spatial_types() {
371        let dialect = PostgresDialect;
372        let select = Select::from_table("places")
373            .filter(
374                Expr::column("places__geom")
375                    .eq(Expr::param(Value::Geometry("POINT(1 2)".to_string()))),
376            )
377            .build()
378            .unwrap();
379        let sql = dialect.render_select(&select).unwrap();
380
381        assert_eq!(
382            sql.text,
383            "SELECT * FROM \"places\" WHERE (\"places\".\"geom\" = $1::geometry)"
384        );
385        assert_eq!(sql.params, vec![Value::Geometry("POINT(1 2)".to_string())]);
386
387        let select = Select::from_table("places")
388            .filter(
389                Expr::column("places__geog")
390                    .eq(Expr::param(Value::Geography("POINT(1 2)".to_string()))),
391            )
392            .build()
393            .unwrap();
394        let sql = dialect.render_select(&select).unwrap();
395
396        assert_eq!(
397            sql.text,
398            "SELECT * FROM \"places\" WHERE (\"places\".\"geog\" = $1::geography)"
399        );
400        assert_eq!(sql.params, vec![Value::Geography("POINT(1 2)".to_string())]);
401    }
402
403    #[test]
404    fn composite_params_are_cast_to_their_type_name() {
405        let dialect = PostgresDialect;
406        let composite = Value::Composite {
407            type_name: "ChampionStatsT".to_string(),
408            fields: vec![Value::I32(0), Value::I32(0)],
409        };
410        let select = Select::from_table("champions")
411            .filter(Expr::column("champions__stats").eq(Expr::param(composite.clone())))
412            .build()
413            .unwrap();
414        let sql = dialect.render_select(&select).unwrap();
415
416        assert_eq!(
417            sql.text,
418            "SELECT * FROM \"champions\" WHERE (\"champions\".\"stats\" = $1::\"ChampionStatsT\")"
419        );
420        assert_eq!(sql.params, vec![composite]);
421    }
422
423    #[test]
424    fn composite_insert_and_update_params_are_cast_to_their_type_name() {
425        let dialect = PostgresDialect;
426        let composite = Value::Composite {
427            type_name: "ChampionStatsT".to_string(),
428            fields: vec![Value::I32(0), Value::I32(0)],
429        };
430
431        let insert = Insert::into_table("champions")
432            .column(nautilus_core::ColumnMarker::new("champions", "stats"))
433            .values(vec![composite.clone()])
434            .build()
435            .unwrap();
436        let sql = dialect.render_insert(&insert).unwrap();
437
438        assert_eq!(
439            sql.text,
440            "INSERT INTO \"champions\" (\"stats\") VALUES ($1::\"ChampionStatsT\")"
441        );
442        assert_eq!(sql.params, vec![composite.clone()]);
443
444        let update = Update::table("champions")
445            .set(
446                nautilus_core::ColumnMarker::new("champions", "stats"),
447                composite.clone(),
448            )
449            .build()
450            .unwrap();
451        let sql = dialect.render_update(&update).unwrap();
452
453        assert_eq!(
454            sql.text,
455            "UPDATE \"champions\" SET \"stats\" = $1::\"ChampionStatsT\""
456        );
457        assert_eq!(sql.params, vec![composite]);
458    }
459
460    #[test]
461    fn vector_distance_ordering_uses_pgvector_operator() {
462        let dialect = PostgresDialect;
463        let select = Select::from_table("embeddings")
464            .order_by_expr(
465                Expr::vector_distance(
466                    nautilus_core::VectorMetric::Cosine,
467                    Expr::column("embeddings__vector"),
468                    Expr::param(Value::Vector(vec![1.0, 2.0, 3.0])),
469                ),
470                nautilus_core::OrderDir::Asc,
471            )
472            .take(5)
473            .build()
474            .unwrap();
475        let sql = dialect.render_select(&select).unwrap();
476
477        assert_eq!(
478            sql.text,
479            "SELECT * FROM \"embeddings\" ORDER BY (\"embeddings\".\"vector\" <=> $1::vector) ASC LIMIT 5"
480        );
481        assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
482    }
483}