1use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Expr, Insert, Result, Select, Update, Value};
5
6#[derive(Debug, Clone, Copy)]
12pub struct PostgresDialect;
13
14impl Dialect for PostgresDialect {
15 fn render_select(&self, select: &Select) -> Result<Sql> {
16 let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(select));
17 render_select_body_core!(&mut ctx, select, '"', render_expr, true, false);
18 Ok(Sql {
19 text: ctx.sql,
20 params: ctx.params,
21 })
22 }
23
24 fn render_select_owned(&self, mut select: Select) -> Result<Sql> {
25 let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(&select));
26 render_select_body_core_mut!(&mut ctx, &mut select, '"', render_expr_owned, true, false);
27 Ok(Sql {
28 text: ctx.sql,
29 params: ctx.params,
30 })
31 }
32
33 fn render_insert(&self, insert: &Insert) -> Result<Sql> {
34 let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(insert));
35 render_insert_body!(&mut ctx, insert, '"', true, true);
36 Ok(Sql {
37 text: ctx.sql,
38 params: ctx.params,
39 })
40 }
41
42 fn render_insert_owned(&self, mut insert: Insert) -> Result<Sql> {
43 let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(&insert));
44 render_insert_body_mut!(&mut ctx, &mut insert, '"', true, true);
45 Ok(Sql {
46 text: ctx.sql,
47 params: ctx.params,
48 })
49 }
50
51 fn render_update(&self, update: &Update) -> Result<Sql> {
52 let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(update));
53 render_update_body!(&mut ctx, update, '"', render_expr, true, true);
54 Ok(Sql {
55 text: ctx.sql,
56 params: ctx.params,
57 })
58 }
59
60 fn render_update_owned(&self, mut update: Update) -> Result<Sql> {
61 let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(&update));
62 render_update_body_mut!(&mut ctx, &mut update, '"', render_expr_owned, true, true);
63 Ok(Sql {
64 text: ctx.sql,
65 params: ctx.params,
66 })
67 }
68
69 fn render_delete(&self, delete: &Delete) -> Result<Sql> {
70 let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(delete));
71 render_delete_body!(&mut ctx, delete, '"', render_expr, true);
72 Ok(Sql {
73 text: ctx.sql,
74 params: ctx.params,
75 })
76 }
77
78 fn render_delete_owned(&self, mut delete: Delete) -> Result<Sql> {
79 let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(&delete));
80 render_delete_body_mut!(&mut ctx, &mut delete, '"', render_expr_owned, true);
81 Ok(Sql {
82 text: ctx.sql,
83 params: ctx.params,
84 })
85 }
86}
87
88struct RenderContext {
89 sql: String,
90 params: Vec<Value>,
91}
92
93impl RenderContext {
94 fn with_estimate(estimate: crate::RenderEstimate) -> Self {
95 Self {
96 sql: String::with_capacity(estimate.sql_capacity),
97 params: Vec::with_capacity(estimate.params_capacity),
98 }
99 }
100
101 fn push_param(&mut self, value: Value) {
102 self.params.push(value);
103 self.sql.push('$');
104 crate::push_usize(&mut self.sql, self.params.len());
105 }
106
107 fn take_param(&mut self, value: &mut Value) {
108 self.push_param(std::mem::replace(value, Value::Null));
109 }
110}
111
112fn render_select_body(ctx: &mut RenderContext, select: &crate::Select) {
113 render_select_body_core!(ctx, select, '"', render_expr, true, false);
114}
115
116fn render_select_body_owned(ctx: &mut RenderContext, select: &mut crate::Select) {
117 render_select_body_core_mut!(ctx, select, '"', render_expr_owned, true, false);
118}
119
120fn render_expr(ctx: &mut RenderContext, expr: &Expr) {
121 render_expr_common!(ctx, expr, '"', render_expr, render_select_body, {
122 Expr::Param(value) => {
123 if matches!(value, Value::Null) {
126 ctx.sql.push_str("NULL");
127 } else {
128 ctx.push_param(value.clone());
129 if matches!(value, Value::Uuid(_)) {
131 ctx.sql.push_str("::uuid");
132 } else if matches!(value, Value::Json(_)) {
133 ctx.sql.push_str("::json");
134 } else if matches!(value, Value::Vector(_)) {
135 ctx.sql.push_str("::vector");
136 } else if matches!(value, Value::Geometry(_)) {
137 ctx.sql.push_str("::geometry");
138 } else if matches!(value, Value::Geography(_)) {
139 ctx.sql.push_str("::geography");
140 } else if is_homogeneous_geometry_array(value) {
141 ctx.sql.push_str("::geometry[]");
142 } else if is_homogeneous_geography_array(value) {
143 ctx.sql.push_str("::geography[]");
144 } else if let Value::Enum { type_name, .. } = value {
145 ctx.sql.push_str("::");
146 ctx.sql.push_str(type_name);
147 }
148 }
149 }
150 Expr::Binary { left, op, right } => {
151 if matches!(op, BinaryOp::In | BinaryOp::NotIn) {
152 ctx.sql.push('(');
153 render_expr(ctx, left);
154 ctx.sql.push(' ');
155 ctx.sql.push_str(if matches!(op, BinaryOp::In) { "IN" } else { "NOT IN" });
156 ctx.sql.push_str(" (");
157 if let Expr::List(exprs) = right.as_ref() {
158 for (i, e) in exprs.iter().enumerate() {
159 if i > 0 { ctx.sql.push_str(", "); }
160 render_expr(ctx, e);
161 }
162 } else {
163 render_expr(ctx, right);
164 }
165 ctx.sql.push(')');
166 ctx.sql.push(')');
167 } else {
168 ctx.sql.push('(');
169 render_expr(ctx, left);
170 ctx.sql.push(' ');
171 ctx.sql.push_str(match op {
172 BinaryOp::ArrayContains => "@>",
173 BinaryOp::ArrayContainedBy => "<@",
174 BinaryOp::ArrayOverlaps => "&&",
175 _ => crate::binary_op_sql(op),
176 });
177 ctx.sql.push(' ');
178 render_expr(ctx, right);
179 ctx.sql.push(')');
180 }
181 }
182 Expr::FunctionCall { name, args } => {
183 if args.len() == 2 {
184 let op = match name.as_str() {
185 nautilus_core::expr::VECTOR_L2_DISTANCE_FUNCTION => Some("<->"),
186 nautilus_core::expr::VECTOR_INNER_PRODUCT_FUNCTION => Some("<#>"),
187 nautilus_core::expr::VECTOR_COSINE_DISTANCE_FUNCTION => Some("<=>"),
188 _ => None,
189 };
190 if let Some(op) = op {
191 ctx.sql.push('(');
192 render_expr(ctx, &args[0]);
193 ctx.sql.push(' ');
194 ctx.sql.push_str(op);
195 ctx.sql.push(' ');
196 render_expr(ctx, &args[1]);
197 ctx.sql.push(')');
198 return;
199 }
200 }
201 ctx.sql.push_str(name);
202 ctx.sql.push('(');
203 for (i, arg) in args.iter().enumerate() {
204 if i > 0 { ctx.sql.push_str(", "); }
205 render_expr(ctx, arg);
206 }
207 ctx.sql.push(')');
208 }
209 Expr::Filter { expr, predicate } => {
210 render_expr(ctx, expr);
212 ctx.sql.push_str(" FILTER (WHERE ");
213 render_expr(ctx, predicate);
214 ctx.sql.push(')');
215 }
216 });
217}
218
219fn render_expr_owned(ctx: &mut RenderContext, expr: &mut Expr) {
220 render_expr_common_mut!(ctx, expr, '"', render_expr_owned, render_select_body_owned, {
221 Expr::Param(value) => {
222 if matches!(value, Value::Null) {
225 ctx.sql.push_str("NULL");
226 } else {
227 let cast = postgres_param_cast(value);
228 ctx.take_param(value);
229 if let Some(cast) = cast {
230 cast.push_sql(&mut ctx.sql);
231 }
232 }
233 }
234 Expr::Binary { left, op, right } => {
235 if matches!(*op, BinaryOp::In | BinaryOp::NotIn) {
236 ctx.sql.push('(');
237 render_expr_owned(ctx, left.as_mut());
238 ctx.sql.push(' ');
239 ctx.sql
240 .push_str(if matches!(*op, BinaryOp::In) { "IN" } else { "NOT IN" });
241 ctx.sql.push_str(" (");
242 if let Expr::List(exprs) = right.as_mut() {
243 for (i, e) in exprs.iter_mut().enumerate() {
244 if i > 0 {
245 ctx.sql.push_str(", ");
246 }
247 render_expr_owned(ctx, e);
248 }
249 } else {
250 render_expr_owned(ctx, right.as_mut());
251 }
252 ctx.sql.push(')');
253 ctx.sql.push(')');
254 } else {
255 ctx.sql.push('(');
256 render_expr_owned(ctx, left.as_mut());
257 ctx.sql.push(' ');
258 ctx.sql.push_str(match *op {
259 BinaryOp::ArrayContains => "@>",
260 BinaryOp::ArrayContainedBy => "<@",
261 BinaryOp::ArrayOverlaps => "&&",
262 _ => crate::binary_op_sql(op),
263 });
264 ctx.sql.push(' ');
265 render_expr_owned(ctx, right.as_mut());
266 ctx.sql.push(')');
267 }
268 }
269 Expr::FunctionCall { name, args } => {
270 if args.len() == 2 {
271 let op = match name.as_str() {
272 nautilus_core::expr::VECTOR_L2_DISTANCE_FUNCTION => Some("<->"),
273 nautilus_core::expr::VECTOR_INNER_PRODUCT_FUNCTION => Some("<#>"),
274 nautilus_core::expr::VECTOR_COSINE_DISTANCE_FUNCTION => Some("<=>"),
275 _ => None,
276 };
277 if let Some(op) = op {
278 ctx.sql.push('(');
279 render_expr_owned(ctx, &mut args[0]);
280 ctx.sql.push(' ');
281 ctx.sql.push_str(op);
282 ctx.sql.push(' ');
283 render_expr_owned(ctx, &mut args[1]);
284 ctx.sql.push(')');
285 return;
286 }
287 }
288 ctx.sql.push_str(name);
289 ctx.sql.push('(');
290 for (i, arg) in args.iter_mut().enumerate() {
291 if i > 0 {
292 ctx.sql.push_str(", ");
293 }
294 render_expr_owned(ctx, arg);
295 }
296 ctx.sql.push(')');
297 }
298 Expr::Filter { expr, predicate } => {
299 render_expr_owned(ctx, expr.as_mut());
300 ctx.sql.push_str(" FILTER (WHERE ");
301 render_expr_owned(ctx, predicate.as_mut());
302 ctx.sql.push(')');
303 }
304 });
305}
306
307enum ParamCast {
308 Static(&'static str),
309 Enum(String),
310}
311
312impl ParamCast {
313 fn push_sql(&self, sql: &mut String) {
314 match self {
315 Self::Static(name) => {
316 sql.push_str("::");
317 sql.push_str(name);
318 }
319 Self::Enum(type_name) => {
320 sql.push_str("::");
321 sql.push_str(type_name);
322 }
323 }
324 }
325}
326
327fn postgres_param_cast(value: &Value) -> Option<ParamCast> {
328 match value {
329 Value::Uuid(_) => Some(ParamCast::Static("uuid")),
330 Value::Json(_) => Some(ParamCast::Static("json")),
331 Value::Vector(_) => Some(ParamCast::Static("vector")),
332 Value::Geometry(_) => Some(ParamCast::Static("geometry")),
333 Value::Geography(_) => Some(ParamCast::Static("geography")),
334 value if is_homogeneous_geometry_array(value) => Some(ParamCast::Static("geometry[]")),
335 value if is_homogeneous_geography_array(value) => Some(ParamCast::Static("geography[]")),
336 Value::Enum { type_name, .. } => Some(ParamCast::Enum(type_name.clone())),
337 _ => None,
338 }
339}
340
341fn is_homogeneous_geometry_array(value: &Value) -> bool {
342 matches!(
343 value,
344 Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geometry(_)))
345 )
346}
347
348fn is_homogeneous_geography_array(value: &Value) -> bool {
349 matches!(
350 value,
351 Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geography(_)))
352 )
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 fn quote_identifier(name: &str) -> String {
360 let mut sql = String::new();
361 crate::push_quoted_identifier(&mut sql, name, '"');
362 sql
363 }
364
365 #[test]
366 fn test_quote_identifier() {
367 assert_eq!(quote_identifier("users"), "\"users\"");
368 assert_eq!(quote_identifier("email"), "\"email\"");
369 assert_eq!(quote_identifier("foo\"bar"), "\"foo\"\"bar\"");
370 assert_eq!(quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
371 }
372
373 #[test]
374 fn test_array_contains_operator() {
375 let dialect = PostgresDialect;
376 let expr = Expr::Binary {
377 left: Box::new(Expr::column("posts__tags")),
378 op: BinaryOp::ArrayContains,
379 right: Box::new(Expr::param(Value::Array(vec![Value::String(
380 "rust".to_string(),
381 )]))),
382 };
383 let select = Select::from_table("posts").filter(expr).build().unwrap();
384 let sql = dialect.render_select(&select).unwrap();
385
386 assert_eq!(
387 sql.text,
388 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" @> $1)"
389 );
390 assert_eq!(sql.params.len(), 1);
391 match &sql.params[0] {
392 Value::Array(arr) => {
393 assert_eq!(arr.len(), 1);
394 assert_eq!(arr[0], Value::String("rust".to_string()));
395 }
396 _ => panic!("Expected Array value"),
397 }
398 }
399
400 #[test]
401 fn test_array_contained_by_operator() {
402 let dialect = PostgresDialect;
403 let expr = Expr::Binary {
404 left: Box::new(Expr::column("posts__tags")),
405 op: BinaryOp::ArrayContainedBy,
406 right: Box::new(Expr::param(Value::Array(vec![
407 Value::String("rust".to_string()),
408 Value::String("go".to_string()),
409 ]))),
410 };
411 let select = Select::from_table("posts").filter(expr).build().unwrap();
412 let sql = dialect.render_select(&select).unwrap();
413
414 assert_eq!(
415 sql.text,
416 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" <@ $1)"
417 );
418 assert_eq!(sql.params.len(), 1);
419 match &sql.params[0] {
420 Value::Array(arr) => {
421 assert_eq!(arr.len(), 2);
422 assert_eq!(arr[0], Value::String("rust".to_string()));
423 assert_eq!(arr[1], Value::String("go".to_string()));
424 }
425 _ => panic!("Expected Array value"),
426 }
427 }
428
429 #[test]
430 fn test_array_overlaps_operator() {
431 let dialect = PostgresDialect;
432 let expr = Expr::Binary {
433 left: Box::new(Expr::column("posts__tags")),
434 op: BinaryOp::ArrayOverlaps,
435 right: Box::new(Expr::param(Value::Array(vec![
436 Value::String("rust".to_string()),
437 Value::String("python".to_string()),
438 ]))),
439 };
440 let select = Select::from_table("posts").filter(expr).build().unwrap();
441 let sql = dialect.render_select(&select).unwrap();
442
443 assert_eq!(
444 sql.text,
445 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" && $1)"
446 );
447 assert_eq!(sql.params.len(), 1);
448 match &sql.params[0] {
449 Value::Array(arr) => {
450 assert_eq!(arr.len(), 2);
451 assert_eq!(arr[0], Value::String("rust".to_string()));
452 assert_eq!(arr[1], Value::String("python".to_string()));
453 }
454 _ => panic!("Expected Array value"),
455 }
456 }
457
458 #[test]
459 fn test_array_operators_with_integers() {
460 let dialect = PostgresDialect;
461 let expr = Expr::Binary {
462 left: Box::new(Expr::column("posts__scores")),
463 op: BinaryOp::ArrayContains,
464 right: Box::new(Expr::param(Value::Array(vec![
465 Value::I32(100),
466 Value::I32(200),
467 ]))),
468 };
469 let select = Select::from_table("posts").filter(expr).build().unwrap();
470 let sql = dialect.render_select(&select).unwrap();
471
472 assert_eq!(
473 sql.text,
474 "SELECT * FROM \"posts\" WHERE (\"posts\".\"scores\" @> $1)"
475 );
476 assert_eq!(sql.params.len(), 1);
477 match &sql.params[0] {
478 Value::Array(arr) => {
479 assert_eq!(arr.len(), 2);
480 assert_eq!(arr[0], Value::I32(100));
481 assert_eq!(arr[1], Value::I32(200));
482 }
483 _ => panic!("Expected Array value"),
484 }
485 }
486
487 #[test]
488 fn vector_params_are_cast_to_pgvector_type() {
489 let dialect = PostgresDialect;
490 let select = Select::from_table("embeddings")
491 .filter(
492 Expr::column("embeddings__vector")
493 .eq(Expr::param(Value::Vector(vec![1.0, 2.0, 3.0]))),
494 )
495 .build()
496 .unwrap();
497 let sql = dialect.render_select(&select).unwrap();
498
499 assert_eq!(
500 sql.text,
501 "SELECT * FROM \"embeddings\" WHERE (\"embeddings\".\"vector\" = $1::vector)"
502 );
503 assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
504 }
505
506 #[test]
507 fn postgis_params_are_cast_to_spatial_types() {
508 let dialect = PostgresDialect;
509 let select = Select::from_table("places")
510 .filter(
511 Expr::column("places__geom")
512 .eq(Expr::param(Value::Geometry("POINT(1 2)".to_string()))),
513 )
514 .build()
515 .unwrap();
516 let sql = dialect.render_select(&select).unwrap();
517
518 assert_eq!(
519 sql.text,
520 "SELECT * FROM \"places\" WHERE (\"places\".\"geom\" = $1::geometry)"
521 );
522 assert_eq!(sql.params, vec![Value::Geometry("POINT(1 2)".to_string())]);
523
524 let select = Select::from_table("places")
525 .filter(
526 Expr::column("places__geog")
527 .eq(Expr::param(Value::Geography("POINT(1 2)".to_string()))),
528 )
529 .build()
530 .unwrap();
531 let sql = dialect.render_select(&select).unwrap();
532
533 assert_eq!(
534 sql.text,
535 "SELECT * FROM \"places\" WHERE (\"places\".\"geog\" = $1::geography)"
536 );
537 assert_eq!(sql.params, vec![Value::Geography("POINT(1 2)".to_string())]);
538 }
539
540 #[test]
541 fn vector_distance_ordering_uses_pgvector_operator() {
542 let dialect = PostgresDialect;
543 let select = Select::from_table("embeddings")
544 .order_by_expr(
545 Expr::vector_distance(
546 nautilus_core::VectorMetric::Cosine,
547 Expr::column("embeddings__vector"),
548 Expr::param(Value::Vector(vec![1.0, 2.0, 3.0])),
549 ),
550 nautilus_core::OrderDir::Asc,
551 )
552 .take(5)
553 .build()
554 .unwrap();
555 let sql = dialect.render_select(&select).unwrap();
556
557 assert_eq!(
558 sql.text,
559 "SELECT * FROM \"embeddings\" ORDER BY (\"embeddings\".\"vector\" <=> $1::vector) ASC LIMIT 5"
560 );
561 assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
562 }
563}