datafusion_functions_json/
rewrite.rs1use std::sync::Arc;
2
3use datafusion::arrow::datatypes::DataType;
4use datafusion::common::config::ConfigOptions;
5use datafusion::common::tree_node::Transformed;
6use datafusion::common::Column;
7use datafusion::common::DFSchema;
8use datafusion::common::Result;
9use datafusion::logical_expr::expr::{Alias, Cast, Expr, ScalarFunction};
10use datafusion::logical_expr::expr_rewriter::FunctionRewrite;
11use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr};
12use datafusion::logical_expr::sqlparser::ast::BinaryOperator;
13use datafusion::logical_expr::ScalarUDF;
14use datafusion::scalar::ScalarValue;
15
16#[derive(Debug)]
17pub(crate) struct JsonFunctionRewriter;
18
19impl FunctionRewrite for JsonFunctionRewriter {
20 fn name(&self) -> &'static str {
21 "JsonFunctionRewriter"
22 }
23
24 fn rewrite(&self, expr: Expr, _schema: &DFSchema, _config: &ConfigOptions) -> Result<Transformed<Expr>> {
25 let transform = match &expr {
26 Expr::Cast(cast) => optimise_json_get_cast(cast),
27 Expr::ScalarFunction(func) => unnest_json_calls(func),
28 _ => None,
29 };
30 Ok(transform.unwrap_or_else(|| Transformed::no(expr)))
31 }
32}
33
34fn optimise_json_get_cast(cast: &Cast) -> Option<Transformed<Expr>> {
37 let scalar_func = extract_scalar_function(&cast.expr)?;
38 if scalar_func.func.name() != "json_get" {
39 return None;
40 }
41 let func = match &cast.data_type {
42 DataType::Boolean => crate::json_get_bool::json_get_bool_udf(),
43 DataType::Float64 | DataType::Float32 | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
44 crate::json_get_float::json_get_float_udf()
45 }
46 DataType::Int64 | DataType::Int32 => crate::json_get_int::json_get_int_udf(),
47 DataType::Utf8 => crate::json_get_str::json_get_str_udf(),
48 _ => return None,
49 };
50 Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
51 func,
52 args: scalar_func.args.clone(),
53 })))
54}
55
56fn unnest_json_calls(func: &ScalarFunction) -> Option<Transformed<Expr>> {
58 if !matches!(
59 func.func.name(),
60 "json_get"
61 | "json_get_bool"
62 | "json_get_float"
63 | "json_get_int"
64 | "json_get_json"
65 | "json_get_str"
66 | "json_as_text"
67 ) {
68 return None;
69 }
70 let mut outer_args_iter = func.args.iter();
71 let first_arg = outer_args_iter.next()?;
72 let inner_func = extract_scalar_function(first_arg)?;
73
74 if !matches!(inner_func.func.name(), "json_get" | "json_as_text") {
77 return None;
78 }
79
80 let mut args = inner_func.args.clone();
81 args.extend(outer_args_iter.cloned());
82 if args.iter().skip(1).all(|arg| matches!(arg, Expr::Literal(_))) {
84 Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
85 func: func.func.clone(),
86 args,
87 })))
88 } else {
89 None
90 }
91}
92
93fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> {
94 match expr {
95 Expr::ScalarFunction(func) => Some(func),
96 Expr::Alias(alias) => extract_scalar_function(&alias.expr),
97 _ => None,
98 }
99}
100
101#[derive(Debug, Clone, Copy)]
102enum JsonOperator {
103 Arrow,
104 LongArrow,
105 Question,
106}
107
108impl TryFrom<&BinaryOperator> for JsonOperator {
109 type Error = ();
110
111 fn try_from(op: &BinaryOperator) -> Result<Self, Self::Error> {
112 match op {
113 BinaryOperator::Arrow => Ok(JsonOperator::Arrow),
114 BinaryOperator::LongArrow => Ok(JsonOperator::LongArrow),
115 BinaryOperator::Question => Ok(JsonOperator::Question),
116 _ => Err(()),
117 }
118 }
119}
120
121impl From<JsonOperator> for Arc<ScalarUDF> {
122 fn from(op: JsonOperator) -> Arc<ScalarUDF> {
123 match op {
124 JsonOperator::Arrow => crate::udfs::json_get_udf(),
125 JsonOperator::LongArrow => crate::udfs::json_as_text_udf(),
126 JsonOperator::Question => crate::udfs::json_contains_udf(),
127 }
128 }
129}
130
131impl std::fmt::Display for JsonOperator {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 match self {
134 JsonOperator::Arrow => write!(f, "->"),
135 JsonOperator::LongArrow => write!(f, "->>"),
136 JsonOperator::Question => write!(f, "?"),
137 }
138 }
139}
140
141fn expr_to_sql_repr(expr: &Expr) -> String {
143 match expr {
144 Expr::Column(Column {
145 name,
146 relation,
147 spans: _,
148 }) => relation
149 .as_ref()
150 .map_or_else(|| name.clone(), |r| format!("{r}.{name}")),
151 Expr::Alias(alias) => alias.name.clone(),
152 Expr::Literal(scalar) => match scalar {
153 ScalarValue::Utf8(Some(v)) | ScalarValue::Utf8View(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => {
154 format!("'{v}'")
155 }
156 ScalarValue::UInt8(Some(v)) => v.to_string(),
157 ScalarValue::UInt16(Some(v)) => v.to_string(),
158 ScalarValue::UInt32(Some(v)) => v.to_string(),
159 ScalarValue::UInt64(Some(v)) => v.to_string(),
160 ScalarValue::Int8(Some(v)) => v.to_string(),
161 ScalarValue::Int16(Some(v)) => v.to_string(),
162 ScalarValue::Int32(Some(v)) => v.to_string(),
163 ScalarValue::Int64(Some(v)) => v.to_string(),
164 _ => scalar.to_string(),
165 },
166 Expr::Cast(cast) => expr_to_sql_repr(&cast.expr),
167 _ => expr.to_string(),
168 }
169}
170
171#[derive(Debug, Default)]
173pub struct JsonExprPlanner;
174
175impl ExprPlanner for JsonExprPlanner {
176 fn plan_binary_op(&self, expr: RawBinaryExpr, _schema: &DFSchema) -> Result<PlannerResult<RawBinaryExpr>> {
177 let Ok(op) = JsonOperator::try_from(&expr.op) else {
178 return Ok(PlannerResult::Original(expr));
179 };
180
181 let left_repr = expr_to_sql_repr(&expr.left);
182 let right_repr = expr_to_sql_repr(&expr.right);
183
184 let alias_name = format!("{left_repr} {op} {right_repr}");
185
186 Ok(PlannerResult::Planned(Expr::Alias(Alias::new(
188 Expr::ScalarFunction(ScalarFunction {
189 func: op.into(),
190 args: vec![expr.left, expr.right],
191 }),
192 None::<&str>,
193 alias_name,
194 ))))
195 }
196}