1use std::collections::HashSet;
8use std::fmt;
9use std::hash::{Hash, Hasher};
10use std::sync::Arc;
11
12use datafusion::common::DFSchemaRef;
13use datafusion::logical_expr::logical_plan::LogicalPlan;
14use datafusion::logical_expr::{Expr, UserDefinedLogicalNodeCore};
15use datafusion_common::Result;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
19pub enum LookupJoinType {
20 Inner,
22 LeftOuter,
24}
25
26impl fmt::Display for LookupJoinType {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 match self {
29 Self::Inner => write!(f, "Inner"),
30 Self::LeftOuter => write!(f, "LeftOuter"),
31 }
32 }
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub struct JoinKeyPair {
38 pub stream_expr: Expr,
40 pub lookup_column: String,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
46pub struct LookupTableMetadata {
47 pub connector: String,
49 pub strategy: String,
51 pub pushdown_mode: String,
53 pub primary_key: Vec<String>,
55}
56
57#[derive(Debug, Clone)]
63pub struct LookupJoinNode {
64 input: Arc<LogicalPlan>,
66 lookup_table: String,
68 lookup_schema: DFSchemaRef,
70 join_keys: Vec<JoinKeyPair>,
72 join_type: LookupJoinType,
74 pushdown_predicates: Vec<Expr>,
76 local_predicates: Vec<Expr>,
78 required_lookup_columns: HashSet<String>,
80 output_schema: DFSchemaRef,
82 metadata: LookupTableMetadata,
84 lookup_alias: Option<String>,
86 stream_alias: Option<String>,
88}
89
90impl PartialEq for LookupJoinNode {
91 fn eq(&self, other: &Self) -> bool {
92 self.lookup_table == other.lookup_table
93 && self.join_keys == other.join_keys
94 && self.join_type == other.join_type
95 && self.pushdown_predicates == other.pushdown_predicates
96 && self.local_predicates == other.local_predicates
97 && self.required_lookup_columns == other.required_lookup_columns
98 && self.metadata == other.metadata
99 }
100}
101
102impl Eq for LookupJoinNode {}
103
104impl Hash for LookupJoinNode {
105 fn hash<H: Hasher>(&self, state: &mut H) {
106 self.lookup_table.hash(state);
107 self.join_keys.hash(state);
108 self.join_type.hash(state);
109 self.pushdown_predicates.hash(state);
110 self.local_predicates.hash(state);
111 self.metadata.hash(state);
112 let mut cols: Vec<&String> = self.required_lookup_columns.iter().collect();
114 cols.sort();
115 cols.hash(state);
116 }
117}
118
119impl PartialOrd for LookupJoinNode {
120 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
121 self.lookup_table.partial_cmp(&other.lookup_table)
122 }
123}
124
125impl LookupJoinNode {
126 #[must_use]
128 #[allow(clippy::too_many_arguments)]
129 pub fn new(
130 input: LogicalPlan,
131 lookup_table: String,
132 lookup_schema: DFSchemaRef,
133 join_keys: Vec<JoinKeyPair>,
134 join_type: LookupJoinType,
135 pushdown_predicates: Vec<Expr>,
136 required_lookup_columns: HashSet<String>,
137 output_schema: DFSchemaRef,
138 metadata: LookupTableMetadata,
139 ) -> Self {
140 Self {
141 input: Arc::new(input),
142 lookup_table,
143 lookup_schema,
144 join_keys,
145 join_type,
146 pushdown_predicates,
147 local_predicates: vec![],
148 required_lookup_columns,
149 output_schema,
150 metadata,
151 lookup_alias: None,
152 stream_alias: None,
153 }
154 }
155
156 #[must_use]
158 pub fn with_local_predicates(mut self, predicates: Vec<Expr>) -> Self {
159 self.local_predicates = predicates;
160 self
161 }
162
163 #[must_use]
165 pub fn with_aliases(
166 mut self,
167 lookup_alias: Option<String>,
168 stream_alias: Option<String>,
169 ) -> Self {
170 self.lookup_alias = lookup_alias;
171 self.stream_alias = stream_alias;
172 self
173 }
174
175 #[must_use]
177 pub fn lookup_table_name(&self) -> &str {
178 &self.lookup_table
179 }
180
181 #[must_use]
183 pub fn join_keys(&self) -> &[JoinKeyPair] {
184 &self.join_keys
185 }
186
187 #[must_use]
189 pub fn join_type(&self) -> LookupJoinType {
190 self.join_type
191 }
192
193 #[must_use]
195 pub fn pushdown_predicates(&self) -> &[Expr] {
196 &self.pushdown_predicates
197 }
198
199 #[must_use]
201 pub fn required_lookup_columns(&self) -> &HashSet<String> {
202 &self.required_lookup_columns
203 }
204
205 #[must_use]
207 pub fn metadata(&self) -> &LookupTableMetadata {
208 &self.metadata
209 }
210
211 #[must_use]
213 pub fn lookup_schema(&self) -> &DFSchemaRef {
214 &self.lookup_schema
215 }
216
217 #[must_use]
219 pub fn local_predicates(&self) -> &[Expr] {
220 &self.local_predicates
221 }
222
223 #[must_use]
225 pub fn lookup_alias(&self) -> Option<&str> {
226 self.lookup_alias.as_deref()
227 }
228
229 #[must_use]
231 pub fn stream_alias(&self) -> Option<&str> {
232 self.stream_alias.as_deref()
233 }
234}
235
236impl UserDefinedLogicalNodeCore for LookupJoinNode {
237 fn name(&self) -> &'static str {
238 "LookupJoin"
239 }
240
241 fn inputs(&self) -> Vec<&LogicalPlan> {
242 vec![&self.input]
243 }
244
245 fn schema(&self) -> &DFSchemaRef {
246 &self.output_schema
247 }
248
249 fn expressions(&self) -> Vec<Expr> {
250 self.join_keys
251 .iter()
252 .map(|k| k.stream_expr.clone())
253 .chain(self.pushdown_predicates.clone())
254 .chain(self.local_predicates.clone())
255 .collect()
256 }
257
258 fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
259 let keys: Vec<String> = self
260 .join_keys
261 .iter()
262 .map(|k| format!("{}={}", k.stream_expr, k.lookup_column))
263 .collect();
264 write!(
265 f,
266 "LookupJoin: table={}, keys=[{}], type={}, pushdown={}, local={}",
267 self.lookup_table,
268 keys.join(", "),
269 self.join_type,
270 self.pushdown_predicates.len(),
271 self.local_predicates.len(),
272 )
273 }
274
275 fn with_exprs_and_inputs(
276 &self,
277 exprs: Vec<Expr>,
278 mut inputs: Vec<LogicalPlan>,
279 ) -> Result<Self> {
280 let input = inputs.swap_remove(0);
281
282 let num_keys = self.join_keys.len();
284 let num_pushdown = self.pushdown_predicates.len();
285 let (key_exprs, rest) = exprs.split_at(num_keys.min(exprs.len()));
286 let (pushdown_exprs, local_exprs) = rest.split_at(num_pushdown.min(rest.len()));
287
288 let join_keys: Vec<JoinKeyPair> = key_exprs
289 .iter()
290 .zip(self.join_keys.iter())
291 .map(|(expr, old)| JoinKeyPair {
292 stream_expr: expr.clone(),
293 lookup_column: old.lookup_column.clone(),
294 })
295 .collect();
296
297 Ok(Self {
298 input: Arc::new(input),
299 lookup_table: self.lookup_table.clone(),
300 lookup_schema: Arc::clone(&self.lookup_schema),
301 join_keys,
302 join_type: self.join_type,
303 pushdown_predicates: pushdown_exprs.to_vec(),
304 local_predicates: local_exprs.to_vec(),
305 required_lookup_columns: self.required_lookup_columns.clone(),
306 output_schema: Arc::clone(&self.output_schema),
307 metadata: self.metadata.clone(),
308 lookup_alias: self.lookup_alias.clone(),
309 stream_alias: self.stream_alias.clone(),
310 })
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use std::fmt::Write;
318
319 use arrow::datatypes::{DataType, Field, Schema};
320 use datafusion::common::DFSchema;
321 use datafusion::logical_expr::col;
322
323 fn test_stream_schema() -> DFSchemaRef {
324 Arc::new(
325 DFSchema::try_from(Schema::new(vec![
326 Field::new("order_id", DataType::Int64, false),
327 Field::new("customer_id", DataType::Int64, false),
328 Field::new("amount", DataType::Float64, false),
329 ]))
330 .unwrap(),
331 )
332 }
333
334 fn test_lookup_schema() -> DFSchemaRef {
335 Arc::new(
336 DFSchema::try_from(Schema::new(vec![
337 Field::new("id", DataType::Int64, false),
338 Field::new("name", DataType::Utf8, true),
339 Field::new("region", DataType::Utf8, true),
340 ]))
341 .unwrap(),
342 )
343 }
344
345 fn test_output_schema() -> DFSchemaRef {
346 Arc::new(
347 DFSchema::try_from(Schema::new(vec![
348 Field::new("order_id", DataType::Int64, false),
349 Field::new("customer_id", DataType::Int64, false),
350 Field::new("amount", DataType::Float64, false),
351 Field::new("id", DataType::Int64, false),
352 Field::new("name", DataType::Utf8, true),
353 Field::new("region", DataType::Utf8, true),
354 ]))
355 .unwrap(),
356 )
357 }
358
359 fn test_metadata() -> LookupTableMetadata {
360 LookupTableMetadata {
361 connector: "postgres-cdc".to_string(),
362 strategy: "replicated".to_string(),
363 pushdown_mode: "auto".to_string(),
364 primary_key: vec!["id".to_string()],
365 }
366 }
367
368 fn test_node() -> LookupJoinNode {
369 let stream_schema = test_stream_schema();
370 let input = LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
371 produce_one_row: false,
372 schema: stream_schema,
373 });
374
375 LookupJoinNode::new(
376 input,
377 "customers".to_string(),
378 test_lookup_schema(),
379 vec![JoinKeyPair {
380 stream_expr: col("customer_id"),
381 lookup_column: "id".to_string(),
382 }],
383 LookupJoinType::Inner,
384 vec![],
385 HashSet::from(["name".to_string(), "region".to_string()]),
386 test_output_schema(),
387 test_metadata(),
388 )
389 }
390
391 #[test]
392 fn test_name() {
393 let node = test_node();
394 assert_eq!(node.name(), "LookupJoin");
395 }
396
397 #[test]
398 fn test_inputs() {
399 let node = test_node();
400 assert_eq!(node.inputs().len(), 1);
401 }
402
403 #[test]
404 fn test_schema() {
405 let node = test_node();
406 assert_eq!(node.schema().fields().len(), 6);
407 }
408
409 #[test]
410 fn test_expressions() {
411 let node = test_node();
412 let exprs = node.expressions();
413 assert_eq!(exprs.len(), 1); }
415
416 #[test]
417 fn test_fmt_for_explain() {
418 let node = test_node();
419 let explain = format!("{node:?}");
420 assert!(explain.contains("LookupJoin"));
421
422 let mut buf = String::new();
424 write!(buf, "{}", DisplayExplain(&node)).unwrap();
425 assert!(buf.contains("LookupJoin: table=customers"));
426 assert!(buf.contains("type=Inner"));
427 }
428
429 #[test]
430 fn test_with_exprs_and_inputs_roundtrip() {
431 let node = test_node();
432 let exprs = node.expressions();
433 let inputs: Vec<LogicalPlan> = node.inputs().into_iter().cloned().collect();
434
435 let rebuilt = node.with_exprs_and_inputs(exprs, inputs).unwrap();
436 assert_eq!(rebuilt.lookup_table, "customers");
437 assert_eq!(rebuilt.join_keys.len(), 1);
438 assert_eq!(rebuilt.join_type, LookupJoinType::Inner);
439 }
440
441 #[test]
442 fn test_left_outer_join() {
443 let stream_schema = test_stream_schema();
444 let input = LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
445 produce_one_row: false,
446 schema: stream_schema,
447 });
448
449 let node = LookupJoinNode::new(
450 input,
451 "customers".to_string(),
452 test_lookup_schema(),
453 vec![JoinKeyPair {
454 stream_expr: col("customer_id"),
455 lookup_column: "id".to_string(),
456 }],
457 LookupJoinType::LeftOuter,
458 vec![],
459 HashSet::new(),
460 test_output_schema(),
461 test_metadata(),
462 );
463
464 assert_eq!(node.join_type(), LookupJoinType::LeftOuter);
465 }
466
467 struct DisplayExplain<'a>(&'a LookupJoinNode);
469
470 impl fmt::Display for DisplayExplain<'_> {
471 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
472 UserDefinedLogicalNodeCore::fmt_for_explain(self.0, f)
473 }
474 }
475}