Skip to main content

lance_datafusion/
exec.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Utilities for working with datafusion execution plans
5
6use std::{
7    collections::HashMap,
8    fmt::{self, Formatter},
9    sync::{Arc, Mutex, OnceLock},
10    time::Duration,
11};
12
13use chrono::{DateTime, Utc};
14
15use arrow_array::RecordBatch;
16use arrow_schema::Schema as ArrowSchema;
17use datafusion::physical_plan::metrics::MetricType;
18use datafusion::{
19    catalog::streaming::StreamingTable,
20    dataframe::DataFrame,
21    execution::{
22        TaskContext,
23        context::{SessionConfig, SessionContext},
24        disk_manager::DiskManagerBuilder,
25        memory_pool::FairSpillPool,
26        runtime_env::RuntimeEnvBuilder,
27    },
28    physical_plan::{
29        DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
30        analyze::AnalyzeExec,
31        display::DisplayableExecutionPlan,
32        execution_plan::{Boundedness, CardinalityEffect, EmissionType},
33        metrics::MetricValue,
34        stream::RecordBatchStreamAdapter,
35        streaming::PartitionStream,
36    },
37};
38use datafusion_common::{DataFusionError, Statistics};
39use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
40
41use futures::{StreamExt, stream};
42use lance_arrow::SchemaExt;
43use lance_core::{
44    Error, Result,
45    utils::{
46        futures::FinallyStreamExt,
47        tracing::{EXECUTION_PLAN_RUN, StreamTracingExt, TRACE_EXECUTION},
48    },
49};
50use log::{debug, info, warn};
51use tracing::Span;
52
53use crate::udf::register_functions;
54use crate::{
55    chunker::StrictBatchSizeStream,
56    utils::{
57        BYTES_READ_METRIC, INDEX_COMPARISONS_METRIC, INDICES_LOADED_METRIC, IOPS_METRIC,
58        MetricsExt, PARTS_LOADED_METRIC, REQUESTS_METRIC,
59    },
60};
61
62/// An source execution node created from an existing stream
63///
64/// It can only be used once, and will return the stream.  After that the node
65/// is exhausted.
66///
67/// Note: the stream should be finite, otherwise we will report datafusion properties
68/// incorrectly.
69pub struct OneShotExec {
70    stream: Mutex<Option<SendableRecordBatchStream>>,
71    // We save off a copy of the schema to speed up formatting and so ExecutionPlan::schema & display_as
72    // can still function after exhausted
73    schema: Arc<ArrowSchema>,
74    properties: Arc<PlanProperties>,
75}
76
77impl OneShotExec {
78    /// Create a new instance from a given stream
79    pub fn new(stream: SendableRecordBatchStream) -> Self {
80        let schema = stream.schema();
81        Self {
82            stream: Mutex::new(Some(stream)),
83            schema: schema.clone(),
84            properties: Arc::new(PlanProperties::new(
85                EquivalenceProperties::new(schema),
86                Partitioning::RoundRobinBatch(1),
87                EmissionType::Incremental,
88                Boundedness::Bounded,
89            )),
90        }
91    }
92
93    pub fn from_batch(batch: RecordBatch) -> Self {
94        let schema = batch.schema();
95        let stream = Box::pin(RecordBatchStreamAdapter::new(
96            schema,
97            stream::iter(vec![Ok(batch)]),
98        ));
99        Self::new(stream)
100    }
101}
102
103impl std::fmt::Debug for OneShotExec {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        let stream = self.stream.lock().unwrap();
106        f.debug_struct("OneShotExec")
107            .field("exhausted", &stream.is_none())
108            .field("schema", self.schema.as_ref())
109            .finish()
110    }
111}
112
113impl DisplayAs for OneShotExec {
114    fn fmt_as(
115        &self,
116        t: datafusion::physical_plan::DisplayFormatType,
117        f: &mut std::fmt::Formatter,
118    ) -> std::fmt::Result {
119        let stream = self.stream.lock().unwrap();
120        let exhausted = if stream.is_some() { "" } else { "EXHAUSTED" };
121        let columns = self
122            .schema
123            .field_names()
124            .iter()
125            .cloned()
126            .cloned()
127            .collect::<Vec<_>>();
128        match t {
129            DisplayFormatType::Default | DisplayFormatType::Verbose => {
130                write!(
131                    f,
132                    "OneShotStream: {}columns=[{}]",
133                    exhausted,
134                    columns.join(",")
135                )
136            }
137            DisplayFormatType::TreeRender => {
138                write!(
139                    f,
140                    "OneShotStream\nexhausted={}\ncolumns=[{}]",
141                    exhausted,
142                    columns.join(",")
143                )
144            }
145        }
146    }
147}
148
149impl ExecutionPlan for OneShotExec {
150    fn name(&self) -> &str {
151        "OneShotExec"
152    }
153
154    fn as_any(&self) -> &dyn std::any::Any {
155        self
156    }
157
158    fn schema(&self) -> arrow_schema::SchemaRef {
159        self.schema.clone()
160    }
161
162    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
163        vec![]
164    }
165
166    fn with_new_children(
167        self: Arc<Self>,
168        children: Vec<Arc<dyn ExecutionPlan>>,
169    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
170        // OneShotExec has no children, so this should only be called with an empty vector
171        if !children.is_empty() {
172            return Err(datafusion_common::DataFusionError::Internal(
173                "OneShotExec does not support children".to_string(),
174            ));
175        }
176        Ok(self)
177    }
178
179    fn execute(
180        &self,
181        _partition: usize,
182        _context: Arc<datafusion::execution::TaskContext>,
183    ) -> datafusion_common::Result<SendableRecordBatchStream> {
184        let stream = self
185            .stream
186            .lock()
187            .map_err(|err| DataFusionError::Execution(err.to_string()))?
188            .take();
189        if let Some(stream) = stream {
190            Ok(stream)
191        } else {
192            Err(DataFusionError::Execution(
193                "OneShotExec has already been executed".to_string(),
194            ))
195        }
196    }
197
198    fn properties(&self) -> &Arc<datafusion::physical_plan::PlanProperties> {
199        &self.properties
200    }
201}
202
203struct TracedExec {
204    input: Arc<dyn ExecutionPlan>,
205    properties: Arc<PlanProperties>,
206    span: Span,
207}
208
209impl TracedExec {
210    pub fn new(input: Arc<dyn ExecutionPlan>, span: Span) -> Self {
211        Self {
212            properties: input.properties().clone(),
213            input,
214            span,
215        }
216    }
217}
218
219impl DisplayAs for TracedExec {
220    fn fmt_as(
221        &self,
222        t: datafusion::physical_plan::DisplayFormatType,
223        f: &mut std::fmt::Formatter,
224    ) -> std::fmt::Result {
225        match t {
226            DisplayFormatType::Default
227            | DisplayFormatType::Verbose
228            | DisplayFormatType::TreeRender => {
229                write!(f, "TracedExec")
230            }
231        }
232    }
233}
234
235impl std::fmt::Debug for TracedExec {
236    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
237        write!(f, "TracedExec")
238    }
239}
240impl ExecutionPlan for TracedExec {
241    fn name(&self) -> &str {
242        "TracedExec"
243    }
244
245    fn as_any(&self) -> &dyn std::any::Any {
246        self
247    }
248
249    fn properties(&self) -> &Arc<PlanProperties> {
250        &self.properties
251    }
252
253    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
254        vec![&self.input]
255    }
256
257    fn with_new_children(
258        self: Arc<Self>,
259        children: Vec<Arc<dyn ExecutionPlan>>,
260    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
261        Ok(Arc::new(Self {
262            input: children[0].clone(),
263            properties: self.properties.clone(),
264            span: self.span.clone(),
265        }))
266    }
267
268    fn execute(
269        &self,
270        partition: usize,
271        context: Arc<TaskContext>,
272    ) -> datafusion_common::Result<SendableRecordBatchStream> {
273        let _guard = self.span.enter();
274        let stream = self.input.execute(partition, context)?;
275        let schema = stream.schema();
276        let stream = stream.stream_in_span(self.span.clone());
277        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
278    }
279}
280
281/// Callback for reporting statistics after a scan
282pub type ExecutionStatsCallback = Arc<dyn Fn(&ExecutionSummaryCounts) + Send + Sync>;
283
284#[derive(Default, Clone)]
285pub struct LanceExecutionOptions {
286    pub use_spilling: bool,
287    pub mem_pool_size: Option<u64>,
288    pub max_temp_directory_size: Option<u64>,
289    pub batch_size: Option<usize>,
290    pub target_partition: Option<usize>,
291    pub execution_stats_callback: Option<ExecutionStatsCallback>,
292    pub skip_logging: bool,
293}
294
295impl std::fmt::Debug for LanceExecutionOptions {
296    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        f.debug_struct("LanceExecutionOptions")
298            .field("use_spilling", &self.use_spilling)
299            .field("mem_pool_size", &self.mem_pool_size)
300            .field("max_temp_directory_size", &self.max_temp_directory_size)
301            .field("batch_size", &self.batch_size)
302            .field("target_partition", &self.target_partition)
303            .field("skip_logging", &self.skip_logging)
304            .field(
305                "execution_stats_callback",
306                &self.execution_stats_callback.is_some(),
307            )
308            .finish()
309    }
310}
311
312const DEFAULT_LANCE_MEM_POOL_SIZE_PER_PARTITION: u64 = 100 * 1024 * 1024;
313const DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE: u64 = 100 * 1024 * 1024 * 1024; // 100GB
314
315impl LanceExecutionOptions {
316    pub fn mem_pool_size(&self) -> u64 {
317        let num_partitions = self.target_partition.unwrap_or(1) as u64;
318        self.mem_pool_size.unwrap_or_else(|| {
319            std::env::var("LANCE_MEM_POOL_SIZE")
320                .map(|s| match s.parse::<u64>() {
321                    Ok(v) => v,
322                    Err(e) => {
323                        warn!("Failed to parse LANCE_MEM_POOL_SIZE: {}, using default", e);
324                        DEFAULT_LANCE_MEM_POOL_SIZE_PER_PARTITION * num_partitions
325                    }
326                })
327                .unwrap_or(DEFAULT_LANCE_MEM_POOL_SIZE_PER_PARTITION * num_partitions)
328        })
329    }
330
331    pub fn max_temp_directory_size(&self) -> u64 {
332        self.max_temp_directory_size.unwrap_or_else(|| {
333            std::env::var("LANCE_MAX_TEMP_DIRECTORY_SIZE")
334                .map(|s| match s.parse::<u64>() {
335                    Ok(v) => v,
336                    Err(e) => {
337                        warn!(
338                            "Failed to parse LANCE_MAX_TEMP_DIRECTORY_SIZE: {}, using default",
339                            e
340                        );
341                        DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE
342                    }
343                })
344                .unwrap_or(DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE)
345        })
346    }
347
348    pub fn use_spilling(&self) -> bool {
349        if !self.use_spilling {
350            return false;
351        }
352        std::env::var("LANCE_BYPASS_SPILLING")
353            .map(|_| {
354                info!("Bypassing spilling because LANCE_BYPASS_SPILLING is set");
355                false
356            })
357            .unwrap_or(true)
358    }
359}
360
361pub fn new_session_context(options: &LanceExecutionOptions) -> SessionContext {
362    let mut session_config = SessionConfig::new();
363    let mut runtime_env_builder = RuntimeEnvBuilder::new();
364    if let Some(target_partition) = options.target_partition {
365        session_config = session_config.with_target_partitions(target_partition);
366    }
367    if options.use_spilling() {
368        let disk_manager_builder = DiskManagerBuilder::default()
369            .with_max_temp_directory_size(options.max_temp_directory_size());
370        runtime_env_builder = runtime_env_builder
371            .with_disk_manager_builder(disk_manager_builder)
372            .with_memory_pool(Arc::new(FairSpillPool::new(
373                options.mem_pool_size() as usize
374            )));
375    }
376    let runtime_env = runtime_env_builder.build_arc().unwrap();
377
378    let ctx = SessionContext::new_with_config_rt(session_config, runtime_env);
379    register_functions(&ctx);
380
381    ctx
382}
383
384/// Cache key for session contexts based on resolved configuration values.
385#[derive(Clone, Debug, PartialEq, Eq, Hash)]
386struct SessionContextCacheKey {
387    mem_pool_size: u64,
388    max_temp_directory_size: u64,
389    target_partition: Option<usize>,
390    use_spilling: bool,
391}
392
393impl SessionContextCacheKey {
394    fn from_options(options: &LanceExecutionOptions) -> Self {
395        Self {
396            mem_pool_size: options.mem_pool_size(),
397            max_temp_directory_size: options.max_temp_directory_size(),
398            target_partition: options.target_partition,
399            use_spilling: options.use_spilling(),
400        }
401    }
402}
403
404struct CachedSessionContext {
405    context: SessionContext,
406    last_access: std::time::Instant,
407}
408
409fn get_session_cache() -> &'static Mutex<HashMap<SessionContextCacheKey, CachedSessionContext>> {
410    static SESSION_CACHE: OnceLock<Mutex<HashMap<SessionContextCacheKey, CachedSessionContext>>> =
411        OnceLock::new();
412    SESSION_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
413}
414
415fn get_max_cache_size() -> usize {
416    const DEFAULT_CACHE_SIZE: usize = 4;
417    static MAX_CACHE_SIZE: OnceLock<usize> = OnceLock::new();
418    *MAX_CACHE_SIZE.get_or_init(|| {
419        std::env::var("LANCE_SESSION_CACHE_SIZE")
420            .ok()
421            .and_then(|v| v.parse().ok())
422            .unwrap_or(DEFAULT_CACHE_SIZE)
423    })
424}
425
426pub fn get_session_context(options: &LanceExecutionOptions) -> SessionContext {
427    let key = SessionContextCacheKey::from_options(options);
428    let mut cache = get_session_cache()
429        .lock()
430        .unwrap_or_else(|e| e.into_inner());
431
432    // If key exists, update access time and return
433    if let Some(entry) = cache.get_mut(&key) {
434        entry.last_access = std::time::Instant::now();
435        return entry.context.clone();
436    }
437
438    // Evict least recently used entry if cache is full
439    if cache.len() >= get_max_cache_size()
440        && let Some(lru_key) = cache
441            .iter()
442            .min_by_key(|(_, v)| v.last_access)
443            .map(|(k, _)| k.clone())
444    {
445        cache.remove(&lru_key);
446    }
447
448    let context = new_session_context(options);
449    cache.insert(
450        key,
451        CachedSessionContext {
452            context: context.clone(),
453            last_access: std::time::Instant::now(),
454        },
455    );
456    context
457}
458
459fn get_task_context(
460    session_ctx: &SessionContext,
461    options: &LanceExecutionOptions,
462) -> Arc<TaskContext> {
463    let mut state = session_ctx.state();
464    if let Some(batch_size) = options.batch_size.as_ref() {
465        state.config_mut().options_mut().execution.batch_size = *batch_size;
466    }
467
468    state.task_ctx()
469}
470
471#[derive(Default, Clone, Debug, PartialEq, Eq)]
472pub struct ExecutionSummaryCounts {
473    /// The number of I/O operations performed
474    pub iops: usize,
475    /// The number of requests made to the storage layer (may be larger or smaller than iops
476    /// depending on coalescing configuration)
477    pub requests: usize,
478    /// The number of bytes read during the execution of the plan
479    pub bytes_read: usize,
480    /// The number of top-level indices loaded
481    pub indices_loaded: usize,
482    /// The number of index partitions loaded
483    pub parts_loaded: usize,
484    /// The number of index comparisons performed (the exact meaning depends on the index type)
485    pub index_comparisons: usize,
486    /// Additional metrics for more detailed statistics.  These are subject to change in the future
487    /// and should only be used for debugging purposes.
488    pub all_counts: HashMap<String, usize>,
489    /// Additional time metrics for more detailed statistics, stored in nanoseconds.
490    /// These are subject to change in the future and should only be used for debugging purposes.
491    pub all_times: HashMap<String, usize>,
492}
493
494pub fn collect_execution_metrics(node: &dyn ExecutionPlan, counts: &mut ExecutionSummaryCounts) {
495    if let Some(metrics) = node.metrics() {
496        for (metric_name, count) in metrics.iter_counts() {
497            match metric_name.as_ref() {
498                IOPS_METRIC => counts.iops += count.value(),
499                REQUESTS_METRIC => counts.requests += count.value(),
500                BYTES_READ_METRIC => counts.bytes_read += count.value(),
501                INDICES_LOADED_METRIC => counts.indices_loaded += count.value(),
502                PARTS_LOADED_METRIC => counts.parts_loaded += count.value(),
503                INDEX_COMPARISONS_METRIC => counts.index_comparisons += count.value(),
504                _ => {
505                    let existing = counts
506                        .all_counts
507                        .entry(metric_name.as_ref().to_string())
508                        .or_insert(0);
509                    *existing += count.value();
510                }
511            }
512        }
513        for (metric_name, time) in metrics.iter_times() {
514            let existing = counts
515                .all_times
516                .entry(metric_name.as_ref().to_string())
517                .or_insert(0);
518            *existing += time.value();
519        }
520        // Include gauge-based I/O metrics (some nodes record I/O as gauges)
521        for (metric_name, gauge) in metrics.iter_gauges() {
522            match metric_name.as_ref() {
523                IOPS_METRIC => counts.iops += gauge.value(),
524                REQUESTS_METRIC => counts.requests += gauge.value(),
525                BYTES_READ_METRIC => counts.bytes_read += gauge.value(),
526                _ => {}
527            }
528        }
529    }
530    for child in node.children() {
531        collect_execution_metrics(child.as_ref(), counts);
532    }
533}
534
535fn report_plan_summary_metrics(plan: &dyn ExecutionPlan, options: &LanceExecutionOptions) {
536    let output_rows = plan
537        .metrics()
538        .map(|m| m.output_rows().unwrap_or(0))
539        .unwrap_or(0);
540    let mut counts = ExecutionSummaryCounts::default();
541    collect_execution_metrics(plan, &mut counts);
542    if !options.skip_logging {
543        tracing::info!(
544            target: TRACE_EXECUTION,
545            r#type = EXECUTION_PLAN_RUN,
546            plan_summary = display_plan_one_liner(plan),
547            output_rows,
548            iops = counts.iops,
549            requests = counts.requests,
550            bytes_read = counts.bytes_read,
551            indices_loaded = counts.indices_loaded,
552            parts_loaded = counts.parts_loaded,
553            index_comparisons = counts.index_comparisons,
554        );
555    }
556    if let Some(callback) = options.execution_stats_callback.as_ref() {
557        callback(&counts);
558    }
559}
560
561/// Create a one-line rough summary of the given execution plan.
562///
563/// The summary just shows the name of the operators in the plan. It omits any
564/// details such as parameters or schema information.
565///
566/// Example: `Projection(Take(CoalesceBatches(Filter(LanceScan))))`
567fn display_plan_one_liner(plan: &dyn ExecutionPlan) -> String {
568    let mut output = String::new();
569
570    display_plan_one_liner_impl(plan, &mut output);
571
572    output
573}
574
575fn display_plan_one_liner_impl(plan: &dyn ExecutionPlan, output: &mut String) {
576    // Remove the "Exec" suffix from the plan name if present for brevity
577    let name = plan.name().trim_end_matches("Exec");
578    output.push_str(name);
579
580    let children = plan.children();
581    if !children.is_empty() {
582        output.push('(');
583        for (i, child) in children.iter().enumerate() {
584            if i > 0 {
585                output.push(',');
586            }
587            display_plan_one_liner_impl(child.as_ref(), output);
588        }
589        output.push(')');
590    }
591}
592
593/// Executes a plan using default session & runtime configuration
594///
595/// Only executes a single partition.  Panics if the plan has more than one partition.
596pub fn execute_plan(
597    plan: Arc<dyn ExecutionPlan>,
598    options: LanceExecutionOptions,
599) -> Result<SendableRecordBatchStream> {
600    if !options.skip_logging {
601        debug!(
602            "Executing plan:\n{}",
603            DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
604        );
605    }
606
607    let session_ctx = get_session_context(&options);
608
609    // NOTE: we are only executing the first partition here. Therefore, if
610    // the plan has more than one partition, we will be missing data.
611    assert_eq!(plan.properties().partitioning.partition_count(), 1);
612    let stream = plan.execute(0, get_task_context(&session_ctx, &options))?;
613
614    let schema = stream.schema();
615    let stream = stream.finally(move || {
616        if !options.skip_logging || options.execution_stats_callback.is_some() {
617            report_plan_summary_metrics(plan.as_ref(), &options);
618        }
619    });
620    Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
621}
622
623pub async fn analyze_plan(
624    plan: Arc<dyn ExecutionPlan>,
625    options: LanceExecutionOptions,
626) -> Result<String> {
627    // This is needed as AnalyzeExec launches a thread task per
628    // partition, and we want these to be connected to the parent span
629    let plan = Arc::new(TracedExec::new(plan, Span::current()));
630
631    let schema = plan.schema();
632    // TODO(tsaucer) I chose SUMMARY here but do we also want DEV?
633    let analyze = Arc::new(AnalyzeExec::new(
634        true,
635        true,
636        vec![MetricType::SUMMARY],
637        plan,
638        schema,
639    ));
640
641    let session_ctx = get_session_context(&options);
642    assert_eq!(analyze.properties().partitioning.partition_count(), 1);
643    let mut stream = analyze
644        .execute(0, get_task_context(&session_ctx, &options))
645        .map_err(|err| Error::io(format!("Failed to execute analyze plan: {}", err)))?;
646
647    // fully execute the plan
648    while (stream.next().await).is_some() {}
649
650    let result = format_plan(analyze);
651    Ok(result)
652}
653
654pub fn format_plan(plan: Arc<dyn ExecutionPlan>) -> String {
655    /// A visitor which calculates additional metrics for all the plans.
656    struct CalculateVisitor {
657        highest_index: usize,
658        index_to_elapsed: HashMap<usize, Duration>,
659    }
660
661    /// Result of calculating metrics for a subtree
662    struct SubtreeMetrics {
663        min_start: Option<DateTime<Utc>>,
664        max_end: Option<DateTime<Utc>>,
665    }
666
667    impl CalculateVisitor {
668        fn calculate_metrics(&mut self, plan: &Arc<dyn ExecutionPlan>) -> SubtreeMetrics {
669            self.highest_index += 1;
670            let plan_index = self.highest_index;
671
672            // Get timestamps for this node
673            let (mut min_start, mut max_end) = Self::node_timerange(plan);
674
675            // Accumulate from children
676            for child in plan.children() {
677                let child_metrics = self.calculate_metrics(child);
678                min_start = Self::min_option(min_start, child_metrics.min_start);
679                max_end = Self::max_option(max_end, child_metrics.max_end);
680            }
681
682            // Calculate wall clock duration for this subtree (only if we have timestamps)
683            let elapsed = match (min_start, max_end) {
684                (Some(start), Some(end)) => Some((end - start).to_std().unwrap_or_default()),
685                _ => None,
686            };
687
688            if let Some(e) = elapsed {
689                self.index_to_elapsed.insert(plan_index, e);
690            }
691
692            SubtreeMetrics { min_start, max_end }
693        }
694
695        fn node_timerange(
696            plan: &Arc<dyn ExecutionPlan>,
697        ) -> (Option<DateTime<Utc>>, Option<DateTime<Utc>>) {
698            let Some(metrics) = plan.metrics() else {
699                return (None, None);
700            };
701            let min_start = metrics
702                .iter()
703                .filter_map(|m| match m.value() {
704                    MetricValue::StartTimestamp(ts) => ts.value(),
705                    _ => None,
706                })
707                .min();
708            let max_end = metrics
709                .iter()
710                .filter_map(|m| match m.value() {
711                    MetricValue::EndTimestamp(ts) => ts.value(),
712                    _ => None,
713                })
714                .max();
715            (min_start, max_end)
716        }
717
718        fn min_option(a: Option<DateTime<Utc>>, b: Option<DateTime<Utc>>) -> Option<DateTime<Utc>> {
719            [a, b].into_iter().flatten().min()
720        }
721
722        fn max_option(a: Option<DateTime<Utc>>, b: Option<DateTime<Utc>>) -> Option<DateTime<Utc>> {
723            [a, b].into_iter().flatten().max()
724        }
725    }
726
727    /// A visitor which prints out all the plans.
728    struct PrintVisitor {
729        highest_index: usize,
730        indent: usize,
731    }
732    impl PrintVisitor {
733        fn write_output(
734            &mut self,
735            plan: &Arc<dyn ExecutionPlan>,
736            f: &mut Formatter,
737            calcs: &CalculateVisitor,
738        ) -> std::fmt::Result {
739            self.highest_index += 1;
740            write!(f, "{:indent$}", "", indent = self.indent * 2)?;
741
742            // Format the plan description
743            let displayable =
744                datafusion::physical_plan::display::DisplayableExecutionPlan::new(plan.as_ref());
745            let plan_str = displayable.one_line().to_string();
746            let plan_str = plan_str.trim();
747
748            // Write operator with elapsed time inserted after the name
749            match calcs.index_to_elapsed.get(&self.highest_index) {
750                Some(elapsed) => match plan_str.find(": ") {
751                    Some(i) => write!(
752                        f,
753                        "{}: elapsed={elapsed:?}, {}",
754                        &plan_str[..i],
755                        &plan_str[i + 2..]
756                    )?,
757                    None => write!(f, "{plan_str}, elapsed={elapsed:?}")?,
758                },
759                None => write!(f, "{plan_str}")?,
760            }
761
762            if let Some(metrics) = plan.metrics() {
763                let metrics = metrics
764                    .aggregate_by_name()
765                    .sorted_for_display()
766                    .timestamps_removed();
767
768                write!(f, ", metrics=[{metrics}]")?;
769            } else {
770                write!(f, ", metrics=[]")?;
771            }
772            writeln!(f)?;
773            self.indent += 1;
774            for child in plan.children() {
775                self.write_output(child, f, calcs)?;
776            }
777            self.indent -= 1;
778            std::fmt::Result::Ok(())
779        }
780    }
781    // A wrapper which prints out a plan.
782    struct PrintWrapper {
783        plan: Arc<dyn ExecutionPlan>,
784    }
785    impl fmt::Display for PrintWrapper {
786        fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
787            let mut calcs = CalculateVisitor {
788                highest_index: 0,
789                index_to_elapsed: HashMap::new(),
790            };
791            calcs.calculate_metrics(&self.plan);
792            let mut prints = PrintVisitor {
793                highest_index: 0,
794                indent: 0,
795            };
796            prints.write_output(&self.plan, f, &calcs)
797        }
798    }
799    let wrapper = PrintWrapper { plan };
800    format!("{}", wrapper)
801}
802
803pub trait SessionContextExt {
804    /// Creates a DataFrame for reading a stream of data
805    ///
806    /// This dataframe may only be queried once, future queries will fail
807    fn read_one_shot(
808        &self,
809        data: SendableRecordBatchStream,
810    ) -> datafusion::common::Result<DataFrame>;
811}
812
813pub struct OneShotPartitionStream {
814    data: Arc<Mutex<Option<SendableRecordBatchStream>>>,
815    schema: Arc<ArrowSchema>,
816}
817
818impl std::fmt::Debug for OneShotPartitionStream {
819    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
820        let data = self.data.lock().unwrap();
821        f.debug_struct("OneShotPartitionStream")
822            .field("exhausted", &data.is_none())
823            .field("schema", self.schema.as_ref())
824            .finish()
825    }
826}
827
828impl OneShotPartitionStream {
829    pub fn new(data: SendableRecordBatchStream) -> Self {
830        let schema = data.schema();
831        Self {
832            data: Arc::new(Mutex::new(Some(data))),
833            schema,
834        }
835    }
836}
837
838impl PartitionStream for OneShotPartitionStream {
839    fn schema(&self) -> &arrow_schema::SchemaRef {
840        &self.schema
841    }
842
843    fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
844        let mut stream = self.data.lock().unwrap();
845        stream
846            .take()
847            .expect("Attempt to consume a one shot dataframe multiple times")
848    }
849}
850
851impl SessionContextExt for SessionContext {
852    fn read_one_shot(
853        &self,
854        data: SendableRecordBatchStream,
855    ) -> datafusion::common::Result<DataFrame> {
856        let schema = data.schema();
857        let part_stream = Arc::new(OneShotPartitionStream::new(data));
858        let provider = StreamingTable::try_new(schema, vec![part_stream])?;
859        self.read_table(Arc::new(provider))
860    }
861}
862
863#[derive(Clone, Debug)]
864pub struct StrictBatchSizeExec {
865    input: Arc<dyn ExecutionPlan>,
866    batch_size: usize,
867}
868
869impl StrictBatchSizeExec {
870    pub fn new(input: Arc<dyn ExecutionPlan>, batch_size: usize) -> Self {
871        Self { input, batch_size }
872    }
873}
874
875impl DisplayAs for StrictBatchSizeExec {
876    fn fmt_as(
877        &self,
878        _t: datafusion::physical_plan::DisplayFormatType,
879        f: &mut std::fmt::Formatter,
880    ) -> std::fmt::Result {
881        write!(f, "StrictBatchSizeExec")
882    }
883}
884
885impl ExecutionPlan for StrictBatchSizeExec {
886    fn name(&self) -> &str {
887        "StrictBatchSizeExec"
888    }
889
890    fn as_any(&self) -> &dyn std::any::Any {
891        self
892    }
893
894    fn properties(&self) -> &Arc<PlanProperties> {
895        self.input.properties()
896    }
897
898    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
899        vec![&self.input]
900    }
901
902    fn with_new_children(
903        self: Arc<Self>,
904        children: Vec<Arc<dyn ExecutionPlan>>,
905    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
906        Ok(Arc::new(Self {
907            input: children[0].clone(),
908            batch_size: self.batch_size,
909        }))
910    }
911
912    fn execute(
913        &self,
914        partition: usize,
915        context: Arc<TaskContext>,
916    ) -> datafusion_common::Result<SendableRecordBatchStream> {
917        let stream = self.input.execute(partition, context)?;
918        let schema = stream.schema();
919        let stream = StrictBatchSizeStream::new(stream, self.batch_size);
920        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
921    }
922
923    fn maintains_input_order(&self) -> Vec<bool> {
924        vec![true]
925    }
926
927    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
928        vec![false]
929    }
930
931    fn partition_statistics(
932        &self,
933        partition: Option<usize>,
934    ) -> datafusion_common::Result<Statistics> {
935        self.input.partition_statistics(partition)
936    }
937
938    fn cardinality_effect(&self) -> CardinalityEffect {
939        CardinalityEffect::Equal
940    }
941
942    fn supports_limit_pushdown(&self) -> bool {
943        true
944    }
945}
946
947/// Exec node that rechunks batches so no output batch exceeds `max_bytes`.
948///
949/// # Why this exists
950///
951/// DataFusion's sort operator cannot handle batches larger than the memory
952/// pool size.  When upstream operators produce very large batches this can
953/// cause the sort to fail.  This node caps batch sizes
954/// *before* the sort so the operation succeeds.  The trade-off is a
955/// potentially expensive deep copy of the batch data — see below — but that
956/// is preferable to failing the operation entirely.  This workaround may
957/// become unnecessary if a fix is upstreamed to DataFusion.
958///
959/// # Deep copy
960///
961/// After slicing a RecordBatch, `get_array_memory_size` still reports the
962/// size of the *original* backing buffers, not the slice.  To get accurate
963/// sizes the slices must be deep-copied.  This is a last resort and can be
964/// expensive for large batches, but the deep copy is only performed when a
965/// batch actually needs to be sliced — batches that are already within the
966/// target range pass through at zero cost.
967///
968/// If a single row exceeds `max_bytes`, execution fails with an error.
969#[derive(Clone, Debug)]
970pub struct HardCapBatchSizeExec {
971    input: Arc<dyn ExecutionPlan>,
972    max_bytes: usize,
973}
974
975impl HardCapBatchSizeExec {
976    pub fn new(input: Arc<dyn ExecutionPlan>, max_bytes: usize) -> Self {
977        Self { input, max_bytes }
978    }
979}
980
981impl DisplayAs for HardCapBatchSizeExec {
982    fn fmt_as(
983        &self,
984        _t: datafusion::physical_plan::DisplayFormatType,
985        f: &mut std::fmt::Formatter,
986    ) -> std::fmt::Result {
987        write!(f, "HardCapBatchSizeExec(max_bytes={})", self.max_bytes)
988    }
989}
990
991impl ExecutionPlan for HardCapBatchSizeExec {
992    fn name(&self) -> &str {
993        "HardCapBatchSizeExec"
994    }
995
996    fn as_any(&self) -> &dyn std::any::Any {
997        self
998    }
999
1000    fn properties(&self) -> &Arc<PlanProperties> {
1001        self.input.properties()
1002    }
1003
1004    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1005        vec![&self.input]
1006    }
1007
1008    fn with_new_children(
1009        self: Arc<Self>,
1010        children: Vec<Arc<dyn ExecutionPlan>>,
1011    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
1012        Ok(Arc::new(Self {
1013            input: children[0].clone(),
1014            max_bytes: self.max_bytes,
1015        }))
1016    }
1017
1018    fn execute(
1019        &self,
1020        partition: usize,
1021        context: Arc<TaskContext>,
1022    ) -> datafusion_common::Result<SendableRecordBatchStream> {
1023        let stream = self.input.execute(partition, context)?;
1024        let schema = stream.schema();
1025        let max_bytes = self.max_bytes;
1026        let rechunked = lance_arrow::stream::rechunk_stream_by_size_deep_copy(
1027            stream,
1028            schema.clone(),
1029            0,
1030            max_bytes,
1031        );
1032        // Check that no single-row batch exceeds the limit.
1033        let validated = rechunked.map(move |result| {
1034            let batch = result?;
1035            if batch.num_rows() == 1 && batch.get_array_memory_size() > max_bytes {
1036                return Err(DataFusionError::External(Box::new(Error::invalid_input(
1037                    format!(
1038                        "a single row is {} bytes which exceeds the maximum allowed batch \
1039                         size of {} bytes",
1040                        batch.get_array_memory_size(),
1041                        max_bytes,
1042                    ),
1043                ))));
1044            }
1045            Ok(batch)
1046        });
1047        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, validated)))
1048    }
1049
1050    fn maintains_input_order(&self) -> Vec<bool> {
1051        vec![true]
1052    }
1053
1054    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
1055        vec![false]
1056    }
1057
1058    fn partition_statistics(
1059        &self,
1060        partition: Option<usize>,
1061    ) -> datafusion_common::Result<Statistics> {
1062        self.input.partition_statistics(partition)
1063    }
1064
1065    fn cardinality_effect(&self) -> CardinalityEffect {
1066        CardinalityEffect::Equal
1067    }
1068
1069    fn supports_limit_pushdown(&self) -> bool {
1070        true
1071    }
1072}
1073
1074#[cfg(test)]
1075mod tests {
1076    use super::*;
1077
1078    // Serialize cache tests since they share global state
1079    static CACHE_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
1080
1081    #[test]
1082    fn test_session_context_cache() {
1083        let _lock = CACHE_TEST_LOCK.lock().unwrap();
1084        let cache = get_session_cache();
1085
1086        // Clear any existing entries from other tests
1087        cache.lock().unwrap().clear();
1088
1089        // Create first session with default options
1090        let opts1 = LanceExecutionOptions::default();
1091        let _ctx1 = get_session_context(&opts1);
1092
1093        {
1094            let cache_guard = cache.lock().unwrap();
1095            assert_eq!(cache_guard.len(), 1);
1096        }
1097
1098        // Same options should reuse cached session (no new entry)
1099        let _ctx1_again = get_session_context(&opts1);
1100        {
1101            let cache_guard = cache.lock().unwrap();
1102            assert_eq!(cache_guard.len(), 1);
1103        }
1104
1105        // Different options should create new entry
1106        let opts2 = LanceExecutionOptions {
1107            use_spilling: true,
1108            ..Default::default()
1109        };
1110        let _ctx2 = get_session_context(&opts2);
1111        {
1112            let cache_guard = cache.lock().unwrap();
1113            assert_eq!(cache_guard.len(), 2);
1114        }
1115    }
1116
1117    #[test]
1118    fn test_session_context_cache_lru_eviction() {
1119        let _lock = CACHE_TEST_LOCK.lock().unwrap();
1120        let cache = get_session_cache();
1121
1122        // Clear any existing entries from other tests
1123        cache.lock().unwrap().clear();
1124
1125        // Create 4 different configurations to fill the cache
1126        let configs: Vec<LanceExecutionOptions> = (0..4)
1127            .map(|i| LanceExecutionOptions {
1128                mem_pool_size: Some((i + 1) as u64 * 1024 * 1024),
1129                ..Default::default()
1130            })
1131            .collect();
1132
1133        for config in &configs {
1134            let _ctx = get_session_context(config);
1135        }
1136
1137        {
1138            let cache_guard = cache.lock().unwrap();
1139            assert_eq!(cache_guard.len(), 4);
1140        }
1141
1142        // Access config[0] to make it more recently used than config[1]
1143        // (config[0] was inserted first, so without this access it would be evicted)
1144        std::thread::sleep(std::time::Duration::from_millis(1));
1145        let _ctx = get_session_context(&configs[0]);
1146
1147        // Add a 5th configuration - should evict config[1] (now least recently used)
1148        let opts5 = LanceExecutionOptions {
1149            mem_pool_size: Some(5 * 1024 * 1024),
1150            ..Default::default()
1151        };
1152        let _ctx5 = get_session_context(&opts5);
1153
1154        {
1155            let cache_guard = cache.lock().unwrap();
1156            assert_eq!(cache_guard.len(), 4);
1157
1158            // config[0] should still be present (was accessed recently)
1159            let key0 = SessionContextCacheKey::from_options(&configs[0]);
1160            assert!(
1161                cache_guard.contains_key(&key0),
1162                "config[0] should still be cached after recent access"
1163            );
1164
1165            // config[1] should be evicted (was least recently used)
1166            let key1 = SessionContextCacheKey::from_options(&configs[1]);
1167            assert!(
1168                !cache_guard.contains_key(&key1),
1169                "config[1] should have been evicted"
1170            );
1171
1172            // New config should be present
1173            let key5 = SessionContextCacheKey::from_options(&opts5);
1174            assert!(
1175                cache_guard.contains_key(&key5),
1176                "new config should be cached"
1177            );
1178        }
1179    }
1180
1181    #[test]
1182    fn test_mem_pool_size_scales_with_partitions() {
1183        let default_per_partition = DEFAULT_LANCE_MEM_POOL_SIZE_PER_PARTITION;
1184
1185        // No partitions specified → defaults to 1 partition
1186        let opts = LanceExecutionOptions::default();
1187        assert_eq!(opts.mem_pool_size(), default_per_partition);
1188
1189        // 4 partitions → 4x the per-partition size
1190        let opts = LanceExecutionOptions {
1191            target_partition: Some(4),
1192            ..Default::default()
1193        };
1194        assert_eq!(opts.mem_pool_size(), default_per_partition * 4);
1195
1196        // 8 partitions → 8x the per-partition size
1197        let opts = LanceExecutionOptions {
1198            target_partition: Some(8),
1199            ..Default::default()
1200        };
1201        assert_eq!(opts.mem_pool_size(), default_per_partition * 8);
1202
1203        // Explicit mem_pool_size is not scaled
1204        let opts = LanceExecutionOptions {
1205            mem_pool_size: Some(50 * 1024 * 1024),
1206            target_partition: Some(8),
1207            ..Default::default()
1208        };
1209        assert_eq!(opts.mem_pool_size(), 50 * 1024 * 1024);
1210    }
1211}