Skip to main content

lance_datafusion/
exec.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Utilities for working with datafusion execution plans
5
6use std::{
7    collections::HashMap,
8    fmt::{self, Formatter},
9    sync::{Arc, Mutex, OnceLock},
10    time::Duration,
11};
12
13use arrow_array::RecordBatch;
14use arrow_schema::Schema as ArrowSchema;
15use datafusion::{
16    catalog::streaming::StreamingTable,
17    dataframe::DataFrame,
18    execution::{
19        context::{SessionConfig, SessionContext},
20        disk_manager::DiskManagerBuilder,
21        memory_pool::FairSpillPool,
22        runtime_env::RuntimeEnvBuilder,
23        TaskContext,
24    },
25    physical_plan::{
26        analyze::AnalyzeExec,
27        display::DisplayableExecutionPlan,
28        execution_plan::{Boundedness, CardinalityEffect, EmissionType},
29        stream::RecordBatchStreamAdapter,
30        streaming::PartitionStream,
31        DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
32    },
33};
34use datafusion_common::{DataFusionError, Statistics};
35use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
36
37use futures::{stream, StreamExt};
38use lance_arrow::SchemaExt;
39use lance_core::{
40    utils::{
41        futures::FinallyStreamExt,
42        tracing::{StreamTracingExt, EXECUTION_PLAN_RUN, TRACE_EXECUTION},
43    },
44    Error, Result,
45};
46use log::{debug, info, warn};
47use snafu::location;
48use tracing::Span;
49
50use crate::udf::register_functions;
51use crate::{
52    chunker::StrictBatchSizeStream,
53    utils::{
54        MetricsExt, BYTES_READ_METRIC, INDEX_COMPARISONS_METRIC, INDICES_LOADED_METRIC,
55        IOPS_METRIC, PARTS_LOADED_METRIC, REQUESTS_METRIC,
56    },
57};
58
59/// An source execution node created from an existing stream
60///
61/// It can only be used once, and will return the stream.  After that the node
62/// is exhausted.
63///
64/// Note: the stream should be finite, otherwise we will report datafusion properties
65/// incorrectly.
66pub struct OneShotExec {
67    stream: Mutex<Option<SendableRecordBatchStream>>,
68    // We save off a copy of the schema to speed up formatting and so ExecutionPlan::schema & display_as
69    // can still function after exhausted
70    schema: Arc<ArrowSchema>,
71    properties: PlanProperties,
72}
73
74impl OneShotExec {
75    /// Create a new instance from a given stream
76    pub fn new(stream: SendableRecordBatchStream) -> Self {
77        let schema = stream.schema();
78        Self {
79            stream: Mutex::new(Some(stream)),
80            schema: schema.clone(),
81            properties: PlanProperties::new(
82                EquivalenceProperties::new(schema),
83                Partitioning::RoundRobinBatch(1),
84                EmissionType::Incremental,
85                Boundedness::Bounded,
86            ),
87        }
88    }
89
90    pub fn from_batch(batch: RecordBatch) -> Self {
91        let schema = batch.schema();
92        let stream = Box::pin(RecordBatchStreamAdapter::new(
93            schema,
94            stream::iter(vec![Ok(batch)]),
95        ));
96        Self::new(stream)
97    }
98}
99
100impl std::fmt::Debug for OneShotExec {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        let stream = self.stream.lock().unwrap();
103        f.debug_struct("OneShotExec")
104            .field("exhausted", &stream.is_none())
105            .field("schema", self.schema.as_ref())
106            .finish()
107    }
108}
109
110impl DisplayAs for OneShotExec {
111    fn fmt_as(
112        &self,
113        t: datafusion::physical_plan::DisplayFormatType,
114        f: &mut std::fmt::Formatter,
115    ) -> std::fmt::Result {
116        let stream = self.stream.lock().unwrap();
117        let exhausted = if stream.is_some() { "" } else { "EXHAUSTED" };
118        let columns = self
119            .schema
120            .field_names()
121            .iter()
122            .cloned()
123            .cloned()
124            .collect::<Vec<_>>();
125        match t {
126            DisplayFormatType::Default | DisplayFormatType::Verbose => {
127                write!(
128                    f,
129                    "OneShotStream: {}columns=[{}]",
130                    exhausted,
131                    columns.join(",")
132                )
133            }
134            DisplayFormatType::TreeRender => {
135                write!(
136                    f,
137                    "OneShotStream\nexhausted={}\ncolumns=[{}]",
138                    exhausted,
139                    columns.join(",")
140                )
141            }
142        }
143    }
144}
145
146impl ExecutionPlan for OneShotExec {
147    fn name(&self) -> &str {
148        "OneShotExec"
149    }
150
151    fn as_any(&self) -> &dyn std::any::Any {
152        self
153    }
154
155    fn schema(&self) -> arrow_schema::SchemaRef {
156        self.schema.clone()
157    }
158
159    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
160        vec![]
161    }
162
163    fn with_new_children(
164        self: Arc<Self>,
165        children: Vec<Arc<dyn ExecutionPlan>>,
166    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
167        // OneShotExec has no children, so this should only be called with an empty vector
168        if !children.is_empty() {
169            return Err(datafusion_common::DataFusionError::Internal(
170                "OneShotExec does not support children".to_string(),
171            ));
172        }
173        Ok(self)
174    }
175
176    fn execute(
177        &self,
178        _partition: usize,
179        _context: Arc<datafusion::execution::TaskContext>,
180    ) -> datafusion_common::Result<SendableRecordBatchStream> {
181        let stream = self
182            .stream
183            .lock()
184            .map_err(|err| DataFusionError::Execution(err.to_string()))?
185            .take();
186        if let Some(stream) = stream {
187            Ok(stream)
188        } else {
189            Err(DataFusionError::Execution(
190                "OneShotExec has already been executed".to_string(),
191            ))
192        }
193    }
194
195    fn statistics(&self) -> datafusion_common::Result<datafusion_common::Statistics> {
196        Ok(Statistics::new_unknown(&self.schema))
197    }
198
199    fn properties(&self) -> &datafusion::physical_plan::PlanProperties {
200        &self.properties
201    }
202}
203
204struct TracedExec {
205    input: Arc<dyn ExecutionPlan>,
206    properties: PlanProperties,
207    span: Span,
208}
209
210impl TracedExec {
211    pub fn new(input: Arc<dyn ExecutionPlan>, span: Span) -> Self {
212        Self {
213            properties: input.properties().clone(),
214            input,
215            span,
216        }
217    }
218}
219
220impl DisplayAs for TracedExec {
221    fn fmt_as(
222        &self,
223        t: datafusion::physical_plan::DisplayFormatType,
224        f: &mut std::fmt::Formatter,
225    ) -> std::fmt::Result {
226        match t {
227            DisplayFormatType::Default
228            | DisplayFormatType::Verbose
229            | DisplayFormatType::TreeRender => {
230                write!(f, "TracedExec")
231            }
232        }
233    }
234}
235
236impl std::fmt::Debug for TracedExec {
237    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
238        write!(f, "TracedExec")
239    }
240}
241impl ExecutionPlan for TracedExec {
242    fn name(&self) -> &str {
243        "TracedExec"
244    }
245
246    fn as_any(&self) -> &dyn std::any::Any {
247        self
248    }
249
250    fn properties(&self) -> &PlanProperties {
251        &self.properties
252    }
253
254    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
255        vec![&self.input]
256    }
257
258    fn with_new_children(
259        self: Arc<Self>,
260        children: Vec<Arc<dyn ExecutionPlan>>,
261    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
262        Ok(Arc::new(Self {
263            input: children[0].clone(),
264            properties: self.properties.clone(),
265            span: self.span.clone(),
266        }))
267    }
268
269    fn execute(
270        &self,
271        partition: usize,
272        context: Arc<TaskContext>,
273    ) -> datafusion_common::Result<SendableRecordBatchStream> {
274        let _guard = self.span.enter();
275        let stream = self.input.execute(partition, context)?;
276        let schema = stream.schema();
277        let stream = stream.stream_in_span(self.span.clone());
278        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
279    }
280}
281
282/// Callback for reporting statistics after a scan
283pub type ExecutionStatsCallback = Arc<dyn Fn(&ExecutionSummaryCounts) + Send + Sync>;
284
285#[derive(Default, Clone)]
286pub struct LanceExecutionOptions {
287    pub use_spilling: bool,
288    pub mem_pool_size: Option<u64>,
289    pub batch_size: Option<usize>,
290    pub target_partition: Option<usize>,
291    pub execution_stats_callback: Option<ExecutionStatsCallback>,
292    pub skip_logging: bool,
293}
294
295impl std::fmt::Debug for LanceExecutionOptions {
296    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        f.debug_struct("LanceExecutionOptions")
298            .field("use_spilling", &self.use_spilling)
299            .field("mem_pool_size", &self.mem_pool_size)
300            .field("batch_size", &self.batch_size)
301            .field("target_partition", &self.target_partition)
302            .field("skip_logging", &self.skip_logging)
303            .field(
304                "execution_stats_callback",
305                &self.execution_stats_callback.is_some(),
306            )
307            .finish()
308    }
309}
310
311const DEFAULT_LANCE_MEM_POOL_SIZE: u64 = 100 * 1024 * 1024;
312
313impl LanceExecutionOptions {
314    pub fn mem_pool_size(&self) -> u64 {
315        self.mem_pool_size.unwrap_or_else(|| {
316            std::env::var("LANCE_MEM_POOL_SIZE")
317                .map(|s| match s.parse::<u64>() {
318                    Ok(v) => v,
319                    Err(e) => {
320                        warn!("Failed to parse LANCE_MEM_POOL_SIZE: {}, using default", e);
321                        DEFAULT_LANCE_MEM_POOL_SIZE
322                    }
323                })
324                .unwrap_or(DEFAULT_LANCE_MEM_POOL_SIZE)
325        })
326    }
327
328    pub fn use_spilling(&self) -> bool {
329        if !self.use_spilling {
330            return false;
331        }
332        std::env::var("LANCE_BYPASS_SPILLING")
333            .map(|_| {
334                info!("Bypassing spilling because LANCE_BYPASS_SPILLING is set");
335                false
336            })
337            .unwrap_or(true)
338    }
339}
340
341pub fn new_session_context(options: &LanceExecutionOptions) -> SessionContext {
342    let mut session_config = SessionConfig::new();
343    let mut runtime_env_builder = RuntimeEnvBuilder::new();
344    if let Some(target_partition) = options.target_partition {
345        session_config = session_config.with_target_partitions(target_partition);
346    }
347    if options.use_spilling() {
348        runtime_env_builder = runtime_env_builder
349            .with_disk_manager_builder(DiskManagerBuilder::default())
350            .with_memory_pool(Arc::new(FairSpillPool::new(
351                options.mem_pool_size() as usize
352            )));
353    }
354    let runtime_env = runtime_env_builder.build_arc().unwrap();
355
356    let ctx = SessionContext::new_with_config_rt(session_config, runtime_env);
357    register_functions(&ctx);
358
359    ctx
360}
361
362/// Cache key for session contexts based on resolved configuration values.
363#[derive(Clone, Debug, PartialEq, Eq, Hash)]
364struct SessionContextCacheKey {
365    mem_pool_size: u64,
366    target_partition: Option<usize>,
367    use_spilling: bool,
368}
369
370impl SessionContextCacheKey {
371    fn from_options(options: &LanceExecutionOptions) -> Self {
372        Self {
373            mem_pool_size: options.mem_pool_size(),
374            target_partition: options.target_partition,
375            use_spilling: options.use_spilling(),
376        }
377    }
378}
379
380struct CachedSessionContext {
381    context: SessionContext,
382    last_access: std::time::Instant,
383}
384
385fn get_session_cache() -> &'static Mutex<HashMap<SessionContextCacheKey, CachedSessionContext>> {
386    static SESSION_CACHE: OnceLock<Mutex<HashMap<SessionContextCacheKey, CachedSessionContext>>> =
387        OnceLock::new();
388    SESSION_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
389}
390
391fn get_max_cache_size() -> usize {
392    const DEFAULT_CACHE_SIZE: usize = 4;
393    static MAX_CACHE_SIZE: OnceLock<usize> = OnceLock::new();
394    *MAX_CACHE_SIZE.get_or_init(|| {
395        std::env::var("LANCE_SESSION_CACHE_SIZE")
396            .ok()
397            .and_then(|v| v.parse().ok())
398            .unwrap_or(DEFAULT_CACHE_SIZE)
399    })
400}
401
402pub fn get_session_context(options: &LanceExecutionOptions) -> SessionContext {
403    let key = SessionContextCacheKey::from_options(options);
404    let mut cache = get_session_cache()
405        .lock()
406        .unwrap_or_else(|e| e.into_inner());
407
408    // If key exists, update access time and return
409    if let Some(entry) = cache.get_mut(&key) {
410        entry.last_access = std::time::Instant::now();
411        return entry.context.clone();
412    }
413
414    // Evict least recently used entry if cache is full
415    if cache.len() >= get_max_cache_size() {
416        if let Some(lru_key) = cache
417            .iter()
418            .min_by_key(|(_, v)| v.last_access)
419            .map(|(k, _)| k.clone())
420        {
421            cache.remove(&lru_key);
422        }
423    }
424
425    let context = new_session_context(options);
426    cache.insert(
427        key,
428        CachedSessionContext {
429            context: context.clone(),
430            last_access: std::time::Instant::now(),
431        },
432    );
433    context
434}
435
436fn get_task_context(
437    session_ctx: &SessionContext,
438    options: &LanceExecutionOptions,
439) -> Arc<TaskContext> {
440    let mut state = session_ctx.state();
441    if let Some(batch_size) = options.batch_size.as_ref() {
442        state.config_mut().options_mut().execution.batch_size = *batch_size;
443    }
444
445    state.task_ctx()
446}
447
448#[derive(Default, Clone, Debug, PartialEq, Eq)]
449pub struct ExecutionSummaryCounts {
450    /// The number of I/O operations performed
451    pub iops: usize,
452    /// The number of requests made to the storage layer (may be larger or smaller than iops
453    /// depending on coalescing configuration)
454    pub requests: usize,
455    /// The number of bytes read during the execution of the plan
456    pub bytes_read: usize,
457    /// The number of top-level indices loaded
458    pub indices_loaded: usize,
459    /// The number of index partitions loaded
460    pub parts_loaded: usize,
461    /// The number of index comparisons performed (the exact meaning depends on the index type)
462    pub index_comparisons: usize,
463    /// Additional metrics for more detailed statistics.  These are subject to change in the future
464    /// and should only be used for debugging purposes.
465    pub all_counts: HashMap<String, usize>,
466}
467
468fn visit_node(node: &dyn ExecutionPlan, counts: &mut ExecutionSummaryCounts) {
469    if let Some(metrics) = node.metrics() {
470        for (metric_name, count) in metrics.iter_counts() {
471            match metric_name.as_ref() {
472                IOPS_METRIC => counts.iops += count.value(),
473                REQUESTS_METRIC => counts.requests += count.value(),
474                BYTES_READ_METRIC => counts.bytes_read += count.value(),
475                INDICES_LOADED_METRIC => counts.indices_loaded += count.value(),
476                PARTS_LOADED_METRIC => counts.parts_loaded += count.value(),
477                INDEX_COMPARISONS_METRIC => counts.index_comparisons += count.value(),
478                _ => {
479                    let existing = counts
480                        .all_counts
481                        .entry(metric_name.as_ref().to_string())
482                        .or_insert(0);
483                    *existing += count.value();
484                }
485            }
486        }
487        // Include gauge-based I/O metrics (some nodes record I/O as gauges)
488        for (metric_name, gauge) in metrics.iter_gauges() {
489            match metric_name.as_ref() {
490                IOPS_METRIC => counts.iops += gauge.value(),
491                REQUESTS_METRIC => counts.requests += gauge.value(),
492                BYTES_READ_METRIC => counts.bytes_read += gauge.value(),
493                _ => {}
494            }
495        }
496    }
497    for child in node.children() {
498        visit_node(child.as_ref(), counts);
499    }
500}
501
502fn report_plan_summary_metrics(plan: &dyn ExecutionPlan, options: &LanceExecutionOptions) {
503    let output_rows = plan
504        .metrics()
505        .map(|m| m.output_rows().unwrap_or(0))
506        .unwrap_or(0);
507    let mut counts = ExecutionSummaryCounts::default();
508    visit_node(plan, &mut counts);
509    tracing::info!(
510        target: TRACE_EXECUTION,
511        r#type = EXECUTION_PLAN_RUN,
512        plan_summary = display_plan_one_liner(plan),
513        output_rows,
514        iops = counts.iops,
515        requests = counts.requests,
516        bytes_read = counts.bytes_read,
517        indices_loaded = counts.indices_loaded,
518        parts_loaded = counts.parts_loaded,
519        index_comparisons = counts.index_comparisons,
520    );
521    if let Some(callback) = options.execution_stats_callback.as_ref() {
522        callback(&counts);
523    }
524}
525
526/// Create a one-line rough summary of the given execution plan.
527///
528/// The summary just shows the name of the operators in the plan. It omits any
529/// details such as parameters or schema information.
530///
531/// Example: `Projection(Take(CoalesceBatches(Filter(LanceScan))))`
532fn display_plan_one_liner(plan: &dyn ExecutionPlan) -> String {
533    let mut output = String::new();
534
535    display_plan_one_liner_impl(plan, &mut output);
536
537    output
538}
539
540fn display_plan_one_liner_impl(plan: &dyn ExecutionPlan, output: &mut String) {
541    // Remove the "Exec" suffix from the plan name if present for brevity
542    let name = plan.name().trim_end_matches("Exec");
543    output.push_str(name);
544
545    let children = plan.children();
546    if !children.is_empty() {
547        output.push('(');
548        for (i, child) in children.iter().enumerate() {
549            if i > 0 {
550                output.push(',');
551            }
552            display_plan_one_liner_impl(child.as_ref(), output);
553        }
554        output.push(')');
555    }
556}
557
558/// Executes a plan using default session & runtime configuration
559///
560/// Only executes a single partition.  Panics if the plan has more than one partition.
561pub fn execute_plan(
562    plan: Arc<dyn ExecutionPlan>,
563    options: LanceExecutionOptions,
564) -> Result<SendableRecordBatchStream> {
565    if !options.skip_logging {
566        debug!(
567            "Executing plan:\n{}",
568            DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
569        );
570    }
571
572    let session_ctx = get_session_context(&options);
573
574    // NOTE: we are only executing the first partition here. Therefore, if
575    // the plan has more than one partition, we will be missing data.
576    assert_eq!(plan.properties().partitioning.partition_count(), 1);
577    let stream = plan.execute(0, get_task_context(&session_ctx, &options))?;
578
579    let schema = stream.schema();
580    let stream = stream.finally(move || {
581        if !options.skip_logging {
582            report_plan_summary_metrics(plan.as_ref(), &options);
583        }
584    });
585    Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
586}
587
588pub async fn analyze_plan(
589    plan: Arc<dyn ExecutionPlan>,
590    options: LanceExecutionOptions,
591) -> Result<String> {
592    // This is needed as AnalyzeExec launches a thread task per
593    // partition, and we want these to be connected to the parent span
594    let plan = Arc::new(TracedExec::new(plan, Span::current()));
595
596    let schema = plan.schema();
597    let analyze = Arc::new(AnalyzeExec::new(true, true, plan, schema));
598
599    let session_ctx = get_session_context(&options);
600    assert_eq!(analyze.properties().partitioning.partition_count(), 1);
601    let mut stream = analyze
602        .execute(0, get_task_context(&session_ctx, &options))
603        .map_err(|err| {
604            Error::io(
605                format!("Failed to execute analyze plan: {}", err),
606                location!(),
607            )
608        })?;
609
610    // fully execute the plan
611    while (stream.next().await).is_some() {}
612
613    let result = format_plan(analyze);
614    Ok(result)
615}
616
617pub fn format_plan(plan: Arc<dyn ExecutionPlan>) -> String {
618    /// A visitor which calculates additional metrics for all the plans.
619    struct CalculateVisitor {
620        highest_index: usize,
621        index_to_cumulative_cpu: HashMap<usize, usize>,
622    }
623    impl CalculateVisitor {
624        fn calculate_cumulative_cpu(&mut self, plan: &Arc<dyn ExecutionPlan>) -> usize {
625            self.highest_index += 1;
626            let plan_index = self.highest_index;
627            let elapsed_cpu: usize = match plan.metrics() {
628                Some(metrics) => metrics.elapsed_compute().unwrap_or_default(),
629                None => 0,
630            };
631            let mut cumulative_cpu = elapsed_cpu;
632            for child in plan.children() {
633                cumulative_cpu += self.calculate_cumulative_cpu(child);
634            }
635            self.index_to_cumulative_cpu
636                .insert(plan_index, cumulative_cpu);
637            cumulative_cpu
638        }
639    }
640
641    /// A visitor which prints out all the plans.
642    struct PrintVisitor {
643        highest_index: usize,
644        indent: usize,
645    }
646    impl PrintVisitor {
647        fn write_output(
648            &mut self,
649            plan: &Arc<dyn ExecutionPlan>,
650            f: &mut Formatter,
651            calcs: &CalculateVisitor,
652        ) -> std::fmt::Result {
653            self.highest_index += 1;
654            write!(f, "{:indent$}", "", indent = self.indent * 2)?;
655            plan.fmt_as(datafusion::physical_plan::DisplayFormatType::Verbose, f)?;
656            if let Some(metrics) = plan.metrics() {
657                let metrics = metrics
658                    .aggregate_by_name()
659                    .sorted_for_display()
660                    .timestamps_removed();
661
662                write!(f, ", metrics=[{metrics}]")?;
663            } else {
664                write!(f, ", metrics=[]")?;
665            }
666            let cumulative_cpu = calcs
667                .index_to_cumulative_cpu
668                .get(&self.highest_index)
669                .unwrap();
670            let cumulative_cpu_duration = Duration::from_nanos((*cumulative_cpu) as u64);
671            write!(f, ", cumulative_cpu={cumulative_cpu_duration:?}")?;
672            writeln!(f)?;
673            self.indent += 1;
674            for child in plan.children() {
675                self.write_output(child, f, calcs)?;
676            }
677            self.indent -= 1;
678            std::fmt::Result::Ok(())
679        }
680    }
681    // A wrapper which prints out a plan.
682    struct PrintWrapper {
683        plan: Arc<dyn ExecutionPlan>,
684    }
685    impl fmt::Display for PrintWrapper {
686        fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
687            let mut calcs = CalculateVisitor {
688                highest_index: 0,
689                index_to_cumulative_cpu: HashMap::new(),
690            };
691            calcs.calculate_cumulative_cpu(&self.plan);
692            let mut prints = PrintVisitor {
693                highest_index: 0,
694                indent: 0,
695            };
696            prints.write_output(&self.plan, f, &calcs)
697        }
698    }
699    let wrapper = PrintWrapper { plan };
700    format!("{}", wrapper)
701}
702
703pub trait SessionContextExt {
704    /// Creates a DataFrame for reading a stream of data
705    ///
706    /// This dataframe may only be queried once, future queries will fail
707    fn read_one_shot(
708        &self,
709        data: SendableRecordBatchStream,
710    ) -> datafusion::common::Result<DataFrame>;
711}
712
713struct OneShotPartitionStream {
714    data: Arc<Mutex<Option<SendableRecordBatchStream>>>,
715    schema: Arc<ArrowSchema>,
716}
717
718impl std::fmt::Debug for OneShotPartitionStream {
719    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
720        let data = self.data.lock().unwrap();
721        f.debug_struct("OneShotPartitionStream")
722            .field("exhausted", &data.is_none())
723            .field("schema", self.schema.as_ref())
724            .finish()
725    }
726}
727
728impl OneShotPartitionStream {
729    fn new(data: SendableRecordBatchStream) -> Self {
730        let schema = data.schema();
731        Self {
732            data: Arc::new(Mutex::new(Some(data))),
733            schema,
734        }
735    }
736}
737
738impl PartitionStream for OneShotPartitionStream {
739    fn schema(&self) -> &arrow_schema::SchemaRef {
740        &self.schema
741    }
742
743    fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
744        let mut stream = self.data.lock().unwrap();
745        stream
746            .take()
747            .expect("Attempt to consume a one shot dataframe multiple times")
748    }
749}
750
751impl SessionContextExt for SessionContext {
752    fn read_one_shot(
753        &self,
754        data: SendableRecordBatchStream,
755    ) -> datafusion::common::Result<DataFrame> {
756        let schema = data.schema();
757        let part_stream = Arc::new(OneShotPartitionStream::new(data));
758        let provider = StreamingTable::try_new(schema, vec![part_stream])?;
759        self.read_table(Arc::new(provider))
760    }
761}
762
763#[derive(Clone, Debug)]
764pub struct StrictBatchSizeExec {
765    input: Arc<dyn ExecutionPlan>,
766    batch_size: usize,
767}
768
769impl StrictBatchSizeExec {
770    pub fn new(input: Arc<dyn ExecutionPlan>, batch_size: usize) -> Self {
771        Self { input, batch_size }
772    }
773}
774
775impl DisplayAs for StrictBatchSizeExec {
776    fn fmt_as(
777        &self,
778        _t: datafusion::physical_plan::DisplayFormatType,
779        f: &mut std::fmt::Formatter,
780    ) -> std::fmt::Result {
781        write!(f, "StrictBatchSizeExec")
782    }
783}
784
785impl ExecutionPlan for StrictBatchSizeExec {
786    fn name(&self) -> &str {
787        "StrictBatchSizeExec"
788    }
789
790    fn as_any(&self) -> &dyn std::any::Any {
791        self
792    }
793
794    fn properties(&self) -> &PlanProperties {
795        self.input.properties()
796    }
797
798    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
799        vec![&self.input]
800    }
801
802    fn with_new_children(
803        self: Arc<Self>,
804        children: Vec<Arc<dyn ExecutionPlan>>,
805    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
806        Ok(Arc::new(Self {
807            input: children[0].clone(),
808            batch_size: self.batch_size,
809        }))
810    }
811
812    fn execute(
813        &self,
814        partition: usize,
815        context: Arc<TaskContext>,
816    ) -> datafusion_common::Result<SendableRecordBatchStream> {
817        let stream = self.input.execute(partition, context)?;
818        let schema = stream.schema();
819        let stream = StrictBatchSizeStream::new(stream, self.batch_size);
820        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
821    }
822
823    fn maintains_input_order(&self) -> Vec<bool> {
824        vec![true]
825    }
826
827    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
828        vec![false]
829    }
830
831    fn partition_statistics(
832        &self,
833        partition: Option<usize>,
834    ) -> datafusion_common::Result<Statistics> {
835        self.input.partition_statistics(partition)
836    }
837
838    fn cardinality_effect(&self) -> CardinalityEffect {
839        CardinalityEffect::Equal
840    }
841
842    fn supports_limit_pushdown(&self) -> bool {
843        true
844    }
845}
846
847#[cfg(test)]
848mod tests {
849    use super::*;
850
851    // Serialize cache tests since they share global state
852    static CACHE_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
853
854    #[test]
855    fn test_session_context_cache() {
856        let _lock = CACHE_TEST_LOCK.lock().unwrap();
857        let cache = get_session_cache();
858
859        // Clear any existing entries from other tests
860        cache.lock().unwrap().clear();
861
862        // Create first session with default options
863        let opts1 = LanceExecutionOptions::default();
864        let _ctx1 = get_session_context(&opts1);
865
866        {
867            let cache_guard = cache.lock().unwrap();
868            assert_eq!(cache_guard.len(), 1);
869        }
870
871        // Same options should reuse cached session (no new entry)
872        let _ctx1_again = get_session_context(&opts1);
873        {
874            let cache_guard = cache.lock().unwrap();
875            assert_eq!(cache_guard.len(), 1);
876        }
877
878        // Different options should create new entry
879        let opts2 = LanceExecutionOptions {
880            use_spilling: true,
881            ..Default::default()
882        };
883        let _ctx2 = get_session_context(&opts2);
884        {
885            let cache_guard = cache.lock().unwrap();
886            assert_eq!(cache_guard.len(), 2);
887        }
888    }
889
890    #[test]
891    fn test_session_context_cache_lru_eviction() {
892        let _lock = CACHE_TEST_LOCK.lock().unwrap();
893        let cache = get_session_cache();
894
895        // Clear any existing entries from other tests
896        cache.lock().unwrap().clear();
897
898        // Create 4 different configurations to fill the cache
899        let configs: Vec<LanceExecutionOptions> = (0..4)
900            .map(|i| LanceExecutionOptions {
901                mem_pool_size: Some((i + 1) as u64 * 1024 * 1024),
902                ..Default::default()
903            })
904            .collect();
905
906        for config in &configs {
907            let _ctx = get_session_context(config);
908        }
909
910        {
911            let cache_guard = cache.lock().unwrap();
912            assert_eq!(cache_guard.len(), 4);
913        }
914
915        // Access config[0] to make it more recently used than config[1]
916        // (config[0] was inserted first, so without this access it would be evicted)
917        std::thread::sleep(std::time::Duration::from_millis(1));
918        let _ctx = get_session_context(&configs[0]);
919
920        // Add a 5th configuration - should evict config[1] (now least recently used)
921        let opts5 = LanceExecutionOptions {
922            mem_pool_size: Some(5 * 1024 * 1024),
923            ..Default::default()
924        };
925        let _ctx5 = get_session_context(&opts5);
926
927        {
928            let cache_guard = cache.lock().unwrap();
929            assert_eq!(cache_guard.len(), 4);
930
931            // config[0] should still be present (was accessed recently)
932            let key0 = SessionContextCacheKey::from_options(&configs[0]);
933            assert!(
934                cache_guard.contains_key(&key0),
935                "config[0] should still be cached after recent access"
936            );
937
938            // config[1] should be evicted (was least recently used)
939            let key1 = SessionContextCacheKey::from_options(&configs[1]);
940            assert!(
941                !cache_guard.contains_key(&key1),
942                "config[1] should have been evicted"
943            );
944
945            // New config should be present
946            let key5 = SessionContextCacheKey::from_options(&opts5);
947            assert!(
948                cache_guard.contains_key(&key5),
949                "new config should be cached"
950            );
951        }
952    }
953}