1use std::{
7 collections::HashMap,
8 fmt::{self, Formatter},
9 sync::{Arc, Mutex, OnceLock},
10 time::Duration,
11};
12
13use chrono::{DateTime, Utc};
14
15use arrow_array::RecordBatch;
16use arrow_schema::Schema as ArrowSchema;
17use datafusion::physical_plan::metrics::MetricType;
18use datafusion::{
19 catalog::streaming::StreamingTable,
20 dataframe::DataFrame,
21 execution::{
22 TaskContext,
23 context::{SessionConfig, SessionContext},
24 disk_manager::DiskManagerBuilder,
25 memory_pool::FairSpillPool,
26 runtime_env::RuntimeEnvBuilder,
27 },
28 physical_plan::{
29 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
30 analyze::AnalyzeExec,
31 coalesce_partitions::CoalescePartitionsExec,
32 display::DisplayableExecutionPlan,
33 execution_plan::{Boundedness, CardinalityEffect, EmissionType},
34 metrics::MetricValue,
35 stream::RecordBatchStreamAdapter,
36 streaming::PartitionStream,
37 },
38};
39use datafusion_common::{DataFusionError, Statistics};
40use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
41
42use futures::{StreamExt, stream};
43use lance_arrow::SchemaExt;
44use lance_core::{
45 Error, Result,
46 utils::{
47 futures::FinallyStreamExt,
48 tracing::{EXECUTION_PLAN_RUN, StreamTracingExt, TRACE_EXECUTION},
49 },
50};
51use log::{debug, info, warn};
52use tracing::Span;
53
54use crate::udf::register_functions;
55use crate::{
56 chunker::StrictBatchSizeStream,
57 utils::{
58 BYTES_READ_METRIC, INDEX_COMPARISONS_METRIC, INDICES_LOADED_METRIC, IOPS_METRIC,
59 MetricsExt, PARTS_LOADED_METRIC, REQUESTS_METRIC,
60 },
61};
62
63pub struct OneShotExec {
71 stream: Mutex<Option<SendableRecordBatchStream>>,
72 schema: Arc<ArrowSchema>,
75 properties: Arc<PlanProperties>,
76}
77
78impl OneShotExec {
79 pub fn new(stream: SendableRecordBatchStream) -> Self {
81 let schema = stream.schema();
82 Self {
83 stream: Mutex::new(Some(stream)),
84 schema: schema.clone(),
85 properties: Arc::new(PlanProperties::new(
86 EquivalenceProperties::new(schema),
87 Partitioning::RoundRobinBatch(1),
88 EmissionType::Incremental,
89 Boundedness::Bounded,
90 )),
91 }
92 }
93
94 pub fn from_batch(batch: RecordBatch) -> Self {
95 let schema = batch.schema();
96 let stream = Box::pin(RecordBatchStreamAdapter::new(
97 schema,
98 stream::iter(vec![Ok(batch)]),
99 ));
100 Self::new(stream)
101 }
102}
103
104impl std::fmt::Debug for OneShotExec {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 let stream = self.stream.lock().unwrap();
107 f.debug_struct("OneShotExec")
108 .field("exhausted", &stream.is_none())
109 .field("schema", self.schema.as_ref())
110 .finish()
111 }
112}
113
114impl DisplayAs for OneShotExec {
115 fn fmt_as(
116 &self,
117 t: datafusion::physical_plan::DisplayFormatType,
118 f: &mut std::fmt::Formatter,
119 ) -> std::fmt::Result {
120 let stream = self.stream.lock().unwrap();
121 let exhausted = if stream.is_some() { "" } else { "EXHAUSTED" };
122 let columns = self
123 .schema
124 .field_names()
125 .iter()
126 .cloned()
127 .cloned()
128 .collect::<Vec<_>>();
129 match t {
130 DisplayFormatType::Default | DisplayFormatType::Verbose => {
131 write!(
132 f,
133 "OneShotStream: {}columns=[{}]",
134 exhausted,
135 columns.join(",")
136 )
137 }
138 DisplayFormatType::TreeRender => {
139 write!(
140 f,
141 "OneShotStream\nexhausted={}\ncolumns=[{}]",
142 exhausted,
143 columns.join(",")
144 )
145 }
146 }
147 }
148}
149
150impl ExecutionPlan for OneShotExec {
151 fn name(&self) -> &str {
152 "OneShotExec"
153 }
154
155 fn as_any(&self) -> &dyn std::any::Any {
156 self
157 }
158
159 fn schema(&self) -> arrow_schema::SchemaRef {
160 self.schema.clone()
161 }
162
163 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
164 vec![]
165 }
166
167 fn with_new_children(
168 self: Arc<Self>,
169 children: Vec<Arc<dyn ExecutionPlan>>,
170 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
171 if !children.is_empty() {
173 return Err(datafusion_common::DataFusionError::Internal(
174 "OneShotExec does not support children".to_string(),
175 ));
176 }
177 Ok(self)
178 }
179
180 fn execute(
181 &self,
182 _partition: usize,
183 _context: Arc<datafusion::execution::TaskContext>,
184 ) -> datafusion_common::Result<SendableRecordBatchStream> {
185 let stream = self
186 .stream
187 .lock()
188 .map_err(|err| DataFusionError::Execution(err.to_string()))?
189 .take();
190 if let Some(stream) = stream {
191 Ok(stream)
192 } else {
193 Err(DataFusionError::Execution(
194 "OneShotExec has already been executed".to_string(),
195 ))
196 }
197 }
198
199 fn properties(&self) -> &Arc<datafusion::physical_plan::PlanProperties> {
200 &self.properties
201 }
202}
203
204struct TracedExec {
205 input: Arc<dyn ExecutionPlan>,
206 properties: Arc<PlanProperties>,
207 span: Span,
208}
209
210impl TracedExec {
211 pub fn new(input: Arc<dyn ExecutionPlan>, span: Span) -> Self {
212 Self {
213 properties: input.properties().clone(),
214 input,
215 span,
216 }
217 }
218}
219
220impl DisplayAs for TracedExec {
221 fn fmt_as(
222 &self,
223 t: datafusion::physical_plan::DisplayFormatType,
224 f: &mut std::fmt::Formatter,
225 ) -> std::fmt::Result {
226 match t {
227 DisplayFormatType::Default
228 | DisplayFormatType::Verbose
229 | DisplayFormatType::TreeRender => {
230 write!(f, "TracedExec")
231 }
232 }
233 }
234}
235
236impl std::fmt::Debug for TracedExec {
237 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
238 write!(f, "TracedExec")
239 }
240}
241impl ExecutionPlan for TracedExec {
242 fn name(&self) -> &str {
243 "TracedExec"
244 }
245
246 fn as_any(&self) -> &dyn std::any::Any {
247 self
248 }
249
250 fn properties(&self) -> &Arc<PlanProperties> {
251 &self.properties
252 }
253
254 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
255 vec![&self.input]
256 }
257
258 fn with_new_children(
259 self: Arc<Self>,
260 children: Vec<Arc<dyn ExecutionPlan>>,
261 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
262 Ok(Arc::new(Self {
263 input: children[0].clone(),
264 properties: self.properties.clone(),
265 span: self.span.clone(),
266 }))
267 }
268
269 fn execute(
270 &self,
271 partition: usize,
272 context: Arc<TaskContext>,
273 ) -> datafusion_common::Result<SendableRecordBatchStream> {
274 let _guard = self.span.enter();
275 let stream = self.input.execute(partition, context)?;
276 let schema = stream.schema();
277 let stream = stream.stream_in_span(self.span.clone());
278 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
279 }
280}
281
282pub type ExecutionStatsCallback = Arc<dyn Fn(&ExecutionSummaryCounts) + Send + Sync>;
284
285#[derive(Default, Clone)]
286pub struct LanceExecutionOptions {
287 pub use_spilling: bool,
288 pub mem_pool_size: Option<u64>,
289 pub max_temp_directory_size: Option<u64>,
290 pub batch_size: Option<usize>,
291 pub target_partition: Option<usize>,
292 pub execution_stats_callback: Option<ExecutionStatsCallback>,
293 pub skip_logging: bool,
294}
295
296impl std::fmt::Debug for LanceExecutionOptions {
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 f.debug_struct("LanceExecutionOptions")
299 .field("use_spilling", &self.use_spilling)
300 .field("mem_pool_size", &self.mem_pool_size)
301 .field("max_temp_directory_size", &self.max_temp_directory_size)
302 .field("batch_size", &self.batch_size)
303 .field("target_partition", &self.target_partition)
304 .field("skip_logging", &self.skip_logging)
305 .field(
306 "execution_stats_callback",
307 &self.execution_stats_callback.is_some(),
308 )
309 .finish()
310 }
311}
312
313const DEFAULT_LANCE_MEM_POOL_SIZE_PER_PARTITION: u64 = 100 * 1024 * 1024;
314const DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE: u64 = 100 * 1024 * 1024 * 1024; impl LanceExecutionOptions {
317 pub fn mem_pool_size(&self) -> u64 {
318 let num_partitions = self.target_partition.unwrap_or(1) as u64;
319 self.mem_pool_size.unwrap_or_else(|| {
320 std::env::var("LANCE_MEM_POOL_SIZE")
321 .map(|s| match s.parse::<u64>() {
322 Ok(v) => v,
323 Err(e) => {
324 warn!("Failed to parse LANCE_MEM_POOL_SIZE: {}, using default", e);
325 DEFAULT_LANCE_MEM_POOL_SIZE_PER_PARTITION * num_partitions
326 }
327 })
328 .unwrap_or(DEFAULT_LANCE_MEM_POOL_SIZE_PER_PARTITION * num_partitions)
329 })
330 }
331
332 pub fn max_temp_directory_size(&self) -> u64 {
333 self.max_temp_directory_size.unwrap_or_else(|| {
334 std::env::var("LANCE_MAX_TEMP_DIRECTORY_SIZE")
335 .map(|s| match s.parse::<u64>() {
336 Ok(v) => v,
337 Err(e) => {
338 warn!(
339 "Failed to parse LANCE_MAX_TEMP_DIRECTORY_SIZE: {}, using default",
340 e
341 );
342 DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE
343 }
344 })
345 .unwrap_or(DEFAULT_LANCE_MAX_TEMP_DIRECTORY_SIZE)
346 })
347 }
348
349 pub fn use_spilling(&self) -> bool {
350 if !self.use_spilling {
351 return false;
352 }
353 std::env::var("LANCE_BYPASS_SPILLING")
354 .map(|_| {
355 info!("Bypassing spilling because LANCE_BYPASS_SPILLING is set");
356 false
357 })
358 .unwrap_or(true)
359 }
360}
361
362pub fn new_session_context(options: &LanceExecutionOptions) -> SessionContext {
363 let mut session_config = SessionConfig::new();
364 let mut runtime_env_builder = RuntimeEnvBuilder::new();
365 if let Some(target_partition) = options.target_partition {
366 session_config = session_config.with_target_partitions(target_partition);
367 }
368 if options.use_spilling() {
369 let disk_manager_builder = DiskManagerBuilder::default()
370 .with_max_temp_directory_size(options.max_temp_directory_size());
371 runtime_env_builder = runtime_env_builder
372 .with_disk_manager_builder(disk_manager_builder)
373 .with_memory_pool(Arc::new(FairSpillPool::new(
374 options.mem_pool_size() as usize
375 )));
376 }
377 let runtime_env = runtime_env_builder.build_arc().unwrap();
378
379 let ctx = SessionContext::new_with_config_rt(session_config, runtime_env);
380 register_functions(&ctx);
381
382 ctx
383}
384
385#[derive(Clone, Debug, PartialEq, Eq, Hash)]
387struct SessionContextCacheKey {
388 mem_pool_size: u64,
389 max_temp_directory_size: u64,
390 target_partition: Option<usize>,
391 use_spilling: bool,
392}
393
394impl SessionContextCacheKey {
395 fn from_options(options: &LanceExecutionOptions) -> Self {
396 Self {
397 mem_pool_size: options.mem_pool_size(),
398 max_temp_directory_size: options.max_temp_directory_size(),
399 target_partition: options.target_partition,
400 use_spilling: options.use_spilling(),
401 }
402 }
403}
404
405struct CachedSessionContext {
406 context: SessionContext,
407 last_access: std::time::Instant,
408}
409
410fn get_session_cache() -> &'static Mutex<HashMap<SessionContextCacheKey, CachedSessionContext>> {
411 static SESSION_CACHE: OnceLock<Mutex<HashMap<SessionContextCacheKey, CachedSessionContext>>> =
412 OnceLock::new();
413 SESSION_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
414}
415
416fn get_max_cache_size() -> usize {
417 const DEFAULT_CACHE_SIZE: usize = 4;
418 static MAX_CACHE_SIZE: OnceLock<usize> = OnceLock::new();
419 *MAX_CACHE_SIZE.get_or_init(|| {
420 std::env::var("LANCE_SESSION_CACHE_SIZE")
421 .ok()
422 .and_then(|v| v.parse().ok())
423 .unwrap_or(DEFAULT_CACHE_SIZE)
424 })
425}
426
427pub fn get_session_context(options: &LanceExecutionOptions) -> SessionContext {
428 let key = SessionContextCacheKey::from_options(options);
429 let mut cache = get_session_cache()
430 .lock()
431 .unwrap_or_else(|e| e.into_inner());
432
433 if let Some(entry) = cache.get_mut(&key) {
435 entry.last_access = std::time::Instant::now();
436 return entry.context.clone();
437 }
438
439 if cache.len() >= get_max_cache_size()
441 && let Some(lru_key) = cache
442 .iter()
443 .min_by_key(|(_, v)| v.last_access)
444 .map(|(k, _)| k.clone())
445 {
446 cache.remove(&lru_key);
447 }
448
449 let context = new_session_context(options);
450 cache.insert(
451 key,
452 CachedSessionContext {
453 context: context.clone(),
454 last_access: std::time::Instant::now(),
455 },
456 );
457 context
458}
459
460fn get_task_context(
461 session_ctx: &SessionContext,
462 options: &LanceExecutionOptions,
463) -> Arc<TaskContext> {
464 let mut state = session_ctx.state();
465 if let Some(batch_size) = options.batch_size.as_ref() {
466 state.config_mut().options_mut().execution.batch_size = *batch_size;
467 }
468
469 state.task_ctx()
470}
471
472#[derive(Default, Clone, Debug, PartialEq, Eq)]
473pub struct ExecutionSummaryCounts {
474 pub iops: usize,
476 pub requests: usize,
479 pub bytes_read: usize,
481 pub indices_loaded: usize,
483 pub parts_loaded: usize,
485 pub index_comparisons: usize,
487 pub all_counts: HashMap<String, usize>,
490 pub all_times: HashMap<String, usize>,
493}
494
495pub fn collect_execution_metrics(node: &dyn ExecutionPlan, counts: &mut ExecutionSummaryCounts) {
496 if let Some(metrics) = node.metrics() {
497 for (metric_name, count) in metrics.iter_counts() {
498 match metric_name.as_ref() {
499 IOPS_METRIC => counts.iops += count.value(),
500 REQUESTS_METRIC => counts.requests += count.value(),
501 BYTES_READ_METRIC => counts.bytes_read += count.value(),
502 INDICES_LOADED_METRIC => counts.indices_loaded += count.value(),
503 PARTS_LOADED_METRIC => counts.parts_loaded += count.value(),
504 INDEX_COMPARISONS_METRIC => counts.index_comparisons += count.value(),
505 _ => {
506 let existing = counts
507 .all_counts
508 .entry(metric_name.as_ref().to_string())
509 .or_insert(0);
510 *existing += count.value();
511 }
512 }
513 }
514 for (metric_name, time) in metrics.iter_times() {
515 let existing = counts
516 .all_times
517 .entry(metric_name.as_ref().to_string())
518 .or_insert(0);
519 *existing += time.value();
520 }
521 for (metric_name, gauge) in metrics.iter_gauges() {
523 match metric_name.as_ref() {
524 IOPS_METRIC => counts.iops += gauge.value(),
525 REQUESTS_METRIC => counts.requests += gauge.value(),
526 BYTES_READ_METRIC => counts.bytes_read += gauge.value(),
527 _ => {}
528 }
529 }
530 }
531 for child in node.children() {
532 collect_execution_metrics(child.as_ref(), counts);
533 }
534}
535
536fn report_plan_summary_metrics(plan: &dyn ExecutionPlan, options: &LanceExecutionOptions) {
537 let output_rows = plan
538 .metrics()
539 .map(|m| m.output_rows().unwrap_or(0))
540 .unwrap_or(0);
541 let mut counts = ExecutionSummaryCounts::default();
542 collect_execution_metrics(plan, &mut counts);
543 if !options.skip_logging {
544 tracing::info!(
545 target: TRACE_EXECUTION,
546 r#type = EXECUTION_PLAN_RUN,
547 plan_summary = display_plan_one_liner(plan),
548 output_rows,
549 iops = counts.iops,
550 requests = counts.requests,
551 bytes_read = counts.bytes_read,
552 indices_loaded = counts.indices_loaded,
553 parts_loaded = counts.parts_loaded,
554 index_comparisons = counts.index_comparisons,
555 );
556 }
557 if let Some(callback) = options.execution_stats_callback.as_ref() {
558 callback(&counts);
559 }
560}
561
562fn display_plan_one_liner(plan: &dyn ExecutionPlan) -> String {
569 let mut output = String::new();
570
571 display_plan_one_liner_impl(plan, &mut output);
572
573 output
574}
575
576fn display_plan_one_liner_impl(plan: &dyn ExecutionPlan, output: &mut String) {
577 let name = plan.name().trim_end_matches("Exec");
579 output.push_str(name);
580
581 let children = plan.children();
582 if !children.is_empty() {
583 output.push('(');
584 for (i, child) in children.iter().enumerate() {
585 if i > 0 {
586 output.push(',');
587 }
588 display_plan_one_liner_impl(child.as_ref(), output);
589 }
590 output.push(')');
591 }
592}
593
594pub fn execute_plan(
598 plan: Arc<dyn ExecutionPlan>,
599 options: LanceExecutionOptions,
600) -> Result<SendableRecordBatchStream> {
601 if !options.skip_logging {
602 debug!(
603 "Executing plan:\n{}",
604 DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
605 );
606 }
607
608 let session_ctx = get_session_context(&options);
609
610 let plan: Arc<dyn ExecutionPlan> = if plan.properties().partitioning.partition_count() == 1 {
614 plan
615 } else {
616 Arc::new(CoalescePartitionsExec::new(plan))
617 };
618
619 let stream = plan.execute(0, get_task_context(&session_ctx, &options))?;
620
621 let schema = stream.schema();
622 let stream = stream.finally(move || {
623 if !options.skip_logging || options.execution_stats_callback.is_some() {
624 report_plan_summary_metrics(plan.as_ref(), &options);
625 }
626 });
627 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
628}
629
630pub async fn analyze_plan(
631 plan: Arc<dyn ExecutionPlan>,
632 options: LanceExecutionOptions,
633) -> Result<String> {
634 let plan = Arc::new(TracedExec::new(plan, Span::current()));
637
638 let schema = plan.schema();
639 let analyze = Arc::new(AnalyzeExec::new(
641 true,
642 true,
643 vec![MetricType::SUMMARY],
644 plan,
645 schema,
646 ));
647
648 let session_ctx = get_session_context(&options);
649 assert_eq!(analyze.properties().partitioning.partition_count(), 1);
650 let mut stream = analyze
651 .execute(0, get_task_context(&session_ctx, &options))
652 .map_err(|err| Error::io(format!("Failed to execute analyze plan: {}", err)))?;
653
654 while (stream.next().await).is_some() {}
656
657 let result = format_plan(analyze);
658 Ok(result)
659}
660
661pub fn format_plan(plan: Arc<dyn ExecutionPlan>) -> String {
662 struct CalculateVisitor {
664 highest_index: usize,
665 index_to_elapsed: HashMap<usize, Duration>,
666 }
667
668 struct SubtreeMetrics {
670 min_start: Option<DateTime<Utc>>,
671 max_end: Option<DateTime<Utc>>,
672 }
673
674 impl CalculateVisitor {
675 fn calculate_metrics(&mut self, plan: &Arc<dyn ExecutionPlan>) -> SubtreeMetrics {
676 self.highest_index += 1;
677 let plan_index = self.highest_index;
678
679 let (mut min_start, mut max_end) = Self::node_timerange(plan);
681
682 for child in plan.children() {
684 let child_metrics = self.calculate_metrics(child);
685 min_start = Self::min_option(min_start, child_metrics.min_start);
686 max_end = Self::max_option(max_end, child_metrics.max_end);
687 }
688
689 let elapsed = match (min_start, max_end) {
691 (Some(start), Some(end)) => Some((end - start).to_std().unwrap_or_default()),
692 _ => None,
693 };
694
695 if let Some(e) = elapsed {
696 self.index_to_elapsed.insert(plan_index, e);
697 }
698
699 SubtreeMetrics { min_start, max_end }
700 }
701
702 fn node_timerange(
703 plan: &Arc<dyn ExecutionPlan>,
704 ) -> (Option<DateTime<Utc>>, Option<DateTime<Utc>>) {
705 let Some(metrics) = plan.metrics() else {
706 return (None, None);
707 };
708 let min_start = metrics
709 .iter()
710 .filter_map(|m| match m.value() {
711 MetricValue::StartTimestamp(ts) => ts.value(),
712 _ => None,
713 })
714 .min();
715 let max_end = metrics
716 .iter()
717 .filter_map(|m| match m.value() {
718 MetricValue::EndTimestamp(ts) => ts.value(),
719 _ => None,
720 })
721 .max();
722 (min_start, max_end)
723 }
724
725 fn min_option(a: Option<DateTime<Utc>>, b: Option<DateTime<Utc>>) -> Option<DateTime<Utc>> {
726 [a, b].into_iter().flatten().min()
727 }
728
729 fn max_option(a: Option<DateTime<Utc>>, b: Option<DateTime<Utc>>) -> Option<DateTime<Utc>> {
730 [a, b].into_iter().flatten().max()
731 }
732 }
733
734 struct PrintVisitor {
736 highest_index: usize,
737 indent: usize,
738 }
739 impl PrintVisitor {
740 fn write_output(
741 &mut self,
742 plan: &Arc<dyn ExecutionPlan>,
743 f: &mut Formatter,
744 calcs: &CalculateVisitor,
745 ) -> std::fmt::Result {
746 self.highest_index += 1;
747 write!(f, "{:indent$}", "", indent = self.indent * 2)?;
748
749 let displayable =
751 datafusion::physical_plan::display::DisplayableExecutionPlan::new(plan.as_ref());
752 let plan_str = displayable.one_line().to_string();
753 let plan_str = plan_str.trim();
754
755 match calcs.index_to_elapsed.get(&self.highest_index) {
757 Some(elapsed) => match plan_str.find(": ") {
758 Some(i) => write!(
759 f,
760 "{}: elapsed={elapsed:?}, {}",
761 &plan_str[..i],
762 &plan_str[i + 2..]
763 )?,
764 None => write!(f, "{plan_str}, elapsed={elapsed:?}")?,
765 },
766 None => write!(f, "{plan_str}")?,
767 }
768
769 if let Some(metrics) = plan.metrics() {
770 let metrics = metrics
771 .aggregate_by_name()
772 .sorted_for_display()
773 .timestamps_removed();
774
775 write!(f, ", metrics=[{metrics}]")?;
776 } else {
777 write!(f, ", metrics=[]")?;
778 }
779 writeln!(f)?;
780 self.indent += 1;
781 for child in plan.children() {
782 self.write_output(child, f, calcs)?;
783 }
784 self.indent -= 1;
785 std::fmt::Result::Ok(())
786 }
787 }
788 struct PrintWrapper {
790 plan: Arc<dyn ExecutionPlan>,
791 }
792 impl fmt::Display for PrintWrapper {
793 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
794 let mut calcs = CalculateVisitor {
795 highest_index: 0,
796 index_to_elapsed: HashMap::new(),
797 };
798 calcs.calculate_metrics(&self.plan);
799 let mut prints = PrintVisitor {
800 highest_index: 0,
801 indent: 0,
802 };
803 prints.write_output(&self.plan, f, &calcs)
804 }
805 }
806 let wrapper = PrintWrapper { plan };
807 format!("{}", wrapper)
808}
809
810pub trait SessionContextExt {
811 fn read_one_shot(
815 &self,
816 data: SendableRecordBatchStream,
817 ) -> datafusion::common::Result<DataFrame>;
818}
819
820pub struct OneShotPartitionStream {
821 data: Arc<Mutex<Option<SendableRecordBatchStream>>>,
822 schema: Arc<ArrowSchema>,
823}
824
825impl std::fmt::Debug for OneShotPartitionStream {
826 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
827 let data = self.data.lock().unwrap();
828 f.debug_struct("OneShotPartitionStream")
829 .field("exhausted", &data.is_none())
830 .field("schema", self.schema.as_ref())
831 .finish()
832 }
833}
834
835impl OneShotPartitionStream {
836 pub fn new(data: SendableRecordBatchStream) -> Self {
837 let schema = data.schema();
838 Self {
839 data: Arc::new(Mutex::new(Some(data))),
840 schema,
841 }
842 }
843}
844
845impl PartitionStream for OneShotPartitionStream {
846 fn schema(&self) -> &arrow_schema::SchemaRef {
847 &self.schema
848 }
849
850 fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
851 let mut stream = self.data.lock().unwrap();
852 stream
853 .take()
854 .expect("Attempt to consume a one shot dataframe multiple times")
855 }
856}
857
858impl SessionContextExt for SessionContext {
859 fn read_one_shot(
860 &self,
861 data: SendableRecordBatchStream,
862 ) -> datafusion::common::Result<DataFrame> {
863 let schema = data.schema();
864 let part_stream = Arc::new(OneShotPartitionStream::new(data));
865 let provider = StreamingTable::try_new(schema, vec![part_stream])?;
866 self.read_table(Arc::new(provider))
867 }
868}
869
870#[derive(Clone, Debug)]
871pub struct StrictBatchSizeExec {
872 input: Arc<dyn ExecutionPlan>,
873 batch_size: usize,
874}
875
876impl StrictBatchSizeExec {
877 pub fn new(input: Arc<dyn ExecutionPlan>, batch_size: usize) -> Self {
878 Self { input, batch_size }
879 }
880}
881
882impl DisplayAs for StrictBatchSizeExec {
883 fn fmt_as(
884 &self,
885 _t: datafusion::physical_plan::DisplayFormatType,
886 f: &mut std::fmt::Formatter,
887 ) -> std::fmt::Result {
888 write!(f, "StrictBatchSizeExec")
889 }
890}
891
892impl ExecutionPlan for StrictBatchSizeExec {
893 fn name(&self) -> &str {
894 "StrictBatchSizeExec"
895 }
896
897 fn as_any(&self) -> &dyn std::any::Any {
898 self
899 }
900
901 fn properties(&self) -> &Arc<PlanProperties> {
902 self.input.properties()
903 }
904
905 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
906 vec![&self.input]
907 }
908
909 fn with_new_children(
910 self: Arc<Self>,
911 children: Vec<Arc<dyn ExecutionPlan>>,
912 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
913 Ok(Arc::new(Self {
914 input: children[0].clone(),
915 batch_size: self.batch_size,
916 }))
917 }
918
919 fn execute(
920 &self,
921 partition: usize,
922 context: Arc<TaskContext>,
923 ) -> datafusion_common::Result<SendableRecordBatchStream> {
924 let stream = self.input.execute(partition, context)?;
925 let schema = stream.schema();
926 let stream = StrictBatchSizeStream::new(stream, self.batch_size);
927 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
928 }
929
930 fn maintains_input_order(&self) -> Vec<bool> {
931 vec![true]
932 }
933
934 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
935 vec![false]
936 }
937
938 fn partition_statistics(
939 &self,
940 partition: Option<usize>,
941 ) -> datafusion_common::Result<Statistics> {
942 self.input.partition_statistics(partition)
943 }
944
945 fn cardinality_effect(&self) -> CardinalityEffect {
946 CardinalityEffect::Equal
947 }
948
949 fn supports_limit_pushdown(&self) -> bool {
950 true
951 }
952}
953
954#[derive(Clone, Debug)]
977pub struct HardCapBatchSizeExec {
978 input: Arc<dyn ExecutionPlan>,
979 max_bytes: usize,
980}
981
982impl HardCapBatchSizeExec {
983 pub fn new(input: Arc<dyn ExecutionPlan>, max_bytes: usize) -> Self {
984 Self { input, max_bytes }
985 }
986}
987
988impl DisplayAs for HardCapBatchSizeExec {
989 fn fmt_as(
990 &self,
991 _t: datafusion::physical_plan::DisplayFormatType,
992 f: &mut std::fmt::Formatter,
993 ) -> std::fmt::Result {
994 write!(f, "HardCapBatchSizeExec(max_bytes={})", self.max_bytes)
995 }
996}
997
998impl ExecutionPlan for HardCapBatchSizeExec {
999 fn name(&self) -> &str {
1000 "HardCapBatchSizeExec"
1001 }
1002
1003 fn as_any(&self) -> &dyn std::any::Any {
1004 self
1005 }
1006
1007 fn properties(&self) -> &Arc<PlanProperties> {
1008 self.input.properties()
1009 }
1010
1011 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1012 vec![&self.input]
1013 }
1014
1015 fn with_new_children(
1016 self: Arc<Self>,
1017 children: Vec<Arc<dyn ExecutionPlan>>,
1018 ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
1019 Ok(Arc::new(Self {
1020 input: children[0].clone(),
1021 max_bytes: self.max_bytes,
1022 }))
1023 }
1024
1025 fn execute(
1026 &self,
1027 partition: usize,
1028 context: Arc<TaskContext>,
1029 ) -> datafusion_common::Result<SendableRecordBatchStream> {
1030 let stream = self.input.execute(partition, context)?;
1031 let schema = stream.schema();
1032 let max_bytes = self.max_bytes;
1033 let rechunked = lance_arrow::stream::rechunk_stream_by_size_deep_copy(
1034 stream,
1035 schema.clone(),
1036 0,
1037 max_bytes,
1038 );
1039 let validated = rechunked.map(move |result| {
1041 let batch = result?;
1042 if batch.num_rows() == 1 && batch.get_array_memory_size() > max_bytes {
1043 return Err(DataFusionError::External(Box::new(Error::invalid_input(
1044 format!(
1045 "a single row is {} bytes which exceeds the maximum allowed batch \
1046 size of {} bytes",
1047 batch.get_array_memory_size(),
1048 max_bytes,
1049 ),
1050 ))));
1051 }
1052 Ok(batch)
1053 });
1054 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, validated)))
1055 }
1056
1057 fn maintains_input_order(&self) -> Vec<bool> {
1058 vec![true]
1059 }
1060
1061 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
1062 vec![false]
1063 }
1064
1065 fn partition_statistics(
1066 &self,
1067 partition: Option<usize>,
1068 ) -> datafusion_common::Result<Statistics> {
1069 self.input.partition_statistics(partition)
1070 }
1071
1072 fn cardinality_effect(&self) -> CardinalityEffect {
1073 CardinalityEffect::Equal
1074 }
1075
1076 fn supports_limit_pushdown(&self) -> bool {
1077 true
1078 }
1079}
1080
1081#[cfg(test)]
1082mod tests {
1083 use super::*;
1084
1085 static CACHE_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
1087
1088 #[test]
1089 fn test_session_context_cache() {
1090 let _lock = CACHE_TEST_LOCK.lock().unwrap();
1091 let cache = get_session_cache();
1092
1093 cache.lock().unwrap().clear();
1095
1096 let opts1 = LanceExecutionOptions::default();
1098 let _ctx1 = get_session_context(&opts1);
1099
1100 {
1101 let cache_guard = cache.lock().unwrap();
1102 assert_eq!(cache_guard.len(), 1);
1103 }
1104
1105 let _ctx1_again = get_session_context(&opts1);
1107 {
1108 let cache_guard = cache.lock().unwrap();
1109 assert_eq!(cache_guard.len(), 1);
1110 }
1111
1112 let opts2 = LanceExecutionOptions {
1114 use_spilling: true,
1115 ..Default::default()
1116 };
1117 let _ctx2 = get_session_context(&opts2);
1118 {
1119 let cache_guard = cache.lock().unwrap();
1120 assert_eq!(cache_guard.len(), 2);
1121 }
1122 }
1123
1124 #[test]
1125 fn test_session_context_cache_lru_eviction() {
1126 let _lock = CACHE_TEST_LOCK.lock().unwrap();
1127 let cache = get_session_cache();
1128
1129 cache.lock().unwrap().clear();
1131
1132 let configs: Vec<LanceExecutionOptions> = (0..4)
1134 .map(|i| LanceExecutionOptions {
1135 mem_pool_size: Some((i + 1) as u64 * 1024 * 1024),
1136 ..Default::default()
1137 })
1138 .collect();
1139
1140 for config in &configs {
1141 let _ctx = get_session_context(config);
1142 }
1143
1144 {
1145 let cache_guard = cache.lock().unwrap();
1146 assert_eq!(cache_guard.len(), 4);
1147 }
1148
1149 std::thread::sleep(std::time::Duration::from_millis(1));
1152 let _ctx = get_session_context(&configs[0]);
1153
1154 let opts5 = LanceExecutionOptions {
1156 mem_pool_size: Some(5 * 1024 * 1024),
1157 ..Default::default()
1158 };
1159 let _ctx5 = get_session_context(&opts5);
1160
1161 {
1162 let cache_guard = cache.lock().unwrap();
1163 assert_eq!(cache_guard.len(), 4);
1164
1165 let key0 = SessionContextCacheKey::from_options(&configs[0]);
1167 assert!(
1168 cache_guard.contains_key(&key0),
1169 "config[0] should still be cached after recent access"
1170 );
1171
1172 let key1 = SessionContextCacheKey::from_options(&configs[1]);
1174 assert!(
1175 !cache_guard.contains_key(&key1),
1176 "config[1] should have been evicted"
1177 );
1178
1179 let key5 = SessionContextCacheKey::from_options(&opts5);
1181 assert!(
1182 cache_guard.contains_key(&key5),
1183 "new config should be cached"
1184 );
1185 }
1186 }
1187
1188 #[test]
1189 fn test_mem_pool_size_scales_with_partitions() {
1190 let default_per_partition = DEFAULT_LANCE_MEM_POOL_SIZE_PER_PARTITION;
1191
1192 let opts = LanceExecutionOptions::default();
1194 assert_eq!(opts.mem_pool_size(), default_per_partition);
1195
1196 let opts = LanceExecutionOptions {
1198 target_partition: Some(4),
1199 ..Default::default()
1200 };
1201 assert_eq!(opts.mem_pool_size(), default_per_partition * 4);
1202
1203 let opts = LanceExecutionOptions {
1205 target_partition: Some(8),
1206 ..Default::default()
1207 };
1208 assert_eq!(opts.mem_pool_size(), default_per_partition * 8);
1209
1210 let opts = LanceExecutionOptions {
1212 mem_pool_size: Some(50 * 1024 * 1024),
1213 target_partition: Some(8),
1214 ..Default::default()
1215 };
1216 assert_eq!(opts.mem_pool_size(), 50 * 1024 * 1024);
1217 }
1218}