rest_model_postgres/
query.rs

1use anyhow::{bail, Result};
2use rest_model::Condition;
3use serde_json::Value;
4use tokio_postgres::types::ToSql;
5
6pub fn cond_to_sql(
7    cond: &Condition,
8    bindings: &mut Vec<Box<dyn ToSql + Sync>>,
9    seq: &mut u32,
10) -> Result<String> {
11    match cond {
12        Condition::And(conds) => Ok(format!(
13            "({})",
14            conds
15                .iter()
16                .map(|c| cond_to_sql(c, bindings, seq))
17                .collect::<Result<Vec<_>>>()?
18                .join(" AND "),
19        )),
20        Condition::Or(conds) => Ok(format!(
21            "({})",
22            conds
23                .iter()
24                .map(|c| cond_to_sql(c, bindings, seq))
25                .collect::<Result<Vec<_>>>()?
26                .join(" OR "),
27        )),
28        Condition::Not(cond) => Ok(format!("(NOT ({}))", cond_to_sql(cond, bindings, seq)?)),
29        Condition::Regex(field, value) => {
30            check_invalid_chars(field)?;
31            let s = *seq;
32            *seq += 1;
33            match value {
34                Value::String(v) => {
35                    bindings.push(Box::new(v.clone()));
36                    Ok(format!("{} ~ ${}", field_to_key_t(&field), s))
37                }
38                _ => {
39                    bail!("Invalid value for Regex")
40                }
41            }
42        }
43        Condition::Regexi(field, value) => {
44            check_invalid_chars(field)?;
45            let s = *seq;
46            *seq += 1;
47            match value {
48                Value::String(v) => {
49                    bindings.push(Box::new(v.clone()));
50                    Ok(format!("{} ~* ${}", field_to_key_t(&field), s))
51                }
52                _ => {
53                    bail!("Invalid value for Regexi")
54                }
55            }
56        }
57        Condition::Eq(field, value) => {
58            check_invalid_chars(field)?;
59            normal_comparison(seq, bindings, field, "=", value)
60        }
61        Condition::Ne(field, value) => {
62            check_invalid_chars(field)?;
63            normal_comparison(seq, bindings, field, "!=", value)
64        }
65        Condition::Gt(field, value) => {
66            check_invalid_chars(field)?;
67            normal_comparison(seq, bindings, field, ">", value)
68        }
69        Condition::Gte(field, value) => {
70            check_invalid_chars(field)?;
71            normal_comparison(seq, bindings, field, ">=", value)
72        }
73        Condition::Lt(field, value) => {
74            check_invalid_chars(field)?;
75            normal_comparison(seq, bindings, field, "<", value)
76        }
77        Condition::Lte(field, value) => {
78            check_invalid_chars(field)?;
79            normal_comparison(seq, bindings, field, "<=", value)
80        }
81        Condition::In(field, value) => {
82            check_invalid_chars(field)?;
83            array_comparison(seq, bindings, field, value)
84        }
85        Condition::Nin(field, value) => {
86            check_invalid_chars(field)?;
87            Ok(format!(
88                "(NOT ({}))",
89                array_comparison(seq, bindings, field, value)?,
90            ))
91        }
92    }
93}
94
95pub fn sort_to_sql(sort_expr: &str) -> Result<String> {
96    check_invalid_chars(sort_expr)?;
97    let order_by_clauses: Vec<String> = sort_expr
98        .split(|c| c == '+' || c == '-') // 按 `+` 和 `-` 拆分
99        .filter(|s| !s.is_empty()) // 去除空字符串
100        .map(|key| {
101            let order = if sort_expr.contains(&format!("+{}", key)) {
102                "ASC"
103            } else {
104                "DESC"
105            };
106            format!("{} {}", field_to_key_t(key), order) // 组合成 `key ASC/DESC`
107        })
108        .collect();
109
110    if order_by_clauses.is_empty() {
111        Ok("_id ASC".to_string())
112    } else {
113        Ok(format!("{}", order_by_clauses.join(", ")))
114    }
115}
116
117fn field_to_key(field: &str) -> String {
118    _field_to_key(field, false)
119}
120
121fn field_to_key_t(field: &str) -> String {
122    _field_to_key(field, true)
123}
124
125fn _field_to_key(field: &str, text: bool) -> String {
126    let v = field
127        .split(".")
128        .map(|s| format!("{}", s))
129        .collect::<Vec<_>>()
130        .join(",");
131    let t = if text { ">" } else { "" };
132    format!("data#>{}'{{{}}}'", t, v)
133}
134
135fn normal_comparison(
136    seq: &mut u32,
137    bindings: &mut Vec<Box<dyn ToSql + Sync>>,
138    field: &str,
139    op: &str,
140    value: &Value,
141) -> Result<String> {
142    let s = *seq;
143    *seq += 1;
144    match value {
145        Value::String(v) => match field {
146            "_id" => {
147                bindings.push(Box::new(v.clone()));
148                Ok(format!("{} {} ${}", field, op, s))
149            }
150            _ => {
151                bindings.push(Box::new(value.clone()));
152                Ok(format!("{} {} ${}", field_to_key(field), op, s))
153            }
154        },
155        Value::Number(v) => match field {
156            "_created_at" | "_updated_at" => {
157                bindings.push(Box::new(v.as_i64().unwrap()));
158                Ok(format!("{} {} ${}", field, op, s))
159            }
160            _ => {
161                bindings.push(Box::new(value.clone()));
162                Ok(format!("{} {} ${}", field_to_key(field), op, s))
163            }
164        },
165        _ => {
166            bail!("Invalid value for op {}", op)
167        }
168    }
169}
170
171fn array_comparison(
172    seq: &mut u32,
173    bindings: &mut Vec<Box<dyn ToSql + Sync>>,
174    field: &str,
175    value: &Value,
176) -> Result<String> {
177    match value {
178        Value::Array(arr) => {
179            if arr.is_empty() {
180                return Ok("FALSE".to_string());
181            } else {
182                let s = *seq;
183                *seq += 1;
184                match field {
185                    "_id" => {
186                        let arr = arr.iter().map(|v| v.as_str()).collect::<Vec<_>>();
187                        if arr.contains(&None) {
188                            bail!("Invalid value for field {}", field)
189                        } else {
190                            bindings.push(Box::new(
191                                arr.into_iter()
192                                    .map(|v| v.unwrap().to_string())
193                                    .collect::<Vec<_>>(),
194                            ));
195                            Ok(format!("{} = ANY(${})", field, s))
196                        }
197                    }
198                    "_created_at" | "_updated_at" => {
199                        let arr = arr.iter().map(|v| v.as_i64()).collect::<Vec<_>>();
200                        if arr.contains(&None) {
201                            bail!("Invalid value for field {}", field)
202                        } else {
203                            bindings.push(Box::new(
204                                arr.into_iter().map(|v| v.unwrap()).collect::<Vec<_>>(),
205                            ));
206                            Ok(format!("{} = ANY(${})", field, s))
207                        }
208                    }
209                    _ => {
210                        let array_type = match arr.first() {
211                            Some(Value::Number(_)) => "FLOAT8",
212                            Some(Value::String(_)) => "TEXT",
213                            _ => bail!("Invalid value for op IN"),
214                        };
215
216                        if array_type == "FLOAT8" {
217                            let arr = arr.iter().map(|v| v.as_f64()).collect::<Vec<_>>();
218                            if arr.contains(&None) {
219                                bail!("Invalid value for field {}", field)
220                            } else {
221                                bindings.push(Box::new(
222                                    arr.into_iter().map(|v| v.unwrap()).collect::<Vec<_>>(),
223                                ));
224                            }
225                            Ok(format!(
226                                "({})::{} = ANY(${})",
227                                field_to_key(field),
228                                array_type,
229                                s
230                            ))
231                        } else {
232                            let arr = arr.iter().map(|v| v.as_str()).collect::<Vec<_>>();
233                            if arr.contains(&None) {
234                                bail!("Invalid value for field {}", field)
235                            } else {
236                                bindings.push(Box::new(
237                                    arr.into_iter()
238                                        .map(|v| v.unwrap().to_string())
239                                        .collect::<Vec<_>>(),
240                                ));
241                            }
242                            Ok(format!("{} = ANY(${})", field_to_key_t(field), s))
243                        }
244                    }
245                }
246            }
247        }
248        _ => bail!("Invalid value for op IN"),
249    }
250}
251
252fn check_invalid_chars(input: &str) -> Result<()> {
253    if input.contains('\'') || input.contains('(') || input.contains(')') || input.contains(';') {
254        bail!("Invalid characters in input");
255    }
256    Ok(())
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use rest_model::Condition;
263    use serde_json::json;
264    use tokio_postgres::types::ToSql;
265
266    #[test]
267    fn test_sort_to_sql() {
268        assert_eq!(sort_to_sql("+name").unwrap(), "data#>>'{name}' ASC");
269        assert_eq!(sort_to_sql("-age").unwrap(), "data#>>'{age}' DESC");
270        assert_eq!(sort_to_sql("+name-age").unwrap(), "data#>>'{name}' ASC, data#>>'{age}' DESC");
271        assert_eq!(sort_to_sql("").unwrap(), "_id ASC");
272    }
273
274    #[test]
275    fn test_normal_comparison() {
276        let mut seq = 1;
277        let mut bindings: Vec<Box<dyn ToSql + Sync>> = Vec::new();
278
279        let sql = normal_comparison(&mut seq, &mut bindings, "_id", "=", &json!("123")).unwrap();
280        assert_eq!(sql, "_id = $1");
281        assert_eq!(bindings.len(), 1);
282
283        let sql = normal_comparison(&mut seq, &mut bindings, "_created_at", ">", &json!(123456)).unwrap();
284        assert_eq!(sql, "_created_at > $2");
285        assert_eq!(bindings.len(), 2);
286
287        let sql = normal_comparison(&mut seq, &mut bindings, "age", "<", &json!(30)).unwrap();
288        assert_eq!(sql, "data#>'{age}' < $3");
289        assert_eq!(bindings.len(), 3);
290    }
291
292    #[test]
293    fn test_array_comparison() {
294        let mut seq = 1;
295        let mut bindings: Vec<Box<dyn ToSql + Sync>> = Vec::new();
296
297        let sql = array_comparison(&mut seq, &mut bindings, "_id", &json!(["a1", "b2", "c3"])).unwrap();
298        assert_eq!(sql, "_id = ANY($1)");
299        assert_eq!(bindings.len(), 1);
300
301        let sql = array_comparison(&mut seq, &mut bindings, "age", &json!([20, 25, 30])).unwrap();
302        assert!(sql.contains("ANY($2)"));
303        assert_eq!(bindings.len(), 2);
304    }
305
306    #[test]
307    fn test_cond_to_sql() {
308        let mut seq = 1;
309        let mut bindings: Vec<Box<dyn ToSql + Sync>> = Vec::new();
310
311        let cond = Condition::Eq("_id".to_string(), json!("123"));
312        let sql = cond_to_sql(&cond, &mut bindings, &mut seq).unwrap();
313        assert_eq!(sql, "_id = $1");
314        assert_eq!(bindings.len(), 1);
315
316        let cond = Condition::In("age".to_string(), json!([25, 30, 35]));
317        let sql = cond_to_sql(&cond, &mut bindings, &mut seq).unwrap();
318        assert!(sql.contains("ANY($2)"));
319        assert_eq!(bindings.len(), 2);
320
321        let cond = Condition::And(vec![
322            Box::new(Condition::Gt("score".to_string(), json!(80))),
323            Box::new(Condition::Lt("score".to_string(), json!(100))),
324        ]);
325        let sql = cond_to_sql(&cond, &mut bindings, &mut seq).unwrap();
326        assert_eq!(&sql, "(data#>'{score}' > $3 AND data#>'{score}' < $4)");
327        assert_eq!(bindings.len(), 4);
328    }
329
330    #[test]
331    fn test_regex_conditions() {
332        let mut seq = 1;
333        let mut bindings: Vec<Box<dyn ToSql + Sync>> = Vec::new();
334
335        let cond = Condition::Regex("name".to_string(), json!("^J.*"));
336        let sql = cond_to_sql(&cond, &mut bindings, &mut seq).unwrap();
337        assert_eq!(sql, "data#>>'{name}' ~ $1");
338        assert_eq!(bindings.len(), 1);
339
340        let cond = Condition::Regexi("name".to_string(), json!("^J.*"));
341        let sql = cond_to_sql(&cond, &mut bindings, &mut seq).unwrap();
342        assert_eq!(sql, "data#>>'{name}' ~* $2");
343        assert_eq!(bindings.len(), 2);
344    }
345
346    #[test]
347    fn test_invalid_chars() {
348        assert!(check_invalid_chars("valid_field").is_ok());
349        assert!(check_invalid_chars("invalid'field").is_err());
350        assert!(check_invalid_chars("invalid(field)").is_err());
351    }
352}