1use std::any::Any;
17use std::collections::HashMap;
18use std::fmt::{self, Debug, Formatter};
19use std::sync::Arc;
20
21use parking_lot::RwLock;
22
23use std::collections::BTreeMap;
24
25use arrow::compute::take;
26use arrow::row::{RowConverter, SortField};
27use arrow_array::{RecordBatch, UInt32Array};
28use arrow_schema::{Schema, SchemaRef};
29use async_trait::async_trait;
30use datafusion::execution::{SendableRecordBatchStream, SessionState, TaskContext};
31use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNode};
32use datafusion::physical_expr::{EquivalenceProperties, LexOrdering, Partitioning};
33use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
34use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
35use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
36use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
37use datafusion_common::{DataFusionError, Result};
38use datafusion_expr::Expr;
39use futures::StreamExt;
40use laminar_core::lookup::foyer_cache::FoyerMemoryCache;
41use laminar_core::lookup::source::LookupSourceDyn;
42use laminar_core::lookup::table::LookupTable;
43use tokio::sync::Semaphore;
44
45use super::lookup_join::{LookupJoinNode, LookupJoinType};
46
47#[derive(Default)]
54pub struct LookupTableRegistry {
55 tables: RwLock<HashMap<String, RegisteredLookup>>,
56}
57
58pub enum RegisteredLookup {
61 Snapshot(Arc<LookupSnapshot>),
63 Partial(Arc<PartialLookupState>),
65 Versioned(Arc<VersionedLookupState>),
67}
68
69pub struct LookupSnapshot {
71 pub batch: RecordBatch,
73 pub key_columns: Vec<String>,
75}
76
77pub struct VersionedLookupState {
84 pub batch: RecordBatch,
86 pub index: Arc<VersionedIndex>,
88 pub key_columns: Vec<String>,
90 pub version_column: String,
92 pub stream_time_column: String,
94}
95
96pub struct PartialLookupState {
98 pub foyer_cache: Arc<FoyerMemoryCache>,
100 pub schema: SchemaRef,
102 pub key_columns: Vec<String>,
104 pub key_sort_fields: Vec<SortField>,
106 pub source: Option<Arc<dyn LookupSourceDyn>>,
108 pub fetch_semaphore: Arc<Semaphore>,
110}
111
112impl LookupTableRegistry {
113 #[must_use]
115 pub fn new() -> Self {
116 Self::default()
117 }
118
119 pub fn register(&self, name: &str, snapshot: LookupSnapshot) {
125 self.tables.write().insert(
126 name.to_lowercase(),
127 RegisteredLookup::Snapshot(Arc::new(snapshot)),
128 );
129 }
130
131 pub fn register_partial(&self, name: &str, state: PartialLookupState) {
137 self.tables.write().insert(
138 name.to_lowercase(),
139 RegisteredLookup::Partial(Arc::new(state)),
140 );
141 }
142
143 pub fn register_versioned(&self, name: &str, state: VersionedLookupState) {
149 self.tables.write().insert(
150 name.to_lowercase(),
151 RegisteredLookup::Versioned(Arc::new(state)),
152 );
153 }
154
155 pub fn unregister(&self, name: &str) {
161 self.tables.write().remove(&name.to_lowercase());
162 }
163
164 #[must_use]
170 pub fn get(&self, name: &str) -> Option<Arc<LookupSnapshot>> {
171 let tables = self.tables.read();
172 match tables.get(&name.to_lowercase())? {
173 RegisteredLookup::Snapshot(s) => Some(Arc::clone(s)),
174 RegisteredLookup::Partial(_) | RegisteredLookup::Versioned(_) => None,
175 }
176 }
177
178 pub fn get_entry(&self, name: &str) -> Option<RegisteredLookup> {
184 let tables = self.tables.read();
185 tables.get(&name.to_lowercase()).map(|e| match e {
186 RegisteredLookup::Snapshot(s) => RegisteredLookup::Snapshot(Arc::clone(s)),
187 RegisteredLookup::Partial(p) => RegisteredLookup::Partial(Arc::clone(p)),
188 RegisteredLookup::Versioned(v) => RegisteredLookup::Versioned(Arc::clone(v)),
189 })
190 }
191}
192
193struct HashIndex {
197 map: HashMap<Box<[u8]>, Vec<u32>>,
198}
199
200impl HashIndex {
201 fn build(batch: &RecordBatch, key_indices: &[usize]) -> Result<Self> {
206 if batch.num_rows() == 0 {
207 return Ok(Self {
208 map: HashMap::new(),
209 });
210 }
211
212 let sort_fields: Vec<SortField> = key_indices
213 .iter()
214 .map(|&i| SortField::new(batch.schema().field(i).data_type().clone()))
215 .collect();
216 let converter = RowConverter::new(sort_fields)?;
217
218 let key_cols: Vec<_> = key_indices
219 .iter()
220 .map(|&i| batch.column(i).clone())
221 .collect();
222 let rows = converter.convert_columns(&key_cols)?;
223
224 let num_rows = batch.num_rows();
225 let mut map: HashMap<Box<[u8]>, Vec<u32>> = HashMap::with_capacity(num_rows);
226 #[allow(clippy::cast_possible_truncation)] for i in 0..num_rows {
228 map.entry(Box::from(rows.row(i).as_ref()))
229 .or_default()
230 .push(i as u32);
231 }
232
233 Ok(Self { map })
234 }
235
236 fn probe(&self, key: &[u8]) -> Option<&[u32]> {
237 self.map.get(key).map(Vec::as_slice)
238 }
239}
240
241#[derive(Default)]
247pub struct VersionedIndex {
248 map: HashMap<Box<[u8]>, BTreeMap<i64, Vec<u32>>>,
249}
250
251impl VersionedIndex {
252 pub fn build(
262 batch: &RecordBatch,
263 key_indices: &[usize],
264 version_col_idx: usize,
265 ) -> Result<Self> {
266 if batch.num_rows() == 0 {
267 return Ok(Self {
268 map: HashMap::new(),
269 });
270 }
271
272 let sort_fields: Vec<SortField> = key_indices
273 .iter()
274 .map(|&i| SortField::new(batch.schema().field(i).data_type().clone()))
275 .collect();
276 let converter = RowConverter::new(sort_fields)?;
277
278 let key_cols: Vec<_> = key_indices
279 .iter()
280 .map(|&i| batch.column(i).clone())
281 .collect();
282 let rows = converter.convert_columns(&key_cols)?;
283
284 let timestamps = extract_i64_timestamps(batch.column(version_col_idx))?;
285
286 let num_rows = batch.num_rows();
287 let mut map: HashMap<Box<[u8]>, BTreeMap<i64, Vec<u32>>> = HashMap::with_capacity(num_rows);
288 #[allow(clippy::cast_possible_truncation)]
289 for (i, ts_opt) in timestamps.iter().enumerate() {
290 let Some(version_ts) = ts_opt else { continue };
292 if key_cols.iter().any(|c| c.is_null(i)) {
293 continue;
294 }
295 let key = Box::from(rows.row(i).as_ref());
296 map.entry(key)
297 .or_default()
298 .entry(*version_ts)
299 .or_default()
300 .push(i as u32);
301 }
302
303 Ok(Self { map })
304 }
305
306 fn probe_at_time(&self, key: &[u8], event_ts: i64) -> Option<u32> {
309 let versions = self.map.get(key)?;
310 let (_, indices) = versions.range(..=event_ts).next_back()?;
311 indices.last().copied()
312 }
313}
314
315fn extract_i64_timestamps(col: &dyn arrow_array::Array) -> Result<Vec<Option<i64>>> {
321 use arrow_array::{
322 Float64Array, Int64Array, TimestampMicrosecondArray, TimestampMillisecondArray,
323 TimestampNanosecondArray, TimestampSecondArray,
324 };
325 use arrow_schema::{DataType, TimeUnit};
326
327 let n = col.len();
328 let mut out = Vec::with_capacity(n);
329 macro_rules! extract_typed {
330 ($arr_type:ty, $scale:expr) => {{
331 let arr = col.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
332 DataFusionError::Internal(concat!("expected ", stringify!($arr_type)).into())
333 })?;
334 for i in 0..n {
335 out.push(if col.is_null(i) {
336 None
337 } else {
338 Some(arr.value(i) * $scale)
339 });
340 }
341 }};
342 }
343
344 match col.data_type() {
345 DataType::Int64 => extract_typed!(Int64Array, 1),
346 DataType::Timestamp(TimeUnit::Millisecond, _) => {
347 extract_typed!(TimestampMillisecondArray, 1);
348 }
349 DataType::Timestamp(TimeUnit::Microsecond, _) => {
350 let arr = col
351 .as_any()
352 .downcast_ref::<TimestampMicrosecondArray>()
353 .ok_or_else(|| {
354 DataFusionError::Internal("expected TimestampMicrosecondArray".into())
355 })?;
356 for i in 0..n {
357 out.push(if col.is_null(i) {
358 None
359 } else {
360 Some(arr.value(i) / 1000)
361 });
362 }
363 }
364 DataType::Timestamp(TimeUnit::Second, _) => {
365 extract_typed!(TimestampSecondArray, 1000);
366 }
367 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
368 let arr = col
369 .as_any()
370 .downcast_ref::<TimestampNanosecondArray>()
371 .ok_or_else(|| {
372 DataFusionError::Internal("expected TimestampNanosecondArray".into())
373 })?;
374 for i in 0..n {
375 out.push(if col.is_null(i) {
376 None
377 } else {
378 Some(arr.value(i) / 1_000_000)
379 });
380 }
381 }
382 DataType::Float64 => {
383 let arr = col
384 .as_any()
385 .downcast_ref::<Float64Array>()
386 .ok_or_else(|| DataFusionError::Internal("expected Float64Array".into()))?;
387 #[allow(clippy::cast_possible_truncation)]
388 for i in 0..n {
389 out.push(if col.is_null(i) {
390 None
391 } else {
392 Some(arr.value(i) as i64)
393 });
394 }
395 }
396 other => {
397 return Err(DataFusionError::Plan(format!(
398 "unsupported timestamp type for temporal join: {other:?}"
399 )));
400 }
401 }
402
403 Ok(out)
404}
405
406pub struct LookupJoinExec {
411 input: Arc<dyn ExecutionPlan>,
412 index: Arc<HashIndex>,
413 lookup_batch: Arc<RecordBatch>,
414 stream_key_indices: Vec<usize>,
415 join_type: LookupJoinType,
416 schema: SchemaRef,
417 properties: PlanProperties,
418 key_sort_fields: Vec<SortField>,
420 stream_field_count: usize,
421}
422
423impl LookupJoinExec {
424 #[allow(clippy::needless_pass_by_value)] pub fn try_new(
434 input: Arc<dyn ExecutionPlan>,
435 lookup_batch: RecordBatch,
436 stream_key_indices: Vec<usize>,
437 lookup_key_indices: Vec<usize>,
438 join_type: LookupJoinType,
439 output_schema: SchemaRef,
440 ) -> Result<Self> {
441 let index = HashIndex::build(&lookup_batch, &lookup_key_indices)?;
442
443 let key_sort_fields: Vec<SortField> = lookup_key_indices
444 .iter()
445 .map(|&i| SortField::new(lookup_batch.schema().field(i).data_type().clone()))
446 .collect();
447
448 let output_schema = if join_type == LookupJoinType::LeftOuter {
451 let stream_count = input.schema().fields().len();
452 let mut fields = output_schema.fields().to_vec();
453 for f in &mut fields[stream_count..] {
454 if !f.is_nullable() {
455 *f = Arc::new(f.as_ref().clone().with_nullable(true));
456 }
457 }
458 Arc::new(Schema::new_with_metadata(
459 fields,
460 output_schema.metadata().clone(),
461 ))
462 } else {
463 output_schema
464 };
465
466 let properties = PlanProperties::new(
467 EquivalenceProperties::new(Arc::clone(&output_schema)),
468 Partitioning::UnknownPartitioning(1),
469 EmissionType::Incremental,
470 Boundedness::Unbounded {
471 requires_infinite_memory: false,
472 },
473 );
474
475 let stream_field_count = input.schema().fields().len();
476
477 Ok(Self {
478 input,
479 index: Arc::new(index),
480 lookup_batch: Arc::new(lookup_batch),
481 stream_key_indices,
482 join_type,
483 schema: output_schema,
484 properties,
485 key_sort_fields,
486 stream_field_count,
487 })
488 }
489}
490
491impl Debug for LookupJoinExec {
492 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
493 f.debug_struct("LookupJoinExec")
494 .field("join_type", &self.join_type)
495 .field("stream_keys", &self.stream_key_indices)
496 .field("lookup_rows", &self.lookup_batch.num_rows())
497 .finish_non_exhaustive()
498 }
499}
500
501impl DisplayAs for LookupJoinExec {
502 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
503 match t {
504 DisplayFormatType::Default | DisplayFormatType::Verbose => {
505 write!(
506 f,
507 "LookupJoinExec: type={}, stream_keys={:?}, lookup_rows={}",
508 self.join_type,
509 self.stream_key_indices,
510 self.lookup_batch.num_rows(),
511 )
512 }
513 DisplayFormatType::TreeRender => write!(f, "LookupJoinExec"),
514 }
515 }
516}
517
518impl ExecutionPlan for LookupJoinExec {
519 fn name(&self) -> &'static str {
520 "LookupJoinExec"
521 }
522
523 fn as_any(&self) -> &dyn Any {
524 self
525 }
526
527 fn schema(&self) -> SchemaRef {
528 Arc::clone(&self.schema)
529 }
530
531 fn properties(&self) -> &PlanProperties {
532 &self.properties
533 }
534
535 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
536 vec![&self.input]
537 }
538
539 fn with_new_children(
540 self: Arc<Self>,
541 mut children: Vec<Arc<dyn ExecutionPlan>>,
542 ) -> Result<Arc<dyn ExecutionPlan>> {
543 if children.len() != 1 {
544 return Err(DataFusionError::Plan(
545 "LookupJoinExec requires exactly one child".into(),
546 ));
547 }
548 Ok(Arc::new(Self {
549 input: children.swap_remove(0),
550 index: Arc::clone(&self.index),
551 lookup_batch: Arc::clone(&self.lookup_batch),
552 stream_key_indices: self.stream_key_indices.clone(),
553 join_type: self.join_type,
554 schema: Arc::clone(&self.schema),
555 properties: self.properties.clone(),
556 key_sort_fields: self.key_sort_fields.clone(),
557 stream_field_count: self.stream_field_count,
558 }))
559 }
560
561 fn execute(
562 &self,
563 partition: usize,
564 context: Arc<TaskContext>,
565 ) -> Result<SendableRecordBatchStream> {
566 let input_stream = self.input.execute(partition, context)?;
567 let converter = RowConverter::new(self.key_sort_fields.clone())?;
568 let index = Arc::clone(&self.index);
569 let lookup_batch = Arc::clone(&self.lookup_batch);
570 let stream_key_indices = self.stream_key_indices.clone();
571 let join_type = self.join_type;
572 let schema = self.schema();
573 let stream_field_count = self.stream_field_count;
574
575 let output = input_stream.map(move |result| {
576 let batch = result?;
577 if batch.num_rows() == 0 {
578 return Ok(RecordBatch::new_empty(Arc::clone(&schema)));
579 }
580 probe_batch(
581 &batch,
582 &converter,
583 &index,
584 &lookup_batch,
585 &stream_key_indices,
586 join_type,
587 &schema,
588 stream_field_count,
589 )
590 });
591
592 Ok(Box::pin(RecordBatchStreamAdapter::new(
593 self.schema(),
594 output,
595 )))
596 }
597}
598
599impl datafusion::physical_plan::ExecutionPlanProperties for LookupJoinExec {
600 fn output_partitioning(&self) -> &Partitioning {
601 self.properties.output_partitioning()
602 }
603
604 fn output_ordering(&self) -> Option<&LexOrdering> {
605 self.properties.output_ordering()
606 }
607
608 fn boundedness(&self) -> Boundedness {
609 Boundedness::Unbounded {
610 requires_infinite_memory: false,
611 }
612 }
613
614 fn pipeline_behavior(&self) -> EmissionType {
615 EmissionType::Incremental
616 }
617
618 fn equivalence_properties(&self) -> &EquivalenceProperties {
619 self.properties.equivalence_properties()
620 }
621}
622
623#[allow(clippy::too_many_arguments)]
628fn probe_batch(
629 stream_batch: &RecordBatch,
630 converter: &RowConverter,
631 index: &HashIndex,
632 lookup_batch: &RecordBatch,
633 stream_key_indices: &[usize],
634 join_type: LookupJoinType,
635 output_schema: &SchemaRef,
636 stream_field_count: usize,
637) -> Result<RecordBatch> {
638 let key_cols: Vec<_> = stream_key_indices
639 .iter()
640 .map(|&i| stream_batch.column(i).clone())
641 .collect();
642 let rows = converter.convert_columns(&key_cols)?;
643
644 let num_rows = stream_batch.num_rows();
645 let mut stream_indices: Vec<u32> = Vec::with_capacity(num_rows);
646 let mut lookup_indices: Vec<Option<u32>> = Vec::with_capacity(num_rows);
647
648 #[allow(clippy::cast_possible_truncation)] for row in 0..num_rows {
650 if key_cols.iter().any(|c| c.is_null(row)) {
652 if join_type == LookupJoinType::LeftOuter {
653 stream_indices.push(row as u32);
654 lookup_indices.push(None);
655 }
656 continue;
657 }
658
659 let key = rows.row(row);
660 match index.probe(key.as_ref()) {
661 Some(matches) => {
662 for &lookup_row in matches {
663 stream_indices.push(row as u32);
664 lookup_indices.push(Some(lookup_row));
665 }
666 }
667 None if join_type == LookupJoinType::LeftOuter => {
668 stream_indices.push(row as u32);
669 lookup_indices.push(None);
670 }
671 None => {}
672 }
673 }
674
675 if stream_indices.is_empty() {
676 return Ok(RecordBatch::new_empty(Arc::clone(output_schema)));
677 }
678
679 let take_stream = UInt32Array::from(stream_indices);
681 let mut columns = Vec::with_capacity(output_schema.fields().len());
682
683 for col in stream_batch.columns() {
684 columns.push(take(col.as_ref(), &take_stream, None)?);
685 }
686
687 let take_lookup: UInt32Array = lookup_indices.into_iter().collect();
689 for col in lookup_batch.columns() {
690 columns.push(take(col.as_ref(), &take_lookup, None)?);
691 }
692
693 debug_assert_eq!(
694 columns.len(),
695 stream_field_count + lookup_batch.num_columns(),
696 "output column count mismatch"
697 );
698
699 Ok(RecordBatch::try_new(Arc::clone(output_schema), columns)?)
700}
701
702pub struct VersionedLookupJoinExec {
708 input: Arc<dyn ExecutionPlan>,
709 index: Arc<VersionedIndex>,
710 table_batch: Arc<RecordBatch>,
711 stream_key_indices: Vec<usize>,
712 stream_time_col_idx: usize,
713 join_type: LookupJoinType,
714 schema: SchemaRef,
715 properties: PlanProperties,
716 key_sort_fields: Vec<SortField>,
717 stream_field_count: usize,
718}
719
720impl VersionedLookupJoinExec {
721 #[allow(clippy::too_many_arguments, clippy::needless_pass_by_value)]
731 pub fn try_new(
732 input: Arc<dyn ExecutionPlan>,
733 table_batch: RecordBatch,
734 index: Arc<VersionedIndex>,
735 stream_key_indices: Vec<usize>,
736 stream_time_col_idx: usize,
737 join_type: LookupJoinType,
738 output_schema: SchemaRef,
739 key_sort_fields: Vec<SortField>,
740 ) -> Result<Self> {
741 let output_schema = if join_type == LookupJoinType::LeftOuter {
742 let stream_count = input.schema().fields().len();
743 let mut fields = output_schema.fields().to_vec();
744 for f in &mut fields[stream_count..] {
745 if !f.is_nullable() {
746 *f = Arc::new(f.as_ref().clone().with_nullable(true));
747 }
748 }
749 Arc::new(Schema::new_with_metadata(
750 fields,
751 output_schema.metadata().clone(),
752 ))
753 } else {
754 output_schema
755 };
756
757 let properties = PlanProperties::new(
758 EquivalenceProperties::new(Arc::clone(&output_schema)),
759 Partitioning::UnknownPartitioning(1),
760 EmissionType::Incremental,
761 Boundedness::Unbounded {
762 requires_infinite_memory: false,
763 },
764 );
765
766 let stream_field_count = input.schema().fields().len();
767
768 Ok(Self {
769 input,
770 index,
771 table_batch: Arc::new(table_batch),
772 stream_key_indices,
773 stream_time_col_idx,
774 join_type,
775 schema: output_schema,
776 properties,
777 key_sort_fields,
778 stream_field_count,
779 })
780 }
781}
782
783impl Debug for VersionedLookupJoinExec {
784 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
785 f.debug_struct("VersionedLookupJoinExec")
786 .field("join_type", &self.join_type)
787 .field("stream_keys", &self.stream_key_indices)
788 .field("table_rows", &self.table_batch.num_rows())
789 .finish_non_exhaustive()
790 }
791}
792
793impl DisplayAs for VersionedLookupJoinExec {
794 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
795 match t {
796 DisplayFormatType::Default | DisplayFormatType::Verbose => {
797 write!(
798 f,
799 "VersionedLookupJoinExec: type={}, stream_keys={:?}, table_rows={}",
800 self.join_type,
801 self.stream_key_indices,
802 self.table_batch.num_rows(),
803 )
804 }
805 DisplayFormatType::TreeRender => write!(f, "VersionedLookupJoinExec"),
806 }
807 }
808}
809
810impl ExecutionPlan for VersionedLookupJoinExec {
811 fn name(&self) -> &'static str {
812 "VersionedLookupJoinExec"
813 }
814
815 fn as_any(&self) -> &dyn Any {
816 self
817 }
818
819 fn schema(&self) -> SchemaRef {
820 Arc::clone(&self.schema)
821 }
822
823 fn properties(&self) -> &PlanProperties {
824 &self.properties
825 }
826
827 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
828 vec![&self.input]
829 }
830
831 fn with_new_children(
832 self: Arc<Self>,
833 mut children: Vec<Arc<dyn ExecutionPlan>>,
834 ) -> Result<Arc<dyn ExecutionPlan>> {
835 if children.len() != 1 {
836 return Err(DataFusionError::Plan(
837 "VersionedLookupJoinExec requires exactly one child".into(),
838 ));
839 }
840 Ok(Arc::new(Self {
841 input: children.swap_remove(0),
842 index: Arc::clone(&self.index),
843 table_batch: Arc::clone(&self.table_batch),
844 stream_key_indices: self.stream_key_indices.clone(),
845 stream_time_col_idx: self.stream_time_col_idx,
846 join_type: self.join_type,
847 schema: Arc::clone(&self.schema),
848 properties: self.properties.clone(),
849 key_sort_fields: self.key_sort_fields.clone(),
850 stream_field_count: self.stream_field_count,
851 }))
852 }
853
854 fn execute(
855 &self,
856 partition: usize,
857 context: Arc<TaskContext>,
858 ) -> Result<SendableRecordBatchStream> {
859 let input_stream = self.input.execute(partition, context)?;
860 let converter = RowConverter::new(self.key_sort_fields.clone())?;
861 let index = Arc::clone(&self.index);
862 let table_batch = Arc::clone(&self.table_batch);
863 let stream_key_indices = self.stream_key_indices.clone();
864 let stream_time_col_idx = self.stream_time_col_idx;
865 let join_type = self.join_type;
866 let schema = self.schema();
867 let stream_field_count = self.stream_field_count;
868
869 let output = input_stream.map(move |result| {
870 let batch = result?;
871 if batch.num_rows() == 0 {
872 return Ok(RecordBatch::new_empty(Arc::clone(&schema)));
873 }
874 probe_versioned_batch(
875 &batch,
876 &converter,
877 &index,
878 &table_batch,
879 &stream_key_indices,
880 stream_time_col_idx,
881 join_type,
882 &schema,
883 stream_field_count,
884 )
885 });
886
887 Ok(Box::pin(RecordBatchStreamAdapter::new(
888 self.schema(),
889 output,
890 )))
891 }
892}
893
894impl datafusion::physical_plan::ExecutionPlanProperties for VersionedLookupJoinExec {
895 fn output_partitioning(&self) -> &Partitioning {
896 self.properties.output_partitioning()
897 }
898
899 fn output_ordering(&self) -> Option<&LexOrdering> {
900 self.properties.output_ordering()
901 }
902
903 fn boundedness(&self) -> Boundedness {
904 Boundedness::Unbounded {
905 requires_infinite_memory: false,
906 }
907 }
908
909 fn pipeline_behavior(&self) -> EmissionType {
910 EmissionType::Incremental
911 }
912
913 fn equivalence_properties(&self) -> &EquivalenceProperties {
914 self.properties.equivalence_properties()
915 }
916}
917
918#[allow(clippy::too_many_arguments)]
921fn probe_versioned_batch(
922 stream_batch: &RecordBatch,
923 converter: &RowConverter,
924 index: &VersionedIndex,
925 table_batch: &RecordBatch,
926 stream_key_indices: &[usize],
927 stream_time_col_idx: usize,
928 join_type: LookupJoinType,
929 output_schema: &SchemaRef,
930 stream_field_count: usize,
931) -> Result<RecordBatch> {
932 let key_cols: Vec<_> = stream_key_indices
933 .iter()
934 .map(|&i| stream_batch.column(i).clone())
935 .collect();
936 let rows = converter.convert_columns(&key_cols)?;
937 let event_timestamps =
938 extract_i64_timestamps(stream_batch.column(stream_time_col_idx).as_ref())?;
939
940 let num_rows = stream_batch.num_rows();
941 let mut stream_indices: Vec<u32> = Vec::with_capacity(num_rows);
942 let mut lookup_indices: Vec<Option<u32>> = Vec::with_capacity(num_rows);
943
944 #[allow(clippy::cast_possible_truncation)]
945 for (row, event_ts_opt) in event_timestamps.iter().enumerate() {
946 if key_cols.iter().any(|c| c.is_null(row)) || event_ts_opt.is_none() {
948 if join_type == LookupJoinType::LeftOuter {
949 stream_indices.push(row as u32);
950 lookup_indices.push(None);
951 }
952 continue;
953 }
954
955 let key = rows.row(row);
956 let event_ts = event_ts_opt.unwrap();
957 match index.probe_at_time(key.as_ref(), event_ts) {
958 Some(table_row_idx) => {
959 stream_indices.push(row as u32);
960 lookup_indices.push(Some(table_row_idx));
961 }
962 None if join_type == LookupJoinType::LeftOuter => {
963 stream_indices.push(row as u32);
964 lookup_indices.push(None);
965 }
966 None => {}
967 }
968 }
969
970 if stream_indices.is_empty() {
971 return Ok(RecordBatch::new_empty(Arc::clone(output_schema)));
972 }
973
974 let take_stream = UInt32Array::from(stream_indices);
975 let mut columns = Vec::with_capacity(output_schema.fields().len());
976
977 for col in stream_batch.columns() {
978 columns.push(take(col.as_ref(), &take_stream, None)?);
979 }
980
981 let take_lookup: UInt32Array = lookup_indices.into_iter().collect();
982 for col in table_batch.columns() {
983 columns.push(take(col.as_ref(), &take_lookup, None)?);
984 }
985
986 debug_assert_eq!(
987 columns.len(),
988 stream_field_count + table_batch.num_columns(),
989 "output column count mismatch"
990 );
991
992 Ok(RecordBatch::try_new(Arc::clone(output_schema), columns)?)
993}
994
995pub struct PartialLookupJoinExec {
1001 input: Arc<dyn ExecutionPlan>,
1002 foyer_cache: Arc<FoyerMemoryCache>,
1003 stream_key_indices: Vec<usize>,
1004 join_type: LookupJoinType,
1005 schema: SchemaRef,
1006 properties: PlanProperties,
1007 key_sort_fields: Vec<SortField>,
1008 stream_field_count: usize,
1009 lookup_schema: SchemaRef,
1010 source: Option<Arc<dyn LookupSourceDyn>>,
1011 fetch_semaphore: Arc<Semaphore>,
1012}
1013
1014impl PartialLookupJoinExec {
1015 pub fn try_new(
1021 input: Arc<dyn ExecutionPlan>,
1022 foyer_cache: Arc<FoyerMemoryCache>,
1023 stream_key_indices: Vec<usize>,
1024 key_sort_fields: Vec<SortField>,
1025 join_type: LookupJoinType,
1026 lookup_schema: SchemaRef,
1027 output_schema: SchemaRef,
1028 ) -> Result<Self> {
1029 Self::try_new_with_source(
1030 input,
1031 foyer_cache,
1032 stream_key_indices,
1033 key_sort_fields,
1034 join_type,
1035 lookup_schema,
1036 output_schema,
1037 None,
1038 Arc::new(Semaphore::new(64)),
1039 )
1040 }
1041
1042 #[allow(clippy::too_many_arguments)]
1048 pub fn try_new_with_source(
1049 input: Arc<dyn ExecutionPlan>,
1050 foyer_cache: Arc<FoyerMemoryCache>,
1051 stream_key_indices: Vec<usize>,
1052 key_sort_fields: Vec<SortField>,
1053 join_type: LookupJoinType,
1054 lookup_schema: SchemaRef,
1055 output_schema: SchemaRef,
1056 source: Option<Arc<dyn LookupSourceDyn>>,
1057 fetch_semaphore: Arc<Semaphore>,
1058 ) -> Result<Self> {
1059 let output_schema = if join_type == LookupJoinType::LeftOuter {
1060 let stream_count = input.schema().fields().len();
1061 let mut fields = output_schema.fields().to_vec();
1062 for f in &mut fields[stream_count..] {
1063 if !f.is_nullable() {
1064 *f = Arc::new(f.as_ref().clone().with_nullable(true));
1065 }
1066 }
1067 Arc::new(Schema::new_with_metadata(
1068 fields,
1069 output_schema.metadata().clone(),
1070 ))
1071 } else {
1072 output_schema
1073 };
1074
1075 let properties = PlanProperties::new(
1076 EquivalenceProperties::new(Arc::clone(&output_schema)),
1077 Partitioning::UnknownPartitioning(1),
1078 EmissionType::Incremental,
1079 Boundedness::Unbounded {
1080 requires_infinite_memory: false,
1081 },
1082 );
1083
1084 let stream_field_count = input.schema().fields().len();
1085
1086 Ok(Self {
1087 input,
1088 foyer_cache,
1089 stream_key_indices,
1090 join_type,
1091 schema: output_schema,
1092 properties,
1093 key_sort_fields,
1094 stream_field_count,
1095 lookup_schema,
1096 source,
1097 fetch_semaphore,
1098 })
1099 }
1100}
1101
1102impl Debug for PartialLookupJoinExec {
1103 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1104 f.debug_struct("PartialLookupJoinExec")
1105 .field("join_type", &self.join_type)
1106 .field("stream_keys", &self.stream_key_indices)
1107 .field("cache_table_id", &self.foyer_cache.table_id())
1108 .finish_non_exhaustive()
1109 }
1110}
1111
1112impl DisplayAs for PartialLookupJoinExec {
1113 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
1114 match t {
1115 DisplayFormatType::Default | DisplayFormatType::Verbose => {
1116 write!(
1117 f,
1118 "PartialLookupJoinExec: type={}, stream_keys={:?}, cache_entries={}",
1119 self.join_type,
1120 self.stream_key_indices,
1121 self.foyer_cache.len(),
1122 )
1123 }
1124 DisplayFormatType::TreeRender => write!(f, "PartialLookupJoinExec"),
1125 }
1126 }
1127}
1128
1129impl ExecutionPlan for PartialLookupJoinExec {
1130 fn name(&self) -> &'static str {
1131 "PartialLookupJoinExec"
1132 }
1133
1134 fn as_any(&self) -> &dyn Any {
1135 self
1136 }
1137
1138 fn schema(&self) -> SchemaRef {
1139 Arc::clone(&self.schema)
1140 }
1141
1142 fn properties(&self) -> &PlanProperties {
1143 &self.properties
1144 }
1145
1146 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1147 vec![&self.input]
1148 }
1149
1150 fn with_new_children(
1151 self: Arc<Self>,
1152 mut children: Vec<Arc<dyn ExecutionPlan>>,
1153 ) -> Result<Arc<dyn ExecutionPlan>> {
1154 if children.len() != 1 {
1155 return Err(DataFusionError::Plan(
1156 "PartialLookupJoinExec requires exactly one child".into(),
1157 ));
1158 }
1159 Ok(Arc::new(Self {
1160 input: children.swap_remove(0),
1161 foyer_cache: Arc::clone(&self.foyer_cache),
1162 stream_key_indices: self.stream_key_indices.clone(),
1163 join_type: self.join_type,
1164 schema: Arc::clone(&self.schema),
1165 properties: self.properties.clone(),
1166 key_sort_fields: self.key_sort_fields.clone(),
1167 stream_field_count: self.stream_field_count,
1168 lookup_schema: Arc::clone(&self.lookup_schema),
1169 source: self.source.clone(),
1170 fetch_semaphore: Arc::clone(&self.fetch_semaphore),
1171 }))
1172 }
1173
1174 fn execute(
1175 &self,
1176 partition: usize,
1177 context: Arc<TaskContext>,
1178 ) -> Result<SendableRecordBatchStream> {
1179 let input_stream = self.input.execute(partition, context)?;
1180 let converter = Arc::new(RowConverter::new(self.key_sort_fields.clone())?);
1181 let foyer_cache = Arc::clone(&self.foyer_cache);
1182 let stream_key_indices = self.stream_key_indices.clone();
1183 let join_type = self.join_type;
1184 let schema = self.schema();
1185 let stream_field_count = self.stream_field_count;
1186 let lookup_schema = Arc::clone(&self.lookup_schema);
1187 let source = self.source.clone();
1188 let fetch_semaphore = Arc::clone(&self.fetch_semaphore);
1189
1190 let output = input_stream.then(move |result| {
1191 let foyer_cache = Arc::clone(&foyer_cache);
1192 let converter = Arc::clone(&converter);
1193 let stream_key_indices = stream_key_indices.clone();
1194 let schema = Arc::clone(&schema);
1195 let lookup_schema = Arc::clone(&lookup_schema);
1196 let source = source.clone();
1197 let fetch_semaphore = Arc::clone(&fetch_semaphore);
1198 async move {
1199 let batch = result?;
1200 if batch.num_rows() == 0 {
1201 return Ok(RecordBatch::new_empty(Arc::clone(&schema)));
1202 }
1203 probe_partial_batch_with_fallback(
1204 &batch,
1205 &converter,
1206 &foyer_cache,
1207 &stream_key_indices,
1208 join_type,
1209 &schema,
1210 stream_field_count,
1211 &lookup_schema,
1212 source.as_deref(),
1213 &fetch_semaphore,
1214 )
1215 .await
1216 }
1217 });
1218
1219 Ok(Box::pin(RecordBatchStreamAdapter::new(
1220 self.schema(),
1221 output,
1222 )))
1223 }
1224}
1225
1226impl datafusion::physical_plan::ExecutionPlanProperties for PartialLookupJoinExec {
1227 fn output_partitioning(&self) -> &Partitioning {
1228 self.properties.output_partitioning()
1229 }
1230
1231 fn output_ordering(&self) -> Option<&LexOrdering> {
1232 self.properties.output_ordering()
1233 }
1234
1235 fn boundedness(&self) -> Boundedness {
1236 Boundedness::Unbounded {
1237 requires_infinite_memory: false,
1238 }
1239 }
1240
1241 fn pipeline_behavior(&self) -> EmissionType {
1242 EmissionType::Incremental
1243 }
1244
1245 fn equivalence_properties(&self) -> &EquivalenceProperties {
1246 self.properties.equivalence_properties()
1247 }
1248}
1249
1250#[allow(clippy::too_many_arguments)]
1254async fn probe_partial_batch_with_fallback(
1255 stream_batch: &RecordBatch,
1256 converter: &RowConverter,
1257 foyer_cache: &FoyerMemoryCache,
1258 stream_key_indices: &[usize],
1259 join_type: LookupJoinType,
1260 output_schema: &SchemaRef,
1261 stream_field_count: usize,
1262 lookup_schema: &SchemaRef,
1263 source: Option<&dyn LookupSourceDyn>,
1264 fetch_semaphore: &Semaphore,
1265) -> Result<RecordBatch> {
1266 let key_cols: Vec<_> = stream_key_indices
1267 .iter()
1268 .map(|&i| stream_batch.column(i).clone())
1269 .collect();
1270 let rows = converter.convert_columns(&key_cols)?;
1271
1272 let num_rows = stream_batch.num_rows();
1273 let mut stream_indices: Vec<u32> = Vec::with_capacity(num_rows);
1274 let mut lookup_batches: Vec<Option<RecordBatch>> = Vec::with_capacity(num_rows);
1275 let mut miss_keys: Vec<(usize, Vec<u8>)> = Vec::new();
1276
1277 #[allow(clippy::cast_possible_truncation)]
1278 for row in 0..num_rows {
1279 if key_cols.iter().any(|c| c.is_null(row)) {
1281 if join_type == LookupJoinType::LeftOuter {
1282 stream_indices.push(row as u32);
1283 lookup_batches.push(None);
1284 }
1285 continue;
1286 }
1287
1288 let key = rows.row(row);
1289 let result = foyer_cache.get_cached(key.as_ref());
1290 if let Some(batch) = result.into_batch() {
1291 stream_indices.push(row as u32);
1292 lookup_batches.push(Some(batch));
1293 } else {
1294 let idx = stream_indices.len();
1295 stream_indices.push(row as u32);
1296 lookup_batches.push(None);
1297 miss_keys.push((idx, key.as_ref().to_vec()));
1298 }
1299 }
1300
1301 if let Some(source) = source {
1303 if !miss_keys.is_empty() {
1304 let _permit = fetch_semaphore
1305 .acquire()
1306 .await
1307 .map_err(|_| DataFusionError::Internal("fetch semaphore closed".into()))?;
1308
1309 let key_refs: Vec<&[u8]> = miss_keys.iter().map(|(_, k)| k.as_slice()).collect();
1310 let source_results = source.query_batch(&key_refs, &[], &[]).await;
1311
1312 match source_results {
1313 Ok(results) => {
1314 for ((idx, key_bytes), maybe_batch) in miss_keys.iter().zip(results.into_iter())
1315 {
1316 if let Some(batch) = maybe_batch {
1317 foyer_cache.insert(key_bytes, batch.clone());
1318 lookup_batches[*idx] = Some(batch);
1319 }
1320 }
1321 }
1322 Err(e) => {
1323 tracing::warn!(error = %e, "source fallback failed, serving cache-only results");
1324 }
1325 }
1326 }
1327 }
1328
1329 if join_type == LookupJoinType::Inner {
1331 let mut write = 0;
1332 for read in 0..stream_indices.len() {
1333 if lookup_batches[read].is_some() {
1334 stream_indices[write] = stream_indices[read];
1335 lookup_batches.swap(write, read);
1336 write += 1;
1337 }
1338 }
1339 stream_indices.truncate(write);
1340 lookup_batches.truncate(write);
1341 }
1342
1343 if stream_indices.is_empty() {
1344 return Ok(RecordBatch::new_empty(Arc::clone(output_schema)));
1345 }
1346
1347 let take_indices = UInt32Array::from(stream_indices);
1348 let mut columns = Vec::with_capacity(output_schema.fields().len());
1349
1350 for col in stream_batch.columns() {
1351 columns.push(take(col.as_ref(), &take_indices, None)?);
1352 }
1353
1354 let lookup_col_count = lookup_schema.fields().len();
1355 for col_idx in 0..lookup_col_count {
1356 let arrays: Vec<_> = lookup_batches
1357 .iter()
1358 .map(|opt| match opt {
1359 Some(b) => b.column(col_idx).clone(),
1360 None => arrow_array::new_null_array(lookup_schema.field(col_idx).data_type(), 1),
1361 })
1362 .collect();
1363 let refs: Vec<&dyn arrow_array::Array> = arrays.iter().map(AsRef::as_ref).collect();
1364 columns.push(arrow::compute::concat(&refs)?);
1365 }
1366
1367 debug_assert_eq!(
1368 columns.len(),
1369 stream_field_count + lookup_col_count,
1370 "output column count mismatch"
1371 );
1372
1373 Ok(RecordBatch::try_new(Arc::clone(output_schema), columns)?)
1374}
1375
1376pub struct LookupJoinExtensionPlanner {
1382 registry: Arc<LookupTableRegistry>,
1383}
1384
1385impl LookupJoinExtensionPlanner {
1386 pub fn new(registry: Arc<LookupTableRegistry>) -> Self {
1388 Self { registry }
1389 }
1390}
1391
1392#[async_trait]
1393impl ExtensionPlanner for LookupJoinExtensionPlanner {
1394 #[allow(clippy::too_many_lines)]
1395 async fn plan_extension(
1396 &self,
1397 _planner: &dyn PhysicalPlanner,
1398 node: &dyn UserDefinedLogicalNode,
1399 _logical_inputs: &[&LogicalPlan],
1400 physical_inputs: &[Arc<dyn ExecutionPlan>],
1401 session_state: &SessionState,
1402 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1403 let Some(lookup_node) = node.as_any().downcast_ref::<LookupJoinNode>() else {
1404 return Ok(None);
1405 };
1406
1407 let entry = self
1408 .registry
1409 .get_entry(lookup_node.lookup_table_name())
1410 .ok_or_else(|| {
1411 DataFusionError::Plan(format!(
1412 "lookup table '{}' not registered",
1413 lookup_node.lookup_table_name()
1414 ))
1415 })?;
1416
1417 let input = Arc::clone(&physical_inputs[0]);
1418 let stream_schema = input.schema();
1419
1420 match entry {
1421 RegisteredLookup::Partial(partial_state) => {
1422 let stream_key_indices = resolve_stream_keys(lookup_node, &stream_schema)?;
1423
1424 let mut output_fields = stream_schema.fields().to_vec();
1425 output_fields.extend(partial_state.schema.fields().iter().cloned());
1426 let output_schema = Arc::new(Schema::new(output_fields));
1427
1428 let exec = PartialLookupJoinExec::try_new_with_source(
1429 input,
1430 Arc::clone(&partial_state.foyer_cache),
1431 stream_key_indices,
1432 partial_state.key_sort_fields.clone(),
1433 lookup_node.join_type(),
1434 Arc::clone(&partial_state.schema),
1435 output_schema,
1436 partial_state.source.clone(),
1437 Arc::clone(&partial_state.fetch_semaphore),
1438 )?;
1439 Ok(Some(Arc::new(exec)))
1440 }
1441 RegisteredLookup::Snapshot(snapshot) => {
1442 let lookup_schema = snapshot.batch.schema();
1443 let lookup_key_indices = resolve_lookup_keys(lookup_node, &lookup_schema)?;
1444
1445 let lookup_batch = if lookup_node.pushdown_predicates().is_empty()
1446 || snapshot.batch.num_rows() == 0
1447 {
1448 snapshot.batch.clone()
1449 } else {
1450 apply_pushdown_predicates(
1451 &snapshot.batch,
1452 lookup_node.pushdown_predicates(),
1453 session_state,
1454 )?
1455 };
1456
1457 let stream_key_indices = resolve_stream_keys(lookup_node, &stream_schema)?;
1458
1459 for (si, li) in stream_key_indices.iter().zip(&lookup_key_indices) {
1461 let st = stream_schema.field(*si).data_type();
1462 let lt = lookup_schema.field(*li).data_type();
1463 if st != lt {
1464 return Err(DataFusionError::Plan(format!(
1465 "Lookup join key type mismatch: stream '{}' is {st:?} \
1466 but lookup '{}' is {lt:?}",
1467 stream_schema.field(*si).name(),
1468 lookup_schema.field(*li).name(),
1469 )));
1470 }
1471 }
1472
1473 let mut output_fields = stream_schema.fields().to_vec();
1474 output_fields.extend(lookup_batch.schema().fields().iter().cloned());
1475 let output_schema = Arc::new(Schema::new(output_fields));
1476
1477 let exec = LookupJoinExec::try_new(
1478 input,
1479 lookup_batch,
1480 stream_key_indices,
1481 lookup_key_indices,
1482 lookup_node.join_type(),
1483 output_schema,
1484 )?;
1485
1486 Ok(Some(Arc::new(exec)))
1487 }
1488 RegisteredLookup::Versioned(versioned_state) => {
1489 let table_schema = versioned_state.batch.schema();
1490 let lookup_key_indices = resolve_lookup_keys(lookup_node, &table_schema)?;
1491 let stream_key_indices = resolve_stream_keys(lookup_node, &stream_schema)?;
1492
1493 for (si, li) in stream_key_indices.iter().zip(&lookup_key_indices) {
1495 let st = stream_schema.field(*si).data_type();
1496 let lt = table_schema.field(*li).data_type();
1497 if st != lt {
1498 return Err(DataFusionError::Plan(format!(
1499 "Temporal join key type mismatch: stream '{}' is {st:?} \
1500 but table '{}' is {lt:?}",
1501 stream_schema.field(*si).name(),
1502 table_schema.field(*li).name(),
1503 )));
1504 }
1505 }
1506
1507 let stream_time_col_idx = stream_schema
1508 .index_of(&versioned_state.stream_time_column)
1509 .map_err(|_| {
1510 DataFusionError::Plan(format!(
1511 "stream time column '{}' not found in stream schema",
1512 versioned_state.stream_time_column
1513 ))
1514 })?;
1515
1516 let key_sort_fields: Vec<SortField> = lookup_key_indices
1517 .iter()
1518 .map(|&i| SortField::new(table_schema.field(i).data_type().clone()))
1519 .collect();
1520
1521 let mut output_fields = stream_schema.fields().to_vec();
1522 output_fields.extend(table_schema.fields().iter().cloned());
1523 let output_schema = Arc::new(Schema::new(output_fields));
1524
1525 let exec = VersionedLookupJoinExec::try_new(
1526 input,
1527 versioned_state.batch.clone(),
1528 Arc::clone(&versioned_state.index),
1529 stream_key_indices,
1530 stream_time_col_idx,
1531 lookup_node.join_type(),
1532 output_schema,
1533 key_sort_fields,
1534 )?;
1535
1536 Ok(Some(Arc::new(exec)))
1537 }
1538 }
1539 }
1540}
1541
1542fn apply_pushdown_predicates(
1545 batch: &RecordBatch,
1546 predicates: &[Expr],
1547 session_state: &SessionState,
1548) -> Result<RecordBatch> {
1549 use arrow::compute::filter_record_batch;
1550 use datafusion::physical_expr::create_physical_expr;
1551
1552 let schema = batch.schema();
1553 let df_schema = datafusion::common::DFSchema::try_from(schema.as_ref().clone())?;
1554
1555 let mut mask = None::<arrow_array::BooleanArray>;
1556 for pred in predicates {
1557 let phys_expr = create_physical_expr(pred, &df_schema, session_state.execution_props())?;
1558 let result = phys_expr.evaluate(batch)?;
1559 let bool_arr = result
1560 .into_array(batch.num_rows())?
1561 .as_any()
1562 .downcast_ref::<arrow_array::BooleanArray>()
1563 .ok_or_else(|| {
1564 DataFusionError::Internal("pushdown predicate did not evaluate to boolean".into())
1565 })?
1566 .clone();
1567 mask = Some(match mask {
1568 Some(existing) => arrow::compute::and(&existing, &bool_arr)?,
1569 None => bool_arr,
1570 });
1571 }
1572
1573 match mask {
1574 Some(m) => Ok(filter_record_batch(batch, &m)?),
1575 None => Ok(batch.clone()),
1576 }
1577}
1578
1579fn resolve_stream_keys(node: &LookupJoinNode, schema: &SchemaRef) -> Result<Vec<usize>> {
1580 node.join_keys()
1581 .iter()
1582 .map(|pair| match &pair.stream_expr {
1583 Expr::Column(col) => schema.index_of(&col.name).map_err(|_| {
1584 DataFusionError::Plan(format!(
1585 "stream key column '{}' not found in physical schema",
1586 col.name
1587 ))
1588 }),
1589 other => Err(DataFusionError::NotImplemented(format!(
1590 "lookup join requires column references as stream keys, got: {other}"
1591 ))),
1592 })
1593 .collect()
1594}
1595
1596fn resolve_lookup_keys(node: &LookupJoinNode, schema: &SchemaRef) -> Result<Vec<usize>> {
1597 node.join_keys()
1598 .iter()
1599 .map(|pair| {
1600 schema.index_of(&pair.lookup_column).map_err(|_| {
1601 DataFusionError::Plan(format!(
1602 "lookup key column '{}' not found in lookup table schema",
1603 pair.lookup_column
1604 ))
1605 })
1606 })
1607 .collect()
1608}
1609
1610#[cfg(test)]
1613mod tests {
1614 use super::*;
1615 use arrow_array::{Array, Float64Array, Int64Array, StringArray};
1616 use arrow_schema::{DataType, Field};
1617 use datafusion::physical_plan::stream::RecordBatchStreamAdapter as TestStreamAdapter;
1618 use futures::TryStreamExt;
1619
1620 fn batch_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
1622 let schema = batch.schema();
1623 let batches = vec![batch];
1624 let stream_schema = Arc::clone(&schema);
1625 Arc::new(StreamExecStub {
1626 schema,
1627 batches: std::sync::Mutex::new(Some(batches)),
1628 stream_schema,
1629 })
1630 }
1631
1632 struct StreamExecStub {
1634 schema: SchemaRef,
1635 batches: std::sync::Mutex<Option<Vec<RecordBatch>>>,
1636 stream_schema: SchemaRef,
1637 }
1638
1639 impl Debug for StreamExecStub {
1640 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1641 write!(f, "StreamExecStub")
1642 }
1643 }
1644
1645 impl DisplayAs for StreamExecStub {
1646 fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
1647 write!(f, "StreamExecStub")
1648 }
1649 }
1650
1651 impl ExecutionPlan for StreamExecStub {
1652 fn name(&self) -> &'static str {
1653 "StreamExecStub"
1654 }
1655 fn as_any(&self) -> &dyn Any {
1656 self
1657 }
1658 fn schema(&self) -> SchemaRef {
1659 Arc::clone(&self.schema)
1660 }
1661 fn properties(&self) -> &PlanProperties {
1662 Box::leak(Box::new(PlanProperties::new(
1664 EquivalenceProperties::new(Arc::clone(&self.schema)),
1665 Partitioning::UnknownPartitioning(1),
1666 EmissionType::Final,
1667 Boundedness::Bounded,
1668 )))
1669 }
1670 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1671 vec![]
1672 }
1673 fn with_new_children(
1674 self: Arc<Self>,
1675 _: Vec<Arc<dyn ExecutionPlan>>,
1676 ) -> Result<Arc<dyn ExecutionPlan>> {
1677 Ok(self)
1678 }
1679 fn execute(&self, _: usize, _: Arc<TaskContext>) -> Result<SendableRecordBatchStream> {
1680 let batches = self.batches.lock().unwrap().take().unwrap_or_default();
1681 let schema = Arc::clone(&self.stream_schema);
1682 let stream = futures::stream::iter(batches.into_iter().map(Ok));
1683 Ok(Box::pin(TestStreamAdapter::new(schema, stream)))
1684 }
1685 }
1686
1687 impl datafusion::physical_plan::ExecutionPlanProperties for StreamExecStub {
1688 fn output_partitioning(&self) -> &Partitioning {
1689 self.properties().output_partitioning()
1690 }
1691 fn output_ordering(&self) -> Option<&LexOrdering> {
1692 None
1693 }
1694 fn boundedness(&self) -> Boundedness {
1695 Boundedness::Bounded
1696 }
1697 fn pipeline_behavior(&self) -> EmissionType {
1698 EmissionType::Final
1699 }
1700 fn equivalence_properties(&self) -> &EquivalenceProperties {
1701 self.properties().equivalence_properties()
1702 }
1703 }
1704
1705 fn orders_schema() -> SchemaRef {
1706 Arc::new(Schema::new(vec![
1707 Field::new("order_id", DataType::Int64, false),
1708 Field::new("customer_id", DataType::Int64, false),
1709 Field::new("amount", DataType::Float64, false),
1710 ]))
1711 }
1712
1713 fn customers_schema() -> SchemaRef {
1714 Arc::new(Schema::new(vec![
1715 Field::new("id", DataType::Int64, false),
1716 Field::new("name", DataType::Utf8, true),
1717 ]))
1718 }
1719
1720 fn output_schema() -> SchemaRef {
1721 Arc::new(Schema::new(vec![
1722 Field::new("order_id", DataType::Int64, false),
1723 Field::new("customer_id", DataType::Int64, false),
1724 Field::new("amount", DataType::Float64, false),
1725 Field::new("id", DataType::Int64, false),
1726 Field::new("name", DataType::Utf8, true),
1727 ]))
1728 }
1729
1730 fn customers_batch() -> RecordBatch {
1731 RecordBatch::try_new(
1732 customers_schema(),
1733 vec![
1734 Arc::new(Int64Array::from(vec![1, 2, 3])),
1735 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
1736 ],
1737 )
1738 .unwrap()
1739 }
1740
1741 fn orders_batch() -> RecordBatch {
1742 RecordBatch::try_new(
1743 orders_schema(),
1744 vec![
1745 Arc::new(Int64Array::from(vec![100, 101, 102, 103])),
1746 Arc::new(Int64Array::from(vec![1, 2, 99, 3])),
1747 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0, 40.0])),
1748 ],
1749 )
1750 .unwrap()
1751 }
1752
1753 fn make_exec(join_type: LookupJoinType) -> LookupJoinExec {
1754 let input = batch_exec(orders_batch());
1755 LookupJoinExec::try_new(
1756 input,
1757 customers_batch(),
1758 vec![1], vec![0], join_type,
1761 output_schema(),
1762 )
1763 .unwrap()
1764 }
1765
1766 #[tokio::test]
1767 async fn inner_join_filters_non_matches() {
1768 let exec = make_exec(LookupJoinType::Inner);
1769 let ctx = Arc::new(TaskContext::default());
1770 let stream = exec.execute(0, ctx).unwrap();
1771 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
1772
1773 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1774 assert_eq!(total, 3, "customer_id=99 has no match, filtered by inner");
1775
1776 let names = batches[0]
1777 .column(4)
1778 .as_any()
1779 .downcast_ref::<StringArray>()
1780 .unwrap();
1781 assert_eq!(names.value(0), "Alice");
1782 assert_eq!(names.value(1), "Bob");
1783 assert_eq!(names.value(2), "Charlie");
1784 }
1785
1786 #[tokio::test]
1787 async fn left_outer_preserves_non_matches() {
1788 let exec = make_exec(LookupJoinType::LeftOuter);
1789 let ctx = Arc::new(TaskContext::default());
1790 let stream = exec.execute(0, ctx).unwrap();
1791 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
1792
1793 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1794 assert_eq!(total, 4, "all 4 stream rows preserved in left outer");
1795
1796 let names = batches[0]
1797 .column(4)
1798 .as_any()
1799 .downcast_ref::<StringArray>()
1800 .unwrap();
1801 assert!(names.is_null(2));
1803 }
1804
1805 #[tokio::test]
1806 async fn empty_lookup_inner_produces_no_rows() {
1807 let empty = RecordBatch::new_empty(customers_schema());
1808 let input = batch_exec(orders_batch());
1809 let exec = LookupJoinExec::try_new(
1810 input,
1811 empty,
1812 vec![1],
1813 vec![0],
1814 LookupJoinType::Inner,
1815 output_schema(),
1816 )
1817 .unwrap();
1818
1819 let ctx = Arc::new(TaskContext::default());
1820 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
1821 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1822 assert_eq!(total, 0);
1823 }
1824
1825 #[tokio::test]
1826 async fn empty_lookup_left_outer_preserves_all_stream_rows() {
1827 let empty = RecordBatch::new_empty(customers_schema());
1828 let input = batch_exec(orders_batch());
1829 let exec = LookupJoinExec::try_new(
1830 input,
1831 empty,
1832 vec![1],
1833 vec![0],
1834 LookupJoinType::LeftOuter,
1835 output_schema(),
1836 )
1837 .unwrap();
1838
1839 let ctx = Arc::new(TaskContext::default());
1840 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
1841 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1842 assert_eq!(total, 4);
1843 }
1844
1845 #[tokio::test]
1846 async fn duplicate_keys_produce_multiple_rows() {
1847 let lookup = RecordBatch::try_new(
1848 customers_schema(),
1849 vec![
1850 Arc::new(Int64Array::from(vec![1, 1])),
1851 Arc::new(StringArray::from(vec!["Alice-A", "Alice-B"])),
1852 ],
1853 )
1854 .unwrap();
1855
1856 let stream = RecordBatch::try_new(
1857 orders_schema(),
1858 vec![
1859 Arc::new(Int64Array::from(vec![100])),
1860 Arc::new(Int64Array::from(vec![1])),
1861 Arc::new(Float64Array::from(vec![10.0])),
1862 ],
1863 )
1864 .unwrap();
1865
1866 let input = batch_exec(stream);
1867 let exec = LookupJoinExec::try_new(
1868 input,
1869 lookup,
1870 vec![1],
1871 vec![0],
1872 LookupJoinType::Inner,
1873 output_schema(),
1874 )
1875 .unwrap();
1876
1877 let ctx = Arc::new(TaskContext::default());
1878 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
1879 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1880 assert_eq!(total, 2, "one stream row matched two lookup rows");
1881 }
1882
1883 #[test]
1884 fn with_new_children_preserves_state() {
1885 let exec = Arc::new(make_exec(LookupJoinType::Inner));
1886 let expected_schema = exec.schema();
1887 let children = exec.children().into_iter().cloned().collect();
1888 let rebuilt = exec.with_new_children(children).unwrap();
1889 assert_eq!(rebuilt.schema(), expected_schema);
1890 assert_eq!(rebuilt.name(), "LookupJoinExec");
1891 }
1892
1893 #[test]
1894 fn display_format() {
1895 let exec = make_exec(LookupJoinType::Inner);
1896 let s = format!("{exec:?}");
1897 assert!(s.contains("LookupJoinExec"));
1898 assert!(s.contains("lookup_rows: 3"));
1899 }
1900
1901 #[test]
1902 fn registry_crud() {
1903 let reg = LookupTableRegistry::new();
1904 assert!(reg.get("customers").is_none());
1905
1906 reg.register(
1907 "customers",
1908 LookupSnapshot {
1909 batch: customers_batch(),
1910 key_columns: vec!["id".into()],
1911 },
1912 );
1913 assert!(reg.get("customers").is_some());
1914 assert!(reg.get("CUSTOMERS").is_some(), "case-insensitive");
1915
1916 reg.unregister("customers");
1917 assert!(reg.get("customers").is_none());
1918 }
1919
1920 #[test]
1921 fn registry_update_replaces() {
1922 let reg = LookupTableRegistry::new();
1923 reg.register(
1924 "t",
1925 LookupSnapshot {
1926 batch: RecordBatch::new_empty(customers_schema()),
1927 key_columns: vec![],
1928 },
1929 );
1930 assert_eq!(reg.get("t").unwrap().batch.num_rows(), 0);
1931
1932 reg.register(
1933 "t",
1934 LookupSnapshot {
1935 batch: customers_batch(),
1936 key_columns: vec![],
1937 },
1938 );
1939 assert_eq!(reg.get("t").unwrap().batch.num_rows(), 3);
1940 }
1941
1942 #[test]
1943 fn pushdown_predicates_filter_snapshot() {
1944 use datafusion::logical_expr::{col, lit};
1945
1946 let batch = customers_batch(); let ctx = datafusion::prelude::SessionContext::new();
1948 let state = ctx.state();
1949
1950 let predicates = vec![col("id").gt(lit(1i64))];
1952 let filtered = apply_pushdown_predicates(&batch, &predicates, &state).unwrap();
1953 assert_eq!(filtered.num_rows(), 2);
1954
1955 let ids = filtered
1956 .column(0)
1957 .as_any()
1958 .downcast_ref::<Int64Array>()
1959 .unwrap();
1960 assert_eq!(ids.value(0), 2);
1961 assert_eq!(ids.value(1), 3);
1962 }
1963
1964 #[test]
1965 fn pushdown_predicates_empty_passes_all() {
1966 let batch = customers_batch();
1967 let ctx = datafusion::prelude::SessionContext::new();
1968 let state = ctx.state();
1969
1970 let filtered = apply_pushdown_predicates(&batch, &[], &state).unwrap();
1971 assert_eq!(filtered.num_rows(), 3);
1972 }
1973
1974 #[test]
1975 fn pushdown_predicates_multiple_and() {
1976 use datafusion::logical_expr::{col, lit};
1977
1978 let batch = customers_batch(); let ctx = datafusion::prelude::SessionContext::new();
1980 let state = ctx.state();
1981
1982 let predicates = vec![col("id").gt_eq(lit(2i64)), col("id").lt(lit(3i64))];
1984 let filtered = apply_pushdown_predicates(&batch, &predicates, &state).unwrap();
1985 assert_eq!(filtered.num_rows(), 1);
1986 }
1987
1988 use laminar_core::lookup::foyer_cache::FoyerMemoryCacheConfig;
1991
1992 fn make_foyer_cache() -> Arc<FoyerMemoryCache> {
1993 Arc::new(FoyerMemoryCache::new(
1994 1,
1995 FoyerMemoryCacheConfig {
1996 capacity: 64,
1997 shards: 4,
1998 },
1999 ))
2000 }
2001
2002 fn customer_row(id: i64, name: &str) -> RecordBatch {
2003 RecordBatch::try_new(
2004 customers_schema(),
2005 vec![
2006 Arc::new(Int64Array::from(vec![id])),
2007 Arc::new(StringArray::from(vec![name])),
2008 ],
2009 )
2010 .unwrap()
2011 }
2012
2013 fn warm_cache(cache: &FoyerMemoryCache) {
2014 let converter = RowConverter::new(vec![SortField::new(DataType::Int64)]).unwrap();
2015
2016 for (id, name) in [(1, "Alice"), (2, "Bob"), (3, "Charlie")] {
2017 let key_col = Int64Array::from(vec![id]);
2018 let rows = converter.convert_columns(&[Arc::new(key_col)]).unwrap();
2019 let key = rows.row(0);
2020 cache.insert(key.as_ref(), customer_row(id, name));
2021 }
2022 }
2023
2024 fn make_partial_exec(join_type: LookupJoinType) -> PartialLookupJoinExec {
2025 let cache = make_foyer_cache();
2026 warm_cache(&cache);
2027
2028 let input = batch_exec(orders_batch());
2029 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2030
2031 PartialLookupJoinExec::try_new(
2032 input,
2033 cache,
2034 vec![1], key_sort_fields,
2036 join_type,
2037 customers_schema(),
2038 output_schema(),
2039 )
2040 .unwrap()
2041 }
2042
2043 #[tokio::test]
2044 async fn partial_inner_join_filters_non_matches() {
2045 let exec = make_partial_exec(LookupJoinType::Inner);
2046 let ctx = Arc::new(TaskContext::default());
2047 let stream = exec.execute(0, ctx).unwrap();
2048 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2049
2050 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2051 assert_eq!(total, 3, "customer_id=99 has no match, filtered by inner");
2052
2053 let names = batches[0]
2054 .column(4)
2055 .as_any()
2056 .downcast_ref::<StringArray>()
2057 .unwrap();
2058 assert_eq!(names.value(0), "Alice");
2059 assert_eq!(names.value(1), "Bob");
2060 assert_eq!(names.value(2), "Charlie");
2061 }
2062
2063 #[tokio::test]
2064 async fn partial_left_outer_preserves_non_matches() {
2065 let exec = make_partial_exec(LookupJoinType::LeftOuter);
2066 let ctx = Arc::new(TaskContext::default());
2067 let stream = exec.execute(0, ctx).unwrap();
2068 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2069
2070 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2071 assert_eq!(total, 4, "all 4 stream rows preserved in left outer");
2072
2073 let names = batches[0]
2074 .column(4)
2075 .as_any()
2076 .downcast_ref::<StringArray>()
2077 .unwrap();
2078 assert!(names.is_null(2), "customer_id=99 should have null name");
2079 }
2080
2081 #[tokio::test]
2082 async fn partial_empty_cache_inner_produces_no_rows() {
2083 let cache = make_foyer_cache();
2084 let input = batch_exec(orders_batch());
2085 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2086
2087 let exec = PartialLookupJoinExec::try_new(
2088 input,
2089 cache,
2090 vec![1],
2091 key_sort_fields,
2092 LookupJoinType::Inner,
2093 customers_schema(),
2094 output_schema(),
2095 )
2096 .unwrap();
2097
2098 let ctx = Arc::new(TaskContext::default());
2099 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2100 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2101 assert_eq!(total, 0);
2102 }
2103
2104 #[tokio::test]
2105 async fn partial_empty_cache_left_outer_preserves_all() {
2106 let cache = make_foyer_cache();
2107 let input = batch_exec(orders_batch());
2108 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2109
2110 let exec = PartialLookupJoinExec::try_new(
2111 input,
2112 cache,
2113 vec![1],
2114 key_sort_fields,
2115 LookupJoinType::LeftOuter,
2116 customers_schema(),
2117 output_schema(),
2118 )
2119 .unwrap();
2120
2121 let ctx = Arc::new(TaskContext::default());
2122 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2123 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2124 assert_eq!(total, 4);
2125 }
2126
2127 #[test]
2128 fn partial_with_new_children_preserves_state() {
2129 let exec = Arc::new(make_partial_exec(LookupJoinType::Inner));
2130 let expected_schema = exec.schema();
2131 let children = exec.children().into_iter().cloned().collect();
2132 let rebuilt = exec.with_new_children(children).unwrap();
2133 assert_eq!(rebuilt.schema(), expected_schema);
2134 assert_eq!(rebuilt.name(), "PartialLookupJoinExec");
2135 }
2136
2137 #[test]
2138 fn partial_display_format() {
2139 let exec = make_partial_exec(LookupJoinType::Inner);
2140 let s = format!("{exec:?}");
2141 assert!(s.contains("PartialLookupJoinExec"));
2142 assert!(s.contains("cache_table_id: 1"));
2143 }
2144
2145 #[test]
2146 fn registry_partial_entry() {
2147 let reg = LookupTableRegistry::new();
2148 let cache = make_foyer_cache();
2149 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2150
2151 reg.register_partial(
2152 "customers",
2153 PartialLookupState {
2154 foyer_cache: cache,
2155 schema: customers_schema(),
2156 key_columns: vec!["id".into()],
2157 key_sort_fields,
2158 source: None,
2159 fetch_semaphore: Arc::new(Semaphore::new(64)),
2160 },
2161 );
2162
2163 assert!(reg.get("customers").is_none());
2164
2165 let entry = reg.get_entry("customers");
2166 assert!(entry.is_some());
2167 assert!(matches!(entry.unwrap(), RegisteredLookup::Partial(_)));
2168 }
2169
2170 #[tokio::test]
2171 async fn partial_source_fallback_on_miss() {
2172 use laminar_core::lookup::source::LookupError;
2173 use laminar_core::lookup::source::LookupSourceDyn;
2174
2175 struct TestSource;
2176
2177 #[async_trait]
2178 impl LookupSourceDyn for TestSource {
2179 async fn query_batch(
2180 &self,
2181 keys: &[&[u8]],
2182 _predicates: &[laminar_core::lookup::predicate::Predicate],
2183 _projection: &[laminar_core::lookup::source::ColumnId],
2184 ) -> std::result::Result<Vec<Option<RecordBatch>>, LookupError> {
2185 Ok(keys
2186 .iter()
2187 .map(|_| Some(customer_row(99, "FromSource")))
2188 .collect())
2189 }
2190
2191 fn schema(&self) -> SchemaRef {
2192 customers_schema()
2193 }
2194 }
2195
2196 let cache = make_foyer_cache();
2197 warm_cache(&cache);
2199
2200 let orders = RecordBatch::try_new(
2201 orders_schema(),
2202 vec![
2203 Arc::new(Int64Array::from(vec![200])),
2204 Arc::new(Int64Array::from(vec![99])), Arc::new(Float64Array::from(vec![50.0])),
2206 ],
2207 )
2208 .unwrap();
2209
2210 let input = batch_exec(orders);
2211 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2212 let source: Arc<dyn LookupSourceDyn> = Arc::new(TestSource);
2213
2214 let exec = PartialLookupJoinExec::try_new_with_source(
2215 input,
2216 cache,
2217 vec![1],
2218 key_sort_fields,
2219 LookupJoinType::Inner,
2220 customers_schema(),
2221 output_schema(),
2222 Some(source),
2223 Arc::new(Semaphore::new(64)),
2224 )
2225 .unwrap();
2226
2227 let ctx = Arc::new(TaskContext::default());
2228 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2229 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2230 assert_eq!(total, 1, "source fallback should produce 1 row");
2231
2232 let names = batches[0]
2233 .column(4)
2234 .as_any()
2235 .downcast_ref::<StringArray>()
2236 .unwrap();
2237 assert_eq!(names.value(0), "FromSource");
2238 }
2239
2240 #[tokio::test]
2241 async fn partial_source_error_graceful_degradation() {
2242 use laminar_core::lookup::source::LookupError;
2243 use laminar_core::lookup::source::LookupSourceDyn;
2244
2245 struct FailingSource;
2246
2247 #[async_trait]
2248 impl LookupSourceDyn for FailingSource {
2249 async fn query_batch(
2250 &self,
2251 _keys: &[&[u8]],
2252 _predicates: &[laminar_core::lookup::predicate::Predicate],
2253 _projection: &[laminar_core::lookup::source::ColumnId],
2254 ) -> std::result::Result<Vec<Option<RecordBatch>>, LookupError> {
2255 Err(LookupError::Internal("source unavailable".into()))
2256 }
2257
2258 fn schema(&self) -> SchemaRef {
2259 customers_schema()
2260 }
2261 }
2262
2263 let cache = make_foyer_cache();
2264 let input = batch_exec(orders_batch());
2265 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2266 let source: Arc<dyn LookupSourceDyn> = Arc::new(FailingSource);
2267
2268 let exec = PartialLookupJoinExec::try_new_with_source(
2269 input,
2270 cache,
2271 vec![1],
2272 key_sort_fields,
2273 LookupJoinType::LeftOuter,
2274 customers_schema(),
2275 output_schema(),
2276 Some(source),
2277 Arc::new(Semaphore::new(64)),
2278 )
2279 .unwrap();
2280
2281 let ctx = Arc::new(TaskContext::default());
2282 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2283 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2284 assert_eq!(total, 4);
2286 }
2287
2288 #[test]
2289 fn registry_snapshot_entry_via_get_entry() {
2290 let reg = LookupTableRegistry::new();
2291 reg.register(
2292 "t",
2293 LookupSnapshot {
2294 batch: customers_batch(),
2295 key_columns: vec!["id".into()],
2296 },
2297 );
2298
2299 let entry = reg.get_entry("t");
2300 assert!(matches!(entry.unwrap(), RegisteredLookup::Snapshot(_)));
2301 assert!(reg.get("t").is_some());
2302 }
2303
2304 fn nullable_orders_schema() -> SchemaRef {
2307 Arc::new(Schema::new(vec![
2308 Field::new("order_id", DataType::Int64, false),
2309 Field::new("customer_id", DataType::Int64, true), Field::new("amount", DataType::Float64, false),
2311 ]))
2312 }
2313
2314 fn nullable_output_schema(join_type: LookupJoinType) -> SchemaRef {
2315 let lookup_nullable = join_type == LookupJoinType::LeftOuter;
2316 Arc::new(Schema::new(vec![
2317 Field::new("order_id", DataType::Int64, false),
2318 Field::new("customer_id", DataType::Int64, true),
2319 Field::new("amount", DataType::Float64, false),
2320 Field::new("id", DataType::Int64, lookup_nullable),
2321 Field::new("name", DataType::Utf8, true),
2322 ]))
2323 }
2324
2325 #[tokio::test]
2326 async fn null_key_inner_join_no_match() {
2327 let stream_batch = RecordBatch::try_new(
2329 nullable_orders_schema(),
2330 vec![
2331 Arc::new(Int64Array::from(vec![100, 101, 102])),
2332 Arc::new(Int64Array::from(vec![Some(1), None, Some(2)])),
2333 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
2334 ],
2335 )
2336 .unwrap();
2337
2338 let input = batch_exec(stream_batch);
2339 let exec = LookupJoinExec::try_new(
2340 input,
2341 customers_batch(),
2342 vec![1],
2343 vec![0],
2344 LookupJoinType::Inner,
2345 nullable_output_schema(LookupJoinType::Inner),
2346 )
2347 .unwrap();
2348
2349 let ctx = Arc::new(TaskContext::default());
2350 let stream = exec.execute(0, ctx).unwrap();
2351 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2352
2353 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2354 assert_eq!(total, 2, "NULL key row should not match in inner join");
2356 }
2357
2358 #[tokio::test]
2359 async fn null_key_left_outer_produces_nulls() {
2360 let stream_batch = RecordBatch::try_new(
2362 nullable_orders_schema(),
2363 vec![
2364 Arc::new(Int64Array::from(vec![100, 101, 102])),
2365 Arc::new(Int64Array::from(vec![Some(1), None, Some(2)])),
2366 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
2367 ],
2368 )
2369 .unwrap();
2370
2371 let input = batch_exec(stream_batch);
2372 let out_schema = nullable_output_schema(LookupJoinType::LeftOuter);
2373 let exec = LookupJoinExec::try_new(
2374 input,
2375 customers_batch(),
2376 vec![1],
2377 vec![0],
2378 LookupJoinType::LeftOuter,
2379 out_schema,
2380 )
2381 .unwrap();
2382
2383 let ctx = Arc::new(TaskContext::default());
2384 let stream = exec.execute(0, ctx).unwrap();
2385 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2386
2387 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2388 assert_eq!(total, 3, "all rows preserved in left outer");
2390
2391 let names = batches[0]
2392 .column(4)
2393 .as_any()
2394 .downcast_ref::<StringArray>()
2395 .unwrap();
2396 assert_eq!(names.value(0), "Alice");
2397 assert!(
2398 names.is_null(1),
2399 "NULL key row should have null lookup name"
2400 );
2401 assert_eq!(names.value(2), "Bob");
2402 }
2403
2404 fn versioned_table_batch() -> RecordBatch {
2407 let schema = Arc::new(Schema::new(vec![
2410 Field::new("currency", DataType::Utf8, false),
2411 Field::new("valid_from", DataType::Int64, false),
2412 Field::new("rate", DataType::Float64, false),
2413 ]));
2414 RecordBatch::try_new(
2415 schema,
2416 vec![
2417 Arc::new(StringArray::from(vec!["USD", "USD", "EUR", "EUR", "EUR"])),
2418 Arc::new(Int64Array::from(vec![100, 200, 100, 150, 300])),
2419 Arc::new(Float64Array::from(vec![1.0, 1.1, 0.85, 0.90, 0.88])),
2420 ],
2421 )
2422 .unwrap()
2423 }
2424
2425 fn stream_batch_with_time() -> RecordBatch {
2426 let schema = Arc::new(Schema::new(vec![
2427 Field::new("order_id", DataType::Int64, false),
2428 Field::new("currency", DataType::Utf8, false),
2429 Field::new("event_ts", DataType::Int64, false),
2430 ]));
2431 RecordBatch::try_new(
2432 schema,
2433 vec![
2434 Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
2435 Arc::new(StringArray::from(vec!["USD", "EUR", "USD", "EUR"])),
2436 Arc::new(Int64Array::from(vec![150, 160, 250, 50])),
2437 ],
2438 )
2439 .unwrap()
2440 }
2441
2442 #[test]
2443 fn test_versioned_index_build_and_probe() {
2444 let batch = versioned_table_batch();
2445 let index = VersionedIndex::build(&batch, &[0], 1).unwrap();
2446
2447 let key_sf = vec![SortField::new(DataType::Utf8)];
2450 let converter = RowConverter::new(key_sf).unwrap();
2451 let usd_col = Arc::new(StringArray::from(vec!["USD"]));
2452 let usd_rows = converter.convert_columns(&[usd_col]).unwrap();
2453 let usd_key = usd_rows.row(0);
2454
2455 let result = index.probe_at_time(usd_key.as_ref(), 150);
2456 assert!(result.is_some());
2457 assert_eq!(result.unwrap(), 0);
2459
2460 let result = index.probe_at_time(usd_key.as_ref(), 250);
2462 assert_eq!(result.unwrap(), 1);
2463 }
2464
2465 #[test]
2466 fn test_versioned_index_no_version_before_ts() {
2467 let batch = versioned_table_batch();
2468 let index = VersionedIndex::build(&batch, &[0], 1).unwrap();
2469
2470 let key_sf = vec![SortField::new(DataType::Utf8)];
2471 let converter = RowConverter::new(key_sf).unwrap();
2472 let eur_col = Arc::new(StringArray::from(vec!["EUR"]));
2473 let eur_rows = converter.convert_columns(&[eur_col]).unwrap();
2474 let eur_key = eur_rows.row(0);
2475
2476 let result = index.probe_at_time(eur_key.as_ref(), 50);
2478 assert!(result.is_none());
2479 }
2480
2481 fn build_versioned_exec(
2483 table: RecordBatch,
2484 stream: &RecordBatch,
2485 join_type: LookupJoinType,
2486 ) -> VersionedLookupJoinExec {
2487 let input = batch_exec(stream.clone());
2488 let index = Arc::new(VersionedIndex::build(&table, &[0], 1).unwrap());
2489 let key_sort_fields = vec![SortField::new(DataType::Utf8)];
2490 let mut output_fields = stream.schema().fields().to_vec();
2491 output_fields.extend(table.schema().fields().iter().cloned());
2492 let output_schema = Arc::new(Schema::new(output_fields));
2493 VersionedLookupJoinExec::try_new(
2494 input,
2495 table,
2496 index,
2497 vec![1], 2, join_type,
2500 output_schema,
2501 key_sort_fields,
2502 )
2503 .unwrap()
2504 }
2505
2506 #[tokio::test]
2507 async fn test_versioned_join_exec_inner() {
2508 let table = versioned_table_batch();
2509 let stream = stream_batch_with_time();
2510 let exec = build_versioned_exec(table, &stream, LookupJoinType::Inner);
2511
2512 let ctx = Arc::new(TaskContext::default());
2513 let stream_out = exec.execute(0, ctx).unwrap();
2514 let batches: Vec<RecordBatch> = stream_out.try_collect().await.unwrap();
2515
2516 assert_eq!(batches.len(), 1);
2517 let batch = &batches[0];
2518 assert_eq!(batch.num_rows(), 3);
2523
2524 let rates = batch
2525 .column(5) .as_any()
2527 .downcast_ref::<Float64Array>()
2528 .unwrap();
2529 assert!((rates.value(0) - 1.0).abs() < f64::EPSILON); assert!((rates.value(1) - 0.90).abs() < f64::EPSILON); assert!((rates.value(2) - 1.1).abs() < f64::EPSILON); }
2533
2534 #[tokio::test]
2535 async fn test_versioned_join_exec_left_outer() {
2536 let table = versioned_table_batch();
2537 let stream = stream_batch_with_time();
2538 let exec = build_versioned_exec(table, &stream, LookupJoinType::LeftOuter);
2539
2540 let ctx = Arc::new(TaskContext::default());
2541 let stream_out = exec.execute(0, ctx).unwrap();
2542 let batches: Vec<RecordBatch> = stream_out.try_collect().await.unwrap();
2543
2544 assert_eq!(batches.len(), 1);
2545 let batch = &batches[0];
2546 assert_eq!(batch.num_rows(), 4);
2548
2549 let rates = batch
2551 .column(5)
2552 .as_any()
2553 .downcast_ref::<Float64Array>()
2554 .unwrap();
2555 assert!(rates.is_null(3), "EUR@50 should have null rate");
2556 }
2557
2558 #[test]
2559 fn test_versioned_index_empty_batch() {
2560 let schema = Arc::new(Schema::new(vec![
2561 Field::new("k", DataType::Utf8, false),
2562 Field::new("v", DataType::Int64, false),
2563 ]));
2564 let batch = RecordBatch::new_empty(schema);
2565 let index = VersionedIndex::build(&batch, &[0], 1).unwrap();
2566 assert!(index.map.is_empty());
2567 }
2568
2569 #[test]
2570 fn test_versioned_lookup_registry() {
2571 let registry = LookupTableRegistry::new();
2572 let table = versioned_table_batch();
2573 let index = Arc::new(VersionedIndex::build(&table, &[0], 1).unwrap());
2574
2575 registry.register_versioned(
2576 "rates",
2577 VersionedLookupState {
2578 batch: table,
2579 index,
2580 key_columns: vec!["currency".to_string()],
2581 version_column: "valid_from".to_string(),
2582 stream_time_column: "event_ts".to_string(),
2583 },
2584 );
2585
2586 let entry = registry.get_entry("rates");
2587 assert!(entry.is_some());
2588 assert!(matches!(entry.unwrap(), RegisteredLookup::Versioned(_)));
2589
2590 assert!(registry.get("rates").is_none());
2592 }
2593}