datafusion_physical_expr/expressions/
lambda.rs1use std::hash::Hash;
21use std::sync::Arc;
22
23use crate::{
24 ScalarFunctionExpr,
25 expressions::{Column, LambdaVariable},
26 physical_expr::PhysicalExpr,
27};
28use arrow::{
29 datatypes::{DataType, Schema},
30 record_batch::RecordBatch,
31};
32use datafusion_common::{
33 HashMap, plan_err,
34 tree_node::{Transformed, TreeNode, TreeNodeRecursion},
35};
36use datafusion_common::{HashSet, Result, internal_err};
37use datafusion_expr::ColumnarValue;
38
39#[derive(Debug, Eq, Clone)]
41pub struct LambdaExpr {
42 params: Vec<String>,
43 body: Arc<dyn PhysicalExpr>,
44 projected_body: Arc<dyn PhysicalExpr>,
45 projection: Vec<usize>,
46}
47
48impl PartialEq for LambdaExpr {
50 fn eq(&self, other: &Self) -> bool {
51 self.params.eq(&other.params) && self.body.eq(&other.body)
52 }
53}
54
55impl Hash for LambdaExpr {
56 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
57 self.params.hash(state);
58 self.body.hash(state);
59 }
60}
61
62impl LambdaExpr {
63 pub fn try_new(params: Vec<String>, body: Arc<dyn PhysicalExpr>) -> Result<Self> {
65 if !all_unique(¶ms) {
66 return plan_err!(
67 "lambda params must be unique, got ({})",
68 params.join(", ")
69 );
70 }
71
72 check_async_udf(&body)?;
73
74 Ok(Self::new(params, body))
75 }
76
77 fn new(params: Vec<String>, body: Arc<dyn PhysicalExpr>) -> Self {
78 let mut used_column_indices = HashSet::new();
79
80 body.apply(|node| {
81 if let Some(col) = node.downcast_ref::<Column>() {
82 used_column_indices.insert(col.index());
83 } else if let Some(var) = node.downcast_ref::<LambdaVariable>() {
84 used_column_indices.insert(var.index());
85 }
86
87 Ok(TreeNodeRecursion::Continue)
88 })
89 .expect("closure should be infallible");
90
91 let mut projection = used_column_indices.into_iter().collect::<Vec<_>>();
92
93 projection.sort();
94
95 let column_index_map = projection
96 .iter()
97 .enumerate()
98 .map(|(projected, original)| (*original, projected))
99 .collect::<HashMap<_, _>>();
100
101 let projected_body = Arc::clone(&body)
102 .transform_down(|e| {
103 if let Some(column) = e.downcast_ref::<Column>() {
104 let original = column.index();
105 let projected = *column_index_map.get(&original).unwrap();
106 if projected != original {
107 return Ok(Transformed::yes(Arc::new(Column::new(
108 column.name(),
109 projected,
110 ))));
111 }
112 } else if let Some(lambda_variable) = e.downcast_ref::<LambdaVariable>() {
113 let original = lambda_variable.index();
114 let projected = *column_index_map.get(&original).unwrap();
115 if projected != original {
116 return Ok(Transformed::yes(Arc::new(LambdaVariable::new(
117 projected,
118 Arc::clone(lambda_variable.field()),
119 ))));
120 }
121 }
122 Ok(Transformed::no(e))
123 })
124 .expect("closure should be infallible")
125 .data;
126
127 Self {
128 params,
129 body,
130 projected_body,
131 projection,
132 }
133 }
134
135 pub fn params(&self) -> &[String] {
137 &self.params
138 }
139
140 pub fn body(&self) -> &Arc<dyn PhysicalExpr> {
142 &self.body
143 }
144
145 pub(crate) fn projection(&self) -> &[usize] {
146 &self.projection
147 }
148
149 pub(crate) fn projected_body(&self) -> &Arc<dyn PhysicalExpr> {
150 &self.projected_body
151 }
152}
153
154impl std::fmt::Display for LambdaExpr {
155 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
156 write!(f, "({}) -> {}", self.params.join(", "), self.body)
157 }
158}
159
160impl PhysicalExpr for LambdaExpr {
161 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
162 Ok(DataType::Null)
163 }
164
165 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
166 Ok(true)
167 }
168
169 fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
170 internal_err!("LambdaExpr::evaluate() should not be called")
171 }
172
173 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
174 vec![&self.body]
175 }
176
177 fn with_new_children(
178 self: Arc<Self>,
179 children: Vec<Arc<dyn PhysicalExpr>>,
180 ) -> Result<Arc<dyn PhysicalExpr>> {
181 let [body] = children.as_slice() else {
182 return internal_err!(
183 "LambdaExpr expects exactly 1 child, got {}",
184 children.len()
185 );
186 };
187
188 check_async_udf(body)?;
189
190 Ok(Arc::new(Self::new(self.params.clone(), Arc::clone(body))))
191 }
192
193 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 write!(f, "({}) -> {}", self.params.join(", "), self.body)
195 }
196}
197
198pub fn lambda(
200 params: impl IntoIterator<Item = impl Into<String>>,
201 body: Arc<dyn PhysicalExpr>,
202) -> Result<Arc<dyn PhysicalExpr>> {
203 Ok(Arc::new(LambdaExpr::try_new(
204 params.into_iter().map(Into::into).collect(),
205 body,
206 )?))
207}
208
209fn all_unique(params: &[String]) -> bool {
210 match params.len() {
211 0 | 1 => true,
212 2 => params[0] != params[1],
213 _ => {
214 let mut set = HashSet::with_capacity(params.len());
215
216 params.iter().all(|p| set.insert(p.as_str()))
217 }
218 }
219}
220
221fn check_async_udf(body: &Arc<dyn PhysicalExpr>) -> Result<()> {
222 if body.exists(|expr| {
223 Ok(expr
224 .downcast_ref::<ScalarFunctionExpr>()
225 .is_some_and(|udf| udf.fun().as_async().is_some()))
226 })? {
227 return plan_err!(
228 "Async functions in lambdas aren't supported, see https://github.com/apache/datafusion/issues/22091"
229 );
230 }
231
232 Ok(())
233}
234
235#[cfg(test)]
236mod tests {
237 use crate::expressions::{NoOp, lambda::lambda};
238 use arrow::{array::RecordBatch, datatypes::Schema};
239 use std::sync::Arc;
240
241 #[test]
242 fn test_lambda_evaluate() {
243 let lambda = lambda(["a"], Arc::new(NoOp::new())).unwrap();
244 let batch = RecordBatch::new_empty(Arc::new(Schema::empty()));
245 assert!(lambda.evaluate(&batch).is_err());
246 }
247
248 #[test]
249 fn test_lambda_duplicate_name() {
250 assert!(lambda(["a", "a"], Arc::new(NoOp::new())).is_err());
251 }
252}