Skip to main content

lance_datafusion/
exec.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Utilities for working with datafusion execution plans
5
6use std::{
7    collections::HashMap,
8    fmt::{self, Formatter},
9    sync::{Arc, Mutex, OnceLock},
10    time::Duration,
11};
12
13use chrono::{DateTime, Utc};
14
15use arrow_array::RecordBatch;
16use arrow_schema::Schema as ArrowSchema;
17use datafusion::physical_plan::metrics::MetricType;
18use datafusion::{
19    catalog::streaming::StreamingTable,
20    dataframe::DataFrame,
21    execution::{
22        TaskContext,
23        context::{SessionConfig, SessionContext},
24        disk_manager::DiskManagerBuilder,
25        memory_pool::FairSpillPool,
26        runtime_env::RuntimeEnvBuilder,
27    },
28    physical_plan::{
29        DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
30        analyze::AnalyzeExec,
31        display::DisplayableExecutionPlan,
32        execution_plan::{Boundedness, CardinalityEffect, EmissionType},
33        metrics::MetricValue,
34        stream::RecordBatchStreamAdapter,
35        streaming::PartitionStream,
36    },
37};
38use datafusion_common::{DataFusionError, Statistics};
39use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
40
41use futures::{StreamExt, stream};
42use lance_arrow::SchemaExt;
43use lance_core::{
44    Error, Result,
45    utils::{
46        futures::FinallyStreamExt,
47        tracing::{EXECUTION_PLAN_RUN, StreamTracingExt, TRACE_EXECUTION},
48    },
49};
50use log::{debug, info, warn};
51use tracing::Span;
52
53use crate::udf::register_functions;
54use crate::{
55    chunker::StrictBatchSizeStream,
56    utils::{
57        BYTES_READ_METRIC, INDEX_COMPARISONS_METRIC, INDICES_LOADED_METRIC, IOPS_METRIC,
58        MetricsExt, PARTS_LOADED_METRIC, REQUESTS_METRIC,
59    },
60};
61
62/// An source execution node created from an existing stream
63///
64/// It can only be used once, and will return the stream.  After that the node
65/// is exhausted.
66///
67/// Note: the stream should be finite, otherwise we will report datafusion properties
68/// incorrectly.
69pub struct OneShotExec {
70    stream: Mutex<Option<SendableRecordBatchStream>>,
71    // We save off a copy of the schema to speed up formatting and so ExecutionPlan::schema & display_as
72    // can still function after exhausted
73    schema: Arc<ArrowSchema>,
74    properties: PlanProperties,
75}
76
77impl OneShotExec {
78    /// Create a new instance from a given stream
79    pub fn new(stream: SendableRecordBatchStream) -> Self {
80        let schema = stream.schema();
81        Self {
82            stream: Mutex::new(Some(stream)),
83            schema: schema.clone(),
84            properties: PlanProperties::new(
85                EquivalenceProperties::new(schema),
86                Partitioning::RoundRobinBatch(1),
87                EmissionType::Incremental,
88                Boundedness::Bounded,
89            ),
90        }
91    }
92
93    pub fn from_batch(batch: RecordBatch) -> Self {
94        let schema = batch.schema();
95        let stream = Box::pin(RecordBatchStreamAdapter::new(
96            schema,
97            stream::iter(vec![Ok(batch)]),
98        ));
99        Self::new(stream)
100    }
101}
102
103impl std::fmt::Debug for OneShotExec {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        let stream = self.stream.lock().unwrap();
106        f.debug_struct("OneShotExec")
107            .field("exhausted", &stream.is_none())
108            .field("schema", self.schema.as_ref())
109            .finish()
110    }
111}
112
113impl DisplayAs for OneShotExec {
114    fn fmt_as(
115        &self,
116        t: datafusion::physical_plan::DisplayFormatType,
117        f: &mut std::fmt::Formatter,
118    ) -> std::fmt::Result {
119        let stream = self.stream.lock().unwrap();
120        let exhausted = if stream.is_some() { "" } else { "EXHAUSTED" };
121        let columns = self
122            .schema
123            .field_names()
124            .iter()
125            .cloned()
126            .cloned()
127            .collect::<Vec<_>>();
128        match t {
129            DisplayFormatType::Default | DisplayFormatType::Verbose => {
130                write!(
131                    f,
132                    "OneShotStream: {}columns=[{}]",
133                    exhausted,
134                    columns.join(",")
135                )
136            }
137            DisplayFormatType::TreeRender => {
138                write!(
139                    f,
140                    "OneShotStream\nexhausted={}\ncolumns=[{}]",
141                    exhausted,
142                    columns.join(",")
143                )
144            }
145        }
146    }
147}
148
149impl ExecutionPlan for OneShotExec {
150    fn name(&self) -> &str {
151        "OneShotExec"
152    }
153
154    fn as_any(&self) -> &dyn std::any::Any {
155        self
156    }
157
158    fn schema(&self) -> arrow_schema::SchemaRef {
159        self.schema.clone()
160    }
161
162    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
163        vec![]
164    }
165
166    fn with_new_children(
167        self: Arc<Self>,
168        children: Vec<Arc<dyn ExecutionPlan>>,
169    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
170        // OneShotExec has no children, so this should only be called with an empty vector
171        if !children.is_empty() {
172            return Err(datafusion_common::DataFusionError::Internal(
173                "OneShotExec does not support children".to_string(),
174            ));
175        }
176        Ok(self)
177    }
178
179    fn execute(
180        &self,
181        _partition: usize,
182        _context: Arc<datafusion::execution::TaskContext>,
183    ) -> datafusion_common::Result<SendableRecordBatchStream> {
184        let stream = self
185            .stream
186            .lock()
187            .map_err(|err| DataFusionError::Execution(err.to_string()))?
188            .take();
189        if let Some(stream) = stream {
190            Ok(stream)
191        } else {
192            Err(DataFusionError::Execution(
193                "OneShotExec has already been executed".to_string(),
194            ))
195        }
196    }
197
198    fn statistics(&self) -> datafusion_common::Result<datafusion_common::Statistics> {
199        Ok(Statistics::new_unknown(&self.schema))
200    }
201
202    fn properties(&self) -> &datafusion::physical_plan::PlanProperties {
203        &self.properties
204    }
205}
206
207struct TracedExec {
208    input: Arc<dyn ExecutionPlan>,
209    properties: PlanProperties,
210    span: Span,
211}
212
213impl TracedExec {
214    pub fn new(input: Arc<dyn ExecutionPlan>, span: Span) -> Self {
215        Self {
216            properties: input.properties().clone(),
217            input,
218            span,
219        }
220    }
221}
222
223impl DisplayAs for TracedExec {
224    fn fmt_as(
225        &self,
226        t: datafusion::physical_plan::DisplayFormatType,
227        f: &mut std::fmt::Formatter,
228    ) -> std::fmt::Result {
229        match t {
230            DisplayFormatType::Default
231            | DisplayFormatType::Verbose
232            | DisplayFormatType::TreeRender => {
233                write!(f, "TracedExec")
234            }
235        }
236    }
237}
238
239impl std::fmt::Debug for TracedExec {
240    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
241        write!(f, "TracedExec")
242    }
243}
244impl ExecutionPlan for TracedExec {
245    fn name(&self) -> &str {
246        "TracedExec"
247    }
248
249    fn as_any(&self) -> &dyn std::any::Any {
250        self
251    }
252
253    fn properties(&self) -> &PlanProperties {
254        &self.properties
255    }
256
257    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
258        vec![&self.input]
259    }
260
261    fn with_new_children(
262        self: Arc<Self>,
263        children: Vec<Arc<dyn ExecutionPlan>>,
264    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
265        Ok(Arc::new(Self {
266            input: children[0].clone(),
267            properties: self.properties.clone(),
268            span: self.span.clone(),
269        }))
270    }
271
272    fn execute(
273        &self,
274        partition: usize,
275        context: Arc<TaskContext>,
276    ) -> datafusion_common::Result<SendableRecordBatchStream> {
277        let _guard = self.span.enter();
278        let stream = self.input.execute(partition, context)?;
279        let schema = stream.schema();
280        let stream = stream.stream_in_span(self.span.clone());
281        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
282    }
283}
284
285/// Callback for reporting statistics after a scan
286pub type ExecutionStatsCallback = Arc<dyn Fn(&ExecutionSummaryCounts) + Send + Sync>;
287
288#[derive(Default, Clone)]
289pub struct LanceExecutionOptions {
290    pub use_spilling: bool,
291    pub mem_pool_size: Option<u64>,
292    pub max_temp_directory_size: Option<u64>,
293    pub batch_size: Option<usize>,
294    pub target_partition: Option<usize>,
295    pub execution_stats_callback: Option<ExecutionStatsCallback>,
296    pub skip_logging: bool,
297}
298
299impl std::fmt::Debug for LanceExecutionOptions {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        f.debug_struct("LanceExecutionOptions")
302            .field("use_spilling", &self.use_spilling)
303            .field("mem_pool_size", &self.mem_pool_size)
304            .field("max_temp_directory_size", &self.max_temp_directory_size)
305            .field("batch_size", &self.batch_size)
306            .field("target_partition", &self.target_partition)
307            .field("skip_logging", &self.skip_logging)
308            .field(
309                "execution_stats_callback",
310                &self.execution_stats_callback.is_some(),
311            )
312            .finish()
313    }
314}
315
316const DEFAULT_LANCE_MEM_POOL_SIZE: u64 = 100 * 1024 * 1024;
317const DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE: u64 = 100 * 1024 * 1024 * 1024; // 100GB
318
319impl LanceExecutionOptions {
320    pub fn mem_pool_size(&self) -> u64 {
321        self.mem_pool_size.unwrap_or_else(|| {
322            std::env::var("LANCE_MEM_POOL_SIZE")
323                .map(|s| match s.parse::<u64>() {
324                    Ok(v) => v,
325                    Err(e) => {
326                        warn!("Failed to parse LANCE_MEM_POOL_SIZE: {}, using default", e);
327                        DEFAULT_LANCE_MEM_POOL_SIZE
328                    }
329                })
330                .unwrap_or(DEFAULT_LANCE_MEM_POOL_SIZE)
331        })
332    }
333
334    pub fn max_temp_directory_size(&self) -> u64 {
335        self.max_temp_directory_size.unwrap_or_else(|| {
336            std::env::var("LANCE_MAX_TEMP_DIRECTORY_SIZE")
337                .map(|s| match s.parse::<u64>() {
338                    Ok(v) => v,
339                    Err(e) => {
340                        warn!(
341                            "Failed to parse LANCE_MAX_TEMP_DIRECTORY_SIZE: {}, using default",
342                            e
343                        );
344                        DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE
345                    }
346                })
347                .unwrap_or(DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE)
348        })
349    }
350
351    pub fn use_spilling(&self) -> bool {
352        if !self.use_spilling {
353            return false;
354        }
355        std::env::var("LANCE_BYPASS_SPILLING")
356            .map(|_| {
357                info!("Bypassing spilling because LANCE_BYPASS_SPILLING is set");
358                false
359            })
360            .unwrap_or(true)
361    }
362}
363
364pub fn new_session_context(options: &LanceExecutionOptions) -> SessionContext {
365    let mut session_config = SessionConfig::new();
366    let mut runtime_env_builder = RuntimeEnvBuilder::new();
367    if let Some(target_partition) = options.target_partition {
368        session_config = session_config.with_target_partitions(target_partition);
369    }
370    if options.use_spilling() {
371        let disk_manager_builder = DiskManagerBuilder::default()
372            .with_max_temp_directory_size(options.max_temp_directory_size());
373        runtime_env_builder = runtime_env_builder
374            .with_disk_manager_builder(disk_manager_builder)
375            .with_memory_pool(Arc::new(FairSpillPool::new(
376                options.mem_pool_size() as usize
377            )));
378    }
379    let runtime_env = runtime_env_builder.build_arc().unwrap();
380
381    let ctx = SessionContext::new_with_config_rt(session_config, runtime_env);
382    register_functions(&ctx);
383
384    ctx
385}
386
387/// Cache key for session contexts based on resolved configuration values.
388#[derive(Clone, Debug, PartialEq, Eq, Hash)]
389struct SessionContextCacheKey {
390    mem_pool_size: u64,
391    max_temp_directory_size: u64,
392    target_partition: Option<usize>,
393    use_spilling: bool,
394}
395
396impl SessionContextCacheKey {
397    fn from_options(options: &LanceExecutionOptions) -> Self {
398        Self {
399            mem_pool_size: options.mem_pool_size(),
400            max_temp_directory_size: options.max_temp_directory_size(),
401            target_partition: options.target_partition,
402            use_spilling: options.use_spilling(),
403        }
404    }
405}
406
407struct CachedSessionContext {
408    context: SessionContext,
409    last_access: std::time::Instant,
410}
411
412fn get_session_cache() -> &'static Mutex<HashMap<SessionContextCacheKey, CachedSessionContext>> {
413    static SESSION_CACHE: OnceLock<Mutex<HashMap<SessionContextCacheKey, CachedSessionContext>>> =
414        OnceLock::new();
415    SESSION_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
416}
417
418fn get_max_cache_size() -> usize {
419    const DEFAULT_CACHE_SIZE: usize = 4;
420    static MAX_CACHE_SIZE: OnceLock<usize> = OnceLock::new();
421    *MAX_CACHE_SIZE.get_or_init(|| {
422        std::env::var("LANCE_SESSION_CACHE_SIZE")
423            .ok()
424            .and_then(|v| v.parse().ok())
425            .unwrap_or(DEFAULT_CACHE_SIZE)
426    })
427}
428
429pub fn get_session_context(options: &LanceExecutionOptions) -> SessionContext {
430    let key = SessionContextCacheKey::from_options(options);
431    let mut cache = get_session_cache()
432        .lock()
433        .unwrap_or_else(|e| e.into_inner());
434
435    // If key exists, update access time and return
436    if let Some(entry) = cache.get_mut(&key) {
437        entry.last_access = std::time::Instant::now();
438        return entry.context.clone();
439    }
440
441    // Evict least recently used entry if cache is full
442    if cache.len() >= get_max_cache_size()
443        && let Some(lru_key) = cache
444            .iter()
445            .min_by_key(|(_, v)| v.last_access)
446            .map(|(k, _)| k.clone())
447    {
448        cache.remove(&lru_key);
449    }
450
451    let context = new_session_context(options);
452    cache.insert(
453        key,
454        CachedSessionContext {
455            context: context.clone(),
456            last_access: std::time::Instant::now(),
457        },
458    );
459    context
460}
461
462fn get_task_context(
463    session_ctx: &SessionContext,
464    options: &LanceExecutionOptions,
465) -> Arc<TaskContext> {
466    let mut state = session_ctx.state();
467    if let Some(batch_size) = options.batch_size.as_ref() {
468        state.config_mut().options_mut().execution.batch_size = *batch_size;
469    }
470
471    state.task_ctx()
472}
473
474#[derive(Default, Clone, Debug, PartialEq, Eq)]
475pub struct ExecutionSummaryCounts {
476    /// The number of I/O operations performed
477    pub iops: usize,
478    /// The number of requests made to the storage layer (may be larger or smaller than iops
479    /// depending on coalescing configuration)
480    pub requests: usize,
481    /// The number of bytes read during the execution of the plan
482    pub bytes_read: usize,
483    /// The number of top-level indices loaded
484    pub indices_loaded: usize,
485    /// The number of index partitions loaded
486    pub parts_loaded: usize,
487    /// The number of index comparisons performed (the exact meaning depends on the index type)
488    pub index_comparisons: usize,
489    /// Additional metrics for more detailed statistics.  These are subject to change in the future
490    /// and should only be used for debugging purposes.
491    pub all_counts: HashMap<String, usize>,
492}
493
494pub fn collect_execution_metrics(node: &dyn ExecutionPlan, counts: &mut ExecutionSummaryCounts) {
495    if let Some(metrics) = node.metrics() {
496        for (metric_name, count) in metrics.iter_counts() {
497            match metric_name.as_ref() {
498                IOPS_METRIC => counts.iops += count.value(),
499                REQUESTS_METRIC => counts.requests += count.value(),
500                BYTES_READ_METRIC => counts.bytes_read += count.value(),
501                INDICES_LOADED_METRIC => counts.indices_loaded += count.value(),
502                PARTS_LOADED_METRIC => counts.parts_loaded += count.value(),
503                INDEX_COMPARISONS_METRIC => counts.index_comparisons += count.value(),
504                _ => {
505                    let existing = counts
506                        .all_counts
507                        .entry(metric_name.as_ref().to_string())
508                        .or_insert(0);
509                    *existing += count.value();
510                }
511            }
512        }
513        // Include gauge-based I/O metrics (some nodes record I/O as gauges)
514        for (metric_name, gauge) in metrics.iter_gauges() {
515            match metric_name.as_ref() {
516                IOPS_METRIC => counts.iops += gauge.value(),
517                REQUESTS_METRIC => counts.requests += gauge.value(),
518                BYTES_READ_METRIC => counts.bytes_read += gauge.value(),
519                _ => {}
520            }
521        }
522    }
523    for child in node.children() {
524        collect_execution_metrics(child.as_ref(), counts);
525    }
526}
527
528fn report_plan_summary_metrics(plan: &dyn ExecutionPlan, options: &LanceExecutionOptions) {
529    let output_rows = plan
530        .metrics()
531        .map(|m| m.output_rows().unwrap_or(0))
532        .unwrap_or(0);
533    let mut counts = ExecutionSummaryCounts::default();
534    collect_execution_metrics(plan, &mut counts);
535    tracing::info!(
536        target: TRACE_EXECUTION,
537        r#type = EXECUTION_PLAN_RUN,
538        plan_summary = display_plan_one_liner(plan),
539        output_rows,
540        iops = counts.iops,
541        requests = counts.requests,
542        bytes_read = counts.bytes_read,
543        indices_loaded = counts.indices_loaded,
544        parts_loaded = counts.parts_loaded,
545        index_comparisons = counts.index_comparisons,
546    );
547    if let Some(callback) = options.execution_stats_callback.as_ref() {
548        callback(&counts);
549    }
550}
551
552/// Create a one-line rough summary of the given execution plan.
553///
554/// The summary just shows the name of the operators in the plan. It omits any
555/// details such as parameters or schema information.
556///
557/// Example: `Projection(Take(CoalesceBatches(Filter(LanceScan))))`
558fn display_plan_one_liner(plan: &dyn ExecutionPlan) -> String {
559    let mut output = String::new();
560
561    display_plan_one_liner_impl(plan, &mut output);
562
563    output
564}
565
566fn display_plan_one_liner_impl(plan: &dyn ExecutionPlan, output: &mut String) {
567    // Remove the "Exec" suffix from the plan name if present for brevity
568    let name = plan.name().trim_end_matches("Exec");
569    output.push_str(name);
570
571    let children = plan.children();
572    if !children.is_empty() {
573        output.push('(');
574        for (i, child) in children.iter().enumerate() {
575            if i > 0 {
576                output.push(',');
577            }
578            display_plan_one_liner_impl(child.as_ref(), output);
579        }
580        output.push(')');
581    }
582}
583
584/// Executes a plan using default session & runtime configuration
585///
586/// Only executes a single partition.  Panics if the plan has more than one partition.
587pub fn execute_plan(
588    plan: Arc<dyn ExecutionPlan>,
589    options: LanceExecutionOptions,
590) -> Result<SendableRecordBatchStream> {
591    if !options.skip_logging {
592        debug!(
593            "Executing plan:\n{}",
594            DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
595        );
596    }
597
598    let session_ctx = get_session_context(&options);
599
600    // NOTE: we are only executing the first partition here. Therefore, if
601    // the plan has more than one partition, we will be missing data.
602    assert_eq!(plan.properties().partitioning.partition_count(), 1);
603    let stream = plan.execute(0, get_task_context(&session_ctx, &options))?;
604
605    let schema = stream.schema();
606    let stream = stream.finally(move || {
607        if !options.skip_logging {
608            report_plan_summary_metrics(plan.as_ref(), &options);
609        }
610    });
611    Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
612}
613
614pub async fn analyze_plan(
615    plan: Arc<dyn ExecutionPlan>,
616    options: LanceExecutionOptions,
617) -> Result<String> {
618    // This is needed as AnalyzeExec launches a thread task per
619    // partition, and we want these to be connected to the parent span
620    let plan = Arc::new(TracedExec::new(plan, Span::current()));
621
622    let schema = plan.schema();
623    // TODO(tsaucer) I chose SUMMARY here but do we also want DEV?
624    let analyze = Arc::new(AnalyzeExec::new(
625        true,
626        true,
627        vec![MetricType::SUMMARY],
628        plan,
629        schema,
630    ));
631
632    let session_ctx = get_session_context(&options);
633    assert_eq!(analyze.properties().partitioning.partition_count(), 1);
634    let mut stream = analyze
635        .execute(0, get_task_context(&session_ctx, &options))
636        .map_err(|err| Error::io(format!("Failed to execute analyze plan: {}", err)))?;
637
638    // fully execute the plan
639    while (stream.next().await).is_some() {}
640
641    let result = format_plan(analyze);
642    Ok(result)
643}
644
645pub fn format_plan(plan: Arc<dyn ExecutionPlan>) -> String {
646    /// A visitor which calculates additional metrics for all the plans.
647    struct CalculateVisitor {
648        highest_index: usize,
649        index_to_elapsed: HashMap<usize, Duration>,
650    }
651
652    /// Result of calculating metrics for a subtree
653    struct SubtreeMetrics {
654        min_start: Option<DateTime<Utc>>,
655        max_end: Option<DateTime<Utc>>,
656    }
657
658    impl CalculateVisitor {
659        fn calculate_metrics(&mut self, plan: &Arc<dyn ExecutionPlan>) -> SubtreeMetrics {
660            self.highest_index += 1;
661            let plan_index = self.highest_index;
662
663            // Get timestamps for this node
664            let (mut min_start, mut max_end) = Self::node_timerange(plan);
665
666            // Accumulate from children
667            for child in plan.children() {
668                let child_metrics = self.calculate_metrics(child);
669                min_start = Self::min_option(min_start, child_metrics.min_start);
670                max_end = Self::max_option(max_end, child_metrics.max_end);
671            }
672
673            // Calculate wall clock duration for this subtree (only if we have timestamps)
674            let elapsed = match (min_start, max_end) {
675                (Some(start), Some(end)) => Some((end - start).to_std().unwrap_or_default()),
676                _ => None,
677            };
678
679            if let Some(e) = elapsed {
680                self.index_to_elapsed.insert(plan_index, e);
681            }
682
683            SubtreeMetrics { min_start, max_end }
684        }
685
686        fn node_timerange(
687            plan: &Arc<dyn ExecutionPlan>,
688        ) -> (Option<DateTime<Utc>>, Option<DateTime<Utc>>) {
689            let Some(metrics) = plan.metrics() else {
690                return (None, None);
691            };
692            let min_start = metrics
693                .iter()
694                .filter_map(|m| match m.value() {
695                    MetricValue::StartTimestamp(ts) => ts.value(),
696                    _ => None,
697                })
698                .min();
699            let max_end = metrics
700                .iter()
701                .filter_map(|m| match m.value() {
702                    MetricValue::EndTimestamp(ts) => ts.value(),
703                    _ => None,
704                })
705                .max();
706            (min_start, max_end)
707        }
708
709        fn min_option(a: Option<DateTime<Utc>>, b: Option<DateTime<Utc>>) -> Option<DateTime<Utc>> {
710            [a, b].into_iter().flatten().min()
711        }
712
713        fn max_option(a: Option<DateTime<Utc>>, b: Option<DateTime<Utc>>) -> Option<DateTime<Utc>> {
714            [a, b].into_iter().flatten().max()
715        }
716    }
717
718    /// A visitor which prints out all the plans.
719    struct PrintVisitor {
720        highest_index: usize,
721        indent: usize,
722    }
723    impl PrintVisitor {
724        fn write_output(
725            &mut self,
726            plan: &Arc<dyn ExecutionPlan>,
727            f: &mut Formatter,
728            calcs: &CalculateVisitor,
729        ) -> std::fmt::Result {
730            self.highest_index += 1;
731            write!(f, "{:indent$}", "", indent = self.indent * 2)?;
732
733            // Format the plan description
734            let displayable =
735                datafusion::physical_plan::display::DisplayableExecutionPlan::new(plan.as_ref());
736            let plan_str = displayable.one_line().to_string();
737            let plan_str = plan_str.trim();
738
739            // Write operator with elapsed time inserted after the name
740            match calcs.index_to_elapsed.get(&self.highest_index) {
741                Some(elapsed) => match plan_str.find(": ") {
742                    Some(i) => write!(
743                        f,
744                        "{}: elapsed={elapsed:?}, {}",
745                        &plan_str[..i],
746                        &plan_str[i + 2..]
747                    )?,
748                    None => write!(f, "{plan_str}, elapsed={elapsed:?}")?,
749                },
750                None => write!(f, "{plan_str}")?,
751            }
752
753            if let Some(metrics) = plan.metrics() {
754                let metrics = metrics
755                    .aggregate_by_name()
756                    .sorted_for_display()
757                    .timestamps_removed();
758
759                write!(f, ", metrics=[{metrics}]")?;
760            } else {
761                write!(f, ", metrics=[]")?;
762            }
763            writeln!(f)?;
764            self.indent += 1;
765            for child in plan.children() {
766                self.write_output(child, f, calcs)?;
767            }
768            self.indent -= 1;
769            std::fmt::Result::Ok(())
770        }
771    }
772    // A wrapper which prints out a plan.
773    struct PrintWrapper {
774        plan: Arc<dyn ExecutionPlan>,
775    }
776    impl fmt::Display for PrintWrapper {
777        fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
778            let mut calcs = CalculateVisitor {
779                highest_index: 0,
780                index_to_elapsed: HashMap::new(),
781            };
782            calcs.calculate_metrics(&self.plan);
783            let mut prints = PrintVisitor {
784                highest_index: 0,
785                indent: 0,
786            };
787            prints.write_output(&self.plan, f, &calcs)
788        }
789    }
790    let wrapper = PrintWrapper { plan };
791    format!("{}", wrapper)
792}
793
794pub trait SessionContextExt {
795    /// Creates a DataFrame for reading a stream of data
796    ///
797    /// This dataframe may only be queried once, future queries will fail
798    fn read_one_shot(
799        &self,
800        data: SendableRecordBatchStream,
801    ) -> datafusion::common::Result<DataFrame>;
802}
803
804pub struct OneShotPartitionStream {
805    data: Arc<Mutex<Option<SendableRecordBatchStream>>>,
806    schema: Arc<ArrowSchema>,
807}
808
809impl std::fmt::Debug for OneShotPartitionStream {
810    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
811        let data = self.data.lock().unwrap();
812        f.debug_struct("OneShotPartitionStream")
813            .field("exhausted", &data.is_none())
814            .field("schema", self.schema.as_ref())
815            .finish()
816    }
817}
818
819impl OneShotPartitionStream {
820    pub fn new(data: SendableRecordBatchStream) -> Self {
821        let schema = data.schema();
822        Self {
823            data: Arc::new(Mutex::new(Some(data))),
824            schema,
825        }
826    }
827}
828
829impl PartitionStream for OneShotPartitionStream {
830    fn schema(&self) -> &arrow_schema::SchemaRef {
831        &self.schema
832    }
833
834    fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
835        let mut stream = self.data.lock().unwrap();
836        stream
837            .take()
838            .expect("Attempt to consume a one shot dataframe multiple times")
839    }
840}
841
842impl SessionContextExt for SessionContext {
843    fn read_one_shot(
844        &self,
845        data: SendableRecordBatchStream,
846    ) -> datafusion::common::Result<DataFrame> {
847        let schema = data.schema();
848        let part_stream = Arc::new(OneShotPartitionStream::new(data));
849        let provider = StreamingTable::try_new(schema, vec![part_stream])?;
850        self.read_table(Arc::new(provider))
851    }
852}
853
854#[derive(Clone, Debug)]
855pub struct StrictBatchSizeExec {
856    input: Arc<dyn ExecutionPlan>,
857    batch_size: usize,
858}
859
860impl StrictBatchSizeExec {
861    pub fn new(input: Arc<dyn ExecutionPlan>, batch_size: usize) -> Self {
862        Self { input, batch_size }
863    }
864}
865
866impl DisplayAs for StrictBatchSizeExec {
867    fn fmt_as(
868        &self,
869        _t: datafusion::physical_plan::DisplayFormatType,
870        f: &mut std::fmt::Formatter,
871    ) -> std::fmt::Result {
872        write!(f, "StrictBatchSizeExec")
873    }
874}
875
876impl ExecutionPlan for StrictBatchSizeExec {
877    fn name(&self) -> &str {
878        "StrictBatchSizeExec"
879    }
880
881    fn as_any(&self) -> &dyn std::any::Any {
882        self
883    }
884
885    fn properties(&self) -> &PlanProperties {
886        self.input.properties()
887    }
888
889    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
890        vec![&self.input]
891    }
892
893    fn with_new_children(
894        self: Arc<Self>,
895        children: Vec<Arc<dyn ExecutionPlan>>,
896    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
897        Ok(Arc::new(Self {
898            input: children[0].clone(),
899            batch_size: self.batch_size,
900        }))
901    }
902
903    fn execute(
904        &self,
905        partition: usize,
906        context: Arc<TaskContext>,
907    ) -> datafusion_common::Result<SendableRecordBatchStream> {
908        let stream = self.input.execute(partition, context)?;
909        let schema = stream.schema();
910        let stream = StrictBatchSizeStream::new(stream, self.batch_size);
911        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
912    }
913
914    fn maintains_input_order(&self) -> Vec<bool> {
915        vec![true]
916    }
917
918    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
919        vec![false]
920    }
921
922    fn partition_statistics(
923        &self,
924        partition: Option<usize>,
925    ) -> datafusion_common::Result<Statistics> {
926        self.input.partition_statistics(partition)
927    }
928
929    fn cardinality_effect(&self) -> CardinalityEffect {
930        CardinalityEffect::Equal
931    }
932
933    fn supports_limit_pushdown(&self) -> bool {
934        true
935    }
936}
937
938#[cfg(test)]
939mod tests {
940    use super::*;
941
942    // Serialize cache tests since they share global state
943    static CACHE_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
944
945    #[test]
946    fn test_session_context_cache() {
947        let _lock = CACHE_TEST_LOCK.lock().unwrap();
948        let cache = get_session_cache();
949
950        // Clear any existing entries from other tests
951        cache.lock().unwrap().clear();
952
953        // Create first session with default options
954        let opts1 = LanceExecutionOptions::default();
955        let _ctx1 = get_session_context(&opts1);
956
957        {
958            let cache_guard = cache.lock().unwrap();
959            assert_eq!(cache_guard.len(), 1);
960        }
961
962        // Same options should reuse cached session (no new entry)
963        let _ctx1_again = get_session_context(&opts1);
964        {
965            let cache_guard = cache.lock().unwrap();
966            assert_eq!(cache_guard.len(), 1);
967        }
968
969        // Different options should create new entry
970        let opts2 = LanceExecutionOptions {
971            use_spilling: true,
972            ..Default::default()
973        };
974        let _ctx2 = get_session_context(&opts2);
975        {
976            let cache_guard = cache.lock().unwrap();
977            assert_eq!(cache_guard.len(), 2);
978        }
979    }
980
981    #[test]
982    fn test_session_context_cache_lru_eviction() {
983        let _lock = CACHE_TEST_LOCK.lock().unwrap();
984        let cache = get_session_cache();
985
986        // Clear any existing entries from other tests
987        cache.lock().unwrap().clear();
988
989        // Create 4 different configurations to fill the cache
990        let configs: Vec<LanceExecutionOptions> = (0..4)
991            .map(|i| LanceExecutionOptions {
992                mem_pool_size: Some((i + 1) as u64 * 1024 * 1024),
993                ..Default::default()
994            })
995            .collect();
996
997        for config in &configs {
998            let _ctx = get_session_context(config);
999        }
1000
1001        {
1002            let cache_guard = cache.lock().unwrap();
1003            assert_eq!(cache_guard.len(), 4);
1004        }
1005
1006        // Access config[0] to make it more recently used than config[1]
1007        // (config[0] was inserted first, so without this access it would be evicted)
1008        std::thread::sleep(std::time::Duration::from_millis(1));
1009        let _ctx = get_session_context(&configs[0]);
1010
1011        // Add a 5th configuration - should evict config[1] (now least recently used)
1012        let opts5 = LanceExecutionOptions {
1013            mem_pool_size: Some(5 * 1024 * 1024),
1014            ..Default::default()
1015        };
1016        let _ctx5 = get_session_context(&opts5);
1017
1018        {
1019            let cache_guard = cache.lock().unwrap();
1020            assert_eq!(cache_guard.len(), 4);
1021
1022            // config[0] should still be present (was accessed recently)
1023            let key0 = SessionContextCacheKey::from_options(&configs[0]);
1024            assert!(
1025                cache_guard.contains_key(&key0),
1026                "config[0] should still be cached after recent access"
1027            );
1028
1029            // config[1] should be evicted (was least recently used)
1030            let key1 = SessionContextCacheKey::from_options(&configs[1]);
1031            assert!(
1032                !cache_guard.contains_key(&key1),
1033                "config[1] should have been evicted"
1034            );
1035
1036            // New config should be present
1037            let key5 = SessionContextCacheKey::from_options(&opts5);
1038            assert!(
1039                cache_guard.contains_key(&key5),
1040                "new config should be cached"
1041            );
1042        }
1043    }
1044}