Skip to main content

datafusion_physical_expr/expressions/
lambda.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Physical lambda expression: [`LambdaExpr`]
19
20use std::hash::Hash;
21use std::sync::Arc;
22
23use crate::{
24    ScalarFunctionExpr,
25    expressions::{Column, LambdaVariable},
26    physical_expr::PhysicalExpr,
27};
28use arrow::{
29    datatypes::{DataType, Schema},
30    record_batch::RecordBatch,
31};
32use datafusion_common::{
33    HashMap, plan_err,
34    tree_node::{Transformed, TreeNode, TreeNodeRecursion},
35};
36use datafusion_common::{HashSet, Result, internal_err};
37use datafusion_expr::ColumnarValue;
38
39/// Represents a lambda with the given parameters names and body
40#[derive(Debug, Eq, Clone)]
41pub struct LambdaExpr {
42    params: Vec<String>,
43    body: Arc<dyn PhysicalExpr>,
44    projected_body: Arc<dyn PhysicalExpr>,
45    projection: Vec<usize>,
46}
47
48// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 [https://github.com/apache/datafusion/issues/13196]
49impl PartialEq for LambdaExpr {
50    fn eq(&self, other: &Self) -> bool {
51        self.params.eq(&other.params) && self.body.eq(&other.body)
52    }
53}
54
55impl Hash for LambdaExpr {
56    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
57        self.params.hash(state);
58        self.body.hash(state);
59    }
60}
61
62impl LambdaExpr {
63    /// Create a new lambda expression with the given parameters and body
64    pub fn try_new(params: Vec<String>, body: Arc<dyn PhysicalExpr>) -> Result<Self> {
65        if !all_unique(&params) {
66            return plan_err!(
67                "lambda params must be unique, got ({})",
68                params.join(", ")
69            );
70        }
71
72        check_async_udf(&body)?;
73
74        Ok(Self::new(params, body))
75    }
76
77    fn new(params: Vec<String>, body: Arc<dyn PhysicalExpr>) -> Self {
78        let mut used_column_indices = HashSet::new();
79
80        body.apply(|node| {
81            if let Some(col) = node.downcast_ref::<Column>() {
82                used_column_indices.insert(col.index());
83            } else if let Some(var) = node.downcast_ref::<LambdaVariable>() {
84                used_column_indices.insert(var.index());
85            }
86
87            Ok(TreeNodeRecursion::Continue)
88        })
89        .expect("closure should be infallible");
90
91        let mut projection = used_column_indices.into_iter().collect::<Vec<_>>();
92
93        projection.sort();
94
95        let column_index_map = projection
96            .iter()
97            .enumerate()
98            .map(|(projected, original)| (*original, projected))
99            .collect::<HashMap<_, _>>();
100
101        let projected_body = Arc::clone(&body)
102            .transform_down(|e| {
103                if let Some(column) = e.downcast_ref::<Column>() {
104                    let original = column.index();
105                    let projected = *column_index_map.get(&original).unwrap();
106                    if projected != original {
107                        return Ok(Transformed::yes(Arc::new(Column::new(
108                            column.name(),
109                            projected,
110                        ))));
111                    }
112                } else if let Some(lambda_variable) = e.downcast_ref::<LambdaVariable>() {
113                    let original = lambda_variable.index();
114                    let projected = *column_index_map.get(&original).unwrap();
115                    if projected != original {
116                        return Ok(Transformed::yes(Arc::new(LambdaVariable::new(
117                            projected,
118                            Arc::clone(lambda_variable.field()),
119                        ))));
120                    }
121                }
122                Ok(Transformed::no(e))
123            })
124            .expect("closure should be infallible")
125            .data;
126
127        Self {
128            params,
129            body,
130            projected_body,
131            projection,
132        }
133    }
134
135    /// Get the lambda's params names
136    pub fn params(&self) -> &[String] {
137        &self.params
138    }
139
140    /// Get the lambda's body
141    pub fn body(&self) -> &Arc<dyn PhysicalExpr> {
142        &self.body
143    }
144
145    pub(crate) fn projection(&self) -> &[usize] {
146        &self.projection
147    }
148
149    pub(crate) fn projected_body(&self) -> &Arc<dyn PhysicalExpr> {
150        &self.projected_body
151    }
152}
153
154impl std::fmt::Display for LambdaExpr {
155    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
156        write!(f, "({}) -> {}", self.params.join(", "), self.body)
157    }
158}
159
160impl PhysicalExpr for LambdaExpr {
161    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
162        Ok(DataType::Null)
163    }
164
165    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
166        Ok(true)
167    }
168
169    fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
170        internal_err!("LambdaExpr::evaluate() should not be called")
171    }
172
173    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
174        vec![&self.body]
175    }
176
177    fn with_new_children(
178        self: Arc<Self>,
179        children: Vec<Arc<dyn PhysicalExpr>>,
180    ) -> Result<Arc<dyn PhysicalExpr>> {
181        let [body] = children.as_slice() else {
182            return internal_err!(
183                "LambdaExpr expects exactly 1 child, got {}",
184                children.len()
185            );
186        };
187
188        check_async_udf(body)?;
189
190        Ok(Arc::new(Self::new(self.params.clone(), Arc::clone(body))))
191    }
192
193    fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        write!(f, "({}) -> {}", self.params.join(", "), self.body)
195    }
196}
197
198/// Create a lambda expression
199pub fn lambda(
200    params: impl IntoIterator<Item = impl Into<String>>,
201    body: Arc<dyn PhysicalExpr>,
202) -> Result<Arc<dyn PhysicalExpr>> {
203    Ok(Arc::new(LambdaExpr::try_new(
204        params.into_iter().map(Into::into).collect(),
205        body,
206    )?))
207}
208
209fn all_unique(params: &[String]) -> bool {
210    match params.len() {
211        0 | 1 => true,
212        2 => params[0] != params[1],
213        _ => {
214            let mut set = HashSet::with_capacity(params.len());
215
216            params.iter().all(|p| set.insert(p.as_str()))
217        }
218    }
219}
220
221fn check_async_udf(body: &Arc<dyn PhysicalExpr>) -> Result<()> {
222    if body.exists(|expr| {
223        Ok(expr
224            .downcast_ref::<ScalarFunctionExpr>()
225            .is_some_and(|udf| udf.fun().as_async().is_some()))
226    })? {
227        return plan_err!(
228            "Async functions in lambdas aren't supported, see https://github.com/apache/datafusion/issues/22091"
229        );
230    }
231
232    Ok(())
233}
234
235#[cfg(test)]
236mod tests {
237    use crate::expressions::{NoOp, lambda::lambda};
238    use arrow::{array::RecordBatch, datatypes::Schema};
239    use std::sync::Arc;
240
241    #[test]
242    fn test_lambda_evaluate() {
243        let lambda = lambda(["a"], Arc::new(NoOp::new())).unwrap();
244        let batch = RecordBatch::new_empty(Arc::new(Schema::empty()));
245        assert!(lambda.evaluate(&batch).is_err());
246    }
247
248    #[test]
249    fn test_lambda_duplicate_name() {
250        assert!(lambda(["a", "a"], Arc::new(NoOp::new())).is_err());
251    }
252}