1use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Expr, Insert, Result, Select, Update, Value};
5
6#[derive(Debug, Clone, Copy)]
12pub struct PostgresDialect;
13
14impl Dialect for PostgresDialect {
15 fn render_select(&self, select: &Select) -> Result<Sql> {
16 let mut ctx = RenderContext::new();
17 render_select_body_core!(&mut ctx, select, quote_identifier, render_expr, true, false);
18 Ok(Sql {
19 text: ctx.sql,
20 params: ctx.params,
21 })
22 }
23
24 fn render_insert(&self, insert: &Insert) -> Result<Sql> {
25 let mut ctx = RenderContext::new();
26 render_insert_body!(&mut ctx, insert, quote_identifier, true, true);
27 Ok(Sql {
28 text: ctx.sql,
29 params: ctx.params,
30 })
31 }
32
33 fn render_update(&self, update: &Update) -> Result<Sql> {
34 let mut ctx = RenderContext::new();
35 render_update_body!(&mut ctx, update, quote_identifier, render_expr, true, true);
36 Ok(Sql {
37 text: ctx.sql,
38 params: ctx.params,
39 })
40 }
41
42 fn render_delete(&self, delete: &Delete) -> Result<Sql> {
43 let mut ctx = RenderContext::new();
44 render_delete_body!(&mut ctx, delete, quote_identifier, render_expr, true);
45 Ok(Sql {
46 text: ctx.sql,
47 params: ctx.params,
48 })
49 }
50}
51
52fn quote_identifier(name: &str) -> String {
53 crate::double_quote_identifier(name)
54}
55
56struct RenderContext {
57 sql: String,
58 params: Vec<Value>,
59}
60
61impl RenderContext {
62 fn new() -> Self {
63 Self {
64 sql: String::new(),
65 params: Vec::new(),
66 }
67 }
68
69 fn push_param(&mut self, value: Value) -> String {
70 self.params.push(value);
71 format!("${}", self.params.len())
72 }
73}
74
75fn render_select_body(ctx: &mut RenderContext, select: &crate::Select) {
76 render_select_body_core!(ctx, select, quote_identifier, render_expr, true, false);
77}
78
79fn render_expr(ctx: &mut RenderContext, expr: &Expr) {
80 render_expr_common!(ctx, expr, quote_identifier, render_expr, render_select_body, {
81 Expr::Param(value) => {
82 if matches!(value, Value::Null) {
85 ctx.sql.push_str("NULL");
86 } else {
87 let placeholder = ctx.push_param(value.clone());
88 ctx.sql.push_str(&placeholder);
89 if matches!(value, Value::Uuid(_)) {
91 ctx.sql.push_str("::uuid");
92 } else if matches!(value, Value::Json(_)) {
93 ctx.sql.push_str("::json");
94 } else if let Value::Enum { type_name, .. } = value {
95 ctx.sql.push_str("::");
96 ctx.sql.push_str(type_name);
97 }
98 }
99 }
100 Expr::Binary { left, op, right } => {
101 if matches!(op, BinaryOp::In | BinaryOp::NotIn) {
102 ctx.sql.push('(');
103 render_expr(ctx, left);
104 ctx.sql.push(' ');
105 ctx.sql.push_str(if matches!(op, BinaryOp::In) { "IN" } else { "NOT IN" });
106 ctx.sql.push_str(" (");
107 if let Expr::List(exprs) = right.as_ref() {
108 for (i, e) in exprs.iter().enumerate() {
109 if i > 0 { ctx.sql.push_str(", "); }
110 render_expr(ctx, e);
111 }
112 } else {
113 render_expr(ctx, right);
114 }
115 ctx.sql.push(')');
116 ctx.sql.push(')');
117 } else {
118 ctx.sql.push('(');
119 render_expr(ctx, left);
120 ctx.sql.push(' ');
121 ctx.sql.push_str(match op {
122 BinaryOp::ArrayContains => "@>",
123 BinaryOp::ArrayContainedBy => "<@",
124 BinaryOp::ArrayOverlaps => "&&",
125 _ => crate::binary_op_sql(op),
126 });
127 ctx.sql.push(' ');
128 render_expr(ctx, right);
129 ctx.sql.push(')');
130 }
131 }
132 Expr::FunctionCall { name, args } => {
133 ctx.sql.push_str(name);
134 ctx.sql.push('(');
135 for (i, arg) in args.iter().enumerate() {
136 if i > 0 { ctx.sql.push_str(", "); }
137 render_expr(ctx, arg);
138 }
139 ctx.sql.push(')');
140 }
141 Expr::Filter { expr, predicate } => {
142 render_expr(ctx, expr);
144 ctx.sql.push_str(" FILTER (WHERE ");
145 render_expr(ctx, predicate);
146 ctx.sql.push(')');
147 }
148 });
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_quote_identifier() {
157 assert_eq!(quote_identifier("users"), "\"users\"");
158 assert_eq!(quote_identifier("email"), "\"email\"");
159 assert_eq!(quote_identifier("foo\"bar"), "\"foo\"\"bar\"");
160 assert_eq!(quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
161 }
162
163 #[test]
164 fn test_array_contains_operator() {
165 let dialect = PostgresDialect;
166 let expr = Expr::Binary {
167 left: Box::new(Expr::column("posts__tags")),
168 op: BinaryOp::ArrayContains,
169 right: Box::new(Expr::param(Value::Array(vec![Value::String(
170 "rust".to_string(),
171 )]))),
172 };
173 let select = Select::from_table("posts").filter(expr).build().unwrap();
174 let sql = dialect.render_select(&select).unwrap();
175
176 assert_eq!(
177 sql.text,
178 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" @> $1)"
179 );
180 assert_eq!(sql.params.len(), 1);
181 match &sql.params[0] {
182 Value::Array(arr) => {
183 assert_eq!(arr.len(), 1);
184 assert_eq!(arr[0], Value::String("rust".to_string()));
185 }
186 _ => panic!("Expected Array value"),
187 }
188 }
189
190 #[test]
191 fn test_array_contained_by_operator() {
192 let dialect = PostgresDialect;
193 let expr = Expr::Binary {
194 left: Box::new(Expr::column("posts__tags")),
195 op: BinaryOp::ArrayContainedBy,
196 right: Box::new(Expr::param(Value::Array(vec![
197 Value::String("rust".to_string()),
198 Value::String("go".to_string()),
199 ]))),
200 };
201 let select = Select::from_table("posts").filter(expr).build().unwrap();
202 let sql = dialect.render_select(&select).unwrap();
203
204 assert_eq!(
205 sql.text,
206 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" <@ $1)"
207 );
208 assert_eq!(sql.params.len(), 1);
209 match &sql.params[0] {
210 Value::Array(arr) => {
211 assert_eq!(arr.len(), 2);
212 assert_eq!(arr[0], Value::String("rust".to_string()));
213 assert_eq!(arr[1], Value::String("go".to_string()));
214 }
215 _ => panic!("Expected Array value"),
216 }
217 }
218
219 #[test]
220 fn test_array_overlaps_operator() {
221 let dialect = PostgresDialect;
222 let expr = Expr::Binary {
223 left: Box::new(Expr::column("posts__tags")),
224 op: BinaryOp::ArrayOverlaps,
225 right: Box::new(Expr::param(Value::Array(vec![
226 Value::String("rust".to_string()),
227 Value::String("python".to_string()),
228 ]))),
229 };
230 let select = Select::from_table("posts").filter(expr).build().unwrap();
231 let sql = dialect.render_select(&select).unwrap();
232
233 assert_eq!(
234 sql.text,
235 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" && $1)"
236 );
237 assert_eq!(sql.params.len(), 1);
238 match &sql.params[0] {
239 Value::Array(arr) => {
240 assert_eq!(arr.len(), 2);
241 assert_eq!(arr[0], Value::String("rust".to_string()));
242 assert_eq!(arr[1], Value::String("python".to_string()));
243 }
244 _ => panic!("Expected Array value"),
245 }
246 }
247
248 #[test]
249 fn test_array_operators_with_integers() {
250 let dialect = PostgresDialect;
251 let expr = Expr::Binary {
252 left: Box::new(Expr::column("posts__scores")),
253 op: BinaryOp::ArrayContains,
254 right: Box::new(Expr::param(Value::Array(vec![
255 Value::I32(100),
256 Value::I32(200),
257 ]))),
258 };
259 let select = Select::from_table("posts").filter(expr).build().unwrap();
260 let sql = dialect.render_select(&select).unwrap();
261
262 assert_eq!(
263 sql.text,
264 "SELECT * FROM \"posts\" WHERE (\"posts\".\"scores\" @> $1)"
265 );
266 assert_eq!(sql.params.len(), 1);
267 match &sql.params[0] {
268 Value::Array(arr) => {
269 assert_eq!(arr.len(), 2);
270 assert_eq!(arr[0], Value::I32(100));
271 assert_eq!(arr[1], Value::I32(200));
272 }
273 _ => panic!("Expected Array value"),
274 }
275 }
276}