Skip to main content

kyu_executor/operators/
hash_join.rs

1//! Hash join operator — build hash table from left side, probe from right.
2
3use hashbrown::HashMap;
4use kyu_common::KyuResult;
5use kyu_expression::{evaluate, BoundExpression};
6use kyu_types::TypedValue;
7
8use crate::context::ExecutionContext;
9use crate::data_chunk::DataChunk;
10use crate::physical_plan::PhysicalOperator;
11
12pub struct HashJoinOp {
13    pub build: Box<PhysicalOperator>,
14    pub probe: Box<PhysicalOperator>,
15    pub build_keys: Vec<BoundExpression>,
16    pub probe_keys: Vec<BoundExpression>,
17    /// Build-side data: stored chunks + index into them.
18    build_data: Option<BuildData>,
19}
20
21struct BuildData {
22    chunks: Vec<DataChunk>,
23    /// key values → list of (chunk_idx, row_idx) pairs.
24    ht: HashMap<Vec<TypedValue>, Vec<(u32, u32)>>,
25    num_cols: usize,
26}
27
28impl HashJoinOp {
29    pub fn new(
30        build: PhysicalOperator,
31        probe: PhysicalOperator,
32        build_keys: Vec<BoundExpression>,
33        probe_keys: Vec<BoundExpression>,
34    ) -> Self {
35        Self {
36            build: Box::new(build),
37            probe: Box::new(probe),
38            build_keys,
39            probe_keys,
40            build_data: None,
41        }
42    }
43
44    pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
45        // Build phase: drain build side on first call, store chunks + index.
46        if self.build_data.is_none() {
47            let mut chunks = Vec::new();
48            let mut ht: HashMap<Vec<TypedValue>, Vec<(u32, u32)>> = HashMap::new();
49            while let Some(chunk) = self.build.next(ctx)? {
50                let ci = chunks.len() as u32;
51                for row_idx in 0..chunk.num_rows() {
52                    let row_ref = chunk.row_ref(row_idx);
53                    let key = eval_keys(&self.build_keys, &row_ref)?;
54                    ht.entry(key).or_default().push((ci, row_idx as u32));
55                }
56                chunks.push(chunk);
57            }
58            let num_cols = chunks.first().map_or(0, |c| c.num_columns());
59            self.build_data = Some(BuildData {
60                chunks,
61                ht,
62                num_cols,
63            });
64        }
65
66        let bd = self.build_data.as_ref().unwrap();
67
68        // Probe phase: pull from probe side.
69        loop {
70            let chunk = match self.probe.next(ctx)? {
71                Some(c) => c,
72                None => return Ok(None),
73            };
74
75            let probe_ncols = chunk.num_columns();
76            let total_cols = bd.num_cols + probe_ncols;
77            let mut result = DataChunk::with_capacity(total_cols, chunk.num_rows());
78
79            for row_idx in 0..chunk.num_rows() {
80                let row_ref = chunk.row_ref(row_idx);
81                let key = eval_keys(&self.probe_keys, &row_ref)?;
82                if let Some(build_locs) = bd.ht.get(&key) {
83                    for &(ci, ri) in build_locs {
84                        let build_chunk = &bd.chunks[ci as usize];
85                        let mut combined = build_chunk.get_row(ri as usize);
86                        for col_idx in 0..probe_ncols {
87                            combined.push(chunk.get_value(row_idx, col_idx));
88                        }
89                        result.append_row(&combined);
90                    }
91                }
92            }
93
94            if !result.is_empty() {
95                return Ok(Some(result));
96            }
97            // No matches for this probe chunk, try next.
98        }
99    }
100}
101
102fn eval_keys<T: kyu_expression::Tuple + ?Sized>(
103    keys: &[BoundExpression],
104    tuple: &T,
105) -> KyuResult<Vec<TypedValue>> {
106    keys.iter().map(|k| evaluate(k, tuple)).collect()
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use crate::context::MockStorage;
113    use kyu_types::LogicalType;
114    use smol_str::SmolStr;
115
116    #[test]
117    fn hash_join_basic() {
118        let mut storage = MockStorage::new();
119        // Left: id, name
120        storage.insert_table(
121            kyu_common::id::TableId(0),
122            vec![
123                vec![TypedValue::Int64(1), TypedValue::String(SmolStr::new("Alice"))],
124                vec![TypedValue::Int64(2), TypedValue::String(SmolStr::new("Bob"))],
125            ],
126        );
127        // Right: id, score
128        storage.insert_table(
129            kyu_common::id::TableId(1),
130            vec![
131                vec![TypedValue::Int64(1), TypedValue::Int64(100)],
132                vec![TypedValue::Int64(2), TypedValue::Int64(200)],
133                vec![TypedValue::Int64(3), TypedValue::Int64(300)],
134            ],
135        );
136        let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
137
138        let build = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
139            kyu_common::id::TableId(0),
140        ));
141        let probe = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
142            kyu_common::id::TableId(1),
143        ));
144
145        // Join on column 0 = column 0.
146        let build_key = BoundExpression::Variable {
147            index: 0,
148            result_type: LogicalType::Int64,
149        };
150        let probe_key = BoundExpression::Variable {
151            index: 0,
152            result_type: LogicalType::Int64,
153        };
154
155        let mut join = HashJoinOp::new(build, probe, vec![build_key], vec![probe_key]);
156        let chunk = join.next(&ctx).unwrap().unwrap();
157        // id=1 and id=2 match → 2 result rows, each with 4 columns.
158        assert_eq!(chunk.num_rows(), 2);
159        assert_eq!(chunk.num_columns(), 4);
160        // After join: no more results.
161        assert!(join.next(&ctx).unwrap().is_none());
162    }
163}