1use std::any::Any;
17use std::collections::HashMap;
18use std::fmt::{self, Debug, Formatter};
19use std::sync::Arc;
20
21use parking_lot::RwLock;
22
23use std::collections::BTreeMap;
24
25use arrow::compute::take;
26use arrow::row::{RowConverter, SortField};
27use arrow_array::{RecordBatch, UInt32Array};
28use arrow_schema::{Schema, SchemaRef};
29use async_trait::async_trait;
30use datafusion::execution::{SendableRecordBatchStream, SessionState, TaskContext};
31use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNode};
32use datafusion::physical_expr::{EquivalenceProperties, LexOrdering, Partitioning};
33use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
34use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
35use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
36use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
37use datafusion_common::{DataFusionError, Result};
38use datafusion_expr::Expr;
39use futures::StreamExt;
40use laminar_core::lookup::foyer_cache::FoyerMemoryCache;
41use laminar_core::lookup::source::LookupSourceDyn;
42use tokio::sync::Semaphore;
43
44use super::lookup_join::{LookupJoinNode, LookupJoinType};
45
46#[derive(Default)]
53pub struct LookupTableRegistry {
54 tables: RwLock<HashMap<String, RegisteredLookup>>,
55}
56
57pub enum RegisteredLookup {
60 Snapshot(Arc<LookupSnapshot>),
62 Partial(Arc<PartialLookupState>),
64 Versioned(Arc<VersionedLookupState>),
66}
67
68pub struct LookupSnapshot {
70 pub batch: RecordBatch,
72 pub key_columns: Vec<String>,
74}
75
76pub struct VersionedLookupState {
83 pub batch: RecordBatch,
85 pub index: Arc<VersionedIndex>,
87 pub key_columns: Vec<String>,
89 pub version_column: String,
91 pub stream_time_column: String,
93 pub max_versions_per_key: usize,
95}
96
97pub struct PartialLookupState {
99 pub foyer_cache: Arc<FoyerMemoryCache>,
101 pub schema: SchemaRef,
103 pub key_columns: Vec<String>,
105 pub key_sort_fields: Vec<SortField>,
107 pub source: Option<Arc<dyn LookupSourceDyn>>,
109 pub fetch_semaphore: Arc<Semaphore>,
111}
112
113impl LookupTableRegistry {
114 #[must_use]
116 pub fn new() -> Self {
117 Self::default()
118 }
119
120 pub fn register(&self, name: &str, snapshot: LookupSnapshot) {
126 self.tables.write().insert(
127 name.to_lowercase(),
128 RegisteredLookup::Snapshot(Arc::new(snapshot)),
129 );
130 }
131
132 pub fn register_partial(&self, name: &str, state: PartialLookupState) {
138 self.tables.write().insert(
139 name.to_lowercase(),
140 RegisteredLookup::Partial(Arc::new(state)),
141 );
142 }
143
144 pub fn register_versioned(&self, name: &str, state: VersionedLookupState) {
150 self.tables.write().insert(
151 name.to_lowercase(),
152 RegisteredLookup::Versioned(Arc::new(state)),
153 );
154 }
155
156 pub fn unregister(&self, name: &str) {
162 self.tables.write().remove(&name.to_lowercase());
163 }
164
165 #[must_use]
171 pub fn get(&self, name: &str) -> Option<Arc<LookupSnapshot>> {
172 let tables = self.tables.read();
173 match tables.get(&name.to_lowercase())? {
174 RegisteredLookup::Snapshot(s) => Some(Arc::clone(s)),
175 RegisteredLookup::Partial(_) | RegisteredLookup::Versioned(_) => None,
176 }
177 }
178
179 pub fn get_entry(&self, name: &str) -> Option<RegisteredLookup> {
185 let tables = self.tables.read();
186 tables.get(&name.to_lowercase()).map(|e| match e {
187 RegisteredLookup::Snapshot(s) => RegisteredLookup::Snapshot(Arc::clone(s)),
188 RegisteredLookup::Partial(p) => RegisteredLookup::Partial(Arc::clone(p)),
189 RegisteredLookup::Versioned(v) => RegisteredLookup::Versioned(Arc::clone(v)),
190 })
191 }
192}
193
194struct HashIndex {
198 map: HashMap<Box<[u8]>, Vec<u32>>,
199}
200
201impl HashIndex {
202 fn build(batch: &RecordBatch, key_indices: &[usize]) -> Result<Self> {
207 if batch.num_rows() == 0 {
208 return Ok(Self {
209 map: HashMap::new(),
210 });
211 }
212
213 let sort_fields: Vec<SortField> = key_indices
214 .iter()
215 .map(|&i| SortField::new(batch.schema().field(i).data_type().clone()))
216 .collect();
217 let converter = RowConverter::new(sort_fields)?;
218
219 let key_cols: Vec<_> = key_indices
220 .iter()
221 .map(|&i| batch.column(i).clone())
222 .collect();
223 let rows = converter.convert_columns(&key_cols)?;
224
225 let num_rows = batch.num_rows();
226 let mut map: HashMap<Box<[u8]>, Vec<u32>> = HashMap::with_capacity(num_rows);
227 #[allow(clippy::cast_possible_truncation)] for i in 0..num_rows {
229 map.entry(Box::from(rows.row(i).as_ref()))
230 .or_default()
231 .push(i as u32);
232 }
233
234 Ok(Self { map })
235 }
236
237 fn probe(&self, key: &[u8]) -> Option<&[u32]> {
238 self.map.get(key).map(Vec::as_slice)
239 }
240}
241
242#[derive(Default)]
248pub struct VersionedIndex {
249 map: HashMap<Box<[u8]>, BTreeMap<i64, Vec<u32>>>,
250}
251
252impl VersionedIndex {
253 pub fn build(
263 batch: &RecordBatch,
264 key_indices: &[usize],
265 version_col_idx: usize,
266 max_versions_per_key: usize,
267 ) -> Result<Self> {
268 if batch.num_rows() == 0 {
269 return Ok(Self {
270 map: HashMap::new(),
271 });
272 }
273
274 let sort_fields: Vec<SortField> = key_indices
275 .iter()
276 .map(|&i| SortField::new(batch.schema().field(i).data_type().clone()))
277 .collect();
278 let converter = RowConverter::new(sort_fields)?;
279
280 let key_cols: Vec<_> = key_indices
281 .iter()
282 .map(|&i| batch.column(i).clone())
283 .collect();
284 let rows = converter.convert_columns(&key_cols)?;
285
286 let timestamps = extract_i64_timestamps(batch.column(version_col_idx))?;
287
288 let num_rows = batch.num_rows();
289 let mut map: HashMap<Box<[u8]>, BTreeMap<i64, Vec<u32>>> = HashMap::with_capacity(num_rows);
290 #[allow(clippy::cast_possible_truncation)]
291 for (i, ts_opt) in timestamps.iter().enumerate() {
292 let Some(version_ts) = ts_opt else { continue };
294 if key_cols.iter().any(|c| c.is_null(i)) {
295 continue;
296 }
297 let key = Box::from(rows.row(i).as_ref());
298 map.entry(key)
299 .or_default()
300 .entry(*version_ts)
301 .or_default()
302 .push(i as u32);
303 }
304
305 if max_versions_per_key < usize::MAX {
307 for versions in map.values_mut() {
308 while versions.len() > max_versions_per_key {
309 versions.pop_first();
310 }
311 }
312 }
313
314 Ok(Self { map })
315 }
316
317 fn probe_at_time(&self, key: &[u8], event_ts: i64) -> Option<u32> {
320 let versions = self.map.get(key)?;
321 let (_, indices) = versions.range(..=event_ts).next_back()?;
322 indices.last().copied()
323 }
324}
325
326fn extract_i64_timestamps(col: &dyn arrow_array::Array) -> Result<Vec<Option<i64>>> {
332 use arrow_array::{
333 Float64Array, Int64Array, TimestampMicrosecondArray, TimestampMillisecondArray,
334 TimestampNanosecondArray, TimestampSecondArray,
335 };
336 use arrow_schema::{DataType, TimeUnit};
337
338 let n = col.len();
339 let mut out = Vec::with_capacity(n);
340 macro_rules! extract_typed {
341 ($arr_type:ty, $scale:expr) => {{
342 let arr = col.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
343 DataFusionError::Internal(concat!("expected ", stringify!($arr_type)).into())
344 })?;
345 for i in 0..n {
346 out.push(if col.is_null(i) {
347 None
348 } else {
349 Some(arr.value(i) * $scale)
350 });
351 }
352 }};
353 }
354
355 match col.data_type() {
356 DataType::Int64 => extract_typed!(Int64Array, 1),
357 DataType::Timestamp(TimeUnit::Millisecond, _) => {
358 extract_typed!(TimestampMillisecondArray, 1);
359 }
360 DataType::Timestamp(TimeUnit::Microsecond, _) => {
361 let arr = col
362 .as_any()
363 .downcast_ref::<TimestampMicrosecondArray>()
364 .ok_or_else(|| {
365 DataFusionError::Internal("expected TimestampMicrosecondArray".into())
366 })?;
367 for i in 0..n {
368 out.push(if col.is_null(i) {
369 None
370 } else {
371 Some(arr.value(i) / 1000)
372 });
373 }
374 }
375 DataType::Timestamp(TimeUnit::Second, _) => {
376 extract_typed!(TimestampSecondArray, 1000);
377 }
378 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
379 let arr = col
380 .as_any()
381 .downcast_ref::<TimestampNanosecondArray>()
382 .ok_or_else(|| {
383 DataFusionError::Internal("expected TimestampNanosecondArray".into())
384 })?;
385 for i in 0..n {
386 out.push(if col.is_null(i) {
387 None
388 } else {
389 Some(arr.value(i) / 1_000_000)
390 });
391 }
392 }
393 DataType::Float64 => {
394 let arr = col
395 .as_any()
396 .downcast_ref::<Float64Array>()
397 .ok_or_else(|| DataFusionError::Internal("expected Float64Array".into()))?;
398 #[allow(clippy::cast_possible_truncation)]
399 for i in 0..n {
400 out.push(if col.is_null(i) {
401 None
402 } else {
403 Some(arr.value(i) as i64)
404 });
405 }
406 }
407 other => {
408 return Err(DataFusionError::Plan(format!(
409 "unsupported timestamp type for temporal join: {other:?}"
410 )));
411 }
412 }
413
414 Ok(out)
415}
416
417pub struct LookupJoinExec {
422 input: Arc<dyn ExecutionPlan>,
423 index: Arc<HashIndex>,
424 lookup_batch: Arc<RecordBatch>,
425 stream_key_indices: Vec<usize>,
426 join_type: LookupJoinType,
427 schema: SchemaRef,
428 properties: PlanProperties,
429 key_sort_fields: Vec<SortField>,
431 stream_field_count: usize,
432}
433
434impl LookupJoinExec {
435 #[allow(clippy::needless_pass_by_value)] pub fn try_new(
445 input: Arc<dyn ExecutionPlan>,
446 lookup_batch: RecordBatch,
447 stream_key_indices: Vec<usize>,
448 lookup_key_indices: Vec<usize>,
449 join_type: LookupJoinType,
450 output_schema: SchemaRef,
451 ) -> Result<Self> {
452 let index = HashIndex::build(&lookup_batch, &lookup_key_indices)?;
453
454 let key_sort_fields: Vec<SortField> = lookup_key_indices
455 .iter()
456 .map(|&i| SortField::new(lookup_batch.schema().field(i).data_type().clone()))
457 .collect();
458
459 let output_schema = if join_type == LookupJoinType::LeftOuter {
462 let stream_count = input.schema().fields().len();
463 let mut fields = output_schema.fields().to_vec();
464 for f in &mut fields[stream_count..] {
465 if !f.is_nullable() {
466 *f = Arc::new(f.as_ref().clone().with_nullable(true));
467 }
468 }
469 Arc::new(Schema::new_with_metadata(
470 fields,
471 output_schema.metadata().clone(),
472 ))
473 } else {
474 output_schema
475 };
476
477 let properties = PlanProperties::new(
478 EquivalenceProperties::new(Arc::clone(&output_schema)),
479 Partitioning::UnknownPartitioning(1),
480 EmissionType::Incremental,
481 Boundedness::Unbounded {
482 requires_infinite_memory: false,
483 },
484 );
485
486 let stream_field_count = input.schema().fields().len();
487
488 Ok(Self {
489 input,
490 index: Arc::new(index),
491 lookup_batch: Arc::new(lookup_batch),
492 stream_key_indices,
493 join_type,
494 schema: output_schema,
495 properties,
496 key_sort_fields,
497 stream_field_count,
498 })
499 }
500}
501
502impl Debug for LookupJoinExec {
503 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
504 f.debug_struct("LookupJoinExec")
505 .field("join_type", &self.join_type)
506 .field("stream_keys", &self.stream_key_indices)
507 .field("lookup_rows", &self.lookup_batch.num_rows())
508 .finish_non_exhaustive()
509 }
510}
511
512impl DisplayAs for LookupJoinExec {
513 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
514 match t {
515 DisplayFormatType::Default | DisplayFormatType::Verbose => {
516 write!(
517 f,
518 "LookupJoinExec: type={}, stream_keys={:?}, lookup_rows={}",
519 self.join_type,
520 self.stream_key_indices,
521 self.lookup_batch.num_rows(),
522 )
523 }
524 DisplayFormatType::TreeRender => write!(f, "LookupJoinExec"),
525 }
526 }
527}
528
529impl ExecutionPlan for LookupJoinExec {
530 fn name(&self) -> &'static str {
531 "LookupJoinExec"
532 }
533
534 fn as_any(&self) -> &dyn Any {
535 self
536 }
537
538 fn schema(&self) -> SchemaRef {
539 Arc::clone(&self.schema)
540 }
541
542 fn properties(&self) -> &PlanProperties {
543 &self.properties
544 }
545
546 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
547 vec![&self.input]
548 }
549
550 fn with_new_children(
551 self: Arc<Self>,
552 mut children: Vec<Arc<dyn ExecutionPlan>>,
553 ) -> Result<Arc<dyn ExecutionPlan>> {
554 if children.len() != 1 {
555 return Err(DataFusionError::Plan(
556 "LookupJoinExec requires exactly one child".into(),
557 ));
558 }
559 Ok(Arc::new(Self {
560 input: children.swap_remove(0),
561 index: Arc::clone(&self.index),
562 lookup_batch: Arc::clone(&self.lookup_batch),
563 stream_key_indices: self.stream_key_indices.clone(),
564 join_type: self.join_type,
565 schema: Arc::clone(&self.schema),
566 properties: self.properties.clone(),
567 key_sort_fields: self.key_sort_fields.clone(),
568 stream_field_count: self.stream_field_count,
569 }))
570 }
571
572 fn execute(
573 &self,
574 partition: usize,
575 context: Arc<TaskContext>,
576 ) -> Result<SendableRecordBatchStream> {
577 let input_stream = self.input.execute(partition, context)?;
578 let converter = RowConverter::new(self.key_sort_fields.clone())?;
579 let index = Arc::clone(&self.index);
580 let lookup_batch = Arc::clone(&self.lookup_batch);
581 let stream_key_indices = self.stream_key_indices.clone();
582 let join_type = self.join_type;
583 let schema = self.schema();
584 let stream_field_count = self.stream_field_count;
585
586 let output = input_stream.map(move |result| {
587 let batch = result?;
588 if batch.num_rows() == 0 {
589 return Ok(RecordBatch::new_empty(Arc::clone(&schema)));
590 }
591 probe_batch(
592 &batch,
593 &converter,
594 &index,
595 &lookup_batch,
596 &stream_key_indices,
597 join_type,
598 &schema,
599 stream_field_count,
600 )
601 });
602
603 Ok(Box::pin(RecordBatchStreamAdapter::new(
604 self.schema(),
605 output,
606 )))
607 }
608}
609
610impl datafusion::physical_plan::ExecutionPlanProperties for LookupJoinExec {
611 fn output_partitioning(&self) -> &Partitioning {
612 self.properties.output_partitioning()
613 }
614
615 fn output_ordering(&self) -> Option<&LexOrdering> {
616 self.properties.output_ordering()
617 }
618
619 fn boundedness(&self) -> Boundedness {
620 Boundedness::Unbounded {
621 requires_infinite_memory: false,
622 }
623 }
624
625 fn pipeline_behavior(&self) -> EmissionType {
626 EmissionType::Incremental
627 }
628
629 fn equivalence_properties(&self) -> &EquivalenceProperties {
630 self.properties.equivalence_properties()
631 }
632}
633
634#[allow(clippy::too_many_arguments)]
639fn probe_batch(
640 stream_batch: &RecordBatch,
641 converter: &RowConverter,
642 index: &HashIndex,
643 lookup_batch: &RecordBatch,
644 stream_key_indices: &[usize],
645 join_type: LookupJoinType,
646 output_schema: &SchemaRef,
647 stream_field_count: usize,
648) -> Result<RecordBatch> {
649 let key_cols: Vec<_> = stream_key_indices
650 .iter()
651 .map(|&i| stream_batch.column(i).clone())
652 .collect();
653 let rows = converter.convert_columns(&key_cols)?;
654
655 let num_rows = stream_batch.num_rows();
656 let mut stream_indices: Vec<u32> = Vec::with_capacity(num_rows);
657 let mut lookup_indices: Vec<Option<u32>> = Vec::with_capacity(num_rows);
658
659 #[allow(clippy::cast_possible_truncation)] for row in 0..num_rows {
661 if key_cols.iter().any(|c| c.is_null(row)) {
663 if join_type == LookupJoinType::LeftOuter {
664 stream_indices.push(row as u32);
665 lookup_indices.push(None);
666 }
667 continue;
668 }
669
670 let key = rows.row(row);
671 match index.probe(key.as_ref()) {
672 Some(matches) => {
673 for &lookup_row in matches {
674 stream_indices.push(row as u32);
675 lookup_indices.push(Some(lookup_row));
676 }
677 }
678 None if join_type == LookupJoinType::LeftOuter => {
679 stream_indices.push(row as u32);
680 lookup_indices.push(None);
681 }
682 None => {}
683 }
684 }
685
686 if stream_indices.is_empty() {
687 return Ok(RecordBatch::new_empty(Arc::clone(output_schema)));
688 }
689
690 let take_stream = UInt32Array::from(stream_indices);
692 let mut columns = Vec::with_capacity(output_schema.fields().len());
693
694 for col in stream_batch.columns() {
695 columns.push(take(col.as_ref(), &take_stream, None)?);
696 }
697
698 let take_lookup: UInt32Array = lookup_indices.into_iter().collect();
700 for col in lookup_batch.columns() {
701 columns.push(take(col.as_ref(), &take_lookup, None)?);
702 }
703
704 debug_assert_eq!(
705 columns.len(),
706 stream_field_count + lookup_batch.num_columns(),
707 "output column count mismatch"
708 );
709
710 Ok(RecordBatch::try_new(Arc::clone(output_schema), columns)?)
711}
712
713pub struct VersionedLookupJoinExec {
719 input: Arc<dyn ExecutionPlan>,
720 index: Arc<VersionedIndex>,
721 table_batch: Arc<RecordBatch>,
722 stream_key_indices: Vec<usize>,
723 stream_time_col_idx: usize,
724 join_type: LookupJoinType,
725 schema: SchemaRef,
726 properties: PlanProperties,
727 key_sort_fields: Vec<SortField>,
728 stream_field_count: usize,
729}
730
731impl VersionedLookupJoinExec {
732 #[allow(clippy::too_many_arguments, clippy::needless_pass_by_value)]
742 pub fn try_new(
743 input: Arc<dyn ExecutionPlan>,
744 table_batch: RecordBatch,
745 index: Arc<VersionedIndex>,
746 stream_key_indices: Vec<usize>,
747 stream_time_col_idx: usize,
748 join_type: LookupJoinType,
749 output_schema: SchemaRef,
750 key_sort_fields: Vec<SortField>,
751 ) -> Result<Self> {
752 let output_schema = if join_type == LookupJoinType::LeftOuter {
753 let stream_count = input.schema().fields().len();
754 let mut fields = output_schema.fields().to_vec();
755 for f in &mut fields[stream_count..] {
756 if !f.is_nullable() {
757 *f = Arc::new(f.as_ref().clone().with_nullable(true));
758 }
759 }
760 Arc::new(Schema::new_with_metadata(
761 fields,
762 output_schema.metadata().clone(),
763 ))
764 } else {
765 output_schema
766 };
767
768 let properties = PlanProperties::new(
769 EquivalenceProperties::new(Arc::clone(&output_schema)),
770 Partitioning::UnknownPartitioning(1),
771 EmissionType::Incremental,
772 Boundedness::Unbounded {
773 requires_infinite_memory: false,
774 },
775 );
776
777 let stream_field_count = input.schema().fields().len();
778
779 Ok(Self {
780 input,
781 index,
782 table_batch: Arc::new(table_batch),
783 stream_key_indices,
784 stream_time_col_idx,
785 join_type,
786 schema: output_schema,
787 properties,
788 key_sort_fields,
789 stream_field_count,
790 })
791 }
792}
793
794impl Debug for VersionedLookupJoinExec {
795 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
796 f.debug_struct("VersionedLookupJoinExec")
797 .field("join_type", &self.join_type)
798 .field("stream_keys", &self.stream_key_indices)
799 .field("table_rows", &self.table_batch.num_rows())
800 .finish_non_exhaustive()
801 }
802}
803
804impl DisplayAs for VersionedLookupJoinExec {
805 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
806 match t {
807 DisplayFormatType::Default | DisplayFormatType::Verbose => {
808 write!(
809 f,
810 "VersionedLookupJoinExec: type={}, stream_keys={:?}, table_rows={}",
811 self.join_type,
812 self.stream_key_indices,
813 self.table_batch.num_rows(),
814 )
815 }
816 DisplayFormatType::TreeRender => write!(f, "VersionedLookupJoinExec"),
817 }
818 }
819}
820
821impl ExecutionPlan for VersionedLookupJoinExec {
822 fn name(&self) -> &'static str {
823 "VersionedLookupJoinExec"
824 }
825
826 fn as_any(&self) -> &dyn Any {
827 self
828 }
829
830 fn schema(&self) -> SchemaRef {
831 Arc::clone(&self.schema)
832 }
833
834 fn properties(&self) -> &PlanProperties {
835 &self.properties
836 }
837
838 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
839 vec![&self.input]
840 }
841
842 fn with_new_children(
843 self: Arc<Self>,
844 mut children: Vec<Arc<dyn ExecutionPlan>>,
845 ) -> Result<Arc<dyn ExecutionPlan>> {
846 if children.len() != 1 {
847 return Err(DataFusionError::Plan(
848 "VersionedLookupJoinExec requires exactly one child".into(),
849 ));
850 }
851 Ok(Arc::new(Self {
852 input: children.swap_remove(0),
853 index: Arc::clone(&self.index),
854 table_batch: Arc::clone(&self.table_batch),
855 stream_key_indices: self.stream_key_indices.clone(),
856 stream_time_col_idx: self.stream_time_col_idx,
857 join_type: self.join_type,
858 schema: Arc::clone(&self.schema),
859 properties: self.properties.clone(),
860 key_sort_fields: self.key_sort_fields.clone(),
861 stream_field_count: self.stream_field_count,
862 }))
863 }
864
865 fn execute(
866 &self,
867 partition: usize,
868 context: Arc<TaskContext>,
869 ) -> Result<SendableRecordBatchStream> {
870 let input_stream = self.input.execute(partition, context)?;
871 let converter = RowConverter::new(self.key_sort_fields.clone())?;
872 let index = Arc::clone(&self.index);
873 let table_batch = Arc::clone(&self.table_batch);
874 let stream_key_indices = self.stream_key_indices.clone();
875 let stream_time_col_idx = self.stream_time_col_idx;
876 let join_type = self.join_type;
877 let schema = self.schema();
878 let stream_field_count = self.stream_field_count;
879
880 let output = input_stream.map(move |result| {
881 let batch = result?;
882 if batch.num_rows() == 0 {
883 return Ok(RecordBatch::new_empty(Arc::clone(&schema)));
884 }
885 probe_versioned_batch(
886 &batch,
887 &converter,
888 &index,
889 &table_batch,
890 &stream_key_indices,
891 stream_time_col_idx,
892 join_type,
893 &schema,
894 stream_field_count,
895 )
896 });
897
898 Ok(Box::pin(RecordBatchStreamAdapter::new(
899 self.schema(),
900 output,
901 )))
902 }
903}
904
905impl datafusion::physical_plan::ExecutionPlanProperties for VersionedLookupJoinExec {
906 fn output_partitioning(&self) -> &Partitioning {
907 self.properties.output_partitioning()
908 }
909
910 fn output_ordering(&self) -> Option<&LexOrdering> {
911 self.properties.output_ordering()
912 }
913
914 fn boundedness(&self) -> Boundedness {
915 Boundedness::Unbounded {
916 requires_infinite_memory: false,
917 }
918 }
919
920 fn pipeline_behavior(&self) -> EmissionType {
921 EmissionType::Incremental
922 }
923
924 fn equivalence_properties(&self) -> &EquivalenceProperties {
925 self.properties.equivalence_properties()
926 }
927}
928
929#[allow(clippy::too_many_arguments)]
932fn probe_versioned_batch(
933 stream_batch: &RecordBatch,
934 converter: &RowConverter,
935 index: &VersionedIndex,
936 table_batch: &RecordBatch,
937 stream_key_indices: &[usize],
938 stream_time_col_idx: usize,
939 join_type: LookupJoinType,
940 output_schema: &SchemaRef,
941 stream_field_count: usize,
942) -> Result<RecordBatch> {
943 let key_cols: Vec<_> = stream_key_indices
944 .iter()
945 .map(|&i| stream_batch.column(i).clone())
946 .collect();
947 let rows = converter.convert_columns(&key_cols)?;
948 let event_timestamps =
949 extract_i64_timestamps(stream_batch.column(stream_time_col_idx).as_ref())?;
950
951 let num_rows = stream_batch.num_rows();
952 let mut stream_indices: Vec<u32> = Vec::with_capacity(num_rows);
953 let mut lookup_indices: Vec<Option<u32>> = Vec::with_capacity(num_rows);
954
955 #[allow(clippy::cast_possible_truncation)]
956 for (row, event_ts_opt) in event_timestamps.iter().enumerate() {
957 if key_cols.iter().any(|c| c.is_null(row)) || event_ts_opt.is_none() {
959 if join_type == LookupJoinType::LeftOuter {
960 stream_indices.push(row as u32);
961 lookup_indices.push(None);
962 }
963 continue;
964 }
965
966 let key = rows.row(row);
967 let event_ts = event_ts_opt.unwrap();
968 match index.probe_at_time(key.as_ref(), event_ts) {
969 Some(table_row_idx) => {
970 stream_indices.push(row as u32);
971 lookup_indices.push(Some(table_row_idx));
972 }
973 None if join_type == LookupJoinType::LeftOuter => {
974 stream_indices.push(row as u32);
975 lookup_indices.push(None);
976 }
977 None => {}
978 }
979 }
980
981 if stream_indices.is_empty() {
982 return Ok(RecordBatch::new_empty(Arc::clone(output_schema)));
983 }
984
985 let take_stream = UInt32Array::from(stream_indices);
986 let mut columns = Vec::with_capacity(output_schema.fields().len());
987
988 for col in stream_batch.columns() {
989 columns.push(take(col.as_ref(), &take_stream, None)?);
990 }
991
992 let take_lookup: UInt32Array = lookup_indices.into_iter().collect();
993 for col in table_batch.columns() {
994 columns.push(take(col.as_ref(), &take_lookup, None)?);
995 }
996
997 debug_assert_eq!(
998 columns.len(),
999 stream_field_count + table_batch.num_columns(),
1000 "output column count mismatch"
1001 );
1002
1003 Ok(RecordBatch::try_new(Arc::clone(output_schema), columns)?)
1004}
1005
1006pub struct PartialLookupJoinExec {
1012 input: Arc<dyn ExecutionPlan>,
1013 foyer_cache: Arc<FoyerMemoryCache>,
1014 stream_key_indices: Vec<usize>,
1015 join_type: LookupJoinType,
1016 schema: SchemaRef,
1017 properties: PlanProperties,
1018 key_sort_fields: Vec<SortField>,
1019 stream_field_count: usize,
1020 lookup_schema: SchemaRef,
1021 source: Option<Arc<dyn LookupSourceDyn>>,
1022 fetch_semaphore: Arc<Semaphore>,
1023}
1024
1025impl PartialLookupJoinExec {
1026 pub fn try_new(
1032 input: Arc<dyn ExecutionPlan>,
1033 foyer_cache: Arc<FoyerMemoryCache>,
1034 stream_key_indices: Vec<usize>,
1035 key_sort_fields: Vec<SortField>,
1036 join_type: LookupJoinType,
1037 lookup_schema: SchemaRef,
1038 output_schema: SchemaRef,
1039 ) -> Result<Self> {
1040 Self::try_new_with_source(
1041 input,
1042 foyer_cache,
1043 stream_key_indices,
1044 key_sort_fields,
1045 join_type,
1046 lookup_schema,
1047 output_schema,
1048 None,
1049 Arc::new(Semaphore::new(64)),
1050 )
1051 }
1052
1053 #[allow(clippy::too_many_arguments)]
1059 pub fn try_new_with_source(
1060 input: Arc<dyn ExecutionPlan>,
1061 foyer_cache: Arc<FoyerMemoryCache>,
1062 stream_key_indices: Vec<usize>,
1063 key_sort_fields: Vec<SortField>,
1064 join_type: LookupJoinType,
1065 lookup_schema: SchemaRef,
1066 output_schema: SchemaRef,
1067 source: Option<Arc<dyn LookupSourceDyn>>,
1068 fetch_semaphore: Arc<Semaphore>,
1069 ) -> Result<Self> {
1070 let output_schema = if join_type == LookupJoinType::LeftOuter {
1071 let stream_count = input.schema().fields().len();
1072 let mut fields = output_schema.fields().to_vec();
1073 for f in &mut fields[stream_count..] {
1074 if !f.is_nullable() {
1075 *f = Arc::new(f.as_ref().clone().with_nullable(true));
1076 }
1077 }
1078 Arc::new(Schema::new_with_metadata(
1079 fields,
1080 output_schema.metadata().clone(),
1081 ))
1082 } else {
1083 output_schema
1084 };
1085
1086 let properties = PlanProperties::new(
1087 EquivalenceProperties::new(Arc::clone(&output_schema)),
1088 Partitioning::UnknownPartitioning(1),
1089 EmissionType::Incremental,
1090 Boundedness::Unbounded {
1091 requires_infinite_memory: false,
1092 },
1093 );
1094
1095 let stream_field_count = input.schema().fields().len();
1096
1097 Ok(Self {
1098 input,
1099 foyer_cache,
1100 stream_key_indices,
1101 join_type,
1102 schema: output_schema,
1103 properties,
1104 key_sort_fields,
1105 stream_field_count,
1106 lookup_schema,
1107 source,
1108 fetch_semaphore,
1109 })
1110 }
1111}
1112
1113impl Debug for PartialLookupJoinExec {
1114 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1115 f.debug_struct("PartialLookupJoinExec")
1116 .field("join_type", &self.join_type)
1117 .field("stream_keys", &self.stream_key_indices)
1118 .field("cache_table_id", &self.foyer_cache.table_id())
1119 .finish_non_exhaustive()
1120 }
1121}
1122
1123impl DisplayAs for PartialLookupJoinExec {
1124 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
1125 match t {
1126 DisplayFormatType::Default | DisplayFormatType::Verbose => {
1127 write!(
1128 f,
1129 "PartialLookupJoinExec: type={}, stream_keys={:?}, cache_entries={}",
1130 self.join_type,
1131 self.stream_key_indices,
1132 self.foyer_cache.len(),
1133 )
1134 }
1135 DisplayFormatType::TreeRender => write!(f, "PartialLookupJoinExec"),
1136 }
1137 }
1138}
1139
1140impl ExecutionPlan for PartialLookupJoinExec {
1141 fn name(&self) -> &'static str {
1142 "PartialLookupJoinExec"
1143 }
1144
1145 fn as_any(&self) -> &dyn Any {
1146 self
1147 }
1148
1149 fn schema(&self) -> SchemaRef {
1150 Arc::clone(&self.schema)
1151 }
1152
1153 fn properties(&self) -> &PlanProperties {
1154 &self.properties
1155 }
1156
1157 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1158 vec![&self.input]
1159 }
1160
1161 fn with_new_children(
1162 self: Arc<Self>,
1163 mut children: Vec<Arc<dyn ExecutionPlan>>,
1164 ) -> Result<Arc<dyn ExecutionPlan>> {
1165 if children.len() != 1 {
1166 return Err(DataFusionError::Plan(
1167 "PartialLookupJoinExec requires exactly one child".into(),
1168 ));
1169 }
1170 Ok(Arc::new(Self {
1171 input: children.swap_remove(0),
1172 foyer_cache: Arc::clone(&self.foyer_cache),
1173 stream_key_indices: self.stream_key_indices.clone(),
1174 join_type: self.join_type,
1175 schema: Arc::clone(&self.schema),
1176 properties: self.properties.clone(),
1177 key_sort_fields: self.key_sort_fields.clone(),
1178 stream_field_count: self.stream_field_count,
1179 lookup_schema: Arc::clone(&self.lookup_schema),
1180 source: self.source.clone(),
1181 fetch_semaphore: Arc::clone(&self.fetch_semaphore),
1182 }))
1183 }
1184
1185 fn execute(
1186 &self,
1187 partition: usize,
1188 context: Arc<TaskContext>,
1189 ) -> Result<SendableRecordBatchStream> {
1190 let input_stream = self.input.execute(partition, context)?;
1191 let converter = Arc::new(RowConverter::new(self.key_sort_fields.clone())?);
1192 let foyer_cache = Arc::clone(&self.foyer_cache);
1193 let stream_key_indices = self.stream_key_indices.clone();
1194 let join_type = self.join_type;
1195 let schema = self.schema();
1196 let stream_field_count = self.stream_field_count;
1197 let lookup_schema = Arc::clone(&self.lookup_schema);
1198 let source = self.source.clone();
1199 let fetch_semaphore = Arc::clone(&self.fetch_semaphore);
1200
1201 let output = input_stream.then(move |result| {
1202 let foyer_cache = Arc::clone(&foyer_cache);
1203 let converter = Arc::clone(&converter);
1204 let stream_key_indices = stream_key_indices.clone();
1205 let schema = Arc::clone(&schema);
1206 let lookup_schema = Arc::clone(&lookup_schema);
1207 let source = source.clone();
1208 let fetch_semaphore = Arc::clone(&fetch_semaphore);
1209 async move {
1210 let batch = result?;
1211 if batch.num_rows() == 0 {
1212 return Ok(RecordBatch::new_empty(Arc::clone(&schema)));
1213 }
1214 probe_partial_batch_with_fallback(
1215 &batch,
1216 &converter,
1217 &foyer_cache,
1218 &stream_key_indices,
1219 join_type,
1220 &schema,
1221 stream_field_count,
1222 &lookup_schema,
1223 source.as_deref(),
1224 &fetch_semaphore,
1225 )
1226 .await
1227 }
1228 });
1229
1230 Ok(Box::pin(RecordBatchStreamAdapter::new(
1231 self.schema(),
1232 output,
1233 )))
1234 }
1235}
1236
1237impl datafusion::physical_plan::ExecutionPlanProperties for PartialLookupJoinExec {
1238 fn output_partitioning(&self) -> &Partitioning {
1239 self.properties.output_partitioning()
1240 }
1241
1242 fn output_ordering(&self) -> Option<&LexOrdering> {
1243 self.properties.output_ordering()
1244 }
1245
1246 fn boundedness(&self) -> Boundedness {
1247 Boundedness::Unbounded {
1248 requires_infinite_memory: false,
1249 }
1250 }
1251
1252 fn pipeline_behavior(&self) -> EmissionType {
1253 EmissionType::Incremental
1254 }
1255
1256 fn equivalence_properties(&self) -> &EquivalenceProperties {
1257 self.properties.equivalence_properties()
1258 }
1259}
1260
1261#[allow(clippy::too_many_arguments)]
1265async fn probe_partial_batch_with_fallback(
1266 stream_batch: &RecordBatch,
1267 converter: &RowConverter,
1268 foyer_cache: &FoyerMemoryCache,
1269 stream_key_indices: &[usize],
1270 join_type: LookupJoinType,
1271 output_schema: &SchemaRef,
1272 stream_field_count: usize,
1273 lookup_schema: &SchemaRef,
1274 source: Option<&dyn LookupSourceDyn>,
1275 fetch_semaphore: &Semaphore,
1276) -> Result<RecordBatch> {
1277 let key_cols: Vec<_> = stream_key_indices
1278 .iter()
1279 .map(|&i| stream_batch.column(i).clone())
1280 .collect();
1281 let rows = converter.convert_columns(&key_cols)?;
1282
1283 let num_rows = stream_batch.num_rows();
1284 let mut stream_indices: Vec<u32> = Vec::with_capacity(num_rows);
1285 let mut lookup_batches: Vec<Option<RecordBatch>> = Vec::with_capacity(num_rows);
1286 let mut miss_keys: Vec<(usize, Vec<u8>)> = Vec::new();
1287
1288 #[allow(clippy::cast_possible_truncation)]
1289 for row in 0..num_rows {
1290 if key_cols.iter().any(|c| c.is_null(row)) {
1292 if join_type == LookupJoinType::LeftOuter {
1293 stream_indices.push(row as u32);
1294 lookup_batches.push(None);
1295 }
1296 continue;
1297 }
1298
1299 let key = rows.row(row);
1300 let result = foyer_cache.get_cached(key.as_ref());
1301 if let Some(batch) = result.into_batch() {
1302 stream_indices.push(row as u32);
1303 lookup_batches.push(Some(batch));
1304 } else {
1305 let idx = stream_indices.len();
1306 stream_indices.push(row as u32);
1307 lookup_batches.push(None);
1308 miss_keys.push((idx, key.as_ref().to_vec()));
1309 }
1310 }
1311
1312 if let Some(source) = source {
1314 if !miss_keys.is_empty() {
1315 let _permit = fetch_semaphore
1316 .acquire()
1317 .await
1318 .map_err(|_| DataFusionError::Internal("fetch semaphore closed".into()))?;
1319
1320 let key_refs: Vec<&[u8]> = miss_keys.iter().map(|(_, k)| k.as_slice()).collect();
1321 let source_results = source.query_batch(&key_refs, &[], &[]).await;
1322
1323 match source_results {
1324 Ok(results) => {
1325 for ((idx, key_bytes), maybe_batch) in miss_keys.iter().zip(results.into_iter())
1326 {
1327 if let Some(batch) = maybe_batch {
1328 foyer_cache.insert(key_bytes, batch.clone());
1329 lookup_batches[*idx] = Some(batch);
1330 }
1331 }
1332 }
1333 Err(e) => {
1334 tracing::warn!(
1335 error = %e,
1336 missed_keys = miss_keys.len(),
1337 "partial lookup: source fallback failed, serving cache-only results"
1338 );
1339 }
1340 }
1341 }
1342 }
1343
1344 if join_type == LookupJoinType::Inner {
1346 let mut write = 0;
1347 for read in 0..stream_indices.len() {
1348 if lookup_batches[read].is_some() {
1349 stream_indices[write] = stream_indices[read];
1350 lookup_batches.swap(write, read);
1351 write += 1;
1352 }
1353 }
1354 stream_indices.truncate(write);
1355 lookup_batches.truncate(write);
1356 }
1357
1358 if stream_indices.is_empty() {
1359 return Ok(RecordBatch::new_empty(Arc::clone(output_schema)));
1360 }
1361
1362 let take_indices = UInt32Array::from(stream_indices);
1363 let mut columns = Vec::with_capacity(output_schema.fields().len());
1364
1365 for col in stream_batch.columns() {
1366 columns.push(take(col.as_ref(), &take_indices, None)?);
1367 }
1368
1369 let lookup_col_count = lookup_schema.fields().len();
1370 for col_idx in 0..lookup_col_count {
1371 let arrays: Vec<_> = lookup_batches
1372 .iter()
1373 .map(|opt| match opt {
1374 Some(b) => b.column(col_idx).clone(),
1375 None => arrow_array::new_null_array(lookup_schema.field(col_idx).data_type(), 1),
1376 })
1377 .collect();
1378 let refs: Vec<&dyn arrow_array::Array> = arrays.iter().map(AsRef::as_ref).collect();
1379 columns.push(arrow::compute::concat(&refs)?);
1380 }
1381
1382 debug_assert_eq!(
1383 columns.len(),
1384 stream_field_count + lookup_col_count,
1385 "output column count mismatch"
1386 );
1387
1388 Ok(RecordBatch::try_new(Arc::clone(output_schema), columns)?)
1389}
1390
1391pub struct LookupJoinExtensionPlanner {
1397 registry: Arc<LookupTableRegistry>,
1398}
1399
1400impl LookupJoinExtensionPlanner {
1401 pub fn new(registry: Arc<LookupTableRegistry>) -> Self {
1403 Self { registry }
1404 }
1405}
1406
1407#[async_trait]
1408impl ExtensionPlanner for LookupJoinExtensionPlanner {
1409 #[allow(clippy::too_many_lines)]
1410 async fn plan_extension(
1411 &self,
1412 _planner: &dyn PhysicalPlanner,
1413 node: &dyn UserDefinedLogicalNode,
1414 _logical_inputs: &[&LogicalPlan],
1415 physical_inputs: &[Arc<dyn ExecutionPlan>],
1416 session_state: &SessionState,
1417 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1418 let Some(lookup_node) = node.as_any().downcast_ref::<LookupJoinNode>() else {
1419 return Ok(None);
1420 };
1421
1422 let entry = self
1423 .registry
1424 .get_entry(lookup_node.lookup_table_name())
1425 .ok_or_else(|| {
1426 DataFusionError::Plan(format!(
1427 "lookup table '{}' not registered",
1428 lookup_node.lookup_table_name()
1429 ))
1430 })?;
1431
1432 let input = Arc::clone(&physical_inputs[0]);
1433 let stream_schema = input.schema();
1434
1435 match entry {
1436 RegisteredLookup::Partial(partial_state) => {
1437 let stream_key_indices = resolve_stream_keys(lookup_node, &stream_schema)?;
1438
1439 let mut output_fields = stream_schema.fields().to_vec();
1440 output_fields.extend(partial_state.schema.fields().iter().cloned());
1441 let output_schema = Arc::new(Schema::new(output_fields));
1442
1443 let exec = PartialLookupJoinExec::try_new_with_source(
1444 input,
1445 Arc::clone(&partial_state.foyer_cache),
1446 stream_key_indices,
1447 partial_state.key_sort_fields.clone(),
1448 lookup_node.join_type(),
1449 Arc::clone(&partial_state.schema),
1450 output_schema,
1451 partial_state.source.clone(),
1452 Arc::clone(&partial_state.fetch_semaphore),
1453 )?;
1454 Ok(Some(Arc::new(exec)))
1455 }
1456 RegisteredLookup::Snapshot(snapshot) => {
1457 let lookup_schema = snapshot.batch.schema();
1458 let lookup_key_indices = resolve_lookup_keys(lookup_node, &lookup_schema)?;
1459
1460 let lookup_batch = if lookup_node.pushdown_predicates().is_empty()
1461 || snapshot.batch.num_rows() == 0
1462 {
1463 snapshot.batch.clone()
1464 } else {
1465 apply_pushdown_predicates(
1466 &snapshot.batch,
1467 lookup_node.pushdown_predicates(),
1468 session_state,
1469 )?
1470 };
1471
1472 let stream_key_indices = resolve_stream_keys(lookup_node, &stream_schema)?;
1473
1474 for (si, li) in stream_key_indices.iter().zip(&lookup_key_indices) {
1476 let st = stream_schema.field(*si).data_type();
1477 let lt = lookup_schema.field(*li).data_type();
1478 if st != lt {
1479 return Err(DataFusionError::Plan(format!(
1480 "Lookup join key type mismatch: stream '{}' is {st:?} \
1481 but lookup '{}' is {lt:?}",
1482 stream_schema.field(*si).name(),
1483 lookup_schema.field(*li).name(),
1484 )));
1485 }
1486 }
1487
1488 let mut output_fields = stream_schema.fields().to_vec();
1489 output_fields.extend(lookup_batch.schema().fields().iter().cloned());
1490 let output_schema = Arc::new(Schema::new(output_fields));
1491
1492 let exec = LookupJoinExec::try_new(
1493 input,
1494 lookup_batch,
1495 stream_key_indices,
1496 lookup_key_indices,
1497 lookup_node.join_type(),
1498 output_schema,
1499 )?;
1500
1501 Ok(Some(Arc::new(exec)))
1502 }
1503 RegisteredLookup::Versioned(versioned_state) => {
1504 let table_schema = versioned_state.batch.schema();
1505 let lookup_key_indices = resolve_lookup_keys(lookup_node, &table_schema)?;
1506 let stream_key_indices = resolve_stream_keys(lookup_node, &stream_schema)?;
1507
1508 for (si, li) in stream_key_indices.iter().zip(&lookup_key_indices) {
1510 let st = stream_schema.field(*si).data_type();
1511 let lt = table_schema.field(*li).data_type();
1512 if st != lt {
1513 return Err(DataFusionError::Plan(format!(
1514 "Temporal join key type mismatch: stream '{}' is {st:?} \
1515 but table '{}' is {lt:?}",
1516 stream_schema.field(*si).name(),
1517 table_schema.field(*li).name(),
1518 )));
1519 }
1520 }
1521
1522 let stream_time_col_idx = stream_schema
1523 .index_of(&versioned_state.stream_time_column)
1524 .map_err(|_| {
1525 DataFusionError::Plan(format!(
1526 "stream time column '{}' not found in stream schema",
1527 versioned_state.stream_time_column
1528 ))
1529 })?;
1530
1531 let key_sort_fields: Vec<SortField> = lookup_key_indices
1532 .iter()
1533 .map(|&i| SortField::new(table_schema.field(i).data_type().clone()))
1534 .collect();
1535
1536 let mut output_fields = stream_schema.fields().to_vec();
1537 output_fields.extend(table_schema.fields().iter().cloned());
1538 let output_schema = Arc::new(Schema::new(output_fields));
1539
1540 let exec = VersionedLookupJoinExec::try_new(
1541 input,
1542 versioned_state.batch.clone(),
1543 Arc::clone(&versioned_state.index),
1544 stream_key_indices,
1545 stream_time_col_idx,
1546 lookup_node.join_type(),
1547 output_schema,
1548 key_sort_fields,
1549 )?;
1550
1551 Ok(Some(Arc::new(exec)))
1552 }
1553 }
1554 }
1555}
1556
1557fn apply_pushdown_predicates(
1560 batch: &RecordBatch,
1561 predicates: &[Expr],
1562 session_state: &SessionState,
1563) -> Result<RecordBatch> {
1564 use arrow::compute::filter_record_batch;
1565 use datafusion::physical_expr::create_physical_expr;
1566
1567 let schema = batch.schema();
1568 let df_schema = datafusion::common::DFSchema::try_from(schema.as_ref().clone())?;
1569
1570 let mut mask = None::<arrow_array::BooleanArray>;
1571 for pred in predicates {
1572 let phys_expr = create_physical_expr(pred, &df_schema, session_state.execution_props())?;
1573 let result = phys_expr.evaluate(batch)?;
1574 let bool_arr = result
1575 .into_array(batch.num_rows())?
1576 .as_any()
1577 .downcast_ref::<arrow_array::BooleanArray>()
1578 .ok_or_else(|| {
1579 DataFusionError::Internal("pushdown predicate did not evaluate to boolean".into())
1580 })?
1581 .clone();
1582 mask = Some(match mask {
1583 Some(existing) => arrow::compute::and(&existing, &bool_arr)?,
1584 None => bool_arr,
1585 });
1586 }
1587
1588 match mask {
1589 Some(m) => Ok(filter_record_batch(batch, &m)?),
1590 None => Ok(batch.clone()),
1591 }
1592}
1593
1594fn resolve_stream_keys(node: &LookupJoinNode, schema: &SchemaRef) -> Result<Vec<usize>> {
1595 node.join_keys()
1596 .iter()
1597 .map(|pair| match &pair.stream_expr {
1598 Expr::Column(col) => schema.index_of(&col.name).map_err(|_| {
1599 DataFusionError::Plan(format!(
1600 "stream key column '{}' not found in physical schema",
1601 col.name
1602 ))
1603 }),
1604 other => Err(DataFusionError::NotImplemented(format!(
1605 "lookup join requires column references as stream keys, got: {other}"
1606 ))),
1607 })
1608 .collect()
1609}
1610
1611fn resolve_lookup_keys(node: &LookupJoinNode, schema: &SchemaRef) -> Result<Vec<usize>> {
1612 node.join_keys()
1613 .iter()
1614 .map(|pair| {
1615 schema.index_of(&pair.lookup_column).map_err(|_| {
1616 DataFusionError::Plan(format!(
1617 "lookup key column '{}' not found in lookup table schema",
1618 pair.lookup_column
1619 ))
1620 })
1621 })
1622 .collect()
1623}
1624
1625#[cfg(test)]
1628mod tests {
1629 use super::*;
1630 use arrow_array::{Array, Float64Array, Int64Array, StringArray};
1631 use arrow_schema::{DataType, Field};
1632 use datafusion::physical_plan::stream::RecordBatchStreamAdapter as TestStreamAdapter;
1633 use futures::TryStreamExt;
1634
1635 fn batch_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
1637 let schema = batch.schema();
1638 let batches = vec![batch];
1639 let stream_schema = Arc::clone(&schema);
1640 Arc::new(StreamExecStub {
1641 schema,
1642 batches: std::sync::Mutex::new(Some(batches)),
1643 stream_schema,
1644 })
1645 }
1646
1647 struct StreamExecStub {
1649 schema: SchemaRef,
1650 batches: std::sync::Mutex<Option<Vec<RecordBatch>>>,
1651 stream_schema: SchemaRef,
1652 }
1653
1654 impl Debug for StreamExecStub {
1655 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1656 write!(f, "StreamExecStub")
1657 }
1658 }
1659
1660 impl DisplayAs for StreamExecStub {
1661 fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
1662 write!(f, "StreamExecStub")
1663 }
1664 }
1665
1666 impl ExecutionPlan for StreamExecStub {
1667 fn name(&self) -> &'static str {
1668 "StreamExecStub"
1669 }
1670 fn as_any(&self) -> &dyn Any {
1671 self
1672 }
1673 fn schema(&self) -> SchemaRef {
1674 Arc::clone(&self.schema)
1675 }
1676 fn properties(&self) -> &PlanProperties {
1677 Box::leak(Box::new(PlanProperties::new(
1679 EquivalenceProperties::new(Arc::clone(&self.schema)),
1680 Partitioning::UnknownPartitioning(1),
1681 EmissionType::Final,
1682 Boundedness::Bounded,
1683 )))
1684 }
1685 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1686 vec![]
1687 }
1688 fn with_new_children(
1689 self: Arc<Self>,
1690 _: Vec<Arc<dyn ExecutionPlan>>,
1691 ) -> Result<Arc<dyn ExecutionPlan>> {
1692 Ok(self)
1693 }
1694 fn execute(&self, _: usize, _: Arc<TaskContext>) -> Result<SendableRecordBatchStream> {
1695 let batches = self.batches.lock().unwrap().take().unwrap_or_default();
1696 let schema = Arc::clone(&self.stream_schema);
1697 let stream = futures::stream::iter(batches.into_iter().map(Ok));
1698 Ok(Box::pin(TestStreamAdapter::new(schema, stream)))
1699 }
1700 }
1701
1702 impl datafusion::physical_plan::ExecutionPlanProperties for StreamExecStub {
1703 fn output_partitioning(&self) -> &Partitioning {
1704 self.properties().output_partitioning()
1705 }
1706 fn output_ordering(&self) -> Option<&LexOrdering> {
1707 None
1708 }
1709 fn boundedness(&self) -> Boundedness {
1710 Boundedness::Bounded
1711 }
1712 fn pipeline_behavior(&self) -> EmissionType {
1713 EmissionType::Final
1714 }
1715 fn equivalence_properties(&self) -> &EquivalenceProperties {
1716 self.properties().equivalence_properties()
1717 }
1718 }
1719
1720 fn orders_schema() -> SchemaRef {
1721 Arc::new(Schema::new(vec![
1722 Field::new("order_id", DataType::Int64, false),
1723 Field::new("customer_id", DataType::Int64, false),
1724 Field::new("amount", DataType::Float64, false),
1725 ]))
1726 }
1727
1728 fn customers_schema() -> SchemaRef {
1729 Arc::new(Schema::new(vec![
1730 Field::new("id", DataType::Int64, false),
1731 Field::new("name", DataType::Utf8, true),
1732 ]))
1733 }
1734
1735 fn output_schema() -> SchemaRef {
1736 Arc::new(Schema::new(vec![
1737 Field::new("order_id", DataType::Int64, false),
1738 Field::new("customer_id", DataType::Int64, false),
1739 Field::new("amount", DataType::Float64, false),
1740 Field::new("id", DataType::Int64, false),
1741 Field::new("name", DataType::Utf8, true),
1742 ]))
1743 }
1744
1745 fn customers_batch() -> RecordBatch {
1746 RecordBatch::try_new(
1747 customers_schema(),
1748 vec![
1749 Arc::new(Int64Array::from(vec![1, 2, 3])),
1750 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
1751 ],
1752 )
1753 .unwrap()
1754 }
1755
1756 fn orders_batch() -> RecordBatch {
1757 RecordBatch::try_new(
1758 orders_schema(),
1759 vec![
1760 Arc::new(Int64Array::from(vec![100, 101, 102, 103])),
1761 Arc::new(Int64Array::from(vec![1, 2, 99, 3])),
1762 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0, 40.0])),
1763 ],
1764 )
1765 .unwrap()
1766 }
1767
1768 fn make_exec(join_type: LookupJoinType) -> LookupJoinExec {
1769 let input = batch_exec(orders_batch());
1770 LookupJoinExec::try_new(
1771 input,
1772 customers_batch(),
1773 vec![1], vec![0], join_type,
1776 output_schema(),
1777 )
1778 .unwrap()
1779 }
1780
1781 #[tokio::test]
1782 async fn inner_join_filters_non_matches() {
1783 let exec = make_exec(LookupJoinType::Inner);
1784 let ctx = Arc::new(TaskContext::default());
1785 let stream = exec.execute(0, ctx).unwrap();
1786 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
1787
1788 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1789 assert_eq!(total, 3, "customer_id=99 has no match, filtered by inner");
1790
1791 let names = batches[0]
1792 .column(4)
1793 .as_any()
1794 .downcast_ref::<StringArray>()
1795 .unwrap();
1796 assert_eq!(names.value(0), "Alice");
1797 assert_eq!(names.value(1), "Bob");
1798 assert_eq!(names.value(2), "Charlie");
1799 }
1800
1801 #[tokio::test]
1802 async fn left_outer_preserves_non_matches() {
1803 let exec = make_exec(LookupJoinType::LeftOuter);
1804 let ctx = Arc::new(TaskContext::default());
1805 let stream = exec.execute(0, ctx).unwrap();
1806 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
1807
1808 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1809 assert_eq!(total, 4, "all 4 stream rows preserved in left outer");
1810
1811 let names = batches[0]
1812 .column(4)
1813 .as_any()
1814 .downcast_ref::<StringArray>()
1815 .unwrap();
1816 assert!(names.is_null(2));
1818 }
1819
1820 #[tokio::test]
1821 async fn empty_lookup_inner_produces_no_rows() {
1822 let empty = RecordBatch::new_empty(customers_schema());
1823 let input = batch_exec(orders_batch());
1824 let exec = LookupJoinExec::try_new(
1825 input,
1826 empty,
1827 vec![1],
1828 vec![0],
1829 LookupJoinType::Inner,
1830 output_schema(),
1831 )
1832 .unwrap();
1833
1834 let ctx = Arc::new(TaskContext::default());
1835 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
1836 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1837 assert_eq!(total, 0);
1838 }
1839
1840 #[tokio::test]
1841 async fn empty_lookup_left_outer_preserves_all_stream_rows() {
1842 let empty = RecordBatch::new_empty(customers_schema());
1843 let input = batch_exec(orders_batch());
1844 let exec = LookupJoinExec::try_new(
1845 input,
1846 empty,
1847 vec![1],
1848 vec![0],
1849 LookupJoinType::LeftOuter,
1850 output_schema(),
1851 )
1852 .unwrap();
1853
1854 let ctx = Arc::new(TaskContext::default());
1855 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
1856 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1857 assert_eq!(total, 4);
1858 }
1859
1860 #[tokio::test]
1861 async fn duplicate_keys_produce_multiple_rows() {
1862 let lookup = RecordBatch::try_new(
1863 customers_schema(),
1864 vec![
1865 Arc::new(Int64Array::from(vec![1, 1])),
1866 Arc::new(StringArray::from(vec!["Alice-A", "Alice-B"])),
1867 ],
1868 )
1869 .unwrap();
1870
1871 let stream = RecordBatch::try_new(
1872 orders_schema(),
1873 vec![
1874 Arc::new(Int64Array::from(vec![100])),
1875 Arc::new(Int64Array::from(vec![1])),
1876 Arc::new(Float64Array::from(vec![10.0])),
1877 ],
1878 )
1879 .unwrap();
1880
1881 let input = batch_exec(stream);
1882 let exec = LookupJoinExec::try_new(
1883 input,
1884 lookup,
1885 vec![1],
1886 vec![0],
1887 LookupJoinType::Inner,
1888 output_schema(),
1889 )
1890 .unwrap();
1891
1892 let ctx = Arc::new(TaskContext::default());
1893 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
1894 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1895 assert_eq!(total, 2, "one stream row matched two lookup rows");
1896 }
1897
1898 #[test]
1899 fn with_new_children_preserves_state() {
1900 let exec = Arc::new(make_exec(LookupJoinType::Inner));
1901 let expected_schema = exec.schema();
1902 let children = exec.children().into_iter().cloned().collect();
1903 let rebuilt = exec.with_new_children(children).unwrap();
1904 assert_eq!(rebuilt.schema(), expected_schema);
1905 assert_eq!(rebuilt.name(), "LookupJoinExec");
1906 }
1907
1908 #[test]
1909 fn display_format() {
1910 let exec = make_exec(LookupJoinType::Inner);
1911 let s = format!("{exec:?}");
1912 assert!(s.contains("LookupJoinExec"));
1913 assert!(s.contains("lookup_rows: 3"));
1914 }
1915
1916 #[test]
1917 fn registry_crud() {
1918 let reg = LookupTableRegistry::new();
1919 assert!(reg.get("customers").is_none());
1920
1921 reg.register(
1922 "customers",
1923 LookupSnapshot {
1924 batch: customers_batch(),
1925 key_columns: vec!["id".into()],
1926 },
1927 );
1928 assert!(reg.get("customers").is_some());
1929 assert!(reg.get("CUSTOMERS").is_some(), "case-insensitive");
1930
1931 reg.unregister("customers");
1932 assert!(reg.get("customers").is_none());
1933 }
1934
1935 #[test]
1936 fn registry_update_replaces() {
1937 let reg = LookupTableRegistry::new();
1938 reg.register(
1939 "t",
1940 LookupSnapshot {
1941 batch: RecordBatch::new_empty(customers_schema()),
1942 key_columns: vec![],
1943 },
1944 );
1945 assert_eq!(reg.get("t").unwrap().batch.num_rows(), 0);
1946
1947 reg.register(
1948 "t",
1949 LookupSnapshot {
1950 batch: customers_batch(),
1951 key_columns: vec![],
1952 },
1953 );
1954 assert_eq!(reg.get("t").unwrap().batch.num_rows(), 3);
1955 }
1956
1957 #[test]
1958 fn pushdown_predicates_filter_snapshot() {
1959 use datafusion::logical_expr::{col, lit};
1960
1961 let batch = customers_batch(); let ctx = datafusion::prelude::SessionContext::new();
1963 let state = ctx.state();
1964
1965 let predicates = vec![col("id").gt(lit(1i64))];
1967 let filtered = apply_pushdown_predicates(&batch, &predicates, &state).unwrap();
1968 assert_eq!(filtered.num_rows(), 2);
1969
1970 let ids = filtered
1971 .column(0)
1972 .as_any()
1973 .downcast_ref::<Int64Array>()
1974 .unwrap();
1975 assert_eq!(ids.value(0), 2);
1976 assert_eq!(ids.value(1), 3);
1977 }
1978
1979 #[test]
1980 fn pushdown_predicates_empty_passes_all() {
1981 let batch = customers_batch();
1982 let ctx = datafusion::prelude::SessionContext::new();
1983 let state = ctx.state();
1984
1985 let filtered = apply_pushdown_predicates(&batch, &[], &state).unwrap();
1986 assert_eq!(filtered.num_rows(), 3);
1987 }
1988
1989 #[test]
1990 fn pushdown_predicates_multiple_and() {
1991 use datafusion::logical_expr::{col, lit};
1992
1993 let batch = customers_batch(); let ctx = datafusion::prelude::SessionContext::new();
1995 let state = ctx.state();
1996
1997 let predicates = vec![col("id").gt_eq(lit(2i64)), col("id").lt(lit(3i64))];
1999 let filtered = apply_pushdown_predicates(&batch, &predicates, &state).unwrap();
2000 assert_eq!(filtered.num_rows(), 1);
2001 }
2002
2003 use laminar_core::lookup::foyer_cache::FoyerMemoryCacheConfig;
2006
2007 fn make_foyer_cache() -> Arc<FoyerMemoryCache> {
2008 Arc::new(FoyerMemoryCache::new(
2009 1,
2010 FoyerMemoryCacheConfig {
2011 capacity: 64,
2012 shards: 4,
2013 },
2014 ))
2015 }
2016
2017 fn customer_row(id: i64, name: &str) -> RecordBatch {
2018 RecordBatch::try_new(
2019 customers_schema(),
2020 vec![
2021 Arc::new(Int64Array::from(vec![id])),
2022 Arc::new(StringArray::from(vec![name])),
2023 ],
2024 )
2025 .unwrap()
2026 }
2027
2028 fn warm_cache(cache: &FoyerMemoryCache) {
2029 let converter = RowConverter::new(vec![SortField::new(DataType::Int64)]).unwrap();
2030
2031 for (id, name) in [(1, "Alice"), (2, "Bob"), (3, "Charlie")] {
2032 let key_col = Int64Array::from(vec![id]);
2033 let rows = converter.convert_columns(&[Arc::new(key_col)]).unwrap();
2034 let key = rows.row(0);
2035 cache.insert(key.as_ref(), customer_row(id, name));
2036 }
2037 }
2038
2039 fn make_partial_exec(join_type: LookupJoinType) -> PartialLookupJoinExec {
2040 let cache = make_foyer_cache();
2041 warm_cache(&cache);
2042
2043 let input = batch_exec(orders_batch());
2044 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2045
2046 PartialLookupJoinExec::try_new(
2047 input,
2048 cache,
2049 vec![1], key_sort_fields,
2051 join_type,
2052 customers_schema(),
2053 output_schema(),
2054 )
2055 .unwrap()
2056 }
2057
2058 #[tokio::test]
2059 async fn partial_inner_join_filters_non_matches() {
2060 let exec = make_partial_exec(LookupJoinType::Inner);
2061 let ctx = Arc::new(TaskContext::default());
2062 let stream = exec.execute(0, ctx).unwrap();
2063 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2064
2065 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2066 assert_eq!(total, 3, "customer_id=99 has no match, filtered by inner");
2067
2068 let names = batches[0]
2069 .column(4)
2070 .as_any()
2071 .downcast_ref::<StringArray>()
2072 .unwrap();
2073 assert_eq!(names.value(0), "Alice");
2074 assert_eq!(names.value(1), "Bob");
2075 assert_eq!(names.value(2), "Charlie");
2076 }
2077
2078 #[tokio::test]
2079 async fn partial_left_outer_preserves_non_matches() {
2080 let exec = make_partial_exec(LookupJoinType::LeftOuter);
2081 let ctx = Arc::new(TaskContext::default());
2082 let stream = exec.execute(0, ctx).unwrap();
2083 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2084
2085 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2086 assert_eq!(total, 4, "all 4 stream rows preserved in left outer");
2087
2088 let names = batches[0]
2089 .column(4)
2090 .as_any()
2091 .downcast_ref::<StringArray>()
2092 .unwrap();
2093 assert!(names.is_null(2), "customer_id=99 should have null name");
2094 }
2095
2096 #[tokio::test]
2097 async fn partial_empty_cache_inner_produces_no_rows() {
2098 let cache = make_foyer_cache();
2099 let input = batch_exec(orders_batch());
2100 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2101
2102 let exec = PartialLookupJoinExec::try_new(
2103 input,
2104 cache,
2105 vec![1],
2106 key_sort_fields,
2107 LookupJoinType::Inner,
2108 customers_schema(),
2109 output_schema(),
2110 )
2111 .unwrap();
2112
2113 let ctx = Arc::new(TaskContext::default());
2114 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2115 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2116 assert_eq!(total, 0);
2117 }
2118
2119 #[tokio::test]
2120 async fn partial_empty_cache_left_outer_preserves_all() {
2121 let cache = make_foyer_cache();
2122 let input = batch_exec(orders_batch());
2123 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2124
2125 let exec = PartialLookupJoinExec::try_new(
2126 input,
2127 cache,
2128 vec![1],
2129 key_sort_fields,
2130 LookupJoinType::LeftOuter,
2131 customers_schema(),
2132 output_schema(),
2133 )
2134 .unwrap();
2135
2136 let ctx = Arc::new(TaskContext::default());
2137 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2138 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2139 assert_eq!(total, 4);
2140 }
2141
2142 #[test]
2143 fn partial_with_new_children_preserves_state() {
2144 let exec = Arc::new(make_partial_exec(LookupJoinType::Inner));
2145 let expected_schema = exec.schema();
2146 let children = exec.children().into_iter().cloned().collect();
2147 let rebuilt = exec.with_new_children(children).unwrap();
2148 assert_eq!(rebuilt.schema(), expected_schema);
2149 assert_eq!(rebuilt.name(), "PartialLookupJoinExec");
2150 }
2151
2152 #[test]
2153 fn partial_display_format() {
2154 let exec = make_partial_exec(LookupJoinType::Inner);
2155 let s = format!("{exec:?}");
2156 assert!(s.contains("PartialLookupJoinExec"));
2157 assert!(s.contains("cache_table_id: 1"));
2158 }
2159
2160 #[test]
2161 fn registry_partial_entry() {
2162 let reg = LookupTableRegistry::new();
2163 let cache = make_foyer_cache();
2164 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2165
2166 reg.register_partial(
2167 "customers",
2168 PartialLookupState {
2169 foyer_cache: cache,
2170 schema: customers_schema(),
2171 key_columns: vec!["id".into()],
2172 key_sort_fields,
2173 source: None,
2174 fetch_semaphore: Arc::new(Semaphore::new(64)),
2175 },
2176 );
2177
2178 assert!(reg.get("customers").is_none());
2179
2180 let entry = reg.get_entry("customers");
2181 assert!(entry.is_some());
2182 assert!(matches!(entry.unwrap(), RegisteredLookup::Partial(_)));
2183 }
2184
2185 #[tokio::test]
2186 async fn partial_source_fallback_on_miss() {
2187 use laminar_core::lookup::source::LookupError;
2188 use laminar_core::lookup::source::LookupSourceDyn;
2189
2190 struct TestSource;
2191
2192 #[async_trait]
2193 impl LookupSourceDyn for TestSource {
2194 async fn query_batch(
2195 &self,
2196 keys: &[&[u8]],
2197 _predicates: &[laminar_core::lookup::predicate::Predicate],
2198 _projection: &[laminar_core::lookup::source::ColumnId],
2199 ) -> std::result::Result<Vec<Option<RecordBatch>>, LookupError> {
2200 Ok(keys
2201 .iter()
2202 .map(|_| Some(customer_row(99, "FromSource")))
2203 .collect())
2204 }
2205
2206 fn schema(&self) -> SchemaRef {
2207 customers_schema()
2208 }
2209 }
2210
2211 let cache = make_foyer_cache();
2212 warm_cache(&cache);
2214
2215 let orders = RecordBatch::try_new(
2216 orders_schema(),
2217 vec![
2218 Arc::new(Int64Array::from(vec![200])),
2219 Arc::new(Int64Array::from(vec![99])), Arc::new(Float64Array::from(vec![50.0])),
2221 ],
2222 )
2223 .unwrap();
2224
2225 let input = batch_exec(orders);
2226 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2227 let source: Arc<dyn LookupSourceDyn> = Arc::new(TestSource);
2228
2229 let exec = PartialLookupJoinExec::try_new_with_source(
2230 input,
2231 cache,
2232 vec![1],
2233 key_sort_fields,
2234 LookupJoinType::Inner,
2235 customers_schema(),
2236 output_schema(),
2237 Some(source),
2238 Arc::new(Semaphore::new(64)),
2239 )
2240 .unwrap();
2241
2242 let ctx = Arc::new(TaskContext::default());
2243 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2244 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2245 assert_eq!(total, 1, "source fallback should produce 1 row");
2246
2247 let names = batches[0]
2248 .column(4)
2249 .as_any()
2250 .downcast_ref::<StringArray>()
2251 .unwrap();
2252 assert_eq!(names.value(0), "FromSource");
2253 }
2254
2255 #[tokio::test]
2256 async fn partial_source_error_graceful_degradation() {
2257 use laminar_core::lookup::source::LookupError;
2258 use laminar_core::lookup::source::LookupSourceDyn;
2259
2260 struct FailingSource;
2261
2262 #[async_trait]
2263 impl LookupSourceDyn for FailingSource {
2264 async fn query_batch(
2265 &self,
2266 _keys: &[&[u8]],
2267 _predicates: &[laminar_core::lookup::predicate::Predicate],
2268 _projection: &[laminar_core::lookup::source::ColumnId],
2269 ) -> std::result::Result<Vec<Option<RecordBatch>>, LookupError> {
2270 Err(LookupError::Internal("source unavailable".into()))
2271 }
2272
2273 fn schema(&self) -> SchemaRef {
2274 customers_schema()
2275 }
2276 }
2277
2278 let cache = make_foyer_cache();
2279 let input = batch_exec(orders_batch());
2280 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2281 let source: Arc<dyn LookupSourceDyn> = Arc::new(FailingSource);
2282
2283 let exec = PartialLookupJoinExec::try_new_with_source(
2284 input,
2285 cache,
2286 vec![1],
2287 key_sort_fields,
2288 LookupJoinType::LeftOuter,
2289 customers_schema(),
2290 output_schema(),
2291 Some(source),
2292 Arc::new(Semaphore::new(64)),
2293 )
2294 .unwrap();
2295
2296 let ctx = Arc::new(TaskContext::default());
2297 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2298 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2299 assert_eq!(total, 4);
2301 }
2302
2303 #[test]
2304 fn registry_snapshot_entry_via_get_entry() {
2305 let reg = LookupTableRegistry::new();
2306 reg.register(
2307 "t",
2308 LookupSnapshot {
2309 batch: customers_batch(),
2310 key_columns: vec!["id".into()],
2311 },
2312 );
2313
2314 let entry = reg.get_entry("t");
2315 assert!(matches!(entry.unwrap(), RegisteredLookup::Snapshot(_)));
2316 assert!(reg.get("t").is_some());
2317 }
2318
2319 fn nullable_orders_schema() -> SchemaRef {
2322 Arc::new(Schema::new(vec![
2323 Field::new("order_id", DataType::Int64, false),
2324 Field::new("customer_id", DataType::Int64, true), Field::new("amount", DataType::Float64, false),
2326 ]))
2327 }
2328
2329 fn nullable_output_schema(join_type: LookupJoinType) -> SchemaRef {
2330 let lookup_nullable = join_type == LookupJoinType::LeftOuter;
2331 Arc::new(Schema::new(vec![
2332 Field::new("order_id", DataType::Int64, false),
2333 Field::new("customer_id", DataType::Int64, true),
2334 Field::new("amount", DataType::Float64, false),
2335 Field::new("id", DataType::Int64, lookup_nullable),
2336 Field::new("name", DataType::Utf8, true),
2337 ]))
2338 }
2339
2340 #[tokio::test]
2341 async fn null_key_inner_join_no_match() {
2342 let stream_batch = RecordBatch::try_new(
2344 nullable_orders_schema(),
2345 vec![
2346 Arc::new(Int64Array::from(vec![100, 101, 102])),
2347 Arc::new(Int64Array::from(vec![Some(1), None, Some(2)])),
2348 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
2349 ],
2350 )
2351 .unwrap();
2352
2353 let input = batch_exec(stream_batch);
2354 let exec = LookupJoinExec::try_new(
2355 input,
2356 customers_batch(),
2357 vec![1],
2358 vec![0],
2359 LookupJoinType::Inner,
2360 nullable_output_schema(LookupJoinType::Inner),
2361 )
2362 .unwrap();
2363
2364 let ctx = Arc::new(TaskContext::default());
2365 let stream = exec.execute(0, ctx).unwrap();
2366 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2367
2368 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2369 assert_eq!(total, 2, "NULL key row should not match in inner join");
2371 }
2372
2373 #[tokio::test]
2374 async fn null_key_left_outer_produces_nulls() {
2375 let stream_batch = RecordBatch::try_new(
2377 nullable_orders_schema(),
2378 vec![
2379 Arc::new(Int64Array::from(vec![100, 101, 102])),
2380 Arc::new(Int64Array::from(vec![Some(1), None, Some(2)])),
2381 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
2382 ],
2383 )
2384 .unwrap();
2385
2386 let input = batch_exec(stream_batch);
2387 let out_schema = nullable_output_schema(LookupJoinType::LeftOuter);
2388 let exec = LookupJoinExec::try_new(
2389 input,
2390 customers_batch(),
2391 vec![1],
2392 vec![0],
2393 LookupJoinType::LeftOuter,
2394 out_schema,
2395 )
2396 .unwrap();
2397
2398 let ctx = Arc::new(TaskContext::default());
2399 let stream = exec.execute(0, ctx).unwrap();
2400 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2401
2402 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2403 assert_eq!(total, 3, "all rows preserved in left outer");
2405
2406 let names = batches[0]
2407 .column(4)
2408 .as_any()
2409 .downcast_ref::<StringArray>()
2410 .unwrap();
2411 assert_eq!(names.value(0), "Alice");
2412 assert!(
2413 names.is_null(1),
2414 "NULL key row should have null lookup name"
2415 );
2416 assert_eq!(names.value(2), "Bob");
2417 }
2418
2419 fn versioned_table_batch() -> RecordBatch {
2422 let schema = Arc::new(Schema::new(vec![
2425 Field::new("currency", DataType::Utf8, false),
2426 Field::new("valid_from", DataType::Int64, false),
2427 Field::new("rate", DataType::Float64, false),
2428 ]));
2429 RecordBatch::try_new(
2430 schema,
2431 vec![
2432 Arc::new(StringArray::from(vec!["USD", "USD", "EUR", "EUR", "EUR"])),
2433 Arc::new(Int64Array::from(vec![100, 200, 100, 150, 300])),
2434 Arc::new(Float64Array::from(vec![1.0, 1.1, 0.85, 0.90, 0.88])),
2435 ],
2436 )
2437 .unwrap()
2438 }
2439
2440 fn stream_batch_with_time() -> RecordBatch {
2441 let schema = Arc::new(Schema::new(vec![
2442 Field::new("order_id", DataType::Int64, false),
2443 Field::new("currency", DataType::Utf8, false),
2444 Field::new("event_ts", DataType::Int64, false),
2445 ]));
2446 RecordBatch::try_new(
2447 schema,
2448 vec![
2449 Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
2450 Arc::new(StringArray::from(vec!["USD", "EUR", "USD", "EUR"])),
2451 Arc::new(Int64Array::from(vec![150, 160, 250, 50])),
2452 ],
2453 )
2454 .unwrap()
2455 }
2456
2457 #[test]
2458 fn test_versioned_index_build_and_probe() {
2459 let batch = versioned_table_batch();
2460 let index = VersionedIndex::build(&batch, &[0], 1, usize::MAX).unwrap();
2461
2462 let key_sf = vec![SortField::new(DataType::Utf8)];
2465 let converter = RowConverter::new(key_sf).unwrap();
2466 let usd_col = Arc::new(StringArray::from(vec!["USD"]));
2467 let usd_rows = converter.convert_columns(&[usd_col]).unwrap();
2468 let usd_key = usd_rows.row(0);
2469
2470 let result = index.probe_at_time(usd_key.as_ref(), 150);
2471 assert!(result.is_some());
2472 assert_eq!(result.unwrap(), 0);
2474
2475 let result = index.probe_at_time(usd_key.as_ref(), 250);
2477 assert_eq!(result.unwrap(), 1);
2478 }
2479
2480 #[test]
2481 fn test_versioned_index_no_version_before_ts() {
2482 let batch = versioned_table_batch();
2483 let index = VersionedIndex::build(&batch, &[0], 1, usize::MAX).unwrap();
2484
2485 let key_sf = vec![SortField::new(DataType::Utf8)];
2486 let converter = RowConverter::new(key_sf).unwrap();
2487 let eur_col = Arc::new(StringArray::from(vec!["EUR"]));
2488 let eur_rows = converter.convert_columns(&[eur_col]).unwrap();
2489 let eur_key = eur_rows.row(0);
2490
2491 let result = index.probe_at_time(eur_key.as_ref(), 50);
2493 assert!(result.is_none());
2494 }
2495
2496 fn build_versioned_exec(
2498 table: RecordBatch,
2499 stream: &RecordBatch,
2500 join_type: LookupJoinType,
2501 ) -> VersionedLookupJoinExec {
2502 let input = batch_exec(stream.clone());
2503 let index = Arc::new(VersionedIndex::build(&table, &[0], 1, usize::MAX).unwrap());
2504 let key_sort_fields = vec![SortField::new(DataType::Utf8)];
2505 let mut output_fields = stream.schema().fields().to_vec();
2506 output_fields.extend(table.schema().fields().iter().cloned());
2507 let output_schema = Arc::new(Schema::new(output_fields));
2508 VersionedLookupJoinExec::try_new(
2509 input,
2510 table,
2511 index,
2512 vec![1], 2, join_type,
2515 output_schema,
2516 key_sort_fields,
2517 )
2518 .unwrap()
2519 }
2520
2521 #[tokio::test]
2522 async fn test_versioned_join_exec_inner() {
2523 let table = versioned_table_batch();
2524 let stream = stream_batch_with_time();
2525 let exec = build_versioned_exec(table, &stream, LookupJoinType::Inner);
2526
2527 let ctx = Arc::new(TaskContext::default());
2528 let stream_out = exec.execute(0, ctx).unwrap();
2529 let batches: Vec<RecordBatch> = stream_out.try_collect().await.unwrap();
2530
2531 assert_eq!(batches.len(), 1);
2532 let batch = &batches[0];
2533 assert_eq!(batch.num_rows(), 3);
2538
2539 let rates = batch
2540 .column(5) .as_any()
2542 .downcast_ref::<Float64Array>()
2543 .unwrap();
2544 assert!((rates.value(0) - 1.0).abs() < f64::EPSILON); assert!((rates.value(1) - 0.90).abs() < f64::EPSILON); assert!((rates.value(2) - 1.1).abs() < f64::EPSILON); }
2548
2549 #[tokio::test]
2550 async fn test_versioned_join_exec_left_outer() {
2551 let table = versioned_table_batch();
2552 let stream = stream_batch_with_time();
2553 let exec = build_versioned_exec(table, &stream, LookupJoinType::LeftOuter);
2554
2555 let ctx = Arc::new(TaskContext::default());
2556 let stream_out = exec.execute(0, ctx).unwrap();
2557 let batches: Vec<RecordBatch> = stream_out.try_collect().await.unwrap();
2558
2559 assert_eq!(batches.len(), 1);
2560 let batch = &batches[0];
2561 assert_eq!(batch.num_rows(), 4);
2563
2564 let rates = batch
2566 .column(5)
2567 .as_any()
2568 .downcast_ref::<Float64Array>()
2569 .unwrap();
2570 assert!(rates.is_null(3), "EUR@50 should have null rate");
2571 }
2572
2573 #[test]
2574 fn test_versioned_index_empty_batch() {
2575 let schema = Arc::new(Schema::new(vec![
2576 Field::new("k", DataType::Utf8, false),
2577 Field::new("v", DataType::Int64, false),
2578 ]));
2579 let batch = RecordBatch::new_empty(schema);
2580 let index = VersionedIndex::build(&batch, &[0], 1, usize::MAX).unwrap();
2581 assert!(index.map.is_empty());
2582 }
2583
2584 #[test]
2585 fn test_versioned_lookup_registry() {
2586 let registry = LookupTableRegistry::new();
2587 let table = versioned_table_batch();
2588 let index = Arc::new(VersionedIndex::build(&table, &[0], 1, usize::MAX).unwrap());
2589
2590 registry.register_versioned(
2591 "rates",
2592 VersionedLookupState {
2593 batch: table,
2594 index,
2595 key_columns: vec!["currency".to_string()],
2596 version_column: "valid_from".to_string(),
2597 stream_time_column: "event_ts".to_string(),
2598 max_versions_per_key: usize::MAX,
2599 },
2600 );
2601
2602 let entry = registry.get_entry("rates");
2603 assert!(entry.is_some());
2604 assert!(matches!(entry.unwrap(), RegisteredLookup::Versioned(_)));
2605
2606 assert!(registry.get("rates").is_none());
2608 }
2609}