Skip to main content

lance_datafusion/
exec.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Utilities for working with datafusion execution plans
5
6use std::{
7    collections::HashMap,
8    fmt::{self, Formatter},
9    sync::{Arc, Mutex, OnceLock},
10    time::Duration,
11};
12
13use chrono::{DateTime, Utc};
14
15use arrow_array::RecordBatch;
16use arrow_schema::Schema as ArrowSchema;
17use datafusion::physical_plan::metrics::MetricType;
18use datafusion::{
19    catalog::streaming::StreamingTable,
20    dataframe::DataFrame,
21    execution::{
22        context::{SessionConfig, SessionContext},
23        disk_manager::DiskManagerBuilder,
24        memory_pool::FairSpillPool,
25        runtime_env::RuntimeEnvBuilder,
26        TaskContext,
27    },
28    physical_plan::{
29        analyze::AnalyzeExec,
30        display::DisplayableExecutionPlan,
31        execution_plan::{Boundedness, CardinalityEffect, EmissionType},
32        metrics::MetricValue,
33        stream::RecordBatchStreamAdapter,
34        streaming::PartitionStream,
35        DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
36    },
37};
38use datafusion_common::{DataFusionError, Statistics};
39use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
40
41use futures::{stream, StreamExt};
42use lance_arrow::SchemaExt;
43use lance_core::{
44    utils::{
45        futures::FinallyStreamExt,
46        tracing::{StreamTracingExt, EXECUTION_PLAN_RUN, TRACE_EXECUTION},
47    },
48    Error, Result,
49};
50use log::{debug, info, warn};
51use snafu::location;
52use tracing::Span;
53
54use crate::udf::register_functions;
55use crate::{
56    chunker::StrictBatchSizeStream,
57    utils::{
58        MetricsExt, BYTES_READ_METRIC, INDEX_COMPARISONS_METRIC, INDICES_LOADED_METRIC,
59        IOPS_METRIC, PARTS_LOADED_METRIC, REQUESTS_METRIC,
60    },
61};
62
63/// An source execution node created from an existing stream
64///
65/// It can only be used once, and will return the stream.  After that the node
66/// is exhausted.
67///
68/// Note: the stream should be finite, otherwise we will report datafusion properties
69/// incorrectly.
70pub struct OneShotExec {
71    stream: Mutex<Option<SendableRecordBatchStream>>,
72    // We save off a copy of the schema to speed up formatting and so ExecutionPlan::schema & display_as
73    // can still function after exhausted
74    schema: Arc<ArrowSchema>,
75    properties: PlanProperties,
76}
77
78impl OneShotExec {
79    /// Create a new instance from a given stream
80    pub fn new(stream: SendableRecordBatchStream) -> Self {
81        let schema = stream.schema();
82        Self {
83            stream: Mutex::new(Some(stream)),
84            schema: schema.clone(),
85            properties: PlanProperties::new(
86                EquivalenceProperties::new(schema),
87                Partitioning::RoundRobinBatch(1),
88                EmissionType::Incremental,
89                Boundedness::Bounded,
90            ),
91        }
92    }
93
94    pub fn from_batch(batch: RecordBatch) -> Self {
95        let schema = batch.schema();
96        let stream = Box::pin(RecordBatchStreamAdapter::new(
97            schema,
98            stream::iter(vec![Ok(batch)]),
99        ));
100        Self::new(stream)
101    }
102}
103
104impl std::fmt::Debug for OneShotExec {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        let stream = self.stream.lock().unwrap();
107        f.debug_struct("OneShotExec")
108            .field("exhausted", &stream.is_none())
109            .field("schema", self.schema.as_ref())
110            .finish()
111    }
112}
113
114impl DisplayAs for OneShotExec {
115    fn fmt_as(
116        &self,
117        t: datafusion::physical_plan::DisplayFormatType,
118        f: &mut std::fmt::Formatter,
119    ) -> std::fmt::Result {
120        let stream = self.stream.lock().unwrap();
121        let exhausted = if stream.is_some() { "" } else { "EXHAUSTED" };
122        let columns = self
123            .schema
124            .field_names()
125            .iter()
126            .cloned()
127            .cloned()
128            .collect::<Vec<_>>();
129        match t {
130            DisplayFormatType::Default | DisplayFormatType::Verbose => {
131                write!(
132                    f,
133                    "OneShotStream: {}columns=[{}]",
134                    exhausted,
135                    columns.join(",")
136                )
137            }
138            DisplayFormatType::TreeRender => {
139                write!(
140                    f,
141                    "OneShotStream\nexhausted={}\ncolumns=[{}]",
142                    exhausted,
143                    columns.join(",")
144                )
145            }
146        }
147    }
148}
149
150impl ExecutionPlan for OneShotExec {
151    fn name(&self) -> &str {
152        "OneShotExec"
153    }
154
155    fn as_any(&self) -> &dyn std::any::Any {
156        self
157    }
158
159    fn schema(&self) -> arrow_schema::SchemaRef {
160        self.schema.clone()
161    }
162
163    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
164        vec![]
165    }
166
167    fn with_new_children(
168        self: Arc<Self>,
169        children: Vec<Arc<dyn ExecutionPlan>>,
170    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
171        // OneShotExec has no children, so this should only be called with an empty vector
172        if !children.is_empty() {
173            return Err(datafusion_common::DataFusionError::Internal(
174                "OneShotExec does not support children".to_string(),
175            ));
176        }
177        Ok(self)
178    }
179
180    fn execute(
181        &self,
182        _partition: usize,
183        _context: Arc<datafusion::execution::TaskContext>,
184    ) -> datafusion_common::Result<SendableRecordBatchStream> {
185        let stream = self
186            .stream
187            .lock()
188            .map_err(|err| DataFusionError::Execution(err.to_string()))?
189            .take();
190        if let Some(stream) = stream {
191            Ok(stream)
192        } else {
193            Err(DataFusionError::Execution(
194                "OneShotExec has already been executed".to_string(),
195            ))
196        }
197    }
198
199    fn statistics(&self) -> datafusion_common::Result<datafusion_common::Statistics> {
200        Ok(Statistics::new_unknown(&self.schema))
201    }
202
203    fn properties(&self) -> &datafusion::physical_plan::PlanProperties {
204        &self.properties
205    }
206}
207
208struct TracedExec {
209    input: Arc<dyn ExecutionPlan>,
210    properties: PlanProperties,
211    span: Span,
212}
213
214impl TracedExec {
215    pub fn new(input: Arc<dyn ExecutionPlan>, span: Span) -> Self {
216        Self {
217            properties: input.properties().clone(),
218            input,
219            span,
220        }
221    }
222}
223
224impl DisplayAs for TracedExec {
225    fn fmt_as(
226        &self,
227        t: datafusion::physical_plan::DisplayFormatType,
228        f: &mut std::fmt::Formatter,
229    ) -> std::fmt::Result {
230        match t {
231            DisplayFormatType::Default
232            | DisplayFormatType::Verbose
233            | DisplayFormatType::TreeRender => {
234                write!(f, "TracedExec")
235            }
236        }
237    }
238}
239
240impl std::fmt::Debug for TracedExec {
241    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
242        write!(f, "TracedExec")
243    }
244}
245impl ExecutionPlan for TracedExec {
246    fn name(&self) -> &str {
247        "TracedExec"
248    }
249
250    fn as_any(&self) -> &dyn std::any::Any {
251        self
252    }
253
254    fn properties(&self) -> &PlanProperties {
255        &self.properties
256    }
257
258    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
259        vec![&self.input]
260    }
261
262    fn with_new_children(
263        self: Arc<Self>,
264        children: Vec<Arc<dyn ExecutionPlan>>,
265    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
266        Ok(Arc::new(Self {
267            input: children[0].clone(),
268            properties: self.properties.clone(),
269            span: self.span.clone(),
270        }))
271    }
272
273    fn execute(
274        &self,
275        partition: usize,
276        context: Arc<TaskContext>,
277    ) -> datafusion_common::Result<SendableRecordBatchStream> {
278        let _guard = self.span.enter();
279        let stream = self.input.execute(partition, context)?;
280        let schema = stream.schema();
281        let stream = stream.stream_in_span(self.span.clone());
282        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
283    }
284}
285
286/// Callback for reporting statistics after a scan
287pub type ExecutionStatsCallback = Arc<dyn Fn(&ExecutionSummaryCounts) + Send + Sync>;
288
289#[derive(Default, Clone)]
290pub struct LanceExecutionOptions {
291    pub use_spilling: bool,
292    pub mem_pool_size: Option<u64>,
293    pub max_temp_directory_size: Option<u64>,
294    pub batch_size: Option<usize>,
295    pub target_partition: Option<usize>,
296    pub execution_stats_callback: Option<ExecutionStatsCallback>,
297    pub skip_logging: bool,
298}
299
300impl std::fmt::Debug for LanceExecutionOptions {
301    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        f.debug_struct("LanceExecutionOptions")
303            .field("use_spilling", &self.use_spilling)
304            .field("mem_pool_size", &self.mem_pool_size)
305            .field("max_temp_directory_size", &self.max_temp_directory_size)
306            .field("batch_size", &self.batch_size)
307            .field("target_partition", &self.target_partition)
308            .field("skip_logging", &self.skip_logging)
309            .field(
310                "execution_stats_callback",
311                &self.execution_stats_callback.is_some(),
312            )
313            .finish()
314    }
315}
316
317const DEFAULT_LANCE_MEM_POOL_SIZE: u64 = 100 * 1024 * 1024;
318const DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE: u64 = 100 * 1024 * 1024 * 1024; // 100GB
319
320impl LanceExecutionOptions {
321    pub fn mem_pool_size(&self) -> u64 {
322        self.mem_pool_size.unwrap_or_else(|| {
323            std::env::var("LANCE_MEM_POOL_SIZE")
324                .map(|s| match s.parse::<u64>() {
325                    Ok(v) => v,
326                    Err(e) => {
327                        warn!("Failed to parse LANCE_MEM_POOL_SIZE: {}, using default", e);
328                        DEFAULT_LANCE_MEM_POOL_SIZE
329                    }
330                })
331                .unwrap_or(DEFAULT_LANCE_MEM_POOL_SIZE)
332        })
333    }
334
335    pub fn max_temp_directory_size(&self) -> u64 {
336        self.max_temp_directory_size.unwrap_or_else(|| {
337            std::env::var("LANCE_MAX_TEMP_DIRECTORY_SIZE")
338                .map(|s| match s.parse::<u64>() {
339                    Ok(v) => v,
340                    Err(e) => {
341                        warn!(
342                            "Failed to parse LANCE_MAX_TEMP_DIRECTORY_SIZE: {}, using default",
343                            e
344                        );
345                        DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE
346                    }
347                })
348                .unwrap_or(DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE)
349        })
350    }
351
352    pub fn use_spilling(&self) -> bool {
353        if !self.use_spilling {
354            return false;
355        }
356        std::env::var("LANCE_BYPASS_SPILLING")
357            .map(|_| {
358                info!("Bypassing spilling because LANCE_BYPASS_SPILLING is set");
359                false
360            })
361            .unwrap_or(true)
362    }
363}
364
365pub fn new_session_context(options: &LanceExecutionOptions) -> SessionContext {
366    let mut session_config = SessionConfig::new();
367    let mut runtime_env_builder = RuntimeEnvBuilder::new();
368    if let Some(target_partition) = options.target_partition {
369        session_config = session_config.with_target_partitions(target_partition);
370    }
371    if options.use_spilling() {
372        let disk_manager_builder = DiskManagerBuilder::default()
373            .with_max_temp_directory_size(options.max_temp_directory_size());
374        runtime_env_builder = runtime_env_builder
375            .with_disk_manager_builder(disk_manager_builder)
376            .with_memory_pool(Arc::new(FairSpillPool::new(
377                options.mem_pool_size() as usize
378            )));
379    }
380    let runtime_env = runtime_env_builder.build_arc().unwrap();
381
382    let ctx = SessionContext::new_with_config_rt(session_config, runtime_env);
383    register_functions(&ctx);
384
385    ctx
386}
387
388/// Cache key for session contexts based on resolved configuration values.
389#[derive(Clone, Debug, PartialEq, Eq, Hash)]
390struct SessionContextCacheKey {
391    mem_pool_size: u64,
392    max_temp_directory_size: u64,
393    target_partition: Option<usize>,
394    use_spilling: bool,
395}
396
397impl SessionContextCacheKey {
398    fn from_options(options: &LanceExecutionOptions) -> Self {
399        Self {
400            mem_pool_size: options.mem_pool_size(),
401            max_temp_directory_size: options.max_temp_directory_size(),
402            target_partition: options.target_partition,
403            use_spilling: options.use_spilling(),
404        }
405    }
406}
407
408struct CachedSessionContext {
409    context: SessionContext,
410    last_access: std::time::Instant,
411}
412
413fn get_session_cache() -> &'static Mutex<HashMap<SessionContextCacheKey, CachedSessionContext>> {
414    static SESSION_CACHE: OnceLock<Mutex<HashMap<SessionContextCacheKey, CachedSessionContext>>> =
415        OnceLock::new();
416    SESSION_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
417}
418
419fn get_max_cache_size() -> usize {
420    const DEFAULT_CACHE_SIZE: usize = 4;
421    static MAX_CACHE_SIZE: OnceLock<usize> = OnceLock::new();
422    *MAX_CACHE_SIZE.get_or_init(|| {
423        std::env::var("LANCE_SESSION_CACHE_SIZE")
424            .ok()
425            .and_then(|v| v.parse().ok())
426            .unwrap_or(DEFAULT_CACHE_SIZE)
427    })
428}
429
430pub fn get_session_context(options: &LanceExecutionOptions) -> SessionContext {
431    let key = SessionContextCacheKey::from_options(options);
432    let mut cache = get_session_cache()
433        .lock()
434        .unwrap_or_else(|e| e.into_inner());
435
436    // If key exists, update access time and return
437    if let Some(entry) = cache.get_mut(&key) {
438        entry.last_access = std::time::Instant::now();
439        return entry.context.clone();
440    }
441
442    // Evict least recently used entry if cache is full
443    if cache.len() >= get_max_cache_size() {
444        if let Some(lru_key) = cache
445            .iter()
446            .min_by_key(|(_, v)| v.last_access)
447            .map(|(k, _)| k.clone())
448        {
449            cache.remove(&lru_key);
450        }
451    }
452
453    let context = new_session_context(options);
454    cache.insert(
455        key,
456        CachedSessionContext {
457            context: context.clone(),
458            last_access: std::time::Instant::now(),
459        },
460    );
461    context
462}
463
464fn get_task_context(
465    session_ctx: &SessionContext,
466    options: &LanceExecutionOptions,
467) -> Arc<TaskContext> {
468    let mut state = session_ctx.state();
469    if let Some(batch_size) = options.batch_size.as_ref() {
470        state.config_mut().options_mut().execution.batch_size = *batch_size;
471    }
472
473    state.task_ctx()
474}
475
476#[derive(Default, Clone, Debug, PartialEq, Eq)]
477pub struct ExecutionSummaryCounts {
478    /// The number of I/O operations performed
479    pub iops: usize,
480    /// The number of requests made to the storage layer (may be larger or smaller than iops
481    /// depending on coalescing configuration)
482    pub requests: usize,
483    /// The number of bytes read during the execution of the plan
484    pub bytes_read: usize,
485    /// The number of top-level indices loaded
486    pub indices_loaded: usize,
487    /// The number of index partitions loaded
488    pub parts_loaded: usize,
489    /// The number of index comparisons performed (the exact meaning depends on the index type)
490    pub index_comparisons: usize,
491    /// Additional metrics for more detailed statistics.  These are subject to change in the future
492    /// and should only be used for debugging purposes.
493    pub all_counts: HashMap<String, usize>,
494}
495
496fn visit_node(node: &dyn ExecutionPlan, counts: &mut ExecutionSummaryCounts) {
497    if let Some(metrics) = node.metrics() {
498        for (metric_name, count) in metrics.iter_counts() {
499            match metric_name.as_ref() {
500                IOPS_METRIC => counts.iops += count.value(),
501                REQUESTS_METRIC => counts.requests += count.value(),
502                BYTES_READ_METRIC => counts.bytes_read += count.value(),
503                INDICES_LOADED_METRIC => counts.indices_loaded += count.value(),
504                PARTS_LOADED_METRIC => counts.parts_loaded += count.value(),
505                INDEX_COMPARISONS_METRIC => counts.index_comparisons += count.value(),
506                _ => {
507                    let existing = counts
508                        .all_counts
509                        .entry(metric_name.as_ref().to_string())
510                        .or_insert(0);
511                    *existing += count.value();
512                }
513            }
514        }
515        // Include gauge-based I/O metrics (some nodes record I/O as gauges)
516        for (metric_name, gauge) in metrics.iter_gauges() {
517            match metric_name.as_ref() {
518                IOPS_METRIC => counts.iops += gauge.value(),
519                REQUESTS_METRIC => counts.requests += gauge.value(),
520                BYTES_READ_METRIC => counts.bytes_read += gauge.value(),
521                _ => {}
522            }
523        }
524    }
525    for child in node.children() {
526        visit_node(child.as_ref(), counts);
527    }
528}
529
530fn report_plan_summary_metrics(plan: &dyn ExecutionPlan, options: &LanceExecutionOptions) {
531    let output_rows = plan
532        .metrics()
533        .map(|m| m.output_rows().unwrap_or(0))
534        .unwrap_or(0);
535    let mut counts = ExecutionSummaryCounts::default();
536    visit_node(plan, &mut counts);
537    tracing::info!(
538        target: TRACE_EXECUTION,
539        r#type = EXECUTION_PLAN_RUN,
540        plan_summary = display_plan_one_liner(plan),
541        output_rows,
542        iops = counts.iops,
543        requests = counts.requests,
544        bytes_read = counts.bytes_read,
545        indices_loaded = counts.indices_loaded,
546        parts_loaded = counts.parts_loaded,
547        index_comparisons = counts.index_comparisons,
548    );
549    if let Some(callback) = options.execution_stats_callback.as_ref() {
550        callback(&counts);
551    }
552}
553
554/// Create a one-line rough summary of the given execution plan.
555///
556/// The summary just shows the name of the operators in the plan. It omits any
557/// details such as parameters or schema information.
558///
559/// Example: `Projection(Take(CoalesceBatches(Filter(LanceScan))))`
560fn display_plan_one_liner(plan: &dyn ExecutionPlan) -> String {
561    let mut output = String::new();
562
563    display_plan_one_liner_impl(plan, &mut output);
564
565    output
566}
567
568fn display_plan_one_liner_impl(plan: &dyn ExecutionPlan, output: &mut String) {
569    // Remove the "Exec" suffix from the plan name if present for brevity
570    let name = plan.name().trim_end_matches("Exec");
571    output.push_str(name);
572
573    let children = plan.children();
574    if !children.is_empty() {
575        output.push('(');
576        for (i, child) in children.iter().enumerate() {
577            if i > 0 {
578                output.push(',');
579            }
580            display_plan_one_liner_impl(child.as_ref(), output);
581        }
582        output.push(')');
583    }
584}
585
586/// Executes a plan using default session & runtime configuration
587///
588/// Only executes a single partition.  Panics if the plan has more than one partition.
589pub fn execute_plan(
590    plan: Arc<dyn ExecutionPlan>,
591    options: LanceExecutionOptions,
592) -> Result<SendableRecordBatchStream> {
593    if !options.skip_logging {
594        debug!(
595            "Executing plan:\n{}",
596            DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
597        );
598    }
599
600    let session_ctx = get_session_context(&options);
601
602    // NOTE: we are only executing the first partition here. Therefore, if
603    // the plan has more than one partition, we will be missing data.
604    assert_eq!(plan.properties().partitioning.partition_count(), 1);
605    let stream = plan.execute(0, get_task_context(&session_ctx, &options))?;
606
607    let schema = stream.schema();
608    let stream = stream.finally(move || {
609        if !options.skip_logging {
610            report_plan_summary_metrics(plan.as_ref(), &options);
611        }
612    });
613    Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
614}
615
616pub async fn analyze_plan(
617    plan: Arc<dyn ExecutionPlan>,
618    options: LanceExecutionOptions,
619) -> Result<String> {
620    // This is needed as AnalyzeExec launches a thread task per
621    // partition, and we want these to be connected to the parent span
622    let plan = Arc::new(TracedExec::new(plan, Span::current()));
623
624    let schema = plan.schema();
625    // TODO(tsaucer) I chose SUMMARY here but do we also want DEV?
626    let analyze = Arc::new(AnalyzeExec::new(
627        true,
628        true,
629        vec![MetricType::SUMMARY],
630        plan,
631        schema,
632    ));
633
634    let session_ctx = get_session_context(&options);
635    assert_eq!(analyze.properties().partitioning.partition_count(), 1);
636    let mut stream = analyze
637        .execute(0, get_task_context(&session_ctx, &options))
638        .map_err(|err| {
639            Error::io(
640                format!("Failed to execute analyze plan: {}", err),
641                location!(),
642            )
643        })?;
644
645    // fully execute the plan
646    while (stream.next().await).is_some() {}
647
648    let result = format_plan(analyze);
649    Ok(result)
650}
651
652pub fn format_plan(plan: Arc<dyn ExecutionPlan>) -> String {
653    /// A visitor which calculates additional metrics for all the plans.
654    struct CalculateVisitor {
655        highest_index: usize,
656        index_to_elapsed: HashMap<usize, Duration>,
657    }
658
659    /// Result of calculating metrics for a subtree
660    struct SubtreeMetrics {
661        min_start: Option<DateTime<Utc>>,
662        max_end: Option<DateTime<Utc>>,
663    }
664
665    impl CalculateVisitor {
666        fn calculate_metrics(&mut self, plan: &Arc<dyn ExecutionPlan>) -> SubtreeMetrics {
667            self.highest_index += 1;
668            let plan_index = self.highest_index;
669
670            // Get timestamps for this node
671            let (mut min_start, mut max_end) = Self::node_timerange(plan);
672
673            // Accumulate from children
674            for child in plan.children() {
675                let child_metrics = self.calculate_metrics(child);
676                min_start = Self::min_option(min_start, child_metrics.min_start);
677                max_end = Self::max_option(max_end, child_metrics.max_end);
678            }
679
680            // Calculate wall clock duration for this subtree (only if we have timestamps)
681            let elapsed = match (min_start, max_end) {
682                (Some(start), Some(end)) => Some((end - start).to_std().unwrap_or_default()),
683                _ => None,
684            };
685
686            if let Some(e) = elapsed {
687                self.index_to_elapsed.insert(plan_index, e);
688            }
689
690            SubtreeMetrics { min_start, max_end }
691        }
692
693        fn node_timerange(
694            plan: &Arc<dyn ExecutionPlan>,
695        ) -> (Option<DateTime<Utc>>, Option<DateTime<Utc>>) {
696            let Some(metrics) = plan.metrics() else {
697                return (None, None);
698            };
699            let min_start = metrics
700                .iter()
701                .filter_map(|m| match m.value() {
702                    MetricValue::StartTimestamp(ts) => ts.value(),
703                    _ => None,
704                })
705                .min();
706            let max_end = metrics
707                .iter()
708                .filter_map(|m| match m.value() {
709                    MetricValue::EndTimestamp(ts) => ts.value(),
710                    _ => None,
711                })
712                .max();
713            (min_start, max_end)
714        }
715
716        fn min_option(a: Option<DateTime<Utc>>, b: Option<DateTime<Utc>>) -> Option<DateTime<Utc>> {
717            [a, b].into_iter().flatten().min()
718        }
719
720        fn max_option(a: Option<DateTime<Utc>>, b: Option<DateTime<Utc>>) -> Option<DateTime<Utc>> {
721            [a, b].into_iter().flatten().max()
722        }
723    }
724
725    /// A visitor which prints out all the plans.
726    struct PrintVisitor {
727        highest_index: usize,
728        indent: usize,
729    }
730    impl PrintVisitor {
731        fn write_output(
732            &mut self,
733            plan: &Arc<dyn ExecutionPlan>,
734            f: &mut Formatter,
735            calcs: &CalculateVisitor,
736        ) -> std::fmt::Result {
737            self.highest_index += 1;
738            write!(f, "{:indent$}", "", indent = self.indent * 2)?;
739
740            // Format the plan description
741            let displayable =
742                datafusion::physical_plan::display::DisplayableExecutionPlan::new(plan.as_ref());
743            let plan_str = displayable.one_line().to_string();
744            let plan_str = plan_str.trim();
745
746            // Write operator with elapsed time inserted after the name
747            match calcs.index_to_elapsed.get(&self.highest_index) {
748                Some(elapsed) => match plan_str.find(": ") {
749                    Some(i) => write!(
750                        f,
751                        "{}: elapsed={elapsed:?}, {}",
752                        &plan_str[..i],
753                        &plan_str[i + 2..]
754                    )?,
755                    None => write!(f, "{plan_str}, elapsed={elapsed:?}")?,
756                },
757                None => write!(f, "{plan_str}")?,
758            }
759
760            if let Some(metrics) = plan.metrics() {
761                let metrics = metrics
762                    .aggregate_by_name()
763                    .sorted_for_display()
764                    .timestamps_removed();
765
766                write!(f, ", metrics=[{metrics}]")?;
767            } else {
768                write!(f, ", metrics=[]")?;
769            }
770            writeln!(f)?;
771            self.indent += 1;
772            for child in plan.children() {
773                self.write_output(child, f, calcs)?;
774            }
775            self.indent -= 1;
776            std::fmt::Result::Ok(())
777        }
778    }
779    // A wrapper which prints out a plan.
780    struct PrintWrapper {
781        plan: Arc<dyn ExecutionPlan>,
782    }
783    impl fmt::Display for PrintWrapper {
784        fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
785            let mut calcs = CalculateVisitor {
786                highest_index: 0,
787                index_to_elapsed: HashMap::new(),
788            };
789            calcs.calculate_metrics(&self.plan);
790            let mut prints = PrintVisitor {
791                highest_index: 0,
792                indent: 0,
793            };
794            prints.write_output(&self.plan, f, &calcs)
795        }
796    }
797    let wrapper = PrintWrapper { plan };
798    format!("{}", wrapper)
799}
800
801pub trait SessionContextExt {
802    /// Creates a DataFrame for reading a stream of data
803    ///
804    /// This dataframe may only be queried once, future queries will fail
805    fn read_one_shot(
806        &self,
807        data: SendableRecordBatchStream,
808    ) -> datafusion::common::Result<DataFrame>;
809}
810
811pub struct OneShotPartitionStream {
812    data: Arc<Mutex<Option<SendableRecordBatchStream>>>,
813    schema: Arc<ArrowSchema>,
814}
815
816impl std::fmt::Debug for OneShotPartitionStream {
817    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
818        let data = self.data.lock().unwrap();
819        f.debug_struct("OneShotPartitionStream")
820            .field("exhausted", &data.is_none())
821            .field("schema", self.schema.as_ref())
822            .finish()
823    }
824}
825
826impl OneShotPartitionStream {
827    pub fn new(data: SendableRecordBatchStream) -> Self {
828        let schema = data.schema();
829        Self {
830            data: Arc::new(Mutex::new(Some(data))),
831            schema,
832        }
833    }
834}
835
836impl PartitionStream for OneShotPartitionStream {
837    fn schema(&self) -> &arrow_schema::SchemaRef {
838        &self.schema
839    }
840
841    fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
842        let mut stream = self.data.lock().unwrap();
843        stream
844            .take()
845            .expect("Attempt to consume a one shot dataframe multiple times")
846    }
847}
848
849impl SessionContextExt for SessionContext {
850    fn read_one_shot(
851        &self,
852        data: SendableRecordBatchStream,
853    ) -> datafusion::common::Result<DataFrame> {
854        let schema = data.schema();
855        let part_stream = Arc::new(OneShotPartitionStream::new(data));
856        let provider = StreamingTable::try_new(schema, vec![part_stream])?;
857        self.read_table(Arc::new(provider))
858    }
859}
860
861#[derive(Clone, Debug)]
862pub struct StrictBatchSizeExec {
863    input: Arc<dyn ExecutionPlan>,
864    batch_size: usize,
865}
866
867impl StrictBatchSizeExec {
868    pub fn new(input: Arc<dyn ExecutionPlan>, batch_size: usize) -> Self {
869        Self { input, batch_size }
870    }
871}
872
873impl DisplayAs for StrictBatchSizeExec {
874    fn fmt_as(
875        &self,
876        _t: datafusion::physical_plan::DisplayFormatType,
877        f: &mut std::fmt::Formatter,
878    ) -> std::fmt::Result {
879        write!(f, "StrictBatchSizeExec")
880    }
881}
882
883impl ExecutionPlan for StrictBatchSizeExec {
884    fn name(&self) -> &str {
885        "StrictBatchSizeExec"
886    }
887
888    fn as_any(&self) -> &dyn std::any::Any {
889        self
890    }
891
892    fn properties(&self) -> &PlanProperties {
893        self.input.properties()
894    }
895
896    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
897        vec![&self.input]
898    }
899
900    fn with_new_children(
901        self: Arc<Self>,
902        children: Vec<Arc<dyn ExecutionPlan>>,
903    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
904        Ok(Arc::new(Self {
905            input: children[0].clone(),
906            batch_size: self.batch_size,
907        }))
908    }
909
910    fn execute(
911        &self,
912        partition: usize,
913        context: Arc<TaskContext>,
914    ) -> datafusion_common::Result<SendableRecordBatchStream> {
915        let stream = self.input.execute(partition, context)?;
916        let schema = stream.schema();
917        let stream = StrictBatchSizeStream::new(stream, self.batch_size);
918        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
919    }
920
921    fn maintains_input_order(&self) -> Vec<bool> {
922        vec![true]
923    }
924
925    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
926        vec![false]
927    }
928
929    fn partition_statistics(
930        &self,
931        partition: Option<usize>,
932    ) -> datafusion_common::Result<Statistics> {
933        self.input.partition_statistics(partition)
934    }
935
936    fn cardinality_effect(&self) -> CardinalityEffect {
937        CardinalityEffect::Equal
938    }
939
940    fn supports_limit_pushdown(&self) -> bool {
941        true
942    }
943}
944
945#[cfg(test)]
946mod tests {
947    use super::*;
948
949    // Serialize cache tests since they share global state
950    static CACHE_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
951
952    #[test]
953    fn test_session_context_cache() {
954        let _lock = CACHE_TEST_LOCK.lock().unwrap();
955        let cache = get_session_cache();
956
957        // Clear any existing entries from other tests
958        cache.lock().unwrap().clear();
959
960        // Create first session with default options
961        let opts1 = LanceExecutionOptions::default();
962        let _ctx1 = get_session_context(&opts1);
963
964        {
965            let cache_guard = cache.lock().unwrap();
966            assert_eq!(cache_guard.len(), 1);
967        }
968
969        // Same options should reuse cached session (no new entry)
970        let _ctx1_again = get_session_context(&opts1);
971        {
972            let cache_guard = cache.lock().unwrap();
973            assert_eq!(cache_guard.len(), 1);
974        }
975
976        // Different options should create new entry
977        let opts2 = LanceExecutionOptions {
978            use_spilling: true,
979            ..Default::default()
980        };
981        let _ctx2 = get_session_context(&opts2);
982        {
983            let cache_guard = cache.lock().unwrap();
984            assert_eq!(cache_guard.len(), 2);
985        }
986    }
987
988    #[test]
989    fn test_session_context_cache_lru_eviction() {
990        let _lock = CACHE_TEST_LOCK.lock().unwrap();
991        let cache = get_session_cache();
992
993        // Clear any existing entries from other tests
994        cache.lock().unwrap().clear();
995
996        // Create 4 different configurations to fill the cache
997        let configs: Vec<LanceExecutionOptions> = (0..4)
998            .map(|i| LanceExecutionOptions {
999                mem_pool_size: Some((i + 1) as u64 * 1024 * 1024),
1000                ..Default::default()
1001            })
1002            .collect();
1003
1004        for config in &configs {
1005            let _ctx = get_session_context(config);
1006        }
1007
1008        {
1009            let cache_guard = cache.lock().unwrap();
1010            assert_eq!(cache_guard.len(), 4);
1011        }
1012
1013        // Access config[0] to make it more recently used than config[1]
1014        // (config[0] was inserted first, so without this access it would be evicted)
1015        std::thread::sleep(std::time::Duration::from_millis(1));
1016        let _ctx = get_session_context(&configs[0]);
1017
1018        // Add a 5th configuration - should evict config[1] (now least recently used)
1019        let opts5 = LanceExecutionOptions {
1020            mem_pool_size: Some(5 * 1024 * 1024),
1021            ..Default::default()
1022        };
1023        let _ctx5 = get_session_context(&opts5);
1024
1025        {
1026            let cache_guard = cache.lock().unwrap();
1027            assert_eq!(cache_guard.len(), 4);
1028
1029            // config[0] should still be present (was accessed recently)
1030            let key0 = SessionContextCacheKey::from_options(&configs[0]);
1031            assert!(
1032                cache_guard.contains_key(&key0),
1033                "config[0] should still be cached after recent access"
1034            );
1035
1036            // config[1] should be evicted (was least recently used)
1037            let key1 = SessionContextCacheKey::from_options(&configs[1]);
1038            assert!(
1039                !cache_guard.contains_key(&key1),
1040                "config[1] should have been evicted"
1041            );
1042
1043            // New config should be present
1044            let key5 = SessionContextCacheKey::from_options(&opts5);
1045            assert!(
1046                cache_guard.contains_key(&key5),
1047                "new config should be cached"
1048            );
1049        }
1050    }
1051}