1use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Expr, Insert, Result, Select, Update, Value};
5
6#[derive(Debug, Clone, Copy)]
8pub struct SqliteDialect;
9
10impl Dialect for SqliteDialect {
13 fn render_select(&self, select: &Select) -> Result<Sql> {
14 let mut ctx = RenderContext::new();
15 render_select_body_core!(
17 &mut ctx,
18 select,
19 quote_identifier,
20 render_expr,
21 false,
22 false
23 );
24 Ok(Sql {
25 text: ctx.sql,
26 params: ctx.params,
27 })
28 }
29
30 fn render_insert(&self, insert: &Insert) -> Result<Sql> {
31 let mut ctx = RenderContext::new();
32 render_insert_body!(&mut ctx, insert, quote_identifier, true, false);
33 Ok(Sql {
34 text: ctx.sql,
35 params: ctx.params,
36 })
37 }
38
39 fn render_update(&self, update: &Update) -> Result<Sql> {
40 let mut ctx = RenderContext::new();
41 render_update_body!(&mut ctx, update, quote_identifier, render_expr, true, false);
42 Ok(Sql {
43 text: ctx.sql,
44 params: ctx.params,
45 })
46 }
47
48 fn render_delete(&self, delete: &Delete) -> Result<Sql> {
49 let mut ctx = RenderContext::new();
50 render_delete_body!(&mut ctx, delete, quote_identifier, render_expr, true);
51 Ok(Sql {
52 text: ctx.sql,
53 params: ctx.params,
54 })
55 }
56}
57
58fn quote_identifier(name: &str) -> String {
60 crate::double_quote_identifier(name)
61}
62
63struct RenderContext {
65 sql: String,
66 params: Vec<Value>,
67}
68
69impl RenderContext {
70 fn new() -> Self {
71 Self {
72 sql: String::new(),
73 params: Vec::new(),
74 }
75 }
76
77 fn push_param(&mut self, value: Value) -> String {
79 self.params.push(value);
80 "?".to_string()
81 }
82}
83
84fn render_select_body(ctx: &mut RenderContext, select: &crate::Select) {
88 render_select_body_core!(ctx, select, quote_identifier, render_expr, false, false);
90}
91
92fn render_expr(ctx: &mut RenderContext, expr: &Expr) {
94 render_expr_common!(ctx, expr, quote_identifier, render_expr, render_select_body, {
95 Expr::Param(value) => {
96 if matches!(value, Value::Null) {
98 ctx.sql.push_str("NULL");
99 } else {
100 let placeholder = ctx.push_param(value.clone());
101 ctx.sql.push_str(&placeholder);
102 }
103 }
104 Expr::Binary { left, op, right } => {
105 if matches!(op, BinaryOp::In | BinaryOp::NotIn) {
106 ctx.sql.push('(');
107 render_expr(ctx, left);
108 ctx.sql.push(' ');
109 ctx.sql.push_str(if matches!(op, BinaryOp::In) { "IN" } else { "NOT IN" });
110 ctx.sql.push_str(" (");
111 if let Expr::List(exprs) = right.as_ref() {
112 for (i, e) in exprs.iter().enumerate() {
113 if i > 0 { ctx.sql.push_str(", "); }
114 render_expr(ctx, e);
115 }
116 } else {
117 render_expr(ctx, right);
118 }
119 ctx.sql.push(')');
120 ctx.sql.push(')');
121 } else if matches!(op, BinaryOp::ArrayContains | BinaryOp::ArrayContainedBy | BinaryOp::ArrayOverlaps) {
122 match op {
125 BinaryOp::ArrayContains => {
126 ctx.sql.push_str("NOT EXISTS (SELECT 1 FROM json_each(");
128 render_expr(ctx, right);
129 ctx.sql.push_str(") AS _rhs WHERE _rhs.value NOT IN (SELECT _col.value FROM json_each(");
130 render_expr(ctx, left);
131 ctx.sql.push_str(") AS _col)))");
132 }
133 BinaryOp::ArrayContainedBy => {
134 ctx.sql.push_str("NOT EXISTS (SELECT 1 FROM json_each(");
136 render_expr(ctx, left);
137 ctx.sql.push_str(") AS _col WHERE _col.value NOT IN (SELECT _rhs.value FROM json_each(");
138 render_expr(ctx, right);
139 ctx.sql.push_str(") AS _rhs)))");
140 }
141 BinaryOp::ArrayOverlaps => {
142 ctx.sql.push_str("EXISTS (SELECT 1 FROM json_each(");
144 render_expr(ctx, left);
145 ctx.sql.push_str(") AS _col WHERE _col.value IN (SELECT _rhs.value FROM json_each(");
146 render_expr(ctx, right);
147 ctx.sql.push_str(") AS _rhs)))");
148 }
149 _ => unreachable!(),
150 }
151 } else {
152 ctx.sql.push('(');
153 render_expr(ctx, left);
154 ctx.sql.push(' ');
155 ctx.sql.push_str(crate::binary_op_sql(op));
156 ctx.sql.push(' ');
157 render_expr(ctx, right);
158 ctx.sql.push(')');
159 }
160 }
161 Expr::FunctionCall { name, args } => {
162 let sqlite_name = match name.as_str() {
164 "json_agg" => "json_group_array",
165 "json_build_object" => "json_object",
166 _ => name,
167 };
168 ctx.sql.push_str(sqlite_name);
169 ctx.sql.push('(');
170 for (i, arg) in args.iter().enumerate() {
171 if i > 0 { ctx.sql.push_str(", "); }
172 render_expr(ctx, arg);
173 }
174 ctx.sql.push(')');
175 }
176 Expr::Filter { expr, predicate } => {
177 render_expr(ctx, expr);
179 ctx.sql.push_str(" FILTER (WHERE ");
180 render_expr(ctx, predicate);
181 ctx.sql.push(')');
182 }
183 });
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn test_quote_identifier() {
192 assert_eq!(quote_identifier("users"), "\"users\"");
193 assert_eq!(quote_identifier("email"), "\"email\"");
194 assert_eq!(quote_identifier("foo\"bar"), "\"foo\"\"bar\"");
195 assert_eq!(quote_identifier("a\"b\"c"), "\"a\"\"b\"\"c\"");
196 }
197
198 #[test]
201 fn test_array_contains_operator() {
202 let dialect = SqliteDialect;
203 let expr = Expr::Binary {
204 left: Box::new(Expr::column("posts__tags")),
205 op: BinaryOp::ArrayContains,
206 right: Box::new(Expr::param(Value::Array(vec![Value::String(
207 "rust".to_string(),
208 )]))),
209 };
210 let select = Select::from_table("posts").filter(expr).build().unwrap();
211 let sql = dialect.render_select(&select).unwrap();
212
213 assert_eq!(
214 sql.text,
215 "SELECT * FROM \"posts\" WHERE NOT EXISTS (SELECT 1 FROM json_each(?) AS _rhs WHERE _rhs.value NOT IN (SELECT _col.value FROM json_each(\"posts\".\"tags\") AS _col)))"
216 );
217 assert_eq!(sql.params.len(), 1);
218 match &sql.params[0] {
219 Value::Array(arr) => {
220 assert_eq!(arr.len(), 1);
221 assert_eq!(arr[0], Value::String("rust".to_string()));
222 }
223 _ => panic!("Expected Array value"),
224 }
225 }
226
227 #[test]
228 fn test_array_contained_by_operator() {
229 let dialect = SqliteDialect;
230 let expr = Expr::Binary {
231 left: Box::new(Expr::column("posts__tags")),
232 op: BinaryOp::ArrayContainedBy,
233 right: Box::new(Expr::param(Value::Array(vec![
234 Value::String("rust".to_string()),
235 Value::String("go".to_string()),
236 ]))),
237 };
238 let select = Select::from_table("posts").filter(expr).build().unwrap();
239 let sql = dialect.render_select(&select).unwrap();
240
241 assert_eq!(
242 sql.text,
243 "SELECT * FROM \"posts\" WHERE NOT EXISTS (SELECT 1 FROM json_each(\"posts\".\"tags\") AS _col WHERE _col.value NOT IN (SELECT _rhs.value FROM json_each(?) AS _rhs)))"
244 );
245 assert_eq!(sql.params.len(), 1);
246 match &sql.params[0] {
247 Value::Array(arr) => {
248 assert_eq!(arr.len(), 2);
249 assert_eq!(arr[0], Value::String("rust".to_string()));
250 assert_eq!(arr[1], Value::String("go".to_string()));
251 }
252 _ => panic!("Expected Array value"),
253 }
254 }
255
256 #[test]
257 fn test_array_overlaps_operator() {
258 let dialect = SqliteDialect;
259 let expr = Expr::Binary {
260 left: Box::new(Expr::column("posts__tags")),
261 op: BinaryOp::ArrayOverlaps,
262 right: Box::new(Expr::param(Value::Array(vec![
263 Value::String("rust".to_string()),
264 Value::String("python".to_string()),
265 ]))),
266 };
267 let select = Select::from_table("posts").filter(expr).build().unwrap();
268 let sql = dialect.render_select(&select).unwrap();
269
270 assert_eq!(
271 sql.text,
272 "SELECT * FROM \"posts\" WHERE EXISTS (SELECT 1 FROM json_each(\"posts\".\"tags\") AS _col WHERE _col.value IN (SELECT _rhs.value FROM json_each(?) AS _rhs)))"
273 );
274 assert_eq!(sql.params.len(), 1);
275 match &sql.params[0] {
276 Value::Array(arr) => {
277 assert_eq!(arr.len(), 2);
278 assert_eq!(arr[0], Value::String("rust".to_string()));
279 assert_eq!(arr[1], Value::String("python".to_string()));
280 }
281 _ => panic!("Expected Array value"),
282 }
283 }
284}