Skip to main content

datafusion_physical_plan/
scalar_subquery.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Execution plan for uncorrelated scalar subqueries.
19//!
20//! [`ScalarSubqueryExec`] wraps a main input plan and a set of subquery plans.
21//! At execution time, it runs each subquery exactly once, extracts the scalar
22//! result, and populates a shared [`ScalarSubqueryResults`] container that
23//! [`ScalarSubqueryExpr`] instances hold directly and read from by index.
24//!
25//! [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr
26
27use std::fmt;
28use std::sync::Arc;
29
30use datafusion_common::{Result, ScalarValue, Statistics, exec_err, internal_err};
31use datafusion_execution::TaskContext;
32use datafusion_expr::execution_props::{ScalarSubqueryResults, SubqueryIndex};
33
34use crate::execution_plan::{CardinalityEffect, ExecutionPlan, PlanProperties};
35use crate::joins::utils::{OnceAsync, OnceFut};
36use crate::stream::RecordBatchStreamAdapter;
37use crate::{DisplayAs, DisplayFormatType, SendableRecordBatchStream};
38
39use futures::StreamExt;
40use futures::TryStreamExt;
41
42/// Links a scalar subquery's execution plan to its index in the shared results
43/// container. The [`ScalarSubqueryExec`] that owns these links populates
44/// `results[index]` at execution time, and [`ScalarSubqueryExpr`] instances
45/// with the same index read from it.
46///
47/// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr
48#[derive(Debug, Clone)]
49pub struct ScalarSubqueryLink {
50    /// The physical plan for the subquery.
51    pub plan: Arc<dyn ExecutionPlan>,
52    /// Index into the shared results container.
53    pub index: SubqueryIndex,
54}
55
56/// Manages execution of uncorrelated scalar subqueries for a single plan
57/// level.
58///
59/// From a query-results perspective, this node is a pass-through: it yields
60/// the same batches as its main input and exists only to populate scalar
61/// subquery results as a side effect before those batches are produced.
62///
63/// The first child node is the **main input plan**, whose batches are passed
64/// through unchanged. The remaining children are **subquery plans**, each of
65/// which must produce exactly zero or one row. Before any batches from the main
66/// input are yielded, all subquery plans are executed and their scalar results
67/// are stored in a shared [`ScalarSubqueryResults`] container owned by this
68/// node. [`ScalarSubqueryExpr`] nodes embedded in the main input's expressions
69/// hold the same container and read from it by index.
70///
71/// All subqueries are evaluated eagerly when the first output partition is
72/// requested, before any rows from the main input are produced.
73///
74/// TODO: Consider overlapping computation of the subqueries with evaluating the
75/// main query.
76///
77/// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr
78#[derive(Debug)]
79pub struct ScalarSubqueryExec {
80    /// The main input plan whose output is passed through.
81    input: Arc<dyn ExecutionPlan>,
82    /// Subquery plans and their result indexes.
83    subqueries: Vec<ScalarSubqueryLink>,
84    /// Shared one-time async computation of subquery results.
85    subquery_future: Arc<OnceAsync<()>>,
86    /// Shared results container; the corresponding `ScalarSubqueryExpr`
87    /// nodes in the input plan hold the same underlying container.
88    results: ScalarSubqueryResults,
89    /// Cached plan properties (copied from input).
90    cache: Arc<PlanProperties>,
91}
92
93impl ScalarSubqueryExec {
94    pub fn new(
95        input: Arc<dyn ExecutionPlan>,
96        subqueries: Vec<ScalarSubqueryLink>,
97        results: ScalarSubqueryResults,
98    ) -> Self {
99        let cache = Arc::clone(input.properties());
100        Self {
101            input,
102            subqueries,
103            subquery_future: Arc::default(),
104            results,
105            cache,
106        }
107    }
108
109    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
110        &self.input
111    }
112
113    pub fn subqueries(&self) -> &[ScalarSubqueryLink] {
114        &self.subqueries
115    }
116
117    pub fn results(&self) -> &ScalarSubqueryResults {
118        &self.results
119    }
120
121    /// Returns a per-child bool vec that is `true` for the main input
122    /// (child 0) and `false` for every subquery child.
123    fn true_for_input_only(&self) -> Vec<bool> {
124        std::iter::once(true)
125            .chain(std::iter::repeat_n(false, self.subqueries.len()))
126            .collect()
127    }
128}
129
130impl DisplayAs for ScalarSubqueryExec {
131    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
132        match t {
133            DisplayFormatType::Default | DisplayFormatType::Verbose => {
134                write!(
135                    f,
136                    "ScalarSubqueryExec: subqueries={}",
137                    self.subqueries.len()
138                )
139            }
140            DisplayFormatType::TreeRender => {
141                write!(f, "")
142            }
143        }
144    }
145}
146
147impl ExecutionPlan for ScalarSubqueryExec {
148    fn name(&self) -> &'static str {
149        "ScalarSubqueryExec"
150    }
151
152    fn properties(&self) -> &Arc<PlanProperties> {
153        &self.cache
154    }
155
156    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
157        let mut children = vec![&self.input];
158        for sq in &self.subqueries {
159            children.push(&sq.plan);
160        }
161        children
162    }
163
164    fn with_new_children(
165        self: Arc<Self>,
166        mut children: Vec<Arc<dyn ExecutionPlan>>,
167    ) -> Result<Arc<dyn ExecutionPlan>> {
168        // First child is the main input, the rest are subquery plans.
169        let input = children.remove(0);
170        let subqueries = self
171            .subqueries
172            .iter()
173            .zip(children)
174            .map(|(sq, new_plan)| ScalarSubqueryLink {
175                plan: new_plan,
176                index: sq.index,
177            })
178            .collect();
179        Ok(Arc::new(ScalarSubqueryExec::new(
180            input,
181            subqueries,
182            self.results.clone(),
183        )))
184    }
185
186    fn reset_state(self: Arc<Self>) -> Result<Arc<dyn ExecutionPlan>> {
187        self.results.clear();
188        Ok(Arc::new(ScalarSubqueryExec {
189            input: Arc::clone(&self.input),
190            subqueries: self.subqueries.clone(),
191            subquery_future: Arc::default(),
192            results: self.results.clone(),
193            cache: Arc::clone(&self.cache),
194        }))
195    }
196
197    fn execute(
198        &self,
199        partition: usize,
200        context: Arc<TaskContext>,
201    ) -> Result<SendableRecordBatchStream> {
202        let subqueries = self.subqueries.clone();
203        let results = self.results.clone();
204        let subquery_ctx = Arc::clone(&context);
205        let mut subquery_future = self.subquery_future.try_once(move || {
206            Ok(async move { execute_subqueries(subqueries, results, subquery_ctx).await })
207        })?;
208        let input = Arc::clone(&self.input);
209        let schema = self.schema();
210
211        Ok(Box::pin(RecordBatchStreamAdapter::new(
212            schema,
213            futures::stream::once(async move {
214                // Execute all subqueries exactly once, even when multiple
215                // partitions call execute() concurrently.
216                wait_for_subqueries(&mut subquery_future).await?;
217
218                // Now that the subqueries have finished execution, we can
219                // safely execute the main input
220                input.execute(partition, context)
221            })
222            .try_flatten(),
223        )))
224    }
225
226    fn maintains_input_order(&self) -> Vec<bool> {
227        // Only the main input (first child); subquery children don't contribute
228        // to ordering.
229        self.true_for_input_only()
230    }
231
232    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
233        // ScalarSubqueryExec is a pass-through coordinator: it does not
234        // benefit from repartitioning any child directly below it.
235        vec![false; self.subqueries.len() + 1]
236    }
237
238    fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
239        self.input.partition_statistics(partition)
240    }
241
242    fn cardinality_effect(&self) -> CardinalityEffect {
243        CardinalityEffect::Equal
244    }
245}
246
247/// Wait for the subquery execution future to complete.
248async fn wait_for_subqueries(fut: &mut OnceFut<()>) -> Result<()> {
249    std::future::poll_fn(|cx| fut.get_shared(cx)).await?;
250    Ok(())
251}
252
253async fn execute_subqueries(
254    subqueries: Vec<ScalarSubqueryLink>,
255    results: ScalarSubqueryResults,
256    context: Arc<TaskContext>,
257) -> Result<()> {
258    // Evaluate subqueries in parallel; wait for them all to finish evaluation
259    // before returning.
260    let futures = subqueries.iter().map(|sq| {
261        let plan = Arc::clone(&sq.plan);
262        let ctx = Arc::clone(&context);
263        let results = results.clone();
264        let index = sq.index;
265        async move {
266            let value = execute_scalar_subquery(plan, ctx).await?;
267            results.set(index, value)?;
268            Ok(()) as Result<()>
269        }
270    });
271    futures::future::try_join_all(futures).await?;
272    Ok(())
273}
274
275/// Execute a single subquery plan and extract the scalar value.
276/// Returns NULL for 0 rows, the scalar value for exactly 1 row,
277/// or an error for >1 rows.
278async fn execute_scalar_subquery(
279    plan: Arc<dyn ExecutionPlan>,
280    context: Arc<TaskContext>,
281) -> Result<ScalarValue> {
282    let schema = plan.schema();
283    if schema.fields().len() != 1 {
284        // Should be enforced by the physical planner.
285        return internal_err!(
286            "Scalar subquery must return exactly one column, got {}",
287            schema.fields().len()
288        );
289    }
290
291    let mut stream = crate::execute_stream(plan, context)?;
292    let mut result: Option<ScalarValue> = None;
293
294    while let Some(batch) = stream.next().await.transpose()? {
295        if batch.num_rows() == 0 {
296            continue;
297        }
298        if result.is_some() || batch.num_rows() > 1 {
299            return exec_err!("Scalar subquery returned more than one row");
300        }
301        result = Some(ScalarValue::try_from_array(batch.column(0), 0)?);
302    }
303
304    // 0 rows → typed NULL per SQL semantics
305    match result {
306        Some(v) => Ok(v),
307        None => ScalarValue::try_from(schema.field(0).data_type()),
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::test::{self, TestMemoryExec};
315    use crate::{
316        execution_plan::reset_plan_states,
317        projection::{ProjectionExec, ProjectionExpr},
318    };
319
320    use std::sync::atomic::{AtomicUsize, Ordering};
321
322    use crate::test::exec::ErrorExec;
323    use arrow::array::{Int32Array, Int64Array};
324    use arrow::datatypes::{DataType, Field, Schema};
325    use arrow::record_batch::RecordBatch;
326    use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr;
327
328    enum ExpectedSubqueryResult {
329        Value(ScalarValue),
330        Error(&'static str),
331    }
332
333    #[derive(Debug)]
334    struct CountingExec {
335        inner: Arc<dyn ExecutionPlan>,
336        execute_calls: Arc<AtomicUsize>,
337    }
338
339    impl CountingExec {
340        fn new(inner: Arc<dyn ExecutionPlan>, execute_calls: Arc<AtomicUsize>) -> Self {
341            Self {
342                inner,
343                execute_calls,
344            }
345        }
346    }
347
348    impl DisplayAs for CountingExec {
349        fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
350            match t {
351                DisplayFormatType::Default | DisplayFormatType::Verbose => {
352                    write!(f, "CountingExec")
353                }
354                DisplayFormatType::TreeRender => write!(f, ""),
355            }
356        }
357    }
358
359    impl ExecutionPlan for CountingExec {
360        fn name(&self) -> &'static str {
361            "CountingExec"
362        }
363
364        fn properties(&self) -> &Arc<PlanProperties> {
365            self.inner.properties()
366        }
367
368        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
369            vec![&self.inner]
370        }
371
372        fn with_new_children(
373            self: Arc<Self>,
374            mut children: Vec<Arc<dyn ExecutionPlan>>,
375        ) -> Result<Arc<dyn ExecutionPlan>> {
376            Ok(Arc::new(Self::new(
377                children.remove(0),
378                Arc::clone(&self.execute_calls),
379            )))
380        }
381
382        fn execute(
383            &self,
384            partition: usize,
385            context: Arc<TaskContext>,
386        ) -> Result<SendableRecordBatchStream> {
387            self.execute_calls.fetch_add(1, Ordering::SeqCst);
388            self.inner.execute(partition, context)
389        }
390    }
391
392    fn make_subquery_plan(batches: Vec<RecordBatch>) -> Arc<dyn ExecutionPlan> {
393        let schema = batches[0].schema();
394        TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap()
395    }
396
397    fn int32_batch(values: Vec<i32>) -> RecordBatch {
398        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
399        RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(values))]).unwrap()
400    }
401
402    fn empty_int64_batch() -> RecordBatch {
403        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)]));
404        RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![] as Vec<i64>))])
405            .unwrap()
406    }
407
408    fn placeholder_input() -> Arc<dyn ExecutionPlan> {
409        Arc::new(crate::placeholder_row::PlaceholderRowExec::new(
410            test::aggr_test_schema(),
411        ))
412    }
413
414    fn single_subquery_exec(
415        input: Arc<dyn ExecutionPlan>,
416        subquery_plan: Arc<dyn ExecutionPlan>,
417        results: ScalarSubqueryResults,
418    ) -> ScalarSubqueryExec {
419        ScalarSubqueryExec::new(
420            input,
421            vec![ScalarSubqueryLink {
422                plan: subquery_plan,
423                index: SubqueryIndex::new(0),
424            }],
425            results,
426        )
427    }
428
429    fn scalar_subquery_projection_input(
430        results: ScalarSubqueryResults,
431    ) -> Result<Arc<dyn ExecutionPlan>> {
432        Ok(Arc::new(ProjectionExec::try_new(
433            vec![ProjectionExpr {
434                expr: Arc::new(ScalarSubqueryExpr::new(
435                    DataType::Int32,
436                    false,
437                    SubqueryIndex::new(0),
438                    results,
439                )),
440                alias: "sq".to_string(),
441            }],
442            placeholder_input(),
443        )?))
444    }
445
446    fn extract_single_int32_value(batches: &[RecordBatch]) -> i32 {
447        assert_eq!(batches.len(), 1);
448        let values = batches[0]
449            .column(0)
450            .as_any()
451            .downcast_ref::<Int32Array>()
452            .unwrap();
453        assert_eq!(values.len(), 1);
454        values.value(0)
455    }
456
457    #[tokio::test]
458    async fn test_execute_scalar_subquery_row_count_semantics() -> Result<()> {
459        for (name, plan, expected) in [
460            (
461                "single_row",
462                make_subquery_plan(vec![int32_batch(vec![42])]),
463                ExpectedSubqueryResult::Value(ScalarValue::Int32(Some(42))),
464            ),
465            (
466                "zero_rows",
467                make_subquery_plan(vec![empty_int64_batch()]),
468                ExpectedSubqueryResult::Value(ScalarValue::Int64(None)),
469            ),
470            (
471                "multiple_rows",
472                make_subquery_plan(vec![int32_batch(vec![1, 2, 3])]),
473                ExpectedSubqueryResult::Error("more than one row"),
474            ),
475        ] {
476            let actual =
477                execute_scalar_subquery(plan, Arc::new(TaskContext::default())).await;
478            match expected {
479                ExpectedSubqueryResult::Value(expected) => {
480                    assert_eq!(actual?, expected, "{name}");
481                }
482                ExpectedSubqueryResult::Error(expected) => {
483                    let err = actual.expect_err(name);
484                    assert!(
485                        err.to_string().contains(expected),
486                        "{name}: expected error containing '{expected}', got {err}"
487                    );
488                }
489            }
490        }
491
492        Ok(())
493    }
494
495    #[tokio::test]
496    async fn test_failed_subquery_is_not_retried() -> Result<()> {
497        let execute_calls = Arc::new(AtomicUsize::new(0));
498        let subquery_plan = Arc::new(CountingExec::new(
499            Arc::new(ErrorExec::new()),
500            Arc::clone(&execute_calls),
501        ));
502        let exec = single_subquery_exec(
503            placeholder_input(),
504            subquery_plan,
505            ScalarSubqueryResults::new(1),
506        );
507
508        let ctx = Arc::new(TaskContext::default());
509        let stream = exec.execute(0, Arc::clone(&ctx))?;
510        assert!(crate::common::collect(stream).await.is_err());
511
512        let stream = exec.execute(0, ctx)?;
513        assert!(crate::common::collect(stream).await.is_err());
514
515        assert_eq!(execute_calls.load(Ordering::SeqCst), 1);
516        Ok(())
517    }
518
519    #[tokio::test]
520    async fn test_reset_state_clears_results_and_reexecutes_subqueries() -> Result<()> {
521        let execute_calls = Arc::new(AtomicUsize::new(0));
522        let results = ScalarSubqueryResults::new(1);
523        let subquery_plan = Arc::new(CountingExec::new(
524            make_subquery_plan(vec![int32_batch(vec![42])]),
525            Arc::clone(&execute_calls),
526        ));
527        let exec: Arc<dyn ExecutionPlan> = Arc::new(single_subquery_exec(
528            scalar_subquery_projection_input(results.clone())?,
529            subquery_plan,
530            results.clone(),
531        ));
532
533        let batches =
534            crate::common::collect(exec.execute(0, Arc::new(TaskContext::default()))?)
535                .await?;
536        assert_eq!(extract_single_int32_value(&batches), 42);
537        assert_eq!(
538            results.get(SubqueryIndex::new(0)),
539            Some(ScalarValue::Int32(Some(42)))
540        );
541
542        let reset_exec = reset_plan_states(Arc::clone(&exec))?;
543        assert_eq!(results.get(SubqueryIndex::new(0)), None);
544
545        let reset_batches = crate::common::collect(
546            reset_exec.execute(0, Arc::new(TaskContext::default()))?,
547        )
548        .await?;
549        assert_eq!(extract_single_int32_value(&reset_batches), 42);
550        assert_eq!(
551            results.get(SubqueryIndex::new(0)),
552            Some(ScalarValue::Int32(Some(42)))
553        );
554        assert_eq!(execute_calls.load(Ordering::SeqCst), 2);
555
556        Ok(())
557    }
558}