Skip to main content

grafeo_core/execution/parallel/
merge.rs

1//! Merge utilities for parallel pipeline breakers.
2//!
3//! When parallel pipelines have pipeline breakers (Sort, Aggregate, Distinct),
4//! each worker produces partial results that must be merged into final output.
5
6use crate::execution::chunk::DataChunk;
7use crate::execution::operators::OperatorError;
8use crate::execution::vector::ValueVector;
9use grafeo_common::types::Value;
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12
13/// Trait for operators that support parallel merge.
14///
15/// Pipeline breakers must implement this to enable parallel execution.
16pub trait MergeableOperator: Send + Sync {
17    /// Merges partial results from another operator instance.
18    fn merge_from(&mut self, other: Self)
19    where
20        Self: Sized;
21
22    /// Returns whether this operator supports parallel merge.
23    fn supports_parallel_merge(&self) -> bool {
24        true
25    }
26}
27
28/// Accumulator state that supports merging.
29///
30/// Used by aggregate operators to merge partial aggregations.
31#[derive(Debug, Clone)]
32pub struct MergeableAccumulator {
33    /// Count of values.
34    pub count: i64,
35    /// Sum of values.
36    pub sum: f64,
37    /// Minimum value.
38    pub min: Option<Value>,
39    /// Maximum value.
40    pub max: Option<Value>,
41    /// First value encountered.
42    pub first: Option<Value>,
43    /// For AVG: sum of squared values (for variance if needed).
44    pub sum_squared: f64,
45}
46
47impl MergeableAccumulator {
48    /// Creates a new empty accumulator.
49    #[must_use]
50    pub fn new() -> Self {
51        Self {
52            count: 0,
53            sum: 0.0,
54            min: None,
55            max: None,
56            first: None,
57            sum_squared: 0.0,
58        }
59    }
60
61    /// Adds a value to the accumulator.
62    pub fn add(&mut self, value: &Value) {
63        if matches!(value, Value::Null) {
64            return;
65        }
66
67        self.count += 1;
68
69        if let Some(n) = value_to_f64(value) {
70            self.sum += n;
71            self.sum_squared += n * n;
72        }
73
74        // Min
75        if self.min.is_none() || compare_for_min(&self.min, value) {
76            self.min = Some(value.clone());
77        }
78
79        // Max
80        if self.max.is_none() || compare_for_max(&self.max, value) {
81            self.max = Some(value.clone());
82        }
83
84        // First
85        if self.first.is_none() {
86            self.first = Some(value.clone());
87        }
88    }
89
90    /// Merges another accumulator into this one.
91    pub fn merge(&mut self, other: &MergeableAccumulator) {
92        self.count += other.count;
93        self.sum += other.sum;
94        self.sum_squared += other.sum_squared;
95
96        // Merge min
97        if let Some(ref other_min) = other.min
98            && compare_for_min(&self.min, other_min)
99        {
100            self.min = Some(other_min.clone());
101        }
102
103        // Merge max
104        if let Some(ref other_max) = other.max
105            && compare_for_max(&self.max, other_max)
106        {
107            self.max = Some(other_max.clone());
108        }
109
110        // Keep our first (we processed earlier)
111        // If we have no first, take theirs
112        if self.first.is_none() {
113            self.first.clone_from(&other.first);
114        }
115    }
116
117    /// Finalizes COUNT aggregate.
118    #[must_use]
119    pub fn finalize_count(&self) -> Value {
120        Value::Int64(self.count)
121    }
122
123    /// Finalizes SUM aggregate.
124    #[must_use]
125    pub fn finalize_sum(&self) -> Value {
126        if self.count == 0 {
127            Value::Null
128        } else {
129            Value::Float64(self.sum)
130        }
131    }
132
133    /// Finalizes MIN aggregate.
134    #[must_use]
135    pub fn finalize_min(&self) -> Value {
136        self.min.clone().unwrap_or(Value::Null)
137    }
138
139    /// Finalizes MAX aggregate.
140    #[must_use]
141    pub fn finalize_max(&self) -> Value {
142        self.max.clone().unwrap_or(Value::Null)
143    }
144
145    /// Finalizes AVG aggregate.
146    #[must_use]
147    pub fn finalize_avg(&self) -> Value {
148        if self.count == 0 {
149            Value::Null
150        } else {
151            Value::Float64(self.sum / self.count as f64)
152        }
153    }
154
155    /// Finalizes FIRST aggregate.
156    #[must_use]
157    pub fn finalize_first(&self) -> Value {
158        self.first.clone().unwrap_or(Value::Null)
159    }
160}
161
162impl Default for MergeableAccumulator {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168fn value_to_f64(value: &Value) -> Option<f64> {
169    match value {
170        Value::Int64(i) => Some(*i as f64),
171        Value::Float64(f) => Some(*f),
172        _ => None,
173    }
174}
175
176fn compare_for_min(current: &Option<Value>, new: &Value) -> bool {
177    match (current, new) {
178        (None, _) => true,
179        (Some(Value::Int64(a)), Value::Int64(b)) => b < a,
180        (Some(Value::Float64(a)), Value::Float64(b)) => b < a,
181        (Some(Value::String(a)), Value::String(b)) => b < a,
182        _ => false,
183    }
184}
185
186fn compare_for_max(current: &Option<Value>, new: &Value) -> bool {
187    match (current, new) {
188        (None, _) => true,
189        (Some(Value::Int64(a)), Value::Int64(b)) => b > a,
190        (Some(Value::Float64(a)), Value::Float64(b)) => b > a,
191        (Some(Value::String(a)), Value::String(b)) => b > a,
192        _ => false,
193    }
194}
195
196/// Sort key for k-way merge.
197#[derive(Debug, Clone)]
198pub struct SortKey {
199    /// Column index to sort by.
200    pub column: usize,
201    /// Sort direction (ascending = true).
202    pub ascending: bool,
203    /// Nulls first (true) or last (false).
204    pub nulls_first: bool,
205}
206
207impl SortKey {
208    /// Creates an ascending sort key.
209    #[must_use]
210    pub fn ascending(column: usize) -> Self {
211        Self {
212            column,
213            ascending: true,
214            nulls_first: false,
215        }
216    }
217
218    /// Creates a descending sort key.
219    #[must_use]
220    pub fn descending(column: usize) -> Self {
221        Self {
222            column,
223            ascending: false,
224            nulls_first: true,
225        }
226    }
227}
228
229/// Entry in the k-way merge heap.
230struct MergeEntry {
231    /// Row data.
232    row: Vec<Value>,
233    /// Source run index.
234    run_index: usize,
235    /// Sort keys for comparison.
236    keys: Vec<SortKey>,
237}
238
239impl MergeEntry {
240    fn compare_to(&self, other: &Self) -> Ordering {
241        for key in &self.keys {
242            let a = self.row.get(key.column);
243            let b = other.row.get(key.column);
244
245            let ordering = compare_values_for_sort(a, b, key.nulls_first);
246
247            let ordering = if key.ascending {
248                ordering
249            } else {
250                ordering.reverse()
251            };
252
253            if ordering != Ordering::Equal {
254                return ordering;
255            }
256        }
257        Ordering::Equal
258    }
259}
260
261impl PartialEq for MergeEntry {
262    fn eq(&self, other: &Self) -> bool {
263        self.compare_to(other) == Ordering::Equal
264    }
265}
266
267impl Eq for MergeEntry {}
268
269impl PartialOrd for MergeEntry {
270    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
271        Some(self.cmp(other))
272    }
273}
274
275impl Ord for MergeEntry {
276    fn cmp(&self, other: &Self) -> Ordering {
277        // Reverse for min-heap behavior (we want smallest first)
278        other.compare_to(self)
279    }
280}
281
282fn compare_values_for_sort(a: Option<&Value>, b: Option<&Value>, nulls_first: bool) -> Ordering {
283    match (a, b) {
284        (None, None) | (Some(Value::Null), Some(Value::Null)) => Ordering::Equal,
285        (None, _) | (Some(Value::Null), _) => {
286            if nulls_first {
287                Ordering::Less
288            } else {
289                Ordering::Greater
290            }
291        }
292        (_, None) | (_, Some(Value::Null)) => {
293            if nulls_first {
294                Ordering::Greater
295            } else {
296                Ordering::Less
297            }
298        }
299        (Some(a), Some(b)) => compare_values(a, b),
300    }
301}
302
303fn compare_values(a: &Value, b: &Value) -> Ordering {
304    match (a, b) {
305        (Value::Bool(a), Value::Bool(b)) => a.cmp(b),
306        (Value::Int64(a), Value::Int64(b)) => a.cmp(b),
307        (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
308        (Value::String(a), Value::String(b)) => a.cmp(b),
309        _ => Ordering::Equal,
310    }
311}
312
313/// Merges multiple sorted runs into a single sorted output.
314///
315/// Uses a min-heap for efficient k-way merge.
316pub fn merge_sorted_runs(
317    runs: Vec<Vec<Vec<Value>>>,
318    keys: &[SortKey],
319) -> Result<Vec<Vec<Value>>, OperatorError> {
320    if runs.is_empty() {
321        return Ok(Vec::new());
322    }
323
324    if runs.len() == 1 {
325        // Invariant: runs.len() == 1 guarantees exactly one element
326        return Ok(runs
327            .into_iter()
328            .next()
329            .expect("runs has exactly one element: checked on previous line"));
330    }
331
332    // Count total rows
333    let total_rows: usize = runs.iter().map(|r| r.len()).sum();
334    let mut result = Vec::with_capacity(total_rows);
335
336    // Track position in each run
337    let mut positions: Vec<usize> = vec![0; runs.len()];
338
339    // Initialize heap with first row from each non-empty run
340    let mut heap = BinaryHeap::new();
341    for (run_index, run) in runs.iter().enumerate() {
342        if !run.is_empty() {
343            heap.push(MergeEntry {
344                row: run[0].clone(),
345                run_index,
346                keys: keys.to_vec(),
347            });
348            positions[run_index] = 1;
349        }
350    }
351
352    // Extract rows in order
353    while let Some(entry) = heap.pop() {
354        result.push(entry.row);
355
356        // Add next row from same run if available
357        let pos = positions[entry.run_index];
358        if pos < runs[entry.run_index].len() {
359            heap.push(MergeEntry {
360                row: runs[entry.run_index][pos].clone(),
361                run_index: entry.run_index,
362                keys: keys.to_vec(),
363            });
364            positions[entry.run_index] += 1;
365        }
366    }
367
368    Ok(result)
369}
370
371/// Converts sorted rows to DataChunks.
372pub fn rows_to_chunks(
373    rows: Vec<Vec<Value>>,
374    chunk_size: usize,
375) -> Result<Vec<DataChunk>, OperatorError> {
376    if rows.is_empty() {
377        return Ok(Vec::new());
378    }
379
380    let num_columns = rows[0].len();
381    let num_chunks = (rows.len() + chunk_size - 1) / chunk_size;
382    let mut chunks = Vec::with_capacity(num_chunks);
383
384    for chunk_rows in rows.chunks(chunk_size) {
385        let mut columns: Vec<ValueVector> = (0..num_columns).map(|_| ValueVector::new()).collect();
386
387        for row in chunk_rows {
388            for (col_idx, col) in columns.iter_mut().enumerate() {
389                let val = row.get(col_idx).cloned().unwrap_or(Value::Null);
390                col.push(val);
391            }
392        }
393
394        chunks.push(DataChunk::new(columns));
395    }
396
397    Ok(chunks)
398}
399
400/// Merges multiple sorted DataChunk streams into a single sorted stream.
401pub fn merge_sorted_chunks(
402    runs: Vec<Vec<DataChunk>>,
403    keys: &[SortKey],
404    chunk_size: usize,
405) -> Result<Vec<DataChunk>, OperatorError> {
406    // Convert chunks to row format for merging
407    let row_runs: Vec<Vec<Vec<Value>>> = runs.into_iter().map(chunks_to_rows).collect();
408
409    let merged_rows = merge_sorted_runs(row_runs, keys)?;
410    rows_to_chunks(merged_rows, chunk_size)
411}
412
413/// Converts DataChunks to row format.
414fn chunks_to_rows(chunks: Vec<DataChunk>) -> Vec<Vec<Value>> {
415    let mut rows = Vec::new();
416
417    for chunk in chunks {
418        let num_columns = chunk.num_columns();
419        for i in 0..chunk.len() {
420            let mut row = Vec::with_capacity(num_columns);
421            for col_idx in 0..num_columns {
422                let val = chunk
423                    .column(col_idx)
424                    .and_then(|c| c.get(i))
425                    .unwrap_or(Value::Null);
426                row.push(val);
427            }
428            rows.push(row);
429        }
430    }
431
432    rows
433}
434
435/// Concatenates multiple DataChunk results (for non-sorted parallel results).
436pub fn concat_parallel_results(results: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
437    results.into_iter().flatten().collect()
438}
439
440/// Merges parallel DISTINCT results by deduplication.
441pub fn merge_distinct_results(
442    results: Vec<Vec<DataChunk>>,
443) -> Result<Vec<DataChunk>, OperatorError> {
444    use std::collections::HashSet;
445
446    // Simple row-based deduplication using hash
447    let mut seen: HashSet<u64> = HashSet::new();
448    let mut unique_rows: Vec<Vec<Value>> = Vec::new();
449
450    for chunks in results {
451        for chunk in chunks {
452            let num_columns = chunk.num_columns();
453            for i in 0..chunk.len() {
454                let mut row = Vec::with_capacity(num_columns);
455                for col_idx in 0..num_columns {
456                    let val = chunk
457                        .column(col_idx)
458                        .and_then(|c| c.get(i))
459                        .unwrap_or(Value::Null);
460                    row.push(val);
461                }
462
463                let hash = hash_row(&row);
464                if seen.insert(hash) {
465                    unique_rows.push(row);
466                }
467            }
468        }
469    }
470
471    rows_to_chunks(unique_rows, 2048)
472}
473
474fn hash_row(row: &[Value]) -> u64 {
475    use std::collections::hash_map::DefaultHasher;
476    use std::hash::{Hash, Hasher};
477
478    let mut hasher = DefaultHasher::new();
479    for value in row {
480        match value {
481            Value::Null => 0u8.hash(&mut hasher),
482            Value::Bool(b) => b.hash(&mut hasher),
483            Value::Int64(i) => i.hash(&mut hasher),
484            Value::Float64(f) => f.to_bits().hash(&mut hasher),
485            Value::String(s) => s.hash(&mut hasher),
486            _ => 0u8.hash(&mut hasher),
487        }
488    }
489    hasher.finish()
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    #[test]
497    fn test_mergeable_accumulator() {
498        let mut acc1 = MergeableAccumulator::new();
499        acc1.add(&Value::Int64(10));
500        acc1.add(&Value::Int64(20));
501
502        let mut acc2 = MergeableAccumulator::new();
503        acc2.add(&Value::Int64(30));
504        acc2.add(&Value::Int64(40));
505
506        acc1.merge(&acc2);
507
508        assert_eq!(acc1.count, 4);
509        assert_eq!(acc1.sum, 100.0);
510        assert_eq!(acc1.finalize_min(), Value::Int64(10));
511        assert_eq!(acc1.finalize_max(), Value::Int64(40));
512        assert_eq!(acc1.finalize_avg(), Value::Float64(25.0));
513    }
514
515    #[test]
516    fn test_merge_sorted_runs_empty() {
517        let runs: Vec<Vec<Vec<Value>>> = Vec::new();
518        let result = merge_sorted_runs(runs, &[]).unwrap();
519        assert!(result.is_empty());
520    }
521
522    #[test]
523    fn test_merge_sorted_runs_single() {
524        let runs = vec![vec![
525            vec![Value::Int64(1)],
526            vec![Value::Int64(2)],
527            vec![Value::Int64(3)],
528        ]];
529        let keys = vec![SortKey::ascending(0)];
530
531        let result = merge_sorted_runs(runs, &keys).unwrap();
532        assert_eq!(result.len(), 3);
533    }
534
535    #[test]
536    fn test_merge_sorted_runs_multiple() {
537        // Run 1: [1, 4, 7]
538        // Run 2: [2, 5, 8]
539        // Run 3: [3, 6, 9]
540        let runs = vec![
541            vec![
542                vec![Value::Int64(1)],
543                vec![Value::Int64(4)],
544                vec![Value::Int64(7)],
545            ],
546            vec![
547                vec![Value::Int64(2)],
548                vec![Value::Int64(5)],
549                vec![Value::Int64(8)],
550            ],
551            vec![
552                vec![Value::Int64(3)],
553                vec![Value::Int64(6)],
554                vec![Value::Int64(9)],
555            ],
556        ];
557        let keys = vec![SortKey::ascending(0)];
558
559        let result = merge_sorted_runs(runs, &keys).unwrap();
560        assert_eq!(result.len(), 9);
561
562        // Verify sorted order
563        for i in 0..9 {
564            assert_eq!(result[i][0], Value::Int64((i + 1) as i64));
565        }
566    }
567
568    #[test]
569    fn test_merge_sorted_runs_descending() {
570        let runs = vec![
571            vec![
572                vec![Value::Int64(7)],
573                vec![Value::Int64(4)],
574                vec![Value::Int64(1)],
575            ],
576            vec![
577                vec![Value::Int64(8)],
578                vec![Value::Int64(5)],
579                vec![Value::Int64(2)],
580            ],
581        ];
582        let keys = vec![SortKey::descending(0)];
583
584        let result = merge_sorted_runs(runs, &keys).unwrap();
585        assert_eq!(result.len(), 6);
586
587        // Verify descending order
588        assert_eq!(result[0][0], Value::Int64(8));
589        assert_eq!(result[1][0], Value::Int64(7));
590        assert_eq!(result[5][0], Value::Int64(1));
591    }
592
593    #[test]
594    fn test_rows_to_chunks() {
595        let rows = (0..10).map(|i| vec![Value::Int64(i)]).collect();
596        let chunks = rows_to_chunks(rows, 3).unwrap();
597
598        assert_eq!(chunks.len(), 4); // 10 rows / 3 = 4 chunks
599        assert_eq!(chunks[0].len(), 3);
600        assert_eq!(chunks[1].len(), 3);
601        assert_eq!(chunks[2].len(), 3);
602        assert_eq!(chunks[3].len(), 1);
603    }
604
605    #[test]
606    fn test_merge_distinct_results() {
607        let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[
608            Value::Int64(1),
609            Value::Int64(2),
610            Value::Int64(3),
611        ])]);
612
613        let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[
614            Value::Int64(2),
615            Value::Int64(3),
616            Value::Int64(4),
617        ])]);
618
619        let results = vec![vec![chunk1], vec![chunk2]];
620        let merged = merge_distinct_results(results).unwrap();
621
622        let total_rows: usize = merged.iter().map(DataChunk::len).sum();
623        assert_eq!(total_rows, 4); // 1, 2, 3, 4 (no duplicates)
624    }
625
626    #[test]
627    fn test_hash_row_with_non_primitive_values() {
628        // Exercises the catch-all branch in hash_row for non-primitive Value types
629        let row1 = vec![Value::List(vec![Value::Int64(1)].into())];
630        let row2 = vec![Value::List(vec![Value::Int64(2)].into())];
631        let row3 = vec![Value::Bytes(vec![1, 2, 3].into())];
632
633        // The catch-all hashes all non-primitive types to the same bucket (0u8)
634        let h1 = hash_row(&row1);
635        let h2 = hash_row(&row2);
636        let h3 = hash_row(&row3);
637
638        // All non-primitive types hash identically via the catch-all
639        assert_eq!(h1, h2);
640        assert_eq!(h2, h3);
641    }
642
643    #[test]
644    fn test_concat_parallel_results() {
645        let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(1)])]);
646        let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(2)])]);
647        let chunk3 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(3)])]);
648
649        let results = vec![vec![chunk1], vec![chunk2, chunk3]];
650        let concatenated = concat_parallel_results(results);
651
652        assert_eq!(concatenated.len(), 3);
653    }
654}