1use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Expr, Insert, Result, Select, Update, Value};
5
6#[derive(Debug, Clone, Copy)]
12pub struct PostgresDialect;
13
14impl Dialect for PostgresDialect {
15 fn render_select_owned(&self, mut select: Select) -> Result<Sql> {
16 let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(&select));
17 render_select_body_core_mut!(&mut ctx, &mut select, '"', render_expr_owned, true, false);
18 Ok(Sql {
19 text: ctx.sql,
20 params: ctx.params,
21 })
22 }
23
24 fn render_insert_owned(&self, mut insert: Insert) -> Result<Sql> {
25 let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(&insert));
26 render_insert_body_mut!(&mut ctx, &mut insert, '"', true, true);
27 Ok(Sql {
28 text: ctx.sql,
29 params: ctx.params,
30 })
31 }
32
33 fn render_update_owned(&self, mut update: Update) -> Result<Sql> {
34 let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(&update));
35 render_update_body_mut!(&mut ctx, &mut update, '"', render_expr_owned, true, true);
36 Ok(Sql {
37 text: ctx.sql,
38 params: ctx.params,
39 })
40 }
41
42 fn render_delete_owned(&self, mut delete: Delete) -> Result<Sql> {
43 let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(&delete));
44 render_delete_body_mut!(&mut ctx, &mut delete, '"', render_expr_owned, true);
45 Ok(Sql {
46 text: ctx.sql,
47 params: ctx.params,
48 })
49 }
50}
51
52struct RenderContext {
53 sql: String,
54 params: Vec<Value>,
55}
56
57impl RenderContext {
58 fn with_estimate(estimate: crate::RenderEstimate) -> Self {
59 Self {
60 sql: String::with_capacity(estimate.sql_capacity),
61 params: Vec::with_capacity(estimate.params_capacity),
62 }
63 }
64
65 fn push_param(&mut self, value: Value) {
66 self.params.push(value);
67 self.sql.push('$');
68 crate::push_usize(&mut self.sql, self.params.len());
69 }
70
71 fn take_param(&mut self, value: &mut Value) {
72 self.push_param(std::mem::replace(value, Value::Null));
73 }
74}
75
76fn render_select_body_owned(ctx: &mut RenderContext, select: &mut crate::Select) {
77 render_select_body_core_mut!(ctx, select, '"', render_expr_owned, true, false);
78}
79
80fn render_expr_owned(ctx: &mut RenderContext, expr: &mut Expr) {
81 render_expr_common_mut!(ctx, expr, '"', render_expr_owned, render_select_body_owned, {
82 Expr::Param(value) => {
83 if matches!(value, Value::Null) {
86 ctx.sql.push_str("NULL");
87 } else {
88 let cast = postgres_param_cast(value);
89 ctx.take_param(value);
90 if let Some(cast) = cast {
91 cast.push_sql(&mut ctx.sql);
92 }
93 }
94 }
95 Expr::Binary { left, op, right } => {
96 if matches!(*op, BinaryOp::In | BinaryOp::NotIn) {
97 ctx.sql.push('(');
98 render_expr_owned(ctx, left.as_mut());
99 ctx.sql.push(' ');
100 ctx.sql
101 .push_str(if matches!(*op, BinaryOp::In) { "IN" } else { "NOT IN" });
102 ctx.sql.push_str(" (");
103 if let Expr::List(exprs) = right.as_mut() {
104 for (i, e) in exprs.iter_mut().enumerate() {
105 if i > 0 {
106 ctx.sql.push_str(", ");
107 }
108 render_expr_owned(ctx, e);
109 }
110 } else {
111 render_expr_owned(ctx, right.as_mut());
112 }
113 ctx.sql.push(')');
114 ctx.sql.push(')');
115 } else {
116 ctx.sql.push('(');
117 render_expr_owned(ctx, left.as_mut());
118 ctx.sql.push(' ');
119 ctx.sql.push_str(match *op {
120 BinaryOp::ArrayContains => "@>",
121 BinaryOp::ArrayContainedBy => "<@",
122 BinaryOp::ArrayOverlaps => "&&",
123 _ => crate::binary_op_sql(op),
124 });
125 ctx.sql.push(' ');
126 render_expr_owned(ctx, right.as_mut());
127 ctx.sql.push(')');
128 }
129 }
130 Expr::FunctionCall { name, args } => {
131 if args.len() == 2 {
132 let op = match name.as_str() {
133 nautilus_core::expr::VECTOR_L2_DISTANCE_FUNCTION => Some("<->"),
134 nautilus_core::expr::VECTOR_INNER_PRODUCT_FUNCTION => Some("<#>"),
135 nautilus_core::expr::VECTOR_COSINE_DISTANCE_FUNCTION => Some("<=>"),
136 _ => None,
137 };
138 if let Some(op) = op {
139 ctx.sql.push('(');
140 render_expr_owned(ctx, &mut args[0]);
141 ctx.sql.push(' ');
142 ctx.sql.push_str(op);
143 ctx.sql.push(' ');
144 render_expr_owned(ctx, &mut args[1]);
145 ctx.sql.push(')');
146 return;
147 }
148 }
149 ctx.sql.push_str(name);
150 ctx.sql.push('(');
151 for (i, arg) in args.iter_mut().enumerate() {
152 if i > 0 {
153 ctx.sql.push_str(", ");
154 }
155 render_expr_owned(ctx, arg);
156 }
157 ctx.sql.push(')');
158 }
159 Expr::Filter { expr, predicate } => {
160 render_expr_owned(ctx, expr.as_mut());
161 ctx.sql.push_str(" FILTER (WHERE ");
162 render_expr_owned(ctx, predicate.as_mut());
163 ctx.sql.push(')');
164 }
165 });
166}
167
168enum ParamCast {
169 Static(&'static str),
170 Enum(String),
171 Composite(String),
172}
173
174impl ParamCast {
175 fn push_sql(&self, sql: &mut String) {
176 match self {
177 Self::Static(name) => {
178 sql.push_str("::");
179 sql.push_str(name);
180 }
181 Self::Enum(type_name) | Self::Composite(type_name) => {
182 sql.push_str("::");
183 crate::push_quoted_identifier(sql, type_name, '"');
184 }
185 }
186 }
187}
188
189fn postgres_param_cast(value: &Value) -> Option<ParamCast> {
190 match value {
191 Value::Uuid(_) => Some(ParamCast::Static("uuid")),
192 Value::Json(_) => Some(ParamCast::Static("json")),
193 Value::Vector(_) => Some(ParamCast::Static("vector")),
194 Value::Geometry(_) => Some(ParamCast::Static("geometry")),
195 Value::Geography(_) => Some(ParamCast::Static("geography")),
196 value if is_homogeneous_geometry_array(value) => Some(ParamCast::Static("geometry[]")),
197 value if is_homogeneous_geography_array(value) => Some(ParamCast::Static("geography[]")),
198 Value::Enum { type_name, .. } => Some(ParamCast::Enum(type_name.clone())),
199 Value::Composite { type_name, .. } => Some(ParamCast::Composite(type_name.clone())),
200 _ => None,
201 }
202}
203
204fn is_homogeneous_geometry_array(value: &Value) -> bool {
205 matches!(
206 value,
207 Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geometry(_)))
208 )
209}
210
211fn is_homogeneous_geography_array(value: &Value) -> bool {
212 matches!(
213 value,
214 Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geography(_)))
215 )
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 fn quote_identifier(name: &str) -> String {
223 let mut sql = String::new();
224 crate::push_quoted_identifier(&mut sql, name, '"');
225 sql
226 }
227
228 #[test]
229 fn test_quote_identifier() {
230 assert_eq!(quote_identifier("users"), "\"users\"");
231 assert_eq!(quote_identifier("email"), "\"email\"");
232 assert_eq!(quote_identifier("foo\"bar"), "\"foo\"\"bar\"");
233 assert_eq!(quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
234 }
235
236 #[test]
237 fn test_array_contains_operator() {
238 let dialect = PostgresDialect;
239 let expr = Expr::Binary {
240 left: Box::new(Expr::column("posts__tags")),
241 op: BinaryOp::ArrayContains,
242 right: Box::new(Expr::param(Value::Array(vec![Value::String(
243 "rust".to_string(),
244 )]))),
245 };
246 let select = Select::from_table("posts").filter(expr).build().unwrap();
247 let sql = dialect.render_select(&select).unwrap();
248
249 assert_eq!(
250 sql.text,
251 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" @> $1)"
252 );
253 assert_eq!(sql.params.len(), 1);
254 match &sql.params[0] {
255 Value::Array(arr) => {
256 assert_eq!(arr.len(), 1);
257 assert_eq!(arr[0], Value::String("rust".to_string()));
258 }
259 _ => panic!("Expected Array value"),
260 }
261 }
262
263 #[test]
264 fn test_array_contained_by_operator() {
265 let dialect = PostgresDialect;
266 let expr = Expr::Binary {
267 left: Box::new(Expr::column("posts__tags")),
268 op: BinaryOp::ArrayContainedBy,
269 right: Box::new(Expr::param(Value::Array(vec![
270 Value::String("rust".to_string()),
271 Value::String("go".to_string()),
272 ]))),
273 };
274 let select = Select::from_table("posts").filter(expr).build().unwrap();
275 let sql = dialect.render_select(&select).unwrap();
276
277 assert_eq!(
278 sql.text,
279 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" <@ $1)"
280 );
281 assert_eq!(sql.params.len(), 1);
282 match &sql.params[0] {
283 Value::Array(arr) => {
284 assert_eq!(arr.len(), 2);
285 assert_eq!(arr[0], Value::String("rust".to_string()));
286 assert_eq!(arr[1], Value::String("go".to_string()));
287 }
288 _ => panic!("Expected Array value"),
289 }
290 }
291
292 #[test]
293 fn test_array_overlaps_operator() {
294 let dialect = PostgresDialect;
295 let expr = Expr::Binary {
296 left: Box::new(Expr::column("posts__tags")),
297 op: BinaryOp::ArrayOverlaps,
298 right: Box::new(Expr::param(Value::Array(vec![
299 Value::String("rust".to_string()),
300 Value::String("python".to_string()),
301 ]))),
302 };
303 let select = Select::from_table("posts").filter(expr).build().unwrap();
304 let sql = dialect.render_select(&select).unwrap();
305
306 assert_eq!(
307 sql.text,
308 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" && $1)"
309 );
310 assert_eq!(sql.params.len(), 1);
311 match &sql.params[0] {
312 Value::Array(arr) => {
313 assert_eq!(arr.len(), 2);
314 assert_eq!(arr[0], Value::String("rust".to_string()));
315 assert_eq!(arr[1], Value::String("python".to_string()));
316 }
317 _ => panic!("Expected Array value"),
318 }
319 }
320
321 #[test]
322 fn test_array_operators_with_integers() {
323 let dialect = PostgresDialect;
324 let expr = Expr::Binary {
325 left: Box::new(Expr::column("posts__scores")),
326 op: BinaryOp::ArrayContains,
327 right: Box::new(Expr::param(Value::Array(vec![
328 Value::I32(100),
329 Value::I32(200),
330 ]))),
331 };
332 let select = Select::from_table("posts").filter(expr).build().unwrap();
333 let sql = dialect.render_select(&select).unwrap();
334
335 assert_eq!(
336 sql.text,
337 "SELECT * FROM \"posts\" WHERE (\"posts\".\"scores\" @> $1)"
338 );
339 assert_eq!(sql.params.len(), 1);
340 match &sql.params[0] {
341 Value::Array(arr) => {
342 assert_eq!(arr.len(), 2);
343 assert_eq!(arr[0], Value::I32(100));
344 assert_eq!(arr[1], Value::I32(200));
345 }
346 _ => panic!("Expected Array value"),
347 }
348 }
349
350 #[test]
351 fn vector_params_are_cast_to_pgvector_type() {
352 let dialect = PostgresDialect;
353 let select = Select::from_table("embeddings")
354 .filter(
355 Expr::column("embeddings__vector")
356 .eq(Expr::param(Value::Vector(vec![1.0, 2.0, 3.0]))),
357 )
358 .build()
359 .unwrap();
360 let sql = dialect.render_select(&select).unwrap();
361
362 assert_eq!(
363 sql.text,
364 "SELECT * FROM \"embeddings\" WHERE (\"embeddings\".\"vector\" = $1::vector)"
365 );
366 assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
367 }
368
369 #[test]
370 fn postgis_params_are_cast_to_spatial_types() {
371 let dialect = PostgresDialect;
372 let select = Select::from_table("places")
373 .filter(
374 Expr::column("places__geom")
375 .eq(Expr::param(Value::Geometry("POINT(1 2)".to_string()))),
376 )
377 .build()
378 .unwrap();
379 let sql = dialect.render_select(&select).unwrap();
380
381 assert_eq!(
382 sql.text,
383 "SELECT * FROM \"places\" WHERE (\"places\".\"geom\" = $1::geometry)"
384 );
385 assert_eq!(sql.params, vec![Value::Geometry("POINT(1 2)".to_string())]);
386
387 let select = Select::from_table("places")
388 .filter(
389 Expr::column("places__geog")
390 .eq(Expr::param(Value::Geography("POINT(1 2)".to_string()))),
391 )
392 .build()
393 .unwrap();
394 let sql = dialect.render_select(&select).unwrap();
395
396 assert_eq!(
397 sql.text,
398 "SELECT * FROM \"places\" WHERE (\"places\".\"geog\" = $1::geography)"
399 );
400 assert_eq!(sql.params, vec![Value::Geography("POINT(1 2)".to_string())]);
401 }
402
403 #[test]
404 fn composite_params_are_cast_to_their_type_name() {
405 let dialect = PostgresDialect;
406 let composite = Value::Composite {
407 type_name: "ChampionStatsT".to_string(),
408 fields: vec![Value::I32(0), Value::I32(0)],
409 };
410 let select = Select::from_table("champions")
411 .filter(Expr::column("champions__stats").eq(Expr::param(composite.clone())))
412 .build()
413 .unwrap();
414 let sql = dialect.render_select(&select).unwrap();
415
416 assert_eq!(
417 sql.text,
418 "SELECT * FROM \"champions\" WHERE (\"champions\".\"stats\" = $1::\"ChampionStatsT\")"
419 );
420 assert_eq!(sql.params, vec![composite]);
421 }
422
423 #[test]
424 fn composite_insert_and_update_params_are_cast_to_their_type_name() {
425 let dialect = PostgresDialect;
426 let composite = Value::Composite {
427 type_name: "ChampionStatsT".to_string(),
428 fields: vec![Value::I32(0), Value::I32(0)],
429 };
430
431 let insert = Insert::into_table("champions")
432 .column(nautilus_core::ColumnMarker::new("champions", "stats"))
433 .values(vec![composite.clone()])
434 .build()
435 .unwrap();
436 let sql = dialect.render_insert(&insert).unwrap();
437
438 assert_eq!(
439 sql.text,
440 "INSERT INTO \"champions\" (\"stats\") VALUES ($1::\"ChampionStatsT\")"
441 );
442 assert_eq!(sql.params, vec![composite.clone()]);
443
444 let update = Update::table("champions")
445 .set(
446 nautilus_core::ColumnMarker::new("champions", "stats"),
447 composite.clone(),
448 )
449 .build()
450 .unwrap();
451 let sql = dialect.render_update(&update).unwrap();
452
453 assert_eq!(
454 sql.text,
455 "UPDATE \"champions\" SET \"stats\" = $1::\"ChampionStatsT\""
456 );
457 assert_eq!(sql.params, vec![composite]);
458 }
459
460 #[test]
461 fn vector_distance_ordering_uses_pgvector_operator() {
462 let dialect = PostgresDialect;
463 let select = Select::from_table("embeddings")
464 .order_by_expr(
465 Expr::vector_distance(
466 nautilus_core::VectorMetric::Cosine,
467 Expr::column("embeddings__vector"),
468 Expr::param(Value::Vector(vec![1.0, 2.0, 3.0])),
469 ),
470 nautilus_core::OrderDir::Asc,
471 )
472 .take(5)
473 .build()
474 .unwrap();
475 let sql = dialect.render_select(&select).unwrap();
476
477 assert_eq!(
478 sql.text,
479 "SELECT * FROM \"embeddings\" ORDER BY (\"embeddings\".\"vector\" <=> $1::vector) ASC LIMIT 5"
480 );
481 assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
482 }
483}