1use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Expr, Insert, Result, Select, Update, Value};
5
6#[derive(Debug, Clone, Copy)]
12pub struct PostgresDialect;
13
14impl Dialect for PostgresDialect {
15 fn render_select(&self, select: &Select) -> Result<Sql> {
16 let mut ctx = RenderContext::new();
17 render_select_body_core!(&mut ctx, select, quote_identifier, render_expr, true, false);
19 Ok(Sql {
20 text: ctx.sql,
21 params: ctx.params,
22 })
23 }
24
25 fn render_insert(&self, insert: &Insert) -> Result<Sql> {
26 let mut ctx = RenderContext::new();
27 render_insert_body!(&mut ctx, insert, quote_identifier, true, true);
28 Ok(Sql {
29 text: ctx.sql,
30 params: ctx.params,
31 })
32 }
33
34 fn render_update(&self, update: &Update) -> Result<Sql> {
35 let mut ctx = RenderContext::new();
36 render_update_body!(&mut ctx, update, quote_identifier, render_expr, true, true);
37 Ok(Sql {
38 text: ctx.sql,
39 params: ctx.params,
40 })
41 }
42
43 fn render_delete(&self, delete: &Delete) -> Result<Sql> {
44 let mut ctx = RenderContext::new();
45 render_delete_body!(&mut ctx, delete, quote_identifier, render_expr, true);
46 Ok(Sql {
47 text: ctx.sql,
48 params: ctx.params,
49 })
50 }
51}
52
53fn quote_identifier(name: &str) -> String {
55 crate::double_quote_identifier(name)
56}
57
58struct RenderContext {
60 sql: String,
61 params: Vec<Value>,
62}
63
64impl RenderContext {
65 fn new() -> Self {
66 Self {
67 sql: String::new(),
68 params: Vec::new(),
69 }
70 }
71
72 fn push_param(&mut self, value: Value) -> String {
74 self.params.push(value);
75 format!("${}", self.params.len())
76 }
77}
78
79fn render_select_body(ctx: &mut RenderContext, select: &crate::Select) {
83 render_select_body_core!(ctx, select, quote_identifier, render_expr, true, false);
84}
85
86fn render_expr(ctx: &mut RenderContext, expr: &Expr) {
88 render_expr_common!(ctx, expr, quote_identifier, render_expr, render_select_body, {
89 Expr::Param(value) => {
90 if matches!(value, Value::Null) {
93 ctx.sql.push_str("NULL");
94 } else {
95 let placeholder = ctx.push_param(value.clone());
96 ctx.sql.push_str(&placeholder);
97 if matches!(value, Value::Uuid(_)) {
99 ctx.sql.push_str("::uuid");
100 } else if matches!(value, Value::Json(_)) {
101 ctx.sql.push_str("::json");
102 } else if let Value::Enum { type_name, .. } = value {
103 ctx.sql.push_str("::");
104 ctx.sql.push_str(type_name);
105 }
106 }
107 }
108 Expr::Binary { left, op, right } => {
109 if matches!(op, BinaryOp::In | BinaryOp::NotIn) {
110 ctx.sql.push('(');
111 render_expr(ctx, left);
112 ctx.sql.push(' ');
113 ctx.sql.push_str(if matches!(op, BinaryOp::In) { "IN" } else { "NOT IN" });
114 ctx.sql.push_str(" (");
115 if let Expr::List(exprs) = right.as_ref() {
116 for (i, e) in exprs.iter().enumerate() {
117 if i > 0 { ctx.sql.push_str(", "); }
118 render_expr(ctx, e);
119 }
120 } else {
121 render_expr(ctx, right);
122 }
123 ctx.sql.push(')');
124 ctx.sql.push(')');
125 } else {
126 ctx.sql.push('(');
127 render_expr(ctx, left);
128 ctx.sql.push(' ');
129 ctx.sql.push_str(match op {
130 BinaryOp::ArrayContains => "@>",
131 BinaryOp::ArrayContainedBy => "<@",
132 BinaryOp::ArrayOverlaps => "&&",
133 _ => crate::binary_op_sql(op),
134 });
135 ctx.sql.push(' ');
136 render_expr(ctx, right);
137 ctx.sql.push(')');
138 }
139 }
140 Expr::FunctionCall { name, args } => {
141 ctx.sql.push_str(name);
142 ctx.sql.push('(');
143 for (i, arg) in args.iter().enumerate() {
144 if i > 0 { ctx.sql.push_str(", "); }
145 render_expr(ctx, arg);
146 }
147 ctx.sql.push(')');
148 }
149 Expr::Filter { expr, predicate } => {
150 render_expr(ctx, expr);
152 ctx.sql.push_str(" FILTER (WHERE ");
153 render_expr(ctx, predicate);
154 ctx.sql.push(')');
155 }
156 });
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn test_quote_identifier() {
165 assert_eq!(quote_identifier("users"), "\"users\"");
166 assert_eq!(quote_identifier("email"), "\"email\"");
167 assert_eq!(quote_identifier("foo\"bar"), "\"foo\"\"bar\"");
168 assert_eq!(quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
169 }
170
171 #[test]
174 fn test_array_contains_operator() {
175 let dialect = PostgresDialect;
176 let expr = Expr::Binary {
177 left: Box::new(Expr::column("posts__tags")),
178 op: BinaryOp::ArrayContains,
179 right: Box::new(Expr::param(Value::Array(vec![Value::String(
180 "rust".to_string(),
181 )]))),
182 };
183 let select = Select::from_table("posts").filter(expr).build().unwrap();
184 let sql = dialect.render_select(&select).unwrap();
185
186 assert_eq!(
187 sql.text,
188 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" @> $1)"
189 );
190 assert_eq!(sql.params.len(), 1);
191 match &sql.params[0] {
192 Value::Array(arr) => {
193 assert_eq!(arr.len(), 1);
194 assert_eq!(arr[0], Value::String("rust".to_string()));
195 }
196 _ => panic!("Expected Array value"),
197 }
198 }
199
200 #[test]
201 fn test_array_contained_by_operator() {
202 let dialect = PostgresDialect;
203 let expr = Expr::Binary {
204 left: Box::new(Expr::column("posts__tags")),
205 op: BinaryOp::ArrayContainedBy,
206 right: Box::new(Expr::param(Value::Array(vec![
207 Value::String("rust".to_string()),
208 Value::String("go".to_string()),
209 ]))),
210 };
211 let select = Select::from_table("posts").filter(expr).build().unwrap();
212 let sql = dialect.render_select(&select).unwrap();
213
214 assert_eq!(
215 sql.text,
216 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" <@ $1)"
217 );
218 assert_eq!(sql.params.len(), 1);
219 match &sql.params[0] {
220 Value::Array(arr) => {
221 assert_eq!(arr.len(), 2);
222 assert_eq!(arr[0], Value::String("rust".to_string()));
223 assert_eq!(arr[1], Value::String("go".to_string()));
224 }
225 _ => panic!("Expected Array value"),
226 }
227 }
228
229 #[test]
230 fn test_array_overlaps_operator() {
231 let dialect = PostgresDialect;
232 let expr = Expr::Binary {
233 left: Box::new(Expr::column("posts__tags")),
234 op: BinaryOp::ArrayOverlaps,
235 right: Box::new(Expr::param(Value::Array(vec![
236 Value::String("rust".to_string()),
237 Value::String("python".to_string()),
238 ]))),
239 };
240 let select = Select::from_table("posts").filter(expr).build().unwrap();
241 let sql = dialect.render_select(&select).unwrap();
242
243 assert_eq!(
244 sql.text,
245 "SELECT * FROM \"posts\" WHERE (\"posts\".\"tags\" && $1)"
246 );
247 assert_eq!(sql.params.len(), 1);
248 match &sql.params[0] {
249 Value::Array(arr) => {
250 assert_eq!(arr.len(), 2);
251 assert_eq!(arr[0], Value::String("rust".to_string()));
252 assert_eq!(arr[1], Value::String("python".to_string()));
253 }
254 _ => panic!("Expected Array value"),
255 }
256 }
257
258 #[test]
259 fn test_array_operators_with_integers() {
260 let dialect = PostgresDialect;
261 let expr = Expr::Binary {
262 left: Box::new(Expr::column("posts__scores")),
263 op: BinaryOp::ArrayContains,
264 right: Box::new(Expr::param(Value::Array(vec![
265 Value::I32(100),
266 Value::I32(200),
267 ]))),
268 };
269 let select = Select::from_table("posts").filter(expr).build().unwrap();
270 let sql = dialect.render_select(&select).unwrap();
271
272 assert_eq!(
273 sql.text,
274 "SELECT * FROM \"posts\" WHERE (\"posts\".\"scores\" @> $1)"
275 );
276 assert_eq!(sql.params.len(), 1);
277 match &sql.params[0] {
278 Value::Array(arr) => {
279 assert_eq!(arr.len(), 2);
280 assert_eq!(arr[0], Value::I32(100));
281 assert_eq!(arr[1], Value::I32(200));
282 }
283 _ => panic!("Expected Array value"),
284 }
285 }
286}