1use std::any::Any;
17use std::collections::HashMap;
18use std::fmt::{self, Debug, Formatter};
19use std::sync::Arc;
20
21use parking_lot::RwLock;
22
23use std::collections::BTreeMap;
24
25use arrow::compute::take;
26use arrow::row::{RowConverter, SortField};
27use arrow_array::{RecordBatch, UInt32Array};
28use arrow_schema::{Schema, SchemaRef};
29use async_trait::async_trait;
30use datafusion::execution::{SendableRecordBatchStream, SessionState, TaskContext};
31use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNode};
32use datafusion::physical_expr::{EquivalenceProperties, LexOrdering, Partitioning};
33use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
34use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
35use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
36use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
37use datafusion_common::{DataFusionError, Result};
38use datafusion_expr::Expr;
39use futures::StreamExt;
40use laminar_core::lookup::foyer_cache::FoyerMemoryCache;
41use laminar_core::lookup::source::LookupSourceDyn;
42use tokio::sync::Semaphore;
43
44use super::lookup_join::{LookupJoinNode, LookupJoinType};
45
46#[derive(Default)]
53pub struct LookupTableRegistry {
54 tables: RwLock<HashMap<String, RegisteredLookup>>,
55}
56
57pub enum RegisteredLookup {
60 Snapshot(Arc<LookupSnapshot>),
62 Partial(Arc<PartialLookupState>),
64 Versioned(Arc<VersionedLookupState>),
66}
67
68pub struct LookupSnapshot {
70 pub batch: RecordBatch,
72 pub key_columns: Vec<String>,
74}
75
76pub struct VersionedLookupState {
83 pub batch: RecordBatch,
85 pub index: Arc<VersionedIndex>,
87 pub key_columns: Vec<String>,
89 pub version_column: String,
91 pub stream_time_column: String,
93}
94
95pub struct PartialLookupState {
97 pub foyer_cache: Arc<FoyerMemoryCache>,
99 pub schema: SchemaRef,
101 pub key_columns: Vec<String>,
103 pub key_sort_fields: Vec<SortField>,
105 pub source: Option<Arc<dyn LookupSourceDyn>>,
107 pub fetch_semaphore: Arc<Semaphore>,
109}
110
111impl LookupTableRegistry {
112 #[must_use]
114 pub fn new() -> Self {
115 Self::default()
116 }
117
118 pub fn register(&self, name: &str, snapshot: LookupSnapshot) {
124 self.tables.write().insert(
125 name.to_lowercase(),
126 RegisteredLookup::Snapshot(Arc::new(snapshot)),
127 );
128 }
129
130 pub fn register_partial(&self, name: &str, state: PartialLookupState) {
136 self.tables.write().insert(
137 name.to_lowercase(),
138 RegisteredLookup::Partial(Arc::new(state)),
139 );
140 }
141
142 pub fn register_versioned(&self, name: &str, state: VersionedLookupState) {
148 self.tables.write().insert(
149 name.to_lowercase(),
150 RegisteredLookup::Versioned(Arc::new(state)),
151 );
152 }
153
154 pub fn unregister(&self, name: &str) {
160 self.tables.write().remove(&name.to_lowercase());
161 }
162
163 #[must_use]
169 pub fn get(&self, name: &str) -> Option<Arc<LookupSnapshot>> {
170 let tables = self.tables.read();
171 match tables.get(&name.to_lowercase())? {
172 RegisteredLookup::Snapshot(s) => Some(Arc::clone(s)),
173 RegisteredLookup::Partial(_) | RegisteredLookup::Versioned(_) => None,
174 }
175 }
176
177 pub fn get_entry(&self, name: &str) -> Option<RegisteredLookup> {
183 let tables = self.tables.read();
184 tables.get(&name.to_lowercase()).map(|e| match e {
185 RegisteredLookup::Snapshot(s) => RegisteredLookup::Snapshot(Arc::clone(s)),
186 RegisteredLookup::Partial(p) => RegisteredLookup::Partial(Arc::clone(p)),
187 RegisteredLookup::Versioned(v) => RegisteredLookup::Versioned(Arc::clone(v)),
188 })
189 }
190}
191
192struct HashIndex {
196 map: HashMap<Box<[u8]>, Vec<u32>>,
197}
198
199impl HashIndex {
200 fn build(batch: &RecordBatch, key_indices: &[usize]) -> Result<Self> {
205 if batch.num_rows() == 0 {
206 return Ok(Self {
207 map: HashMap::new(),
208 });
209 }
210
211 let sort_fields: Vec<SortField> = key_indices
212 .iter()
213 .map(|&i| SortField::new(batch.schema().field(i).data_type().clone()))
214 .collect();
215 let converter = RowConverter::new(sort_fields)?;
216
217 let key_cols: Vec<_> = key_indices
218 .iter()
219 .map(|&i| batch.column(i).clone())
220 .collect();
221 let rows = converter.convert_columns(&key_cols)?;
222
223 let num_rows = batch.num_rows();
224 let mut map: HashMap<Box<[u8]>, Vec<u32>> = HashMap::with_capacity(num_rows);
225 #[allow(clippy::cast_possible_truncation)] for i in 0..num_rows {
227 map.entry(Box::from(rows.row(i).as_ref()))
228 .or_default()
229 .push(i as u32);
230 }
231
232 Ok(Self { map })
233 }
234
235 fn probe(&self, key: &[u8]) -> Option<&[u32]> {
236 self.map.get(key).map(Vec::as_slice)
237 }
238}
239
240#[derive(Default)]
246pub struct VersionedIndex {
247 map: HashMap<Box<[u8]>, BTreeMap<i64, Vec<u32>>>,
248}
249
250impl VersionedIndex {
251 pub fn build(
261 batch: &RecordBatch,
262 key_indices: &[usize],
263 version_col_idx: usize,
264 ) -> Result<Self> {
265 if batch.num_rows() == 0 {
266 return Ok(Self {
267 map: HashMap::new(),
268 });
269 }
270
271 let sort_fields: Vec<SortField> = key_indices
272 .iter()
273 .map(|&i| SortField::new(batch.schema().field(i).data_type().clone()))
274 .collect();
275 let converter = RowConverter::new(sort_fields)?;
276
277 let key_cols: Vec<_> = key_indices
278 .iter()
279 .map(|&i| batch.column(i).clone())
280 .collect();
281 let rows = converter.convert_columns(&key_cols)?;
282
283 let timestamps = extract_i64_timestamps(batch.column(version_col_idx))?;
284
285 let num_rows = batch.num_rows();
286 let mut map: HashMap<Box<[u8]>, BTreeMap<i64, Vec<u32>>> = HashMap::with_capacity(num_rows);
287 #[allow(clippy::cast_possible_truncation)]
288 for (i, ts_opt) in timestamps.iter().enumerate() {
289 let Some(version_ts) = ts_opt else { continue };
291 if key_cols.iter().any(|c| c.is_null(i)) {
292 continue;
293 }
294 let key = Box::from(rows.row(i).as_ref());
295 map.entry(key)
296 .or_default()
297 .entry(*version_ts)
298 .or_default()
299 .push(i as u32);
300 }
301
302 Ok(Self { map })
303 }
304
305 fn probe_at_time(&self, key: &[u8], event_ts: i64) -> Option<u32> {
308 let versions = self.map.get(key)?;
309 let (_, indices) = versions.range(..=event_ts).next_back()?;
310 indices.last().copied()
311 }
312}
313
314fn extract_i64_timestamps(col: &dyn arrow_array::Array) -> Result<Vec<Option<i64>>> {
320 use arrow_array::{
321 Float64Array, Int64Array, TimestampMicrosecondArray, TimestampMillisecondArray,
322 TimestampNanosecondArray, TimestampSecondArray,
323 };
324 use arrow_schema::{DataType, TimeUnit};
325
326 let n = col.len();
327 let mut out = Vec::with_capacity(n);
328 macro_rules! extract_typed {
329 ($arr_type:ty, $scale:expr) => {{
330 let arr = col.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
331 DataFusionError::Internal(concat!("expected ", stringify!($arr_type)).into())
332 })?;
333 for i in 0..n {
334 out.push(if col.is_null(i) {
335 None
336 } else {
337 Some(arr.value(i) * $scale)
338 });
339 }
340 }};
341 }
342
343 match col.data_type() {
344 DataType::Int64 => extract_typed!(Int64Array, 1),
345 DataType::Timestamp(TimeUnit::Millisecond, _) => {
346 extract_typed!(TimestampMillisecondArray, 1);
347 }
348 DataType::Timestamp(TimeUnit::Microsecond, _) => {
349 let arr = col
350 .as_any()
351 .downcast_ref::<TimestampMicrosecondArray>()
352 .ok_or_else(|| {
353 DataFusionError::Internal("expected TimestampMicrosecondArray".into())
354 })?;
355 for i in 0..n {
356 out.push(if col.is_null(i) {
357 None
358 } else {
359 Some(arr.value(i) / 1000)
360 });
361 }
362 }
363 DataType::Timestamp(TimeUnit::Second, _) => {
364 extract_typed!(TimestampSecondArray, 1000);
365 }
366 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
367 let arr = col
368 .as_any()
369 .downcast_ref::<TimestampNanosecondArray>()
370 .ok_or_else(|| {
371 DataFusionError::Internal("expected TimestampNanosecondArray".into())
372 })?;
373 for i in 0..n {
374 out.push(if col.is_null(i) {
375 None
376 } else {
377 Some(arr.value(i) / 1_000_000)
378 });
379 }
380 }
381 DataType::Float64 => {
382 let arr = col
383 .as_any()
384 .downcast_ref::<Float64Array>()
385 .ok_or_else(|| DataFusionError::Internal("expected Float64Array".into()))?;
386 #[allow(clippy::cast_possible_truncation)]
387 for i in 0..n {
388 out.push(if col.is_null(i) {
389 None
390 } else {
391 Some(arr.value(i) as i64)
392 });
393 }
394 }
395 other => {
396 return Err(DataFusionError::Plan(format!(
397 "unsupported timestamp type for temporal join: {other:?}"
398 )));
399 }
400 }
401
402 Ok(out)
403}
404
405pub struct LookupJoinExec {
410 input: Arc<dyn ExecutionPlan>,
411 index: Arc<HashIndex>,
412 lookup_batch: Arc<RecordBatch>,
413 stream_key_indices: Vec<usize>,
414 join_type: LookupJoinType,
415 schema: SchemaRef,
416 properties: PlanProperties,
417 key_sort_fields: Vec<SortField>,
419 stream_field_count: usize,
420}
421
422impl LookupJoinExec {
423 #[allow(clippy::needless_pass_by_value)] pub fn try_new(
433 input: Arc<dyn ExecutionPlan>,
434 lookup_batch: RecordBatch,
435 stream_key_indices: Vec<usize>,
436 lookup_key_indices: Vec<usize>,
437 join_type: LookupJoinType,
438 output_schema: SchemaRef,
439 ) -> Result<Self> {
440 let index = HashIndex::build(&lookup_batch, &lookup_key_indices)?;
441
442 let key_sort_fields: Vec<SortField> = lookup_key_indices
443 .iter()
444 .map(|&i| SortField::new(lookup_batch.schema().field(i).data_type().clone()))
445 .collect();
446
447 let output_schema = if join_type == LookupJoinType::LeftOuter {
450 let stream_count = input.schema().fields().len();
451 let mut fields = output_schema.fields().to_vec();
452 for f in &mut fields[stream_count..] {
453 if !f.is_nullable() {
454 *f = Arc::new(f.as_ref().clone().with_nullable(true));
455 }
456 }
457 Arc::new(Schema::new_with_metadata(
458 fields,
459 output_schema.metadata().clone(),
460 ))
461 } else {
462 output_schema
463 };
464
465 let properties = PlanProperties::new(
466 EquivalenceProperties::new(Arc::clone(&output_schema)),
467 Partitioning::UnknownPartitioning(1),
468 EmissionType::Incremental,
469 Boundedness::Unbounded {
470 requires_infinite_memory: false,
471 },
472 );
473
474 let stream_field_count = input.schema().fields().len();
475
476 Ok(Self {
477 input,
478 index: Arc::new(index),
479 lookup_batch: Arc::new(lookup_batch),
480 stream_key_indices,
481 join_type,
482 schema: output_schema,
483 properties,
484 key_sort_fields,
485 stream_field_count,
486 })
487 }
488}
489
490impl Debug for LookupJoinExec {
491 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
492 f.debug_struct("LookupJoinExec")
493 .field("join_type", &self.join_type)
494 .field("stream_keys", &self.stream_key_indices)
495 .field("lookup_rows", &self.lookup_batch.num_rows())
496 .finish_non_exhaustive()
497 }
498}
499
500impl DisplayAs for LookupJoinExec {
501 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
502 match t {
503 DisplayFormatType::Default | DisplayFormatType::Verbose => {
504 write!(
505 f,
506 "LookupJoinExec: type={}, stream_keys={:?}, lookup_rows={}",
507 self.join_type,
508 self.stream_key_indices,
509 self.lookup_batch.num_rows(),
510 )
511 }
512 DisplayFormatType::TreeRender => write!(f, "LookupJoinExec"),
513 }
514 }
515}
516
517impl ExecutionPlan for LookupJoinExec {
518 fn name(&self) -> &'static str {
519 "LookupJoinExec"
520 }
521
522 fn as_any(&self) -> &dyn Any {
523 self
524 }
525
526 fn schema(&self) -> SchemaRef {
527 Arc::clone(&self.schema)
528 }
529
530 fn properties(&self) -> &PlanProperties {
531 &self.properties
532 }
533
534 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
535 vec![&self.input]
536 }
537
538 fn with_new_children(
539 self: Arc<Self>,
540 mut children: Vec<Arc<dyn ExecutionPlan>>,
541 ) -> Result<Arc<dyn ExecutionPlan>> {
542 if children.len() != 1 {
543 return Err(DataFusionError::Plan(
544 "LookupJoinExec requires exactly one child".into(),
545 ));
546 }
547 Ok(Arc::new(Self {
548 input: children.swap_remove(0),
549 index: Arc::clone(&self.index),
550 lookup_batch: Arc::clone(&self.lookup_batch),
551 stream_key_indices: self.stream_key_indices.clone(),
552 join_type: self.join_type,
553 schema: Arc::clone(&self.schema),
554 properties: self.properties.clone(),
555 key_sort_fields: self.key_sort_fields.clone(),
556 stream_field_count: self.stream_field_count,
557 }))
558 }
559
560 fn execute(
561 &self,
562 partition: usize,
563 context: Arc<TaskContext>,
564 ) -> Result<SendableRecordBatchStream> {
565 let input_stream = self.input.execute(partition, context)?;
566 let converter = RowConverter::new(self.key_sort_fields.clone())?;
567 let index = Arc::clone(&self.index);
568 let lookup_batch = Arc::clone(&self.lookup_batch);
569 let stream_key_indices = self.stream_key_indices.clone();
570 let join_type = self.join_type;
571 let schema = self.schema();
572 let stream_field_count = self.stream_field_count;
573
574 let output = input_stream.map(move |result| {
575 let batch = result?;
576 if batch.num_rows() == 0 {
577 return Ok(RecordBatch::new_empty(Arc::clone(&schema)));
578 }
579 probe_batch(
580 &batch,
581 &converter,
582 &index,
583 &lookup_batch,
584 &stream_key_indices,
585 join_type,
586 &schema,
587 stream_field_count,
588 )
589 });
590
591 Ok(Box::pin(RecordBatchStreamAdapter::new(
592 self.schema(),
593 output,
594 )))
595 }
596}
597
598impl datafusion::physical_plan::ExecutionPlanProperties for LookupJoinExec {
599 fn output_partitioning(&self) -> &Partitioning {
600 self.properties.output_partitioning()
601 }
602
603 fn output_ordering(&self) -> Option<&LexOrdering> {
604 self.properties.output_ordering()
605 }
606
607 fn boundedness(&self) -> Boundedness {
608 Boundedness::Unbounded {
609 requires_infinite_memory: false,
610 }
611 }
612
613 fn pipeline_behavior(&self) -> EmissionType {
614 EmissionType::Incremental
615 }
616
617 fn equivalence_properties(&self) -> &EquivalenceProperties {
618 self.properties.equivalence_properties()
619 }
620}
621
622#[allow(clippy::too_many_arguments)]
627fn probe_batch(
628 stream_batch: &RecordBatch,
629 converter: &RowConverter,
630 index: &HashIndex,
631 lookup_batch: &RecordBatch,
632 stream_key_indices: &[usize],
633 join_type: LookupJoinType,
634 output_schema: &SchemaRef,
635 stream_field_count: usize,
636) -> Result<RecordBatch> {
637 let key_cols: Vec<_> = stream_key_indices
638 .iter()
639 .map(|&i| stream_batch.column(i).clone())
640 .collect();
641 let rows = converter.convert_columns(&key_cols)?;
642
643 let num_rows = stream_batch.num_rows();
644 let mut stream_indices: Vec<u32> = Vec::with_capacity(num_rows);
645 let mut lookup_indices: Vec<Option<u32>> = Vec::with_capacity(num_rows);
646
647 #[allow(clippy::cast_possible_truncation)] for row in 0..num_rows {
649 if key_cols.iter().any(|c| c.is_null(row)) {
651 if join_type == LookupJoinType::LeftOuter {
652 stream_indices.push(row as u32);
653 lookup_indices.push(None);
654 }
655 continue;
656 }
657
658 let key = rows.row(row);
659 match index.probe(key.as_ref()) {
660 Some(matches) => {
661 for &lookup_row in matches {
662 stream_indices.push(row as u32);
663 lookup_indices.push(Some(lookup_row));
664 }
665 }
666 None if join_type == LookupJoinType::LeftOuter => {
667 stream_indices.push(row as u32);
668 lookup_indices.push(None);
669 }
670 None => {}
671 }
672 }
673
674 if stream_indices.is_empty() {
675 return Ok(RecordBatch::new_empty(Arc::clone(output_schema)));
676 }
677
678 let take_stream = UInt32Array::from(stream_indices);
680 let mut columns = Vec::with_capacity(output_schema.fields().len());
681
682 for col in stream_batch.columns() {
683 columns.push(take(col.as_ref(), &take_stream, None)?);
684 }
685
686 let take_lookup: UInt32Array = lookup_indices.into_iter().collect();
688 for col in lookup_batch.columns() {
689 columns.push(take(col.as_ref(), &take_lookup, None)?);
690 }
691
692 debug_assert_eq!(
693 columns.len(),
694 stream_field_count + lookup_batch.num_columns(),
695 "output column count mismatch"
696 );
697
698 Ok(RecordBatch::try_new(Arc::clone(output_schema), columns)?)
699}
700
701pub struct VersionedLookupJoinExec {
707 input: Arc<dyn ExecutionPlan>,
708 index: Arc<VersionedIndex>,
709 table_batch: Arc<RecordBatch>,
710 stream_key_indices: Vec<usize>,
711 stream_time_col_idx: usize,
712 join_type: LookupJoinType,
713 schema: SchemaRef,
714 properties: PlanProperties,
715 key_sort_fields: Vec<SortField>,
716 stream_field_count: usize,
717}
718
719impl VersionedLookupJoinExec {
720 #[allow(clippy::too_many_arguments, clippy::needless_pass_by_value)]
730 pub fn try_new(
731 input: Arc<dyn ExecutionPlan>,
732 table_batch: RecordBatch,
733 index: Arc<VersionedIndex>,
734 stream_key_indices: Vec<usize>,
735 stream_time_col_idx: usize,
736 join_type: LookupJoinType,
737 output_schema: SchemaRef,
738 key_sort_fields: Vec<SortField>,
739 ) -> Result<Self> {
740 let output_schema = if join_type == LookupJoinType::LeftOuter {
741 let stream_count = input.schema().fields().len();
742 let mut fields = output_schema.fields().to_vec();
743 for f in &mut fields[stream_count..] {
744 if !f.is_nullable() {
745 *f = Arc::new(f.as_ref().clone().with_nullable(true));
746 }
747 }
748 Arc::new(Schema::new_with_metadata(
749 fields,
750 output_schema.metadata().clone(),
751 ))
752 } else {
753 output_schema
754 };
755
756 let properties = PlanProperties::new(
757 EquivalenceProperties::new(Arc::clone(&output_schema)),
758 Partitioning::UnknownPartitioning(1),
759 EmissionType::Incremental,
760 Boundedness::Unbounded {
761 requires_infinite_memory: false,
762 },
763 );
764
765 let stream_field_count = input.schema().fields().len();
766
767 Ok(Self {
768 input,
769 index,
770 table_batch: Arc::new(table_batch),
771 stream_key_indices,
772 stream_time_col_idx,
773 join_type,
774 schema: output_schema,
775 properties,
776 key_sort_fields,
777 stream_field_count,
778 })
779 }
780}
781
782impl Debug for VersionedLookupJoinExec {
783 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
784 f.debug_struct("VersionedLookupJoinExec")
785 .field("join_type", &self.join_type)
786 .field("stream_keys", &self.stream_key_indices)
787 .field("table_rows", &self.table_batch.num_rows())
788 .finish_non_exhaustive()
789 }
790}
791
792impl DisplayAs for VersionedLookupJoinExec {
793 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
794 match t {
795 DisplayFormatType::Default | DisplayFormatType::Verbose => {
796 write!(
797 f,
798 "VersionedLookupJoinExec: type={}, stream_keys={:?}, table_rows={}",
799 self.join_type,
800 self.stream_key_indices,
801 self.table_batch.num_rows(),
802 )
803 }
804 DisplayFormatType::TreeRender => write!(f, "VersionedLookupJoinExec"),
805 }
806 }
807}
808
809impl ExecutionPlan for VersionedLookupJoinExec {
810 fn name(&self) -> &'static str {
811 "VersionedLookupJoinExec"
812 }
813
814 fn as_any(&self) -> &dyn Any {
815 self
816 }
817
818 fn schema(&self) -> SchemaRef {
819 Arc::clone(&self.schema)
820 }
821
822 fn properties(&self) -> &PlanProperties {
823 &self.properties
824 }
825
826 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
827 vec![&self.input]
828 }
829
830 fn with_new_children(
831 self: Arc<Self>,
832 mut children: Vec<Arc<dyn ExecutionPlan>>,
833 ) -> Result<Arc<dyn ExecutionPlan>> {
834 if children.len() != 1 {
835 return Err(DataFusionError::Plan(
836 "VersionedLookupJoinExec requires exactly one child".into(),
837 ));
838 }
839 Ok(Arc::new(Self {
840 input: children.swap_remove(0),
841 index: Arc::clone(&self.index),
842 table_batch: Arc::clone(&self.table_batch),
843 stream_key_indices: self.stream_key_indices.clone(),
844 stream_time_col_idx: self.stream_time_col_idx,
845 join_type: self.join_type,
846 schema: Arc::clone(&self.schema),
847 properties: self.properties.clone(),
848 key_sort_fields: self.key_sort_fields.clone(),
849 stream_field_count: self.stream_field_count,
850 }))
851 }
852
853 fn execute(
854 &self,
855 partition: usize,
856 context: Arc<TaskContext>,
857 ) -> Result<SendableRecordBatchStream> {
858 let input_stream = self.input.execute(partition, context)?;
859 let converter = RowConverter::new(self.key_sort_fields.clone())?;
860 let index = Arc::clone(&self.index);
861 let table_batch = Arc::clone(&self.table_batch);
862 let stream_key_indices = self.stream_key_indices.clone();
863 let stream_time_col_idx = self.stream_time_col_idx;
864 let join_type = self.join_type;
865 let schema = self.schema();
866 let stream_field_count = self.stream_field_count;
867
868 let output = input_stream.map(move |result| {
869 let batch = result?;
870 if batch.num_rows() == 0 {
871 return Ok(RecordBatch::new_empty(Arc::clone(&schema)));
872 }
873 probe_versioned_batch(
874 &batch,
875 &converter,
876 &index,
877 &table_batch,
878 &stream_key_indices,
879 stream_time_col_idx,
880 join_type,
881 &schema,
882 stream_field_count,
883 )
884 });
885
886 Ok(Box::pin(RecordBatchStreamAdapter::new(
887 self.schema(),
888 output,
889 )))
890 }
891}
892
893impl datafusion::physical_plan::ExecutionPlanProperties for VersionedLookupJoinExec {
894 fn output_partitioning(&self) -> &Partitioning {
895 self.properties.output_partitioning()
896 }
897
898 fn output_ordering(&self) -> Option<&LexOrdering> {
899 self.properties.output_ordering()
900 }
901
902 fn boundedness(&self) -> Boundedness {
903 Boundedness::Unbounded {
904 requires_infinite_memory: false,
905 }
906 }
907
908 fn pipeline_behavior(&self) -> EmissionType {
909 EmissionType::Incremental
910 }
911
912 fn equivalence_properties(&self) -> &EquivalenceProperties {
913 self.properties.equivalence_properties()
914 }
915}
916
917#[allow(clippy::too_many_arguments)]
920fn probe_versioned_batch(
921 stream_batch: &RecordBatch,
922 converter: &RowConverter,
923 index: &VersionedIndex,
924 table_batch: &RecordBatch,
925 stream_key_indices: &[usize],
926 stream_time_col_idx: usize,
927 join_type: LookupJoinType,
928 output_schema: &SchemaRef,
929 stream_field_count: usize,
930) -> Result<RecordBatch> {
931 let key_cols: Vec<_> = stream_key_indices
932 .iter()
933 .map(|&i| stream_batch.column(i).clone())
934 .collect();
935 let rows = converter.convert_columns(&key_cols)?;
936 let event_timestamps =
937 extract_i64_timestamps(stream_batch.column(stream_time_col_idx).as_ref())?;
938
939 let num_rows = stream_batch.num_rows();
940 let mut stream_indices: Vec<u32> = Vec::with_capacity(num_rows);
941 let mut lookup_indices: Vec<Option<u32>> = Vec::with_capacity(num_rows);
942
943 #[allow(clippy::cast_possible_truncation)]
944 for (row, event_ts_opt) in event_timestamps.iter().enumerate() {
945 if key_cols.iter().any(|c| c.is_null(row)) || event_ts_opt.is_none() {
947 if join_type == LookupJoinType::LeftOuter {
948 stream_indices.push(row as u32);
949 lookup_indices.push(None);
950 }
951 continue;
952 }
953
954 let key = rows.row(row);
955 let event_ts = event_ts_opt.unwrap();
956 match index.probe_at_time(key.as_ref(), event_ts) {
957 Some(table_row_idx) => {
958 stream_indices.push(row as u32);
959 lookup_indices.push(Some(table_row_idx));
960 }
961 None if join_type == LookupJoinType::LeftOuter => {
962 stream_indices.push(row as u32);
963 lookup_indices.push(None);
964 }
965 None => {}
966 }
967 }
968
969 if stream_indices.is_empty() {
970 return Ok(RecordBatch::new_empty(Arc::clone(output_schema)));
971 }
972
973 let take_stream = UInt32Array::from(stream_indices);
974 let mut columns = Vec::with_capacity(output_schema.fields().len());
975
976 for col in stream_batch.columns() {
977 columns.push(take(col.as_ref(), &take_stream, None)?);
978 }
979
980 let take_lookup: UInt32Array = lookup_indices.into_iter().collect();
981 for col in table_batch.columns() {
982 columns.push(take(col.as_ref(), &take_lookup, None)?);
983 }
984
985 debug_assert_eq!(
986 columns.len(),
987 stream_field_count + table_batch.num_columns(),
988 "output column count mismatch"
989 );
990
991 Ok(RecordBatch::try_new(Arc::clone(output_schema), columns)?)
992}
993
994pub struct PartialLookupJoinExec {
1000 input: Arc<dyn ExecutionPlan>,
1001 foyer_cache: Arc<FoyerMemoryCache>,
1002 stream_key_indices: Vec<usize>,
1003 join_type: LookupJoinType,
1004 schema: SchemaRef,
1005 properties: PlanProperties,
1006 key_sort_fields: Vec<SortField>,
1007 stream_field_count: usize,
1008 lookup_schema: SchemaRef,
1009 source: Option<Arc<dyn LookupSourceDyn>>,
1010 fetch_semaphore: Arc<Semaphore>,
1011}
1012
1013impl PartialLookupJoinExec {
1014 pub fn try_new(
1020 input: Arc<dyn ExecutionPlan>,
1021 foyer_cache: Arc<FoyerMemoryCache>,
1022 stream_key_indices: Vec<usize>,
1023 key_sort_fields: Vec<SortField>,
1024 join_type: LookupJoinType,
1025 lookup_schema: SchemaRef,
1026 output_schema: SchemaRef,
1027 ) -> Result<Self> {
1028 Self::try_new_with_source(
1029 input,
1030 foyer_cache,
1031 stream_key_indices,
1032 key_sort_fields,
1033 join_type,
1034 lookup_schema,
1035 output_schema,
1036 None,
1037 Arc::new(Semaphore::new(64)),
1038 )
1039 }
1040
1041 #[allow(clippy::too_many_arguments)]
1047 pub fn try_new_with_source(
1048 input: Arc<dyn ExecutionPlan>,
1049 foyer_cache: Arc<FoyerMemoryCache>,
1050 stream_key_indices: Vec<usize>,
1051 key_sort_fields: Vec<SortField>,
1052 join_type: LookupJoinType,
1053 lookup_schema: SchemaRef,
1054 output_schema: SchemaRef,
1055 source: Option<Arc<dyn LookupSourceDyn>>,
1056 fetch_semaphore: Arc<Semaphore>,
1057 ) -> Result<Self> {
1058 let output_schema = if join_type == LookupJoinType::LeftOuter {
1059 let stream_count = input.schema().fields().len();
1060 let mut fields = output_schema.fields().to_vec();
1061 for f in &mut fields[stream_count..] {
1062 if !f.is_nullable() {
1063 *f = Arc::new(f.as_ref().clone().with_nullable(true));
1064 }
1065 }
1066 Arc::new(Schema::new_with_metadata(
1067 fields,
1068 output_schema.metadata().clone(),
1069 ))
1070 } else {
1071 output_schema
1072 };
1073
1074 let properties = PlanProperties::new(
1075 EquivalenceProperties::new(Arc::clone(&output_schema)),
1076 Partitioning::UnknownPartitioning(1),
1077 EmissionType::Incremental,
1078 Boundedness::Unbounded {
1079 requires_infinite_memory: false,
1080 },
1081 );
1082
1083 let stream_field_count = input.schema().fields().len();
1084
1085 Ok(Self {
1086 input,
1087 foyer_cache,
1088 stream_key_indices,
1089 join_type,
1090 schema: output_schema,
1091 properties,
1092 key_sort_fields,
1093 stream_field_count,
1094 lookup_schema,
1095 source,
1096 fetch_semaphore,
1097 })
1098 }
1099}
1100
1101impl Debug for PartialLookupJoinExec {
1102 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1103 f.debug_struct("PartialLookupJoinExec")
1104 .field("join_type", &self.join_type)
1105 .field("stream_keys", &self.stream_key_indices)
1106 .field("cache_table_id", &self.foyer_cache.table_id())
1107 .finish_non_exhaustive()
1108 }
1109}
1110
1111impl DisplayAs for PartialLookupJoinExec {
1112 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
1113 match t {
1114 DisplayFormatType::Default | DisplayFormatType::Verbose => {
1115 write!(
1116 f,
1117 "PartialLookupJoinExec: type={}, stream_keys={:?}, cache_entries={}",
1118 self.join_type,
1119 self.stream_key_indices,
1120 self.foyer_cache.len(),
1121 )
1122 }
1123 DisplayFormatType::TreeRender => write!(f, "PartialLookupJoinExec"),
1124 }
1125 }
1126}
1127
1128impl ExecutionPlan for PartialLookupJoinExec {
1129 fn name(&self) -> &'static str {
1130 "PartialLookupJoinExec"
1131 }
1132
1133 fn as_any(&self) -> &dyn Any {
1134 self
1135 }
1136
1137 fn schema(&self) -> SchemaRef {
1138 Arc::clone(&self.schema)
1139 }
1140
1141 fn properties(&self) -> &PlanProperties {
1142 &self.properties
1143 }
1144
1145 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1146 vec![&self.input]
1147 }
1148
1149 fn with_new_children(
1150 self: Arc<Self>,
1151 mut children: Vec<Arc<dyn ExecutionPlan>>,
1152 ) -> Result<Arc<dyn ExecutionPlan>> {
1153 if children.len() != 1 {
1154 return Err(DataFusionError::Plan(
1155 "PartialLookupJoinExec requires exactly one child".into(),
1156 ));
1157 }
1158 Ok(Arc::new(Self {
1159 input: children.swap_remove(0),
1160 foyer_cache: Arc::clone(&self.foyer_cache),
1161 stream_key_indices: self.stream_key_indices.clone(),
1162 join_type: self.join_type,
1163 schema: Arc::clone(&self.schema),
1164 properties: self.properties.clone(),
1165 key_sort_fields: self.key_sort_fields.clone(),
1166 stream_field_count: self.stream_field_count,
1167 lookup_schema: Arc::clone(&self.lookup_schema),
1168 source: self.source.clone(),
1169 fetch_semaphore: Arc::clone(&self.fetch_semaphore),
1170 }))
1171 }
1172
1173 fn execute(
1174 &self,
1175 partition: usize,
1176 context: Arc<TaskContext>,
1177 ) -> Result<SendableRecordBatchStream> {
1178 let input_stream = self.input.execute(partition, context)?;
1179 let converter = Arc::new(RowConverter::new(self.key_sort_fields.clone())?);
1180 let foyer_cache = Arc::clone(&self.foyer_cache);
1181 let stream_key_indices = self.stream_key_indices.clone();
1182 let join_type = self.join_type;
1183 let schema = self.schema();
1184 let stream_field_count = self.stream_field_count;
1185 let lookup_schema = Arc::clone(&self.lookup_schema);
1186 let source = self.source.clone();
1187 let fetch_semaphore = Arc::clone(&self.fetch_semaphore);
1188
1189 let output = input_stream.then(move |result| {
1190 let foyer_cache = Arc::clone(&foyer_cache);
1191 let converter = Arc::clone(&converter);
1192 let stream_key_indices = stream_key_indices.clone();
1193 let schema = Arc::clone(&schema);
1194 let lookup_schema = Arc::clone(&lookup_schema);
1195 let source = source.clone();
1196 let fetch_semaphore = Arc::clone(&fetch_semaphore);
1197 async move {
1198 let batch = result?;
1199 if batch.num_rows() == 0 {
1200 return Ok(RecordBatch::new_empty(Arc::clone(&schema)));
1201 }
1202 probe_partial_batch_with_fallback(
1203 &batch,
1204 &converter,
1205 &foyer_cache,
1206 &stream_key_indices,
1207 join_type,
1208 &schema,
1209 stream_field_count,
1210 &lookup_schema,
1211 source.as_deref(),
1212 &fetch_semaphore,
1213 )
1214 .await
1215 }
1216 });
1217
1218 Ok(Box::pin(RecordBatchStreamAdapter::new(
1219 self.schema(),
1220 output,
1221 )))
1222 }
1223}
1224
1225impl datafusion::physical_plan::ExecutionPlanProperties for PartialLookupJoinExec {
1226 fn output_partitioning(&self) -> &Partitioning {
1227 self.properties.output_partitioning()
1228 }
1229
1230 fn output_ordering(&self) -> Option<&LexOrdering> {
1231 self.properties.output_ordering()
1232 }
1233
1234 fn boundedness(&self) -> Boundedness {
1235 Boundedness::Unbounded {
1236 requires_infinite_memory: false,
1237 }
1238 }
1239
1240 fn pipeline_behavior(&self) -> EmissionType {
1241 EmissionType::Incremental
1242 }
1243
1244 fn equivalence_properties(&self) -> &EquivalenceProperties {
1245 self.properties.equivalence_properties()
1246 }
1247}
1248
1249#[allow(clippy::too_many_arguments)]
1253async fn probe_partial_batch_with_fallback(
1254 stream_batch: &RecordBatch,
1255 converter: &RowConverter,
1256 foyer_cache: &FoyerMemoryCache,
1257 stream_key_indices: &[usize],
1258 join_type: LookupJoinType,
1259 output_schema: &SchemaRef,
1260 stream_field_count: usize,
1261 lookup_schema: &SchemaRef,
1262 source: Option<&dyn LookupSourceDyn>,
1263 fetch_semaphore: &Semaphore,
1264) -> Result<RecordBatch> {
1265 let key_cols: Vec<_> = stream_key_indices
1266 .iter()
1267 .map(|&i| stream_batch.column(i).clone())
1268 .collect();
1269 let rows = converter.convert_columns(&key_cols)?;
1270
1271 let num_rows = stream_batch.num_rows();
1272 let mut stream_indices: Vec<u32> = Vec::with_capacity(num_rows);
1273 let mut lookup_batches: Vec<Option<RecordBatch>> = Vec::with_capacity(num_rows);
1274 let mut miss_keys: Vec<(usize, Vec<u8>)> = Vec::new();
1275
1276 #[allow(clippy::cast_possible_truncation)]
1277 for row in 0..num_rows {
1278 if key_cols.iter().any(|c| c.is_null(row)) {
1280 if join_type == LookupJoinType::LeftOuter {
1281 stream_indices.push(row as u32);
1282 lookup_batches.push(None);
1283 }
1284 continue;
1285 }
1286
1287 let key = rows.row(row);
1288 let result = foyer_cache.get_cached(key.as_ref());
1289 if let Some(batch) = result.into_batch() {
1290 stream_indices.push(row as u32);
1291 lookup_batches.push(Some(batch));
1292 } else {
1293 let idx = stream_indices.len();
1294 stream_indices.push(row as u32);
1295 lookup_batches.push(None);
1296 miss_keys.push((idx, key.as_ref().to_vec()));
1297 }
1298 }
1299
1300 if let Some(source) = source {
1302 if !miss_keys.is_empty() {
1303 let _permit = fetch_semaphore
1304 .acquire()
1305 .await
1306 .map_err(|_| DataFusionError::Internal("fetch semaphore closed".into()))?;
1307
1308 let key_refs: Vec<&[u8]> = miss_keys.iter().map(|(_, k)| k.as_slice()).collect();
1309 let source_results = source.query_batch(&key_refs, &[], &[]).await;
1310
1311 match source_results {
1312 Ok(results) => {
1313 for ((idx, key_bytes), maybe_batch) in miss_keys.iter().zip(results.into_iter())
1314 {
1315 if let Some(batch) = maybe_batch {
1316 foyer_cache.insert(key_bytes, batch.clone());
1317 lookup_batches[*idx] = Some(batch);
1318 }
1319 }
1320 }
1321 Err(e) => {
1322 tracing::warn!(error = %e, "source fallback failed, serving cache-only results");
1323 }
1324 }
1325 }
1326 }
1327
1328 if join_type == LookupJoinType::Inner {
1330 let mut write = 0;
1331 for read in 0..stream_indices.len() {
1332 if lookup_batches[read].is_some() {
1333 stream_indices[write] = stream_indices[read];
1334 lookup_batches.swap(write, read);
1335 write += 1;
1336 }
1337 }
1338 stream_indices.truncate(write);
1339 lookup_batches.truncate(write);
1340 }
1341
1342 if stream_indices.is_empty() {
1343 return Ok(RecordBatch::new_empty(Arc::clone(output_schema)));
1344 }
1345
1346 let take_indices = UInt32Array::from(stream_indices);
1347 let mut columns = Vec::with_capacity(output_schema.fields().len());
1348
1349 for col in stream_batch.columns() {
1350 columns.push(take(col.as_ref(), &take_indices, None)?);
1351 }
1352
1353 let lookup_col_count = lookup_schema.fields().len();
1354 for col_idx in 0..lookup_col_count {
1355 let arrays: Vec<_> = lookup_batches
1356 .iter()
1357 .map(|opt| match opt {
1358 Some(b) => b.column(col_idx).clone(),
1359 None => arrow_array::new_null_array(lookup_schema.field(col_idx).data_type(), 1),
1360 })
1361 .collect();
1362 let refs: Vec<&dyn arrow_array::Array> = arrays.iter().map(AsRef::as_ref).collect();
1363 columns.push(arrow::compute::concat(&refs)?);
1364 }
1365
1366 debug_assert_eq!(
1367 columns.len(),
1368 stream_field_count + lookup_col_count,
1369 "output column count mismatch"
1370 );
1371
1372 Ok(RecordBatch::try_new(Arc::clone(output_schema), columns)?)
1373}
1374
1375pub struct LookupJoinExtensionPlanner {
1381 registry: Arc<LookupTableRegistry>,
1382}
1383
1384impl LookupJoinExtensionPlanner {
1385 pub fn new(registry: Arc<LookupTableRegistry>) -> Self {
1387 Self { registry }
1388 }
1389}
1390
1391#[async_trait]
1392impl ExtensionPlanner for LookupJoinExtensionPlanner {
1393 #[allow(clippy::too_many_lines)]
1394 async fn plan_extension(
1395 &self,
1396 _planner: &dyn PhysicalPlanner,
1397 node: &dyn UserDefinedLogicalNode,
1398 _logical_inputs: &[&LogicalPlan],
1399 physical_inputs: &[Arc<dyn ExecutionPlan>],
1400 session_state: &SessionState,
1401 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1402 let Some(lookup_node) = node.as_any().downcast_ref::<LookupJoinNode>() else {
1403 return Ok(None);
1404 };
1405
1406 let entry = self
1407 .registry
1408 .get_entry(lookup_node.lookup_table_name())
1409 .ok_or_else(|| {
1410 DataFusionError::Plan(format!(
1411 "lookup table '{}' not registered",
1412 lookup_node.lookup_table_name()
1413 ))
1414 })?;
1415
1416 let input = Arc::clone(&physical_inputs[0]);
1417 let stream_schema = input.schema();
1418
1419 match entry {
1420 RegisteredLookup::Partial(partial_state) => {
1421 let stream_key_indices = resolve_stream_keys(lookup_node, &stream_schema)?;
1422
1423 let mut output_fields = stream_schema.fields().to_vec();
1424 output_fields.extend(partial_state.schema.fields().iter().cloned());
1425 let output_schema = Arc::new(Schema::new(output_fields));
1426
1427 let exec = PartialLookupJoinExec::try_new_with_source(
1428 input,
1429 Arc::clone(&partial_state.foyer_cache),
1430 stream_key_indices,
1431 partial_state.key_sort_fields.clone(),
1432 lookup_node.join_type(),
1433 Arc::clone(&partial_state.schema),
1434 output_schema,
1435 partial_state.source.clone(),
1436 Arc::clone(&partial_state.fetch_semaphore),
1437 )?;
1438 Ok(Some(Arc::new(exec)))
1439 }
1440 RegisteredLookup::Snapshot(snapshot) => {
1441 let lookup_schema = snapshot.batch.schema();
1442 let lookup_key_indices = resolve_lookup_keys(lookup_node, &lookup_schema)?;
1443
1444 let lookup_batch = if lookup_node.pushdown_predicates().is_empty()
1445 || snapshot.batch.num_rows() == 0
1446 {
1447 snapshot.batch.clone()
1448 } else {
1449 apply_pushdown_predicates(
1450 &snapshot.batch,
1451 lookup_node.pushdown_predicates(),
1452 session_state,
1453 )?
1454 };
1455
1456 let stream_key_indices = resolve_stream_keys(lookup_node, &stream_schema)?;
1457
1458 for (si, li) in stream_key_indices.iter().zip(&lookup_key_indices) {
1460 let st = stream_schema.field(*si).data_type();
1461 let lt = lookup_schema.field(*li).data_type();
1462 if st != lt {
1463 return Err(DataFusionError::Plan(format!(
1464 "Lookup join key type mismatch: stream '{}' is {st:?} \
1465 but lookup '{}' is {lt:?}",
1466 stream_schema.field(*si).name(),
1467 lookup_schema.field(*li).name(),
1468 )));
1469 }
1470 }
1471
1472 let mut output_fields = stream_schema.fields().to_vec();
1473 output_fields.extend(lookup_batch.schema().fields().iter().cloned());
1474 let output_schema = Arc::new(Schema::new(output_fields));
1475
1476 let exec = LookupJoinExec::try_new(
1477 input,
1478 lookup_batch,
1479 stream_key_indices,
1480 lookup_key_indices,
1481 lookup_node.join_type(),
1482 output_schema,
1483 )?;
1484
1485 Ok(Some(Arc::new(exec)))
1486 }
1487 RegisteredLookup::Versioned(versioned_state) => {
1488 let table_schema = versioned_state.batch.schema();
1489 let lookup_key_indices = resolve_lookup_keys(lookup_node, &table_schema)?;
1490 let stream_key_indices = resolve_stream_keys(lookup_node, &stream_schema)?;
1491
1492 for (si, li) in stream_key_indices.iter().zip(&lookup_key_indices) {
1494 let st = stream_schema.field(*si).data_type();
1495 let lt = table_schema.field(*li).data_type();
1496 if st != lt {
1497 return Err(DataFusionError::Plan(format!(
1498 "Temporal join key type mismatch: stream '{}' is {st:?} \
1499 but table '{}' is {lt:?}",
1500 stream_schema.field(*si).name(),
1501 table_schema.field(*li).name(),
1502 )));
1503 }
1504 }
1505
1506 let stream_time_col_idx = stream_schema
1507 .index_of(&versioned_state.stream_time_column)
1508 .map_err(|_| {
1509 DataFusionError::Plan(format!(
1510 "stream time column '{}' not found in stream schema",
1511 versioned_state.stream_time_column
1512 ))
1513 })?;
1514
1515 let key_sort_fields: Vec<SortField> = lookup_key_indices
1516 .iter()
1517 .map(|&i| SortField::new(table_schema.field(i).data_type().clone()))
1518 .collect();
1519
1520 let mut output_fields = stream_schema.fields().to_vec();
1521 output_fields.extend(table_schema.fields().iter().cloned());
1522 let output_schema = Arc::new(Schema::new(output_fields));
1523
1524 let exec = VersionedLookupJoinExec::try_new(
1525 input,
1526 versioned_state.batch.clone(),
1527 Arc::clone(&versioned_state.index),
1528 stream_key_indices,
1529 stream_time_col_idx,
1530 lookup_node.join_type(),
1531 output_schema,
1532 key_sort_fields,
1533 )?;
1534
1535 Ok(Some(Arc::new(exec)))
1536 }
1537 }
1538 }
1539}
1540
1541fn apply_pushdown_predicates(
1544 batch: &RecordBatch,
1545 predicates: &[Expr],
1546 session_state: &SessionState,
1547) -> Result<RecordBatch> {
1548 use arrow::compute::filter_record_batch;
1549 use datafusion::physical_expr::create_physical_expr;
1550
1551 let schema = batch.schema();
1552 let df_schema = datafusion::common::DFSchema::try_from(schema.as_ref().clone())?;
1553
1554 let mut mask = None::<arrow_array::BooleanArray>;
1555 for pred in predicates {
1556 let phys_expr = create_physical_expr(pred, &df_schema, session_state.execution_props())?;
1557 let result = phys_expr.evaluate(batch)?;
1558 let bool_arr = result
1559 .into_array(batch.num_rows())?
1560 .as_any()
1561 .downcast_ref::<arrow_array::BooleanArray>()
1562 .ok_or_else(|| {
1563 DataFusionError::Internal("pushdown predicate did not evaluate to boolean".into())
1564 })?
1565 .clone();
1566 mask = Some(match mask {
1567 Some(existing) => arrow::compute::and(&existing, &bool_arr)?,
1568 None => bool_arr,
1569 });
1570 }
1571
1572 match mask {
1573 Some(m) => Ok(filter_record_batch(batch, &m)?),
1574 None => Ok(batch.clone()),
1575 }
1576}
1577
1578fn resolve_stream_keys(node: &LookupJoinNode, schema: &SchemaRef) -> Result<Vec<usize>> {
1579 node.join_keys()
1580 .iter()
1581 .map(|pair| match &pair.stream_expr {
1582 Expr::Column(col) => schema.index_of(&col.name).map_err(|_| {
1583 DataFusionError::Plan(format!(
1584 "stream key column '{}' not found in physical schema",
1585 col.name
1586 ))
1587 }),
1588 other => Err(DataFusionError::NotImplemented(format!(
1589 "lookup join requires column references as stream keys, got: {other}"
1590 ))),
1591 })
1592 .collect()
1593}
1594
1595fn resolve_lookup_keys(node: &LookupJoinNode, schema: &SchemaRef) -> Result<Vec<usize>> {
1596 node.join_keys()
1597 .iter()
1598 .map(|pair| {
1599 schema.index_of(&pair.lookup_column).map_err(|_| {
1600 DataFusionError::Plan(format!(
1601 "lookup key column '{}' not found in lookup table schema",
1602 pair.lookup_column
1603 ))
1604 })
1605 })
1606 .collect()
1607}
1608
1609#[cfg(test)]
1612mod tests {
1613 use super::*;
1614 use arrow_array::{Array, Float64Array, Int64Array, StringArray};
1615 use arrow_schema::{DataType, Field};
1616 use datafusion::physical_plan::stream::RecordBatchStreamAdapter as TestStreamAdapter;
1617 use futures::TryStreamExt;
1618
1619 fn batch_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
1621 let schema = batch.schema();
1622 let batches = vec![batch];
1623 let stream_schema = Arc::clone(&schema);
1624 Arc::new(StreamExecStub {
1625 schema,
1626 batches: std::sync::Mutex::new(Some(batches)),
1627 stream_schema,
1628 })
1629 }
1630
1631 struct StreamExecStub {
1633 schema: SchemaRef,
1634 batches: std::sync::Mutex<Option<Vec<RecordBatch>>>,
1635 stream_schema: SchemaRef,
1636 }
1637
1638 impl Debug for StreamExecStub {
1639 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1640 write!(f, "StreamExecStub")
1641 }
1642 }
1643
1644 impl DisplayAs for StreamExecStub {
1645 fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
1646 write!(f, "StreamExecStub")
1647 }
1648 }
1649
1650 impl ExecutionPlan for StreamExecStub {
1651 fn name(&self) -> &'static str {
1652 "StreamExecStub"
1653 }
1654 fn as_any(&self) -> &dyn Any {
1655 self
1656 }
1657 fn schema(&self) -> SchemaRef {
1658 Arc::clone(&self.schema)
1659 }
1660 fn properties(&self) -> &PlanProperties {
1661 Box::leak(Box::new(PlanProperties::new(
1663 EquivalenceProperties::new(Arc::clone(&self.schema)),
1664 Partitioning::UnknownPartitioning(1),
1665 EmissionType::Final,
1666 Boundedness::Bounded,
1667 )))
1668 }
1669 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1670 vec![]
1671 }
1672 fn with_new_children(
1673 self: Arc<Self>,
1674 _: Vec<Arc<dyn ExecutionPlan>>,
1675 ) -> Result<Arc<dyn ExecutionPlan>> {
1676 Ok(self)
1677 }
1678 fn execute(&self, _: usize, _: Arc<TaskContext>) -> Result<SendableRecordBatchStream> {
1679 let batches = self.batches.lock().unwrap().take().unwrap_or_default();
1680 let schema = Arc::clone(&self.stream_schema);
1681 let stream = futures::stream::iter(batches.into_iter().map(Ok));
1682 Ok(Box::pin(TestStreamAdapter::new(schema, stream)))
1683 }
1684 }
1685
1686 impl datafusion::physical_plan::ExecutionPlanProperties for StreamExecStub {
1687 fn output_partitioning(&self) -> &Partitioning {
1688 self.properties().output_partitioning()
1689 }
1690 fn output_ordering(&self) -> Option<&LexOrdering> {
1691 None
1692 }
1693 fn boundedness(&self) -> Boundedness {
1694 Boundedness::Bounded
1695 }
1696 fn pipeline_behavior(&self) -> EmissionType {
1697 EmissionType::Final
1698 }
1699 fn equivalence_properties(&self) -> &EquivalenceProperties {
1700 self.properties().equivalence_properties()
1701 }
1702 }
1703
1704 fn orders_schema() -> SchemaRef {
1705 Arc::new(Schema::new(vec![
1706 Field::new("order_id", DataType::Int64, false),
1707 Field::new("customer_id", DataType::Int64, false),
1708 Field::new("amount", DataType::Float64, false),
1709 ]))
1710 }
1711
1712 fn customers_schema() -> SchemaRef {
1713 Arc::new(Schema::new(vec![
1714 Field::new("id", DataType::Int64, false),
1715 Field::new("name", DataType::Utf8, true),
1716 ]))
1717 }
1718
1719 fn output_schema() -> SchemaRef {
1720 Arc::new(Schema::new(vec![
1721 Field::new("order_id", DataType::Int64, false),
1722 Field::new("customer_id", DataType::Int64, false),
1723 Field::new("amount", DataType::Float64, false),
1724 Field::new("id", DataType::Int64, false),
1725 Field::new("name", DataType::Utf8, true),
1726 ]))
1727 }
1728
1729 fn customers_batch() -> RecordBatch {
1730 RecordBatch::try_new(
1731 customers_schema(),
1732 vec![
1733 Arc::new(Int64Array::from(vec![1, 2, 3])),
1734 Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
1735 ],
1736 )
1737 .unwrap()
1738 }
1739
1740 fn orders_batch() -> RecordBatch {
1741 RecordBatch::try_new(
1742 orders_schema(),
1743 vec![
1744 Arc::new(Int64Array::from(vec![100, 101, 102, 103])),
1745 Arc::new(Int64Array::from(vec![1, 2, 99, 3])),
1746 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0, 40.0])),
1747 ],
1748 )
1749 .unwrap()
1750 }
1751
1752 fn make_exec(join_type: LookupJoinType) -> LookupJoinExec {
1753 let input = batch_exec(orders_batch());
1754 LookupJoinExec::try_new(
1755 input,
1756 customers_batch(),
1757 vec![1], vec![0], join_type,
1760 output_schema(),
1761 )
1762 .unwrap()
1763 }
1764
1765 #[tokio::test]
1766 async fn inner_join_filters_non_matches() {
1767 let exec = make_exec(LookupJoinType::Inner);
1768 let ctx = Arc::new(TaskContext::default());
1769 let stream = exec.execute(0, ctx).unwrap();
1770 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
1771
1772 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1773 assert_eq!(total, 3, "customer_id=99 has no match, filtered by inner");
1774
1775 let names = batches[0]
1776 .column(4)
1777 .as_any()
1778 .downcast_ref::<StringArray>()
1779 .unwrap();
1780 assert_eq!(names.value(0), "Alice");
1781 assert_eq!(names.value(1), "Bob");
1782 assert_eq!(names.value(2), "Charlie");
1783 }
1784
1785 #[tokio::test]
1786 async fn left_outer_preserves_non_matches() {
1787 let exec = make_exec(LookupJoinType::LeftOuter);
1788 let ctx = Arc::new(TaskContext::default());
1789 let stream = exec.execute(0, ctx).unwrap();
1790 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
1791
1792 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1793 assert_eq!(total, 4, "all 4 stream rows preserved in left outer");
1794
1795 let names = batches[0]
1796 .column(4)
1797 .as_any()
1798 .downcast_ref::<StringArray>()
1799 .unwrap();
1800 assert!(names.is_null(2));
1802 }
1803
1804 #[tokio::test]
1805 async fn empty_lookup_inner_produces_no_rows() {
1806 let empty = RecordBatch::new_empty(customers_schema());
1807 let input = batch_exec(orders_batch());
1808 let exec = LookupJoinExec::try_new(
1809 input,
1810 empty,
1811 vec![1],
1812 vec![0],
1813 LookupJoinType::Inner,
1814 output_schema(),
1815 )
1816 .unwrap();
1817
1818 let ctx = Arc::new(TaskContext::default());
1819 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
1820 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1821 assert_eq!(total, 0);
1822 }
1823
1824 #[tokio::test]
1825 async fn empty_lookup_left_outer_preserves_all_stream_rows() {
1826 let empty = RecordBatch::new_empty(customers_schema());
1827 let input = batch_exec(orders_batch());
1828 let exec = LookupJoinExec::try_new(
1829 input,
1830 empty,
1831 vec![1],
1832 vec![0],
1833 LookupJoinType::LeftOuter,
1834 output_schema(),
1835 )
1836 .unwrap();
1837
1838 let ctx = Arc::new(TaskContext::default());
1839 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
1840 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1841 assert_eq!(total, 4);
1842 }
1843
1844 #[tokio::test]
1845 async fn duplicate_keys_produce_multiple_rows() {
1846 let lookup = RecordBatch::try_new(
1847 customers_schema(),
1848 vec![
1849 Arc::new(Int64Array::from(vec![1, 1])),
1850 Arc::new(StringArray::from(vec!["Alice-A", "Alice-B"])),
1851 ],
1852 )
1853 .unwrap();
1854
1855 let stream = RecordBatch::try_new(
1856 orders_schema(),
1857 vec![
1858 Arc::new(Int64Array::from(vec![100])),
1859 Arc::new(Int64Array::from(vec![1])),
1860 Arc::new(Float64Array::from(vec![10.0])),
1861 ],
1862 )
1863 .unwrap();
1864
1865 let input = batch_exec(stream);
1866 let exec = LookupJoinExec::try_new(
1867 input,
1868 lookup,
1869 vec![1],
1870 vec![0],
1871 LookupJoinType::Inner,
1872 output_schema(),
1873 )
1874 .unwrap();
1875
1876 let ctx = Arc::new(TaskContext::default());
1877 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
1878 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1879 assert_eq!(total, 2, "one stream row matched two lookup rows");
1880 }
1881
1882 #[test]
1883 fn with_new_children_preserves_state() {
1884 let exec = Arc::new(make_exec(LookupJoinType::Inner));
1885 let expected_schema = exec.schema();
1886 let children = exec.children().into_iter().cloned().collect();
1887 let rebuilt = exec.with_new_children(children).unwrap();
1888 assert_eq!(rebuilt.schema(), expected_schema);
1889 assert_eq!(rebuilt.name(), "LookupJoinExec");
1890 }
1891
1892 #[test]
1893 fn display_format() {
1894 let exec = make_exec(LookupJoinType::Inner);
1895 let s = format!("{exec:?}");
1896 assert!(s.contains("LookupJoinExec"));
1897 assert!(s.contains("lookup_rows: 3"));
1898 }
1899
1900 #[test]
1901 fn registry_crud() {
1902 let reg = LookupTableRegistry::new();
1903 assert!(reg.get("customers").is_none());
1904
1905 reg.register(
1906 "customers",
1907 LookupSnapshot {
1908 batch: customers_batch(),
1909 key_columns: vec!["id".into()],
1910 },
1911 );
1912 assert!(reg.get("customers").is_some());
1913 assert!(reg.get("CUSTOMERS").is_some(), "case-insensitive");
1914
1915 reg.unregister("customers");
1916 assert!(reg.get("customers").is_none());
1917 }
1918
1919 #[test]
1920 fn registry_update_replaces() {
1921 let reg = LookupTableRegistry::new();
1922 reg.register(
1923 "t",
1924 LookupSnapshot {
1925 batch: RecordBatch::new_empty(customers_schema()),
1926 key_columns: vec![],
1927 },
1928 );
1929 assert_eq!(reg.get("t").unwrap().batch.num_rows(), 0);
1930
1931 reg.register(
1932 "t",
1933 LookupSnapshot {
1934 batch: customers_batch(),
1935 key_columns: vec![],
1936 },
1937 );
1938 assert_eq!(reg.get("t").unwrap().batch.num_rows(), 3);
1939 }
1940
1941 #[test]
1942 fn pushdown_predicates_filter_snapshot() {
1943 use datafusion::logical_expr::{col, lit};
1944
1945 let batch = customers_batch(); let ctx = datafusion::prelude::SessionContext::new();
1947 let state = ctx.state();
1948
1949 let predicates = vec![col("id").gt(lit(1i64))];
1951 let filtered = apply_pushdown_predicates(&batch, &predicates, &state).unwrap();
1952 assert_eq!(filtered.num_rows(), 2);
1953
1954 let ids = filtered
1955 .column(0)
1956 .as_any()
1957 .downcast_ref::<Int64Array>()
1958 .unwrap();
1959 assert_eq!(ids.value(0), 2);
1960 assert_eq!(ids.value(1), 3);
1961 }
1962
1963 #[test]
1964 fn pushdown_predicates_empty_passes_all() {
1965 let batch = customers_batch();
1966 let ctx = datafusion::prelude::SessionContext::new();
1967 let state = ctx.state();
1968
1969 let filtered = apply_pushdown_predicates(&batch, &[], &state).unwrap();
1970 assert_eq!(filtered.num_rows(), 3);
1971 }
1972
1973 #[test]
1974 fn pushdown_predicates_multiple_and() {
1975 use datafusion::logical_expr::{col, lit};
1976
1977 let batch = customers_batch(); let ctx = datafusion::prelude::SessionContext::new();
1979 let state = ctx.state();
1980
1981 let predicates = vec![col("id").gt_eq(lit(2i64)), col("id").lt(lit(3i64))];
1983 let filtered = apply_pushdown_predicates(&batch, &predicates, &state).unwrap();
1984 assert_eq!(filtered.num_rows(), 1);
1985 }
1986
1987 use laminar_core::lookup::foyer_cache::FoyerMemoryCacheConfig;
1990
1991 fn make_foyer_cache() -> Arc<FoyerMemoryCache> {
1992 Arc::new(FoyerMemoryCache::new(
1993 1,
1994 FoyerMemoryCacheConfig {
1995 capacity: 64,
1996 shards: 4,
1997 },
1998 ))
1999 }
2000
2001 fn customer_row(id: i64, name: &str) -> RecordBatch {
2002 RecordBatch::try_new(
2003 customers_schema(),
2004 vec![
2005 Arc::new(Int64Array::from(vec![id])),
2006 Arc::new(StringArray::from(vec![name])),
2007 ],
2008 )
2009 .unwrap()
2010 }
2011
2012 fn warm_cache(cache: &FoyerMemoryCache) {
2013 let converter = RowConverter::new(vec![SortField::new(DataType::Int64)]).unwrap();
2014
2015 for (id, name) in [(1, "Alice"), (2, "Bob"), (3, "Charlie")] {
2016 let key_col = Int64Array::from(vec![id]);
2017 let rows = converter.convert_columns(&[Arc::new(key_col)]).unwrap();
2018 let key = rows.row(0);
2019 cache.insert(key.as_ref(), customer_row(id, name));
2020 }
2021 }
2022
2023 fn make_partial_exec(join_type: LookupJoinType) -> PartialLookupJoinExec {
2024 let cache = make_foyer_cache();
2025 warm_cache(&cache);
2026
2027 let input = batch_exec(orders_batch());
2028 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2029
2030 PartialLookupJoinExec::try_new(
2031 input,
2032 cache,
2033 vec![1], key_sort_fields,
2035 join_type,
2036 customers_schema(),
2037 output_schema(),
2038 )
2039 .unwrap()
2040 }
2041
2042 #[tokio::test]
2043 async fn partial_inner_join_filters_non_matches() {
2044 let exec = make_partial_exec(LookupJoinType::Inner);
2045 let ctx = Arc::new(TaskContext::default());
2046 let stream = exec.execute(0, ctx).unwrap();
2047 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2048
2049 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2050 assert_eq!(total, 3, "customer_id=99 has no match, filtered by inner");
2051
2052 let names = batches[0]
2053 .column(4)
2054 .as_any()
2055 .downcast_ref::<StringArray>()
2056 .unwrap();
2057 assert_eq!(names.value(0), "Alice");
2058 assert_eq!(names.value(1), "Bob");
2059 assert_eq!(names.value(2), "Charlie");
2060 }
2061
2062 #[tokio::test]
2063 async fn partial_left_outer_preserves_non_matches() {
2064 let exec = make_partial_exec(LookupJoinType::LeftOuter);
2065 let ctx = Arc::new(TaskContext::default());
2066 let stream = exec.execute(0, ctx).unwrap();
2067 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2068
2069 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2070 assert_eq!(total, 4, "all 4 stream rows preserved in left outer");
2071
2072 let names = batches[0]
2073 .column(4)
2074 .as_any()
2075 .downcast_ref::<StringArray>()
2076 .unwrap();
2077 assert!(names.is_null(2), "customer_id=99 should have null name");
2078 }
2079
2080 #[tokio::test]
2081 async fn partial_empty_cache_inner_produces_no_rows() {
2082 let cache = make_foyer_cache();
2083 let input = batch_exec(orders_batch());
2084 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2085
2086 let exec = PartialLookupJoinExec::try_new(
2087 input,
2088 cache,
2089 vec![1],
2090 key_sort_fields,
2091 LookupJoinType::Inner,
2092 customers_schema(),
2093 output_schema(),
2094 )
2095 .unwrap();
2096
2097 let ctx = Arc::new(TaskContext::default());
2098 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2099 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2100 assert_eq!(total, 0);
2101 }
2102
2103 #[tokio::test]
2104 async fn partial_empty_cache_left_outer_preserves_all() {
2105 let cache = make_foyer_cache();
2106 let input = batch_exec(orders_batch());
2107 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2108
2109 let exec = PartialLookupJoinExec::try_new(
2110 input,
2111 cache,
2112 vec![1],
2113 key_sort_fields,
2114 LookupJoinType::LeftOuter,
2115 customers_schema(),
2116 output_schema(),
2117 )
2118 .unwrap();
2119
2120 let ctx = Arc::new(TaskContext::default());
2121 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2122 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2123 assert_eq!(total, 4);
2124 }
2125
2126 #[test]
2127 fn partial_with_new_children_preserves_state() {
2128 let exec = Arc::new(make_partial_exec(LookupJoinType::Inner));
2129 let expected_schema = exec.schema();
2130 let children = exec.children().into_iter().cloned().collect();
2131 let rebuilt = exec.with_new_children(children).unwrap();
2132 assert_eq!(rebuilt.schema(), expected_schema);
2133 assert_eq!(rebuilt.name(), "PartialLookupJoinExec");
2134 }
2135
2136 #[test]
2137 fn partial_display_format() {
2138 let exec = make_partial_exec(LookupJoinType::Inner);
2139 let s = format!("{exec:?}");
2140 assert!(s.contains("PartialLookupJoinExec"));
2141 assert!(s.contains("cache_table_id: 1"));
2142 }
2143
2144 #[test]
2145 fn registry_partial_entry() {
2146 let reg = LookupTableRegistry::new();
2147 let cache = make_foyer_cache();
2148 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2149
2150 reg.register_partial(
2151 "customers",
2152 PartialLookupState {
2153 foyer_cache: cache,
2154 schema: customers_schema(),
2155 key_columns: vec!["id".into()],
2156 key_sort_fields,
2157 source: None,
2158 fetch_semaphore: Arc::new(Semaphore::new(64)),
2159 },
2160 );
2161
2162 assert!(reg.get("customers").is_none());
2163
2164 let entry = reg.get_entry("customers");
2165 assert!(entry.is_some());
2166 assert!(matches!(entry.unwrap(), RegisteredLookup::Partial(_)));
2167 }
2168
2169 #[tokio::test]
2170 async fn partial_source_fallback_on_miss() {
2171 use laminar_core::lookup::source::LookupError;
2172 use laminar_core::lookup::source::LookupSourceDyn;
2173
2174 struct TestSource;
2175
2176 #[async_trait]
2177 impl LookupSourceDyn for TestSource {
2178 async fn query_batch(
2179 &self,
2180 keys: &[&[u8]],
2181 _predicates: &[laminar_core::lookup::predicate::Predicate],
2182 _projection: &[laminar_core::lookup::source::ColumnId],
2183 ) -> std::result::Result<Vec<Option<RecordBatch>>, LookupError> {
2184 Ok(keys
2185 .iter()
2186 .map(|_| Some(customer_row(99, "FromSource")))
2187 .collect())
2188 }
2189
2190 fn schema(&self) -> SchemaRef {
2191 customers_schema()
2192 }
2193 }
2194
2195 let cache = make_foyer_cache();
2196 warm_cache(&cache);
2198
2199 let orders = RecordBatch::try_new(
2200 orders_schema(),
2201 vec![
2202 Arc::new(Int64Array::from(vec![200])),
2203 Arc::new(Int64Array::from(vec![99])), Arc::new(Float64Array::from(vec![50.0])),
2205 ],
2206 )
2207 .unwrap();
2208
2209 let input = batch_exec(orders);
2210 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2211 let source: Arc<dyn LookupSourceDyn> = Arc::new(TestSource);
2212
2213 let exec = PartialLookupJoinExec::try_new_with_source(
2214 input,
2215 cache,
2216 vec![1],
2217 key_sort_fields,
2218 LookupJoinType::Inner,
2219 customers_schema(),
2220 output_schema(),
2221 Some(source),
2222 Arc::new(Semaphore::new(64)),
2223 )
2224 .unwrap();
2225
2226 let ctx = Arc::new(TaskContext::default());
2227 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2228 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2229 assert_eq!(total, 1, "source fallback should produce 1 row");
2230
2231 let names = batches[0]
2232 .column(4)
2233 .as_any()
2234 .downcast_ref::<StringArray>()
2235 .unwrap();
2236 assert_eq!(names.value(0), "FromSource");
2237 }
2238
2239 #[tokio::test]
2240 async fn partial_source_error_graceful_degradation() {
2241 use laminar_core::lookup::source::LookupError;
2242 use laminar_core::lookup::source::LookupSourceDyn;
2243
2244 struct FailingSource;
2245
2246 #[async_trait]
2247 impl LookupSourceDyn for FailingSource {
2248 async fn query_batch(
2249 &self,
2250 _keys: &[&[u8]],
2251 _predicates: &[laminar_core::lookup::predicate::Predicate],
2252 _projection: &[laminar_core::lookup::source::ColumnId],
2253 ) -> std::result::Result<Vec<Option<RecordBatch>>, LookupError> {
2254 Err(LookupError::Internal("source unavailable".into()))
2255 }
2256
2257 fn schema(&self) -> SchemaRef {
2258 customers_schema()
2259 }
2260 }
2261
2262 let cache = make_foyer_cache();
2263 let input = batch_exec(orders_batch());
2264 let key_sort_fields = vec![SortField::new(DataType::Int64)];
2265 let source: Arc<dyn LookupSourceDyn> = Arc::new(FailingSource);
2266
2267 let exec = PartialLookupJoinExec::try_new_with_source(
2268 input,
2269 cache,
2270 vec![1],
2271 key_sort_fields,
2272 LookupJoinType::LeftOuter,
2273 customers_schema(),
2274 output_schema(),
2275 Some(source),
2276 Arc::new(Semaphore::new(64)),
2277 )
2278 .unwrap();
2279
2280 let ctx = Arc::new(TaskContext::default());
2281 let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2282 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2283 assert_eq!(total, 4);
2285 }
2286
2287 #[test]
2288 fn registry_snapshot_entry_via_get_entry() {
2289 let reg = LookupTableRegistry::new();
2290 reg.register(
2291 "t",
2292 LookupSnapshot {
2293 batch: customers_batch(),
2294 key_columns: vec!["id".into()],
2295 },
2296 );
2297
2298 let entry = reg.get_entry("t");
2299 assert!(matches!(entry.unwrap(), RegisteredLookup::Snapshot(_)));
2300 assert!(reg.get("t").is_some());
2301 }
2302
2303 fn nullable_orders_schema() -> SchemaRef {
2306 Arc::new(Schema::new(vec![
2307 Field::new("order_id", DataType::Int64, false),
2308 Field::new("customer_id", DataType::Int64, true), Field::new("amount", DataType::Float64, false),
2310 ]))
2311 }
2312
2313 fn nullable_output_schema(join_type: LookupJoinType) -> SchemaRef {
2314 let lookup_nullable = join_type == LookupJoinType::LeftOuter;
2315 Arc::new(Schema::new(vec![
2316 Field::new("order_id", DataType::Int64, false),
2317 Field::new("customer_id", DataType::Int64, true),
2318 Field::new("amount", DataType::Float64, false),
2319 Field::new("id", DataType::Int64, lookup_nullable),
2320 Field::new("name", DataType::Utf8, true),
2321 ]))
2322 }
2323
2324 #[tokio::test]
2325 async fn null_key_inner_join_no_match() {
2326 let stream_batch = RecordBatch::try_new(
2328 nullable_orders_schema(),
2329 vec![
2330 Arc::new(Int64Array::from(vec![100, 101, 102])),
2331 Arc::new(Int64Array::from(vec![Some(1), None, Some(2)])),
2332 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
2333 ],
2334 )
2335 .unwrap();
2336
2337 let input = batch_exec(stream_batch);
2338 let exec = LookupJoinExec::try_new(
2339 input,
2340 customers_batch(),
2341 vec![1],
2342 vec![0],
2343 LookupJoinType::Inner,
2344 nullable_output_schema(LookupJoinType::Inner),
2345 )
2346 .unwrap();
2347
2348 let ctx = Arc::new(TaskContext::default());
2349 let stream = exec.execute(0, ctx).unwrap();
2350 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2351
2352 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2353 assert_eq!(total, 2, "NULL key row should not match in inner join");
2355 }
2356
2357 #[tokio::test]
2358 async fn null_key_left_outer_produces_nulls() {
2359 let stream_batch = RecordBatch::try_new(
2361 nullable_orders_schema(),
2362 vec![
2363 Arc::new(Int64Array::from(vec![100, 101, 102])),
2364 Arc::new(Int64Array::from(vec![Some(1), None, Some(2)])),
2365 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
2366 ],
2367 )
2368 .unwrap();
2369
2370 let input = batch_exec(stream_batch);
2371 let out_schema = nullable_output_schema(LookupJoinType::LeftOuter);
2372 let exec = LookupJoinExec::try_new(
2373 input,
2374 customers_batch(),
2375 vec![1],
2376 vec![0],
2377 LookupJoinType::LeftOuter,
2378 out_schema,
2379 )
2380 .unwrap();
2381
2382 let ctx = Arc::new(TaskContext::default());
2383 let stream = exec.execute(0, ctx).unwrap();
2384 let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2385
2386 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2387 assert_eq!(total, 3, "all rows preserved in left outer");
2389
2390 let names = batches[0]
2391 .column(4)
2392 .as_any()
2393 .downcast_ref::<StringArray>()
2394 .unwrap();
2395 assert_eq!(names.value(0), "Alice");
2396 assert!(
2397 names.is_null(1),
2398 "NULL key row should have null lookup name"
2399 );
2400 assert_eq!(names.value(2), "Bob");
2401 }
2402
2403 fn versioned_table_batch() -> RecordBatch {
2406 let schema = Arc::new(Schema::new(vec![
2409 Field::new("currency", DataType::Utf8, false),
2410 Field::new("valid_from", DataType::Int64, false),
2411 Field::new("rate", DataType::Float64, false),
2412 ]));
2413 RecordBatch::try_new(
2414 schema,
2415 vec![
2416 Arc::new(StringArray::from(vec!["USD", "USD", "EUR", "EUR", "EUR"])),
2417 Arc::new(Int64Array::from(vec![100, 200, 100, 150, 300])),
2418 Arc::new(Float64Array::from(vec![1.0, 1.1, 0.85, 0.90, 0.88])),
2419 ],
2420 )
2421 .unwrap()
2422 }
2423
2424 fn stream_batch_with_time() -> RecordBatch {
2425 let schema = Arc::new(Schema::new(vec![
2426 Field::new("order_id", DataType::Int64, false),
2427 Field::new("currency", DataType::Utf8, false),
2428 Field::new("event_ts", DataType::Int64, false),
2429 ]));
2430 RecordBatch::try_new(
2431 schema,
2432 vec![
2433 Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
2434 Arc::new(StringArray::from(vec!["USD", "EUR", "USD", "EUR"])),
2435 Arc::new(Int64Array::from(vec![150, 160, 250, 50])),
2436 ],
2437 )
2438 .unwrap()
2439 }
2440
2441 #[test]
2442 fn test_versioned_index_build_and_probe() {
2443 let batch = versioned_table_batch();
2444 let index = VersionedIndex::build(&batch, &[0], 1).unwrap();
2445
2446 let key_sf = vec![SortField::new(DataType::Utf8)];
2449 let converter = RowConverter::new(key_sf).unwrap();
2450 let usd_col = Arc::new(StringArray::from(vec!["USD"]));
2451 let usd_rows = converter.convert_columns(&[usd_col]).unwrap();
2452 let usd_key = usd_rows.row(0);
2453
2454 let result = index.probe_at_time(usd_key.as_ref(), 150);
2455 assert!(result.is_some());
2456 assert_eq!(result.unwrap(), 0);
2458
2459 let result = index.probe_at_time(usd_key.as_ref(), 250);
2461 assert_eq!(result.unwrap(), 1);
2462 }
2463
2464 #[test]
2465 fn test_versioned_index_no_version_before_ts() {
2466 let batch = versioned_table_batch();
2467 let index = VersionedIndex::build(&batch, &[0], 1).unwrap();
2468
2469 let key_sf = vec![SortField::new(DataType::Utf8)];
2470 let converter = RowConverter::new(key_sf).unwrap();
2471 let eur_col = Arc::new(StringArray::from(vec!["EUR"]));
2472 let eur_rows = converter.convert_columns(&[eur_col]).unwrap();
2473 let eur_key = eur_rows.row(0);
2474
2475 let result = index.probe_at_time(eur_key.as_ref(), 50);
2477 assert!(result.is_none());
2478 }
2479
2480 fn build_versioned_exec(
2482 table: RecordBatch,
2483 stream: &RecordBatch,
2484 join_type: LookupJoinType,
2485 ) -> VersionedLookupJoinExec {
2486 let input = batch_exec(stream.clone());
2487 let index = Arc::new(VersionedIndex::build(&table, &[0], 1).unwrap());
2488 let key_sort_fields = vec![SortField::new(DataType::Utf8)];
2489 let mut output_fields = stream.schema().fields().to_vec();
2490 output_fields.extend(table.schema().fields().iter().cloned());
2491 let output_schema = Arc::new(Schema::new(output_fields));
2492 VersionedLookupJoinExec::try_new(
2493 input,
2494 table,
2495 index,
2496 vec![1], 2, join_type,
2499 output_schema,
2500 key_sort_fields,
2501 )
2502 .unwrap()
2503 }
2504
2505 #[tokio::test]
2506 async fn test_versioned_join_exec_inner() {
2507 let table = versioned_table_batch();
2508 let stream = stream_batch_with_time();
2509 let exec = build_versioned_exec(table, &stream, LookupJoinType::Inner);
2510
2511 let ctx = Arc::new(TaskContext::default());
2512 let stream_out = exec.execute(0, ctx).unwrap();
2513 let batches: Vec<RecordBatch> = stream_out.try_collect().await.unwrap();
2514
2515 assert_eq!(batches.len(), 1);
2516 let batch = &batches[0];
2517 assert_eq!(batch.num_rows(), 3);
2522
2523 let rates = batch
2524 .column(5) .as_any()
2526 .downcast_ref::<Float64Array>()
2527 .unwrap();
2528 assert!((rates.value(0) - 1.0).abs() < f64::EPSILON); assert!((rates.value(1) - 0.90).abs() < f64::EPSILON); assert!((rates.value(2) - 1.1).abs() < f64::EPSILON); }
2532
2533 #[tokio::test]
2534 async fn test_versioned_join_exec_left_outer() {
2535 let table = versioned_table_batch();
2536 let stream = stream_batch_with_time();
2537 let exec = build_versioned_exec(table, &stream, LookupJoinType::LeftOuter);
2538
2539 let ctx = Arc::new(TaskContext::default());
2540 let stream_out = exec.execute(0, ctx).unwrap();
2541 let batches: Vec<RecordBatch> = stream_out.try_collect().await.unwrap();
2542
2543 assert_eq!(batches.len(), 1);
2544 let batch = &batches[0];
2545 assert_eq!(batch.num_rows(), 4);
2547
2548 let rates = batch
2550 .column(5)
2551 .as_any()
2552 .downcast_ref::<Float64Array>()
2553 .unwrap();
2554 assert!(rates.is_null(3), "EUR@50 should have null rate");
2555 }
2556
2557 #[test]
2558 fn test_versioned_index_empty_batch() {
2559 let schema = Arc::new(Schema::new(vec![
2560 Field::new("k", DataType::Utf8, false),
2561 Field::new("v", DataType::Int64, false),
2562 ]));
2563 let batch = RecordBatch::new_empty(schema);
2564 let index = VersionedIndex::build(&batch, &[0], 1).unwrap();
2565 assert!(index.map.is_empty());
2566 }
2567
2568 #[test]
2569 fn test_versioned_lookup_registry() {
2570 let registry = LookupTableRegistry::new();
2571 let table = versioned_table_batch();
2572 let index = Arc::new(VersionedIndex::build(&table, &[0], 1).unwrap());
2573
2574 registry.register_versioned(
2575 "rates",
2576 VersionedLookupState {
2577 batch: table,
2578 index,
2579 key_columns: vec!["currency".to_string()],
2580 version_column: "valid_from".to_string(),
2581 stream_time_column: "event_ts".to_string(),
2582 },
2583 );
2584
2585 let entry = registry.get_entry("rates");
2586 assert!(entry.is_some());
2587 assert!(matches!(entry.unwrap(), RegisteredLookup::Versioned(_)));
2588
2589 assert!(registry.get("rates").is_none());
2591 }
2592}