1use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Expr, Insert, Result, Select, Update, Value};
5
6#[derive(Debug, Clone, Copy)]
12pub struct PostgresDialect;
13
14impl Dialect for PostgresDialect {
15 fn render_select(&self, select: &Select) -> Result<Sql> {
16 let mut ctx = RenderContext::new();
17 render_select_body_core!(&mut ctx, select, quote_identifier, render_expr, true, false);
18 Ok(Sql {
19 text: ctx.sql,
20 params: ctx.params,
21 })
22 }
23
24 fn render_insert(&self, insert: &Insert) -> Result<Sql> {
25 let mut ctx = RenderContext::new();
26 render_insert_body!(&mut ctx, insert, quote_identifier, true, true);
27 Ok(Sql {
28 text: ctx.sql,
29 params: ctx.params,
30 })
31 }
32
33 fn render_update(&self, update: &Update) -> Result<Sql> {
34 let mut ctx = RenderContext::new();
35 render_update_body!(&mut ctx, update, quote_identifier, render_expr, true, true);
36 Ok(Sql {
37 text: ctx.sql,
38 params: ctx.params,
39 })
40 }
41
42 fn render_delete(&self, delete: &Delete) -> Result<Sql> {
43 let mut ctx = RenderContext::new();
44 render_delete_body!(&mut ctx, delete, quote_identifier, render_expr, true);
45 Ok(Sql {
46 text: ctx.sql,
47 params: ctx.params,
48 })
49 }
50}
51
52fn quote_identifier(name: &str) -> String {
53 crate::double_quote_identifier(name)
54}
55
56struct RenderContext {
57 sql: String,
58 params: Vec<Value>,
59}
60
61impl RenderContext {
62 fn new() -> Self {
63 Self {
64 sql: String::new(),
65 params: Vec::new(),
66 }
67 }
68
69 fn push_param(&mut self, value: Value) -> String {
70 self.params.push(value);
71 format!("${}", self.params.len())
72 }
73}
74
75fn render_select_body(ctx: &mut RenderContext, select: &crate::Select) {
76 render_select_body_core!(ctx, select, quote_identifier, render_expr, true, false);
77}
78
79fn render_expr(ctx: &mut RenderContext, expr: &Expr) {
80 render_expr_common!(ctx, expr, quote_identifier, render_expr, render_select_body, {
81 Expr::Param(value) => {
82 if matches!(value, Value::Null) {
85 ctx.sql.push_str("NULL");
86 } else {
87 let placeholder = ctx.push_param(value.clone());
88 ctx.sql.push_str(&placeholder);
89 if matches!(value, Value::Uuid(_)) {
91 ctx.sql.push_str("::uuid");
92 } else if matches!(value, Value::Json(_)) {
93 ctx.sql.push_str("::json");
94 } else if matches!(value, Value::Vector(_)) {
95 ctx.sql.push_str("::vector");
96 } else if matches!(value, Value::Geometry(_)) {
97 ctx.sql.push_str("::geometry");
98 } else if matches!(value, Value::Geography(_)) {
99 ctx.sql.push_str("::geography");
100 } else if is_homogeneous_geometry_array(value) {
101 ctx.sql.push_str("::geometry[]");
102 } else if is_homogeneous_geography_array(value) {
103 ctx.sql.push_str("::geography[]");
104 } else if let Value::Enum { type_name, .. } = value {
105 ctx.sql.push_str("::");
106 ctx.sql.push_str(type_name);
107 }
108 }
109 }
110 Expr::Binary { left, op, right } => {
111 if matches!(op, BinaryOp::In | BinaryOp::NotIn) {
112 ctx.sql.push('(');
113 render_expr(ctx, left);
114 ctx.sql.push(' ');
115 ctx.sql.push_str(if matches!(op, BinaryOp::In) { "IN" } else { "NOT IN" });
116 ctx.sql.push_str(" (");
117 if let Expr::List(exprs) = right.as_ref() {
118 for (i, e) in exprs.iter().enumerate() {
119 if i > 0 { ctx.sql.push_str(", "); }
120 render_expr(ctx, e);
121 }
122 } else {
123 render_expr(ctx, right);
124 }
125 ctx.sql.push(')');
126 ctx.sql.push(')');
127 } else {
128 ctx.sql.push('(');
129 render_expr(ctx, left);
130 ctx.sql.push(' ');
131 ctx.sql.push_str(match op {
132 BinaryOp::ArrayContains => "@>",
133 BinaryOp::ArrayContainedBy => "<@",
134 BinaryOp::ArrayOverlaps => "&&",
135 _ => crate::binary_op_sql(op),
136 });
137 ctx.sql.push(' ');
138 render_expr(ctx, right);
139 ctx.sql.push(')');
140 }
141 }
142 Expr::FunctionCall { name, args } => {
143 if args.len() == 2 {
144 let op = match name.as_str() {
145 nautilus_core::expr::VECTOR_L2_DISTANCE_FUNCTION => Some("<->"),
146 nautilus_core::expr::VECTOR_INNER_PRODUCT_FUNCTION => Some("<#>"),
147 nautilus_core::expr::VECTOR_COSINE_DISTANCE_FUNCTION => Some("<=>"),
148 _ => None,
149 };
150 if let Some(op) = op {
151 ctx.sql.push('(');
152 render_expr(ctx, &args[0]);
153 ctx.sql.push(' ');
154 ctx.sql.push_str(op);
155 ctx.sql.push(' ');
156 render_expr(ctx, &args[1]);
157 ctx.sql.push(')');
158 return;
159 }
160 }
161 ctx.sql.push_str(name);
162 ctx.sql.push('(');
163 for (i, arg) in args.iter().enumerate() {
164 if i > 0 { ctx.sql.push_str(", "); }
165 render_expr(ctx, arg);
166 }
167 ctx.sql.push(')');
168 }
169 Expr::Filter { expr, predicate } => {
170 render_expr(ctx, expr);
172 ctx.sql.push_str(" FILTER (WHERE ");
173 render_expr(ctx, predicate);
174 ctx.sql.push(')');
175 }
176 });
177}
178
179fn is_homogeneous_geometry_array(value: &Value) -> bool {
180 matches!(
181 value,
182 Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geometry(_)))
183 )
184}
185
186fn is_homogeneous_geography_array(value: &Value) -> bool {
187 matches!(
188 value,
189 Value::Array(items) if !items.is_empty() && items.iter().all(|item| matches!(item, Value::Geography(_)))
190 )
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn test_quote_identifier() {
199 assert_eq!(quote_identifier("users"), "\"users\"");
200 assert_eq!(quote_identifier("email"), "\"email\"");
201 assert_eq!(quote_identifier("foo\"bar"), "\"foo\"\"bar\"");
202 assert_eq!(quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
203 }
204
205 #[test]
206 fn test_array_contains_operator() {
207 let dialect = PostgresDialect;
208 let expr = Expr::Binary {
209 left: Box::new(Expr::column("posts__tags")),
210 op: BinaryOp::ArrayContains,
211 right: Box::new(Expr::param(Value::Array(vec![Value::String(
212 "rust".to_string(),
213 )]))),
214 };
215 let select = Select::from_table("posts").filter(expr).build().unwrap();
216 let sql = dialect.render_select(&select).unwrap();
217
218 assert_eq!(
219 sql.text,
220 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" @> $1)"
221 );
222 assert_eq!(sql.params.len(), 1);
223 match &sql.params[0] {
224 Value::Array(arr) => {
225 assert_eq!(arr.len(), 1);
226 assert_eq!(arr[0], Value::String("rust".to_string()));
227 }
228 _ => panic!("Expected Array value"),
229 }
230 }
231
232 #[test]
233 fn test_array_contained_by_operator() {
234 let dialect = PostgresDialect;
235 let expr = Expr::Binary {
236 left: Box::new(Expr::column("posts__tags")),
237 op: BinaryOp::ArrayContainedBy,
238 right: Box::new(Expr::param(Value::Array(vec![
239 Value::String("rust".to_string()),
240 Value::String("go".to_string()),
241 ]))),
242 };
243 let select = Select::from_table("posts").filter(expr).build().unwrap();
244 let sql = dialect.render_select(&select).unwrap();
245
246 assert_eq!(
247 sql.text,
248 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" <@ $1)"
249 );
250 assert_eq!(sql.params.len(), 1);
251 match &sql.params[0] {
252 Value::Array(arr) => {
253 assert_eq!(arr.len(), 2);
254 assert_eq!(arr[0], Value::String("rust".to_string()));
255 assert_eq!(arr[1], Value::String("go".to_string()));
256 }
257 _ => panic!("Expected Array value"),
258 }
259 }
260
261 #[test]
262 fn test_array_overlaps_operator() {
263 let dialect = PostgresDialect;
264 let expr = Expr::Binary {
265 left: Box::new(Expr::column("posts__tags")),
266 op: BinaryOp::ArrayOverlaps,
267 right: Box::new(Expr::param(Value::Array(vec![
268 Value::String("rust".to_string()),
269 Value::String("python".to_string()),
270 ]))),
271 };
272 let select = Select::from_table("posts").filter(expr).build().unwrap();
273 let sql = dialect.render_select(&select).unwrap();
274
275 assert_eq!(
276 sql.text,
277 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" && $1)"
278 );
279 assert_eq!(sql.params.len(), 1);
280 match &sql.params[0] {
281 Value::Array(arr) => {
282 assert_eq!(arr.len(), 2);
283 assert_eq!(arr[0], Value::String("rust".to_string()));
284 assert_eq!(arr[1], Value::String("python".to_string()));
285 }
286 _ => panic!("Expected Array value"),
287 }
288 }
289
290 #[test]
291 fn test_array_operators_with_integers() {
292 let dialect = PostgresDialect;
293 let expr = Expr::Binary {
294 left: Box::new(Expr::column("posts__scores")),
295 op: BinaryOp::ArrayContains,
296 right: Box::new(Expr::param(Value::Array(vec![
297 Value::I32(100),
298 Value::I32(200),
299 ]))),
300 };
301 let select = Select::from_table("posts").filter(expr).build().unwrap();
302 let sql = dialect.render_select(&select).unwrap();
303
304 assert_eq!(
305 sql.text,
306 "SELECT * FROM \"posts\" WHERE (\"posts\".\"scores\" @> $1)"
307 );
308 assert_eq!(sql.params.len(), 1);
309 match &sql.params[0] {
310 Value::Array(arr) => {
311 assert_eq!(arr.len(), 2);
312 assert_eq!(arr[0], Value::I32(100));
313 assert_eq!(arr[1], Value::I32(200));
314 }
315 _ => panic!("Expected Array value"),
316 }
317 }
318
319 #[test]
320 fn vector_params_are_cast_to_pgvector_type() {
321 let dialect = PostgresDialect;
322 let select = Select::from_table("embeddings")
323 .filter(
324 Expr::column("embeddings__vector")
325 .eq(Expr::param(Value::Vector(vec![1.0, 2.0, 3.0]))),
326 )
327 .build()
328 .unwrap();
329 let sql = dialect.render_select(&select).unwrap();
330
331 assert_eq!(
332 sql.text,
333 "SELECT * FROM \"embeddings\" WHERE (\"embeddings\".\"vector\" = $1::vector)"
334 );
335 assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
336 }
337
338 #[test]
339 fn postgis_params_are_cast_to_spatial_types() {
340 let dialect = PostgresDialect;
341 let select = Select::from_table("places")
342 .filter(
343 Expr::column("places__geom")
344 .eq(Expr::param(Value::Geometry("POINT(1 2)".to_string()))),
345 )
346 .build()
347 .unwrap();
348 let sql = dialect.render_select(&select).unwrap();
349
350 assert_eq!(
351 sql.text,
352 "SELECT * FROM \"places\" WHERE (\"places\".\"geom\" = $1::geometry)"
353 );
354 assert_eq!(sql.params, vec![Value::Geometry("POINT(1 2)".to_string())]);
355
356 let select = Select::from_table("places")
357 .filter(
358 Expr::column("places__geog")
359 .eq(Expr::param(Value::Geography("POINT(1 2)".to_string()))),
360 )
361 .build()
362 .unwrap();
363 let sql = dialect.render_select(&select).unwrap();
364
365 assert_eq!(
366 sql.text,
367 "SELECT * FROM \"places\" WHERE (\"places\".\"geog\" = $1::geography)"
368 );
369 assert_eq!(sql.params, vec![Value::Geography("POINT(1 2)".to_string())]);
370 }
371
372 #[test]
373 fn vector_distance_ordering_uses_pgvector_operator() {
374 let dialect = PostgresDialect;
375 let select = Select::from_table("embeddings")
376 .order_by_expr(
377 Expr::vector_distance(
378 nautilus_core::VectorMetric::Cosine,
379 Expr::column("embeddings__vector"),
380 Expr::param(Value::Vector(vec![1.0, 2.0, 3.0])),
381 ),
382 nautilus_core::OrderDir::Asc,
383 )
384 .take(5)
385 .build()
386 .unwrap();
387 let sql = dialect.render_select(&select).unwrap();
388
389 assert_eq!(
390 sql.text,
391 "SELECT * FROM \"embeddings\" ORDER BY (\"embeddings\".\"vector\" <=> $1::vector) ASC LIMIT 5"
392 );
393 assert_eq!(sql.params, vec![Value::Vector(vec![1.0, 2.0, 3.0])]);
394 }
395}