1use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Expr, Insert, Result, Select, Update, Value};
5
6#[derive(Debug, Clone, Copy)]
12pub struct PostgresDialect;
13
14impl Dialect for PostgresDialect {
15 fn render_select_owned(&self, mut select: Select) -> Result<Sql> {
16 let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(&select));
17 render_select_body_core_mut!(&mut ctx, &mut select, '"', render_expr_owned, true, false);
18 Ok(Sql {
19 text: ctx.sql,
20 params: ctx.params,
21 })
22 }
23
24 fn render_insert_owned(&self, mut insert: Insert) -> Result<Sql> {
25 let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(&insert));
26 render_insert_body_mut!(&mut ctx, &mut insert, '"', true, true);
27 Ok(Sql {
28 text: ctx.sql,
29 params: ctx.params,
30 })
31 }
32
33 fn render_update_owned(&self, mut update: Update) -> Result<Sql> {
34 let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(&update));
35 render_update_body_mut!(&mut ctx, &mut update, '"', render_expr_owned, true, true);
36 Ok(Sql {
37 text: ctx.sql,
38 params: ctx.params,
39 })
40 }
41
42 fn render_delete_owned(&self, mut delete: Delete) -> Result<Sql> {
43 let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(&delete));
44 render_delete_body_mut!(&mut ctx, &mut delete, '"', render_expr_owned, true);
45 Ok(Sql {
46 text: ctx.sql,
47 params: ctx.params,
48 })
49 }
50}
51
52struct RenderContext {
53 sql: String,
54 params: Vec<Value>,
55}
56
57impl RenderContext {
58 fn with_estimate(estimate: crate::RenderEstimate) -> Self {
59 Self {
60 sql: String::with_capacity(estimate.sql_capacity),
61 params: Vec::with_capacity(estimate.params_capacity),
62 }
63 }
64
65 fn push_param(&mut self, value: Value) {
66 self.params.push(value);
67 self.sql.push('$');
68 crate::push_usize(&mut self.sql, self.params.len());
69 }
70
71 fn take_param(&mut self, value: &mut Value) {
72 self.push_param(std::mem::replace(value, Value::Null));
73 }
74}
75
76fn render_select_body_owned(ctx: &mut RenderContext, select: &mut crate::Select) {
77 render_select_body_core_mut!(ctx, select, '"', render_expr_owned, true, false);
78}
79
80fn render_expr_owned(ctx: &mut RenderContext, expr: &mut Expr) {
81 render_expr_common_mut!(ctx, expr, '"', render_expr_owned, render_select_body_owned, {
82 Expr::CompositeField {
83 table,
84 column,
85 field,
86 ..
87 } => {
88 crate::push_composite_field_reference(&mut ctx.sql, table, column, field, '"');
89 }
90 Expr::Param(value) => {
91 if matches!(value, Value::Null) {
94 ctx.sql.push_str("NULL");
95 } else {
96 let cast = postgres_param_cast(value);
97 ctx.take_param(value);
98 if let Some(cast) = cast {
99 cast.push_sql(&mut ctx.sql);
100 }
101 }
102 }
103 Expr::Binary { left, op, right } => {
104 if matches!(*op, BinaryOp::In | BinaryOp::NotIn) {
105 ctx.sql.push('(');
106 render_expr_owned(ctx, left.as_mut());
107 ctx.sql.push(' ');
108 ctx.sql
109 .push_str(if matches!(*op, BinaryOp::In) { "IN" } else { "NOT IN" });
110 ctx.sql.push_str(" (");
111 if let Expr::List(exprs) = right.as_mut() {
112 for (i, e) in exprs.iter_mut().enumerate() {
113 if i > 0 {
114 ctx.sql.push_str(", ");
115 }
116 render_expr_owned(ctx, e);
117 }
118 } else {
119 render_expr_owned(ctx, right.as_mut());
120 }
121 ctx.sql.push(')');
122 ctx.sql.push(')');
123 } else {
124 ctx.sql.push('(');
125 render_expr_owned(ctx, left.as_mut());
126 ctx.sql.push(' ');
127 ctx.sql.push_str(match *op {
128 BinaryOp::ArrayContains => "@>",
129 BinaryOp::ArrayContainedBy => "<@",
130 BinaryOp::ArrayOverlaps => "&&",
131 _ => crate::binary_op_sql(op),
132 });
133 ctx.sql.push(' ');
134 render_expr_owned(ctx, right.as_mut());
135 ctx.sql.push(')');
136 }
137 }
138 Expr::FunctionCall { name, args } => {
139 if args.len() == 2 {
140 let op = match name.as_str() {
141 nautilus_core::expr::VECTOR_L2_DISTANCE_FUNCTION => Some("<->"),
142 nautilus_core::expr::VECTOR_INNER_PRODUCT_FUNCTION => Some("<#>"),
143 nautilus_core::expr::VECTOR_COSINE_DISTANCE_FUNCTION => Some("<=>"),
144 _ => None,
145 };
146 if let Some(op) = op {
147 ctx.sql.push('(');
148 render_expr_owned(ctx, &mut args[0]);
149 ctx.sql.push(' ');
150 ctx.sql.push_str(op);
151 ctx.sql.push(' ');
152 render_expr_owned(ctx, &mut args[1]);
153 ctx.sql.push(')');
154 return;
155 }
156 }
157 ctx.sql.push_str(name);
158 ctx.sql.push('(');
159 for (i, arg) in args.iter_mut().enumerate() {
160 if i > 0 {
161 ctx.sql.push_str(", ");
162 }
163 render_expr_owned(ctx, arg);
164 }
165 ctx.sql.push(')');
166 }
167 Expr::Filter { expr, predicate } => {
168 render_expr_owned(ctx, expr.as_mut());
169 ctx.sql.push_str(" FILTER (WHERE ");
170 render_expr_owned(ctx, predicate.as_mut());
171 ctx.sql.push(')');
172 }
173 });
174}
175
176enum ParamCast {
177 Static(&'static str),
178 Enum(String),
179 Composite(String),
180}
181
182impl ParamCast {
183 fn push_sql(&self, sql: &mut String) {
184 match self {
185 Self::Static(name) => {
186 sql.push_str("::");
187 sql.push_str(name);
188 }
189 Self::Enum(type_name) | Self::Composite(type_name) => {
190 sql.push_str("::");
191 crate::push_quoted_identifier(sql, type_name, '"');
192 }
193 }
194 }
195}
196
197fn postgres_param_cast(value: &Value) -> Option<ParamCast> {
198 match value {
199 Value::Uuid(_) => Some(ParamCast::Static("uuid")),
200 Value::Json(_) => Some(ParamCast::Static("json")),
201 Value::Vector(_) => Some(ParamCast::Static("vector")),
202 Value::Geometry(_) => Some(ParamCast::Static("geometry")),
203 Value::Geography(_) => Some(ParamCast::Static("geography")),
204 value if is_homogeneous_geometry_array(value) => Some(ParamCast::Static("geometry[]")),
205 value if is_homogeneous_geography_array(value) => Some(ParamCast::Static("geography[]")),
206 Value::Enum { type_name, .. } => Some(ParamCast::Enum(type_name.clone())),
207 Value::Composite { type_name, .. } => Some(ParamCast::Composite(type_name.clone())),
208 _ => None,
209 }
210}
211
212fn is_homogeneous_geometry_array(value: &Value) -> bool {
213 matches!(
214 value,
215 Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geometry(_)))
216 )
217}
218
219fn is_homogeneous_geography_array(value: &Value) -> bool {
220 matches!(
221 value,
222 Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geography(_)))
223 )
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 fn quote_identifier(name: &str) -> String {
231 let mut sql = String::new();
232 crate::push_quoted_identifier(&mut sql, name, '"');
233 sql
234 }
235
236 #[test]
237 fn test_quote_identifier() {
238 assert_eq!(quote_identifier("users"), "\"users\"");
239 assert_eq!(quote_identifier("email"), "\"email\"");
240 assert_eq!(quote_identifier("foo\"bar"), "\"foo\"\"bar\"");
241 assert_eq!(quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
242 }
243
244 #[test]
245 fn test_array_contains_operator() {
246 let dialect = PostgresDialect;
247 let expr = Expr::Binary {
248 left: Box::new(Expr::column("posts__tags")),
249 op: BinaryOp::ArrayContains,
250 right: Box::new(Expr::param(Value::Array(vec![Value::String(
251 "rust".to_string(),
252 )]))),
253 };
254 let select = Select::from_table("posts").filter(expr).build().unwrap();
255 let sql = dialect.render_select(&select).unwrap();
256
257 assert_eq!(
258 sql.text,
259 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" @> $1)"
260 );
261 assert_eq!(sql.params.len(), 1);
262 match &sql.params[0] {
263 Value::Array(arr) => {
264 assert_eq!(arr.len(), 1);
265 assert_eq!(arr[0], Value::String("rust".to_string()));
266 }
267 _ => panic!("Expected Array value"),
268 }
269 }
270
271 #[test]
272 fn test_array_contained_by_operator() {
273 let dialect = PostgresDialect;
274 let expr = Expr::Binary {
275 left: Box::new(Expr::column("posts__tags")),
276 op: BinaryOp::ArrayContainedBy,
277 right: Box::new(Expr::param(Value::Array(vec![
278 Value::String("rust".to_string()),
279 Value::String("go".to_string()),
280 ]))),
281 };
282 let select = Select::from_table("posts").filter(expr).build().unwrap();
283 let sql = dialect.render_select(&select).unwrap();
284
285 assert_eq!(
286 sql.text,
287 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" <@ $1)"
288 );
289 assert_eq!(sql.params.len(), 1);
290 match &sql.params[0] {
291 Value::Array(arr) => {
292 assert_eq!(arr.len(), 2);
293 assert_eq!(arr[0], Value::String("rust".to_string()));
294 assert_eq!(arr[1], Value::String("go".to_string()));
295 }
296 _ => panic!("Expected Array value"),
297 }
298 }
299
300 #[test]
301 fn test_array_overlaps_operator() {
302 let dialect = PostgresDialect;
303 let expr = Expr::Binary {
304 left: Box::new(Expr::column("posts__tags")),
305 op: BinaryOp::ArrayOverlaps,
306 right: Box::new(Expr::param(Value::Array(vec![
307 Value::String("rust".to_string()),
308 Value::String("python".to_string()),
309 ]))),
310 };
311 let select = Select::from_table("posts").filter(expr).build().unwrap();
312 let sql = dialect.render_select(&select).unwrap();
313
314 assert_eq!(
315 sql.text,
316 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" && $1)"
317 );
318 assert_eq!(sql.params.len(), 1);
319 match &sql.params[0] {
320 Value::Array(arr) => {
321 assert_eq!(arr.len(), 2);
322 assert_eq!(arr[0], Value::String("rust".to_string()));
323 assert_eq!(arr[1], Value::String("python".to_string()));
324 }
325 _ => panic!("Expected Array value"),
326 }
327 }
328
329 #[test]
330 fn test_array_operators_with_integers() {
331 let dialect = PostgresDialect;
332 let expr = Expr::Binary {
333 left: Box::new(Expr::column("posts__scores")),
334 op: BinaryOp::ArrayContains,
335 right: Box::new(Expr::param(Value::Array(vec![
336 Value::I32(100),
337 Value::I32(200),
338 ]))),
339 };
340 let select = Select::from_table("posts").filter(expr).build().unwrap();
341 let sql = dialect.render_select(&select).unwrap();
342
343 assert_eq!(
344 sql.text,
345 "SELECT * FROM \"posts\" WHERE (\"posts\".\"scores\" @> $1)"
346 );
347 assert_eq!(sql.params.len(), 1);
348 match &sql.params[0] {
349 Value::Array(arr) => {
350 assert_eq!(arr.len(), 2);
351 assert_eq!(arr[0], Value::I32(100));
352 assert_eq!(arr[1], Value::I32(200));
353 }
354 _ => panic!("Expected Array value"),
355 }
356 }
357
358 #[test]
359 fn vector_params_are_cast_to_pgvector_type() {
360 let dialect = PostgresDialect;
361 let select = Select::from_table("embeddings")
362 .filter(
363 Expr::column("embeddings__vector")
364 .eq(Expr::param(Value::Vector(vec![1.0, 2.0, 3.0]))),
365 )
366 .build()
367 .unwrap();
368 let sql = dialect.render_select(&select).unwrap();
369
370 assert_eq!(
371 sql.text,
372 "SELECT * FROM \"embeddings\" WHERE (\"embeddings\".\"vector\" = $1::vector)"
373 );
374 assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
375 }
376
377 #[test]
378 fn postgis_params_are_cast_to_spatial_types() {
379 let dialect = PostgresDialect;
380 let select = Select::from_table("places")
381 .filter(
382 Expr::column("places__geom")
383 .eq(Expr::param(Value::Geometry("POINT(1 2)".to_string()))),
384 )
385 .build()
386 .unwrap();
387 let sql = dialect.render_select(&select).unwrap();
388
389 assert_eq!(
390 sql.text,
391 "SELECT * FROM \"places\" WHERE (\"places\".\"geom\" = $1::geometry)"
392 );
393 assert_eq!(sql.params, vec![Value::Geometry("POINT(1 2)".to_string())]);
394
395 let select = Select::from_table("places")
396 .filter(
397 Expr::column("places__geog")
398 .eq(Expr::param(Value::Geography("POINT(1 2)".to_string()))),
399 )
400 .build()
401 .unwrap();
402 let sql = dialect.render_select(&select).unwrap();
403
404 assert_eq!(
405 sql.text,
406 "SELECT * FROM \"places\" WHERE (\"places\".\"geog\" = $1::geography)"
407 );
408 assert_eq!(sql.params, vec![Value::Geography("POINT(1 2)".to_string())]);
409 }
410
411 #[test]
412 fn composite_params_are_cast_to_their_type_name() {
413 let dialect = PostgresDialect;
414 let composite = Value::Composite {
415 type_name: "ChampionStatsT".to_string(),
416 fields: vec![Value::I32(0), Value::I32(0)],
417 };
418 let select = Select::from_table("champions")
419 .filter(Expr::column("champions__stats").eq(Expr::param(composite.clone())))
420 .build()
421 .unwrap();
422 let sql = dialect.render_select(&select).unwrap();
423
424 assert_eq!(
425 sql.text,
426 "SELECT * FROM \"champions\" WHERE (\"champions\".\"stats\" = $1::\"ChampionStatsT\")"
427 );
428 assert_eq!(sql.params, vec![composite]);
429 }
430
431 #[test]
432 fn composite_insert_and_update_params_are_cast_to_their_type_name() {
433 let dialect = PostgresDialect;
434 let composite = Value::Composite {
435 type_name: "ChampionStatsT".to_string(),
436 fields: vec![Value::I32(0), Value::I32(0)],
437 };
438
439 let insert = Insert::into_table("champions")
440 .column(nautilus_core::ColumnMarker::new("champions", "stats"))
441 .values(vec![composite.clone()])
442 .build()
443 .unwrap();
444 let sql = dialect.render_insert(&insert).unwrap();
445
446 assert_eq!(
447 sql.text,
448 "INSERT INTO \"champions\" (\"stats\") VALUES ($1::\"ChampionStatsT\")"
449 );
450 assert_eq!(sql.params, vec![composite.clone()]);
451
452 let update = Update::table("champions")
453 .set(
454 nautilus_core::ColumnMarker::new("champions", "stats"),
455 composite.clone(),
456 )
457 .build()
458 .unwrap();
459 let sql = dialect.render_update(&update).unwrap();
460
461 assert_eq!(
462 sql.text,
463 "UPDATE \"champions\" SET \"stats\" = $1::\"ChampionStatsT\""
464 );
465 assert_eq!(sql.params, vec![composite]);
466 }
467
468 #[test]
469 fn composite_field_ordering_uses_native_attribute_syntax() {
470 let dialect = PostgresDialect;
471 let select = Select::from_table("shipments")
472 .order_by_expr(
473 Expr::composite_field(
474 "shipments",
475 "delivery_snapshot",
476 "eta_minutes",
477 "etaMinutes",
478 nautilus_core::JsonPathCast::Signed,
479 ),
480 nautilus_core::OrderDir::Asc,
481 )
482 .build()
483 .unwrap();
484 let sql = dialect.render_select(&select).unwrap();
485
486 assert_eq!(
487 sql.text,
488 "SELECT * FROM \"shipments\" ORDER BY (\"shipments\".\"delivery_snapshot\").\"eta_minutes\" ASC"
489 );
490 }
491
492 #[test]
493 fn vector_distance_ordering_uses_pgvector_operator() {
494 let dialect = PostgresDialect;
495 let select = Select::from_table("embeddings")
496 .order_by_expr(
497 Expr::vector_distance(
498 nautilus_core::VectorMetric::Cosine,
499 Expr::column("embeddings__vector"),
500 Expr::param(Value::Vector(vec![1.0, 2.0, 3.0])),
501 ),
502 nautilus_core::OrderDir::Asc,
503 )
504 .take(5)
505 .build()
506 .unwrap();
507 let sql = dialect.render_select(&select).unwrap();
508
509 assert_eq!(
510 sql.text,
511 "SELECT * FROM \"embeddings\" ORDER BY (\"embeddings\".\"vector\" <=> $1::vector) ASC LIMIT 5"
512 );
513 assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
514 }
515}