Skip to main content

grafeo_core/execution/operators/
join.rs

1//! Join operators for combining data from two sources.
2//!
3//! This module provides:
4//! - `HashJoinOperator`: Efficient hash-based join for equality conditions
5//! - `NestedLoopJoinOperator`: General-purpose join for any condition
6
7use std::cmp::Ordering;
8use std::collections::HashMap;
9
10use arcstr::ArcStr;
11use grafeo_common::types::{LogicalType, Value};
12
13use super::{Operator, OperatorError, OperatorResult};
14use crate::execution::chunk::DataChunkBuilder;
15use crate::execution::{DataChunk, ValueVector};
16
17/// The type of join to perform.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum JoinType {
20    /// Inner join: only matching rows from both sides.
21    Inner,
22    /// Left outer join: all rows from left, matching from right (nulls if no match).
23    Left,
24    /// Right outer join: all rows from right, matching from left (nulls if no match).
25    Right,
26    /// Full outer join: all rows from both sides.
27    Full,
28    /// Cross join: cartesian product of both sides.
29    Cross,
30    /// Semi join: rows from left that have a match in right.
31    Semi,
32    /// Anti join: rows from left that have no match in right.
33    Anti,
34}
35
36/// A hash key that can be hashed and compared for join operations.
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum HashKey {
39    /// Null key.
40    Null,
41    /// Boolean key.
42    Bool(bool),
43    /// Integer key.
44    Int64(i64),
45    /// String key (cheap clone via ArcStr refcount).
46    String(ArcStr),
47    /// Byte content key.
48    Bytes(Vec<u8>),
49    /// Composite key for multi-column joins.
50    Composite(Vec<HashKey>),
51}
52
53impl Ord for HashKey {
54    fn cmp(&self, other: &Self) -> Ordering {
55        match (self, other) {
56            (HashKey::Null, HashKey::Null) => Ordering::Equal,
57            (HashKey::Null, _) => Ordering::Less,
58            (_, HashKey::Null) => Ordering::Greater,
59            (HashKey::Bool(a), HashKey::Bool(b)) => a.cmp(b),
60            (HashKey::Bool(_), _) => Ordering::Less,
61            (_, HashKey::Bool(_)) => Ordering::Greater,
62            (HashKey::Int64(a), HashKey::Int64(b)) => a.cmp(b),
63            (HashKey::Int64(_), _) => Ordering::Less,
64            (_, HashKey::Int64(_)) => Ordering::Greater,
65            (HashKey::String(a), HashKey::String(b)) => a.cmp(b),
66            (HashKey::String(_), _) => Ordering::Less,
67            (_, HashKey::String(_)) => Ordering::Greater,
68            (HashKey::Bytes(a), HashKey::Bytes(b)) => a.cmp(b),
69            (HashKey::Bytes(_), _) => Ordering::Less,
70            (_, HashKey::Bytes(_)) => Ordering::Greater,
71            (HashKey::Composite(a), HashKey::Composite(b)) => a.cmp(b),
72        }
73    }
74}
75
76impl PartialOrd for HashKey {
77    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
78        Some(self.cmp(other))
79    }
80}
81
82impl HashKey {
83    /// Creates a hash key from a Value.
84    pub fn from_value(value: &Value) -> Self {
85        match value {
86            Value::Null => HashKey::Null,
87            Value::Bool(b) => HashKey::Bool(*b),
88            Value::Int64(i) => HashKey::Int64(*i),
89            Value::Float64(f) => {
90                // Convert float to bits for consistent hashing
91                HashKey::Int64(f.to_bits() as i64)
92            }
93            Value::String(s) => HashKey::String(s.clone()),
94            Value::Bytes(b) => HashKey::Bytes(b.to_vec()),
95            Value::Timestamp(t) => HashKey::Int64(t.as_micros()),
96            Value::Date(d) => HashKey::Int64(d.as_days() as i64),
97            Value::Time(t) => HashKey::Int64(t.as_nanos() as i64),
98            Value::Duration(d) => HashKey::Composite(vec![
99                HashKey::Int64(d.months()),
100                HashKey::Int64(d.days()),
101                HashKey::Int64(d.nanos()),
102            ]),
103            Value::ZonedDatetime(zdt) => HashKey::Int64(zdt.as_timestamp().as_micros()),
104            Value::List(items) => {
105                HashKey::Composite(items.iter().map(HashKey::from_value).collect())
106            }
107            Value::Map(map) => {
108                let mut keys: Vec<_> = map
109                    .iter()
110                    .map(|(k, v)| {
111                        HashKey::Composite(vec![
112                            HashKey::String(ArcStr::from(k.as_str())),
113                            HashKey::from_value(v),
114                        ])
115                    })
116                    .collect();
117                keys.sort();
118                HashKey::Composite(keys)
119            }
120            Value::Vector(v) => {
121                // Hash vectors by converting each f32 to its bit representation
122                HashKey::Composite(
123                    v.iter()
124                        .map(|f| HashKey::Int64(f.to_bits() as i64))
125                        .collect(),
126                )
127            }
128            Value::Path { nodes, edges } => {
129                let mut parts: Vec<_> = nodes.iter().map(HashKey::from_value).collect();
130                parts.extend(edges.iter().map(HashKey::from_value));
131                HashKey::Composite(parts)
132            }
133        }
134    }
135
136    /// Creates a hash key from a column value at a given row.
137    pub fn from_column(column: &ValueVector, row: usize) -> Option<Self> {
138        column.get_value(row).map(|v| Self::from_value(&v))
139    }
140}
141
142/// Hash join operator.
143///
144/// Builds a hash table from the build side (right) and probes with the probe side (left).
145/// Efficient for equality joins on one or more columns.
146pub struct HashJoinOperator {
147    /// Left (probe) side operator.
148    probe_side: Box<dyn Operator>,
149    /// Right (build) side operator.
150    build_side: Box<dyn Operator>,
151    /// Column indices on the probe side for join keys.
152    probe_keys: Vec<usize>,
153    /// Column indices on the build side for join keys.
154    build_keys: Vec<usize>,
155    /// Join type.
156    join_type: JoinType,
157    /// Output schema (combined from both sides).
158    output_schema: Vec<LogicalType>,
159    /// Hash table: key -> list of (chunk_index, row_index).
160    hash_table: HashMap<HashKey, Vec<(usize, usize)>>,
161    /// Materialized build side chunks.
162    build_chunks: Vec<DataChunk>,
163    /// Whether the build phase is complete.
164    build_complete: bool,
165    /// Current probe chunk being processed.
166    current_probe_chunk: Option<DataChunk>,
167    /// Current row in the probe chunk.
168    current_probe_row: usize,
169    /// Current position in the hash table matches for the current probe row.
170    current_match_position: usize,
171    /// Current matches for the current probe row.
172    current_matches: Vec<(usize, usize)>,
173    /// For left/full outer joins: track which probe rows had matches.
174    probe_matched: Vec<bool>,
175    /// For right/full outer joins: track which build rows were matched.
176    build_matched: Vec<Vec<bool>>,
177    /// Whether we're in the emit unmatched phase (for outer joins).
178    emitting_unmatched: bool,
179    /// Current chunk index when emitting unmatched rows.
180    unmatched_chunk_idx: usize,
181    /// Current row index when emitting unmatched rows.
182    unmatched_row_idx: usize,
183}
184
185impl HashJoinOperator {
186    /// Creates a new hash join operator.
187    ///
188    /// # Arguments
189    /// * `probe_side` - Left side operator (will be probed).
190    /// * `build_side` - Right side operator (will build hash table).
191    /// * `probe_keys` - Column indices on probe side for join keys.
192    /// * `build_keys` - Column indices on build side for join keys.
193    /// * `join_type` - Type of join to perform.
194    /// * `output_schema` - Schema of the output (probe columns + build columns).
195    pub fn new(
196        probe_side: Box<dyn Operator>,
197        build_side: Box<dyn Operator>,
198        probe_keys: Vec<usize>,
199        build_keys: Vec<usize>,
200        join_type: JoinType,
201        output_schema: Vec<LogicalType>,
202    ) -> Self {
203        Self {
204            probe_side,
205            build_side,
206            probe_keys,
207            build_keys,
208            join_type,
209            output_schema,
210            hash_table: HashMap::new(),
211            build_chunks: Vec::new(),
212            build_complete: false,
213            current_probe_chunk: None,
214            current_probe_row: 0,
215            current_match_position: 0,
216            current_matches: Vec::new(),
217            probe_matched: Vec::new(),
218            build_matched: Vec::new(),
219            emitting_unmatched: false,
220            unmatched_chunk_idx: 0,
221            unmatched_row_idx: 0,
222        }
223    }
224
225    /// Builds the hash table from the build side.
226    fn build_hash_table(&mut self) -> Result<(), OperatorError> {
227        while let Some(chunk) = self.build_side.next()? {
228            let chunk_idx = self.build_chunks.len();
229
230            // Initialize match tracking for outer joins
231            if matches!(self.join_type, JoinType::Right | JoinType::Full) {
232                self.build_matched.push(vec![false; chunk.row_count()]);
233            }
234
235            // Add each row to the hash table
236            for row in chunk.selected_indices() {
237                let key = self.extract_key(&chunk, row, &self.build_keys)?;
238
239                // Skip null keys for inner/semi/anti joins
240                if matches!(key, HashKey::Null)
241                    && !matches!(
242                        self.join_type,
243                        JoinType::Left | JoinType::Right | JoinType::Full
244                    )
245                {
246                    continue;
247                }
248
249                self.hash_table
250                    .entry(key)
251                    .or_default()
252                    .push((chunk_idx, row));
253            }
254
255            self.build_chunks.push(chunk);
256        }
257
258        self.build_complete = true;
259        Ok(())
260    }
261
262    /// Extracts a hash key from a chunk row.
263    fn extract_key(
264        &self,
265        chunk: &DataChunk,
266        row: usize,
267        key_columns: &[usize],
268    ) -> Result<HashKey, OperatorError> {
269        if key_columns.len() == 1 {
270            let col = chunk.column(key_columns[0]).ok_or_else(|| {
271                OperatorError::ColumnNotFound(format!("column {}", key_columns[0]))
272            })?;
273            Ok(HashKey::from_column(col, row).unwrap_or(HashKey::Null))
274        } else {
275            let keys: Vec<HashKey> = key_columns
276                .iter()
277                .map(|&col_idx| {
278                    chunk
279                        .column(col_idx)
280                        .and_then(|col| HashKey::from_column(col, row))
281                        .unwrap_or(HashKey::Null)
282                })
283                .collect();
284            Ok(HashKey::Composite(keys))
285        }
286    }
287
288    /// Produces an output row from a probe row and build row.
289    fn produce_output_row(
290        &self,
291        builder: &mut DataChunkBuilder,
292        probe_chunk: &DataChunk,
293        probe_row: usize,
294        build_chunk: Option<&DataChunk>,
295        build_row: Option<usize>,
296    ) -> Result<(), OperatorError> {
297        let probe_col_count = probe_chunk.column_count();
298
299        // Copy probe side columns
300        for col_idx in 0..probe_col_count {
301            let src_col = probe_chunk
302                .column(col_idx)
303                .ok_or_else(|| OperatorError::ColumnNotFound(format!("probe column {col_idx}")))?;
304            let dst_col = builder
305                .column_mut(col_idx)
306                .ok_or_else(|| OperatorError::ColumnNotFound(format!("output column {col_idx}")))?;
307
308            if let Some(value) = src_col.get_value(probe_row) {
309                dst_col.push_value(value);
310            } else {
311                dst_col.push_value(Value::Null);
312            }
313        }
314
315        // Copy build side columns
316        match (build_chunk, build_row) {
317            (Some(chunk), Some(row)) => {
318                for col_idx in 0..chunk.column_count() {
319                    let src_col = chunk.column(col_idx).ok_or_else(|| {
320                        OperatorError::ColumnNotFound(format!("build column {col_idx}"))
321                    })?;
322                    let dst_col =
323                        builder
324                            .column_mut(probe_col_count + col_idx)
325                            .ok_or_else(|| {
326                                OperatorError::ColumnNotFound(format!(
327                                    "output column {}",
328                                    probe_col_count + col_idx
329                                ))
330                            })?;
331
332                    if let Some(value) = src_col.get_value(row) {
333                        dst_col.push_value(value);
334                    } else {
335                        dst_col.push_value(Value::Null);
336                    }
337                }
338            }
339            _ => {
340                // Emit nulls for build side (left outer join case)
341                if !self.build_chunks.is_empty() {
342                    let build_col_count = self.build_chunks[0].column_count();
343                    for col_idx in 0..build_col_count {
344                        let dst_col =
345                            builder
346                                .column_mut(probe_col_count + col_idx)
347                                .ok_or_else(|| {
348                                    OperatorError::ColumnNotFound(format!(
349                                        "output column {}",
350                                        probe_col_count + col_idx
351                                    ))
352                                })?;
353                        dst_col.push_value(Value::Null);
354                    }
355                }
356            }
357        }
358
359        builder.advance_row();
360        Ok(())
361    }
362
363    /// Gets the next probe chunk.
364    fn get_next_probe_chunk(&mut self) -> Result<bool, OperatorError> {
365        let chunk = self.probe_side.next()?;
366        if let Some(ref c) = chunk {
367            // Initialize match tracking for outer joins
368            if matches!(self.join_type, JoinType::Left | JoinType::Full) {
369                self.probe_matched = vec![false; c.row_count()];
370            }
371        }
372        let has_chunk = chunk.is_some();
373        self.current_probe_chunk = chunk;
374        self.current_probe_row = 0;
375        Ok(has_chunk)
376    }
377
378    /// Emits unmatched build rows for right/full outer joins.
379    fn emit_unmatched_build(&mut self) -> OperatorResult {
380        if self.build_matched.is_empty() {
381            return Ok(None);
382        }
383
384        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
385
386        // Determine probe column count from schema or first probe chunk
387        let probe_col_count = if !self.build_chunks.is_empty() {
388            self.output_schema.len() - self.build_chunks[0].column_count()
389        } else {
390            0
391        };
392
393        while self.unmatched_chunk_idx < self.build_chunks.len() {
394            let chunk = &self.build_chunks[self.unmatched_chunk_idx];
395            let matched = &self.build_matched[self.unmatched_chunk_idx];
396
397            while self.unmatched_row_idx < matched.len() {
398                if !matched[self.unmatched_row_idx] {
399                    // This row was not matched - emit with nulls on probe side
400
401                    // Emit nulls for probe side
402                    for col_idx in 0..probe_col_count {
403                        if let Some(dst_col) = builder.column_mut(col_idx) {
404                            dst_col.push_value(Value::Null);
405                        }
406                    }
407
408                    // Copy build side values
409                    for col_idx in 0..chunk.column_count() {
410                        if let (Some(src_col), Some(dst_col)) = (
411                            chunk.column(col_idx),
412                            builder.column_mut(probe_col_count + col_idx),
413                        ) {
414                            if let Some(value) = src_col.get_value(self.unmatched_row_idx) {
415                                dst_col.push_value(value);
416                            } else {
417                                dst_col.push_value(Value::Null);
418                            }
419                        }
420                    }
421
422                    builder.advance_row();
423
424                    if builder.is_full() {
425                        self.unmatched_row_idx += 1;
426                        return Ok(Some(builder.finish()));
427                    }
428                }
429
430                self.unmatched_row_idx += 1;
431            }
432
433            self.unmatched_chunk_idx += 1;
434            self.unmatched_row_idx = 0;
435        }
436
437        if builder.row_count() > 0 {
438            Ok(Some(builder.finish()))
439        } else {
440            Ok(None)
441        }
442    }
443}
444
445impl Operator for HashJoinOperator {
446    fn next(&mut self) -> OperatorResult {
447        // Phase 1: Build hash table
448        if !self.build_complete {
449            self.build_hash_table()?;
450        }
451
452        // Phase 3: Emit unmatched build rows (right/full outer join)
453        if self.emitting_unmatched {
454            return self.emit_unmatched_build();
455        }
456
457        // Phase 2: Probe
458        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
459
460        loop {
461            // Get current probe chunk or fetch new one
462            if self.current_probe_chunk.is_none() && !self.get_next_probe_chunk()? {
463                // No more probe data
464                if matches!(self.join_type, JoinType::Right | JoinType::Full) {
465                    self.emitting_unmatched = true;
466                    return self.emit_unmatched_build();
467                }
468                return if builder.row_count() > 0 {
469                    Ok(Some(builder.finish()))
470                } else {
471                    Ok(None)
472                };
473            }
474
475            // Invariant: current_probe_chunk is Some here - the guard at line 396 either
476            // populates it via get_next_probe_chunk() or returns from the function
477            let probe_chunk = self
478                .current_probe_chunk
479                .as_ref()
480                .expect("probe chunk is Some: guard at line 396 ensures this");
481            let probe_rows: Vec<usize> = probe_chunk.selected_indices().collect();
482
483            while self.current_probe_row < probe_rows.len() {
484                let probe_row = probe_rows[self.current_probe_row];
485
486                // If we don't have current matches, look them up
487                if self.current_matches.is_empty() && self.current_match_position == 0 {
488                    let key = self.extract_key(probe_chunk, probe_row, &self.probe_keys)?;
489
490                    // Handle semi/anti joins differently
491                    match self.join_type {
492                        JoinType::Semi => {
493                            if self.hash_table.contains_key(&key) {
494                                // Emit probe row only
495                                for col_idx in 0..probe_chunk.column_count() {
496                                    if let (Some(src_col), Some(dst_col)) =
497                                        (probe_chunk.column(col_idx), builder.column_mut(col_idx))
498                                        && let Some(value) = src_col.get_value(probe_row)
499                                    {
500                                        dst_col.push_value(value);
501                                    }
502                                }
503                                builder.advance_row();
504                            }
505                            self.current_probe_row += 1;
506                            continue;
507                        }
508                        JoinType::Anti => {
509                            if !self.hash_table.contains_key(&key) {
510                                // Emit probe row only
511                                for col_idx in 0..probe_chunk.column_count() {
512                                    if let (Some(src_col), Some(dst_col)) =
513                                        (probe_chunk.column(col_idx), builder.column_mut(col_idx))
514                                        && let Some(value) = src_col.get_value(probe_row)
515                                    {
516                                        dst_col.push_value(value);
517                                    }
518                                }
519                                builder.advance_row();
520                            }
521                            self.current_probe_row += 1;
522                            continue;
523                        }
524                        _ => {
525                            self.current_matches =
526                                self.hash_table.get(&key).cloned().unwrap_or_default();
527                        }
528                    }
529                }
530
531                // Process matches
532                if self.current_matches.is_empty() {
533                    // No matches - for left/full outer join, emit with nulls
534                    if matches!(self.join_type, JoinType::Left | JoinType::Full) {
535                        self.produce_output_row(&mut builder, probe_chunk, probe_row, None, None)?;
536                    }
537                    self.current_probe_row += 1;
538                    self.current_match_position = 0;
539                } else {
540                    // Process each match
541                    while self.current_match_position < self.current_matches.len() {
542                        let (build_chunk_idx, build_row) =
543                            self.current_matches[self.current_match_position];
544                        let build_chunk = &self.build_chunks[build_chunk_idx];
545
546                        // Mark as matched for outer joins
547                        if matches!(self.join_type, JoinType::Left | JoinType::Full)
548                            && probe_row < self.probe_matched.len()
549                        {
550                            self.probe_matched[probe_row] = true;
551                        }
552                        if matches!(self.join_type, JoinType::Right | JoinType::Full)
553                            && build_chunk_idx < self.build_matched.len()
554                            && build_row < self.build_matched[build_chunk_idx].len()
555                        {
556                            self.build_matched[build_chunk_idx][build_row] = true;
557                        }
558
559                        self.produce_output_row(
560                            &mut builder,
561                            probe_chunk,
562                            probe_row,
563                            Some(build_chunk),
564                            Some(build_row),
565                        )?;
566
567                        self.current_match_position += 1;
568
569                        if builder.is_full() {
570                            return Ok(Some(builder.finish()));
571                        }
572                    }
573
574                    // Done with this probe row
575                    self.current_probe_row += 1;
576                    self.current_matches.clear();
577                    self.current_match_position = 0;
578                }
579
580                if builder.is_full() {
581                    return Ok(Some(builder.finish()));
582                }
583            }
584
585            // Done with current probe chunk
586            self.current_probe_chunk = None;
587            self.current_probe_row = 0;
588
589            if builder.row_count() > 0 {
590                return Ok(Some(builder.finish()));
591            }
592        }
593    }
594
595    fn reset(&mut self) {
596        self.probe_side.reset();
597        self.build_side.reset();
598        self.hash_table.clear();
599        self.build_chunks.clear();
600        self.build_complete = false;
601        self.current_probe_chunk = None;
602        self.current_probe_row = 0;
603        self.current_match_position = 0;
604        self.current_matches.clear();
605        self.probe_matched.clear();
606        self.build_matched.clear();
607        self.emitting_unmatched = false;
608        self.unmatched_chunk_idx = 0;
609        self.unmatched_row_idx = 0;
610    }
611
612    fn name(&self) -> &'static str {
613        "HashJoin"
614    }
615}
616
617/// Nested loop join operator.
618///
619/// Performs a cartesian product of both sides, filtering by the join condition.
620/// Less efficient than hash join but supports any join condition.
621pub struct NestedLoopJoinOperator {
622    /// Left side operator.
623    left: Box<dyn Operator>,
624    /// Right side operator.
625    right: Box<dyn Operator>,
626    /// Join condition predicate (if any).
627    condition: Option<Box<dyn JoinCondition>>,
628    /// Join type.
629    join_type: JoinType,
630    /// Output schema.
631    output_schema: Vec<LogicalType>,
632    /// Materialized right side chunks.
633    right_chunks: Vec<DataChunk>,
634    /// Whether the right side is materialized.
635    right_materialized: bool,
636    /// Current left chunk.
637    current_left_chunk: Option<DataChunk>,
638    /// Current row in the left chunk.
639    current_left_row: usize,
640    /// Current chunk index in the right side.
641    current_right_chunk: usize,
642    /// Whether the current left row has been matched (for Left Join).
643    current_left_matched: bool,
644    /// Current row in the current right chunk.
645    current_right_row: usize,
646}
647
648/// Trait for join conditions.
649pub trait JoinCondition: Send + Sync {
650    /// Evaluates the condition for a pair of rows.
651    fn evaluate(
652        &self,
653        left_chunk: &DataChunk,
654        left_row: usize,
655        right_chunk: &DataChunk,
656        right_row: usize,
657    ) -> bool;
658}
659
660/// A simple equality condition for nested loop joins.
661pub struct EqualityCondition {
662    /// Column index on the left side.
663    left_column: usize,
664    /// Column index on the right side.
665    right_column: usize,
666}
667
668impl EqualityCondition {
669    /// Creates a new equality condition.
670    pub fn new(left_column: usize, right_column: usize) -> Self {
671        Self {
672            left_column,
673            right_column,
674        }
675    }
676}
677
678impl JoinCondition for EqualityCondition {
679    fn evaluate(
680        &self,
681        left_chunk: &DataChunk,
682        left_row: usize,
683        right_chunk: &DataChunk,
684        right_row: usize,
685    ) -> bool {
686        let left_val = left_chunk
687            .column(self.left_column)
688            .and_then(|c| c.get_value(left_row));
689        let right_val = right_chunk
690            .column(self.right_column)
691            .and_then(|c| c.get_value(right_row));
692
693        match (left_val, right_val) {
694            (Some(l), Some(r)) => l == r,
695            _ => false,
696        }
697    }
698}
699
700impl NestedLoopJoinOperator {
701    /// Creates a new nested loop join operator.
702    pub fn new(
703        left: Box<dyn Operator>,
704        right: Box<dyn Operator>,
705        condition: Option<Box<dyn JoinCondition>>,
706        join_type: JoinType,
707        output_schema: Vec<LogicalType>,
708    ) -> Self {
709        Self {
710            left,
711            right,
712            condition,
713            join_type,
714            output_schema,
715            right_chunks: Vec::new(),
716            right_materialized: false,
717            current_left_chunk: None,
718            current_left_row: 0,
719            current_right_chunk: 0,
720            current_right_row: 0,
721            current_left_matched: false,
722        }
723    }
724
725    /// Materializes the right side.
726    fn materialize_right(&mut self) -> Result<(), OperatorError> {
727        while let Some(chunk) = self.right.next()? {
728            self.right_chunks.push(chunk);
729        }
730        self.right_materialized = true;
731        Ok(())
732    }
733
734    /// Produces an output row.
735    fn produce_row(
736        &self,
737        builder: &mut DataChunkBuilder,
738        left_chunk: &DataChunk,
739        left_row: usize,
740        right_chunk: &DataChunk,
741        right_row: usize,
742    ) {
743        // Copy left columns
744        for col_idx in 0..left_chunk.column_count() {
745            if let (Some(src), Some(dst)) =
746                (left_chunk.column(col_idx), builder.column_mut(col_idx))
747            {
748                if let Some(val) = src.get_value(left_row) {
749                    dst.push_value(val);
750                } else {
751                    dst.push_value(Value::Null);
752                }
753            }
754        }
755
756        // Copy right columns
757        let left_col_count = left_chunk.column_count();
758        for col_idx in 0..right_chunk.column_count() {
759            if let (Some(src), Some(dst)) = (
760                right_chunk.column(col_idx),
761                builder.column_mut(left_col_count + col_idx),
762            ) {
763                if let Some(val) = src.get_value(right_row) {
764                    dst.push_value(val);
765                } else {
766                    dst.push_value(Value::Null);
767                }
768            }
769        }
770
771        builder.advance_row();
772    }
773
774    /// Produces an output row with NULLs for the right side (for unmatched left rows in Left Join).
775    fn produce_left_unmatched_row(
776        &self,
777        builder: &mut DataChunkBuilder,
778        left_chunk: &DataChunk,
779        left_row: usize,
780        right_col_count: usize,
781    ) {
782        // Copy left columns
783        for col_idx in 0..left_chunk.column_count() {
784            if let (Some(src), Some(dst)) =
785                (left_chunk.column(col_idx), builder.column_mut(col_idx))
786            {
787                if let Some(val) = src.get_value(left_row) {
788                    dst.push_value(val);
789                } else {
790                    dst.push_value(Value::Null);
791                }
792            }
793        }
794
795        // Fill right columns with NULLs
796        let left_col_count = left_chunk.column_count();
797        for col_idx in 0..right_col_count {
798            if let Some(dst) = builder.column_mut(left_col_count + col_idx) {
799                dst.push_value(Value::Null);
800            }
801        }
802
803        builder.advance_row();
804    }
805}
806
807impl Operator for NestedLoopJoinOperator {
808    fn next(&mut self) -> OperatorResult {
809        // Materialize right side
810        if !self.right_materialized {
811            self.materialize_right()?;
812        }
813
814        // If right side is empty and not a left outer join, return nothing
815        if self.right_chunks.is_empty() && !matches!(self.join_type, JoinType::Left) {
816            return Ok(None);
817        }
818
819        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
820
821        loop {
822            // Get current left chunk
823            if self.current_left_chunk.is_none() {
824                self.current_left_chunk = self.left.next()?;
825                self.current_left_row = 0;
826                self.current_right_chunk = 0;
827                self.current_right_row = 0;
828
829                if self.current_left_chunk.is_none() {
830                    // No more left data
831                    return if builder.row_count() > 0 {
832                        Ok(Some(builder.finish()))
833                    } else {
834                        Ok(None)
835                    };
836                }
837            }
838
839            let left_chunk = self
840                .current_left_chunk
841                .as_ref()
842                .expect("left chunk is Some: loaded in loop above");
843            let left_rows: Vec<usize> = left_chunk.selected_indices().collect();
844
845            // Calculate right column count for potential unmatched rows
846            let right_col_count = if !self.right_chunks.is_empty() {
847                self.right_chunks[0].column_count()
848            } else {
849                // Infer from output schema
850                self.output_schema
851                    .len()
852                    .saturating_sub(left_chunk.column_count())
853            };
854
855            // Process current left row against all right rows
856            while self.current_left_row < left_rows.len() {
857                let left_row = left_rows[self.current_left_row];
858
859                // Reset match tracking for this left row
860                if self.current_right_chunk == 0 && self.current_right_row == 0 {
861                    self.current_left_matched = false;
862                }
863
864                // Cross join or inner/other join
865                while self.current_right_chunk < self.right_chunks.len() {
866                    let right_chunk = &self.right_chunks[self.current_right_chunk];
867                    let right_rows: Vec<usize> = right_chunk.selected_indices().collect();
868
869                    while self.current_right_row < right_rows.len() {
870                        let right_row = right_rows[self.current_right_row];
871
872                        // Check condition
873                        let matches = match &self.condition {
874                            Some(cond) => {
875                                cond.evaluate(left_chunk, left_row, right_chunk, right_row)
876                            }
877                            None => true, // Cross join
878                        };
879
880                        if matches {
881                            self.current_left_matched = true;
882                            self.produce_row(
883                                &mut builder,
884                                left_chunk,
885                                left_row,
886                                right_chunk,
887                                right_row,
888                            );
889
890                            if builder.is_full() {
891                                self.current_right_row += 1;
892                                return Ok(Some(builder.finish()));
893                            }
894                        }
895
896                        self.current_right_row += 1;
897                    }
898
899                    self.current_right_chunk += 1;
900                    self.current_right_row = 0;
901                }
902
903                // Done processing all right rows for this left row
904                // For Left Join, emit unmatched left row with NULLs
905                if matches!(self.join_type, JoinType::Left) && !self.current_left_matched {
906                    self.produce_left_unmatched_row(
907                        &mut builder,
908                        left_chunk,
909                        left_row,
910                        right_col_count,
911                    );
912
913                    if builder.is_full() {
914                        self.current_left_row += 1;
915                        self.current_right_chunk = 0;
916                        self.current_right_row = 0;
917                        return Ok(Some(builder.finish()));
918                    }
919                }
920
921                // Move to next left row
922                self.current_left_row += 1;
923                self.current_right_chunk = 0;
924                self.current_right_row = 0;
925            }
926
927            // Done with current left chunk
928            self.current_left_chunk = None;
929
930            if builder.row_count() > 0 {
931                return Ok(Some(builder.finish()));
932            }
933        }
934    }
935
936    fn reset(&mut self) {
937        self.left.reset();
938        self.right.reset();
939        self.right_chunks.clear();
940        self.right_materialized = false;
941        self.current_left_chunk = None;
942        self.current_left_row = 0;
943        self.current_right_chunk = 0;
944        self.current_right_row = 0;
945        self.current_left_matched = false;
946    }
947
948    fn name(&self) -> &'static str {
949        "NestedLoopJoin"
950    }
951}
952
953#[cfg(test)]
954mod tests {
955    use super::*;
956    use crate::execution::chunk::DataChunkBuilder;
957
958    /// Mock operator for testing.
959    struct MockOperator {
960        chunks: Vec<DataChunk>,
961        position: usize,
962    }
963
964    impl MockOperator {
965        fn new(chunks: Vec<DataChunk>) -> Self {
966            Self {
967                chunks,
968                position: 0,
969            }
970        }
971    }
972
973    impl Operator for MockOperator {
974        fn next(&mut self) -> OperatorResult {
975            if self.position < self.chunks.len() {
976                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
977                self.position += 1;
978                Ok(Some(chunk))
979            } else {
980                Ok(None)
981            }
982        }
983
984        fn reset(&mut self) {
985            self.position = 0;
986        }
987
988        fn name(&self) -> &'static str {
989            "Mock"
990        }
991    }
992
993    fn create_int_chunk(values: &[i64]) -> DataChunk {
994        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
995        for &v in values {
996            builder.column_mut(0).unwrap().push_int64(v);
997            builder.advance_row();
998        }
999        builder.finish()
1000    }
1001
1002    #[test]
1003    fn test_hash_join_inner() {
1004        // Left: [1, 2, 3, 4]
1005        // Right: [2, 3, 4, 5]
1006        // Inner join on column 0 should produce: [2, 3, 4]
1007
1008        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
1009        let right = MockOperator::new(vec![create_int_chunk(&[2, 3, 4, 5])]);
1010
1011        let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
1012        let mut join = HashJoinOperator::new(
1013            Box::new(left),
1014            Box::new(right),
1015            vec![0],
1016            vec![0],
1017            JoinType::Inner,
1018            output_schema,
1019        );
1020
1021        let mut results = Vec::new();
1022        while let Some(chunk) = join.next().unwrap() {
1023            for row in chunk.selected_indices() {
1024                let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
1025                let right_val = chunk.column(1).unwrap().get_int64(row).unwrap();
1026                results.push((left_val, right_val));
1027            }
1028        }
1029
1030        results.sort_unstable();
1031        assert_eq!(results, vec![(2, 2), (3, 3), (4, 4)]);
1032    }
1033
1034    #[test]
1035    fn test_hash_join_left_outer() {
1036        // Left: [1, 2, 3]
1037        // Right: [2, 3]
1038        // Left outer join should produce: [(1, null), (2, 2), (3, 3)]
1039
1040        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3])]);
1041        let right = MockOperator::new(vec![create_int_chunk(&[2, 3])]);
1042
1043        let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
1044        let mut join = HashJoinOperator::new(
1045            Box::new(left),
1046            Box::new(right),
1047            vec![0],
1048            vec![0],
1049            JoinType::Left,
1050            output_schema,
1051        );
1052
1053        let mut results = Vec::new();
1054        while let Some(chunk) = join.next().unwrap() {
1055            for row in chunk.selected_indices() {
1056                let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
1057                let right_val = chunk.column(1).unwrap().get_int64(row);
1058                results.push((left_val, right_val));
1059            }
1060        }
1061
1062        results.sort_by_key(|(l, _)| *l);
1063        assert_eq!(results.len(), 3);
1064        assert_eq!(results[0], (1, None)); // No match
1065        assert_eq!(results[1], (2, Some(2)));
1066        assert_eq!(results[2], (3, Some(3)));
1067    }
1068
1069    #[test]
1070    fn test_nested_loop_cross_join() {
1071        // Left: [1, 2]
1072        // Right: [10, 20]
1073        // Cross join should produce: [(1,10), (1,20), (2,10), (2,20)]
1074
1075        let left = MockOperator::new(vec![create_int_chunk(&[1, 2])]);
1076        let right = MockOperator::new(vec![create_int_chunk(&[10, 20])]);
1077
1078        let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
1079        let mut join = NestedLoopJoinOperator::new(
1080            Box::new(left),
1081            Box::new(right),
1082            None,
1083            JoinType::Cross,
1084            output_schema,
1085        );
1086
1087        let mut results = Vec::new();
1088        while let Some(chunk) = join.next().unwrap() {
1089            for row in chunk.selected_indices() {
1090                let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
1091                let right_val = chunk.column(1).unwrap().get_int64(row).unwrap();
1092                results.push((left_val, right_val));
1093            }
1094        }
1095
1096        results.sort_unstable();
1097        assert_eq!(results, vec![(1, 10), (1, 20), (2, 10), (2, 20)]);
1098    }
1099
1100    #[test]
1101    fn test_hash_join_semi() {
1102        // Left: [1, 2, 3, 4]
1103        // Right: [2, 4]
1104        // Semi join should produce: [2, 4] (only left rows that have matches)
1105
1106        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
1107        let right = MockOperator::new(vec![create_int_chunk(&[2, 4])]);
1108
1109        // Semi join only outputs probe (left) columns
1110        let output_schema = vec![LogicalType::Int64];
1111        let mut join = HashJoinOperator::new(
1112            Box::new(left),
1113            Box::new(right),
1114            vec![0],
1115            vec![0],
1116            JoinType::Semi,
1117            output_schema,
1118        );
1119
1120        let mut results = Vec::new();
1121        while let Some(chunk) = join.next().unwrap() {
1122            for row in chunk.selected_indices() {
1123                let val = chunk.column(0).unwrap().get_int64(row).unwrap();
1124                results.push(val);
1125            }
1126        }
1127
1128        results.sort_unstable();
1129        assert_eq!(results, vec![2, 4]);
1130    }
1131
1132    #[test]
1133    fn test_hash_join_anti() {
1134        // Left: [1, 2, 3, 4]
1135        // Right: [2, 4]
1136        // Anti join should produce: [1, 3] (left rows with no matches)
1137
1138        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
1139        let right = MockOperator::new(vec![create_int_chunk(&[2, 4])]);
1140
1141        let output_schema = vec![LogicalType::Int64];
1142        let mut join = HashJoinOperator::new(
1143            Box::new(left),
1144            Box::new(right),
1145            vec![0],
1146            vec![0],
1147            JoinType::Anti,
1148            output_schema,
1149        );
1150
1151        let mut results = Vec::new();
1152        while let Some(chunk) = join.next().unwrap() {
1153            for row in chunk.selected_indices() {
1154                let val = chunk.column(0).unwrap().get_int64(row).unwrap();
1155                results.push(val);
1156            }
1157        }
1158
1159        results.sort_unstable();
1160        assert_eq!(results, vec![1, 3]);
1161    }
1162}