kyu_executor/operators/
hash_join.rs1use hashbrown::HashMap;
4use kyu_common::KyuResult;
5use kyu_expression::{evaluate, BoundExpression};
6use kyu_types::TypedValue;
7
8use crate::context::ExecutionContext;
9use crate::data_chunk::DataChunk;
10use crate::physical_plan::PhysicalOperator;
11
12pub struct HashJoinOp {
13 pub build: Box<PhysicalOperator>,
14 pub probe: Box<PhysicalOperator>,
15 pub build_keys: Vec<BoundExpression>,
16 pub probe_keys: Vec<BoundExpression>,
17 build_data: Option<BuildData>,
19}
20
21struct BuildData {
22 chunks: Vec<DataChunk>,
23 ht: HashMap<Vec<TypedValue>, Vec<(u32, u32)>>,
25 num_cols: usize,
26}
27
28impl HashJoinOp {
29 pub fn new(
30 build: PhysicalOperator,
31 probe: PhysicalOperator,
32 build_keys: Vec<BoundExpression>,
33 probe_keys: Vec<BoundExpression>,
34 ) -> Self {
35 Self {
36 build: Box::new(build),
37 probe: Box::new(probe),
38 build_keys,
39 probe_keys,
40 build_data: None,
41 }
42 }
43
44 pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
45 if self.build_data.is_none() {
47 let mut chunks = Vec::new();
48 let mut ht: HashMap<Vec<TypedValue>, Vec<(u32, u32)>> = HashMap::new();
49 while let Some(chunk) = self.build.next(ctx)? {
50 let ci = chunks.len() as u32;
51 for row_idx in 0..chunk.num_rows() {
52 let row_ref = chunk.row_ref(row_idx);
53 let key = eval_keys(&self.build_keys, &row_ref)?;
54 ht.entry(key).or_default().push((ci, row_idx as u32));
55 }
56 chunks.push(chunk);
57 }
58 let num_cols = chunks.first().map_or(0, |c| c.num_columns());
59 self.build_data = Some(BuildData {
60 chunks,
61 ht,
62 num_cols,
63 });
64 }
65
66 let bd = self.build_data.as_ref().unwrap();
67
68 loop {
70 let chunk = match self.probe.next(ctx)? {
71 Some(c) => c,
72 None => return Ok(None),
73 };
74
75 let probe_ncols = chunk.num_columns();
76 let total_cols = bd.num_cols + probe_ncols;
77 let mut result = DataChunk::with_capacity(total_cols, chunk.num_rows());
78
79 for row_idx in 0..chunk.num_rows() {
80 let row_ref = chunk.row_ref(row_idx);
81 let key = eval_keys(&self.probe_keys, &row_ref)?;
82 if let Some(build_locs) = bd.ht.get(&key) {
83 for &(ci, ri) in build_locs {
84 let build_chunk = &bd.chunks[ci as usize];
85 let mut combined = build_chunk.get_row(ri as usize);
86 for col_idx in 0..probe_ncols {
87 combined.push(chunk.get_value(row_idx, col_idx));
88 }
89 result.append_row(&combined);
90 }
91 }
92 }
93
94 if !result.is_empty() {
95 return Ok(Some(result));
96 }
97 }
99 }
100}
101
102fn eval_keys<T: kyu_expression::Tuple + ?Sized>(
103 keys: &[BoundExpression],
104 tuple: &T,
105) -> KyuResult<Vec<TypedValue>> {
106 keys.iter().map(|k| evaluate(k, tuple)).collect()
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use crate::context::MockStorage;
113 use kyu_types::LogicalType;
114 use smol_str::SmolStr;
115
116 #[test]
117 fn hash_join_basic() {
118 let mut storage = MockStorage::new();
119 storage.insert_table(
121 kyu_common::id::TableId(0),
122 vec![
123 vec![TypedValue::Int64(1), TypedValue::String(SmolStr::new("Alice"))],
124 vec![TypedValue::Int64(2), TypedValue::String(SmolStr::new("Bob"))],
125 ],
126 );
127 storage.insert_table(
129 kyu_common::id::TableId(1),
130 vec![
131 vec![TypedValue::Int64(1), TypedValue::Int64(100)],
132 vec![TypedValue::Int64(2), TypedValue::Int64(200)],
133 vec![TypedValue::Int64(3), TypedValue::Int64(300)],
134 ],
135 );
136 let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
137
138 let build = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
139 kyu_common::id::TableId(0),
140 ));
141 let probe = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
142 kyu_common::id::TableId(1),
143 ));
144
145 let build_key = BoundExpression::Variable {
147 index: 0,
148 result_type: LogicalType::Int64,
149 };
150 let probe_key = BoundExpression::Variable {
151 index: 0,
152 result_type: LogicalType::Int64,
153 };
154
155 let mut join = HashJoinOp::new(build, probe, vec![build_key], vec![probe_key]);
156 let chunk = join.next(&ctx).unwrap().unwrap();
157 assert_eq!(chunk.num_rows(), 2);
159 assert_eq!(chunk.num_columns(), 4);
160 assert!(join.next(&ctx).unwrap().is_none());
162 }
163}