1use anyhow::{bail, Result};
2use rest_model::Condition;
3use serde_json::Value;
4use tokio_postgres::types::ToSql;
5
6pub fn cond_to_sql(
7 cond: &Condition,
8 bindings: &mut Vec<Box<dyn ToSql + Sync>>,
9 seq: &mut u32,
10) -> Result<String> {
11 match cond {
12 Condition::And(conds) => Ok(format!(
13 "({})",
14 conds
15 .iter()
16 .map(|c| cond_to_sql(c, bindings, seq))
17 .collect::<Result<Vec<_>>>()?
18 .join(" AND "),
19 )),
20 Condition::Or(conds) => Ok(format!(
21 "({})",
22 conds
23 .iter()
24 .map(|c| cond_to_sql(c, bindings, seq))
25 .collect::<Result<Vec<_>>>()?
26 .join(" OR "),
27 )),
28 Condition::Not(cond) => Ok(format!("(NOT ({}))", cond_to_sql(cond, bindings, seq)?)),
29 Condition::Regex(field, value) => {
30 check_invalid_chars(field)?;
31 let s = *seq;
32 *seq += 1;
33 match value {
34 Value::String(v) => {
35 bindings.push(Box::new(v.clone()));
36 Ok(format!("{} ~ ${}", field_to_key_t(&field), s))
37 }
38 _ => {
39 bail!("Invalid value for Regex")
40 }
41 }
42 }
43 Condition::Regexi(field, value) => {
44 check_invalid_chars(field)?;
45 let s = *seq;
46 *seq += 1;
47 match value {
48 Value::String(v) => {
49 bindings.push(Box::new(v.clone()));
50 Ok(format!("{} ~* ${}", field_to_key_t(&field), s))
51 }
52 _ => {
53 bail!("Invalid value for Regexi")
54 }
55 }
56 }
57 Condition::Eq(field, value) => {
58 check_invalid_chars(field)?;
59 normal_comparison(seq, bindings, field, "=", value)
60 }
61 Condition::Ne(field, value) => {
62 check_invalid_chars(field)?;
63 normal_comparison(seq, bindings, field, "!=", value)
64 }
65 Condition::Gt(field, value) => {
66 check_invalid_chars(field)?;
67 normal_comparison(seq, bindings, field, ">", value)
68 }
69 Condition::Gte(field, value) => {
70 check_invalid_chars(field)?;
71 normal_comparison(seq, bindings, field, ">=", value)
72 }
73 Condition::Lt(field, value) => {
74 check_invalid_chars(field)?;
75 normal_comparison(seq, bindings, field, "<", value)
76 }
77 Condition::Lte(field, value) => {
78 check_invalid_chars(field)?;
79 normal_comparison(seq, bindings, field, "<=", value)
80 }
81 Condition::In(field, value) => {
82 check_invalid_chars(field)?;
83 array_comparison(seq, bindings, field, value)
84 }
85 Condition::Nin(field, value) => {
86 check_invalid_chars(field)?;
87 Ok(format!(
88 "(NOT ({}))",
89 array_comparison(seq, bindings, field, value)?,
90 ))
91 }
92 }
93}
94
95pub fn sort_to_sql(sort_expr: &str) -> Result<String> {
96 check_invalid_chars(sort_expr)?;
97 let order_by_clauses: Vec<String> = sort_expr
98 .split(|c| c == '+' || c == '-') .filter(|s| !s.is_empty()) .map(|key| {
101 let order = if sort_expr.contains(&format!("+{}", key)) {
102 "ASC"
103 } else {
104 "DESC"
105 };
106 format!("{} {}", field_to_key_t(key), order) })
108 .collect();
109
110 if order_by_clauses.is_empty() {
111 Ok("_id ASC".to_string())
112 } else {
113 Ok(format!("{}", order_by_clauses.join(", ")))
114 }
115}
116
117fn field_to_key(field: &str) -> String {
118 _field_to_key(field, false)
119}
120
121fn field_to_key_t(field: &str) -> String {
122 _field_to_key(field, true)
123}
124
125fn _field_to_key(field: &str, text: bool) -> String {
126 let v = field
127 .split(".")
128 .map(|s| format!("{}", s))
129 .collect::<Vec<_>>()
130 .join(",");
131 let t = if text { ">" } else { "" };
132 format!("data#>{}'{{{}}}'", t, v)
133}
134
135fn normal_comparison(
136 seq: &mut u32,
137 bindings: &mut Vec<Box<dyn ToSql + Sync>>,
138 field: &str,
139 op: &str,
140 value: &Value,
141) -> Result<String> {
142 let s = *seq;
143 *seq += 1;
144 match value {
145 Value::String(v) => match field {
146 "_id" => {
147 bindings.push(Box::new(v.clone()));
148 Ok(format!("{} {} ${}", field, op, s))
149 }
150 _ => {
151 bindings.push(Box::new(value.clone()));
152 Ok(format!("{} {} ${}", field_to_key(field), op, s))
153 }
154 },
155 Value::Number(v) => match field {
156 "_created_at" | "_updated_at" => {
157 bindings.push(Box::new(v.as_i64().unwrap()));
158 Ok(format!("{} {} ${}", field, op, s))
159 }
160 _ => {
161 bindings.push(Box::new(value.clone()));
162 Ok(format!("{} {} ${}", field_to_key(field), op, s))
163 }
164 },
165 _ => {
166 bail!("Invalid value for op {}", op)
167 }
168 }
169}
170
171fn array_comparison(
172 seq: &mut u32,
173 bindings: &mut Vec<Box<dyn ToSql + Sync>>,
174 field: &str,
175 value: &Value,
176) -> Result<String> {
177 match value {
178 Value::Array(arr) => {
179 if arr.is_empty() {
180 return Ok("FALSE".to_string());
181 } else {
182 let s = *seq;
183 *seq += 1;
184 match field {
185 "_id" => {
186 let arr = arr.iter().map(|v| v.as_str()).collect::<Vec<_>>();
187 if arr.contains(&None) {
188 bail!("Invalid value for field {}", field)
189 } else {
190 bindings.push(Box::new(
191 arr.into_iter()
192 .map(|v| v.unwrap().to_string())
193 .collect::<Vec<_>>(),
194 ));
195 Ok(format!("{} = ANY(${})", field, s))
196 }
197 }
198 "_created_at" | "_updated_at" => {
199 let arr = arr.iter().map(|v| v.as_i64()).collect::<Vec<_>>();
200 if arr.contains(&None) {
201 bail!("Invalid value for field {}", field)
202 } else {
203 bindings.push(Box::new(
204 arr.into_iter().map(|v| v.unwrap()).collect::<Vec<_>>(),
205 ));
206 Ok(format!("{} = ANY(${})", field, s))
207 }
208 }
209 _ => {
210 let array_type = match arr.first() {
211 Some(Value::Number(_)) => "FLOAT8",
212 Some(Value::String(_)) => "TEXT",
213 _ => bail!("Invalid value for op IN"),
214 };
215
216 if array_type == "FLOAT8" {
217 let arr = arr.iter().map(|v| v.as_f64()).collect::<Vec<_>>();
218 if arr.contains(&None) {
219 bail!("Invalid value for field {}", field)
220 } else {
221 bindings.push(Box::new(
222 arr.into_iter().map(|v| v.unwrap()).collect::<Vec<_>>(),
223 ));
224 }
225 Ok(format!(
226 "({})::{} = ANY(${})",
227 field_to_key(field),
228 array_type,
229 s
230 ))
231 } else {
232 let arr = arr.iter().map(|v| v.as_str()).collect::<Vec<_>>();
233 if arr.contains(&None) {
234 bail!("Invalid value for field {}", field)
235 } else {
236 bindings.push(Box::new(
237 arr.into_iter()
238 .map(|v| v.unwrap().to_string())
239 .collect::<Vec<_>>(),
240 ));
241 }
242 Ok(format!("{} = ANY(${})", field_to_key_t(field), s))
243 }
244 }
245 }
246 }
247 }
248 _ => bail!("Invalid value for op IN"),
249 }
250}
251
252fn check_invalid_chars(input: &str) -> Result<()> {
253 if input.contains('\'') || input.contains('(') || input.contains(')') || input.contains(';') {
254 bail!("Invalid characters in input");
255 }
256 Ok(())
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use rest_model::Condition;
263 use serde_json::json;
264 use tokio_postgres::types::ToSql;
265
266 #[test]
267 fn test_sort_to_sql() {
268 assert_eq!(sort_to_sql("+name").unwrap(), "data#>>'{name}' ASC");
269 assert_eq!(sort_to_sql("-age").unwrap(), "data#>>'{age}' DESC");
270 assert_eq!(sort_to_sql("+name-age").unwrap(), "data#>>'{name}' ASC, data#>>'{age}' DESC");
271 assert_eq!(sort_to_sql("").unwrap(), "_id ASC");
272 }
273
274 #[test]
275 fn test_normal_comparison() {
276 let mut seq = 1;
277 let mut bindings: Vec<Box<dyn ToSql + Sync>> = Vec::new();
278
279 let sql = normal_comparison(&mut seq, &mut bindings, "_id", "=", &json!("123")).unwrap();
280 assert_eq!(sql, "_id = $1");
281 assert_eq!(bindings.len(), 1);
282
283 let sql = normal_comparison(&mut seq, &mut bindings, "_created_at", ">", &json!(123456)).unwrap();
284 assert_eq!(sql, "_created_at > $2");
285 assert_eq!(bindings.len(), 2);
286
287 let sql = normal_comparison(&mut seq, &mut bindings, "age", "<", &json!(30)).unwrap();
288 assert_eq!(sql, "data#>'{age}' < $3");
289 assert_eq!(bindings.len(), 3);
290 }
291
292 #[test]
293 fn test_array_comparison() {
294 let mut seq = 1;
295 let mut bindings: Vec<Box<dyn ToSql + Sync>> = Vec::new();
296
297 let sql = array_comparison(&mut seq, &mut bindings, "_id", &json!(["a1", "b2", "c3"])).unwrap();
298 assert_eq!(sql, "_id = ANY($1)");
299 assert_eq!(bindings.len(), 1);
300
301 let sql = array_comparison(&mut seq, &mut bindings, "age", &json!([20, 25, 30])).unwrap();
302 assert!(sql.contains("ANY($2)"));
303 assert_eq!(bindings.len(), 2);
304 }
305
306 #[test]
307 fn test_cond_to_sql() {
308 let mut seq = 1;
309 let mut bindings: Vec<Box<dyn ToSql + Sync>> = Vec::new();
310
311 let cond = Condition::Eq("_id".to_string(), json!("123"));
312 let sql = cond_to_sql(&cond, &mut bindings, &mut seq).unwrap();
313 assert_eq!(sql, "_id = $1");
314 assert_eq!(bindings.len(), 1);
315
316 let cond = Condition::In("age".to_string(), json!([25, 30, 35]));
317 let sql = cond_to_sql(&cond, &mut bindings, &mut seq).unwrap();
318 assert!(sql.contains("ANY($2)"));
319 assert_eq!(bindings.len(), 2);
320
321 let cond = Condition::And(vec![
322 Box::new(Condition::Gt("score".to_string(), json!(80))),
323 Box::new(Condition::Lt("score".to_string(), json!(100))),
324 ]);
325 let sql = cond_to_sql(&cond, &mut bindings, &mut seq).unwrap();
326 assert_eq!(&sql, "(data#>'{score}' > $3 AND data#>'{score}' < $4)");
327 assert_eq!(bindings.len(), 4);
328 }
329
330 #[test]
331 fn test_regex_conditions() {
332 let mut seq = 1;
333 let mut bindings: Vec<Box<dyn ToSql + Sync>> = Vec::new();
334
335 let cond = Condition::Regex("name".to_string(), json!("^J.*"));
336 let sql = cond_to_sql(&cond, &mut bindings, &mut seq).unwrap();
337 assert_eq!(sql, "data#>>'{name}' ~ $1");
338 assert_eq!(bindings.len(), 1);
339
340 let cond = Condition::Regexi("name".to_string(), json!("^J.*"));
341 let sql = cond_to_sql(&cond, &mut bindings, &mut seq).unwrap();
342 assert_eq!(sql, "data#>>'{name}' ~* $2");
343 assert_eq!(bindings.len(), 2);
344 }
345
346 #[test]
347 fn test_invalid_chars() {
348 assert!(check_invalid_chars("valid_field").is_ok());
349 assert!(check_invalid_chars("invalid'field").is_err());
350 assert!(check_invalid_chars("invalid(field)").is_err());
351 }
352}