datafusion_physical_expr/
scalar_subquery.rs1use std::fmt;
21use std::hash::Hash;
22use std::sync::Arc;
23
24use arrow::datatypes::{DataType, Field, FieldRef, Schema};
25use arrow::record_batch::RecordBatch;
26use datafusion_common::{Result, internal_datafusion_err};
27use datafusion_expr::execution_props::{ScalarSubqueryResults, SubqueryIndex};
28use datafusion_expr_common::columnar_value::ColumnarValue;
29use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties};
30use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
31
32#[derive(Debug)]
38pub struct ScalarSubqueryExpr {
39 data_type: DataType,
40 nullable: bool,
41 index: SubqueryIndex,
43 results: ScalarSubqueryResults,
45}
46
47impl ScalarSubqueryExpr {
48 pub fn new(
49 data_type: DataType,
50 nullable: bool,
51 index: SubqueryIndex,
52 results: ScalarSubqueryResults,
53 ) -> Self {
54 Self {
55 data_type,
56 nullable,
57 index,
58 results,
59 }
60 }
61
62 pub fn data_type(&self) -> &DataType {
63 &self.data_type
64 }
65
66 pub fn nullable(&self) -> bool {
67 self.nullable
68 }
69
70 pub fn index(&self) -> SubqueryIndex {
72 self.index
73 }
74
75 pub fn results(&self) -> &ScalarSubqueryResults {
76 &self.results
77 }
78}
79
80impl fmt::Display for ScalarSubqueryExpr {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 match self.results.get(self.index) {
83 Some(v) => write!(f, "scalar_subquery({v})"),
84 None => write!(f, "scalar_subquery(<pending>)"),
85 }
86 }
87}
88
89impl Hash for ScalarSubqueryExpr {
92 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
93 self.results.hash(state);
94 self.index.hash(state);
95 }
96}
97
98impl PartialEq for ScalarSubqueryExpr {
99 fn eq(&self, other: &Self) -> bool {
100 self.results == other.results && self.index == other.index
101 }
102}
103
104impl Eq for ScalarSubqueryExpr {}
105
106impl PhysicalExpr for ScalarSubqueryExpr {
107 fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
108 Ok(Arc::new(Field::new(
109 "scalar_subquery",
110 self.data_type.clone(),
111 self.nullable,
112 )))
113 }
114
115 fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
116 let value = self.results.get(self.index).ok_or_else(|| {
117 internal_datafusion_err!(
118 "ScalarSubqueryExpr evaluated before the subquery was executed"
119 )
120 })?;
121 Ok(ColumnarValue::Scalar(value))
122 }
123
124 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
125 vec![]
126 }
127
128 fn with_new_children(
129 self: Arc<Self>,
130 _children: Vec<Arc<dyn PhysicalExpr>>,
131 ) -> Result<Arc<dyn PhysicalExpr>> {
132 Ok(self)
133 }
134
135 fn get_properties(&self, _children: &[ExprProperties]) -> Result<ExprProperties> {
136 Ok(ExprProperties::new_unknown().with_order(SortProperties::Singleton))
137 }
138
139 fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 write!(f, "(scalar subquery)")
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 use arrow::array::Int32Array;
149 use arrow::datatypes::Field;
150 use datafusion_common::ScalarValue;
151
152 fn make_results(values: Vec<Option<ScalarValue>>) -> ScalarSubqueryResults {
153 let results = ScalarSubqueryResults::new(values.len());
154 for (index, value) in values.into_iter().enumerate() {
155 if let Some(value) = value {
156 results.set(SubqueryIndex::new(index), value).unwrap();
157 }
158 }
159 results
160 }
161
162 #[test]
163 fn test_evaluate_with_value() -> Result<()> {
164 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
165 let a = Int32Array::from(vec![1, 2, 3]);
166 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
167
168 let results = make_results(vec![Some(ScalarValue::Int32(Some(42)))]);
169 let expr = ScalarSubqueryExpr::new(
170 DataType::Int32,
171 false,
172 SubqueryIndex::new(0),
173 results,
174 );
175
176 let result = expr.evaluate(&batch)?;
177 match result {
178 ColumnarValue::Scalar(ScalarValue::Int32(Some(42))) => {}
179 other => panic!("Expected Scalar(Int32(42)), got {other:?}"),
180 }
181 Ok(())
182 }
183
184 #[test]
185 fn test_evaluate_before_populated() {
186 let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
187 let a = Int32Array::from(vec![1]);
188 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
189
190 let results = ScalarSubqueryResults::new(1);
191 let expr = ScalarSubqueryExpr::new(
192 DataType::Int32,
193 false,
194 SubqueryIndex::new(0),
195 results,
196 );
197
198 let result = expr.evaluate(&batch);
199 assert!(result.is_err());
200 }
201
202 #[test]
203 fn test_identity_equality() {
204 let results = make_results(vec![None, None]);
205
206 let e1a = ScalarSubqueryExpr::new(
207 DataType::Int32,
208 false,
209 SubqueryIndex::new(0),
210 results.clone(),
211 );
212 let e1b = ScalarSubqueryExpr::new(
213 DataType::Int32,
214 false,
215 SubqueryIndex::new(0),
216 results.clone(),
217 );
218 let e2 = ScalarSubqueryExpr::new(
219 DataType::Int32,
220 false,
221 SubqueryIndex::new(1),
222 results.clone(),
223 );
224
225 assert_eq!(e1a, e1b);
227 assert_ne!(e1a, e2);
229
230 let other_results = make_results(vec![None]);
232 let e3 = ScalarSubqueryExpr::new(
233 DataType::Int32,
234 false,
235 SubqueryIndex::new(0),
236 other_results,
237 );
238 assert_ne!(e1a, e3);
239 }
240}