1use crate::ast::Ast;
2use crate::error::{RestSqlError, ValidationError};
3use crate::mapper::FieldMapper;
4use crate::parsing::parse;
5use crate::{Constraint, Operator, Value};
6
7#[derive(Debug, Clone)]
8pub struct RestSql(Ast);
9
10impl RestSql {
11 pub fn new(query: &str) -> Result<Self, RestSqlError> {
12 let ast = parse(query).map_err(RestSqlError::ParseError)?;
13 let ast = validate_inner(&ast, None).map_err(RestSqlError::ValidationError)?;
14 Ok(Self(ast))
15 }
16
17 pub fn new_for_fields(query: &str, allowed: &[&str]) -> Result<Self, RestSqlError> {
18 let ast = parse(query).map_err(RestSqlError::ParseError)?;
19 let ast = validate_inner(&ast, Some(allowed)).map_err(RestSqlError::ValidationError)?;
20 Ok(Self(ast))
21 }
22
23 #[cfg(feature = "serde")]
24 pub fn new_for<T>(query: &str) -> Result<Self, RestSqlError>
25 where
26 T: for<'de> serde::Deserialize<'de>,
27 {
28 Self::new_for_fields(query, serde_fields::<T>())
29 }
30
31 pub fn map_fields(&self, mapper: &impl FieldMapper) -> Self {
33 Self(apply_mapper(&self.0, mapper))
34 }
35
36 pub fn fields(&self) -> Vec<&str> {
38 fields(&self.0)
39 }
40
41 pub fn ast(&self) -> &Ast {
43 &self.0
44 }
45}
46
47fn apply_mapper(ast: &Ast, mapper: &impl FieldMapper) -> Ast {
48 match ast {
49 Ast::And(children) => Ast::And(children.iter().map(|c| apply_mapper(c, mapper)).collect()),
50 Ast::Or(children) => Ast::Or(children.iter().map(|c| apply_mapper(c, mapper)).collect()),
51 Ast::Constraint(c) => Ast::Constraint(Constraint {
52 field: mapper.map(&c.field).into_owned(),
53 operator: c.operator.clone(),
54 value: c.value.clone(),
55 }),
56 }
57}
58
59pub(crate) fn validate_inner(
60 ast: &Ast,
61 allowed: Option<&[&str]>,
62) -> Result<Ast, Vec<ValidationError>> {
63 let mut errors = Vec::new();
64 let result = validate_node(ast, allowed, &mut errors);
65 if errors.is_empty() {
66 Ok(result.unwrap())
67 } else {
68 Err(errors)
69 }
70}
71
72fn validate_node(
73 ast: &Ast,
74 allowed: Option<&[&str]>,
75 errors: &mut Vec<ValidationError>,
76) -> Option<Ast> {
77 match ast {
78 Ast::And(children) => {
79 let nodes: Vec<_> = children
80 .iter()
81 .filter_map(|c| validate_node(c, allowed, errors))
82 .collect();
83 if nodes.len() == children.len() {
84 Some(Ast::And(nodes))
85 } else {
86 None
87 }
88 }
89 Ast::Or(children) => {
90 let nodes: Vec<_> = children
91 .iter()
92 .filter_map(|c| validate_node(c, allowed, errors))
93 .collect();
94 if nodes.len() == children.len() {
95 Some(Ast::Or(nodes))
96 } else {
97 None
98 }
99 }
100 Ast::Constraint(c) => validate_constraint(c, allowed, errors),
101 }
102}
103
104pub fn fields(ast: &Ast) -> Vec<&str> {
106 let mut out = Vec::new();
107 collect_fields(ast, &mut out);
108 out.sort();
109 out.dedup();
110 out
111}
112
113fn collect_fields<'a>(ast: &'a Ast, out: &mut Vec<&'a str>) {
114 match ast {
115 Ast::And(v) | Ast::Or(v) => v.iter().for_each(|n| collect_fields(n, out)),
116 Ast::Constraint(c) => out.push(&c.field),
117 }
118}
119
120fn validate_constraint(
121 c: &Constraint,
122 allowed: Option<&[&str]>,
123 errors: &mut Vec<ValidationError>,
124) -> Option<Ast> {
125 if let Some(allowed) = allowed
126 && !allowed.contains(&c.field.as_str())
127 {
128 errors.push(ValidationError::ForbiddenField(c.field.clone()));
129 return None;
130 }
131
132 let op_name = format!("{:?}", c.operator);
133
134 let value = match &c.operator {
135 Operator::In | Operator::Out => {
136 if !matches!(c.value, Value::List(_)) {
137 errors.push(ValidationError::ExpectedList {
138 field: c.field.clone(),
139 operator: op_name,
140 });
141 return None;
142 }
143 &c.value
144 }
145 Operator::Between => match &c.value {
146 Value::List(v) if v.len() == 2 => &c.value,
147 Value::List(_) => {
148 errors.push(ValidationError::BetweenArity {
149 field: c.field.clone(),
150 operator: op_name,
151 });
152 return None;
153 }
154 _ => {
155 errors.push(ValidationError::ExpectedList {
156 field: c.field.clone(),
157 operator: op_name,
158 });
159 return None;
160 }
161 },
162 Operator::Null | Operator::NotNull => &c.value,
163 _ => {
164 if matches!(c.value, Value::List(_)) {
165 errors.push(ValidationError::UnexpectedList {
166 field: c.field.clone(),
167 operator: op_name,
168 });
169 return None;
170 }
171 &c.value
172 }
173 };
174
175 Some(Ast::Constraint(Constraint {
176 field: c.field.clone(),
177 operator: c.operator.clone(),
178 value: value.clone(),
179 }))
180}
181
182#[cfg(feature = "serde")]
183mod serde_support {
184 use serde::de::{self, Deserializer, Visitor};
185 use std::fmt;
186
187 struct FieldExtractor;
188
189 enum ExtractErr {
190 Fields(&'static [&'static str]),
191 }
192
193 impl fmt::Display for ExtractErr {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 write!(f, "field extraction")
196 }
197 }
198
199 impl fmt::Debug for ExtractErr {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 write!(f, "ExtractErr")
202 }
203 }
204
205 impl std::error::Error for ExtractErr {}
206
207 impl de::Error for ExtractErr {
208 fn custom<T: fmt::Display>(_: T) -> Self {
209 ExtractErr::Fields(&[])
210 }
211 }
212
213 impl<'de> Deserializer<'de> for FieldExtractor {
214 type Error = ExtractErr;
215
216 fn deserialize_any<V: Visitor<'de>>(self, _: V) -> Result<V::Value, ExtractErr> {
217 Err(ExtractErr::Fields(&[]))
218 }
219
220 fn deserialize_struct<V: Visitor<'de>>(
221 self,
222 _name: &'static str,
223 fields: &'static [&'static str],
224 _visitor: V,
225 ) -> Result<V::Value, ExtractErr> {
226 Err(ExtractErr::Fields(fields))
227 }
228
229 serde::forward_to_deserialize_any! {
230 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
231 bytes byte_buf option unit unit_struct newtype_struct seq tuple
232 tuple_struct map enum identifier ignored_any
233 }
234 }
235
236 pub fn serde_fields<'de, T: serde::Deserialize<'de>>() -> &'static [&'static str] {
241 match T::deserialize(FieldExtractor) {
242 Err(ExtractErr::Fields(f)) => f,
243 _ => &[],
244 }
245 }
246}
247
248#[cfg(feature = "serde")]
249pub use serde_support::serde_fields;