Skip to main content

grafeo_core/execution/operators/
join.rs

1//! Join operators for combining data from two sources.
2//!
3//! This module provides:
4//! - `HashJoinOperator`: Efficient hash-based join for equality conditions
5//! - `NestedLoopJoinOperator`: General-purpose join for any condition
6
7use std::cmp::Ordering;
8use std::collections::HashMap;
9
10use arcstr::ArcStr;
11use grafeo_common::types::{LogicalType, Value};
12
13use super::{Operator, OperatorError, OperatorResult};
14use crate::execution::chunk::DataChunkBuilder;
15use crate::execution::{DataChunk, ValueVector};
16
17/// The type of join to perform.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum JoinType {
20    /// Inner join: only matching rows from both sides.
21    Inner,
22    /// Left outer join: all rows from left, matching from right (nulls if no match).
23    Left,
24    /// Right outer join: all rows from right, matching from left (nulls if no match).
25    Right,
26    /// Full outer join: all rows from both sides.
27    Full,
28    /// Cross join: cartesian product of both sides.
29    Cross,
30    /// Semi join: rows from left that have a match in right.
31    Semi,
32    /// Anti join: rows from left that have no match in right.
33    Anti,
34}
35
36/// A hash key that can be hashed and compared for join operations.
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum HashKey {
39    /// Null key.
40    Null,
41    /// Boolean key.
42    Bool(bool),
43    /// Integer key.
44    Int64(i64),
45    /// String key (cheap clone via ArcStr refcount).
46    String(ArcStr),
47    /// Byte content key.
48    Bytes(Vec<u8>),
49    /// Composite key for multi-column joins.
50    Composite(Vec<HashKey>),
51}
52
53impl Ord for HashKey {
54    fn cmp(&self, other: &Self) -> Ordering {
55        match (self, other) {
56            (HashKey::Null, HashKey::Null) => Ordering::Equal,
57            (HashKey::Null, _) => Ordering::Less,
58            (_, HashKey::Null) => Ordering::Greater,
59            (HashKey::Bool(a), HashKey::Bool(b)) => a.cmp(b),
60            (HashKey::Bool(_), _) => Ordering::Less,
61            (_, HashKey::Bool(_)) => Ordering::Greater,
62            (HashKey::Int64(a), HashKey::Int64(b)) => a.cmp(b),
63            (HashKey::Int64(_), _) => Ordering::Less,
64            (_, HashKey::Int64(_)) => Ordering::Greater,
65            (HashKey::String(a), HashKey::String(b)) => a.cmp(b),
66            (HashKey::String(_), _) => Ordering::Less,
67            (_, HashKey::String(_)) => Ordering::Greater,
68            (HashKey::Bytes(a), HashKey::Bytes(b)) => a.cmp(b),
69            (HashKey::Bytes(_), _) => Ordering::Less,
70            (_, HashKey::Bytes(_)) => Ordering::Greater,
71            (HashKey::Composite(a), HashKey::Composite(b)) => a.cmp(b),
72        }
73    }
74}
75
76impl PartialOrd for HashKey {
77    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
78        Some(self.cmp(other))
79    }
80}
81
82impl HashKey {
83    /// Creates a hash key from a Value.
84    pub fn from_value(value: &Value) -> Self {
85        match value {
86            Value::Null => HashKey::Null,
87            Value::Bool(b) => HashKey::Bool(*b),
88            Value::Int64(i) => HashKey::Int64(*i),
89            Value::Float64(f) => {
90                // Convert float to bits for consistent hashing
91                HashKey::Int64(f.to_bits() as i64)
92            }
93            Value::String(s) => HashKey::String(s.clone()),
94            Value::Bytes(b) => HashKey::Bytes(b.to_vec()),
95            Value::Timestamp(t) => HashKey::Int64(t.as_micros()),
96            Value::Date(d) => HashKey::Int64(d.as_days() as i64),
97            Value::Time(t) => HashKey::Int64(t.as_nanos() as i64),
98            Value::Duration(d) => HashKey::Composite(vec![
99                HashKey::Int64(d.months()),
100                HashKey::Int64(d.days()),
101                HashKey::Int64(d.nanos()),
102            ]),
103            Value::ZonedDatetime(zdt) => HashKey::Int64(zdt.as_timestamp().as_micros()),
104            Value::List(items) => {
105                HashKey::Composite(items.iter().map(HashKey::from_value).collect())
106            }
107            Value::Map(map) => {
108                // BTreeMap::iter() visits entries in ascending key order, so no sort needed.
109                let keys: Vec<_> = map
110                    .iter()
111                    .map(|(k, v)| {
112                        HashKey::Composite(vec![
113                            HashKey::String(ArcStr::from(k.as_str())),
114                            HashKey::from_value(v),
115                        ])
116                    })
117                    .collect();
118                HashKey::Composite(keys)
119            }
120            Value::Vector(v) => {
121                // Hash vectors by converting each f32 to its bit representation
122                HashKey::Composite(
123                    v.iter()
124                        .map(|f| HashKey::Int64(f.to_bits() as i64))
125                        .collect(),
126                )
127            }
128            Value::Path { nodes, edges } => {
129                let mut parts: Vec<_> = nodes.iter().map(HashKey::from_value).collect();
130                parts.extend(edges.iter().map(HashKey::from_value));
131                HashKey::Composite(parts)
132            }
133            // CRDT counters are opaque keys; hash by total logical value.
134            Value::GCounter(counts) => {
135                HashKey::Int64(counts.values().copied().map(|v| v as i64).sum())
136            }
137            Value::OnCounter { pos, neg } => {
138                let p: i64 = pos.values().copied().map(|v| v as i64).sum();
139                let n: i64 = neg.values().copied().map(|v| v as i64).sum();
140                HashKey::Int64(p - n)
141            }
142        }
143    }
144
145    /// Creates a hash key from a column value at a given row.
146    pub fn from_column(column: &ValueVector, row: usize) -> Option<Self> {
147        column.get_value(row).map(|v| Self::from_value(&v))
148    }
149}
150
151/// Hash join operator.
152///
153/// Builds a hash table from the build side (right) and probes with the probe side (left).
154/// Efficient for equality joins on one or more columns.
155pub struct HashJoinOperator {
156    /// Left (probe) side operator.
157    probe_side: Box<dyn Operator>,
158    /// Right (build) side operator.
159    build_side: Box<dyn Operator>,
160    /// Column indices on the probe side for join keys.
161    probe_keys: Vec<usize>,
162    /// Column indices on the build side for join keys.
163    build_keys: Vec<usize>,
164    /// Join type.
165    join_type: JoinType,
166    /// Output schema (combined from both sides).
167    output_schema: Vec<LogicalType>,
168    /// Hash table: key -> list of (chunk_index, row_index).
169    hash_table: HashMap<HashKey, Vec<(usize, usize)>>,
170    /// Materialized build side chunks.
171    build_chunks: Vec<DataChunk>,
172    /// Whether the build phase is complete.
173    build_complete: bool,
174    /// Current probe chunk being processed.
175    current_probe_chunk: Option<DataChunk>,
176    /// Current row in the probe chunk.
177    current_probe_row: usize,
178    /// Current position in the hash table matches for the current probe row.
179    current_match_position: usize,
180    /// Current matches for the current probe row.
181    current_matches: Vec<(usize, usize)>,
182    /// For left/full outer joins: track which probe rows had matches.
183    probe_matched: Vec<bool>,
184    /// For right/full outer joins: track which build rows were matched.
185    build_matched: Vec<Vec<bool>>,
186    /// Whether we're in the emit unmatched phase (for outer joins).
187    emitting_unmatched: bool,
188    /// Current chunk index when emitting unmatched rows.
189    unmatched_chunk_idx: usize,
190    /// Current row index when emitting unmatched rows.
191    unmatched_row_idx: usize,
192}
193
194impl HashJoinOperator {
195    /// Creates a new hash join operator.
196    ///
197    /// # Arguments
198    /// * `probe_side` - Left side operator (will be probed).
199    /// * `build_side` - Right side operator (will build hash table).
200    /// * `probe_keys` - Column indices on probe side for join keys.
201    /// * `build_keys` - Column indices on build side for join keys.
202    /// * `join_type` - Type of join to perform.
203    /// * `output_schema` - Schema of the output (probe columns + build columns).
204    pub fn new(
205        probe_side: Box<dyn Operator>,
206        build_side: Box<dyn Operator>,
207        probe_keys: Vec<usize>,
208        build_keys: Vec<usize>,
209        join_type: JoinType,
210        output_schema: Vec<LogicalType>,
211    ) -> Self {
212        Self {
213            probe_side,
214            build_side,
215            probe_keys,
216            build_keys,
217            join_type,
218            output_schema,
219            hash_table: HashMap::new(),
220            build_chunks: Vec::new(),
221            build_complete: false,
222            current_probe_chunk: None,
223            current_probe_row: 0,
224            current_match_position: 0,
225            current_matches: Vec::new(),
226            probe_matched: Vec::new(),
227            build_matched: Vec::new(),
228            emitting_unmatched: false,
229            unmatched_chunk_idx: 0,
230            unmatched_row_idx: 0,
231        }
232    }
233
234    /// Builds the hash table from the build side.
235    fn build_hash_table(&mut self) -> Result<(), OperatorError> {
236        while let Some(chunk) = self.build_side.next()? {
237            let chunk_idx = self.build_chunks.len();
238
239            // Initialize match tracking for outer joins
240            if matches!(self.join_type, JoinType::Right | JoinType::Full) {
241                self.build_matched.push(vec![false; chunk.row_count()]);
242            }
243
244            // Add each row to the hash table
245            for row in chunk.selected_indices() {
246                let key = self.extract_key(&chunk, row, &self.build_keys)?;
247
248                // Skip null keys for inner/semi/anti joins
249                if matches!(key, HashKey::Null)
250                    && !matches!(
251                        self.join_type,
252                        JoinType::Left | JoinType::Right | JoinType::Full
253                    )
254                {
255                    continue;
256                }
257
258                self.hash_table
259                    .entry(key)
260                    .or_default()
261                    .push((chunk_idx, row));
262            }
263
264            self.build_chunks.push(chunk);
265        }
266
267        self.build_complete = true;
268        Ok(())
269    }
270
271    /// Extracts a hash key from a chunk row.
272    fn extract_key(
273        &self,
274        chunk: &DataChunk,
275        row: usize,
276        key_columns: &[usize],
277    ) -> Result<HashKey, OperatorError> {
278        if key_columns.len() == 1 {
279            let col = chunk.column(key_columns[0]).ok_or_else(|| {
280                OperatorError::ColumnNotFound(format!("column {}", key_columns[0]))
281            })?;
282            Ok(HashKey::from_column(col, row).unwrap_or(HashKey::Null))
283        } else {
284            let keys: Vec<HashKey> = key_columns
285                .iter()
286                .map(|&col_idx| {
287                    chunk
288                        .column(col_idx)
289                        .and_then(|col| HashKey::from_column(col, row))
290                        .unwrap_or(HashKey::Null)
291                })
292                .collect();
293            Ok(HashKey::Composite(keys))
294        }
295    }
296
297    /// Produces an output row from a probe row and build row.
298    fn produce_output_row(
299        &self,
300        builder: &mut DataChunkBuilder,
301        probe_chunk: &DataChunk,
302        probe_row: usize,
303        build_chunk: Option<&DataChunk>,
304        build_row: Option<usize>,
305    ) -> Result<(), OperatorError> {
306        let probe_col_count = probe_chunk.column_count();
307
308        // Copy probe side columns
309        for col_idx in 0..probe_col_count {
310            let src_col = probe_chunk
311                .column(col_idx)
312                .ok_or_else(|| OperatorError::ColumnNotFound(format!("probe column {col_idx}")))?;
313            let dst_col = builder
314                .column_mut(col_idx)
315                .ok_or_else(|| OperatorError::ColumnNotFound(format!("output column {col_idx}")))?;
316
317            if let Some(value) = src_col.get_value(probe_row) {
318                dst_col.push_value(value);
319            } else {
320                dst_col.push_value(Value::Null);
321            }
322        }
323
324        // Copy build side columns
325        match (build_chunk, build_row) {
326            (Some(chunk), Some(row)) => {
327                for col_idx in 0..chunk.column_count() {
328                    let src_col = chunk.column(col_idx).ok_or_else(|| {
329                        OperatorError::ColumnNotFound(format!("build column {col_idx}"))
330                    })?;
331                    let dst_col =
332                        builder
333                            .column_mut(probe_col_count + col_idx)
334                            .ok_or_else(|| {
335                                OperatorError::ColumnNotFound(format!(
336                                    "output column {}",
337                                    probe_col_count + col_idx
338                                ))
339                            })?;
340
341                    if let Some(value) = src_col.get_value(row) {
342                        dst_col.push_value(value);
343                    } else {
344                        dst_col.push_value(Value::Null);
345                    }
346                }
347            }
348            _ => {
349                // Emit nulls for build side (left outer join case)
350                if !self.build_chunks.is_empty() {
351                    let build_col_count = self.build_chunks[0].column_count();
352                    for col_idx in 0..build_col_count {
353                        let dst_col =
354                            builder
355                                .column_mut(probe_col_count + col_idx)
356                                .ok_or_else(|| {
357                                    OperatorError::ColumnNotFound(format!(
358                                        "output column {}",
359                                        probe_col_count + col_idx
360                                    ))
361                                })?;
362                        dst_col.push_value(Value::Null);
363                    }
364                }
365            }
366        }
367
368        builder.advance_row();
369        Ok(())
370    }
371
372    /// Gets the next probe chunk.
373    fn get_next_probe_chunk(&mut self) -> Result<bool, OperatorError> {
374        let chunk = self.probe_side.next()?;
375        if let Some(ref c) = chunk {
376            // Initialize match tracking for outer joins
377            if matches!(self.join_type, JoinType::Left | JoinType::Full) {
378                self.probe_matched = vec![false; c.row_count()];
379            }
380        }
381        let has_chunk = chunk.is_some();
382        self.current_probe_chunk = chunk;
383        self.current_probe_row = 0;
384        Ok(has_chunk)
385    }
386
387    /// Emits unmatched build rows for right/full outer joins.
388    fn emit_unmatched_build(&mut self) -> OperatorResult {
389        if self.build_matched.is_empty() {
390            return Ok(None);
391        }
392
393        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
394
395        // Determine probe column count from schema or first probe chunk
396        let probe_col_count = if !self.build_chunks.is_empty() {
397            self.output_schema.len() - self.build_chunks[0].column_count()
398        } else {
399            0
400        };
401
402        while self.unmatched_chunk_idx < self.build_chunks.len() {
403            let chunk = &self.build_chunks[self.unmatched_chunk_idx];
404            let matched = &self.build_matched[self.unmatched_chunk_idx];
405
406            while self.unmatched_row_idx < matched.len() {
407                if !matched[self.unmatched_row_idx] {
408                    // This row was not matched - emit with nulls on probe side
409
410                    // Emit nulls for probe side
411                    for col_idx in 0..probe_col_count {
412                        if let Some(dst_col) = builder.column_mut(col_idx) {
413                            dst_col.push_value(Value::Null);
414                        }
415                    }
416
417                    // Copy build side values
418                    for col_idx in 0..chunk.column_count() {
419                        if let (Some(src_col), Some(dst_col)) = (
420                            chunk.column(col_idx),
421                            builder.column_mut(probe_col_count + col_idx),
422                        ) {
423                            if let Some(value) = src_col.get_value(self.unmatched_row_idx) {
424                                dst_col.push_value(value);
425                            } else {
426                                dst_col.push_value(Value::Null);
427                            }
428                        }
429                    }
430
431                    builder.advance_row();
432
433                    if builder.is_full() {
434                        self.unmatched_row_idx += 1;
435                        return Ok(Some(builder.finish()));
436                    }
437                }
438
439                self.unmatched_row_idx += 1;
440            }
441
442            self.unmatched_chunk_idx += 1;
443            self.unmatched_row_idx = 0;
444        }
445
446        if builder.row_count() > 0 {
447            Ok(Some(builder.finish()))
448        } else {
449            Ok(None)
450        }
451    }
452}
453
454impl Operator for HashJoinOperator {
455    fn next(&mut self) -> OperatorResult {
456        // Phase 1: Build hash table
457        if !self.build_complete {
458            self.build_hash_table()?;
459        }
460
461        // Phase 3: Emit unmatched build rows (right/full outer join)
462        if self.emitting_unmatched {
463            return self.emit_unmatched_build();
464        }
465
466        // Phase 2: Probe
467        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
468
469        loop {
470            // Get current probe chunk or fetch new one
471            if self.current_probe_chunk.is_none() && !self.get_next_probe_chunk()? {
472                // No more probe data
473                if matches!(self.join_type, JoinType::Right | JoinType::Full) {
474                    self.emitting_unmatched = true;
475                    return self.emit_unmatched_build();
476                }
477                return if builder.row_count() > 0 {
478                    Ok(Some(builder.finish()))
479                } else {
480                    Ok(None)
481                };
482            }
483
484            // Invariant: current_probe_chunk is Some here - the guard at line 396 either
485            // populates it via get_next_probe_chunk() or returns from the function
486            let probe_chunk = self
487                .current_probe_chunk
488                .as_ref()
489                .expect("probe chunk is Some: guard at line 396 ensures this");
490            let probe_rows: Vec<usize> = probe_chunk.selected_indices().collect();
491
492            while self.current_probe_row < probe_rows.len() {
493                let probe_row = probe_rows[self.current_probe_row];
494
495                // If we don't have current matches, look them up
496                if self.current_matches.is_empty() && self.current_match_position == 0 {
497                    let key = self.extract_key(probe_chunk, probe_row, &self.probe_keys)?;
498
499                    // Handle semi/anti joins differently
500                    match self.join_type {
501                        JoinType::Semi => {
502                            if self.hash_table.contains_key(&key) {
503                                // Emit probe row only
504                                for col_idx in 0..probe_chunk.column_count() {
505                                    if let (Some(src_col), Some(dst_col)) =
506                                        (probe_chunk.column(col_idx), builder.column_mut(col_idx))
507                                        && let Some(value) = src_col.get_value(probe_row)
508                                    {
509                                        dst_col.push_value(value);
510                                    }
511                                }
512                                builder.advance_row();
513                            }
514                            self.current_probe_row += 1;
515                            continue;
516                        }
517                        JoinType::Anti => {
518                            if !self.hash_table.contains_key(&key) {
519                                // Emit probe row only
520                                for col_idx in 0..probe_chunk.column_count() {
521                                    if let (Some(src_col), Some(dst_col)) =
522                                        (probe_chunk.column(col_idx), builder.column_mut(col_idx))
523                                        && let Some(value) = src_col.get_value(probe_row)
524                                    {
525                                        dst_col.push_value(value);
526                                    }
527                                }
528                                builder.advance_row();
529                            }
530                            self.current_probe_row += 1;
531                            continue;
532                        }
533                        _ => {
534                            self.current_matches =
535                                self.hash_table.get(&key).cloned().unwrap_or_default();
536                        }
537                    }
538                }
539
540                // Process matches
541                if self.current_matches.is_empty() {
542                    // No matches - for left/full outer join, emit with nulls
543                    if matches!(self.join_type, JoinType::Left | JoinType::Full) {
544                        self.produce_output_row(&mut builder, probe_chunk, probe_row, None, None)?;
545                    }
546                    self.current_probe_row += 1;
547                    self.current_match_position = 0;
548                } else {
549                    // Process each match
550                    while self.current_match_position < self.current_matches.len() {
551                        let (build_chunk_idx, build_row) =
552                            self.current_matches[self.current_match_position];
553                        let build_chunk = &self.build_chunks[build_chunk_idx];
554
555                        // Mark as matched for outer joins
556                        if matches!(self.join_type, JoinType::Left | JoinType::Full)
557                            && probe_row < self.probe_matched.len()
558                        {
559                            self.probe_matched[probe_row] = true;
560                        }
561                        if matches!(self.join_type, JoinType::Right | JoinType::Full)
562                            && build_chunk_idx < self.build_matched.len()
563                            && build_row < self.build_matched[build_chunk_idx].len()
564                        {
565                            self.build_matched[build_chunk_idx][build_row] = true;
566                        }
567
568                        self.produce_output_row(
569                            &mut builder,
570                            probe_chunk,
571                            probe_row,
572                            Some(build_chunk),
573                            Some(build_row),
574                        )?;
575
576                        self.current_match_position += 1;
577
578                        if builder.is_full() {
579                            return Ok(Some(builder.finish()));
580                        }
581                    }
582
583                    // Done with this probe row
584                    self.current_probe_row += 1;
585                    self.current_matches.clear();
586                    self.current_match_position = 0;
587                }
588
589                if builder.is_full() {
590                    return Ok(Some(builder.finish()));
591                }
592            }
593
594            // Done with current probe chunk
595            self.current_probe_chunk = None;
596            self.current_probe_row = 0;
597
598            if builder.row_count() > 0 {
599                return Ok(Some(builder.finish()));
600            }
601        }
602    }
603
604    fn reset(&mut self) {
605        self.probe_side.reset();
606        self.build_side.reset();
607        self.hash_table.clear();
608        self.build_chunks.clear();
609        self.build_complete = false;
610        self.current_probe_chunk = None;
611        self.current_probe_row = 0;
612        self.current_match_position = 0;
613        self.current_matches.clear();
614        self.probe_matched.clear();
615        self.build_matched.clear();
616        self.emitting_unmatched = false;
617        self.unmatched_chunk_idx = 0;
618        self.unmatched_row_idx = 0;
619    }
620
621    fn name(&self) -> &'static str {
622        "HashJoin"
623    }
624}
625
626/// Nested loop join operator.
627///
628/// Performs a cartesian product of both sides, filtering by the join condition.
629/// Less efficient than hash join but supports any join condition.
630pub struct NestedLoopJoinOperator {
631    /// Left side operator.
632    left: Box<dyn Operator>,
633    /// Right side operator.
634    right: Box<dyn Operator>,
635    /// Join condition predicate (if any).
636    condition: Option<Box<dyn JoinCondition>>,
637    /// Join type.
638    join_type: JoinType,
639    /// Output schema.
640    output_schema: Vec<LogicalType>,
641    /// Materialized right side chunks.
642    right_chunks: Vec<DataChunk>,
643    /// Whether the right side is materialized.
644    right_materialized: bool,
645    /// Current left chunk.
646    current_left_chunk: Option<DataChunk>,
647    /// Current row in the left chunk.
648    current_left_row: usize,
649    /// Current chunk index in the right side.
650    current_right_chunk: usize,
651    /// Whether the current left row has been matched (for Left Join).
652    current_left_matched: bool,
653    /// Current row in the current right chunk.
654    current_right_row: usize,
655}
656
657/// Trait for join conditions.
658pub trait JoinCondition: Send + Sync {
659    /// Evaluates the condition for a pair of rows.
660    fn evaluate(
661        &self,
662        left_chunk: &DataChunk,
663        left_row: usize,
664        right_chunk: &DataChunk,
665        right_row: usize,
666    ) -> bool;
667}
668
669/// A simple equality condition for nested loop joins.
670pub struct EqualityCondition {
671    /// Column index on the left side.
672    left_column: usize,
673    /// Column index on the right side.
674    right_column: usize,
675}
676
677impl EqualityCondition {
678    /// Creates a new equality condition.
679    pub fn new(left_column: usize, right_column: usize) -> Self {
680        Self {
681            left_column,
682            right_column,
683        }
684    }
685}
686
687impl JoinCondition for EqualityCondition {
688    fn evaluate(
689        &self,
690        left_chunk: &DataChunk,
691        left_row: usize,
692        right_chunk: &DataChunk,
693        right_row: usize,
694    ) -> bool {
695        let left_val = left_chunk
696            .column(self.left_column)
697            .and_then(|c| c.get_value(left_row));
698        let right_val = right_chunk
699            .column(self.right_column)
700            .and_then(|c| c.get_value(right_row));
701
702        match (left_val, right_val) {
703            (Some(l), Some(r)) => l == r,
704            _ => false,
705        }
706    }
707}
708
709impl NestedLoopJoinOperator {
710    /// Creates a new nested loop join operator.
711    pub fn new(
712        left: Box<dyn Operator>,
713        right: Box<dyn Operator>,
714        condition: Option<Box<dyn JoinCondition>>,
715        join_type: JoinType,
716        output_schema: Vec<LogicalType>,
717    ) -> Self {
718        Self {
719            left,
720            right,
721            condition,
722            join_type,
723            output_schema,
724            right_chunks: Vec::new(),
725            right_materialized: false,
726            current_left_chunk: None,
727            current_left_row: 0,
728            current_right_chunk: 0,
729            current_right_row: 0,
730            current_left_matched: false,
731        }
732    }
733
734    /// Materializes the right side.
735    fn materialize_right(&mut self) -> Result<(), OperatorError> {
736        while let Some(chunk) = self.right.next()? {
737            self.right_chunks.push(chunk);
738        }
739        self.right_materialized = true;
740        Ok(())
741    }
742
743    /// Produces an output row.
744    fn produce_row(
745        &self,
746        builder: &mut DataChunkBuilder,
747        left_chunk: &DataChunk,
748        left_row: usize,
749        right_chunk: &DataChunk,
750        right_row: usize,
751    ) {
752        // Copy left columns
753        for col_idx in 0..left_chunk.column_count() {
754            if let (Some(src), Some(dst)) =
755                (left_chunk.column(col_idx), builder.column_mut(col_idx))
756            {
757                if let Some(val) = src.get_value(left_row) {
758                    dst.push_value(val);
759                } else {
760                    dst.push_value(Value::Null);
761                }
762            }
763        }
764
765        // Copy right columns
766        let left_col_count = left_chunk.column_count();
767        for col_idx in 0..right_chunk.column_count() {
768            if let (Some(src), Some(dst)) = (
769                right_chunk.column(col_idx),
770                builder.column_mut(left_col_count + col_idx),
771            ) {
772                if let Some(val) = src.get_value(right_row) {
773                    dst.push_value(val);
774                } else {
775                    dst.push_value(Value::Null);
776                }
777            }
778        }
779
780        builder.advance_row();
781    }
782
783    /// Produces an output row with NULLs for the right side (for unmatched left rows in Left Join).
784    fn produce_left_unmatched_row(
785        &self,
786        builder: &mut DataChunkBuilder,
787        left_chunk: &DataChunk,
788        left_row: usize,
789        right_col_count: usize,
790    ) {
791        // Copy left columns
792        for col_idx in 0..left_chunk.column_count() {
793            if let (Some(src), Some(dst)) =
794                (left_chunk.column(col_idx), builder.column_mut(col_idx))
795            {
796                if let Some(val) = src.get_value(left_row) {
797                    dst.push_value(val);
798                } else {
799                    dst.push_value(Value::Null);
800                }
801            }
802        }
803
804        // Fill right columns with NULLs
805        let left_col_count = left_chunk.column_count();
806        for col_idx in 0..right_col_count {
807            if let Some(dst) = builder.column_mut(left_col_count + col_idx) {
808                dst.push_value(Value::Null);
809            }
810        }
811
812        builder.advance_row();
813    }
814}
815
816impl Operator for NestedLoopJoinOperator {
817    fn next(&mut self) -> OperatorResult {
818        // Materialize right side
819        if !self.right_materialized {
820            self.materialize_right()?;
821        }
822
823        // If right side is empty and not a left outer join, return nothing
824        if self.right_chunks.is_empty() && !matches!(self.join_type, JoinType::Left) {
825            return Ok(None);
826        }
827
828        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
829
830        loop {
831            // Get current left chunk
832            if self.current_left_chunk.is_none() {
833                self.current_left_chunk = self.left.next()?;
834                self.current_left_row = 0;
835                self.current_right_chunk = 0;
836                self.current_right_row = 0;
837
838                if self.current_left_chunk.is_none() {
839                    // No more left data
840                    return if builder.row_count() > 0 {
841                        Ok(Some(builder.finish()))
842                    } else {
843                        Ok(None)
844                    };
845                }
846            }
847
848            let left_chunk = self
849                .current_left_chunk
850                .as_ref()
851                .expect("left chunk is Some: loaded in loop above");
852            let left_rows: Vec<usize> = left_chunk.selected_indices().collect();
853
854            // Calculate right column count for potential unmatched rows
855            let right_col_count = if !self.right_chunks.is_empty() {
856                self.right_chunks[0].column_count()
857            } else {
858                // Infer from output schema
859                self.output_schema
860                    .len()
861                    .saturating_sub(left_chunk.column_count())
862            };
863
864            // Process current left row against all right rows
865            while self.current_left_row < left_rows.len() {
866                let left_row = left_rows[self.current_left_row];
867
868                // Reset match tracking for this left row
869                if self.current_right_chunk == 0 && self.current_right_row == 0 {
870                    self.current_left_matched = false;
871                }
872
873                // Cross join or inner/other join
874                while self.current_right_chunk < self.right_chunks.len() {
875                    let right_chunk = &self.right_chunks[self.current_right_chunk];
876                    let right_rows: Vec<usize> = right_chunk.selected_indices().collect();
877
878                    while self.current_right_row < right_rows.len() {
879                        let right_row = right_rows[self.current_right_row];
880
881                        // Check condition
882                        let matches = match &self.condition {
883                            Some(cond) => {
884                                cond.evaluate(left_chunk, left_row, right_chunk, right_row)
885                            }
886                            None => true, // Cross join
887                        };
888
889                        if matches {
890                            self.current_left_matched = true;
891                            self.produce_row(
892                                &mut builder,
893                                left_chunk,
894                                left_row,
895                                right_chunk,
896                                right_row,
897                            );
898
899                            if builder.is_full() {
900                                self.current_right_row += 1;
901                                return Ok(Some(builder.finish()));
902                            }
903                        }
904
905                        self.current_right_row += 1;
906                    }
907
908                    self.current_right_chunk += 1;
909                    self.current_right_row = 0;
910                }
911
912                // Done processing all right rows for this left row
913                // For Left Join, emit unmatched left row with NULLs
914                if matches!(self.join_type, JoinType::Left) && !self.current_left_matched {
915                    self.produce_left_unmatched_row(
916                        &mut builder,
917                        left_chunk,
918                        left_row,
919                        right_col_count,
920                    );
921
922                    if builder.is_full() {
923                        self.current_left_row += 1;
924                        self.current_right_chunk = 0;
925                        self.current_right_row = 0;
926                        return Ok(Some(builder.finish()));
927                    }
928                }
929
930                // Move to next left row
931                self.current_left_row += 1;
932                self.current_right_chunk = 0;
933                self.current_right_row = 0;
934            }
935
936            // Done with current left chunk
937            self.current_left_chunk = None;
938
939            if builder.row_count() > 0 {
940                return Ok(Some(builder.finish()));
941            }
942        }
943    }
944
945    fn reset(&mut self) {
946        self.left.reset();
947        self.right.reset();
948        self.right_chunks.clear();
949        self.right_materialized = false;
950        self.current_left_chunk = None;
951        self.current_left_row = 0;
952        self.current_right_chunk = 0;
953        self.current_right_row = 0;
954        self.current_left_matched = false;
955    }
956
957    fn name(&self) -> &'static str {
958        "NestedLoopJoin"
959    }
960}
961
962#[cfg(test)]
963mod tests {
964    use super::*;
965    use crate::execution::chunk::DataChunkBuilder;
966
967    /// Mock operator for testing.
968    struct MockOperator {
969        chunks: Vec<DataChunk>,
970        position: usize,
971    }
972
973    impl MockOperator {
974        fn new(chunks: Vec<DataChunk>) -> Self {
975            Self {
976                chunks,
977                position: 0,
978            }
979        }
980    }
981
982    impl Operator for MockOperator {
983        fn next(&mut self) -> OperatorResult {
984            if self.position < self.chunks.len() {
985                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
986                self.position += 1;
987                Ok(Some(chunk))
988            } else {
989                Ok(None)
990            }
991        }
992
993        fn reset(&mut self) {
994            self.position = 0;
995        }
996
997        fn name(&self) -> &'static str {
998            "Mock"
999        }
1000    }
1001
1002    fn create_int_chunk(values: &[i64]) -> DataChunk {
1003        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1004        for &v in values {
1005            builder.column_mut(0).unwrap().push_int64(v);
1006            builder.advance_row();
1007        }
1008        builder.finish()
1009    }
1010
1011    #[test]
1012    fn test_hash_join_inner() {
1013        // Left: [1, 2, 3, 4]
1014        // Right: [2, 3, 4, 5]
1015        // Inner join on column 0 should produce: [2, 3, 4]
1016
1017        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
1018        let right = MockOperator::new(vec![create_int_chunk(&[2, 3, 4, 5])]);
1019
1020        let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
1021        let mut join = HashJoinOperator::new(
1022            Box::new(left),
1023            Box::new(right),
1024            vec![0],
1025            vec![0],
1026            JoinType::Inner,
1027            output_schema,
1028        );
1029
1030        let mut results = Vec::new();
1031        while let Some(chunk) = join.next().unwrap() {
1032            for row in chunk.selected_indices() {
1033                let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
1034                let right_val = chunk.column(1).unwrap().get_int64(row).unwrap();
1035                results.push((left_val, right_val));
1036            }
1037        }
1038
1039        results.sort_unstable();
1040        assert_eq!(results, vec![(2, 2), (3, 3), (4, 4)]);
1041    }
1042
1043    #[test]
1044    fn test_hash_join_left_outer() {
1045        // Left: [1, 2, 3]
1046        // Right: [2, 3]
1047        // Left outer join should produce: [(1, null), (2, 2), (3, 3)]
1048
1049        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3])]);
1050        let right = MockOperator::new(vec![create_int_chunk(&[2, 3])]);
1051
1052        let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
1053        let mut join = HashJoinOperator::new(
1054            Box::new(left),
1055            Box::new(right),
1056            vec![0],
1057            vec![0],
1058            JoinType::Left,
1059            output_schema,
1060        );
1061
1062        let mut results = Vec::new();
1063        while let Some(chunk) = join.next().unwrap() {
1064            for row in chunk.selected_indices() {
1065                let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
1066                let right_val = chunk.column(1).unwrap().get_int64(row);
1067                results.push((left_val, right_val));
1068            }
1069        }
1070
1071        results.sort_by_key(|(l, _)| *l);
1072        assert_eq!(results.len(), 3);
1073        assert_eq!(results[0], (1, None)); // No match
1074        assert_eq!(results[1], (2, Some(2)));
1075        assert_eq!(results[2], (3, Some(3)));
1076    }
1077
1078    #[test]
1079    fn test_nested_loop_cross_join() {
1080        // Left: [1, 2]
1081        // Right: [10, 20]
1082        // Cross join should produce: [(1,10), (1,20), (2,10), (2,20)]
1083
1084        let left = MockOperator::new(vec![create_int_chunk(&[1, 2])]);
1085        let right = MockOperator::new(vec![create_int_chunk(&[10, 20])]);
1086
1087        let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
1088        let mut join = NestedLoopJoinOperator::new(
1089            Box::new(left),
1090            Box::new(right),
1091            None,
1092            JoinType::Cross,
1093            output_schema,
1094        );
1095
1096        let mut results = Vec::new();
1097        while let Some(chunk) = join.next().unwrap() {
1098            for row in chunk.selected_indices() {
1099                let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
1100                let right_val = chunk.column(1).unwrap().get_int64(row).unwrap();
1101                results.push((left_val, right_val));
1102            }
1103        }
1104
1105        results.sort_unstable();
1106        assert_eq!(results, vec![(1, 10), (1, 20), (2, 10), (2, 20)]);
1107    }
1108
1109    #[test]
1110    fn test_hash_join_semi() {
1111        // Left: [1, 2, 3, 4]
1112        // Right: [2, 4]
1113        // Semi join should produce: [2, 4] (only left rows that have matches)
1114
1115        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
1116        let right = MockOperator::new(vec![create_int_chunk(&[2, 4])]);
1117
1118        // Semi join only outputs probe (left) columns
1119        let output_schema = vec![LogicalType::Int64];
1120        let mut join = HashJoinOperator::new(
1121            Box::new(left),
1122            Box::new(right),
1123            vec![0],
1124            vec![0],
1125            JoinType::Semi,
1126            output_schema,
1127        );
1128
1129        let mut results = Vec::new();
1130        while let Some(chunk) = join.next().unwrap() {
1131            for row in chunk.selected_indices() {
1132                let val = chunk.column(0).unwrap().get_int64(row).unwrap();
1133                results.push(val);
1134            }
1135        }
1136
1137        results.sort_unstable();
1138        assert_eq!(results, vec![2, 4]);
1139    }
1140
1141    #[test]
1142    fn test_hash_join_anti() {
1143        // Left: [1, 2, 3, 4]
1144        // Right: [2, 4]
1145        // Anti join should produce: [1, 3] (left rows with no matches)
1146
1147        let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
1148        let right = MockOperator::new(vec![create_int_chunk(&[2, 4])]);
1149
1150        let output_schema = vec![LogicalType::Int64];
1151        let mut join = HashJoinOperator::new(
1152            Box::new(left),
1153            Box::new(right),
1154            vec![0],
1155            vec![0],
1156            JoinType::Anti,
1157            output_schema,
1158        );
1159
1160        let mut results = Vec::new();
1161        while let Some(chunk) = join.next().unwrap() {
1162            for row in chunk.selected_indices() {
1163                let val = chunk.column(0).unwrap().get_int64(row).unwrap();
1164                results.push(val);
1165            }
1166        }
1167
1168        results.sort_unstable();
1169        assert_eq!(results, vec![1, 3]);
1170    }
1171
1172    #[test]
1173    fn test_hash_key_from_map() {
1174        use grafeo_common::types::{PropertyKey, Value};
1175        use std::collections::BTreeMap;
1176        use std::sync::Arc;
1177
1178        let mut map = BTreeMap::new();
1179        map.insert(PropertyKey::new("key"), Value::Int64(42));
1180        let v = Value::Map(Arc::new(map));
1181        let key = HashKey::from_value(&v);
1182        // BTreeMap iterates in ascending key order, result is a Composite
1183        assert!(matches!(key, HashKey::Composite(_)));
1184
1185        // Two maps with the same content produce the same hash key
1186        let mut map2 = BTreeMap::new();
1187        map2.insert(PropertyKey::new("key"), Value::Int64(42));
1188        let v2 = Value::Map(Arc::new(map2));
1189        assert_eq!(HashKey::from_value(&v), HashKey::from_value(&v2));
1190    }
1191
1192    #[test]
1193    fn test_hash_key_from_map_empty() {
1194        use grafeo_common::types::Value;
1195        use std::collections::BTreeMap;
1196        use std::sync::Arc;
1197
1198        let v = Value::Map(Arc::new(BTreeMap::new()));
1199        let key = HashKey::from_value(&v);
1200        assert_eq!(key, HashKey::Composite(vec![]));
1201    }
1202
1203    #[test]
1204    fn test_hash_key_from_gcounter() {
1205        use grafeo_common::types::Value;
1206        use std::collections::HashMap;
1207        use std::sync::Arc;
1208
1209        let mut counts = HashMap::new();
1210        counts.insert("node-a".to_string(), 5u64);
1211        counts.insert("node-b".to_string(), 3u64);
1212        let v = Value::GCounter(Arc::new(counts));
1213        // GCounter hashes to sum of all values (5 + 3 = 8)
1214        assert_eq!(HashKey::from_value(&v), HashKey::Int64(8));
1215    }
1216
1217    #[test]
1218    fn test_hash_key_from_gcounter_empty() {
1219        use grafeo_common::types::Value;
1220        use std::collections::HashMap;
1221        use std::sync::Arc;
1222
1223        let v = Value::GCounter(Arc::new(HashMap::new()));
1224        assert_eq!(HashKey::from_value(&v), HashKey::Int64(0));
1225    }
1226
1227    #[test]
1228    fn test_hash_key_from_oncounter() {
1229        use grafeo_common::types::Value;
1230        use std::collections::HashMap;
1231        use std::sync::Arc;
1232
1233        let mut pos = HashMap::new();
1234        pos.insert("node-a".to_string(), 10u64);
1235        let mut neg = HashMap::new();
1236        neg.insert("node-a".to_string(), 3u64);
1237        let v = Value::OnCounter {
1238            pos: Arc::new(pos),
1239            neg: Arc::new(neg),
1240        };
1241        // OnCounter hashes to pos_sum - neg_sum = 10 - 3 = 7
1242        assert_eq!(HashKey::from_value(&v), HashKey::Int64(7));
1243    }
1244
1245    #[test]
1246    fn test_hash_key_from_oncounter_balanced() {
1247        use grafeo_common::types::Value;
1248        use std::collections::HashMap;
1249        use std::sync::Arc;
1250
1251        let mut pos = HashMap::new();
1252        pos.insert("r".to_string(), 5u64);
1253        let mut neg = HashMap::new();
1254        neg.insert("r".to_string(), 5u64);
1255        let v = Value::OnCounter {
1256            pos: Arc::new(pos),
1257            neg: Arc::new(neg),
1258        };
1259        assert_eq!(HashKey::from_value(&v), HashKey::Int64(0));
1260    }
1261}