1use std::fmt;
28use std::sync::Arc;
29
30use datafusion_common::{Result, ScalarValue, Statistics, exec_err, internal_err};
31use datafusion_execution::TaskContext;
32use datafusion_expr::execution_props::{ScalarSubqueryResults, SubqueryIndex};
33
34use crate::execution_plan::{CardinalityEffect, ExecutionPlan, PlanProperties};
35use crate::joins::utils::{OnceAsync, OnceFut};
36use crate::stream::RecordBatchStreamAdapter;
37use crate::{DisplayAs, DisplayFormatType, SendableRecordBatchStream};
38
39use futures::StreamExt;
40use futures::TryStreamExt;
41
42#[derive(Debug, Clone)]
49pub struct ScalarSubqueryLink {
50 pub plan: Arc<dyn ExecutionPlan>,
52 pub index: SubqueryIndex,
54}
55
56#[derive(Debug)]
79pub struct ScalarSubqueryExec {
80 input: Arc<dyn ExecutionPlan>,
82 subqueries: Vec<ScalarSubqueryLink>,
84 subquery_future: Arc<OnceAsync<()>>,
86 results: ScalarSubqueryResults,
89 cache: Arc<PlanProperties>,
91}
92
93impl ScalarSubqueryExec {
94 pub fn new(
95 input: Arc<dyn ExecutionPlan>,
96 subqueries: Vec<ScalarSubqueryLink>,
97 results: ScalarSubqueryResults,
98 ) -> Self {
99 let cache = Arc::clone(input.properties());
100 Self {
101 input,
102 subqueries,
103 subquery_future: Arc::default(),
104 results,
105 cache,
106 }
107 }
108
109 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
110 &self.input
111 }
112
113 pub fn subqueries(&self) -> &[ScalarSubqueryLink] {
114 &self.subqueries
115 }
116
117 pub fn results(&self) -> &ScalarSubqueryResults {
118 &self.results
119 }
120
121 fn true_for_input_only(&self) -> Vec<bool> {
124 std::iter::once(true)
125 .chain(std::iter::repeat_n(false, self.subqueries.len()))
126 .collect()
127 }
128}
129
130impl DisplayAs for ScalarSubqueryExec {
131 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
132 match t {
133 DisplayFormatType::Default | DisplayFormatType::Verbose => {
134 write!(
135 f,
136 "ScalarSubqueryExec: subqueries={}",
137 self.subqueries.len()
138 )
139 }
140 DisplayFormatType::TreeRender => {
141 write!(f, "")
142 }
143 }
144 }
145}
146
147impl ExecutionPlan for ScalarSubqueryExec {
148 fn name(&self) -> &'static str {
149 "ScalarSubqueryExec"
150 }
151
152 fn properties(&self) -> &Arc<PlanProperties> {
153 &self.cache
154 }
155
156 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
157 let mut children = vec![&self.input];
158 for sq in &self.subqueries {
159 children.push(&sq.plan);
160 }
161 children
162 }
163
164 fn with_new_children(
165 self: Arc<Self>,
166 mut children: Vec<Arc<dyn ExecutionPlan>>,
167 ) -> Result<Arc<dyn ExecutionPlan>> {
168 let input = children.remove(0);
170 let subqueries = self
171 .subqueries
172 .iter()
173 .zip(children)
174 .map(|(sq, new_plan)| ScalarSubqueryLink {
175 plan: new_plan,
176 index: sq.index,
177 })
178 .collect();
179 Ok(Arc::new(ScalarSubqueryExec::new(
180 input,
181 subqueries,
182 self.results.clone(),
183 )))
184 }
185
186 fn reset_state(self: Arc<Self>) -> Result<Arc<dyn ExecutionPlan>> {
187 self.results.clear();
188 Ok(Arc::new(ScalarSubqueryExec {
189 input: Arc::clone(&self.input),
190 subqueries: self.subqueries.clone(),
191 subquery_future: Arc::default(),
192 results: self.results.clone(),
193 cache: Arc::clone(&self.cache),
194 }))
195 }
196
197 fn execute(
198 &self,
199 partition: usize,
200 context: Arc<TaskContext>,
201 ) -> Result<SendableRecordBatchStream> {
202 let subqueries = self.subqueries.clone();
203 let results = self.results.clone();
204 let subquery_ctx = Arc::clone(&context);
205 let mut subquery_future = self.subquery_future.try_once(move || {
206 Ok(async move { execute_subqueries(subqueries, results, subquery_ctx).await })
207 })?;
208 let input = Arc::clone(&self.input);
209 let schema = self.schema();
210
211 Ok(Box::pin(RecordBatchStreamAdapter::new(
212 schema,
213 futures::stream::once(async move {
214 wait_for_subqueries(&mut subquery_future).await?;
217
218 input.execute(partition, context)
221 })
222 .try_flatten(),
223 )))
224 }
225
226 fn maintains_input_order(&self) -> Vec<bool> {
227 self.true_for_input_only()
230 }
231
232 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
233 vec![false; self.subqueries.len() + 1]
236 }
237
238 fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
239 self.input.partition_statistics(partition)
240 }
241
242 fn cardinality_effect(&self) -> CardinalityEffect {
243 CardinalityEffect::Equal
244 }
245}
246
247async fn wait_for_subqueries(fut: &mut OnceFut<()>) -> Result<()> {
249 std::future::poll_fn(|cx| fut.get_shared(cx)).await?;
250 Ok(())
251}
252
253async fn execute_subqueries(
254 subqueries: Vec<ScalarSubqueryLink>,
255 results: ScalarSubqueryResults,
256 context: Arc<TaskContext>,
257) -> Result<()> {
258 let futures = subqueries.iter().map(|sq| {
261 let plan = Arc::clone(&sq.plan);
262 let ctx = Arc::clone(&context);
263 let results = results.clone();
264 let index = sq.index;
265 async move {
266 let value = execute_scalar_subquery(plan, ctx).await?;
267 results.set(index, value)?;
268 Ok(()) as Result<()>
269 }
270 });
271 futures::future::try_join_all(futures).await?;
272 Ok(())
273}
274
275async fn execute_scalar_subquery(
279 plan: Arc<dyn ExecutionPlan>,
280 context: Arc<TaskContext>,
281) -> Result<ScalarValue> {
282 let schema = plan.schema();
283 if schema.fields().len() != 1 {
284 return internal_err!(
286 "Scalar subquery must return exactly one column, got {}",
287 schema.fields().len()
288 );
289 }
290
291 let mut stream = crate::execute_stream(plan, context)?;
292 let mut result: Option<ScalarValue> = None;
293
294 while let Some(batch) = stream.next().await.transpose()? {
295 if batch.num_rows() == 0 {
296 continue;
297 }
298 if result.is_some() || batch.num_rows() > 1 {
299 return exec_err!("Scalar subquery returned more than one row");
300 }
301 result = Some(ScalarValue::try_from_array(batch.column(0), 0)?);
302 }
303
304 match result {
306 Some(v) => Ok(v),
307 None => ScalarValue::try_from(schema.field(0).data_type()),
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use crate::test::{self, TestMemoryExec};
315 use crate::{
316 execution_plan::reset_plan_states,
317 projection::{ProjectionExec, ProjectionExpr},
318 };
319
320 use std::sync::atomic::{AtomicUsize, Ordering};
321
322 use crate::test::exec::ErrorExec;
323 use arrow::array::{Int32Array, Int64Array};
324 use arrow::datatypes::{DataType, Field, Schema};
325 use arrow::record_batch::RecordBatch;
326 use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr;
327
328 enum ExpectedSubqueryResult {
329 Value(ScalarValue),
330 Error(&'static str),
331 }
332
333 #[derive(Debug)]
334 struct CountingExec {
335 inner: Arc<dyn ExecutionPlan>,
336 execute_calls: Arc<AtomicUsize>,
337 }
338
339 impl CountingExec {
340 fn new(inner: Arc<dyn ExecutionPlan>, execute_calls: Arc<AtomicUsize>) -> Self {
341 Self {
342 inner,
343 execute_calls,
344 }
345 }
346 }
347
348 impl DisplayAs for CountingExec {
349 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
350 match t {
351 DisplayFormatType::Default | DisplayFormatType::Verbose => {
352 write!(f, "CountingExec")
353 }
354 DisplayFormatType::TreeRender => write!(f, ""),
355 }
356 }
357 }
358
359 impl ExecutionPlan for CountingExec {
360 fn name(&self) -> &'static str {
361 "CountingExec"
362 }
363
364 fn properties(&self) -> &Arc<PlanProperties> {
365 self.inner.properties()
366 }
367
368 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
369 vec![&self.inner]
370 }
371
372 fn with_new_children(
373 self: Arc<Self>,
374 mut children: Vec<Arc<dyn ExecutionPlan>>,
375 ) -> Result<Arc<dyn ExecutionPlan>> {
376 Ok(Arc::new(Self::new(
377 children.remove(0),
378 Arc::clone(&self.execute_calls),
379 )))
380 }
381
382 fn execute(
383 &self,
384 partition: usize,
385 context: Arc<TaskContext>,
386 ) -> Result<SendableRecordBatchStream> {
387 self.execute_calls.fetch_add(1, Ordering::SeqCst);
388 self.inner.execute(partition, context)
389 }
390 }
391
392 fn make_subquery_plan(batches: Vec<RecordBatch>) -> Arc<dyn ExecutionPlan> {
393 let schema = batches[0].schema();
394 TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap()
395 }
396
397 fn int32_batch(values: Vec<i32>) -> RecordBatch {
398 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
399 RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(values))]).unwrap()
400 }
401
402 fn empty_int64_batch() -> RecordBatch {
403 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)]));
404 RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![] as Vec<i64>))])
405 .unwrap()
406 }
407
408 fn placeholder_input() -> Arc<dyn ExecutionPlan> {
409 Arc::new(crate::placeholder_row::PlaceholderRowExec::new(
410 test::aggr_test_schema(),
411 ))
412 }
413
414 fn single_subquery_exec(
415 input: Arc<dyn ExecutionPlan>,
416 subquery_plan: Arc<dyn ExecutionPlan>,
417 results: ScalarSubqueryResults,
418 ) -> ScalarSubqueryExec {
419 ScalarSubqueryExec::new(
420 input,
421 vec![ScalarSubqueryLink {
422 plan: subquery_plan,
423 index: SubqueryIndex::new(0),
424 }],
425 results,
426 )
427 }
428
429 fn scalar_subquery_projection_input(
430 results: ScalarSubqueryResults,
431 ) -> Result<Arc<dyn ExecutionPlan>> {
432 Ok(Arc::new(ProjectionExec::try_new(
433 vec![ProjectionExpr {
434 expr: Arc::new(ScalarSubqueryExpr::new(
435 DataType::Int32,
436 false,
437 SubqueryIndex::new(0),
438 results,
439 )),
440 alias: "sq".to_string(),
441 }],
442 placeholder_input(),
443 )?))
444 }
445
446 fn extract_single_int32_value(batches: &[RecordBatch]) -> i32 {
447 assert_eq!(batches.len(), 1);
448 let values = batches[0]
449 .column(0)
450 .as_any()
451 .downcast_ref::<Int32Array>()
452 .unwrap();
453 assert_eq!(values.len(), 1);
454 values.value(0)
455 }
456
457 #[tokio::test]
458 async fn test_execute_scalar_subquery_row_count_semantics() -> Result<()> {
459 for (name, plan, expected) in [
460 (
461 "single_row",
462 make_subquery_plan(vec![int32_batch(vec![42])]),
463 ExpectedSubqueryResult::Value(ScalarValue::Int32(Some(42))),
464 ),
465 (
466 "zero_rows",
467 make_subquery_plan(vec![empty_int64_batch()]),
468 ExpectedSubqueryResult::Value(ScalarValue::Int64(None)),
469 ),
470 (
471 "multiple_rows",
472 make_subquery_plan(vec![int32_batch(vec![1, 2, 3])]),
473 ExpectedSubqueryResult::Error("more than one row"),
474 ),
475 ] {
476 let actual =
477 execute_scalar_subquery(plan, Arc::new(TaskContext::default())).await;
478 match expected {
479 ExpectedSubqueryResult::Value(expected) => {
480 assert_eq!(actual?, expected, "{name}");
481 }
482 ExpectedSubqueryResult::Error(expected) => {
483 let err = actual.expect_err(name);
484 assert!(
485 err.to_string().contains(expected),
486 "{name}: expected error containing '{expected}', got {err}"
487 );
488 }
489 }
490 }
491
492 Ok(())
493 }
494
495 #[tokio::test]
496 async fn test_failed_subquery_is_not_retried() -> Result<()> {
497 let execute_calls = Arc::new(AtomicUsize::new(0));
498 let subquery_plan = Arc::new(CountingExec::new(
499 Arc::new(ErrorExec::new()),
500 Arc::clone(&execute_calls),
501 ));
502 let exec = single_subquery_exec(
503 placeholder_input(),
504 subquery_plan,
505 ScalarSubqueryResults::new(1),
506 );
507
508 let ctx = Arc::new(TaskContext::default());
509 let stream = exec.execute(0, Arc::clone(&ctx))?;
510 assert!(crate::common::collect(stream).await.is_err());
511
512 let stream = exec.execute(0, ctx)?;
513 assert!(crate::common::collect(stream).await.is_err());
514
515 assert_eq!(execute_calls.load(Ordering::SeqCst), 1);
516 Ok(())
517 }
518
519 #[tokio::test]
520 async fn test_reset_state_clears_results_and_reexecutes_subqueries() -> Result<()> {
521 let execute_calls = Arc::new(AtomicUsize::new(0));
522 let results = ScalarSubqueryResults::new(1);
523 let subquery_plan = Arc::new(CountingExec::new(
524 make_subquery_plan(vec![int32_batch(vec![42])]),
525 Arc::clone(&execute_calls),
526 ));
527 let exec: Arc<dyn ExecutionPlan> = Arc::new(single_subquery_exec(
528 scalar_subquery_projection_input(results.clone())?,
529 subquery_plan,
530 results.clone(),
531 ));
532
533 let batches =
534 crate::common::collect(exec.execute(0, Arc::new(TaskContext::default()))?)
535 .await?;
536 assert_eq!(extract_single_int32_value(&batches), 42);
537 assert_eq!(
538 results.get(SubqueryIndex::new(0)),
539 Some(ScalarValue::Int32(Some(42)))
540 );
541
542 let reset_exec = reset_plan_states(Arc::clone(&exec))?;
543 assert_eq!(results.get(SubqueryIndex::new(0)), None);
544
545 let reset_batches = crate::common::collect(
546 reset_exec.execute(0, Arc::new(TaskContext::default()))?,
547 )
548 .await?;
549 assert_eq!(extract_single_int32_value(&reset_batches), 42);
550 assert_eq!(
551 results.get(SubqueryIndex::new(0)),
552 Some(ScalarValue::Int32(Some(42)))
553 );
554 assert_eq!(execute_calls.load(Ordering::SeqCst), 2);
555
556 Ok(())
557 }
558}