kyu_executor/operators/
hash_join.rs1use hashbrown::HashMap;
4use kyu_common::KyuResult;
5use kyu_expression::{BoundExpression, evaluate};
6use kyu_types::TypedValue;
7
8use crate::context::ExecutionContext;
9use crate::data_chunk::DataChunk;
10use crate::physical_plan::PhysicalOperator;
11
12pub struct HashJoinOp {
13 pub build: Box<PhysicalOperator>,
14 pub probe: Box<PhysicalOperator>,
15 pub build_keys: Vec<BoundExpression>,
16 pub probe_keys: Vec<BoundExpression>,
17 build_data: Option<BuildData>,
19}
20
21struct BuildData {
22 chunks: Vec<DataChunk>,
23 ht: HashMap<Vec<TypedValue>, Vec<(u32, u32)>>,
25 num_cols: usize,
26}
27
28impl HashJoinOp {
29 pub fn new(
30 build: PhysicalOperator,
31 probe: PhysicalOperator,
32 build_keys: Vec<BoundExpression>,
33 probe_keys: Vec<BoundExpression>,
34 ) -> Self {
35 Self {
36 build: Box::new(build),
37 probe: Box::new(probe),
38 build_keys,
39 probe_keys,
40 build_data: None,
41 }
42 }
43
44 pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
45 if self.build_data.is_none() {
47 let mut chunks = Vec::new();
48 let mut ht: HashMap<Vec<TypedValue>, Vec<(u32, u32)>> = HashMap::new();
49 while let Some(chunk) = self.build.next(ctx)? {
50 let ci = chunks.len() as u32;
51 for row_idx in 0..chunk.num_rows() {
52 let row_ref = chunk.row_ref(row_idx);
53 let key = eval_keys(&self.build_keys, &row_ref)?;
54 ht.entry(key).or_default().push((ci, row_idx as u32));
55 }
56 chunks.push(chunk);
57 }
58 let num_cols = chunks.first().map_or(0, |c| c.num_columns());
59 self.build_data = Some(BuildData {
60 chunks,
61 ht,
62 num_cols,
63 });
64 }
65
66 let bd = self.build_data.as_ref().unwrap();
67
68 loop {
70 let chunk = match self.probe.next(ctx)? {
71 Some(c) => c,
72 None => return Ok(None),
73 };
74
75 let probe_ncols = chunk.num_columns();
76 let total_cols = bd.num_cols + probe_ncols;
77 let mut result = DataChunk::with_capacity(total_cols, chunk.num_rows());
78
79 for row_idx in 0..chunk.num_rows() {
80 let row_ref = chunk.row_ref(row_idx);
81 let key = eval_keys(&self.probe_keys, &row_ref)?;
82 if let Some(build_locs) = bd.ht.get(&key) {
83 for &(ci, ri) in build_locs {
84 let build_chunk = &bd.chunks[ci as usize];
85 let mut combined = build_chunk.get_row(ri as usize);
86 for col_idx in 0..probe_ncols {
87 combined.push(chunk.get_value(row_idx, col_idx));
88 }
89 result.append_row(&combined);
90 }
91 }
92 }
93
94 if !result.is_empty() {
95 return Ok(Some(result));
96 }
97 }
99 }
100}
101
102fn eval_keys<T: kyu_expression::Tuple + ?Sized>(
103 keys: &[BoundExpression],
104 tuple: &T,
105) -> KyuResult<Vec<TypedValue>> {
106 keys.iter().map(|k| evaluate(k, tuple)).collect()
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use crate::context::MockStorage;
113 use kyu_types::LogicalType;
114 use smol_str::SmolStr;
115
116 #[test]
117 fn hash_join_basic() {
118 let mut storage = MockStorage::new();
119 storage.insert_table(
121 kyu_common::id::TableId(0),
122 vec![
123 vec![
124 TypedValue::Int64(1),
125 TypedValue::String(SmolStr::new("Alice")),
126 ],
127 vec![
128 TypedValue::Int64(2),
129 TypedValue::String(SmolStr::new("Bob")),
130 ],
131 ],
132 );
133 storage.insert_table(
135 kyu_common::id::TableId(1),
136 vec![
137 vec![TypedValue::Int64(1), TypedValue::Int64(100)],
138 vec![TypedValue::Int64(2), TypedValue::Int64(200)],
139 vec![TypedValue::Int64(3), TypedValue::Int64(300)],
140 ],
141 );
142 let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
143
144 let build = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
145 kyu_common::id::TableId(0),
146 ));
147 let probe = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
148 kyu_common::id::TableId(1),
149 ));
150
151 let build_key = BoundExpression::Variable {
153 index: 0,
154 result_type: LogicalType::Int64,
155 };
156 let probe_key = BoundExpression::Variable {
157 index: 0,
158 result_type: LogicalType::Int64,
159 };
160
161 let mut join = HashJoinOp::new(build, probe, vec![build_key], vec![probe_key]);
162 let chunk = join.next(&ctx).unwrap().unwrap();
163 assert_eq!(chunk.num_rows(), 2);
165 assert_eq!(chunk.num_columns(), 4);
166 assert!(join.next(&ctx).unwrap().is_none());
168 }
169}