Skip to main content

grafeo_core/execution/operators/
join.rs

1//! Join operators for combining data from two sources.
2//!
3//! This module provides:
4//! - `HashJoinOperator`: Efficient hash-based join for equality conditions
5//! - `NestedLoopJoinOperator`: General-purpose join for any condition
6
7use std::cmp::Ordering;
8use std::collections::HashMap;
9
10use arcstr::ArcStr;
11use grafeo_common::types::{LogicalType, Value};
12
13use super::{Operator, OperatorError, OperatorResult};
14use crate::execution::chunk::DataChunkBuilder;
15use crate::execution::{DataChunk, ValueVector};
16
17/// The type of join to perform.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19#[non_exhaustive]
20pub enum JoinType {
21    /// Inner join: only matching rows from both sides.
22    Inner,
23    /// Left outer join: all rows from left, matching from right (nulls if no match).
24    Left,
25    /// Right outer join: all rows from right, matching from left (nulls if no match).
26    Right,
27    /// Full outer join: all rows from both sides.
28    Full,
29    /// Cross join: cartesian product of both sides.
30    Cross,
31    /// Semi join: rows from left that have a match in right.
32    Semi,
33    /// Anti join: rows from left that have no match in right.
34    Anti,
35}
36
37/// A hash key that can be hashed and compared for join operations.
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39#[non_exhaustive]
40pub enum HashKey {
41    /// Null key.
42    Null,
43    /// Boolean key.
44    Bool(bool),
45    /// Integer key.
46    Int64(i64),
47    /// String key (cheap clone via ArcStr refcount).
48    String(ArcStr),
49    /// Byte content key.
50    Bytes(Vec<u8>),
51    /// Composite key for multi-column joins.
52    Composite(Vec<HashKey>),
53}
54
55impl Ord for HashKey {
56    fn cmp(&self, other: &Self) -> Ordering {
57        match (self, other) {
58            (HashKey::Null, HashKey::Null) => Ordering::Equal,
59            (HashKey::Null, _) => Ordering::Less,
60            (_, HashKey::Null) => Ordering::Greater,
61            (HashKey::Bool(a), HashKey::Bool(b)) => a.cmp(b),
62            (HashKey::Bool(_), _) => Ordering::Less,
63            (_, HashKey::Bool(_)) => Ordering::Greater,
64            (HashKey::Int64(a), HashKey::Int64(b)) => a.cmp(b),
65            (HashKey::Int64(_), _) => Ordering::Less,
66            (_, HashKey::Int64(_)) => Ordering::Greater,
67            (HashKey::String(a), HashKey::String(b)) => a.cmp(b),
68            (HashKey::String(_), _) => Ordering::Less,
69            (_, HashKey::String(_)) => Ordering::Greater,
70            (HashKey::Bytes(a), HashKey::Bytes(b)) => a.cmp(b),
71            (HashKey::Bytes(_), _) => Ordering::Less,
72            (_, HashKey::Bytes(_)) => Ordering::Greater,
73            (HashKey::Composite(a), HashKey::Composite(b)) => a.cmp(b),
74        }
75    }
76}
77
78impl PartialOrd for HashKey {
79    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
80        Some(self.cmp(other))
81    }
82}
83
84impl HashKey {
85    /// Creates a hash key from a Value.
86    pub fn from_value(value: &Value) -> Self {
87        match value {
88            Value::Null => HashKey::Null,
89            Value::Bool(b) => HashKey::Bool(*b),
90            Value::Int64(i) => HashKey::Int64(*i),
91            Value::Float64(f) => {
92                // Convert float to bits for consistent hashing
93                HashKey::Int64(f.to_bits() as i64)
94            }
95            Value::String(s) => HashKey::String(s.clone()),
96            Value::Bytes(b) => HashKey::Bytes(b.to_vec()),
97            Value::Timestamp(t) => HashKey::Int64(t.as_micros()),
98            Value::Date(d) => HashKey::Int64(d.as_days() as i64),
99            Value::Time(t) => HashKey::Int64(t.as_nanos() as i64),
100            Value::Duration(d) => HashKey::Composite(vec![
101                HashKey::Int64(d.months()),
102                HashKey::Int64(d.days()),
103                HashKey::Int64(d.nanos()),
104            ]),
105            Value::ZonedDatetime(zdt) => HashKey::Int64(zdt.as_timestamp().as_micros()),
106            Value::List(items) => {
107                HashKey::Composite(items.iter().map(HashKey::from_value).collect())
108            }
109            Value::Map(map) => {
110                // BTreeMap::iter() visits entries in ascending key order, so no sort needed.
111                let keys: Vec<_> = map
112                    .iter()
113                    .map(|(k, v)| {
114                        HashKey::Composite(vec![
115                            HashKey::String(ArcStr::from(k.as_str())),
116                            HashKey::from_value(v),
117                        ])
118                    })
119                    .collect();
120                HashKey::Composite(keys)
121            }
122            Value::Vector(v) => {
123                // Hash vectors by converting each f32 to its bit representation
124                HashKey::Composite(
125                    v.iter()
126                        .map(|f| HashKey::Int64(f.to_bits() as i64))
127                        .collect(),
128                )
129            }
130            Value::Path { nodes, edges } => {
131                let mut parts: Vec<_> = nodes.iter().map(HashKey::from_value).collect();
132                parts.extend(edges.iter().map(HashKey::from_value));
133                HashKey::Composite(parts)
134            }
135            // CRDT counters are opaque keys; hash by total logical value.
136            Value::GCounter(counts) => {
137                HashKey::Int64(counts.values().copied().map(|v| v as i64).sum())
138            }
139            Value::OnCounter { pos, neg } => {
140                let p: i64 = pos.values().copied().map(|v| v as i64).sum();
141                let n: i64 = neg.values().copied().map(|v| v as i64).sum();
142                HashKey::Int64(p - n)
143            }
144            _ => HashKey::Null,
145        }
146    }
147
148    /// Creates a hash key from a column value at a given row.
149    pub fn from_column(column: &ValueVector, row: usize) -> Option<Self> {
150        column.get_value(row).map(|v| Self::from_value(&v))
151    }
152}
153
154/// Hash join operator.
155///
156/// Builds a hash table from the build side (right) and probes with the probe side (left).
157/// Efficient for equality joins on one or more columns.
158pub struct HashJoinOperator {
159    /// Left (probe) side operator.
160    probe_side: Box<dyn Operator>,
161    /// Right (build) side operator.
162    build_side: Box<dyn Operator>,
163    /// Column indices on the probe side for join keys.
164    probe_keys: Vec<usize>,
165    /// Column indices on the build side for join keys.
166    build_keys: Vec<usize>,
167    /// Join type.
168    join_type: JoinType,
169    /// Output schema (combined from both sides).
170    output_schema: Vec<LogicalType>,
171    /// Hash table: key -> list of (chunk_index, row_index).
172    hash_table: HashMap<HashKey, Vec<(usize, usize)>>,
173    /// Materialized build side chunks.
174    build_chunks: Vec<DataChunk>,
175    /// Whether the build phase is complete.
176    build_complete: bool,
177    /// Current probe chunk being processed.
178    current_probe_chunk: Option<DataChunk>,
179    /// Current row in the probe chunk.
180    current_probe_row: usize,
181    /// Current position in the hash table matches for the current probe row.
182    current_match_position: usize,
183    /// Current matches for the current probe row.
184    current_matches: Vec<(usize, usize)>,
185    /// For left/full outer joins: track which probe rows had matches.
186    probe_matched: Vec<bool>,
187    /// For right/full outer joins: track which build rows were matched.
188    build_matched: Vec<Vec<bool>>,
189    /// Whether we're in the emit unmatched phase (for outer joins).
190    emitting_unmatched: bool,
191    /// Current chunk index when emitting unmatched rows.
192    unmatched_chunk_idx: usize,
193    /// Current row index when emitting unmatched rows.
194    unmatched_row_idx: usize,
195}
196
197impl HashJoinOperator {
198    /// Creates a new hash join operator.
199    ///
200    /// # Arguments
201    /// * `probe_side` - Left side operator (will be probed).
202    /// * `build_side` - Right side operator (will build hash table).
203    /// * `probe_keys` - Column indices on probe side for join keys.
204    /// * `build_keys` - Column indices on build side for join keys.
205    /// * `join_type` - Type of join to perform.
206    /// * `output_schema` - Schema of the output (probe columns + build columns).
207    pub fn new(
208        probe_side: Box<dyn Operator>,
209        build_side: Box<dyn Operator>,
210        probe_keys: Vec<usize>,
211        build_keys: Vec<usize>,
212        join_type: JoinType,
213        output_schema: Vec<LogicalType>,
214    ) -> Self {
215        Self {
216            probe_side,
217            build_side,
218            probe_keys,
219            build_keys,
220            join_type,
221            output_schema,
222            hash_table: HashMap::new(),
223            build_chunks: Vec::new(),
224            build_complete: false,
225            current_probe_chunk: None,
226            current_probe_row: 0,
227            current_match_position: 0,
228            current_matches: Vec::new(),
229            probe_matched: Vec::new(),
230            build_matched: Vec::new(),
231            emitting_unmatched: false,
232            unmatched_chunk_idx: 0,
233            unmatched_row_idx: 0,
234        }
235    }
236
237    /// Builds the hash table from the build side.
238    fn build_hash_table(&mut self) -> Result<(), OperatorError> {
239        while let Some(chunk) = self.build_side.next()? {
240            let chunk_idx = self.build_chunks.len();
241
242            // Initialize match tracking for outer joins
243            if matches!(self.join_type, JoinType::Right | JoinType::Full) {
244                self.build_matched.push(vec![false; chunk.row_count()]);
245            }
246
247            // Add each row to the hash table
248            for row in chunk.selected_indices() {
249                let key = self.extract_key(&chunk, row, &self.build_keys)?;
250
251                // Skip null keys for inner/semi/anti joins
252                if matches!(key, HashKey::Null)
253                    && !matches!(
254                        self.join_type,
255                        JoinType::Left | JoinType::Right | JoinType::Full
256                    )
257                {
258                    continue;
259                }
260
261                self.hash_table
262                    .entry(key)
263                    .or_default()
264                    .push((chunk_idx, row));
265            }
266
267            self.build_chunks.push(chunk);
268        }
269
270        self.build_complete = true;
271        Ok(())
272    }
273
274    /// Extracts a hash key from a chunk row.
275    fn extract_key(
276        &self,
277        chunk: &DataChunk,
278        row: usize,
279        key_columns: &[usize],
280    ) -> Result<HashKey, OperatorError> {
281        if key_columns.len() == 1 {
282            let col = chunk.column(key_columns[0]).ok_or_else(|| {
283                OperatorError::ColumnNotFound(format!("column {}", key_columns[0]))
284            })?;
285            Ok(HashKey::from_column(col, row).unwrap_or(HashKey::Null))
286        } else {
287            let keys: Vec<HashKey> = key_columns
288                .iter()
289                .map(|&col_idx| {
290                    chunk
291                        .column(col_idx)
292                        .and_then(|col| HashKey::from_column(col, row))
293                        .unwrap_or(HashKey::Null)
294                })
295                .collect();
296            Ok(HashKey::Composite(keys))
297        }
298    }
299
300    /// Produces an output row from a probe row and build row.
301    fn produce_output_row(
302        &self,
303        builder: &mut DataChunkBuilder,
304        probe_chunk: &DataChunk,
305        probe_row: usize,
306        build_chunk: Option<&DataChunk>,
307        build_row: Option<usize>,
308    ) -> Result<(), OperatorError> {
309        let probe_col_count = probe_chunk.column_count();
310
311        // Copy probe side columns
312        for col_idx in 0..probe_col_count {
313            let src_col = probe_chunk
314                .column(col_idx)
315                .ok_or_else(|| OperatorError::ColumnNotFound(format!("probe column {col_idx}")))?;
316            let dst_col = builder
317                .column_mut(col_idx)
318                .ok_or_else(|| OperatorError::ColumnNotFound(format!("output column {col_idx}")))?;
319
320            if let Some(value) = src_col.get_value(probe_row) {
321                dst_col.push_value(value);
322            } else {
323                dst_col.push_value(Value::Null);
324            }
325        }
326
327        // Copy build side columns
328        match (build_chunk, build_row) {
329            (Some(chunk), Some(row)) => {
330                for col_idx in 0..chunk.column_count() {
331                    let src_col = chunk.column(col_idx).ok_or_else(|| {
332                        OperatorError::ColumnNotFound(format!("build column {col_idx}"))
333                    })?;
334                    let dst_col =
335                        builder
336                            .column_mut(probe_col_count + col_idx)
337                            .ok_or_else(|| {
338                                OperatorError::ColumnNotFound(format!(
339                                    "output column {}",
340                                    probe_col_count + col_idx
341                                ))
342                            })?;
343
344                    if let Some(value) = src_col.get_value(row) {
345                        dst_col.push_value(value);
346                    } else {
347                        dst_col.push_value(Value::Null);
348                    }
349                }
350            }
351            _ => {
352                // Emit nulls for build side (left outer join case)
353                if !self.build_chunks.is_empty() {
354                    let build_col_count = self.build_chunks[0].column_count();
355                    for col_idx in 0..build_col_count {
356                        let dst_col =
357                            builder
358                                .column_mut(probe_col_count + col_idx)
359                                .ok_or_else(|| {
360                                    OperatorError::ColumnNotFound(format!(
361                                        "output column {}",
362                                        probe_col_count + col_idx
363                                    ))
364                                })?;
365                        dst_col.push_value(Value::Null);
366                    }
367                }
368            }
369        }
370
371        builder.advance_row();
372        Ok(())
373    }
374
375    /// Gets the next probe chunk.
376    fn get_next_probe_chunk(&mut self) -> Result<bool, OperatorError> {
377        let chunk = self.probe_side.next()?;
378        if let Some(ref c) = chunk {
379            // Initialize match tracking for outer joins
380            if matches!(self.join_type, JoinType::Left | JoinType::Full) {
381                self.probe_matched = vec![false; c.row_count()];
382            }
383        }
384        let has_chunk = chunk.is_some();
385        self.current_probe_chunk = chunk;
386        self.current_probe_row = 0;
387        Ok(has_chunk)
388    }
389
390    /// Emits unmatched build rows for right/full outer joins.
391    fn emit_unmatched_build(&mut self) -> OperatorResult {
392        if self.build_matched.is_empty() {
393            return Ok(None);
394        }
395
396        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
397
398        // Determine probe column count from schema or first probe chunk
399        let probe_col_count = if !self.build_chunks.is_empty() {
400            self.output_schema.len() - self.build_chunks[0].column_count()
401        } else {
402            0
403        };
404
405        while self.unmatched_chunk_idx < self.build_chunks.len() {
406            let chunk = &self.build_chunks[self.unmatched_chunk_idx];
407            let matched = &self.build_matched[self.unmatched_chunk_idx];
408
409            while self.unmatched_row_idx < matched.len() {
410                if !matched[self.unmatched_row_idx] {
411                    // This row was not matched - emit with nulls on probe side
412
413                    // Emit nulls for probe side
414                    for col_idx in 0..probe_col_count {
415                        if let Some(dst_col) = builder.column_mut(col_idx) {
416                            dst_col.push_value(Value::Null);
417                        }
418                    }
419
420                    // Copy build side values
421                    for col_idx in 0..chunk.column_count() {
422                        if let (Some(src_col), Some(dst_col)) = (
423                            chunk.column(col_idx),
424                            builder.column_mut(probe_col_count + col_idx),
425                        ) {
426                            if let Some(value) = src_col.get_value(self.unmatched_row_idx) {
427                                dst_col.push_value(value);
428                            } else {
429                                dst_col.push_value(Value::Null);
430                            }
431                        }
432                    }
433
434                    builder.advance_row();
435
436                    if builder.is_full() {
437                        self.unmatched_row_idx += 1;
438                        return Ok(Some(builder.finish()));
439                    }
440                }
441
442                self.unmatched_row_idx += 1;
443            }
444
445            self.unmatched_chunk_idx += 1;
446            self.unmatched_row_idx = 0;
447        }
448
449        if builder.row_count() > 0 {
450            Ok(Some(builder.finish()))
451        } else {
452            Ok(None)
453        }
454    }
455}
456
457impl Operator for HashJoinOperator {
458    fn next(&mut self) -> OperatorResult {
459        // Phase 1: Build hash table
460        if !self.build_complete {
461            self.build_hash_table()?;
462        }
463
464        // Phase 3: Emit unmatched build rows (right/full outer join)
465        if self.emitting_unmatched {
466            return self.emit_unmatched_build();
467        }
468
469        // Phase 2: Probe
470        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
471
472        loop {
473            // Get current probe chunk or fetch new one
474            if self.current_probe_chunk.is_none() && !self.get_next_probe_chunk()? {
475                // No more probe data
476                if matches!(self.join_type, JoinType::Right | JoinType::Full) {
477                    self.emitting_unmatched = true;
478                    return self.emit_unmatched_build();
479                }
480                return if builder.row_count() > 0 {
481                    Ok(Some(builder.finish()))
482                } else {
483                    Ok(None)
484                };
485            }
486
487            // Invariant: current_probe_chunk is Some here - the guard at line 396 either
488            // populates it via get_next_probe_chunk() or returns from the function
489            let probe_chunk = self
490                .current_probe_chunk
491                .as_ref()
492                .expect("probe chunk is Some: guard at line 396 ensures this");
493            let probe_rows: Vec<usize> = probe_chunk.selected_indices().collect();
494
495            while self.current_probe_row < probe_rows.len() {
496                let probe_row = probe_rows[self.current_probe_row];
497
498                // If we don't have current matches, look them up
499                if self.current_matches.is_empty() && self.current_match_position == 0 {
500                    let key = self.extract_key(probe_chunk, probe_row, &self.probe_keys)?;
501
502                    // Handle semi/anti joins differently
503                    match self.join_type {
504                        JoinType::Semi => {
505                            if self.hash_table.contains_key(&key) {
506                                // Emit probe row only
507                                for col_idx in 0..probe_chunk.column_count() {
508                                    if let (Some(src_col), Some(dst_col)) =
509                                        (probe_chunk.column(col_idx), builder.column_mut(col_idx))
510                                        && let Some(value) = src_col.get_value(probe_row)
511                                    {
512                                        dst_col.push_value(value);
513                                    }
514                                }
515                                builder.advance_row();
516                            }
517                            self.current_probe_row += 1;
518                            continue;
519                        }
520                        JoinType::Anti => {
521                            if !self.hash_table.contains_key(&key) {
522                                // Emit probe row only
523                                for col_idx in 0..probe_chunk.column_count() {
524                                    if let (Some(src_col), Some(dst_col)) =
525                                        (probe_chunk.column(col_idx), builder.column_mut(col_idx))
526                                        && let Some(value) = src_col.get_value(probe_row)
527                                    {
528                                        dst_col.push_value(value);
529                                    }
530                                }
531                                builder.advance_row();
532                            }
533                            self.current_probe_row += 1;
534                            continue;
535                        }
536                        _ => {
537                            self.current_matches =
538                                self.hash_table.get(&key).cloned().unwrap_or_default();
539                        }
540                    }
541                }
542
543                // Process matches
544                if self.current_matches.is_empty() {
545                    // No matches - for left/full outer join, emit with nulls
546                    if matches!(self.join_type, JoinType::Left | JoinType::Full) {
547                        self.produce_output_row(&mut builder, probe_chunk, probe_row, None, None)?;
548                    }
549                    self.current_probe_row += 1;
550                    self.current_match_position = 0;
551                } else {
552                    // Process each match
553                    while self.current_match_position < self.current_matches.len() {
554                        let (build_chunk_idx, build_row) =
555                            self.current_matches[self.current_match_position];
556                        let build_chunk = &self.build_chunks[build_chunk_idx];
557
558                        // Mark as matched for outer joins
559                        if matches!(self.join_type, JoinType::Left | JoinType::Full)
560                            && probe_row < self.probe_matched.len()
561                        {
562                            self.probe_matched[probe_row] = true;
563                        }
564                        if matches!(self.join_type, JoinType::Right | JoinType::Full)
565                            && build_chunk_idx < self.build_matched.len()
566                            && build_row < self.build_matched[build_chunk_idx].len()
567                        {
568                            self.build_matched[build_chunk_idx][build_row] = true;
569                        }
570
571                        self.produce_output_row(
572                            &mut builder,
573                            probe_chunk,
574                            probe_row,
575                            Some(build_chunk),
576                            Some(build_row),
577                        )?;
578
579                        self.current_match_position += 1;
580
581                        if builder.is_full() {
582                            return Ok(Some(builder.finish()));
583                        }
584                    }
585
586                    // Done with this probe row
587                    self.current_probe_row += 1;
588                    self.current_matches.clear();
589                    self.current_match_position = 0;
590                }
591
592                if builder.is_full() {
593                    return Ok(Some(builder.finish()));
594                }
595            }
596
597            // Done with current probe chunk
598            self.current_probe_chunk = None;
599            self.current_probe_row = 0;
600
601            if builder.row_count() > 0 {
602                return Ok(Some(builder.finish()));
603            }
604        }
605    }
606
607    fn reset(&mut self) {
608        self.probe_side.reset();
609        self.build_side.reset();
610        self.hash_table.clear();
611        self.build_chunks.clear();
612        self.build_complete = false;
613        self.current_probe_chunk = None;
614        self.current_probe_row = 0;
615        self.current_match_position = 0;
616        self.current_matches.clear();
617        self.probe_matched.clear();
618        self.build_matched.clear();
619        self.emitting_unmatched = false;
620        self.unmatched_chunk_idx = 0;
621        self.unmatched_row_idx = 0;
622    }
623
624    fn name(&self) -> &'static str {
625        "HashJoin"
626    }
627}
628
629/// Nested loop join operator.
630///
631/// Performs a cartesian product of both sides, filtering by the join condition.
632/// Less efficient than hash join but supports any join condition.
633pub struct NestedLoopJoinOperator {
634    /// Left side operator.
635    left: Box<dyn Operator>,
636    /// Right side operator.
637    right: Box<dyn Operator>,
638    /// Join condition predicate (if any).
639    condition: Option<Box<dyn JoinCondition>>,
640    /// Join type.
641    join_type: JoinType,
642    /// Output schema.
643    output_schema: Vec<LogicalType>,
644    /// Materialized right side chunks.
645    right_chunks: Vec<DataChunk>,
646    /// Whether the right side is materialized.
647    right_materialized: bool,
648    /// Current left chunk.
649    current_left_chunk: Option<DataChunk>,
650    /// Current row in the left chunk.
651    current_left_row: usize,
652    /// Current chunk index in the right side.
653    current_right_chunk: usize,
654    /// Whether the current left row has been matched (for Left Join).
655    current_left_matched: bool,
656    /// Current row in the current right chunk.
657    current_right_row: usize,
658}
659
660/// Trait for join conditions.
661pub trait JoinCondition: Send + Sync {
662    /// Evaluates the condition for a pair of rows.
663    fn evaluate(
664        &self,
665        left_chunk: &DataChunk,
666        left_row: usize,
667        right_chunk: &DataChunk,
668        right_row: usize,
669    ) -> bool;
670}
671
672/// A simple equality condition for nested loop joins.
673pub struct EqualityCondition {
674    /// Column index on the left side.
675    left_column: usize,
676    /// Column index on the right side.
677    right_column: usize,
678}
679
680impl EqualityCondition {
681    /// Creates a new equality condition.
682    pub fn new(left_column: usize, right_column: usize) -> Self {
683        Self {
684            left_column,
685            right_column,
686        }
687    }
688}
689
690impl JoinCondition for EqualityCondition {
691    fn evaluate(
692        &self,
693        left_chunk: &DataChunk,
694        left_row: usize,
695        right_chunk: &DataChunk,
696        right_row: usize,
697    ) -> bool {
698        let left_val = left_chunk
699            .column(self.left_column)
700            .and_then(|c| c.get_value(left_row));
701        let right_val = right_chunk
702            .column(self.right_column)
703            .and_then(|c| c.get_value(right_row));
704
705        match (left_val, right_val) {
706            (Some(l), Some(r)) => l == r,
707            _ => false,
708        }
709    }
710}
711
712impl NestedLoopJoinOperator {
713    /// Creates a new nested loop join operator.
714    pub fn new(
715        left: Box<dyn Operator>,
716        right: Box<dyn Operator>,
717        condition: Option<Box<dyn JoinCondition>>,
718        join_type: JoinType,
719        output_schema: Vec<LogicalType>,
720    ) -> Self {
721        Self {
722            left,
723            right,
724            condition,
725            join_type,
726            output_schema,
727            right_chunks: Vec::new(),
728            right_materialized: false,
729            current_left_chunk: None,
730            current_left_row: 0,
731            current_right_chunk: 0,
732            current_right_row: 0,
733            current_left_matched: false,
734        }
735    }
736
737    /// Materializes the right side.
738    fn materialize_right(&mut self) -> Result<(), OperatorError> {
739        while let Some(chunk) = self.right.next()? {
740            self.right_chunks.push(chunk);
741        }
742        self.right_materialized = true;
743        Ok(())
744    }
745
746    /// Produces an output row.
747    fn produce_row(
748        &self,
749        builder: &mut DataChunkBuilder,
750        left_chunk: &DataChunk,
751        left_row: usize,
752        right_chunk: &DataChunk,
753        right_row: usize,
754    ) {
755        // Copy left columns
756        for col_idx in 0..left_chunk.column_count() {
757            if let (Some(src), Some(dst)) =
758                (left_chunk.column(col_idx), builder.column_mut(col_idx))
759            {
760                if let Some(val) = src.get_value(left_row) {
761                    dst.push_value(val);
762                } else {
763                    dst.push_value(Value::Null);
764                }
765            }
766        }
767
768        // Copy right columns
769        let left_col_count = left_chunk.column_count();
770        for col_idx in 0..right_chunk.column_count() {
771            if let (Some(src), Some(dst)) = (
772                right_chunk.column(col_idx),
773                builder.column_mut(left_col_count + col_idx),
774            ) {
775                if let Some(val) = src.get_value(right_row) {
776                    dst.push_value(val);
777                } else {
778                    dst.push_value(Value::Null);
779                }
780            }
781        }
782
783        builder.advance_row();
784    }
785
786    /// Produces an output row with NULLs for the right side (for unmatched left rows in Left Join).
787    fn produce_left_unmatched_row(
788        &self,
789        builder: &mut DataChunkBuilder,
790        left_chunk: &DataChunk,
791        left_row: usize,
792        right_col_count: usize,
793    ) {
794        // Copy left columns
795        for col_idx in 0..left_chunk.column_count() {
796            if let (Some(src), Some(dst)) =
797                (left_chunk.column(col_idx), builder.column_mut(col_idx))
798            {
799                if let Some(val) = src.get_value(left_row) {
800                    dst.push_value(val);
801                } else {
802                    dst.push_value(Value::Null);
803                }
804            }
805        }
806
807        // Fill right columns with NULLs
808        let left_col_count = left_chunk.column_count();
809        for col_idx in 0..right_col_count {
810            if let Some(dst) = builder.column_mut(left_col_count + col_idx) {
811                dst.push_value(Value::Null);
812            }
813        }
814
815        builder.advance_row();
816    }
817}
818
819impl Operator for NestedLoopJoinOperator {
820    fn next(&mut self) -> OperatorResult {
821        // Materialize right side
822        if !self.right_materialized {
823            self.materialize_right()?;
824        }
825
826        // If right side is empty and not a left outer join, return nothing
827        if self.right_chunks.is_empty() && !matches!(self.join_type, JoinType::Left) {
828            return Ok(None);
829        }
830
831        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
832
833        loop {
834            // Get current left chunk
835            if self.current_left_chunk.is_none() {
836                self.current_left_chunk = self.left.next()?;
837                self.current_left_row = 0;
838                self.current_right_chunk = 0;
839                self.current_right_row = 0;
840
841                if self.current_left_chunk.is_none() {
842                    // No more left data
843                    return if builder.row_count() > 0 {
844                        Ok(Some(builder.finish()))
845                    } else {
846                        Ok(None)
847                    };
848                }
849            }
850
851            let left_chunk = self
852                .current_left_chunk
853                .as_ref()
854                .expect("left chunk is Some: loaded in loop above");
855            let left_rows: Vec<usize> = left_chunk.selected_indices().collect();
856
857            // Calculate right column count for potential unmatched rows
858            let right_col_count = if !self.right_chunks.is_empty() {
859                self.right_chunks[0].column_count()
860            } else {
861                // Infer from output schema
862                self.output_schema
863                    .len()
864                    .saturating_sub(left_chunk.column_count())
865            };
866
867            // Process current left row against all right rows
868            while self.current_left_row < left_rows.len() {
869                let left_row = left_rows[self.current_left_row];
870
871                // Reset match tracking for this left row
872                if self.current_right_chunk == 0 && self.current_right_row == 0 {
873                    self.current_left_matched = false;
874                }
875
876                // Cross join or inner/other join
877                while self.current_right_chunk < self.right_chunks.len() {
878                    let right_chunk = &self.right_chunks[self.current_right_chunk];
879                    let right_rows: Vec<usize> = right_chunk.selected_indices().collect();
880
881                    while self.current_right_row < right_rows.len() {
882                        let right_row = right_rows[self.current_right_row];
883
884                        // Check condition
885                        let matches = match &self.condition {
886                            Some(cond) => {
887                                cond.evaluate(left_chunk, left_row, right_chunk, right_row)
888                            }
889                            None => true, // Cross join
890                        };
891
892                        if matches {
893                            self.current_left_matched = true;
894                            self.produce_row(
895                                &mut builder,
896                                left_chunk,
897                                left_row,
898                                right_chunk,
899                                right_row,
900                            );
901
902                            if builder.is_full() {
903                                self.current_right_row += 1;
904                                return Ok(Some(builder.finish()));
905                            }
906                        }
907
908                        self.current_right_row += 1;
909                    }
910
911                    self.current_right_chunk += 1;
912                    self.current_right_row = 0;
913                }
914
915                // Done processing all right rows for this left row
916                // For Left Join, emit unmatched left row with NULLs
917                if matches!(self.join_type, JoinType::Left) && !self.current_left_matched {
918                    self.produce_left_unmatched_row(
919                        &mut builder,
920                        left_chunk,
921                        left_row,
922                        right_col_count,
923                    );
924
925                    if builder.is_full() {
926                        self.current_left_row += 1;
927                        self.current_right_chunk = 0;
928                        self.current_right_row = 0;
929                        return Ok(Some(builder.finish()));
930                    }
931                }
932
933                // Move to next left row
934                self.current_left_row += 1;
935                self.current_right_chunk = 0;
936                self.current_right_row = 0;
937            }
938
939            // Done with current left chunk
940            self.current_left_chunk = None;
941
942            if builder.row_count() > 0 {
943                return Ok(Some(builder.finish()));
944            }
945        }
946    }
947
948    fn reset(&mut self) {
949        self.left.reset();
950        self.right.reset();
951        self.right_chunks.clear();
952        self.right_materialized = false;
953        self.current_left_chunk = None;
954        self.current_left_row = 0;
955        self.current_right_chunk = 0;
956        self.current_right_row = 0;
957        self.current_left_matched = false;
958    }
959
960    fn name(&self) -> &'static str {
961        "NestedLoopJoin"
962    }
963}
964
965#[cfg(test)]
966mod tests {
967    use super::*;
968    use crate::execution::chunk::DataChunkBuilder;
969
970    /// Mock operator for testing.
971    struct MockOperator {
972        chunks: Vec<DataChunk>,
973        position: usize,
974    }
975
976    impl MockOperator {
977        fn new(chunks: Vec<DataChunk>) -> Self {
978            Self {
979                chunks,
980                position: 0,
981            }
982        }
983    }
984
985    impl Operator for MockOperator {
986        fn next(&mut self) -> OperatorResult {
987            if self.position < self.chunks.len() {
988                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
989                self.position += 1;
990                Ok(Some(chunk))
991            } else {
992                Ok(None)
993            }
994        }
995
996        fn reset(&mut self) {
997            self.position = 0;
998        }
999
1000        fn name(&self) -> &'static str {
1001            "Mock"
1002        }
1003    }
1004
1005    fn create_int_chunk(values: &[i64]) -> DataChunk {
1006        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1007        for &v in values {
1008            builder.column_mut(0).unwrap().push_int64(v);
1009            builder.advance_row();
1010        }
1011        builder.finish()
1012    }
1013
1014    #[test]
1015    fn test_hash_join_inner() {
1016        // Left: [1, 2, 3, 4]
1017        // Right: [2, 3, 4, 5]
1018        // Inner join on column 0 should produce: [2, 3, 4]
1019
1020        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
1021        let right = MockOperator::new(vec![create_int_chunk(&[2, 3, 4, 5])]);
1022
1023        let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
1024        let mut join = HashJoinOperator::new(
1025            Box::new(left),
1026            Box::new(right),
1027            vec![0],
1028            vec![0],
1029            JoinType::Inner,
1030            output_schema,
1031        );
1032
1033        let mut results = Vec::new();
1034        while let Some(chunk) = join.next().unwrap() {
1035            for row in chunk.selected_indices() {
1036                let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
1037                let right_val = chunk.column(1).unwrap().get_int64(row).unwrap();
1038                results.push((left_val, right_val));
1039            }
1040        }
1041
1042        results.sort_unstable();
1043        assert_eq!(results, vec![(2, 2), (3, 3), (4, 4)]);
1044    }
1045
1046    #[test]
1047    fn test_hash_join_left_outer() {
1048        // Left: [1, 2, 3]
1049        // Right: [2, 3]
1050        // Left outer join should produce: [(1, null), (2, 2), (3, 3)]
1051
1052        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3])]);
1053        let right = MockOperator::new(vec![create_int_chunk(&[2, 3])]);
1054
1055        let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
1056        let mut join = HashJoinOperator::new(
1057            Box::new(left),
1058            Box::new(right),
1059            vec![0],
1060            vec![0],
1061            JoinType::Left,
1062            output_schema,
1063        );
1064
1065        let mut results = Vec::new();
1066        while let Some(chunk) = join.next().unwrap() {
1067            for row in chunk.selected_indices() {
1068                let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
1069                let right_val = chunk.column(1).unwrap().get_int64(row);
1070                results.push((left_val, right_val));
1071            }
1072        }
1073
1074        results.sort_by_key(|(l, _)| *l);
1075        assert_eq!(results.len(), 3);
1076        assert_eq!(results[0], (1, None)); // No match
1077        assert_eq!(results[1], (2, Some(2)));
1078        assert_eq!(results[2], (3, Some(3)));
1079    }
1080
1081    #[test]
1082    fn test_nested_loop_cross_join() {
1083        // Left: [1, 2]
1084        // Right: [10, 20]
1085        // Cross join should produce: [(1,10), (1,20), (2,10), (2,20)]
1086
1087        let left = MockOperator::new(vec![create_int_chunk(&[1, 2])]);
1088        let right = MockOperator::new(vec![create_int_chunk(&[10, 20])]);
1089
1090        let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
1091        let mut join = NestedLoopJoinOperator::new(
1092            Box::new(left),
1093            Box::new(right),
1094            None,
1095            JoinType::Cross,
1096            output_schema,
1097        );
1098
1099        let mut results = Vec::new();
1100        while let Some(chunk) = join.next().unwrap() {
1101            for row in chunk.selected_indices() {
1102                let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
1103                let right_val = chunk.column(1).unwrap().get_int64(row).unwrap();
1104                results.push((left_val, right_val));
1105            }
1106        }
1107
1108        results.sort_unstable();
1109        assert_eq!(results, vec![(1, 10), (1, 20), (2, 10), (2, 20)]);
1110    }
1111
1112    #[test]
1113    fn test_hash_join_semi() {
1114        // Left: [1, 2, 3, 4]
1115        // Right: [2, 4]
1116        // Semi join should produce: [2, 4] (only left rows that have matches)
1117
1118        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
1119        let right = MockOperator::new(vec![create_int_chunk(&[2, 4])]);
1120
1121        // Semi join only outputs probe (left) columns
1122        let output_schema = vec![LogicalType::Int64];
1123        let mut join = HashJoinOperator::new(
1124            Box::new(left),
1125            Box::new(right),
1126            vec![0],
1127            vec![0],
1128            JoinType::Semi,
1129            output_schema,
1130        );
1131
1132        let mut results = Vec::new();
1133        while let Some(chunk) = join.next().unwrap() {
1134            for row in chunk.selected_indices() {
1135                let val = chunk.column(0).unwrap().get_int64(row).unwrap();
1136                results.push(val);
1137            }
1138        }
1139
1140        results.sort_unstable();
1141        assert_eq!(results, vec![2, 4]);
1142    }
1143
1144    #[test]
1145    fn test_hash_join_anti() {
1146        // Left: [1, 2, 3, 4]
1147        // Right: [2, 4]
1148        // Anti join should produce: [1, 3] (left rows with no matches)
1149
1150        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
1151        let right = MockOperator::new(vec![create_int_chunk(&[2, 4])]);
1152
1153        let output_schema = vec![LogicalType::Int64];
1154        let mut join = HashJoinOperator::new(
1155            Box::new(left),
1156            Box::new(right),
1157            vec![0],
1158            vec![0],
1159            JoinType::Anti,
1160            output_schema,
1161        );
1162
1163        let mut results = Vec::new();
1164        while let Some(chunk) = join.next().unwrap() {
1165            for row in chunk.selected_indices() {
1166                let val = chunk.column(0).unwrap().get_int64(row).unwrap();
1167                results.push(val);
1168            }
1169        }
1170
1171        results.sort_unstable();
1172        assert_eq!(results, vec![1, 3]);
1173    }
1174
1175    #[test]
1176    fn test_hash_key_from_map() {
1177        use grafeo_common::types::{PropertyKey, Value};
1178        use std::collections::BTreeMap;
1179        use std::sync::Arc;
1180
1181        let mut map = BTreeMap::new();
1182        map.insert(PropertyKey::new("key"), Value::Int64(42));
1183        let v = Value::Map(Arc::new(map));
1184        let key = HashKey::from_value(&v);
1185        // BTreeMap iterates in ascending key order, result is a Composite
1186        assert!(matches!(key, HashKey::Composite(_)));
1187
1188        // Two maps with the same content produce the same hash key
1189        let mut map2 = BTreeMap::new();
1190        map2.insert(PropertyKey::new("key"), Value::Int64(42));
1191        let v2 = Value::Map(Arc::new(map2));
1192        assert_eq!(HashKey::from_value(&v), HashKey::from_value(&v2));
1193    }
1194
1195    #[test]
1196    fn test_hash_key_from_map_empty() {
1197        use grafeo_common::types::Value;
1198        use std::collections::BTreeMap;
1199        use std::sync::Arc;
1200
1201        let v = Value::Map(Arc::new(BTreeMap::new()));
1202        let key = HashKey::from_value(&v);
1203        assert_eq!(key, HashKey::Composite(vec![]));
1204    }
1205
1206    #[test]
1207    fn test_hash_key_from_gcounter() {
1208        use grafeo_common::types::Value;
1209        use std::collections::HashMap;
1210        use std::sync::Arc;
1211
1212        let mut counts = HashMap::new();
1213        counts.insert("node-a".to_string(), 5u64);
1214        counts.insert("node-b".to_string(), 3u64);
1215        let v = Value::GCounter(Arc::new(counts));
1216        // GCounter hashes to sum of all values (5 + 3 = 8)
1217        assert_eq!(HashKey::from_value(&v), HashKey::Int64(8));
1218    }
1219
1220    #[test]
1221    fn test_hash_key_from_gcounter_empty() {
1222        use grafeo_common::types::Value;
1223        use std::collections::HashMap;
1224        use std::sync::Arc;
1225
1226        let v = Value::GCounter(Arc::new(HashMap::new()));
1227        assert_eq!(HashKey::from_value(&v), HashKey::Int64(0));
1228    }
1229
1230    #[test]
1231    fn test_hash_key_from_oncounter() {
1232        use grafeo_common::types::Value;
1233        use std::collections::HashMap;
1234        use std::sync::Arc;
1235
1236        let mut pos = HashMap::new();
1237        pos.insert("node-a".to_string(), 10u64);
1238        let mut neg = HashMap::new();
1239        neg.insert("node-a".to_string(), 3u64);
1240        let v = Value::OnCounter {
1241            pos: Arc::new(pos),
1242            neg: Arc::new(neg),
1243        };
1244        // OnCounter hashes to pos_sum - neg_sum = 10 - 3 = 7
1245        assert_eq!(HashKey::from_value(&v), HashKey::Int64(7));
1246    }
1247
1248    #[test]
1249    fn test_hash_key_from_oncounter_balanced() {
1250        use grafeo_common::types::Value;
1251        use std::collections::HashMap;
1252        use std::sync::Arc;
1253
1254        let mut pos = HashMap::new();
1255        pos.insert("r".to_string(), 5u64);
1256        let mut neg = HashMap::new();
1257        neg.insert("r".to_string(), 5u64);
1258        let v = Value::OnCounter {
1259            pos: Arc::new(pos),
1260            neg: Arc::new(neg),
1261        };
1262        assert_eq!(HashKey::from_value(&v), HashKey::Int64(0));
1263    }
1264}