rest-sql 0.3.0

RSQL/FIQL filter parser and validator for REST APIs — parse, validate, compile to native DB queries
Documentation
use crate::ast::Ast;
use crate::error::{RestSqlError, ValidationError};
use crate::mapper::FieldMapper;
use crate::parsing::parse;
use crate::{Constraint, Operator, Value};

#[derive(Debug, Clone)]
pub struct RestSql(Ast);

impl RestSql {
    pub fn new(query: &str) -> Result<Self, RestSqlError> {
        let ast = parse(query).map_err(RestSqlError::ParseError)?;
        let ast = validate_inner(&ast, None).map_err(RestSqlError::ValidationError)?;
        Ok(Self(ast))
    }

    pub fn new_for_fields(query: &str, allowed: &[&str]) -> Result<Self, RestSqlError> {
        let ast = parse(query).map_err(RestSqlError::ParseError)?;
        let ast = validate_inner(&ast, Some(allowed)).map_err(RestSqlError::ValidationError)?;
        Ok(Self(ast))
    }

    #[cfg(feature = "serde")]
    pub fn new_for<T>(query: &str) -> Result<Self, RestSqlError>
    where
        T: for<'de> serde::Deserialize<'de>,
    {
        Self::new_for_fields(query, serde_fields::<T>())
    }

    /// Returns a new `RestSql` with all field names transformed by `mapper`.
    pub fn map_fields(&self, mapper: &impl FieldMapper) -> Self {
        Self(apply_mapper(&self.0, mapper))
    }

    /// Returns the distinct field names referenced in this filter.
    pub fn fields(&self) -> Vec<&str> {
        fields(&self.0)
    }

    /// Builds a `RestSql` from a programmatically-constructed AST.
    ///
    /// Runs the same validation as `new` (operator/value compatibility, list
    /// arity for `Between`, etc.). Field allowlisting is skipped — call
    /// `from_ast_for_fields` or `from_ast_for::<T>()` if you need it.
    pub fn from_ast(ast: Ast) -> Result<Self, RestSqlError> {
        let ast = validate_inner(&ast, None).map_err(RestSqlError::ValidationError)?;
        Ok(Self(ast))
    }

    /// Like `from_ast`, but also enforces a field allowlist.
    ///
    /// Any field in the AST that is not in `allowed` causes a `ValidationError`.
    pub fn from_ast_for_fields(ast: Ast, allowed: &[&str]) -> Result<Self, RestSqlError> {
        let ast = validate_inner(&ast, Some(allowed)).map_err(RestSqlError::ValidationError)?;
        Ok(Self(ast))
    }

    /// Like `from_ast`, but derives the field allowlist from `T`'s `Deserialize` impl.
    ///
    /// Mirrors `new_for::<T>()` for programmatically-built ASTs — ensures that
    /// fields injected via the DSL are still subject to the same allowlist as
    /// fields coming from a user-supplied RSQL string.
    #[cfg(feature = "serde")]
    pub fn from_ast_for<T>(ast: Ast) -> Result<Self, RestSqlError>
    where
        T: for<'de> serde::Deserialize<'de>,
    {
        Self::from_ast_for_fields(ast, serde_fields::<T>())
    }

    /// Exposes the validated AST — for use by drivers.
    pub fn ast(&self) -> &Ast {
        &self.0
    }
}

fn apply_mapper(ast: &Ast, mapper: &impl FieldMapper) -> Ast {
    match ast {
        Ast::And(children) => Ast::And(children.iter().map(|c| apply_mapper(c, mapper)).collect()),
        Ast::Or(children) => Ast::Or(children.iter().map(|c| apply_mapper(c, mapper)).collect()),
        Ast::Constraint(c) => Ast::Constraint(Constraint {
            field: mapper.map(&c.field).into_owned(),
            operator: c.operator.clone(),
            value: c.value.clone(),
        }),
    }
}

pub(crate) fn validate_inner(
    ast: &Ast,
    allowed: Option<&[&str]>,
) -> Result<Ast, Vec<ValidationError>> {
    let mut errors = Vec::new();
    let result = validate_node(ast, allowed, &mut errors);
    if errors.is_empty() {
        Ok(result.unwrap())
    } else {
        Err(errors)
    }
}

fn validate_node(
    ast: &Ast,
    allowed: Option<&[&str]>,
    errors: &mut Vec<ValidationError>,
) -> Option<Ast> {
    match ast {
        Ast::And(children) => {
            let nodes: Vec<_> = children
                .iter()
                .filter_map(|c| validate_node(c, allowed, errors))
                .collect();
            if nodes.len() == children.len() {
                Some(Ast::And(nodes))
            } else {
                None
            }
        }
        Ast::Or(children) => {
            let nodes: Vec<_> = children
                .iter()
                .filter_map(|c| validate_node(c, allowed, errors))
                .collect();
            if nodes.len() == children.len() {
                Some(Ast::Or(nodes))
            } else {
                None
            }
        }
        Ast::Constraint(c) => validate_constraint(c, allowed, errors),
    }
}

/// Extract the list of distinct field names referenced in a DSL tree.
pub fn fields(ast: &Ast) -> Vec<&str> {
    let mut out = Vec::new();
    collect_fields(ast, &mut out);
    out.sort();
    out.dedup();
    out
}

fn collect_fields<'a>(ast: &'a Ast, out: &mut Vec<&'a str>) {
    match ast {
        Ast::And(v) | Ast::Or(v) => v.iter().for_each(|n| collect_fields(n, out)),
        Ast::Constraint(c) => out.push(&c.field),
    }
}

fn validate_constraint(
    c: &Constraint,
    allowed: Option<&[&str]>,
    errors: &mut Vec<ValidationError>,
) -> Option<Ast> {
    if let Some(allowed) = allowed
        && !allowed.contains(&c.field.as_str())
    {
        errors.push(ValidationError::ForbiddenField(c.field.clone()));
        return None;
    }

    let op_name = format!("{:?}", c.operator);

    let value = match &c.operator {
        Operator::In | Operator::Out => {
            if !matches!(c.value, Value::List(_)) {
                errors.push(ValidationError::ExpectedList {
                    field: c.field.clone(),
                    operator: op_name,
                });
                return None;
            }
            &c.value
        }
        Operator::Between => match &c.value {
            Value::List(v) if v.len() == 2 => &c.value,
            Value::List(_) => {
                errors.push(ValidationError::BetweenArity {
                    field: c.field.clone(),
                    operator: op_name,
                });
                return None;
            }
            _ => {
                errors.push(ValidationError::ExpectedList {
                    field: c.field.clone(),
                    operator: op_name,
                });
                return None;
            }
        },
        Operator::Null | Operator::NotNull => &c.value,
        _ => {
            if matches!(c.value, Value::List(_)) {
                errors.push(ValidationError::UnexpectedList {
                    field: c.field.clone(),
                    operator: op_name,
                });
                return None;
            }
            &c.value
        }
    };

    Some(Ast::Constraint(Constraint {
        field: c.field.clone(),
        operator: c.operator.clone(),
        value: value.clone(),
    }))
}

#[cfg(feature = "serde")]
mod serde_support {
    use serde::de::{self, Deserializer, Visitor};
    use std::fmt;

    struct FieldExtractor;

    enum ExtractErr {
        Fields(&'static [&'static str]),
    }

    impl fmt::Display for ExtractErr {
        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
            write!(f, "field extraction")
        }
    }

    impl fmt::Debug for ExtractErr {
        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
            write!(f, "ExtractErr")
        }
    }

    impl std::error::Error for ExtractErr {}

    impl de::Error for ExtractErr {
        fn custom<T: fmt::Display>(_: T) -> Self {
            ExtractErr::Fields(&[])
        }
    }

    impl<'de> Deserializer<'de> for FieldExtractor {
        type Error = ExtractErr;

        fn deserialize_any<V: Visitor<'de>>(self, _: V) -> Result<V::Value, ExtractErr> {
            Err(ExtractErr::Fields(&[]))
        }

        fn deserialize_struct<V: Visitor<'de>>(
            self,
            _name: &'static str,
            fields: &'static [&'static str],
            _visitor: V,
        ) -> Result<V::Value, ExtractErr> {
            Err(ExtractErr::Fields(fields))
        }

        serde::forward_to_deserialize_any! {
            bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
            bytes byte_buf option unit unit_struct newtype_struct seq tuple
            tuple_struct map enum identifier ignored_any
        }
    }

    /// Returns the serde field names of `T` without allocating.
    ///
    /// Works with any `#[derive(Deserialize)]` struct. Returns `&[]` for
    /// non-struct types (enums, maps, tuples).
    pub fn serde_fields<'de, T: serde::Deserialize<'de>>() -> &'static [&'static str] {
        match T::deserialize(FieldExtractor) {
            Err(ExtractErr::Fields(f)) => f,
            _ => &[],
        }
    }
}

#[cfg(feature = "serde")]
pub use serde_support::serde_fields;