datafusion_odata/
filter.rs

1use chrono::{DateTime, Utc};
2use datafusion::{
3    logical_expr::{BinaryExpr, Operator, expr::InList},
4    prelude::*,
5    scalar::ScalarValue,
6};
7use odata_params::filters as odata_filters;
8
9use crate::error::*;
10
11///////////////////////////////////////////////////////////////////////////////
12
13#[derive(Debug)]
14pub struct ODataFilter(Expr);
15
16impl From<ODataFilter> for Expr {
17    fn from(value: ODataFilter) -> Self {
18        value.0
19    }
20}
21
22///////////////////////////////////////////////////////////////////////////////
23
24impl std::str::FromStr for ODataFilter {
25    type Err = ODataError;
26
27    fn from_str(s: &str) -> Result<Self, Self::Err> {
28        let odata_exprs = odata_params::filters::parse_str(s).map_err(ODataError::bad_request)?;
29        let df_exprs = odata_expr_to_df_expr(&odata_exprs)?;
30        Ok(ODataFilter(df_exprs))
31    }
32}
33
34impl<'de> serde::Deserialize<'de> for ODataFilter {
35    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
36    where
37        D: serde::Deserializer<'de>,
38    {
39        deserializer.deserialize_string(ODataFilterVisitor)
40    }
41}
42
43struct ODataFilterVisitor;
44
45impl serde::de::Visitor<'_> for ODataFilterVisitor {
46    type Value = ODataFilter;
47
48    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
49        write!(formatter, "an OData $filter string")
50    }
51
52    fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
53        v.parse().map_err(serde::de::Error::custom)
54    }
55}
56
57///////////////////////////////////////////////////////////////////////////////
58
59fn odata_expr_to_df_expr(res: &odata_filters::Expr) -> Result<Expr, ODataError> {
60    match res {
61        odata_filters::Expr::Or(l, r) => Ok(Expr::BinaryExpr(BinaryExpr::new(
62            Box::new(odata_expr_to_df_expr(l)?),
63            Operator::Or,
64            Box::new(odata_expr_to_df_expr(r)?),
65        ))),
66        odata_filters::Expr::And(l, r) => Ok(Expr::BinaryExpr(BinaryExpr::new(
67            Box::new(odata_expr_to_df_expr(l)?),
68            Operator::And,
69            Box::new(odata_expr_to_df_expr(r)?),
70        ))),
71        odata_filters::Expr::Compare(l, op, r) => Ok(Expr::BinaryExpr(BinaryExpr::new(
72            Box::new(odata_expr_to_df_expr(l)?),
73            odata_op_to_df_op(op),
74            Box::new(odata_expr_to_df_expr(r)?),
75        ))),
76        odata_filters::Expr::Value(v) => Ok(Expr::Literal(odata_value_to_df_value(v)?, None)),
77        odata_filters::Expr::Not(e) => Ok(Expr::Not(Box::new(odata_expr_to_df_expr(e)?))),
78        odata_filters::Expr::In(i, l) => Ok(Expr::InList(InList::new(
79            Box::new(odata_expr_to_df_expr(i)?),
80            l.iter()
81                .map(odata_expr_to_df_expr)
82                .collect::<Result<Vec<Expr>, ODataError>>()?,
83            false,
84        ))),
85        odata_filters::Expr::Identifier(s) => Ok(Expr::Column(Column::new_unqualified(s))),
86        odata_filters::Expr::Function(..) => {
87            Err(UnsupportedFeature::new("Function within the filter is not supported").into())
88        }
89    }
90}
91
92fn odata_value_to_df_value(v: &odata_filters::Value) -> Result<ScalarValue, ODataError> {
93    match v {
94        odata_filters::Value::String(s) => Ok(ScalarValue::LargeUtf8(Some(s.clone()))),
95        odata_filters::Value::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
96        odata_filters::Value::Null => Ok(ScalarValue::Null),
97        odata_filters::Value::Number(d) => {
98            let d = d
99                .to_string()
100                .parse::<i64>()
101                .map_err(|_| BadRequest::new("Filter contains invalid number"))?;
102            Ok(ScalarValue::Int64(Some(d)))
103        }
104        odata_filters::Value::DateTime(d) => Ok(ScalarValue::Date64(Some(d.timestamp()))),
105        odata_filters::Value::Date(d) => {
106            let d = d
107                .and_hms_opt(0, 0, 0)
108                .ok_or(BadRequest::new("Filter contains invalid date"))?;
109            let timestamp = DateTime::<Utc>::from_naive_utc_and_offset(d, Utc).timestamp();
110            Ok(ScalarValue::Date64(Some(timestamp)))
111        }
112        odata_filters::Value::Uuid(u) => Ok(ScalarValue::LargeUtf8(Some(u.to_string()))),
113        odata_filters::Value::Time(_) => {
114            Err(UnsupportedFeature::new("Time value in filter is not supported").into())
115        }
116    }
117}
118
119fn odata_op_to_df_op(op: &odata_filters::CompareOperator) -> Operator {
120    match op {
121        odata_filters::CompareOperator::Equal => Operator::Eq,
122        odata_filters::CompareOperator::NotEqual => Operator::NotEq,
123        odata_filters::CompareOperator::LessThan => Operator::Lt,
124        odata_filters::CompareOperator::GreaterThan => Operator::Gt,
125        odata_filters::CompareOperator::LessOrEqual => Operator::LtEq,
126        odata_filters::CompareOperator::GreaterOrEqual => Operator::GtEq,
127    }
128}
129
130///////////////////////////////////////////////////////////////////////////////