1use std::{
7 collections::HashMap,
8 fmt::{self, Formatter},
9 sync::{Arc, Mutex, OnceLock},
10 time::Duration,
11};
12
13use chrono::{DateTime, Utc};
14
15use arrow_array::RecordBatch;
16use arrow_schema::Schema as ArrowSchema;
17use datafusion::physical_plan::metrics::MetricType;
18use datafusion::{
19 catalog::streaming::StreamingTable,
20 dataframe::DataFrame,
21 execution::{
22 context::{SessionConfig, SessionContext},
23 disk_manager::DiskManagerBuilder,
24 memory_pool::FairSpillPool,
25 runtime_env::RuntimeEnvBuilder,
26 TaskContext,
27 },
28 physical_plan::{
29 analyze::AnalyzeExec,
30 display::DisplayableExecutionPlan,
31 execution_plan::{Boundedness, CardinalityEffect, EmissionType},
32 metrics::MetricValue,
33 stream::RecordBatchStreamAdapter,
34 streaming::PartitionStream,
35 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
36 },
37};
38use datafusion_common::{DataFusionError, Statistics};
39use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
40
41use futures::{stream, StreamExt};
42use lance_arrow::SchemaExt;
43use lance_core::{
44 utils::{
45 futures::FinallyStreamExt,
46 tracing::{StreamTracingExt, EXECUTION_PLAN_RUN, TRACE_EXECUTION},
47 },
48 Error, Result,
49};
50use log::{debug, info, warn};
51use snafu::location;
52use tracing::Span;
53
54use crate::udf::register_functions;
55use crate::{
56 chunker::StrictBatchSizeStream,
57 utils::{
58 MetricsExt, BYTES_READ_METRIC, INDEX_COMPARISONS_METRIC, INDICES_LOADED_METRIC,
59 IOPS_METRIC, PARTS_LOADED_METRIC, REQUESTS_METRIC,
60 },
61};
62
63pub struct OneShotExec {
71 stream: Mutex<Option<SendableRecordBatchStream>>,
72 schema: Arc<ArrowSchema>,
75 properties: PlanProperties,
76}
77
78impl OneShotExec {
79 pub fn new(stream: SendableRecordBatchStream) -> Self {
81 let schema = stream.schema();
82 Self {
83 stream: Mutex::new(Some(stream)),
84 schema: schema.clone(),
85 properties: PlanProperties::new(
86 EquivalenceProperties::new(schema),
87 Partitioning::RoundRobinBatch(1),
88 EmissionType::Incremental,
89 Boundedness::Bounded,
90 ),
91 }
92 }
93
94 pub fn from_batch(batch: RecordBatch) -> Self {
95 let schema = batch.schema();
96 let stream = Box::pin(RecordBatchStreamAdapter::new(
97 schema,
98 stream::iter(vec![Ok(batch)]),
99 ));
100 Self::new(stream)
101 }
102}
103
104impl std::fmt::Debug for OneShotExec {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 let stream = self.stream.lock().unwrap();
107 f.debug_struct("OneShotExec")
108 .field("exhausted", &stream.is_none())
109 .field("schema", self.schema.as_ref())
110 .finish()
111 }
112}
113
114impl DisplayAs for OneShotExec {
115 fn fmt_as(
116 &self,
117 t: datafusion::physical_plan::DisplayFormatType,
118 f: &mut std::fmt::Formatter,
119 ) -> std::fmt::Result {
120 let stream = self.stream.lock().unwrap();
121 let exhausted = if stream.is_some() { "" } else { "EXHAUSTED" };
122 let columns = self
123 .schema
124 .field_names()
125 .iter()
126 .cloned()
127 .cloned()
128 .collect::<Vec<_>>();
129 match t {
130 DisplayFormatType::Default | DisplayFormatType::Verbose => {
131 write!(
132 f,
133 "OneShotStream: {}columns=[{}]",
134 exhausted,
135 columns.join(",")
136 )
137 }
138 DisplayFormatType::TreeRender => {
139 write!(
140 f,
141 "OneShotStream\nexhausted={}\ncolumns=[{}]",
142 exhausted,
143 columns.join(",")
144 )
145 }
146 }
147 }
148}
149
150impl ExecutionPlan for OneShotExec {
151 fn name(&self) -> &str {
152 "OneShotExec"
153 }
154
155 fn as_any(&self) -> &dyn std::any::Any {
156 self
157 }
158
159 fn schema(&self) -> arrow_schema::SchemaRef {
160 self.schema.clone()
161 }
162
163 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
164 vec![]
165 }
166
167 fn with_new_children(
168 self: Arc<Self>,
169 children: Vec<Arc<dyn ExecutionPlan>>,
170 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
171 if !children.is_empty() {
173 return Err(datafusion_common::DataFusionError::Internal(
174 "OneShotExec does not support children".to_string(),
175 ));
176 }
177 Ok(self)
178 }
179
180 fn execute(
181 &self,
182 _partition: usize,
183 _context: Arc<datafusion::execution::TaskContext>,
184 ) -> datafusion_common::Result<SendableRecordBatchStream> {
185 let stream = self
186 .stream
187 .lock()
188 .map_err(|err| DataFusionError::Execution(err.to_string()))?
189 .take();
190 if let Some(stream) = stream {
191 Ok(stream)
192 } else {
193 Err(DataFusionError::Execution(
194 "OneShotExec has already been executed".to_string(),
195 ))
196 }
197 }
198
199 fn statistics(&self) -> datafusion_common::Result<datafusion_common::Statistics> {
200 Ok(Statistics::new_unknown(&self.schema))
201 }
202
203 fn properties(&self) -> &datafusion::physical_plan::PlanProperties {
204 &self.properties
205 }
206}
207
208struct TracedExec {
209 input: Arc<dyn ExecutionPlan>,
210 properties: PlanProperties,
211 span: Span,
212}
213
214impl TracedExec {
215 pub fn new(input: Arc<dyn ExecutionPlan>, span: Span) -> Self {
216 Self {
217 properties: input.properties().clone(),
218 input,
219 span,
220 }
221 }
222}
223
224impl DisplayAs for TracedExec {
225 fn fmt_as(
226 &self,
227 t: datafusion::physical_plan::DisplayFormatType,
228 f: &mut std::fmt::Formatter,
229 ) -> std::fmt::Result {
230 match t {
231 DisplayFormatType::Default
232 | DisplayFormatType::Verbose
233 | DisplayFormatType::TreeRender => {
234 write!(f, "TracedExec")
235 }
236 }
237 }
238}
239
240impl std::fmt::Debug for TracedExec {
241 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
242 write!(f, "TracedExec")
243 }
244}
245impl ExecutionPlan for TracedExec {
246 fn name(&self) -> &str {
247 "TracedExec"
248 }
249
250 fn as_any(&self) -> &dyn std::any::Any {
251 self
252 }
253
254 fn properties(&self) -> &PlanProperties {
255 &self.properties
256 }
257
258 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
259 vec![&self.input]
260 }
261
262 fn with_new_children(
263 self: Arc<Self>,
264 children: Vec<Arc<dyn ExecutionPlan>>,
265 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
266 Ok(Arc::new(Self {
267 input: children[0].clone(),
268 properties: self.properties.clone(),
269 span: self.span.clone(),
270 }))
271 }
272
273 fn execute(
274 &self,
275 partition: usize,
276 context: Arc<TaskContext>,
277 ) -> datafusion_common::Result<SendableRecordBatchStream> {
278 let _guard = self.span.enter();
279 let stream = self.input.execute(partition, context)?;
280 let schema = stream.schema();
281 let stream = stream.stream_in_span(self.span.clone());
282 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
283 }
284}
285
286pub type ExecutionStatsCallback = Arc<dyn Fn(&ExecutionSummaryCounts) + Send + Sync>;
288
289#[derive(Default, Clone)]
290pub struct LanceExecutionOptions {
291 pub use_spilling: bool,
292 pub mem_pool_size: Option<u64>,
293 pub max_temp_directory_size: Option<u64>,
294 pub batch_size: Option<usize>,
295 pub target_partition: Option<usize>,
296 pub execution_stats_callback: Option<ExecutionStatsCallback>,
297 pub skip_logging: bool,
298}
299
300impl std::fmt::Debug for LanceExecutionOptions {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 f.debug_struct("LanceExecutionOptions")
303 .field("use_spilling", &self.use_spilling)
304 .field("mem_pool_size", &self.mem_pool_size)
305 .field("max_temp_directory_size", &self.max_temp_directory_size)
306 .field("batch_size", &self.batch_size)
307 .field("target_partition", &self.target_partition)
308 .field("skip_logging", &self.skip_logging)
309 .field(
310 "execution_stats_callback",
311 &self.execution_stats_callback.is_some(),
312 )
313 .finish()
314 }
315}
316
317const DEFAULT_LANCE_MEM_POOL_SIZE: u64 = 100 * 1024 * 1024;
318const DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE: u64 = 100 * 1024 * 1024 * 1024; impl LanceExecutionOptions {
321 pub fn mem_pool_size(&self) -> u64 {
322 self.mem_pool_size.unwrap_or_else(|| {
323 std::env::var("LANCE_MEM_POOL_SIZE")
324 .map(|s| match s.parse::<u64>() {
325 Ok(v) => v,
326 Err(e) => {
327 warn!("Failed to parse LANCE_MEM_POOL_SIZE: {}, using default", e);
328 DEFAULT_LANCE_MEM_POOL_SIZE
329 }
330 })
331 .unwrap_or(DEFAULT_LANCE_MEM_POOL_SIZE)
332 })
333 }
334
335 pub fn max_temp_directory_size(&self) -> u64 {
336 self.max_temp_directory_size.unwrap_or_else(|| {
337 std::env::var("LANCE_MAX_TEMP_DIRECTORY_SIZE")
338 .map(|s| match s.parse::<u64>() {
339 Ok(v) => v,
340 Err(e) => {
341 warn!(
342 "Failed to parse LANCE_MAX_TEMP_DIRECTORY_SIZE: {}, using default",
343 e
344 );
345 DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE
346 }
347 })
348 .unwrap_or(DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE)
349 })
350 }
351
352 pub fn use_spilling(&self) -> bool {
353 if !self.use_spilling {
354 return false;
355 }
356 std::env::var("LANCE_BYPASS_SPILLING")
357 .map(|_| {
358 info!("Bypassing spilling because LANCE_BYPASS_SPILLING is set");
359 false
360 })
361 .unwrap_or(true)
362 }
363}
364
365pub fn new_session_context(options: &LanceExecutionOptions) -> SessionContext {
366 let mut session_config = SessionConfig::new();
367 let mut runtime_env_builder = RuntimeEnvBuilder::new();
368 if let Some(target_partition) = options.target_partition {
369 session_config = session_config.with_target_partitions(target_partition);
370 }
371 if options.use_spilling() {
372 let disk_manager_builder = DiskManagerBuilder::default()
373 .with_max_temp_directory_size(options.max_temp_directory_size());
374 runtime_env_builder = runtime_env_builder
375 .with_disk_manager_builder(disk_manager_builder)
376 .with_memory_pool(Arc::new(FairSpillPool::new(
377 options.mem_pool_size() as usize
378 )));
379 }
380 let runtime_env = runtime_env_builder.build_arc().unwrap();
381
382 let ctx = SessionContext::new_with_config_rt(session_config, runtime_env);
383 register_functions(&ctx);
384
385 ctx
386}
387
388#[derive(Clone, Debug, PartialEq, Eq, Hash)]
390struct SessionContextCacheKey {
391 mem_pool_size: u64,
392 max_temp_directory_size: u64,
393 target_partition: Option<usize>,
394 use_spilling: bool,
395}
396
397impl SessionContextCacheKey {
398 fn from_options(options: &LanceExecutionOptions) -> Self {
399 Self {
400 mem_pool_size: options.mem_pool_size(),
401 max_temp_directory_size: options.max_temp_directory_size(),
402 target_partition: options.target_partition,
403 use_spilling: options.use_spilling(),
404 }
405 }
406}
407
408struct CachedSessionContext {
409 context: SessionContext,
410 last_access: std::time::Instant,
411}
412
413fn get_session_cache() -> &'static Mutex<HashMap<SessionContextCacheKey, CachedSessionContext>> {
414 static SESSION_CACHE: OnceLock<Mutex<HashMap<SessionContextCacheKey, CachedSessionContext>>> =
415 OnceLock::new();
416 SESSION_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
417}
418
419fn get_max_cache_size() -> usize {
420 const DEFAULT_CACHE_SIZE: usize = 4;
421 static MAX_CACHE_SIZE: OnceLock<usize> = OnceLock::new();
422 *MAX_CACHE_SIZE.get_or_init(|| {
423 std::env::var("LANCE_SESSION_CACHE_SIZE")
424 .ok()
425 .and_then(|v| v.parse().ok())
426 .unwrap_or(DEFAULT_CACHE_SIZE)
427 })
428}
429
430pub fn get_session_context(options: &LanceExecutionOptions) -> SessionContext {
431 let key = SessionContextCacheKey::from_options(options);
432 let mut cache = get_session_cache()
433 .lock()
434 .unwrap_or_else(|e| e.into_inner());
435
436 if let Some(entry) = cache.get_mut(&key) {
438 entry.last_access = std::time::Instant::now();
439 return entry.context.clone();
440 }
441
442 if cache.len() >= get_max_cache_size() {
444 if let Some(lru_key) = cache
445 .iter()
446 .min_by_key(|(_, v)| v.last_access)
447 .map(|(k, _)| k.clone())
448 {
449 cache.remove(&lru_key);
450 }
451 }
452
453 let context = new_session_context(options);
454 cache.insert(
455 key,
456 CachedSessionContext {
457 context: context.clone(),
458 last_access: std::time::Instant::now(),
459 },
460 );
461 context
462}
463
464fn get_task_context(
465 session_ctx: &SessionContext,
466 options: &LanceExecutionOptions,
467) -> Arc<TaskContext> {
468 let mut state = session_ctx.state();
469 if let Some(batch_size) = options.batch_size.as_ref() {
470 state.config_mut().options_mut().execution.batch_size = *batch_size;
471 }
472
473 state.task_ctx()
474}
475
476#[derive(Default, Clone, Debug, PartialEq, Eq)]
477pub struct ExecutionSummaryCounts {
478 pub iops: usize,
480 pub requests: usize,
483 pub bytes_read: usize,
485 pub indices_loaded: usize,
487 pub parts_loaded: usize,
489 pub index_comparisons: usize,
491 pub all_counts: HashMap<String, usize>,
494}
495
496fn visit_node(node: &dyn ExecutionPlan, counts: &mut ExecutionSummaryCounts) {
497 if let Some(metrics) = node.metrics() {
498 for (metric_name, count) in metrics.iter_counts() {
499 match metric_name.as_ref() {
500 IOPS_METRIC => counts.iops += count.value(),
501 REQUESTS_METRIC => counts.requests += count.value(),
502 BYTES_READ_METRIC => counts.bytes_read += count.value(),
503 INDICES_LOADED_METRIC => counts.indices_loaded += count.value(),
504 PARTS_LOADED_METRIC => counts.parts_loaded += count.value(),
505 INDEX_COMPARISONS_METRIC => counts.index_comparisons += count.value(),
506 _ => {
507 let existing = counts
508 .all_counts
509 .entry(metric_name.as_ref().to_string())
510 .or_insert(0);
511 *existing += count.value();
512 }
513 }
514 }
515 for (metric_name, gauge) in metrics.iter_gauges() {
517 match metric_name.as_ref() {
518 IOPS_METRIC => counts.iops += gauge.value(),
519 REQUESTS_METRIC => counts.requests += gauge.value(),
520 BYTES_READ_METRIC => counts.bytes_read += gauge.value(),
521 _ => {}
522 }
523 }
524 }
525 for child in node.children() {
526 visit_node(child.as_ref(), counts);
527 }
528}
529
530fn report_plan_summary_metrics(plan: &dyn ExecutionPlan, options: &LanceExecutionOptions) {
531 let output_rows = plan
532 .metrics()
533 .map(|m| m.output_rows().unwrap_or(0))
534 .unwrap_or(0);
535 let mut counts = ExecutionSummaryCounts::default();
536 visit_node(plan, &mut counts);
537 tracing::info!(
538 target: TRACE_EXECUTION,
539 r#type = EXECUTION_PLAN_RUN,
540 plan_summary = display_plan_one_liner(plan),
541 output_rows,
542 iops = counts.iops,
543 requests = counts.requests,
544 bytes_read = counts.bytes_read,
545 indices_loaded = counts.indices_loaded,
546 parts_loaded = counts.parts_loaded,
547 index_comparisons = counts.index_comparisons,
548 );
549 if let Some(callback) = options.execution_stats_callback.as_ref() {
550 callback(&counts);
551 }
552}
553
554fn display_plan_one_liner(plan: &dyn ExecutionPlan) -> String {
561 let mut output = String::new();
562
563 display_plan_one_liner_impl(plan, &mut output);
564
565 output
566}
567
568fn display_plan_one_liner_impl(plan: &dyn ExecutionPlan, output: &mut String) {
569 let name = plan.name().trim_end_matches("Exec");
571 output.push_str(name);
572
573 let children = plan.children();
574 if !children.is_empty() {
575 output.push('(');
576 for (i, child) in children.iter().enumerate() {
577 if i > 0 {
578 output.push(',');
579 }
580 display_plan_one_liner_impl(child.as_ref(), output);
581 }
582 output.push(')');
583 }
584}
585
586pub fn execute_plan(
590 plan: Arc<dyn ExecutionPlan>,
591 options: LanceExecutionOptions,
592) -> Result<SendableRecordBatchStream> {
593 if !options.skip_logging {
594 debug!(
595 "Executing plan:\n{}",
596 DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
597 );
598 }
599
600 let session_ctx = get_session_context(&options);
601
602 assert_eq!(plan.properties().partitioning.partition_count(), 1);
605 let stream = plan.execute(0, get_task_context(&session_ctx, &options))?;
606
607 let schema = stream.schema();
608 let stream = stream.finally(move || {
609 if !options.skip_logging {
610 report_plan_summary_metrics(plan.as_ref(), &options);
611 }
612 });
613 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
614}
615
616pub async fn analyze_plan(
617 plan: Arc<dyn ExecutionPlan>,
618 options: LanceExecutionOptions,
619) -> Result<String> {
620 let plan = Arc::new(TracedExec::new(plan, Span::current()));
623
624 let schema = plan.schema();
625 let analyze = Arc::new(AnalyzeExec::new(
627 true,
628 true,
629 vec![MetricType::SUMMARY],
630 plan,
631 schema,
632 ));
633
634 let session_ctx = get_session_context(&options);
635 assert_eq!(analyze.properties().partitioning.partition_count(), 1);
636 let mut stream = analyze
637 .execute(0, get_task_context(&session_ctx, &options))
638 .map_err(|err| {
639 Error::io(
640 format!("Failed to execute analyze plan: {}", err),
641 location!(),
642 )
643 })?;
644
645 while (stream.next().await).is_some() {}
647
648 let result = format_plan(analyze);
649 Ok(result)
650}
651
652pub fn format_plan(plan: Arc<dyn ExecutionPlan>) -> String {
653 struct CalculateVisitor {
655 highest_index: usize,
656 index_to_elapsed: HashMap<usize, Duration>,
657 }
658
659 struct SubtreeMetrics {
661 min_start: Option<DateTime<Utc>>,
662 max_end: Option<DateTime<Utc>>,
663 }
664
665 impl CalculateVisitor {
666 fn calculate_metrics(&mut self, plan: &Arc<dyn ExecutionPlan>) -> SubtreeMetrics {
667 self.highest_index += 1;
668 let plan_index = self.highest_index;
669
670 let (mut min_start, mut max_end) = Self::node_timerange(plan);
672
673 for child in plan.children() {
675 let child_metrics = self.calculate_metrics(child);
676 min_start = Self::min_option(min_start, child_metrics.min_start);
677 max_end = Self::max_option(max_end, child_metrics.max_end);
678 }
679
680 let elapsed = match (min_start, max_end) {
682 (Some(start), Some(end)) => Some((end - start).to_std().unwrap_or_default()),
683 _ => None,
684 };
685
686 if let Some(e) = elapsed {
687 self.index_to_elapsed.insert(plan_index, e);
688 }
689
690 SubtreeMetrics { min_start, max_end }
691 }
692
693 fn node_timerange(
694 plan: &Arc<dyn ExecutionPlan>,
695 ) -> (Option<DateTime<Utc>>, Option<DateTime<Utc>>) {
696 let Some(metrics) = plan.metrics() else {
697 return (None, None);
698 };
699 let min_start = metrics
700 .iter()
701 .filter_map(|m| match m.value() {
702 MetricValue::StartTimestamp(ts) => ts.value(),
703 _ => None,
704 })
705 .min();
706 let max_end = metrics
707 .iter()
708 .filter_map(|m| match m.value() {
709 MetricValue::EndTimestamp(ts) => ts.value(),
710 _ => None,
711 })
712 .max();
713 (min_start, max_end)
714 }
715
716 fn min_option(a: Option<DateTime<Utc>>, b: Option<DateTime<Utc>>) -> Option<DateTime<Utc>> {
717 [a, b].into_iter().flatten().min()
718 }
719
720 fn max_option(a: Option<DateTime<Utc>>, b: Option<DateTime<Utc>>) -> Option<DateTime<Utc>> {
721 [a, b].into_iter().flatten().max()
722 }
723 }
724
725 struct PrintVisitor {
727 highest_index: usize,
728 indent: usize,
729 }
730 impl PrintVisitor {
731 fn write_output(
732 &mut self,
733 plan: &Arc<dyn ExecutionPlan>,
734 f: &mut Formatter,
735 calcs: &CalculateVisitor,
736 ) -> std::fmt::Result {
737 self.highest_index += 1;
738 write!(f, "{:indent$}", "", indent = self.indent * 2)?;
739
740 let displayable =
742 datafusion::physical_plan::display::DisplayableExecutionPlan::new(plan.as_ref());
743 let plan_str = displayable.one_line().to_string();
744 let plan_str = plan_str.trim();
745
746 match calcs.index_to_elapsed.get(&self.highest_index) {
748 Some(elapsed) => match plan_str.find(": ") {
749 Some(i) => write!(
750 f,
751 "{}: elapsed={elapsed:?}, {}",
752 &plan_str[..i],
753 &plan_str[i + 2..]
754 )?,
755 None => write!(f, "{plan_str}, elapsed={elapsed:?}")?,
756 },
757 None => write!(f, "{plan_str}")?,
758 }
759
760 if let Some(metrics) = plan.metrics() {
761 let metrics = metrics
762 .aggregate_by_name()
763 .sorted_for_display()
764 .timestamps_removed();
765
766 write!(f, ", metrics=[{metrics}]")?;
767 } else {
768 write!(f, ", metrics=[]")?;
769 }
770 writeln!(f)?;
771 self.indent += 1;
772 for child in plan.children() {
773 self.write_output(child, f, calcs)?;
774 }
775 self.indent -= 1;
776 std::fmt::Result::Ok(())
777 }
778 }
779 struct PrintWrapper {
781 plan: Arc<dyn ExecutionPlan>,
782 }
783 impl fmt::Display for PrintWrapper {
784 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
785 let mut calcs = CalculateVisitor {
786 highest_index: 0,
787 index_to_elapsed: HashMap::new(),
788 };
789 calcs.calculate_metrics(&self.plan);
790 let mut prints = PrintVisitor {
791 highest_index: 0,
792 indent: 0,
793 };
794 prints.write_output(&self.plan, f, &calcs)
795 }
796 }
797 let wrapper = PrintWrapper { plan };
798 format!("{}", wrapper)
799}
800
801pub trait SessionContextExt {
802 fn read_one_shot(
806 &self,
807 data: SendableRecordBatchStream,
808 ) -> datafusion::common::Result<DataFrame>;
809}
810
811pub struct OneShotPartitionStream {
812 data: Arc<Mutex<Option<SendableRecordBatchStream>>>,
813 schema: Arc<ArrowSchema>,
814}
815
816impl std::fmt::Debug for OneShotPartitionStream {
817 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
818 let data = self.data.lock().unwrap();
819 f.debug_struct("OneShotPartitionStream")
820 .field("exhausted", &data.is_none())
821 .field("schema", self.schema.as_ref())
822 .finish()
823 }
824}
825
826impl OneShotPartitionStream {
827 pub fn new(data: SendableRecordBatchStream) -> Self {
828 let schema = data.schema();
829 Self {
830 data: Arc::new(Mutex::new(Some(data))),
831 schema,
832 }
833 }
834}
835
836impl PartitionStream for OneShotPartitionStream {
837 fn schema(&self) -> &arrow_schema::SchemaRef {
838 &self.schema
839 }
840
841 fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
842 let mut stream = self.data.lock().unwrap();
843 stream
844 .take()
845 .expect("Attempt to consume a one shot dataframe multiple times")
846 }
847}
848
849impl SessionContextExt for SessionContext {
850 fn read_one_shot(
851 &self,
852 data: SendableRecordBatchStream,
853 ) -> datafusion::common::Result<DataFrame> {
854 let schema = data.schema();
855 let part_stream = Arc::new(OneShotPartitionStream::new(data));
856 let provider = StreamingTable::try_new(schema, vec![part_stream])?;
857 self.read_table(Arc::new(provider))
858 }
859}
860
861#[derive(Clone, Debug)]
862pub struct StrictBatchSizeExec {
863 input: Arc<dyn ExecutionPlan>,
864 batch_size: usize,
865}
866
867impl StrictBatchSizeExec {
868 pub fn new(input: Arc<dyn ExecutionPlan>, batch_size: usize) -> Self {
869 Self { input, batch_size }
870 }
871}
872
873impl DisplayAs for StrictBatchSizeExec {
874 fn fmt_as(
875 &self,
876 _t: datafusion::physical_plan::DisplayFormatType,
877 f: &mut std::fmt::Formatter,
878 ) -> std::fmt::Result {
879 write!(f, "StrictBatchSizeExec")
880 }
881}
882
883impl ExecutionPlan for StrictBatchSizeExec {
884 fn name(&self) -> &str {
885 "StrictBatchSizeExec"
886 }
887
888 fn as_any(&self) -> &dyn std::any::Any {
889 self
890 }
891
892 fn properties(&self) -> &PlanProperties {
893 self.input.properties()
894 }
895
896 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
897 vec![&self.input]
898 }
899
900 fn with_new_children(
901 self: Arc<Self>,
902 children: Vec<Arc<dyn ExecutionPlan>>,
903 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
904 Ok(Arc::new(Self {
905 input: children[0].clone(),
906 batch_size: self.batch_size,
907 }))
908 }
909
910 fn execute(
911 &self,
912 partition: usize,
913 context: Arc<TaskContext>,
914 ) -> datafusion_common::Result<SendableRecordBatchStream> {
915 let stream = self.input.execute(partition, context)?;
916 let schema = stream.schema();
917 let stream = StrictBatchSizeStream::new(stream, self.batch_size);
918 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
919 }
920
921 fn maintains_input_order(&self) -> Vec<bool> {
922 vec![true]
923 }
924
925 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
926 vec![false]
927 }
928
929 fn partition_statistics(
930 &self,
931 partition: Option<usize>,
932 ) -> datafusion_common::Result<Statistics> {
933 self.input.partition_statistics(partition)
934 }
935
936 fn cardinality_effect(&self) -> CardinalityEffect {
937 CardinalityEffect::Equal
938 }
939
940 fn supports_limit_pushdown(&self) -> bool {
941 true
942 }
943}
944
945#[cfg(test)]
946mod tests {
947 use super::*;
948
949 static CACHE_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
951
952 #[test]
953 fn test_session_context_cache() {
954 let _lock = CACHE_TEST_LOCK.lock().unwrap();
955 let cache = get_session_cache();
956
957 cache.lock().unwrap().clear();
959
960 let opts1 = LanceExecutionOptions::default();
962 let _ctx1 = get_session_context(&opts1);
963
964 {
965 let cache_guard = cache.lock().unwrap();
966 assert_eq!(cache_guard.len(), 1);
967 }
968
969 let _ctx1_again = get_session_context(&opts1);
971 {
972 let cache_guard = cache.lock().unwrap();
973 assert_eq!(cache_guard.len(), 1);
974 }
975
976 let opts2 = LanceExecutionOptions {
978 use_spilling: true,
979 ..Default::default()
980 };
981 let _ctx2 = get_session_context(&opts2);
982 {
983 let cache_guard = cache.lock().unwrap();
984 assert_eq!(cache_guard.len(), 2);
985 }
986 }
987
988 #[test]
989 fn test_session_context_cache_lru_eviction() {
990 let _lock = CACHE_TEST_LOCK.lock().unwrap();
991 let cache = get_session_cache();
992
993 cache.lock().unwrap().clear();
995
996 let configs: Vec<LanceExecutionOptions> = (0..4)
998 .map(|i| LanceExecutionOptions {
999 mem_pool_size: Some((i + 1) as u64 * 1024 * 1024),
1000 ..Default::default()
1001 })
1002 .collect();
1003
1004 for config in &configs {
1005 let _ctx = get_session_context(config);
1006 }
1007
1008 {
1009 let cache_guard = cache.lock().unwrap();
1010 assert_eq!(cache_guard.len(), 4);
1011 }
1012
1013 std::thread::sleep(std::time::Duration::from_millis(1));
1016 let _ctx = get_session_context(&configs[0]);
1017
1018 let opts5 = LanceExecutionOptions {
1020 mem_pool_size: Some(5 * 1024 * 1024),
1021 ..Default::default()
1022 };
1023 let _ctx5 = get_session_context(&opts5);
1024
1025 {
1026 let cache_guard = cache.lock().unwrap();
1027 assert_eq!(cache_guard.len(), 4);
1028
1029 let key0 = SessionContextCacheKey::from_options(&configs[0]);
1031 assert!(
1032 cache_guard.contains_key(&key0),
1033 "config[0] should still be cached after recent access"
1034 );
1035
1036 let key1 = SessionContextCacheKey::from_options(&configs[1]);
1038 assert!(
1039 !cache_guard.contains_key(&key1),
1040 "config[1] should have been evicted"
1041 );
1042
1043 let key5 = SessionContextCacheKey::from_options(&opts5);
1045 assert!(
1046 cache_guard.contains_key(&key5),
1047 "new config should be cached"
1048 );
1049 }
1050 }
1051}