Skip to main content

reifydb_engine/flow/compiler/operator/
join.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::{
5	common::JoinType::{self, Inner, Left},
6	interface::catalog::flow::FlowNodeId,
7};
8use reifydb_rql::{
9	expression::Expression,
10	flow::node::FlowNodeType,
11	nodes::{JoinInnerNode, JoinLeftNode},
12	query::QueryPlan,
13};
14use reifydb_transaction::transaction::admin::AdminTransaction;
15use reifydb_type::Result;
16
17use crate::flow::compiler::{CompileOperator, FlowCompiler};
18
19pub(crate) struct JoinCompiler {
20	pub join_type: JoinType,
21	pub left: Box<QueryPlan>,
22	pub right: Box<QueryPlan>,
23	pub on: Vec<Expression>,
24	pub alias: Option<String>,
25}
26
27impl From<JoinInnerNode> for JoinCompiler {
28	fn from(node: JoinInnerNode) -> Self {
29		Self {
30			join_type: Inner,
31			left: node.left,
32			right: node.right,
33			on: node.on,
34			alias: node.alias.map(|f| f.text().to_string()),
35		}
36	}
37}
38
39impl From<JoinLeftNode> for JoinCompiler {
40	fn from(node: JoinLeftNode) -> Self {
41		Self {
42			join_type: Left,
43			left: node.left,
44			right: node.right,
45			on: node.on,
46			alias: node.alias.map(|f| f.text().to_string()),
47		}
48	}
49}
50
51// Extract the source name from a query plan if it's a scan node
52fn extract_source_name(plan: &QueryPlan) -> Option<String> {
53	match plan {
54		QueryPlan::TableScan(node) => Some(node.source.def().name.clone()),
55		QueryPlan::ViewScan(node) => Some(node.source.def().name.clone()),
56		QueryPlan::RingBufferScan(node) => Some(node.source.def().name.clone()),
57		QueryPlan::DictionaryScan(node) => Some(node.source.def().name.clone()),
58		// For other node types, try to recursively find the source
59		QueryPlan::Filter(node) => extract_source_name(&node.input),
60		QueryPlan::Map(node) => node.input.as_ref().and_then(|p| extract_source_name(p)),
61		QueryPlan::Take(node) => extract_source_name(&node.input),
62		_ => None,
63	}
64}
65
66/// Recursively collect all Equal leaves from an And tree.
67fn collect_equal_conditions(expr: &Expression, out: &mut Vec<Expression>) {
68	match expr {
69		Expression::And(and) => {
70			collect_equal_conditions(&and.left, out);
71			collect_equal_conditions(&and.right, out);
72		}
73		other => out.push(other.clone()),
74	}
75}
76
77/// Extract left and right key expressions from join conditions.
78/// Handles multi-column joins where conditions are combined with And.
79fn extract_join_keys(conditions: &[Expression]) -> (Vec<Expression>, Vec<Expression>) {
80	let mut left_keys = Vec::new();
81	let mut right_keys = Vec::new();
82
83	// Flatten any And trees into individual conditions
84	let mut flat = Vec::new();
85	for condition in conditions {
86		collect_equal_conditions(condition, &mut flat);
87	}
88
89	for condition in flat {
90		match condition {
91			Expression::Equal(eq) => {
92				left_keys.push(*eq.left.clone());
93				right_keys.push(*eq.right.clone());
94			}
95			_ => {
96				// Non-equality condition: pass through to both sides (existing fallback)
97				left_keys.push(condition.clone());
98				right_keys.push(condition.clone());
99			}
100		}
101	}
102
103	(left_keys, right_keys)
104}
105
106impl CompileOperator for JoinCompiler {
107	fn compile(self, compiler: &mut FlowCompiler, txn: &mut AdminTransaction) -> Result<FlowNodeId> {
108		// Extract source name from right plan for fallback alias
109		let source_name = extract_source_name(&self.right);
110
111		let left_node = compiler.compile_plan(txn, *self.left)?;
112		let right_node = compiler.compile_plan(txn, *self.right)?;
113
114		let (left_keys, right_keys) = extract_join_keys(&self.on);
115
116		// Use explicit alias, or fall back to extracted source name, or use "other"
117		let effective_alias = self.alias.or(source_name).or_else(|| Some("other".to_string()));
118
119		let node_id = compiler.add_node(
120			txn,
121			FlowNodeType::Join {
122				join_type: self.join_type,
123				left: left_keys,
124				right: right_keys,
125				alias: effective_alias,
126			},
127		)?;
128
129		compiler.add_edge(txn, &left_node, &node_id)?;
130		compiler.add_edge(txn, &right_node, &node_id)?;
131
132		Ok(node_id)
133	}
134}