Skip to main content

rest_sql/
dsl.rs

1use crate::ast::Ast;
2use crate::error::{RestSqlError, ValidationError};
3use crate::mapper::FieldMapper;
4use crate::parsing::parse;
5use crate::{Constraint, Operator, Value};
6
7#[derive(Debug, Clone)]
8pub struct RestSql(Ast);
9
10impl RestSql {
11    pub fn new(query: &str) -> Result<Self, RestSqlError> {
12        let ast = parse(query).map_err(RestSqlError::ParseError)?;
13        let ast = validate_inner(&ast, None).map_err(RestSqlError::ValidationError)?;
14        Ok(Self(ast))
15    }
16
17    pub fn new_for_fields(query: &str, allowed: &[&str]) -> Result<Self, RestSqlError> {
18        let ast = parse(query).map_err(RestSqlError::ParseError)?;
19        let ast = validate_inner(&ast, Some(allowed)).map_err(RestSqlError::ValidationError)?;
20        Ok(Self(ast))
21    }
22
23    #[cfg(feature = "serde")]
24    pub fn new_for<T>(query: &str) -> Result<Self, RestSqlError>
25    where
26        T: for<'de> serde::Deserialize<'de>,
27    {
28        Self::new_for_fields(query, serde_fields::<T>())
29    }
30
31    /// Returns a new `RestSql` with all field names transformed by `mapper`.
32    pub fn map_fields(&self, mapper: &impl FieldMapper) -> Self {
33        Self(apply_mapper(&self.0, mapper))
34    }
35
36    /// Returns the distinct field names referenced in this filter.
37    pub fn fields(&self) -> Vec<&str> {
38        fields(&self.0)
39    }
40
41    /// Builds a `RestSql` from a programmatically-constructed AST.
42    ///
43    /// Runs the same validation as `new` (operator/value compatibility, list
44    /// arity for `Between`, etc.). Field allowlisting is skipped — call
45    /// `from_ast_for_fields` or `from_ast_for::<T>()` if you need it.
46    pub fn from_ast(ast: Ast) -> Result<Self, RestSqlError> {
47        let ast = validate_inner(&ast, None).map_err(RestSqlError::ValidationError)?;
48        Ok(Self(ast))
49    }
50
51    /// Like `from_ast`, but also enforces a field allowlist.
52    ///
53    /// Any field in the AST that is not in `allowed` causes a `ValidationError`.
54    pub fn from_ast_for_fields(ast: Ast, allowed: &[&str]) -> Result<Self, RestSqlError> {
55        let ast = validate_inner(&ast, Some(allowed)).map_err(RestSqlError::ValidationError)?;
56        Ok(Self(ast))
57    }
58
59    /// Like `from_ast`, but derives the field allowlist from `T`'s `Deserialize` impl.
60    ///
61    /// Mirrors `new_for::<T>()` for programmatically-built ASTs — ensures that
62    /// fields injected via the DSL are still subject to the same allowlist as
63    /// fields coming from a user-supplied RSQL string.
64    #[cfg(feature = "serde")]
65    pub fn from_ast_for<T>(ast: Ast) -> Result<Self, RestSqlError>
66    where
67        T: for<'de> serde::Deserialize<'de>,
68    {
69        Self::from_ast_for_fields(ast, serde_fields::<T>())
70    }
71
72    /// Exposes the validated AST — for use by drivers.
73    pub fn ast(&self) -> &Ast {
74        &self.0
75    }
76}
77
78fn apply_mapper(ast: &Ast, mapper: &impl FieldMapper) -> Ast {
79    match ast {
80        Ast::And(children) => Ast::And(children.iter().map(|c| apply_mapper(c, mapper)).collect()),
81        Ast::Or(children) => Ast::Or(children.iter().map(|c| apply_mapper(c, mapper)).collect()),
82        Ast::Constraint(c) => Ast::Constraint(Constraint {
83            field: mapper.map(&c.field).into_owned(),
84            operator: c.operator.clone(),
85            value: c.value.clone(),
86        }),
87    }
88}
89
90pub(crate) fn validate_inner(
91    ast: &Ast,
92    allowed: Option<&[&str]>,
93) -> Result<Ast, Vec<ValidationError>> {
94    let mut errors = Vec::new();
95    let result = validate_node(ast, allowed, &mut errors);
96    if errors.is_empty() {
97        Ok(result.unwrap())
98    } else {
99        Err(errors)
100    }
101}
102
103fn validate_node(
104    ast: &Ast,
105    allowed: Option<&[&str]>,
106    errors: &mut Vec<ValidationError>,
107) -> Option<Ast> {
108    match ast {
109        Ast::And(children) => {
110            let nodes: Vec<_> = children
111                .iter()
112                .filter_map(|c| validate_node(c, allowed, errors))
113                .collect();
114            if nodes.len() == children.len() {
115                Some(Ast::And(nodes))
116            } else {
117                None
118            }
119        }
120        Ast::Or(children) => {
121            let nodes: Vec<_> = children
122                .iter()
123                .filter_map(|c| validate_node(c, allowed, errors))
124                .collect();
125            if nodes.len() == children.len() {
126                Some(Ast::Or(nodes))
127            } else {
128                None
129            }
130        }
131        Ast::Constraint(c) => validate_constraint(c, allowed, errors),
132    }
133}
134
135/// Extract the list of distinct field names referenced in a DSL tree.
136pub fn fields(ast: &Ast) -> Vec<&str> {
137    let mut out = Vec::new();
138    collect_fields(ast, &mut out);
139    out.sort();
140    out.dedup();
141    out
142}
143
144fn collect_fields<'a>(ast: &'a Ast, out: &mut Vec<&'a str>) {
145    match ast {
146        Ast::And(v) | Ast::Or(v) => v.iter().for_each(|n| collect_fields(n, out)),
147        Ast::Constraint(c) => out.push(&c.field),
148    }
149}
150
151fn validate_constraint(
152    c: &Constraint,
153    allowed: Option<&[&str]>,
154    errors: &mut Vec<ValidationError>,
155) -> Option<Ast> {
156    if let Some(allowed) = allowed
157        && !allowed.contains(&c.field.as_str())
158    {
159        errors.push(ValidationError::ForbiddenField(c.field.clone()));
160        return None;
161    }
162
163    let op_name = format!("{:?}", c.operator);
164
165    let value = match &c.operator {
166        Operator::In | Operator::Out => {
167            if !matches!(c.value, Value::List(_)) {
168                errors.push(ValidationError::ExpectedList {
169                    field: c.field.clone(),
170                    operator: op_name,
171                });
172                return None;
173            }
174            &c.value
175        }
176        Operator::Between => match &c.value {
177            Value::List(v) if v.len() == 2 => &c.value,
178            Value::List(_) => {
179                errors.push(ValidationError::BetweenArity {
180                    field: c.field.clone(),
181                    operator: op_name,
182                });
183                return None;
184            }
185            _ => {
186                errors.push(ValidationError::ExpectedList {
187                    field: c.field.clone(),
188                    operator: op_name,
189                });
190                return None;
191            }
192        },
193        Operator::Null | Operator::NotNull => &c.value,
194        _ => {
195            if matches!(c.value, Value::List(_)) {
196                errors.push(ValidationError::UnexpectedList {
197                    field: c.field.clone(),
198                    operator: op_name,
199                });
200                return None;
201            }
202            &c.value
203        }
204    };
205
206    Some(Ast::Constraint(Constraint {
207        field: c.field.clone(),
208        operator: c.operator.clone(),
209        value: value.clone(),
210    }))
211}
212
213#[cfg(feature = "serde")]
214mod serde_support {
215    use serde::de::{self, Deserializer, Visitor};
216    use std::fmt;
217
218    struct FieldExtractor;
219
220    enum ExtractErr {
221        Fields(&'static [&'static str]),
222    }
223
224    impl fmt::Display for ExtractErr {
225        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226            write!(f, "field extraction")
227        }
228    }
229
230    impl fmt::Debug for ExtractErr {
231        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232            write!(f, "ExtractErr")
233        }
234    }
235
236    impl std::error::Error for ExtractErr {}
237
238    impl de::Error for ExtractErr {
239        fn custom<T: fmt::Display>(_: T) -> Self {
240            ExtractErr::Fields(&[])
241        }
242    }
243
244    impl<'de> Deserializer<'de> for FieldExtractor {
245        type Error = ExtractErr;
246
247        fn deserialize_any<V: Visitor<'de>>(self, _: V) -> Result<V::Value, ExtractErr> {
248            Err(ExtractErr::Fields(&[]))
249        }
250
251        fn deserialize_struct<V: Visitor<'de>>(
252            self,
253            _name: &'static str,
254            fields: &'static [&'static str],
255            _visitor: V,
256        ) -> Result<V::Value, ExtractErr> {
257            Err(ExtractErr::Fields(fields))
258        }
259
260        serde::forward_to_deserialize_any! {
261            bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
262            bytes byte_buf option unit unit_struct newtype_struct seq tuple
263            tuple_struct map enum identifier ignored_any
264        }
265    }
266
267    /// Returns the serde field names of `T` without allocating.
268    ///
269    /// Works with any `#[derive(Deserialize)]` struct. Returns `&[]` for
270    /// non-struct types (enums, maps, tuples).
271    pub fn serde_fields<'de, T: serde::Deserialize<'de>>() -> &'static [&'static str] {
272        match T::deserialize(FieldExtractor) {
273            Err(ExtractErr::Fields(f)) => f,
274            _ => &[],
275        }
276    }
277}
278
279#[cfg(feature = "serde")]
280pub use serde_support::serde_fields;