Skip to main content

datafusion_physical_expr/
scalar_subquery.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Physical expression for uncorrelated scalar subqueries.
19
20use std::fmt;
21use std::hash::Hash;
22use std::sync::Arc;
23
24use arrow::datatypes::{DataType, Field, FieldRef, Schema};
25use arrow::record_batch::RecordBatch;
26use datafusion_common::{Result, internal_datafusion_err};
27use datafusion_expr::execution_props::{ScalarSubqueryResults, SubqueryIndex};
28use datafusion_expr_common::columnar_value::ColumnarValue;
29use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties};
30use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
31
32/// A physical expression whose value is provided by a scalar subquery.
33///
34/// Subquery execution is handled by `ScalarSubqueryExec`, which stores the
35/// result in a shared [`ScalarSubqueryResults`] container. This expression
36/// simply reads from that container at the appropriate index.
37#[derive(Debug)]
38pub struct ScalarSubqueryExpr {
39    data_type: DataType,
40    nullable: bool,
41    /// Index of this subquery in the shared results container.
42    index: SubqueryIndex,
43    /// Shared results container populated by `ScalarSubqueryExec`.
44    results: ScalarSubqueryResults,
45}
46
47impl ScalarSubqueryExpr {
48    pub fn new(
49        data_type: DataType,
50        nullable: bool,
51        index: SubqueryIndex,
52        results: ScalarSubqueryResults,
53    ) -> Self {
54        Self {
55            data_type,
56            nullable,
57            index,
58            results,
59        }
60    }
61
62    pub fn data_type(&self) -> &DataType {
63        &self.data_type
64    }
65
66    pub fn nullable(&self) -> bool {
67        self.nullable
68    }
69
70    /// Returns the index of this subquery in the shared results container.
71    pub fn index(&self) -> SubqueryIndex {
72        self.index
73    }
74
75    pub fn results(&self) -> &ScalarSubqueryResults {
76        &self.results
77    }
78}
79
80impl fmt::Display for ScalarSubqueryExpr {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        match self.results.get(self.index) {
83            Some(v) => write!(f, "scalar_subquery({v})"),
84            None => write!(f, "scalar_subquery(<pending>)"),
85        }
86    }
87}
88
89// Two ScalarSubqueryExprs are considered the "same" if they refer to the
90// same underlying shared results container and the same index within it.
91impl Hash for ScalarSubqueryExpr {
92    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
93        self.results.hash(state);
94        self.index.hash(state);
95    }
96}
97
98impl PartialEq for ScalarSubqueryExpr {
99    fn eq(&self, other: &Self) -> bool {
100        self.results == other.results && self.index == other.index
101    }
102}
103
104impl Eq for ScalarSubqueryExpr {}
105
106impl PhysicalExpr for ScalarSubqueryExpr {
107    fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
108        Ok(Arc::new(Field::new(
109            "scalar_subquery",
110            self.data_type.clone(),
111            self.nullable,
112        )))
113    }
114
115    fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
116        let value = self.results.get(self.index).ok_or_else(|| {
117            internal_datafusion_err!(
118                "ScalarSubqueryExpr evaluated before the subquery was executed"
119            )
120        })?;
121        Ok(ColumnarValue::Scalar(value))
122    }
123
124    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
125        vec![]
126    }
127
128    fn with_new_children(
129        self: Arc<Self>,
130        _children: Vec<Arc<dyn PhysicalExpr>>,
131    ) -> Result<Arc<dyn PhysicalExpr>> {
132        Ok(self)
133    }
134
135    fn get_properties(&self, _children: &[ExprProperties]) -> Result<ExprProperties> {
136        Ok(ExprProperties::new_unknown().with_order(SortProperties::Singleton))
137    }
138
139    fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140        write!(f, "(scalar subquery)")
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    use arrow::array::Int32Array;
149    use arrow::datatypes::Field;
150    use datafusion_common::ScalarValue;
151
152    fn make_results(values: Vec<Option<ScalarValue>>) -> ScalarSubqueryResults {
153        let results = ScalarSubqueryResults::new(values.len());
154        for (index, value) in values.into_iter().enumerate() {
155            if let Some(value) = value {
156                results.set(SubqueryIndex::new(index), value).unwrap();
157            }
158        }
159        results
160    }
161
162    #[test]
163    fn test_evaluate_with_value() -> Result<()> {
164        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
165        let a = Int32Array::from(vec![1, 2, 3]);
166        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
167
168        let results = make_results(vec![Some(ScalarValue::Int32(Some(42)))]);
169        let expr = ScalarSubqueryExpr::new(
170            DataType::Int32,
171            false,
172            SubqueryIndex::new(0),
173            results,
174        );
175
176        let result = expr.evaluate(&batch)?;
177        match result {
178            ColumnarValue::Scalar(ScalarValue::Int32(Some(42))) => {}
179            other => panic!("Expected Scalar(Int32(42)), got {other:?}"),
180        }
181        Ok(())
182    }
183
184    #[test]
185    fn test_evaluate_before_populated() {
186        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
187        let a = Int32Array::from(vec![1]);
188        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
189
190        let results = ScalarSubqueryResults::new(1);
191        let expr = ScalarSubqueryExpr::new(
192            DataType::Int32,
193            false,
194            SubqueryIndex::new(0),
195            results,
196        );
197
198        let result = expr.evaluate(&batch);
199        assert!(result.is_err());
200    }
201
202    #[test]
203    fn test_identity_equality() {
204        let results = make_results(vec![None, None]);
205
206        let e1a = ScalarSubqueryExpr::new(
207            DataType::Int32,
208            false,
209            SubqueryIndex::new(0),
210            results.clone(),
211        );
212        let e1b = ScalarSubqueryExpr::new(
213            DataType::Int32,
214            false,
215            SubqueryIndex::new(0),
216            results.clone(),
217        );
218        let e2 = ScalarSubqueryExpr::new(
219            DataType::Int32,
220            false,
221            SubqueryIndex::new(1),
222            results.clone(),
223        );
224
225        // Same container + same index → equal
226        assert_eq!(e1a, e1b);
227        // Same container, different index → not equal
228        assert_ne!(e1a, e2);
229
230        // Different container, same index → not equal
231        let other_results = make_results(vec![None]);
232        let e3 = ScalarSubqueryExpr::new(
233            DataType::Int32,
234            false,
235            SubqueryIndex::new(0),
236            other_results,
237        );
238        assert_ne!(e1a, e3);
239    }
240}