use crate::ast::Ast;
use crate::error::{RestSqlError, ValidationError};
use crate::mapper::FieldMapper;
use crate::parsing::parse;
use crate::{Constraint, Operator, Value};
#[derive(Debug, Clone)]
pub struct RestSql(Ast);
impl RestSql {
pub fn new(query: &str) -> Result<Self, RestSqlError> {
let ast = parse(query).map_err(RestSqlError::ParseError)?;
let ast = validate_inner(&ast, None).map_err(RestSqlError::ValidationError)?;
Ok(Self(ast))
}
pub fn new_for_fields(query: &str, allowed: &[&str]) -> Result<Self, RestSqlError> {
let ast = parse(query).map_err(RestSqlError::ParseError)?;
let ast = validate_inner(&ast, Some(allowed)).map_err(RestSqlError::ValidationError)?;
Ok(Self(ast))
}
#[cfg(feature = "serde")]
pub fn new_for<T>(query: &str) -> Result<Self, RestSqlError>
where
T: for<'de> serde::Deserialize<'de>,
{
Self::new_for_fields(query, serde_fields::<T>())
}
pub fn map_fields(&self, mapper: &impl FieldMapper) -> Self {
Self(apply_mapper(&self.0, mapper))
}
pub fn fields(&self) -> Vec<&str> {
fields(&self.0)
}
pub fn from_ast(ast: Ast) -> Result<Self, RestSqlError> {
let ast = validate_inner(&ast, None).map_err(RestSqlError::ValidationError)?;
Ok(Self(ast))
}
pub fn ast(&self) -> &Ast {
&self.0
}
}
fn apply_mapper(ast: &Ast, mapper: &impl FieldMapper) -> Ast {
match ast {
Ast::And(children) => Ast::And(children.iter().map(|c| apply_mapper(c, mapper)).collect()),
Ast::Or(children) => Ast::Or(children.iter().map(|c| apply_mapper(c, mapper)).collect()),
Ast::Constraint(c) => Ast::Constraint(Constraint {
field: mapper.map(&c.field).into_owned(),
operator: c.operator.clone(),
value: c.value.clone(),
}),
}
}
pub(crate) fn validate_inner(
ast: &Ast,
allowed: Option<&[&str]>,
) -> Result<Ast, Vec<ValidationError>> {
let mut errors = Vec::new();
let result = validate_node(ast, allowed, &mut errors);
if errors.is_empty() {
Ok(result.unwrap())
} else {
Err(errors)
}
}
fn validate_node(
ast: &Ast,
allowed: Option<&[&str]>,
errors: &mut Vec<ValidationError>,
) -> Option<Ast> {
match ast {
Ast::And(children) => {
let nodes: Vec<_> = children
.iter()
.filter_map(|c| validate_node(c, allowed, errors))
.collect();
if nodes.len() == children.len() {
Some(Ast::And(nodes))
} else {
None
}
}
Ast::Or(children) => {
let nodes: Vec<_> = children
.iter()
.filter_map(|c| validate_node(c, allowed, errors))
.collect();
if nodes.len() == children.len() {
Some(Ast::Or(nodes))
} else {
None
}
}
Ast::Constraint(c) => validate_constraint(c, allowed, errors),
}
}
pub fn fields(ast: &Ast) -> Vec<&str> {
let mut out = Vec::new();
collect_fields(ast, &mut out);
out.sort();
out.dedup();
out
}
fn collect_fields<'a>(ast: &'a Ast, out: &mut Vec<&'a str>) {
match ast {
Ast::And(v) | Ast::Or(v) => v.iter().for_each(|n| collect_fields(n, out)),
Ast::Constraint(c) => out.push(&c.field),
}
}
fn validate_constraint(
c: &Constraint,
allowed: Option<&[&str]>,
errors: &mut Vec<ValidationError>,
) -> Option<Ast> {
if let Some(allowed) = allowed
&& !allowed.contains(&c.field.as_str())
{
errors.push(ValidationError::ForbiddenField(c.field.clone()));
return None;
}
let op_name = format!("{:?}", c.operator);
let value = match &c.operator {
Operator::In | Operator::Out => {
if !matches!(c.value, Value::List(_)) {
errors.push(ValidationError::ExpectedList {
field: c.field.clone(),
operator: op_name,
});
return None;
}
&c.value
}
Operator::Between => match &c.value {
Value::List(v) if v.len() == 2 => &c.value,
Value::List(_) => {
errors.push(ValidationError::BetweenArity {
field: c.field.clone(),
operator: op_name,
});
return None;
}
_ => {
errors.push(ValidationError::ExpectedList {
field: c.field.clone(),
operator: op_name,
});
return None;
}
},
Operator::Null | Operator::NotNull => &c.value,
_ => {
if matches!(c.value, Value::List(_)) {
errors.push(ValidationError::UnexpectedList {
field: c.field.clone(),
operator: op_name,
});
return None;
}
&c.value
}
};
Some(Ast::Constraint(Constraint {
field: c.field.clone(),
operator: c.operator.clone(),
value: value.clone(),
}))
}
#[cfg(feature = "serde")]
mod serde_support {
use serde::de::{self, Deserializer, Visitor};
use std::fmt;
struct FieldExtractor;
enum ExtractErr {
Fields(&'static [&'static str]),
}
impl fmt::Display for ExtractErr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "field extraction")
}
}
impl fmt::Debug for ExtractErr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ExtractErr")
}
}
impl std::error::Error for ExtractErr {}
impl de::Error for ExtractErr {
fn custom<T: fmt::Display>(_: T) -> Self {
ExtractErr::Fields(&[])
}
}
impl<'de> Deserializer<'de> for FieldExtractor {
type Error = ExtractErr;
fn deserialize_any<V: Visitor<'de>>(self, _: V) -> Result<V::Value, ExtractErr> {
Err(ExtractErr::Fields(&[]))
}
fn deserialize_struct<V: Visitor<'de>>(
self,
_name: &'static str,
fields: &'static [&'static str],
_visitor: V,
) -> Result<V::Value, ExtractErr> {
Err(ExtractErr::Fields(fields))
}
serde::forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf option unit unit_struct newtype_struct seq tuple
tuple_struct map enum identifier ignored_any
}
}
pub fn serde_fields<'de, T: serde::Deserialize<'de>>() -> &'static [&'static str] {
match T::deserialize(FieldExtractor) {
Err(ExtractErr::Fields(f)) => f,
_ => &[],
}
}
}
#[cfg(feature = "serde")]
pub use serde_support::serde_fields;