1use std::{
7 collections::HashMap,
8 fmt::{self, Formatter},
9 sync::{Arc, Mutex, OnceLock},
10 time::Duration,
11};
12
13use chrono::{DateTime, Utc};
14
15use arrow_array::RecordBatch;
16use arrow_schema::Schema as ArrowSchema;
17use datafusion::physical_plan::metrics::MetricType;
18use datafusion::{
19 catalog::streaming::StreamingTable,
20 dataframe::DataFrame,
21 execution::{
22 TaskContext,
23 context::{SessionConfig, SessionContext},
24 disk_manager::DiskManagerBuilder,
25 memory_pool::FairSpillPool,
26 runtime_env::RuntimeEnvBuilder,
27 },
28 physical_plan::{
29 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
30 analyze::AnalyzeExec,
31 display::DisplayableExecutionPlan,
32 execution_plan::{Boundedness, CardinalityEffect, EmissionType},
33 metrics::MetricValue,
34 stream::RecordBatchStreamAdapter,
35 streaming::PartitionStream,
36 },
37};
38use datafusion_common::{DataFusionError, Statistics};
39use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
40
41use futures::{StreamExt, stream};
42use lance_arrow::SchemaExt;
43use lance_core::{
44 Error, Result,
45 utils::{
46 futures::FinallyStreamExt,
47 tracing::{EXECUTION_PLAN_RUN, StreamTracingExt, TRACE_EXECUTION},
48 },
49};
50use log::{debug, info, warn};
51use tracing::Span;
52
53use crate::udf::register_functions;
54use crate::{
55 chunker::StrictBatchSizeStream,
56 utils::{
57 BYTES_READ_METRIC, INDEX_COMPARISONS_METRIC, INDICES_LOADED_METRIC, IOPS_METRIC,
58 MetricsExt, PARTS_LOADED_METRIC, REQUESTS_METRIC,
59 },
60};
61
62pub struct OneShotExec {
70 stream: Mutex<Option<SendableRecordBatchStream>>,
71 schema: Arc<ArrowSchema>,
74 properties: PlanProperties,
75}
76
77impl OneShotExec {
78 pub fn new(stream: SendableRecordBatchStream) -> Self {
80 let schema = stream.schema();
81 Self {
82 stream: Mutex::new(Some(stream)),
83 schema: schema.clone(),
84 properties: PlanProperties::new(
85 EquivalenceProperties::new(schema),
86 Partitioning::RoundRobinBatch(1),
87 EmissionType::Incremental,
88 Boundedness::Bounded,
89 ),
90 }
91 }
92
93 pub fn from_batch(batch: RecordBatch) -> Self {
94 let schema = batch.schema();
95 let stream = Box::pin(RecordBatchStreamAdapter::new(
96 schema,
97 stream::iter(vec![Ok(batch)]),
98 ));
99 Self::new(stream)
100 }
101}
102
103impl std::fmt::Debug for OneShotExec {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 let stream = self.stream.lock().unwrap();
106 f.debug_struct("OneShotExec")
107 .field("exhausted", &stream.is_none())
108 .field("schema", self.schema.as_ref())
109 .finish()
110 }
111}
112
113impl DisplayAs for OneShotExec {
114 fn fmt_as(
115 &self,
116 t: datafusion::physical_plan::DisplayFormatType,
117 f: &mut std::fmt::Formatter,
118 ) -> std::fmt::Result {
119 let stream = self.stream.lock().unwrap();
120 let exhausted = if stream.is_some() { "" } else { "EXHAUSTED" };
121 let columns = self
122 .schema
123 .field_names()
124 .iter()
125 .cloned()
126 .cloned()
127 .collect::<Vec<_>>();
128 match t {
129 DisplayFormatType::Default | DisplayFormatType::Verbose => {
130 write!(
131 f,
132 "OneShotStream: {}columns=[{}]",
133 exhausted,
134 columns.join(",")
135 )
136 }
137 DisplayFormatType::TreeRender => {
138 write!(
139 f,
140 "OneShotStream\nexhausted={}\ncolumns=[{}]",
141 exhausted,
142 columns.join(",")
143 )
144 }
145 }
146 }
147}
148
149impl ExecutionPlan for OneShotExec {
150 fn name(&self) -> &str {
151 "OneShotExec"
152 }
153
154 fn as_any(&self) -> &dyn std::any::Any {
155 self
156 }
157
158 fn schema(&self) -> arrow_schema::SchemaRef {
159 self.schema.clone()
160 }
161
162 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
163 vec![]
164 }
165
166 fn with_new_children(
167 self: Arc<Self>,
168 children: Vec<Arc<dyn ExecutionPlan>>,
169 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
170 if !children.is_empty() {
172 return Err(datafusion_common::DataFusionError::Internal(
173 "OneShotExec does not support children".to_string(),
174 ));
175 }
176 Ok(self)
177 }
178
179 fn execute(
180 &self,
181 _partition: usize,
182 _context: Arc<datafusion::execution::TaskContext>,
183 ) -> datafusion_common::Result<SendableRecordBatchStream> {
184 let stream = self
185 .stream
186 .lock()
187 .map_err(|err| DataFusionError::Execution(err.to_string()))?
188 .take();
189 if let Some(stream) = stream {
190 Ok(stream)
191 } else {
192 Err(DataFusionError::Execution(
193 "OneShotExec has already been executed".to_string(),
194 ))
195 }
196 }
197
198 fn statistics(&self) -> datafusion_common::Result<datafusion_common::Statistics> {
199 Ok(Statistics::new_unknown(&self.schema))
200 }
201
202 fn properties(&self) -> &datafusion::physical_plan::PlanProperties {
203 &self.properties
204 }
205}
206
207struct TracedExec {
208 input: Arc<dyn ExecutionPlan>,
209 properties: PlanProperties,
210 span: Span,
211}
212
213impl TracedExec {
214 pub fn new(input: Arc<dyn ExecutionPlan>, span: Span) -> Self {
215 Self {
216 properties: input.properties().clone(),
217 input,
218 span,
219 }
220 }
221}
222
223impl DisplayAs for TracedExec {
224 fn fmt_as(
225 &self,
226 t: datafusion::physical_plan::DisplayFormatType,
227 f: &mut std::fmt::Formatter,
228 ) -> std::fmt::Result {
229 match t {
230 DisplayFormatType::Default
231 | DisplayFormatType::Verbose
232 | DisplayFormatType::TreeRender => {
233 write!(f, "TracedExec")
234 }
235 }
236 }
237}
238
239impl std::fmt::Debug for TracedExec {
240 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
241 write!(f, "TracedExec")
242 }
243}
244impl ExecutionPlan for TracedExec {
245 fn name(&self) -> &str {
246 "TracedExec"
247 }
248
249 fn as_any(&self) -> &dyn std::any::Any {
250 self
251 }
252
253 fn properties(&self) -> &PlanProperties {
254 &self.properties
255 }
256
257 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
258 vec![&self.input]
259 }
260
261 fn with_new_children(
262 self: Arc<Self>,
263 children: Vec<Arc<dyn ExecutionPlan>>,
264 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
265 Ok(Arc::new(Self {
266 input: children[0].clone(),
267 properties: self.properties.clone(),
268 span: self.span.clone(),
269 }))
270 }
271
272 fn execute(
273 &self,
274 partition: usize,
275 context: Arc<TaskContext>,
276 ) -> datafusion_common::Result<SendableRecordBatchStream> {
277 let _guard = self.span.enter();
278 let stream = self.input.execute(partition, context)?;
279 let schema = stream.schema();
280 let stream = stream.stream_in_span(self.span.clone());
281 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
282 }
283}
284
285pub type ExecutionStatsCallback = Arc<dyn Fn(&ExecutionSummaryCounts) + Send + Sync>;
287
288#[derive(Default, Clone)]
289pub struct LanceExecutionOptions {
290 pub use_spilling: bool,
291 pub mem_pool_size: Option<u64>,
292 pub max_temp_directory_size: Option<u64>,
293 pub batch_size: Option<usize>,
294 pub target_partition: Option<usize>,
295 pub execution_stats_callback: Option<ExecutionStatsCallback>,
296 pub skip_logging: bool,
297}
298
299impl std::fmt::Debug for LanceExecutionOptions {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 f.debug_struct("LanceExecutionOptions")
302 .field("use_spilling", &self.use_spilling)
303 .field("mem_pool_size", &self.mem_pool_size)
304 .field("max_temp_directory_size", &self.max_temp_directory_size)
305 .field("batch_size", &self.batch_size)
306 .field("target_partition", &self.target_partition)
307 .field("skip_logging", &self.skip_logging)
308 .field(
309 "execution_stats_callback",
310 &self.execution_stats_callback.is_some(),
311 )
312 .finish()
313 }
314}
315
316const DEFAULT_LANCE_MEM_POOL_SIZE: u64 = 100 * 1024 * 1024;
317const DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE: u64 = 100 * 1024 * 1024 * 1024; impl LanceExecutionOptions {
320 pub fn mem_pool_size(&self) -> u64 {
321 self.mem_pool_size.unwrap_or_else(|| {
322 std::env::var("LANCE_MEM_POOL_SIZE")
323 .map(|s| match s.parse::<u64>() {
324 Ok(v) => v,
325 Err(e) => {
326 warn!("Failed to parse LANCE_MEM_POOL_SIZE: {}, using default", e);
327 DEFAULT_LANCE_MEM_POOL_SIZE
328 }
329 })
330 .unwrap_or(DEFAULT_LANCE_MEM_POOL_SIZE)
331 })
332 }
333
334 pub fn max_temp_directory_size(&self) -> u64 {
335 self.max_temp_directory_size.unwrap_or_else(|| {
336 std::env::var("LANCE_MAX_TEMP_DIRECTORY_SIZE")
337 .map(|s| match s.parse::<u64>() {
338 Ok(v) => v,
339 Err(e) => {
340 warn!(
341 "Failed to parse LANCE_MAX_TEMP_DIRECTORY_SIZE: {}, using default",
342 e
343 );
344 DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE
345 }
346 })
347 .unwrap_or(DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE)
348 })
349 }
350
351 pub fn use_spilling(&self) -> bool {
352 if !self.use_spilling {
353 return false;
354 }
355 std::env::var("LANCE_BYPASS_SPILLING")
356 .map(|_| {
357 info!("Bypassing spilling because LANCE_BYPASS_SPILLING is set");
358 false
359 })
360 .unwrap_or(true)
361 }
362}
363
364pub fn new_session_context(options: &LanceExecutionOptions) -> SessionContext {
365 let mut session_config = SessionConfig::new();
366 let mut runtime_env_builder = RuntimeEnvBuilder::new();
367 if let Some(target_partition) = options.target_partition {
368 session_config = session_config.with_target_partitions(target_partition);
369 }
370 if options.use_spilling() {
371 let disk_manager_builder = DiskManagerBuilder::default()
372 .with_max_temp_directory_size(options.max_temp_directory_size());
373 runtime_env_builder = runtime_env_builder
374 .with_disk_manager_builder(disk_manager_builder)
375 .with_memory_pool(Arc::new(FairSpillPool::new(
376 options.mem_pool_size() as usize
377 )));
378 }
379 let runtime_env = runtime_env_builder.build_arc().unwrap();
380
381 let ctx = SessionContext::new_with_config_rt(session_config, runtime_env);
382 register_functions(&ctx);
383
384 ctx
385}
386
387#[derive(Clone, Debug, PartialEq, Eq, Hash)]
389struct SessionContextCacheKey {
390 mem_pool_size: u64,
391 max_temp_directory_size: u64,
392 target_partition: Option<usize>,
393 use_spilling: bool,
394}
395
396impl SessionContextCacheKey {
397 fn from_options(options: &LanceExecutionOptions) -> Self {
398 Self {
399 mem_pool_size: options.mem_pool_size(),
400 max_temp_directory_size: options.max_temp_directory_size(),
401 target_partition: options.target_partition,
402 use_spilling: options.use_spilling(),
403 }
404 }
405}
406
407struct CachedSessionContext {
408 context: SessionContext,
409 last_access: std::time::Instant,
410}
411
412fn get_session_cache() -> &'static Mutex<HashMap<SessionContextCacheKey, CachedSessionContext>> {
413 static SESSION_CACHE: OnceLock<Mutex<HashMap<SessionContextCacheKey, CachedSessionContext>>> =
414 OnceLock::new();
415 SESSION_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
416}
417
418fn get_max_cache_size() -> usize {
419 const DEFAULT_CACHE_SIZE: usize = 4;
420 static MAX_CACHE_SIZE: OnceLock<usize> = OnceLock::new();
421 *MAX_CACHE_SIZE.get_or_init(|| {
422 std::env::var("LANCE_SESSION_CACHE_SIZE")
423 .ok()
424 .and_then(|v| v.parse().ok())
425 .unwrap_or(DEFAULT_CACHE_SIZE)
426 })
427}
428
429pub fn get_session_context(options: &LanceExecutionOptions) -> SessionContext {
430 let key = SessionContextCacheKey::from_options(options);
431 let mut cache = get_session_cache()
432 .lock()
433 .unwrap_or_else(|e| e.into_inner());
434
435 if let Some(entry) = cache.get_mut(&key) {
437 entry.last_access = std::time::Instant::now();
438 return entry.context.clone();
439 }
440
441 if cache.len() >= get_max_cache_size()
443 && let Some(lru_key) = cache
444 .iter()
445 .min_by_key(|(_, v)| v.last_access)
446 .map(|(k, _)| k.clone())
447 {
448 cache.remove(&lru_key);
449 }
450
451 let context = new_session_context(options);
452 cache.insert(
453 key,
454 CachedSessionContext {
455 context: context.clone(),
456 last_access: std::time::Instant::now(),
457 },
458 );
459 context
460}
461
462fn get_task_context(
463 session_ctx: &SessionContext,
464 options: &LanceExecutionOptions,
465) -> Arc<TaskContext> {
466 let mut state = session_ctx.state();
467 if let Some(batch_size) = options.batch_size.as_ref() {
468 state.config_mut().options_mut().execution.batch_size = *batch_size;
469 }
470
471 state.task_ctx()
472}
473
474#[derive(Default, Clone, Debug, PartialEq, Eq)]
475pub struct ExecutionSummaryCounts {
476 pub iops: usize,
478 pub requests: usize,
481 pub bytes_read: usize,
483 pub indices_loaded: usize,
485 pub parts_loaded: usize,
487 pub index_comparisons: usize,
489 pub all_counts: HashMap<String, usize>,
492}
493
494pub fn collect_execution_metrics(node: &dyn ExecutionPlan, counts: &mut ExecutionSummaryCounts) {
495 if let Some(metrics) = node.metrics() {
496 for (metric_name, count) in metrics.iter_counts() {
497 match metric_name.as_ref() {
498 IOPS_METRIC => counts.iops += count.value(),
499 REQUESTS_METRIC => counts.requests += count.value(),
500 BYTES_READ_METRIC => counts.bytes_read += count.value(),
501 INDICES_LOADED_METRIC => counts.indices_loaded += count.value(),
502 PARTS_LOADED_METRIC => counts.parts_loaded += count.value(),
503 INDEX_COMPARISONS_METRIC => counts.index_comparisons += count.value(),
504 _ => {
505 let existing = counts
506 .all_counts
507 .entry(metric_name.as_ref().to_string())
508 .or_insert(0);
509 *existing += count.value();
510 }
511 }
512 }
513 for (metric_name, gauge) in metrics.iter_gauges() {
515 match metric_name.as_ref() {
516 IOPS_METRIC => counts.iops += gauge.value(),
517 REQUESTS_METRIC => counts.requests += gauge.value(),
518 BYTES_READ_METRIC => counts.bytes_read += gauge.value(),
519 _ => {}
520 }
521 }
522 }
523 for child in node.children() {
524 collect_execution_metrics(child.as_ref(), counts);
525 }
526}
527
528fn report_plan_summary_metrics(plan: &dyn ExecutionPlan, options: &LanceExecutionOptions) {
529 let output_rows = plan
530 .metrics()
531 .map(|m| m.output_rows().unwrap_or(0))
532 .unwrap_or(0);
533 let mut counts = ExecutionSummaryCounts::default();
534 collect_execution_metrics(plan, &mut counts);
535 tracing::info!(
536 target: TRACE_EXECUTION,
537 r#type = EXECUTION_PLAN_RUN,
538 plan_summary = display_plan_one_liner(plan),
539 output_rows,
540 iops = counts.iops,
541 requests = counts.requests,
542 bytes_read = counts.bytes_read,
543 indices_loaded = counts.indices_loaded,
544 parts_loaded = counts.parts_loaded,
545 index_comparisons = counts.index_comparisons,
546 );
547 if let Some(callback) = options.execution_stats_callback.as_ref() {
548 callback(&counts);
549 }
550}
551
552fn display_plan_one_liner(plan: &dyn ExecutionPlan) -> String {
559 let mut output = String::new();
560
561 display_plan_one_liner_impl(plan, &mut output);
562
563 output
564}
565
566fn display_plan_one_liner_impl(plan: &dyn ExecutionPlan, output: &mut String) {
567 let name = plan.name().trim_end_matches("Exec");
569 output.push_str(name);
570
571 let children = plan.children();
572 if !children.is_empty() {
573 output.push('(');
574 for (i, child) in children.iter().enumerate() {
575 if i > 0 {
576 output.push(',');
577 }
578 display_plan_one_liner_impl(child.as_ref(), output);
579 }
580 output.push(')');
581 }
582}
583
584pub fn execute_plan(
588 plan: Arc<dyn ExecutionPlan>,
589 options: LanceExecutionOptions,
590) -> Result<SendableRecordBatchStream> {
591 if !options.skip_logging {
592 debug!(
593 "Executing plan:\n{}",
594 DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
595 );
596 }
597
598 let session_ctx = get_session_context(&options);
599
600 assert_eq!(plan.properties().partitioning.partition_count(), 1);
603 let stream = plan.execute(0, get_task_context(&session_ctx, &options))?;
604
605 let schema = stream.schema();
606 let stream = stream.finally(move || {
607 if !options.skip_logging {
608 report_plan_summary_metrics(plan.as_ref(), &options);
609 }
610 });
611 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
612}
613
614pub async fn analyze_plan(
615 plan: Arc<dyn ExecutionPlan>,
616 options: LanceExecutionOptions,
617) -> Result<String> {
618 let plan = Arc::new(TracedExec::new(plan, Span::current()));
621
622 let schema = plan.schema();
623 let analyze = Arc::new(AnalyzeExec::new(
625 true,
626 true,
627 vec![MetricType::SUMMARY],
628 plan,
629 schema,
630 ));
631
632 let session_ctx = get_session_context(&options);
633 assert_eq!(analyze.properties().partitioning.partition_count(), 1);
634 let mut stream = analyze
635 .execute(0, get_task_context(&session_ctx, &options))
636 .map_err(|err| Error::io(format!("Failed to execute analyze plan: {}", err)))?;
637
638 while (stream.next().await).is_some() {}
640
641 let result = format_plan(analyze);
642 Ok(result)
643}
644
645pub fn format_plan(plan: Arc<dyn ExecutionPlan>) -> String {
646 struct CalculateVisitor {
648 highest_index: usize,
649 index_to_elapsed: HashMap<usize, Duration>,
650 }
651
652 struct SubtreeMetrics {
654 min_start: Option<DateTime<Utc>>,
655 max_end: Option<DateTime<Utc>>,
656 }
657
658 impl CalculateVisitor {
659 fn calculate_metrics(&mut self, plan: &Arc<dyn ExecutionPlan>) -> SubtreeMetrics {
660 self.highest_index += 1;
661 let plan_index = self.highest_index;
662
663 let (mut min_start, mut max_end) = Self::node_timerange(plan);
665
666 for child in plan.children() {
668 let child_metrics = self.calculate_metrics(child);
669 min_start = Self::min_option(min_start, child_metrics.min_start);
670 max_end = Self::max_option(max_end, child_metrics.max_end);
671 }
672
673 let elapsed = match (min_start, max_end) {
675 (Some(start), Some(end)) => Some((end - start).to_std().unwrap_or_default()),
676 _ => None,
677 };
678
679 if let Some(e) = elapsed {
680 self.index_to_elapsed.insert(plan_index, e);
681 }
682
683 SubtreeMetrics { min_start, max_end }
684 }
685
686 fn node_timerange(
687 plan: &Arc<dyn ExecutionPlan>,
688 ) -> (Option<DateTime<Utc>>, Option<DateTime<Utc>>) {
689 let Some(metrics) = plan.metrics() else {
690 return (None, None);
691 };
692 let min_start = metrics
693 .iter()
694 .filter_map(|m| match m.value() {
695 MetricValue::StartTimestamp(ts) => ts.value(),
696 _ => None,
697 })
698 .min();
699 let max_end = metrics
700 .iter()
701 .filter_map(|m| match m.value() {
702 MetricValue::EndTimestamp(ts) => ts.value(),
703 _ => None,
704 })
705 .max();
706 (min_start, max_end)
707 }
708
709 fn min_option(a: Option<DateTime<Utc>>, b: Option<DateTime<Utc>>) -> Option<DateTime<Utc>> {
710 [a, b].into_iter().flatten().min()
711 }
712
713 fn max_option(a: Option<DateTime<Utc>>, b: Option<DateTime<Utc>>) -> Option<DateTime<Utc>> {
714 [a, b].into_iter().flatten().max()
715 }
716 }
717
718 struct PrintVisitor {
720 highest_index: usize,
721 indent: usize,
722 }
723 impl PrintVisitor {
724 fn write_output(
725 &mut self,
726 plan: &Arc<dyn ExecutionPlan>,
727 f: &mut Formatter,
728 calcs: &CalculateVisitor,
729 ) -> std::fmt::Result {
730 self.highest_index += 1;
731 write!(f, "{:indent$}", "", indent = self.indent * 2)?;
732
733 let displayable =
735 datafusion::physical_plan::display::DisplayableExecutionPlan::new(plan.as_ref());
736 let plan_str = displayable.one_line().to_string();
737 let plan_str = plan_str.trim();
738
739 match calcs.index_to_elapsed.get(&self.highest_index) {
741 Some(elapsed) => match plan_str.find(": ") {
742 Some(i) => write!(
743 f,
744 "{}: elapsed={elapsed:?}, {}",
745 &plan_str[..i],
746 &plan_str[i + 2..]
747 )?,
748 None => write!(f, "{plan_str}, elapsed={elapsed:?}")?,
749 },
750 None => write!(f, "{plan_str}")?,
751 }
752
753 if let Some(metrics) = plan.metrics() {
754 let metrics = metrics
755 .aggregate_by_name()
756 .sorted_for_display()
757 .timestamps_removed();
758
759 write!(f, ", metrics=[{metrics}]")?;
760 } else {
761 write!(f, ", metrics=[]")?;
762 }
763 writeln!(f)?;
764 self.indent += 1;
765 for child in plan.children() {
766 self.write_output(child, f, calcs)?;
767 }
768 self.indent -= 1;
769 std::fmt::Result::Ok(())
770 }
771 }
772 struct PrintWrapper {
774 plan: Arc<dyn ExecutionPlan>,
775 }
776 impl fmt::Display for PrintWrapper {
777 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
778 let mut calcs = CalculateVisitor {
779 highest_index: 0,
780 index_to_elapsed: HashMap::new(),
781 };
782 calcs.calculate_metrics(&self.plan);
783 let mut prints = PrintVisitor {
784 highest_index: 0,
785 indent: 0,
786 };
787 prints.write_output(&self.plan, f, &calcs)
788 }
789 }
790 let wrapper = PrintWrapper { plan };
791 format!("{}", wrapper)
792}
793
794pub trait SessionContextExt {
795 fn read_one_shot(
799 &self,
800 data: SendableRecordBatchStream,
801 ) -> datafusion::common::Result<DataFrame>;
802}
803
804pub struct OneShotPartitionStream {
805 data: Arc<Mutex<Option<SendableRecordBatchStream>>>,
806 schema: Arc<ArrowSchema>,
807}
808
809impl std::fmt::Debug for OneShotPartitionStream {
810 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
811 let data = self.data.lock().unwrap();
812 f.debug_struct("OneShotPartitionStream")
813 .field("exhausted", &data.is_none())
814 .field("schema", self.schema.as_ref())
815 .finish()
816 }
817}
818
819impl OneShotPartitionStream {
820 pub fn new(data: SendableRecordBatchStream) -> Self {
821 let schema = data.schema();
822 Self {
823 data: Arc::new(Mutex::new(Some(data))),
824 schema,
825 }
826 }
827}
828
829impl PartitionStream for OneShotPartitionStream {
830 fn schema(&self) -> &arrow_schema::SchemaRef {
831 &self.schema
832 }
833
834 fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
835 let mut stream = self.data.lock().unwrap();
836 stream
837 .take()
838 .expect("Attempt to consume a one shot dataframe multiple times")
839 }
840}
841
842impl SessionContextExt for SessionContext {
843 fn read_one_shot(
844 &self,
845 data: SendableRecordBatchStream,
846 ) -> datafusion::common::Result<DataFrame> {
847 let schema = data.schema();
848 let part_stream = Arc::new(OneShotPartitionStream::new(data));
849 let provider = StreamingTable::try_new(schema, vec![part_stream])?;
850 self.read_table(Arc::new(provider))
851 }
852}
853
854#[derive(Clone, Debug)]
855pub struct StrictBatchSizeExec {
856 input: Arc<dyn ExecutionPlan>,
857 batch_size: usize,
858}
859
860impl StrictBatchSizeExec {
861 pub fn new(input: Arc<dyn ExecutionPlan>, batch_size: usize) -> Self {
862 Self { input, batch_size }
863 }
864}
865
866impl DisplayAs for StrictBatchSizeExec {
867 fn fmt_as(
868 &self,
869 _t: datafusion::physical_plan::DisplayFormatType,
870 f: &mut std::fmt::Formatter,
871 ) -> std::fmt::Result {
872 write!(f, "StrictBatchSizeExec")
873 }
874}
875
876impl ExecutionPlan for StrictBatchSizeExec {
877 fn name(&self) -> &str {
878 "StrictBatchSizeExec"
879 }
880
881 fn as_any(&self) -> &dyn std::any::Any {
882 self
883 }
884
885 fn properties(&self) -> &PlanProperties {
886 self.input.properties()
887 }
888
889 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
890 vec![&self.input]
891 }
892
893 fn with_new_children(
894 self: Arc<Self>,
895 children: Vec<Arc<dyn ExecutionPlan>>,
896 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
897 Ok(Arc::new(Self {
898 input: children[0].clone(),
899 batch_size: self.batch_size,
900 }))
901 }
902
903 fn execute(
904 &self,
905 partition: usize,
906 context: Arc<TaskContext>,
907 ) -> datafusion_common::Result<SendableRecordBatchStream> {
908 let stream = self.input.execute(partition, context)?;
909 let schema = stream.schema();
910 let stream = StrictBatchSizeStream::new(stream, self.batch_size);
911 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
912 }
913
914 fn maintains_input_order(&self) -> Vec<bool> {
915 vec![true]
916 }
917
918 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
919 vec![false]
920 }
921
922 fn partition_statistics(
923 &self,
924 partition: Option<usize>,
925 ) -> datafusion_common::Result<Statistics> {
926 self.input.partition_statistics(partition)
927 }
928
929 fn cardinality_effect(&self) -> CardinalityEffect {
930 CardinalityEffect::Equal
931 }
932
933 fn supports_limit_pushdown(&self) -> bool {
934 true
935 }
936}
937
938#[cfg(test)]
939mod tests {
940 use super::*;
941
942 static CACHE_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
944
945 #[test]
946 fn test_session_context_cache() {
947 let _lock = CACHE_TEST_LOCK.lock().unwrap();
948 let cache = get_session_cache();
949
950 cache.lock().unwrap().clear();
952
953 let opts1 = LanceExecutionOptions::default();
955 let _ctx1 = get_session_context(&opts1);
956
957 {
958 let cache_guard = cache.lock().unwrap();
959 assert_eq!(cache_guard.len(), 1);
960 }
961
962 let _ctx1_again = get_session_context(&opts1);
964 {
965 let cache_guard = cache.lock().unwrap();
966 assert_eq!(cache_guard.len(), 1);
967 }
968
969 let opts2 = LanceExecutionOptions {
971 use_spilling: true,
972 ..Default::default()
973 };
974 let _ctx2 = get_session_context(&opts2);
975 {
976 let cache_guard = cache.lock().unwrap();
977 assert_eq!(cache_guard.len(), 2);
978 }
979 }
980
981 #[test]
982 fn test_session_context_cache_lru_eviction() {
983 let _lock = CACHE_TEST_LOCK.lock().unwrap();
984 let cache = get_session_cache();
985
986 cache.lock().unwrap().clear();
988
989 let configs: Vec<LanceExecutionOptions> = (0..4)
991 .map(|i| LanceExecutionOptions {
992 mem_pool_size: Some((i + 1) as u64 * 1024 * 1024),
993 ..Default::default()
994 })
995 .collect();
996
997 for config in &configs {
998 let _ctx = get_session_context(config);
999 }
1000
1001 {
1002 let cache_guard = cache.lock().unwrap();
1003 assert_eq!(cache_guard.len(), 4);
1004 }
1005
1006 std::thread::sleep(std::time::Duration::from_millis(1));
1009 let _ctx = get_session_context(&configs[0]);
1010
1011 let opts5 = LanceExecutionOptions {
1013 mem_pool_size: Some(5 * 1024 * 1024),
1014 ..Default::default()
1015 };
1016 let _ctx5 = get_session_context(&opts5);
1017
1018 {
1019 let cache_guard = cache.lock().unwrap();
1020 assert_eq!(cache_guard.len(), 4);
1021
1022 let key0 = SessionContextCacheKey::from_options(&configs[0]);
1024 assert!(
1025 cache_guard.contains_key(&key0),
1026 "config[0] should still be cached after recent access"
1027 );
1028
1029 let key1 = SessionContextCacheKey::from_options(&configs[1]);
1031 assert!(
1032 !cache_guard.contains_key(&key1),
1033 "config[1] should have been evicted"
1034 );
1035
1036 let key5 = SessionContextCacheKey::from_options(&opts5);
1038 assert!(
1039 cache_guard.contains_key(&key5),
1040 "new config should be cached"
1041 );
1042 }
1043 }
1044}