Skip to main content

grafeo_core/execution/operators/
join.rs

1//! Join operators for combining data from two sources.
2//!
3//! This module provides:
4//! - `HashJoinOperator`: Efficient hash-based join for equality conditions
5//! - `NestedLoopJoinOperator`: General-purpose join for any condition
6
7use std::cmp::Ordering;
8use std::collections::HashMap;
9
10use arcstr::ArcStr;
11use grafeo_common::types::{LogicalType, Value};
12
13use super::{Operator, OperatorError, OperatorResult};
14use crate::execution::chunk::DataChunkBuilder;
15use crate::execution::{DataChunk, ValueVector};
16
17/// The type of join to perform.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19#[non_exhaustive]
20pub enum JoinType {
21    /// Inner join: only matching rows from both sides.
22    Inner,
23    /// Left outer join: all rows from left, matching from right (nulls if no match).
24    Left,
25    /// Right outer join: all rows from right, matching from left (nulls if no match).
26    Right,
27    /// Full outer join: all rows from both sides.
28    Full,
29    /// Cross join: cartesian product of both sides.
30    Cross,
31    /// Semi join: rows from left that have a match in right.
32    Semi,
33    /// Anti join: rows from left that have no match in right.
34    Anti,
35}
36
37/// A hash key that can be hashed and compared for join operations.
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39#[non_exhaustive]
40pub enum HashKey {
41    /// Null key.
42    Null,
43    /// Boolean key.
44    Bool(bool),
45    /// Integer key.
46    Int64(i64),
47    /// String key (cheap clone via ArcStr refcount).
48    String(ArcStr),
49    /// Byte content key.
50    Bytes(Vec<u8>),
51    /// Composite key for multi-column joins.
52    Composite(Vec<HashKey>),
53}
54
55impl Ord for HashKey {
56    fn cmp(&self, other: &Self) -> Ordering {
57        match (self, other) {
58            (HashKey::Null, HashKey::Null) => Ordering::Equal,
59            (HashKey::Null, _) => Ordering::Less,
60            (_, HashKey::Null) => Ordering::Greater,
61            (HashKey::Bool(a), HashKey::Bool(b)) => a.cmp(b),
62            (HashKey::Bool(_), _) => Ordering::Less,
63            (_, HashKey::Bool(_)) => Ordering::Greater,
64            (HashKey::Int64(a), HashKey::Int64(b)) => a.cmp(b),
65            (HashKey::Int64(_), _) => Ordering::Less,
66            (_, HashKey::Int64(_)) => Ordering::Greater,
67            (HashKey::String(a), HashKey::String(b)) => a.cmp(b),
68            (HashKey::String(_), _) => Ordering::Less,
69            (_, HashKey::String(_)) => Ordering::Greater,
70            (HashKey::Bytes(a), HashKey::Bytes(b)) => a.cmp(b),
71            (HashKey::Bytes(_), _) => Ordering::Less,
72            (_, HashKey::Bytes(_)) => Ordering::Greater,
73            (HashKey::Composite(a), HashKey::Composite(b)) => a.cmp(b),
74        }
75    }
76}
77
78impl PartialOrd for HashKey {
79    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
80        Some(self.cmp(other))
81    }
82}
83
84impl HashKey {
85    /// Creates a hash key from a Value.
86    pub fn from_value(value: &Value) -> Self {
87        match value {
88            Value::Null => HashKey::Null,
89            Value::Bool(b) => HashKey::Bool(*b),
90            Value::Int64(i) => HashKey::Int64(*i),
91            Value::Float64(f) => {
92                // Convert float to bits for consistent hashing
93                // reason: intentional bit-level reinterpretation for hashing
94                #[allow(clippy::cast_possible_wrap)]
95                HashKey::Int64(f.to_bits() as i64)
96            }
97            Value::String(s) => HashKey::String(s.clone()),
98            Value::Bytes(b) => HashKey::Bytes(b.to_vec()),
99            Value::Timestamp(t) => HashKey::Int64(t.as_micros()),
100            // reason: date days and time nanos fit i64
101            #[allow(clippy::cast_possible_wrap)]
102            Value::Date(d) => HashKey::Int64(d.as_days() as i64),
103            // reason: date days and time nanos are small, fit i64
104            #[allow(clippy::cast_possible_wrap)]
105            Value::Time(t) => HashKey::Int64(t.as_nanos() as i64),
106            Value::Duration(d) => HashKey::Composite(vec![
107                HashKey::Int64(d.months()),
108                HashKey::Int64(d.days()),
109                HashKey::Int64(d.nanos()),
110            ]),
111            Value::ZonedDatetime(zdt) => HashKey::Int64(zdt.as_timestamp().as_micros()),
112            Value::List(items) => {
113                HashKey::Composite(items.iter().map(HashKey::from_value).collect())
114            }
115            Value::Map(map) => {
116                // BTreeMap::iter() visits entries in ascending key order, so no sort needed.
117                let keys: Vec<_> = map
118                    .iter()
119                    .map(|(k, v)| {
120                        HashKey::Composite(vec![
121                            HashKey::String(ArcStr::from(k.as_str())),
122                            HashKey::from_value(v),
123                        ])
124                    })
125                    .collect();
126                HashKey::Composite(keys)
127            }
128            Value::Vector(v) => {
129                // Hash vectors by converting each f32 to its bit representation
130                HashKey::Composite(
131                    v.iter()
132                        .map(|f| HashKey::Int64(f.to_bits() as i64))
133                        .collect(),
134                )
135            }
136            Value::Path { nodes, edges } => {
137                let mut parts: Vec<_> = nodes.iter().map(HashKey::from_value).collect();
138                parts.extend(edges.iter().map(HashKey::from_value));
139                HashKey::Composite(parts)
140            }
141            // CRDT counters are opaque keys; hash by total logical value.
142            Value::GCounter(counts) => {
143                // reason: CRDT counter values are practically small, wrap is acceptable for hashing
144                #[allow(clippy::cast_possible_wrap)]
145                HashKey::Int64(counts.values().copied().map(|v| v as i64).sum())
146            }
147            Value::OnCounter { pos, neg } => {
148                // reason: CRDT counter values are practically small, wrap is acceptable for hashing
149                #[allow(clippy::cast_possible_wrap)]
150                // reason: GCounter values are small, sum fits i64
151                let p: i64 = pos.values().copied().map(|v| v as i64).sum();
152                // reason: GCounter values are small, sum fits i64
153                #[allow(clippy::cast_possible_wrap)]
154                let n: i64 = neg.values().copied().map(|v| v as i64).sum();
155                HashKey::Int64(p - n)
156            }
157            _ => HashKey::Null,
158        }
159    }
160
161    /// Creates a hash key from a column value at a given row.
162    pub fn from_column(column: &ValueVector, row: usize) -> Option<Self> {
163        column.get_value(row).map(|v| Self::from_value(&v))
164    }
165}
166
167/// Hash join operator.
168///
169/// Builds a hash table from the build side (right) and probes with the probe side (left).
170/// Efficient for equality joins on one or more columns.
171pub struct HashJoinOperator {
172    /// Left (probe) side operator.
173    probe_side: Box<dyn Operator>,
174    /// Right (build) side operator.
175    build_side: Box<dyn Operator>,
176    /// Column indices on the probe side for join keys.
177    probe_keys: Vec<usize>,
178    /// Column indices on the build side for join keys.
179    build_keys: Vec<usize>,
180    /// Join type.
181    join_type: JoinType,
182    /// Output schema (combined from both sides).
183    output_schema: Vec<LogicalType>,
184    /// Hash table: key -> list of (chunk_index, row_index).
185    hash_table: HashMap<HashKey, Vec<(usize, usize)>>,
186    /// Materialized build side chunks.
187    build_chunks: Vec<DataChunk>,
188    /// Whether the build phase is complete.
189    build_complete: bool,
190    /// Current probe chunk being processed.
191    current_probe_chunk: Option<DataChunk>,
192    /// Current row in the probe chunk.
193    current_probe_row: usize,
194    /// Current position in the hash table matches for the current probe row.
195    current_match_position: usize,
196    /// Current matches for the current probe row.
197    current_matches: Vec<(usize, usize)>,
198    /// For left/full outer joins: track which probe rows had matches.
199    probe_matched: Vec<bool>,
200    /// For right/full outer joins: track which build rows were matched.
201    build_matched: Vec<Vec<bool>>,
202    /// Whether we're in the emit unmatched phase (for outer joins).
203    emitting_unmatched: bool,
204    /// Current chunk index when emitting unmatched rows.
205    unmatched_chunk_idx: usize,
206    /// Current row index when emitting unmatched rows.
207    unmatched_row_idx: usize,
208}
209
210impl HashJoinOperator {
211    /// Creates a new hash join operator.
212    ///
213    /// # Arguments
214    /// * `probe_side` - Left side operator (will be probed).
215    /// * `build_side` - Right side operator (will build hash table).
216    /// * `probe_keys` - Column indices on probe side for join keys.
217    /// * `build_keys` - Column indices on build side for join keys.
218    /// * `join_type` - Type of join to perform.
219    /// * `output_schema` - Schema of the output (probe columns + build columns).
220    pub fn new(
221        probe_side: Box<dyn Operator>,
222        build_side: Box<dyn Operator>,
223        probe_keys: Vec<usize>,
224        build_keys: Vec<usize>,
225        join_type: JoinType,
226        output_schema: Vec<LogicalType>,
227    ) -> Self {
228        Self {
229            probe_side,
230            build_side,
231            probe_keys,
232            build_keys,
233            join_type,
234            output_schema,
235            hash_table: HashMap::new(),
236            build_chunks: Vec::new(),
237            build_complete: false,
238            current_probe_chunk: None,
239            current_probe_row: 0,
240            current_match_position: 0,
241            current_matches: Vec::new(),
242            probe_matched: Vec::new(),
243            build_matched: Vec::new(),
244            emitting_unmatched: false,
245            unmatched_chunk_idx: 0,
246            unmatched_row_idx: 0,
247        }
248    }
249
250    /// Builds the hash table from the build side.
251    fn build_hash_table(&mut self) -> Result<(), OperatorError> {
252        while let Some(chunk) = self.build_side.next()? {
253            let chunk_idx = self.build_chunks.len();
254
255            // Initialize match tracking for outer joins
256            if matches!(self.join_type, JoinType::Right | JoinType::Full) {
257                self.build_matched.push(vec![false; chunk.row_count()]);
258            }
259
260            // Add each row to the hash table
261            for row in chunk.selected_indices() {
262                let key = self.extract_key(&chunk, row, &self.build_keys)?;
263
264                // Skip null keys for inner/semi/anti joins
265                if matches!(key, HashKey::Null)
266                    && !matches!(
267                        self.join_type,
268                        JoinType::Left | JoinType::Right | JoinType::Full
269                    )
270                {
271                    continue;
272                }
273
274                self.hash_table
275                    .entry(key)
276                    .or_default()
277                    .push((chunk_idx, row));
278            }
279
280            self.build_chunks.push(chunk);
281        }
282
283        self.build_complete = true;
284        Ok(())
285    }
286
287    /// Extracts a hash key from a chunk row.
288    fn extract_key(
289        &self,
290        chunk: &DataChunk,
291        row: usize,
292        key_columns: &[usize],
293    ) -> Result<HashKey, OperatorError> {
294        if key_columns.len() == 1 {
295            let col = chunk.column(key_columns[0]).ok_or_else(|| {
296                OperatorError::ColumnNotFound(format!("column {}", key_columns[0]))
297            })?;
298            Ok(HashKey::from_column(col, row).unwrap_or(HashKey::Null))
299        } else {
300            let keys: Vec<HashKey> = key_columns
301                .iter()
302                .map(|&col_idx| {
303                    chunk
304                        .column(col_idx)
305                        .and_then(|col| HashKey::from_column(col, row))
306                        .unwrap_or(HashKey::Null)
307                })
308                .collect();
309            Ok(HashKey::Composite(keys))
310        }
311    }
312
313    /// Produces an output row from a probe row and build row.
314    fn produce_output_row(
315        &self,
316        builder: &mut DataChunkBuilder,
317        probe_chunk: &DataChunk,
318        probe_row: usize,
319        build_chunk: Option<&DataChunk>,
320        build_row: Option<usize>,
321    ) -> Result<(), OperatorError> {
322        let probe_col_count = probe_chunk.column_count();
323
324        // Copy probe side columns
325        for col_idx in 0..probe_col_count {
326            let src_col = probe_chunk
327                .column(col_idx)
328                .ok_or_else(|| OperatorError::ColumnNotFound(format!("probe column {col_idx}")))?;
329            let dst_col = builder
330                .column_mut(col_idx)
331                .ok_or_else(|| OperatorError::ColumnNotFound(format!("output column {col_idx}")))?;
332
333            if let Some(value) = src_col.get_value(probe_row) {
334                dst_col.push_value(value);
335            } else {
336                dst_col.push_value(Value::Null);
337            }
338        }
339
340        // Copy build side columns
341        match (build_chunk, build_row) {
342            (Some(chunk), Some(row)) => {
343                for col_idx in 0..chunk.column_count() {
344                    let src_col = chunk.column(col_idx).ok_or_else(|| {
345                        OperatorError::ColumnNotFound(format!("build column {col_idx}"))
346                    })?;
347                    let dst_col =
348                        builder
349                            .column_mut(probe_col_count + col_idx)
350                            .ok_or_else(|| {
351                                OperatorError::ColumnNotFound(format!(
352                                    "output column {}",
353                                    probe_col_count + col_idx
354                                ))
355                            })?;
356
357                    if let Some(value) = src_col.get_value(row) {
358                        dst_col.push_value(value);
359                    } else {
360                        dst_col.push_value(Value::Null);
361                    }
362                }
363            }
364            _ => {
365                // Emit nulls for build side (left outer join case)
366                if !self.build_chunks.is_empty() {
367                    let build_col_count = self.build_chunks[0].column_count();
368                    for col_idx in 0..build_col_count {
369                        let dst_col =
370                            builder
371                                .column_mut(probe_col_count + col_idx)
372                                .ok_or_else(|| {
373                                    OperatorError::ColumnNotFound(format!(
374                                        "output column {}",
375                                        probe_col_count + col_idx
376                                    ))
377                                })?;
378                        dst_col.push_value(Value::Null);
379                    }
380                }
381            }
382        }
383
384        builder.advance_row();
385        Ok(())
386    }
387
388    /// Gets the next probe chunk.
389    fn get_next_probe_chunk(&mut self) -> Result<bool, OperatorError> {
390        let chunk = self.probe_side.next()?;
391        if let Some(ref c) = chunk {
392            // Initialize match tracking for outer joins
393            if matches!(self.join_type, JoinType::Left | JoinType::Full) {
394                self.probe_matched = vec![false; c.row_count()];
395            }
396        }
397        let has_chunk = chunk.is_some();
398        self.current_probe_chunk = chunk;
399        self.current_probe_row = 0;
400        Ok(has_chunk)
401    }
402
403    /// Emits unmatched build rows for right/full outer joins.
404    fn emit_unmatched_build(&mut self) -> OperatorResult {
405        if self.build_matched.is_empty() {
406            return Ok(None);
407        }
408
409        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
410
411        // Determine probe column count from schema or first probe chunk
412        let probe_col_count = if !self.build_chunks.is_empty() {
413            self.output_schema.len() - self.build_chunks[0].column_count()
414        } else {
415            0
416        };
417
418        while self.unmatched_chunk_idx < self.build_chunks.len() {
419            let chunk = &self.build_chunks[self.unmatched_chunk_idx];
420            let matched = &self.build_matched[self.unmatched_chunk_idx];
421
422            while self.unmatched_row_idx < matched.len() {
423                if !matched[self.unmatched_row_idx] {
424                    // This row was not matched - emit with nulls on probe side
425
426                    // Emit nulls for probe side
427                    for col_idx in 0..probe_col_count {
428                        if let Some(dst_col) = builder.column_mut(col_idx) {
429                            dst_col.push_value(Value::Null);
430                        }
431                    }
432
433                    // Copy build side values
434                    for col_idx in 0..chunk.column_count() {
435                        if let (Some(src_col), Some(dst_col)) = (
436                            chunk.column(col_idx),
437                            builder.column_mut(probe_col_count + col_idx),
438                        ) {
439                            if let Some(value) = src_col.get_value(self.unmatched_row_idx) {
440                                dst_col.push_value(value);
441                            } else {
442                                dst_col.push_value(Value::Null);
443                            }
444                        }
445                    }
446
447                    builder.advance_row();
448
449                    if builder.is_full() {
450                        self.unmatched_row_idx += 1;
451                        return Ok(Some(builder.finish()));
452                    }
453                }
454
455                self.unmatched_row_idx += 1;
456            }
457
458            self.unmatched_chunk_idx += 1;
459            self.unmatched_row_idx = 0;
460        }
461
462        if builder.row_count() > 0 {
463            Ok(Some(builder.finish()))
464        } else {
465            Ok(None)
466        }
467    }
468}
469
470impl Operator for HashJoinOperator {
471    fn next(&mut self) -> OperatorResult {
472        // Phase 1: Build hash table
473        if !self.build_complete {
474            self.build_hash_table()?;
475        }
476
477        // Phase 3: Emit unmatched build rows (right/full outer join)
478        if self.emitting_unmatched {
479            return self.emit_unmatched_build();
480        }
481
482        // Phase 2: Probe
483        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
484
485        loop {
486            // Get current probe chunk or fetch new one
487            if self.current_probe_chunk.is_none() && !self.get_next_probe_chunk()? {
488                // No more probe data
489                if matches!(self.join_type, JoinType::Right | JoinType::Full) {
490                    self.emitting_unmatched = true;
491                    return self.emit_unmatched_build();
492                }
493                return if builder.row_count() > 0 {
494                    Ok(Some(builder.finish()))
495                } else {
496                    Ok(None)
497                };
498            }
499
500            // Invariant: current_probe_chunk is Some here - the guard at line 396 either
501            // populates it via get_next_probe_chunk() or returns from the function
502            let probe_chunk = self
503                .current_probe_chunk
504                .as_ref()
505                .expect("probe chunk is Some: guard at line 396 ensures this");
506            let probe_rows: Vec<usize> = probe_chunk.selected_indices().collect();
507
508            while self.current_probe_row < probe_rows.len() {
509                let probe_row = probe_rows[self.current_probe_row];
510
511                // If we don't have current matches, look them up
512                if self.current_matches.is_empty() && self.current_match_position == 0 {
513                    let key = self.extract_key(probe_chunk, probe_row, &self.probe_keys)?;
514
515                    // Handle semi/anti joins differently
516                    match self.join_type {
517                        JoinType::Semi => {
518                            if self.hash_table.contains_key(&key) {
519                                // Emit probe row only
520                                for col_idx in 0..probe_chunk.column_count() {
521                                    if let (Some(src_col), Some(dst_col)) =
522                                        (probe_chunk.column(col_idx), builder.column_mut(col_idx))
523                                        && let Some(value) = src_col.get_value(probe_row)
524                                    {
525                                        dst_col.push_value(value);
526                                    }
527                                }
528                                builder.advance_row();
529                            }
530                            self.current_probe_row += 1;
531                            continue;
532                        }
533                        JoinType::Anti => {
534                            if !self.hash_table.contains_key(&key) {
535                                // Emit probe row only
536                                for col_idx in 0..probe_chunk.column_count() {
537                                    if let (Some(src_col), Some(dst_col)) =
538                                        (probe_chunk.column(col_idx), builder.column_mut(col_idx))
539                                        && let Some(value) = src_col.get_value(probe_row)
540                                    {
541                                        dst_col.push_value(value);
542                                    }
543                                }
544                                builder.advance_row();
545                            }
546                            self.current_probe_row += 1;
547                            continue;
548                        }
549                        _ => {
550                            self.current_matches =
551                                self.hash_table.get(&key).cloned().unwrap_or_default();
552                        }
553                    }
554                }
555
556                // Process matches
557                if self.current_matches.is_empty() {
558                    // No matches - for left/full outer join, emit with nulls
559                    if matches!(self.join_type, JoinType::Left | JoinType::Full) {
560                        self.produce_output_row(&mut builder, probe_chunk, probe_row, None, None)?;
561                    }
562                    self.current_probe_row += 1;
563                    self.current_match_position = 0;
564                } else {
565                    // Process each match
566                    while self.current_match_position < self.current_matches.len() {
567                        let (build_chunk_idx, build_row) =
568                            self.current_matches[self.current_match_position];
569                        let build_chunk = &self.build_chunks[build_chunk_idx];
570
571                        // Mark as matched for outer joins
572                        if matches!(self.join_type, JoinType::Left | JoinType::Full)
573                            && probe_row < self.probe_matched.len()
574                        {
575                            self.probe_matched[probe_row] = true;
576                        }
577                        if matches!(self.join_type, JoinType::Right | JoinType::Full)
578                            && build_chunk_idx < self.build_matched.len()
579                            && build_row < self.build_matched[build_chunk_idx].len()
580                        {
581                            self.build_matched[build_chunk_idx][build_row] = true;
582                        }
583
584                        self.produce_output_row(
585                            &mut builder,
586                            probe_chunk,
587                            probe_row,
588                            Some(build_chunk),
589                            Some(build_row),
590                        )?;
591
592                        self.current_match_position += 1;
593
594                        if builder.is_full() {
595                            return Ok(Some(builder.finish()));
596                        }
597                    }
598
599                    // Done with this probe row
600                    self.current_probe_row += 1;
601                    self.current_matches.clear();
602                    self.current_match_position = 0;
603                }
604
605                if builder.is_full() {
606                    return Ok(Some(builder.finish()));
607                }
608            }
609
610            // Done with current probe chunk
611            self.current_probe_chunk = None;
612            self.current_probe_row = 0;
613
614            if builder.row_count() > 0 {
615                return Ok(Some(builder.finish()));
616            }
617        }
618    }
619
620    fn reset(&mut self) {
621        self.probe_side.reset();
622        self.build_side.reset();
623        self.hash_table.clear();
624        self.build_chunks.clear();
625        self.build_complete = false;
626        self.current_probe_chunk = None;
627        self.current_probe_row = 0;
628        self.current_match_position = 0;
629        self.current_matches.clear();
630        self.probe_matched.clear();
631        self.build_matched.clear();
632        self.emitting_unmatched = false;
633        self.unmatched_chunk_idx = 0;
634        self.unmatched_row_idx = 0;
635    }
636
637    fn name(&self) -> &'static str {
638        "HashJoin"
639    }
640
641    fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
642        self
643    }
644}
645
646/// Nested loop join operator.
647///
648/// Performs a cartesian product of both sides, filtering by the join condition.
649/// Less efficient than hash join but supports any join condition.
650pub struct NestedLoopJoinOperator {
651    /// Left side operator.
652    left: Box<dyn Operator>,
653    /// Right side operator.
654    right: Box<dyn Operator>,
655    /// Join condition predicate (if any).
656    condition: Option<Box<dyn JoinCondition>>,
657    /// Join type.
658    join_type: JoinType,
659    /// Output schema.
660    output_schema: Vec<LogicalType>,
661    /// Materialized right side chunks.
662    right_chunks: Vec<DataChunk>,
663    /// Whether the right side is materialized.
664    right_materialized: bool,
665    /// Current left chunk.
666    current_left_chunk: Option<DataChunk>,
667    /// Current row in the left chunk.
668    current_left_row: usize,
669    /// Current chunk index in the right side.
670    current_right_chunk: usize,
671    /// Whether the current left row has been matched (for Left Join).
672    current_left_matched: bool,
673    /// Current row in the current right chunk.
674    current_right_row: usize,
675}
676
677/// Trait for join conditions.
678pub trait JoinCondition: Send + Sync {
679    /// Evaluates the condition for a pair of rows.
680    fn evaluate(
681        &self,
682        left_chunk: &DataChunk,
683        left_row: usize,
684        right_chunk: &DataChunk,
685        right_row: usize,
686    ) -> bool;
687}
688
689/// A simple equality condition for nested loop joins.
690pub struct EqualityCondition {
691    /// Column index on the left side.
692    left_column: usize,
693    /// Column index on the right side.
694    right_column: usize,
695}
696
697impl EqualityCondition {
698    /// Creates a new equality condition.
699    pub fn new(left_column: usize, right_column: usize) -> Self {
700        Self {
701            left_column,
702            right_column,
703        }
704    }
705}
706
707impl JoinCondition for EqualityCondition {
708    fn evaluate(
709        &self,
710        left_chunk: &DataChunk,
711        left_row: usize,
712        right_chunk: &DataChunk,
713        right_row: usize,
714    ) -> bool {
715        let left_val = left_chunk
716            .column(self.left_column)
717            .and_then(|c| c.get_value(left_row));
718        let right_val = right_chunk
719            .column(self.right_column)
720            .and_then(|c| c.get_value(right_row));
721
722        match (left_val, right_val) {
723            (Some(l), Some(r)) => l == r,
724            _ => false,
725        }
726    }
727}
728
729impl NestedLoopJoinOperator {
730    /// Creates a new nested loop join operator.
731    pub fn new(
732        left: Box<dyn Operator>,
733        right: Box<dyn Operator>,
734        condition: Option<Box<dyn JoinCondition>>,
735        join_type: JoinType,
736        output_schema: Vec<LogicalType>,
737    ) -> Self {
738        Self {
739            left,
740            right,
741            condition,
742            join_type,
743            output_schema,
744            right_chunks: Vec::new(),
745            right_materialized: false,
746            current_left_chunk: None,
747            current_left_row: 0,
748            current_right_chunk: 0,
749            current_right_row: 0,
750            current_left_matched: false,
751        }
752    }
753
754    /// Materializes the right side.
755    fn materialize_right(&mut self) -> Result<(), OperatorError> {
756        while let Some(chunk) = self.right.next()? {
757            self.right_chunks.push(chunk);
758        }
759        self.right_materialized = true;
760        Ok(())
761    }
762
763    /// Produces an output row.
764    fn produce_row(
765        &self,
766        builder: &mut DataChunkBuilder,
767        left_chunk: &DataChunk,
768        left_row: usize,
769        right_chunk: &DataChunk,
770        right_row: usize,
771    ) {
772        // Copy left columns
773        for col_idx in 0..left_chunk.column_count() {
774            if let (Some(src), Some(dst)) =
775                (left_chunk.column(col_idx), builder.column_mut(col_idx))
776            {
777                if let Some(val) = src.get_value(left_row) {
778                    dst.push_value(val);
779                } else {
780                    dst.push_value(Value::Null);
781                }
782            }
783        }
784
785        // Copy right columns
786        let left_col_count = left_chunk.column_count();
787        for col_idx in 0..right_chunk.column_count() {
788            if let (Some(src), Some(dst)) = (
789                right_chunk.column(col_idx),
790                builder.column_mut(left_col_count + col_idx),
791            ) {
792                if let Some(val) = src.get_value(right_row) {
793                    dst.push_value(val);
794                } else {
795                    dst.push_value(Value::Null);
796                }
797            }
798        }
799
800        builder.advance_row();
801    }
802
803    /// Produces an output row with NULLs for the right side (for unmatched left rows in Left Join).
804    fn produce_left_unmatched_row(
805        &self,
806        builder: &mut DataChunkBuilder,
807        left_chunk: &DataChunk,
808        left_row: usize,
809        right_col_count: usize,
810    ) {
811        // Copy left columns
812        for col_idx in 0..left_chunk.column_count() {
813            if let (Some(src), Some(dst)) =
814                (left_chunk.column(col_idx), builder.column_mut(col_idx))
815            {
816                if let Some(val) = src.get_value(left_row) {
817                    dst.push_value(val);
818                } else {
819                    dst.push_value(Value::Null);
820                }
821            }
822        }
823
824        // Fill right columns with NULLs
825        let left_col_count = left_chunk.column_count();
826        for col_idx in 0..right_col_count {
827            if let Some(dst) = builder.column_mut(left_col_count + col_idx) {
828                dst.push_value(Value::Null);
829            }
830        }
831
832        builder.advance_row();
833    }
834}
835
836impl Operator for NestedLoopJoinOperator {
837    fn next(&mut self) -> OperatorResult {
838        // Materialize right side
839        if !self.right_materialized {
840            self.materialize_right()?;
841        }
842
843        // If right side is empty and not a left outer join, return nothing
844        if self.right_chunks.is_empty() && !matches!(self.join_type, JoinType::Left) {
845            return Ok(None);
846        }
847
848        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
849
850        loop {
851            // Get current left chunk
852            if self.current_left_chunk.is_none() {
853                self.current_left_chunk = self.left.next()?;
854                self.current_left_row = 0;
855                self.current_right_chunk = 0;
856                self.current_right_row = 0;
857
858                if self.current_left_chunk.is_none() {
859                    // No more left data
860                    return if builder.row_count() > 0 {
861                        Ok(Some(builder.finish()))
862                    } else {
863                        Ok(None)
864                    };
865                }
866            }
867
868            let left_chunk = self
869                .current_left_chunk
870                .as_ref()
871                .expect("left chunk is Some: loaded in loop above");
872            let left_rows: Vec<usize> = left_chunk.selected_indices().collect();
873
874            // Calculate right column count for potential unmatched rows
875            let right_col_count = if !self.right_chunks.is_empty() {
876                self.right_chunks[0].column_count()
877            } else {
878                // Infer from output schema
879                self.output_schema
880                    .len()
881                    .saturating_sub(left_chunk.column_count())
882            };
883
884            // Process current left row against all right rows
885            while self.current_left_row < left_rows.len() {
886                let left_row = left_rows[self.current_left_row];
887
888                // Reset match tracking for this left row
889                if self.current_right_chunk == 0 && self.current_right_row == 0 {
890                    self.current_left_matched = false;
891                }
892
893                // Cross join or inner/other join
894                while self.current_right_chunk < self.right_chunks.len() {
895                    let right_chunk = &self.right_chunks[self.current_right_chunk];
896                    let right_rows: Vec<usize> = right_chunk.selected_indices().collect();
897
898                    while self.current_right_row < right_rows.len() {
899                        let right_row = right_rows[self.current_right_row];
900
901                        // Check condition
902                        let matches = match &self.condition {
903                            Some(cond) => {
904                                cond.evaluate(left_chunk, left_row, right_chunk, right_row)
905                            }
906                            None => true, // Cross join
907                        };
908
909                        if matches {
910                            self.current_left_matched = true;
911                            self.produce_row(
912                                &mut builder,
913                                left_chunk,
914                                left_row,
915                                right_chunk,
916                                right_row,
917                            );
918
919                            if builder.is_full() {
920                                self.current_right_row += 1;
921                                return Ok(Some(builder.finish()));
922                            }
923                        }
924
925                        self.current_right_row += 1;
926                    }
927
928                    self.current_right_chunk += 1;
929                    self.current_right_row = 0;
930                }
931
932                // Done processing all right rows for this left row
933                // For Left Join, emit unmatched left row with NULLs
934                if matches!(self.join_type, JoinType::Left) && !self.current_left_matched {
935                    self.produce_left_unmatched_row(
936                        &mut builder,
937                        left_chunk,
938                        left_row,
939                        right_col_count,
940                    );
941
942                    if builder.is_full() {
943                        self.current_left_row += 1;
944                        self.current_right_chunk = 0;
945                        self.current_right_row = 0;
946                        return Ok(Some(builder.finish()));
947                    }
948                }
949
950                // Move to next left row
951                self.current_left_row += 1;
952                self.current_right_chunk = 0;
953                self.current_right_row = 0;
954            }
955
956            // Done with current left chunk
957            self.current_left_chunk = None;
958
959            if builder.row_count() > 0 {
960                return Ok(Some(builder.finish()));
961            }
962        }
963    }
964
965    fn reset(&mut self) {
966        self.left.reset();
967        self.right.reset();
968        self.right_chunks.clear();
969        self.right_materialized = false;
970        self.current_left_chunk = None;
971        self.current_left_row = 0;
972        self.current_right_chunk = 0;
973        self.current_right_row = 0;
974        self.current_left_matched = false;
975    }
976
977    fn name(&self) -> &'static str {
978        "NestedLoopJoin"
979    }
980
981    fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
982        self
983    }
984}
985
986#[cfg(test)]
987mod tests {
988    use super::*;
989    use crate::execution::chunk::DataChunkBuilder;
990
991    /// Mock operator for testing.
992    struct MockOperator {
993        chunks: Vec<DataChunk>,
994        position: usize,
995    }
996
997    impl MockOperator {
998        fn new(chunks: Vec<DataChunk>) -> Self {
999            Self {
1000                chunks,
1001                position: 0,
1002            }
1003        }
1004    }
1005
1006    impl Operator for MockOperator {
1007        fn next(&mut self) -> OperatorResult {
1008            if self.position < self.chunks.len() {
1009                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1010                self.position += 1;
1011                Ok(Some(chunk))
1012            } else {
1013                Ok(None)
1014            }
1015        }
1016
1017        fn reset(&mut self) {
1018            self.position = 0;
1019        }
1020
1021        fn name(&self) -> &'static str {
1022            "Mock"
1023        }
1024
1025        fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
1026            self
1027        }
1028    }
1029
1030    fn create_int_chunk(values: &[i64]) -> DataChunk {
1031        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1032        for &v in values {
1033            builder.column_mut(0).unwrap().push_int64(v);
1034            builder.advance_row();
1035        }
1036        builder.finish()
1037    }
1038
1039    #[test]
1040    fn test_hash_join_inner() {
1041        // Left: [1, 2, 3, 4]
1042        // Right: [2, 3, 4, 5]
1043        // Inner join on column 0 should produce: [2, 3, 4]
1044
1045        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
1046        let right = MockOperator::new(vec![create_int_chunk(&[2, 3, 4, 5])]);
1047
1048        let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
1049        let mut join = HashJoinOperator::new(
1050            Box::new(left),
1051            Box::new(right),
1052            vec![0],
1053            vec![0],
1054            JoinType::Inner,
1055            output_schema,
1056        );
1057
1058        let mut results = Vec::new();
1059        while let Some(chunk) = join.next().unwrap() {
1060            for row in chunk.selected_indices() {
1061                let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
1062                let right_val = chunk.column(1).unwrap().get_int64(row).unwrap();
1063                results.push((left_val, right_val));
1064            }
1065        }
1066
1067        results.sort_unstable();
1068        assert_eq!(results, vec![(2, 2), (3, 3), (4, 4)]);
1069    }
1070
1071    #[test]
1072    fn test_hash_join_left_outer() {
1073        // Left: [1, 2, 3]
1074        // Right: [2, 3]
1075        // Left outer join should produce: [(1, null), (2, 2), (3, 3)]
1076
1077        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3])]);
1078        let right = MockOperator::new(vec![create_int_chunk(&[2, 3])]);
1079
1080        let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
1081        let mut join = HashJoinOperator::new(
1082            Box::new(left),
1083            Box::new(right),
1084            vec![0],
1085            vec![0],
1086            JoinType::Left,
1087            output_schema,
1088        );
1089
1090        let mut results = Vec::new();
1091        while let Some(chunk) = join.next().unwrap() {
1092            for row in chunk.selected_indices() {
1093                let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
1094                let right_val = chunk.column(1).unwrap().get_int64(row);
1095                results.push((left_val, right_val));
1096            }
1097        }
1098
1099        results.sort_by_key(|(l, _)| *l);
1100        assert_eq!(results.len(), 3);
1101        assert_eq!(results[0], (1, None)); // No match
1102        assert_eq!(results[1], (2, Some(2)));
1103        assert_eq!(results[2], (3, Some(3)));
1104    }
1105
1106    #[test]
1107    fn test_nested_loop_cross_join() {
1108        // Left: [1, 2]
1109        // Right: [10, 20]
1110        // Cross join should produce: [(1,10), (1,20), (2,10), (2,20)]
1111
1112        let left = MockOperator::new(vec![create_int_chunk(&[1, 2])]);
1113        let right = MockOperator::new(vec![create_int_chunk(&[10, 20])]);
1114
1115        let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
1116        let mut join = NestedLoopJoinOperator::new(
1117            Box::new(left),
1118            Box::new(right),
1119            None,
1120            JoinType::Cross,
1121            output_schema,
1122        );
1123
1124        let mut results = Vec::new();
1125        while let Some(chunk) = join.next().unwrap() {
1126            for row in chunk.selected_indices() {
1127                let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
1128                let right_val = chunk.column(1).unwrap().get_int64(row).unwrap();
1129                results.push((left_val, right_val));
1130            }
1131        }
1132
1133        results.sort_unstable();
1134        assert_eq!(results, vec![(1, 10), (1, 20), (2, 10), (2, 20)]);
1135    }
1136
1137    #[test]
1138    fn test_hash_join_semi() {
1139        // Left: [1, 2, 3, 4]
1140        // Right: [2, 4]
1141        // Semi join should produce: [2, 4] (only left rows that have matches)
1142
1143        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
1144        let right = MockOperator::new(vec![create_int_chunk(&[2, 4])]);
1145
1146        // Semi join only outputs probe (left) columns
1147        let output_schema = vec![LogicalType::Int64];
1148        let mut join = HashJoinOperator::new(
1149            Box::new(left),
1150            Box::new(right),
1151            vec![0],
1152            vec![0],
1153            JoinType::Semi,
1154            output_schema,
1155        );
1156
1157        let mut results = Vec::new();
1158        while let Some(chunk) = join.next().unwrap() {
1159            for row in chunk.selected_indices() {
1160                let val = chunk.column(0).unwrap().get_int64(row).unwrap();
1161                results.push(val);
1162            }
1163        }
1164
1165        results.sort_unstable();
1166        assert_eq!(results, vec![2, 4]);
1167    }
1168
1169    #[test]
1170    fn test_hash_join_anti() {
1171        // Left: [1, 2, 3, 4]
1172        // Right: [2, 4]
1173        // Anti join should produce: [1, 3] (left rows with no matches)
1174
1175        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
1176        let right = MockOperator::new(vec![create_int_chunk(&[2, 4])]);
1177
1178        let output_schema = vec![LogicalType::Int64];
1179        let mut join = HashJoinOperator::new(
1180            Box::new(left),
1181            Box::new(right),
1182            vec![0],
1183            vec![0],
1184            JoinType::Anti,
1185            output_schema,
1186        );
1187
1188        let mut results = Vec::new();
1189        while let Some(chunk) = join.next().unwrap() {
1190            for row in chunk.selected_indices() {
1191                let val = chunk.column(0).unwrap().get_int64(row).unwrap();
1192                results.push(val);
1193            }
1194        }
1195
1196        results.sort_unstable();
1197        assert_eq!(results, vec![1, 3]);
1198    }
1199
1200    #[test]
1201    fn test_hash_key_from_map() {
1202        use grafeo_common::types::{PropertyKey, Value};
1203        use std::collections::BTreeMap;
1204        use std::sync::Arc;
1205
1206        let mut map = BTreeMap::new();
1207        map.insert(PropertyKey::new("key"), Value::Int64(42));
1208        let v = Value::Map(Arc::new(map));
1209        let key = HashKey::from_value(&v);
1210        // BTreeMap iterates in ascending key order, result is a Composite
1211        assert!(matches!(key, HashKey::Composite(_)));
1212
1213        // Two maps with the same content produce the same hash key
1214        let mut map2 = BTreeMap::new();
1215        map2.insert(PropertyKey::new("key"), Value::Int64(42));
1216        let v2 = Value::Map(Arc::new(map2));
1217        assert_eq!(HashKey::from_value(&v), HashKey::from_value(&v2));
1218    }
1219
1220    #[test]
1221    fn test_hash_key_from_map_empty() {
1222        use grafeo_common::types::Value;
1223        use std::collections::BTreeMap;
1224        use std::sync::Arc;
1225
1226        let v = Value::Map(Arc::new(BTreeMap::new()));
1227        let key = HashKey::from_value(&v);
1228        assert_eq!(key, HashKey::Composite(vec![]));
1229    }
1230
1231    #[test]
1232    fn test_hash_key_from_gcounter() {
1233        use grafeo_common::types::Value;
1234        use std::collections::HashMap;
1235        use std::sync::Arc;
1236
1237        let mut counts = HashMap::new();
1238        counts.insert("node-a".to_string(), 5u64);
1239        counts.insert("node-b".to_string(), 3u64);
1240        let v = Value::GCounter(Arc::new(counts));
1241        // GCounter hashes to sum of all values (5 + 3 = 8)
1242        assert_eq!(HashKey::from_value(&v), HashKey::Int64(8));
1243    }
1244
1245    #[test]
1246    fn test_hash_key_from_gcounter_empty() {
1247        use grafeo_common::types::Value;
1248        use std::collections::HashMap;
1249        use std::sync::Arc;
1250
1251        let v = Value::GCounter(Arc::new(HashMap::new()));
1252        assert_eq!(HashKey::from_value(&v), HashKey::Int64(0));
1253    }
1254
1255    #[test]
1256    fn test_hash_key_from_oncounter() {
1257        use grafeo_common::types::Value;
1258        use std::collections::HashMap;
1259        use std::sync::Arc;
1260
1261        let mut pos = HashMap::new();
1262        pos.insert("node-a".to_string(), 10u64);
1263        let mut neg = HashMap::new();
1264        neg.insert("node-a".to_string(), 3u64);
1265        let v = Value::OnCounter {
1266            pos: Arc::new(pos),
1267            neg: Arc::new(neg),
1268        };
1269        // OnCounter hashes to pos_sum - neg_sum = 10 - 3 = 7
1270        assert_eq!(HashKey::from_value(&v), HashKey::Int64(7));
1271    }
1272
1273    #[test]
1274    fn test_hash_key_from_oncounter_balanced() {
1275        use grafeo_common::types::Value;
1276        use std::collections::HashMap;
1277        use std::sync::Arc;
1278
1279        let mut pos = HashMap::new();
1280        pos.insert("r".to_string(), 5u64);
1281        let mut neg = HashMap::new();
1282        neg.insert("r".to_string(), 5u64);
1283        let v = Value::OnCounter {
1284            pos: Arc::new(pos),
1285            neg: Arc::new(neg),
1286        };
1287        assert_eq!(HashKey::from_value(&v), HashKey::Int64(0));
1288    }
1289
1290    #[test]
1291    fn test_hash_join_into_any() {
1292        let left = MockOperator::new(vec![]);
1293        let right = MockOperator::new(vec![]);
1294        let op = HashJoinOperator::new(
1295            Box::new(left),
1296            Box::new(right),
1297            vec![0],
1298            vec![0],
1299            JoinType::Inner,
1300            vec![LogicalType::Int64, LogicalType::Int64],
1301        );
1302        let any = Box::new(op).into_any();
1303        assert!(any.downcast::<HashJoinOperator>().is_ok());
1304    }
1305
1306    #[test]
1307    fn test_nested_loop_join_into_any() {
1308        let left = MockOperator::new(vec![]);
1309        let right = MockOperator::new(vec![]);
1310        let op = NestedLoopJoinOperator::new(
1311            Box::new(left),
1312            Box::new(right),
1313            None,
1314            JoinType::Cross,
1315            vec![LogicalType::Int64, LogicalType::Int64],
1316        );
1317        let any = Box::new(op).into_any();
1318        assert!(any.downcast::<NestedLoopJoinOperator>().is_ok());
1319    }
1320
1321    #[test]
1322    fn test_hash_key_ord_same_variant() {
1323        use std::cmp::Ordering;
1324
1325        assert_eq!(HashKey::Null.cmp(&HashKey::Null), Ordering::Equal);
1326        assert_eq!(
1327            HashKey::Bool(false).cmp(&HashKey::Bool(true)),
1328            Ordering::Less
1329        );
1330        assert_eq!(
1331            HashKey::Bool(true).cmp(&HashKey::Bool(false)),
1332            Ordering::Greater
1333        );
1334        assert_eq!(HashKey::Int64(1).cmp(&HashKey::Int64(2)), Ordering::Less);
1335        assert_eq!(HashKey::Int64(5).cmp(&HashKey::Int64(5)), Ordering::Equal);
1336        assert_eq!(
1337            HashKey::String(arcstr::literal!("a")).cmp(&HashKey::String(arcstr::literal!("b"))),
1338            Ordering::Less,
1339        );
1340        assert_eq!(
1341            HashKey::Bytes(vec![1, 2]).cmp(&HashKey::Bytes(vec![1, 3])),
1342            Ordering::Less,
1343        );
1344        assert_eq!(
1345            HashKey::Composite(vec![HashKey::Int64(1)])
1346                .cmp(&HashKey::Composite(vec![HashKey::Int64(2)])),
1347            Ordering::Less,
1348        );
1349    }
1350
1351    #[test]
1352    fn test_hash_key_ord_cross_variant() {
1353        use std::cmp::Ordering;
1354
1355        // Null < Bool < Int64 < String < Bytes < Composite
1356        assert_eq!(HashKey::Null.cmp(&HashKey::Bool(false)), Ordering::Less);
1357        assert_eq!(HashKey::Bool(true).cmp(&HashKey::Null), Ordering::Greater);
1358        assert_eq!(HashKey::Bool(false).cmp(&HashKey::Int64(0)), Ordering::Less);
1359        assert_eq!(
1360            HashKey::Int64(0).cmp(&HashKey::Bool(false)),
1361            Ordering::Greater
1362        );
1363        assert_eq!(
1364            HashKey::Int64(0).cmp(&HashKey::String(arcstr::literal!("a"))),
1365            Ordering::Less,
1366        );
1367        assert_eq!(
1368            HashKey::String(arcstr::literal!("a")).cmp(&HashKey::Int64(0)),
1369            Ordering::Greater,
1370        );
1371        assert_eq!(
1372            HashKey::String(arcstr::literal!("a")).cmp(&HashKey::Bytes(vec![1])),
1373            Ordering::Less,
1374        );
1375        assert_eq!(
1376            HashKey::Bytes(vec![1]).cmp(&HashKey::String(arcstr::literal!("a"))),
1377            Ordering::Greater,
1378        );
1379        assert_eq!(
1380            HashKey::Bytes(vec![1]).cmp(&HashKey::Composite(vec![])),
1381            Ordering::Less,
1382        );
1383        assert_eq!(
1384            HashKey::Composite(vec![]).cmp(&HashKey::Bytes(vec![1])),
1385            Ordering::Greater,
1386        );
1387    }
1388
1389    #[test]
1390    fn test_hash_key_partial_ord() {
1391        // PartialOrd delegates to Ord, just verify it returns Some
1392        assert!(HashKey::Null.partial_cmp(&HashKey::Int64(1)).is_some());
1393        assert!(HashKey::Int64(1).partial_cmp(&HashKey::Int64(2)).is_some());
1394    }
1395}