datafusion_odata/
filter.rs1use chrono::{DateTime, Utc};
2use datafusion::{
3 logical_expr::{BinaryExpr, Operator, expr::InList},
4 prelude::*,
5 scalar::ScalarValue,
6};
7use odata_params::filters as odata_filters;
8
9use crate::error::*;
10
11#[derive(Debug)]
14pub struct ODataFilter(Expr);
15
16impl From<ODataFilter> for Expr {
17 fn from(value: ODataFilter) -> Self {
18 value.0
19 }
20}
21
22impl std::str::FromStr for ODataFilter {
25 type Err = ODataError;
26
27 fn from_str(s: &str) -> Result<Self, Self::Err> {
28 let odata_exprs = odata_params::filters::parse_str(s).map_err(ODataError::bad_request)?;
29 let df_exprs = odata_expr_to_df_expr(&odata_exprs)?;
30 Ok(ODataFilter(df_exprs))
31 }
32}
33
34impl<'de> serde::Deserialize<'de> for ODataFilter {
35 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
36 where
37 D: serde::Deserializer<'de>,
38 {
39 deserializer.deserialize_string(ODataFilterVisitor)
40 }
41}
42
43struct ODataFilterVisitor;
44
45impl serde::de::Visitor<'_> for ODataFilterVisitor {
46 type Value = ODataFilter;
47
48 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
49 write!(formatter, "an OData $filter string")
50 }
51
52 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
53 v.parse().map_err(serde::de::Error::custom)
54 }
55}
56
57fn odata_expr_to_df_expr(res: &odata_filters::Expr) -> Result<Expr, ODataError> {
60 match res {
61 odata_filters::Expr::Or(l, r) => Ok(Expr::BinaryExpr(BinaryExpr::new(
62 Box::new(odata_expr_to_df_expr(l)?),
63 Operator::Or,
64 Box::new(odata_expr_to_df_expr(r)?),
65 ))),
66 odata_filters::Expr::And(l, r) => Ok(Expr::BinaryExpr(BinaryExpr::new(
67 Box::new(odata_expr_to_df_expr(l)?),
68 Operator::And,
69 Box::new(odata_expr_to_df_expr(r)?),
70 ))),
71 odata_filters::Expr::Compare(l, op, r) => Ok(Expr::BinaryExpr(BinaryExpr::new(
72 Box::new(odata_expr_to_df_expr(l)?),
73 odata_op_to_df_op(op),
74 Box::new(odata_expr_to_df_expr(r)?),
75 ))),
76 odata_filters::Expr::Value(v) => Ok(Expr::Literal(odata_value_to_df_value(v)?, None)),
77 odata_filters::Expr::Not(e) => Ok(Expr::Not(Box::new(odata_expr_to_df_expr(e)?))),
78 odata_filters::Expr::In(i, l) => Ok(Expr::InList(InList::new(
79 Box::new(odata_expr_to_df_expr(i)?),
80 l.iter()
81 .map(odata_expr_to_df_expr)
82 .collect::<Result<Vec<Expr>, ODataError>>()?,
83 false,
84 ))),
85 odata_filters::Expr::Identifier(s) => Ok(Expr::Column(Column::new_unqualified(s))),
86 odata_filters::Expr::Function(..) => {
87 Err(UnsupportedFeature::new("Function within the filter is not supported").into())
88 }
89 }
90}
91
92fn odata_value_to_df_value(v: &odata_filters::Value) -> Result<ScalarValue, ODataError> {
93 match v {
94 odata_filters::Value::String(s) => Ok(ScalarValue::LargeUtf8(Some(s.clone()))),
95 odata_filters::Value::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
96 odata_filters::Value::Null => Ok(ScalarValue::Null),
97 odata_filters::Value::Number(d) => {
98 let d = d
99 .to_string()
100 .parse::<i64>()
101 .map_err(|_| BadRequest::new("Filter contains invalid number"))?;
102 Ok(ScalarValue::Int64(Some(d)))
103 }
104 odata_filters::Value::DateTime(d) => Ok(ScalarValue::Date64(Some(d.timestamp()))),
105 odata_filters::Value::Date(d) => {
106 let d = d
107 .and_hms_opt(0, 0, 0)
108 .ok_or(BadRequest::new("Filter contains invalid date"))?;
109 let timestamp = DateTime::<Utc>::from_naive_utc_and_offset(d, Utc).timestamp();
110 Ok(ScalarValue::Date64(Some(timestamp)))
111 }
112 odata_filters::Value::Uuid(u) => Ok(ScalarValue::LargeUtf8(Some(u.to_string()))),
113 odata_filters::Value::Time(_) => {
114 Err(UnsupportedFeature::new("Time value in filter is not supported").into())
115 }
116 }
117}
118
119fn odata_op_to_df_op(op: &odata_filters::CompareOperator) -> Operator {
120 match op {
121 odata_filters::CompareOperator::Equal => Operator::Eq,
122 odata_filters::CompareOperator::NotEqual => Operator::NotEq,
123 odata_filters::CompareOperator::LessThan => Operator::Lt,
124 odata_filters::CompareOperator::GreaterThan => Operator::Gt,
125 odata_filters::CompareOperator::LessOrEqual => Operator::LtEq,
126 odata_filters::CompareOperator::GreaterOrEqual => Operator::GtEq,
127 }
128}
129
130