Skip to main content

rest_sql/
dsl.rs

1use crate::ast::Ast;
2use crate::error::{RestSqlError, ValidationError};
3use crate::mapper::FieldMapper;
4use crate::parsing::parse;
5use crate::{Constraint, Operator, Value};
6
7#[derive(Debug, Clone)]
8pub struct RestSql(Ast);
9
10impl RestSql {
11    pub fn new(query: &str) -> Result<Self, RestSqlError> {
12        let ast = parse(query).map_err(RestSqlError::ParseError)?;
13        let ast = validate_inner(&ast, None).map_err(RestSqlError::ValidationError)?;
14        Ok(Self(ast))
15    }
16
17    pub fn new_for_fields(query: &str, allowed: &[&str]) -> Result<Self, RestSqlError> {
18        let ast = parse(query).map_err(RestSqlError::ParseError)?;
19        let ast = validate_inner(&ast, Some(allowed)).map_err(RestSqlError::ValidationError)?;
20        Ok(Self(ast))
21    }
22
23    #[cfg(feature = "serde")]
24    pub fn new_for<T>(query: &str) -> Result<Self, RestSqlError>
25    where
26        T: for<'de> serde::Deserialize<'de>,
27    {
28        Self::new_for_fields(query, serde_fields::<T>())
29    }
30
31    /// Returns a new `RestSql` with all field names transformed by `mapper`.
32    pub fn map_fields(&self, mapper: &impl FieldMapper) -> Self {
33        Self(apply_mapper(&self.0, mapper))
34    }
35
36    /// Returns the distinct field names referenced in this filter.
37    pub fn fields(&self) -> Vec<&str> {
38        fields(&self.0)
39    }
40
41    /// Exposes the validated AST — for use by drivers.
42    pub fn ast(&self) -> &Ast {
43        &self.0
44    }
45}
46
47fn apply_mapper(ast: &Ast, mapper: &impl FieldMapper) -> Ast {
48    match ast {
49        Ast::And(children) => Ast::And(children.iter().map(|c| apply_mapper(c, mapper)).collect()),
50        Ast::Or(children) => Ast::Or(children.iter().map(|c| apply_mapper(c, mapper)).collect()),
51        Ast::Constraint(c) => Ast::Constraint(Constraint {
52            field: mapper.map(&c.field).into_owned(),
53            operator: c.operator.clone(),
54            value: c.value.clone(),
55        }),
56    }
57}
58
59pub(crate) fn validate_inner(
60    ast: &Ast,
61    allowed: Option<&[&str]>,
62) -> Result<Ast, Vec<ValidationError>> {
63    let mut errors = Vec::new();
64    let result = validate_node(ast, allowed, &mut errors);
65    if errors.is_empty() {
66        Ok(result.unwrap())
67    } else {
68        Err(errors)
69    }
70}
71
72fn validate_node(
73    ast: &Ast,
74    allowed: Option<&[&str]>,
75    errors: &mut Vec<ValidationError>,
76) -> Option<Ast> {
77    match ast {
78        Ast::And(children) => {
79            let nodes: Vec<_> = children
80                .iter()
81                .filter_map(|c| validate_node(c, allowed, errors))
82                .collect();
83            if nodes.len() == children.len() {
84                Some(Ast::And(nodes))
85            } else {
86                None
87            }
88        }
89        Ast::Or(children) => {
90            let nodes: Vec<_> = children
91                .iter()
92                .filter_map(|c| validate_node(c, allowed, errors))
93                .collect();
94            if nodes.len() == children.len() {
95                Some(Ast::Or(nodes))
96            } else {
97                None
98            }
99        }
100        Ast::Constraint(c) => validate_constraint(c, allowed, errors),
101    }
102}
103
104/// Extract the list of distinct field names referenced in a DSL tree.
105pub fn fields(ast: &Ast) -> Vec<&str> {
106    let mut out = Vec::new();
107    collect_fields(ast, &mut out);
108    out.sort();
109    out.dedup();
110    out
111}
112
113fn collect_fields<'a>(ast: &'a Ast, out: &mut Vec<&'a str>) {
114    match ast {
115        Ast::And(v) | Ast::Or(v) => v.iter().for_each(|n| collect_fields(n, out)),
116        Ast::Constraint(c) => out.push(&c.field),
117    }
118}
119
120fn validate_constraint(
121    c: &Constraint,
122    allowed: Option<&[&str]>,
123    errors: &mut Vec<ValidationError>,
124) -> Option<Ast> {
125    if let Some(allowed) = allowed
126        && !allowed.contains(&c.field.as_str())
127    {
128        errors.push(ValidationError::ForbiddenField(c.field.clone()));
129        return None;
130    }
131
132    let op_name = format!("{:?}", c.operator);
133
134    let value = match &c.operator {
135        Operator::In | Operator::Out => {
136            if !matches!(c.value, Value::List(_)) {
137                errors.push(ValidationError::ExpectedList {
138                    field: c.field.clone(),
139                    operator: op_name,
140                });
141                return None;
142            }
143            &c.value
144        }
145        Operator::Between => match &c.value {
146            Value::List(v) if v.len() == 2 => &c.value,
147            Value::List(_) => {
148                errors.push(ValidationError::BetweenArity {
149                    field: c.field.clone(),
150                    operator: op_name,
151                });
152                return None;
153            }
154            _ => {
155                errors.push(ValidationError::ExpectedList {
156                    field: c.field.clone(),
157                    operator: op_name,
158                });
159                return None;
160            }
161        },
162        Operator::Null | Operator::NotNull => &c.value,
163        _ => {
164            if matches!(c.value, Value::List(_)) {
165                errors.push(ValidationError::UnexpectedList {
166                    field: c.field.clone(),
167                    operator: op_name,
168                });
169                return None;
170            }
171            &c.value
172        }
173    };
174
175    Some(Ast::Constraint(Constraint {
176        field: c.field.clone(),
177        operator: c.operator.clone(),
178        value: value.clone(),
179    }))
180}
181
182#[cfg(feature = "serde")]
183mod serde_support {
184    use serde::de::{self, Deserializer, Visitor};
185    use std::fmt;
186
187    struct FieldExtractor;
188
189    enum ExtractErr {
190        Fields(&'static [&'static str]),
191    }
192
193    impl fmt::Display for ExtractErr {
194        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195            write!(f, "field extraction")
196        }
197    }
198
199    impl fmt::Debug for ExtractErr {
200        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201            write!(f, "ExtractErr")
202        }
203    }
204
205    impl std::error::Error for ExtractErr {}
206
207    impl de::Error for ExtractErr {
208        fn custom<T: fmt::Display>(_: T) -> Self {
209            ExtractErr::Fields(&[])
210        }
211    }
212
213    impl<'de> Deserializer<'de> for FieldExtractor {
214        type Error = ExtractErr;
215
216        fn deserialize_any<V: Visitor<'de>>(self, _: V) -> Result<V::Value, ExtractErr> {
217            Err(ExtractErr::Fields(&[]))
218        }
219
220        fn deserialize_struct<V: Visitor<'de>>(
221            self,
222            _name: &'static str,
223            fields: &'static [&'static str],
224            _visitor: V,
225        ) -> Result<V::Value, ExtractErr> {
226            Err(ExtractErr::Fields(fields))
227        }
228
229        serde::forward_to_deserialize_any! {
230            bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
231            bytes byte_buf option unit unit_struct newtype_struct seq tuple
232            tuple_struct map enum identifier ignored_any
233        }
234    }
235
236    /// Returns the serde field names of `T` without allocating.
237    ///
238    /// Works with any `#[derive(Deserialize)]` struct. Returns `&[]` for
239    /// non-struct types (enums, maps, tuples).
240    pub fn serde_fields<'de, T: serde::Deserialize<'de>>() -> &'static [&'static str] {
241        match T::deserialize(FieldExtractor) {
242            Err(ExtractErr::Fields(f)) => f,
243            _ => &[],
244        }
245    }
246}
247
248#[cfg(feature = "serde")]
249pub use serde_support::serde_fields;