1use crate::ast::Ast;
2use crate::error::{RestSqlError, ValidationError};
3use crate::mapper::FieldMapper;
4use crate::parsing::parse;
5use crate::{Constraint, Operator, Value};
6
7#[derive(Debug, Clone)]
8pub struct RestSql(Ast);
9
10impl RestSql {
11 pub fn new(query: &str) -> Result<Self, RestSqlError> {
12 let ast = parse(query).map_err(RestSqlError::ParseError)?;
13 let ast = validate_inner(&ast, None).map_err(RestSqlError::ValidationError)?;
14 Ok(Self(ast))
15 }
16
17 pub fn new_for_fields(query: &str, allowed: &[&str]) -> Result<Self, RestSqlError> {
18 let ast = parse(query).map_err(RestSqlError::ParseError)?;
19 let ast = validate_inner(&ast, Some(allowed)).map_err(RestSqlError::ValidationError)?;
20 Ok(Self(ast))
21 }
22
23 #[cfg(feature = "serde")]
24 pub fn new_for<T>(query: &str) -> Result<Self, RestSqlError>
25 where
26 T: for<'de> serde::Deserialize<'de>,
27 {
28 Self::new_for_fields(query, serde_fields::<T>())
29 }
30
31 pub fn map_fields(&self, mapper: &impl FieldMapper) -> Self {
33 Self(apply_mapper(&self.0, mapper))
34 }
35
36 pub fn fields(&self) -> Vec<&str> {
38 fields(&self.0)
39 }
40
41 pub fn from_ast(ast: Ast) -> Result<Self, RestSqlError> {
47 let ast = validate_inner(&ast, None).map_err(RestSqlError::ValidationError)?;
48 Ok(Self(ast))
49 }
50
51 pub fn from_ast_for_fields(ast: Ast, allowed: &[&str]) -> Result<Self, RestSqlError> {
55 let ast = validate_inner(&ast, Some(allowed)).map_err(RestSqlError::ValidationError)?;
56 Ok(Self(ast))
57 }
58
59 #[cfg(feature = "serde")]
65 pub fn from_ast_for<T>(ast: Ast) -> Result<Self, RestSqlError>
66 where
67 T: for<'de> serde::Deserialize<'de>,
68 {
69 Self::from_ast_for_fields(ast, serde_fields::<T>())
70 }
71
72 pub fn ast(&self) -> &Ast {
74 &self.0
75 }
76}
77
78fn apply_mapper(ast: &Ast, mapper: &impl FieldMapper) -> Ast {
79 match ast {
80 Ast::And(children) => Ast::And(children.iter().map(|c| apply_mapper(c, mapper)).collect()),
81 Ast::Or(children) => Ast::Or(children.iter().map(|c| apply_mapper(c, mapper)).collect()),
82 Ast::Constraint(c) => Ast::Constraint(Constraint {
83 field: mapper.map(&c.field).into_owned(),
84 operator: c.operator.clone(),
85 value: c.value.clone(),
86 }),
87 }
88}
89
90pub(crate) fn validate_inner(
91 ast: &Ast,
92 allowed: Option<&[&str]>,
93) -> Result<Ast, Vec<ValidationError>> {
94 let mut errors = Vec::new();
95 let result = validate_node(ast, allowed, &mut errors);
96 if errors.is_empty() {
97 Ok(result.unwrap())
98 } else {
99 Err(errors)
100 }
101}
102
103fn validate_node(
104 ast: &Ast,
105 allowed: Option<&[&str]>,
106 errors: &mut Vec<ValidationError>,
107) -> Option<Ast> {
108 match ast {
109 Ast::And(children) => {
110 let nodes: Vec<_> = children
111 .iter()
112 .filter_map(|c| validate_node(c, allowed, errors))
113 .collect();
114 if nodes.len() == children.len() {
115 Some(Ast::And(nodes))
116 } else {
117 None
118 }
119 }
120 Ast::Or(children) => {
121 let nodes: Vec<_> = children
122 .iter()
123 .filter_map(|c| validate_node(c, allowed, errors))
124 .collect();
125 if nodes.len() == children.len() {
126 Some(Ast::Or(nodes))
127 } else {
128 None
129 }
130 }
131 Ast::Constraint(c) => validate_constraint(c, allowed, errors),
132 }
133}
134
135pub fn fields(ast: &Ast) -> Vec<&str> {
137 let mut out = Vec::new();
138 collect_fields(ast, &mut out);
139 out.sort();
140 out.dedup();
141 out
142}
143
144fn collect_fields<'a>(ast: &'a Ast, out: &mut Vec<&'a str>) {
145 match ast {
146 Ast::And(v) | Ast::Or(v) => v.iter().for_each(|n| collect_fields(n, out)),
147 Ast::Constraint(c) => out.push(&c.field),
148 }
149}
150
151fn validate_constraint(
152 c: &Constraint,
153 allowed: Option<&[&str]>,
154 errors: &mut Vec<ValidationError>,
155) -> Option<Ast> {
156 if let Some(allowed) = allowed
157 && !allowed.contains(&c.field.as_str())
158 {
159 errors.push(ValidationError::ForbiddenField(c.field.clone()));
160 return None;
161 }
162
163 let op_name = format!("{:?}", c.operator);
164
165 let value = match &c.operator {
166 Operator::In | Operator::Out => {
167 if !matches!(c.value, Value::List(_)) {
168 errors.push(ValidationError::ExpectedList {
169 field: c.field.clone(),
170 operator: op_name,
171 });
172 return None;
173 }
174 &c.value
175 }
176 Operator::Between => match &c.value {
177 Value::List(v) if v.len() == 2 => &c.value,
178 Value::List(_) => {
179 errors.push(ValidationError::BetweenArity {
180 field: c.field.clone(),
181 operator: op_name,
182 });
183 return None;
184 }
185 _ => {
186 errors.push(ValidationError::ExpectedList {
187 field: c.field.clone(),
188 operator: op_name,
189 });
190 return None;
191 }
192 },
193 Operator::Null | Operator::NotNull => &c.value,
194 _ => {
195 if matches!(c.value, Value::List(_)) {
196 errors.push(ValidationError::UnexpectedList {
197 field: c.field.clone(),
198 operator: op_name,
199 });
200 return None;
201 }
202 &c.value
203 }
204 };
205
206 Some(Ast::Constraint(Constraint {
207 field: c.field.clone(),
208 operator: c.operator.clone(),
209 value: value.clone(),
210 }))
211}
212
213#[cfg(feature = "serde")]
214mod serde_support {
215 use serde::de::{self, Deserializer, Visitor};
216 use std::fmt;
217
218 struct FieldExtractor;
219
220 enum ExtractErr {
221 Fields(&'static [&'static str]),
222 }
223
224 impl fmt::Display for ExtractErr {
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 write!(f, "field extraction")
227 }
228 }
229
230 impl fmt::Debug for ExtractErr {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 write!(f, "ExtractErr")
233 }
234 }
235
236 impl std::error::Error for ExtractErr {}
237
238 impl de::Error for ExtractErr {
239 fn custom<T: fmt::Display>(_: T) -> Self {
240 ExtractErr::Fields(&[])
241 }
242 }
243
244 impl<'de> Deserializer<'de> for FieldExtractor {
245 type Error = ExtractErr;
246
247 fn deserialize_any<V: Visitor<'de>>(self, _: V) -> Result<V::Value, ExtractErr> {
248 Err(ExtractErr::Fields(&[]))
249 }
250
251 fn deserialize_struct<V: Visitor<'de>>(
252 self,
253 _name: &'static str,
254 fields: &'static [&'static str],
255 _visitor: V,
256 ) -> Result<V::Value, ExtractErr> {
257 Err(ExtractErr::Fields(fields))
258 }
259
260 serde::forward_to_deserialize_any! {
261 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
262 bytes byte_buf option unit unit_struct newtype_struct seq tuple
263 tuple_struct map enum identifier ignored_any
264 }
265 }
266
267 pub fn serde_fields<'de, T: serde::Deserialize<'de>>() -> &'static [&'static str] {
272 match T::deserialize(FieldExtractor) {
273 Err(ExtractErr::Fields(f)) => f,
274 _ => &[],
275 }
276 }
277}
278
279#[cfg(feature = "serde")]
280pub use serde_support::serde_fields;