datafusion_functions_json/
rewrite.rs

1use std::sync::Arc;
2
3use datafusion::arrow::datatypes::DataType;
4use datafusion::common::config::ConfigOptions;
5use datafusion::common::tree_node::Transformed;
6use datafusion::common::Column;
7use datafusion::common::DFSchema;
8use datafusion::common::Result;
9use datafusion::logical_expr::expr::{Alias, Cast, Expr, ScalarFunction};
10use datafusion::logical_expr::expr_rewriter::FunctionRewrite;
11use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr};
12use datafusion::logical_expr::sqlparser::ast::BinaryOperator;
13use datafusion::logical_expr::ScalarUDF;
14use datafusion::scalar::ScalarValue;
15
16#[derive(Debug)]
17pub(crate) struct JsonFunctionRewriter;
18
19impl FunctionRewrite for JsonFunctionRewriter {
20    fn name(&self) -> &'static str {
21        "JsonFunctionRewriter"
22    }
23
24    fn rewrite(&self, expr: Expr, _schema: &DFSchema, _config: &ConfigOptions) -> Result<Transformed<Expr>> {
25        let transform = match &expr {
26            Expr::Cast(cast) => optimise_json_get_cast(cast),
27            Expr::ScalarFunction(func) => unnest_json_calls(func),
28            _ => None,
29        };
30        Ok(transform.unwrap_or_else(|| Transformed::no(expr)))
31    }
32}
33
34/// This replaces `get_json(foo, bar)::int` with `json_get_int(foo, bar)` so the JSON function can take care of
35/// extracting the right value type from JSON without the need to materialize the JSON union.
36fn optimise_json_get_cast(cast: &Cast) -> Option<Transformed<Expr>> {
37    let scalar_func = extract_scalar_function(&cast.expr)?;
38    if scalar_func.func.name() != "json_get" {
39        return None;
40    }
41    let func = match &cast.data_type {
42        DataType::Boolean => crate::json_get_bool::json_get_bool_udf(),
43        DataType::Float64 | DataType::Float32 | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
44            crate::json_get_float::json_get_float_udf()
45        }
46        DataType::Int64 | DataType::Int32 => crate::json_get_int::json_get_int_udf(),
47        DataType::Utf8 => crate::json_get_str::json_get_str_udf(),
48        _ => return None,
49    };
50    Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
51        func,
52        args: scalar_func.args.clone(),
53    })))
54}
55
56// Replace nested JSON functions e.g. `json_get(json_get(col, 'foo'), 'bar')` with `json_get(col, 'foo', 'bar')`
57fn unnest_json_calls(func: &ScalarFunction) -> Option<Transformed<Expr>> {
58    if !matches!(
59        func.func.name(),
60        "json_get"
61            | "json_get_bool"
62            | "json_get_float"
63            | "json_get_int"
64            | "json_get_json"
65            | "json_get_str"
66            | "json_as_text"
67    ) {
68        return None;
69    }
70    let mut outer_args_iter = func.args.iter();
71    let first_arg = outer_args_iter.next()?;
72    let inner_func = extract_scalar_function(first_arg)?;
73
74    // both json_get and json_as_text would produce new JSON to be processed by the outer
75    // function so can be inlined
76    if !matches!(inner_func.func.name(), "json_get" | "json_as_text") {
77        return None;
78    }
79
80    let mut args = inner_func.args.clone();
81    args.extend(outer_args_iter.cloned());
82    // See #23, unnest only when all lookup arguments are literals
83    if args.iter().skip(1).all(|arg| matches!(arg, Expr::Literal(_))) {
84        Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
85            func: func.func.clone(),
86            args,
87        })))
88    } else {
89        None
90    }
91}
92
93fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> {
94    match expr {
95        Expr::ScalarFunction(func) => Some(func),
96        Expr::Alias(alias) => extract_scalar_function(&alias.expr),
97        _ => None,
98    }
99}
100
101#[derive(Debug, Clone, Copy)]
102enum JsonOperator {
103    Arrow,
104    LongArrow,
105    Question,
106}
107
108impl TryFrom<&BinaryOperator> for JsonOperator {
109    type Error = ();
110
111    fn try_from(op: &BinaryOperator) -> Result<Self, Self::Error> {
112        match op {
113            BinaryOperator::Arrow => Ok(JsonOperator::Arrow),
114            BinaryOperator::LongArrow => Ok(JsonOperator::LongArrow),
115            BinaryOperator::Question => Ok(JsonOperator::Question),
116            _ => Err(()),
117        }
118    }
119}
120
121impl From<JsonOperator> for Arc<ScalarUDF> {
122    fn from(op: JsonOperator) -> Arc<ScalarUDF> {
123        match op {
124            JsonOperator::Arrow => crate::udfs::json_get_udf(),
125            JsonOperator::LongArrow => crate::udfs::json_as_text_udf(),
126            JsonOperator::Question => crate::udfs::json_contains_udf(),
127        }
128    }
129}
130
131impl std::fmt::Display for JsonOperator {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        match self {
134            JsonOperator::Arrow => write!(f, "->"),
135            JsonOperator::LongArrow => write!(f, "->>"),
136            JsonOperator::Question => write!(f, "?"),
137        }
138    }
139}
140
141/// Convert an Expr to a String representatiion for use in alias names.
142fn expr_to_sql_repr(expr: &Expr) -> String {
143    match expr {
144        Expr::Column(Column {
145            name,
146            relation,
147            spans: _,
148        }) => relation
149            .as_ref()
150            .map_or_else(|| name.clone(), |r| format!("{r}.{name}")),
151        Expr::Alias(alias) => alias.name.clone(),
152        Expr::Literal(scalar) => match scalar {
153            ScalarValue::Utf8(Some(v)) | ScalarValue::Utf8View(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => {
154                format!("'{v}'")
155            }
156            ScalarValue::UInt8(Some(v)) => v.to_string(),
157            ScalarValue::UInt16(Some(v)) => v.to_string(),
158            ScalarValue::UInt32(Some(v)) => v.to_string(),
159            ScalarValue::UInt64(Some(v)) => v.to_string(),
160            ScalarValue::Int8(Some(v)) => v.to_string(),
161            ScalarValue::Int16(Some(v)) => v.to_string(),
162            ScalarValue::Int32(Some(v)) => v.to_string(),
163            ScalarValue::Int64(Some(v)) => v.to_string(),
164            _ => scalar.to_string(),
165        },
166        Expr::Cast(cast) => expr_to_sql_repr(&cast.expr),
167        _ => expr.to_string(),
168    }
169}
170
171/// Implement a custom SQL planner to replace postgres JSON operators with custom UDFs
172#[derive(Debug, Default)]
173pub struct JsonExprPlanner;
174
175impl ExprPlanner for JsonExprPlanner {
176    fn plan_binary_op(&self, expr: RawBinaryExpr, _schema: &DFSchema) -> Result<PlannerResult<RawBinaryExpr>> {
177        let Ok(op) = JsonOperator::try_from(&expr.op) else {
178            return Ok(PlannerResult::Original(expr));
179        };
180
181        let left_repr = expr_to_sql_repr(&expr.left);
182        let right_repr = expr_to_sql_repr(&expr.right);
183
184        let alias_name = format!("{left_repr} {op} {right_repr}");
185
186        // we put the alias in so that default column titles are `foo -> bar` instead of `json_get(foo, bar)`
187        Ok(PlannerResult::Planned(Expr::Alias(Alias::new(
188            Expr::ScalarFunction(ScalarFunction {
189                func: op.into(),
190                args: vec![expr.left, expr.right],
191            }),
192            None::<&str>,
193            alias_name,
194        ))))
195    }
196}