Skip to main content

grafeo_core/execution/parallel/
merge.rs

1//! Merge utilities for parallel pipeline breakers.
2//!
3//! When parallel pipelines have pipeline breakers (Sort, Aggregate, Distinct),
4//! each worker produces partial results that must be merged into final output.
5
6use crate::execution::chunk::DataChunk;
7use crate::execution::vector::ValueVector;
8use grafeo_common::types::Value;
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12/// Trait for operators that support parallel merge.
13///
14/// Pipeline breakers must implement this to enable parallel execution.
15pub trait MergeableOperator: Send + Sync {
16    /// Merges partial results from another operator instance.
17    fn merge_from(&mut self, other: Self)
18    where
19        Self: Sized;
20
21    /// Returns whether this operator supports parallel merge.
22    fn supports_parallel_merge(&self) -> bool {
23        true
24    }
25}
26
27/// Accumulator state that supports merging.
28///
29/// Used by aggregate operators to merge partial aggregations.
30#[derive(Debug, Clone)]
31pub struct MergeableAccumulator {
32    /// Count of values.
33    pub count: i64,
34    /// Sum of values.
35    pub sum: f64,
36    /// Minimum value.
37    pub min: Option<Value>,
38    /// Maximum value.
39    pub max: Option<Value>,
40    /// First value encountered.
41    pub first: Option<Value>,
42    /// For AVG: sum of squared values (for variance if needed).
43    pub sum_squared: f64,
44}
45
46impl MergeableAccumulator {
47    /// Creates a new empty accumulator.
48    #[must_use]
49    pub fn new() -> Self {
50        Self {
51            count: 0,
52            sum: 0.0,
53            min: None,
54            max: None,
55            first: None,
56            sum_squared: 0.0,
57        }
58    }
59
60    /// Adds a value to the accumulator.
61    pub fn add(&mut self, value: &Value) {
62        if matches!(value, Value::Null) {
63            return;
64        }
65
66        self.count += 1;
67
68        if let Some(n) = value_to_f64(value) {
69            self.sum += n;
70            self.sum_squared += n * n;
71        }
72
73        // Min
74        if self.min.is_none() || compare_for_min(&self.min, value) {
75            self.min = Some(value.clone());
76        }
77
78        // Max
79        if self.max.is_none() || compare_for_max(&self.max, value) {
80            self.max = Some(value.clone());
81        }
82
83        // First
84        if self.first.is_none() {
85            self.first = Some(value.clone());
86        }
87    }
88
89    /// Merges another accumulator into this one.
90    pub fn merge(&mut self, other: &MergeableAccumulator) {
91        self.count += other.count;
92        self.sum += other.sum;
93        self.sum_squared += other.sum_squared;
94
95        // Merge min
96        if let Some(ref other_min) = other.min
97            && compare_for_min(&self.min, other_min)
98        {
99            self.min = Some(other_min.clone());
100        }
101
102        // Merge max
103        if let Some(ref other_max) = other.max
104            && compare_for_max(&self.max, other_max)
105        {
106            self.max = Some(other_max.clone());
107        }
108
109        // Keep our first (we processed earlier)
110        // If we have no first, take theirs
111        if self.first.is_none() {
112            self.first.clone_from(&other.first);
113        }
114    }
115
116    /// Finalizes COUNT aggregate.
117    #[must_use]
118    pub fn finalize_count(&self) -> Value {
119        Value::Int64(self.count)
120    }
121
122    /// Finalizes SUM aggregate.
123    #[must_use]
124    pub fn finalize_sum(&self) -> Value {
125        if self.count == 0 {
126            Value::Null
127        } else {
128            Value::Float64(self.sum)
129        }
130    }
131
132    /// Finalizes MIN aggregate.
133    #[must_use]
134    pub fn finalize_min(&self) -> Value {
135        self.min.clone().unwrap_or(Value::Null)
136    }
137
138    /// Finalizes MAX aggregate.
139    #[must_use]
140    pub fn finalize_max(&self) -> Value {
141        self.max.clone().unwrap_or(Value::Null)
142    }
143
144    /// Finalizes AVG aggregate.
145    #[must_use]
146    pub fn finalize_avg(&self) -> Value {
147        if self.count == 0 {
148            Value::Null
149        } else {
150            Value::Float64(self.sum / self.count as f64)
151        }
152    }
153
154    /// Finalizes FIRST aggregate.
155    #[must_use]
156    pub fn finalize_first(&self) -> Value {
157        self.first.clone().unwrap_or(Value::Null)
158    }
159}
160
161impl Default for MergeableAccumulator {
162    fn default() -> Self {
163        Self::new()
164    }
165}
166
167fn value_to_f64(value: &Value) -> Option<f64> {
168    match value {
169        Value::Int64(i) => Some(*i as f64),
170        Value::Float64(f) => Some(*f),
171        _ => None,
172    }
173}
174
175fn compare_for_min(current: &Option<Value>, new: &Value) -> bool {
176    match (current, new) {
177        (None, _) => true,
178        (Some(Value::Int64(a)), Value::Int64(b)) => b < a,
179        (Some(Value::Float64(a)), Value::Float64(b)) => b < a,
180        (Some(Value::String(a)), Value::String(b)) => b < a,
181        _ => false,
182    }
183}
184
185fn compare_for_max(current: &Option<Value>, new: &Value) -> bool {
186    match (current, new) {
187        (None, _) => true,
188        (Some(Value::Int64(a)), Value::Int64(b)) => b > a,
189        (Some(Value::Float64(a)), Value::Float64(b)) => b > a,
190        (Some(Value::String(a)), Value::String(b)) => b > a,
191        _ => false,
192    }
193}
194
195/// Sort key for k-way merge.
196#[derive(Debug, Clone)]
197pub struct SortKey {
198    /// Column index to sort by.
199    pub column: usize,
200    /// Sort direction (ascending = true).
201    pub ascending: bool,
202    /// Nulls first (true) or last (false).
203    pub nulls_first: bool,
204}
205
206impl SortKey {
207    /// Creates an ascending sort key.
208    #[must_use]
209    pub fn ascending(column: usize) -> Self {
210        Self {
211            column,
212            ascending: true,
213            nulls_first: false,
214        }
215    }
216
217    /// Creates a descending sort key.
218    #[must_use]
219    pub fn descending(column: usize) -> Self {
220        Self {
221            column,
222            ascending: false,
223            nulls_first: true,
224        }
225    }
226}
227
228/// Entry in the k-way merge heap.
229struct MergeEntry {
230    /// Row data.
231    row: Vec<Value>,
232    /// Source run index.
233    run_index: usize,
234    /// Sort keys for comparison.
235    keys: Vec<SortKey>,
236}
237
238impl MergeEntry {
239    fn compare_to(&self, other: &Self) -> Ordering {
240        for key in &self.keys {
241            let a = self.row.get(key.column);
242            let b = other.row.get(key.column);
243
244            let ordering = compare_values_for_sort(a, b, key.nulls_first);
245
246            let ordering = if key.ascending {
247                ordering
248            } else {
249                ordering.reverse()
250            };
251
252            if ordering != Ordering::Equal {
253                return ordering;
254            }
255        }
256        Ordering::Equal
257    }
258}
259
260impl PartialEq for MergeEntry {
261    fn eq(&self, other: &Self) -> bool {
262        self.compare_to(other) == Ordering::Equal
263    }
264}
265
266impl Eq for MergeEntry {}
267
268impl PartialOrd for MergeEntry {
269    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
270        Some(self.cmp(other))
271    }
272}
273
274impl Ord for MergeEntry {
275    fn cmp(&self, other: &Self) -> Ordering {
276        // Reverse for min-heap behavior (we want smallest first)
277        other.compare_to(self)
278    }
279}
280
281fn compare_values_for_sort(a: Option<&Value>, b: Option<&Value>, nulls_first: bool) -> Ordering {
282    match (a, b) {
283        (None, None) | (Some(Value::Null), Some(Value::Null)) => Ordering::Equal,
284        (None, _) | (Some(Value::Null), _) => {
285            if nulls_first {
286                Ordering::Less
287            } else {
288                Ordering::Greater
289            }
290        }
291        (_, None) | (_, Some(Value::Null)) => {
292            if nulls_first {
293                Ordering::Greater
294            } else {
295                Ordering::Less
296            }
297        }
298        (Some(a), Some(b)) => compare_values(a, b),
299    }
300}
301
302fn compare_values(a: &Value, b: &Value) -> Ordering {
303    match (a, b) {
304        (Value::Bool(a), Value::Bool(b)) => a.cmp(b),
305        (Value::Int64(a), Value::Int64(b)) => a.cmp(b),
306        (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
307        (Value::String(a), Value::String(b)) => a.cmp(b),
308        (Value::Timestamp(a), Value::Timestamp(b)) => a.cmp(b),
309        (Value::Date(a), Value::Date(b)) => a.cmp(b),
310        (Value::Time(a), Value::Time(b)) => a.cmp(b),
311        _ => Ordering::Equal,
312    }
313}
314
315/// Merges multiple sorted runs into a single sorted output.
316///
317/// Uses a min-heap for efficient k-way merge.
318pub fn merge_sorted_runs(runs: Vec<Vec<Vec<Value>>>, keys: &[SortKey]) -> Vec<Vec<Value>> {
319    if runs.is_empty() {
320        return Vec::new();
321    }
322
323    if runs.len() == 1 {
324        return runs.into_iter().next().unwrap_or_default();
325    }
326
327    // Count total rows
328    let total_rows: usize = runs.iter().map(|r| r.len()).sum();
329    let mut result = Vec::with_capacity(total_rows);
330
331    // Track position in each run
332    let mut positions: Vec<usize> = vec![0; runs.len()];
333
334    // Initialize heap with first row from each non-empty run
335    let mut heap = BinaryHeap::new();
336    for (run_index, run) in runs.iter().enumerate() {
337        if !run.is_empty() {
338            heap.push(MergeEntry {
339                row: run[0].clone(),
340                run_index,
341                keys: keys.to_vec(),
342            });
343            positions[run_index] = 1;
344        }
345    }
346
347    // Extract rows in order
348    while let Some(entry) = heap.pop() {
349        result.push(entry.row);
350
351        // Add next row from same run if available
352        let pos = positions[entry.run_index];
353        if pos < runs[entry.run_index].len() {
354            heap.push(MergeEntry {
355                row: runs[entry.run_index][pos].clone(),
356                run_index: entry.run_index,
357                keys: keys.to_vec(),
358            });
359            positions[entry.run_index] += 1;
360        }
361    }
362
363    result
364}
365
366/// Converts sorted rows to `DataChunk`s.
367pub fn rows_to_chunks(rows: Vec<Vec<Value>>, chunk_size: usize) -> Vec<DataChunk> {
368    if rows.is_empty() {
369        return Vec::new();
370    }
371
372    let num_columns = rows[0].len();
373    let num_chunks = (rows.len() + chunk_size - 1) / chunk_size;
374    let mut chunks = Vec::with_capacity(num_chunks);
375
376    for chunk_rows in rows.chunks(chunk_size) {
377        let mut columns: Vec<ValueVector> = (0..num_columns).map(|_| ValueVector::new()).collect();
378
379        for row in chunk_rows {
380            for (col_idx, col) in columns.iter_mut().enumerate() {
381                let val = row.get(col_idx).cloned().unwrap_or(Value::Null);
382                col.push(val);
383            }
384        }
385
386        chunks.push(DataChunk::new(columns));
387    }
388
389    chunks
390}
391
392/// Merges multiple sorted `DataChunk` streams into a single sorted stream.
393pub fn merge_sorted_chunks(
394    runs: Vec<Vec<DataChunk>>,
395    keys: &[SortKey],
396    chunk_size: usize,
397) -> Vec<DataChunk> {
398    // Convert chunks to row format for merging
399    let row_runs: Vec<Vec<Vec<Value>>> = runs.into_iter().map(chunks_to_rows).collect();
400
401    let merged_rows = merge_sorted_runs(row_runs, keys);
402    rows_to_chunks(merged_rows, chunk_size)
403}
404
405/// Converts DataChunks to row format.
406fn chunks_to_rows(chunks: Vec<DataChunk>) -> Vec<Vec<Value>> {
407    let mut rows = Vec::new();
408
409    for chunk in chunks {
410        let num_columns = chunk.num_columns();
411        for i in 0..chunk.len() {
412            let mut row = Vec::with_capacity(num_columns);
413            for col_idx in 0..num_columns {
414                let val = chunk
415                    .column(col_idx)
416                    .and_then(|c| c.get(i))
417                    .unwrap_or(Value::Null);
418                row.push(val);
419            }
420            rows.push(row);
421        }
422    }
423
424    rows
425}
426
427/// Concatenates multiple DataChunk results (for non-sorted parallel results).
428pub fn concat_parallel_results(results: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
429    results.into_iter().flatten().collect()
430}
431
432/// Merges parallel DISTINCT results by deduplication.
433pub fn merge_distinct_results(results: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
434    use std::collections::HashSet;
435
436    // Simple row-based deduplication using hash
437    let mut seen: HashSet<u64> = HashSet::new();
438    let mut unique_rows: Vec<Vec<Value>> = Vec::new();
439
440    for chunks in results {
441        for chunk in chunks {
442            let num_columns = chunk.num_columns();
443            for i in 0..chunk.len() {
444                let mut row = Vec::with_capacity(num_columns);
445                for col_idx in 0..num_columns {
446                    let val = chunk
447                        .column(col_idx)
448                        .and_then(|c| c.get(i))
449                        .unwrap_or(Value::Null);
450                    row.push(val);
451                }
452
453                let hash = hash_row(&row);
454                if seen.insert(hash) {
455                    unique_rows.push(row);
456                }
457            }
458        }
459    }
460
461    rows_to_chunks(unique_rows, 2048)
462}
463
464fn hash_row(row: &[Value]) -> u64 {
465    use std::collections::hash_map::DefaultHasher;
466    use std::hash::{Hash, Hasher};
467
468    let mut hasher = DefaultHasher::new();
469    for value in row {
470        match value {
471            Value::Null => 0u8.hash(&mut hasher),
472            Value::Bool(b) => b.hash(&mut hasher),
473            Value::Int64(i) => i.hash(&mut hasher),
474            Value::Float64(f) => f.to_bits().hash(&mut hasher),
475            Value::String(s) => s.hash(&mut hasher),
476            _ => 0u8.hash(&mut hasher),
477        }
478    }
479    hasher.finish()
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_mergeable_accumulator() {
488        let mut acc1 = MergeableAccumulator::new();
489        acc1.add(&Value::Int64(10));
490        acc1.add(&Value::Int64(20));
491
492        let mut acc2 = MergeableAccumulator::new();
493        acc2.add(&Value::Int64(30));
494        acc2.add(&Value::Int64(40));
495
496        acc1.merge(&acc2);
497
498        assert_eq!(acc1.count, 4);
499        assert_eq!(acc1.sum, 100.0);
500        assert_eq!(acc1.finalize_min(), Value::Int64(10));
501        assert_eq!(acc1.finalize_max(), Value::Int64(40));
502        assert_eq!(acc1.finalize_avg(), Value::Float64(25.0));
503    }
504
505    #[test]
506    fn test_merge_sorted_runs_empty() {
507        let runs: Vec<Vec<Vec<Value>>> = Vec::new();
508        let result = merge_sorted_runs(runs, &[]);
509        assert!(result.is_empty());
510    }
511
512    #[test]
513    fn test_merge_sorted_runs_single() {
514        let runs = vec![vec![
515            vec![Value::Int64(1)],
516            vec![Value::Int64(2)],
517            vec![Value::Int64(3)],
518        ]];
519        let keys = vec![SortKey::ascending(0)];
520
521        let result = merge_sorted_runs(runs, &keys);
522        assert_eq!(result.len(), 3);
523    }
524
525    #[test]
526    fn test_merge_sorted_runs_multiple() {
527        // Run 1: [1, 4, 7]
528        // Run 2: [2, 5, 8]
529        // Run 3: [3, 6, 9]
530        let runs = vec![
531            vec![
532                vec![Value::Int64(1)],
533                vec![Value::Int64(4)],
534                vec![Value::Int64(7)],
535            ],
536            vec![
537                vec![Value::Int64(2)],
538                vec![Value::Int64(5)],
539                vec![Value::Int64(8)],
540            ],
541            vec![
542                vec![Value::Int64(3)],
543                vec![Value::Int64(6)],
544                vec![Value::Int64(9)],
545            ],
546        ];
547        let keys = vec![SortKey::ascending(0)];
548
549        let result = merge_sorted_runs(runs, &keys);
550        assert_eq!(result.len(), 9);
551
552        // Verify sorted order
553        for i in 0..9 {
554            assert_eq!(result[i][0], Value::Int64((i + 1) as i64));
555        }
556    }
557
558    #[test]
559    fn test_merge_sorted_runs_descending() {
560        let runs = vec![
561            vec![
562                vec![Value::Int64(7)],
563                vec![Value::Int64(4)],
564                vec![Value::Int64(1)],
565            ],
566            vec![
567                vec![Value::Int64(8)],
568                vec![Value::Int64(5)],
569                vec![Value::Int64(2)],
570            ],
571        ];
572        let keys = vec![SortKey::descending(0)];
573
574        let result = merge_sorted_runs(runs, &keys);
575        assert_eq!(result.len(), 6);
576
577        // Verify descending order
578        assert_eq!(result[0][0], Value::Int64(8));
579        assert_eq!(result[1][0], Value::Int64(7));
580        assert_eq!(result[5][0], Value::Int64(1));
581    }
582
583    #[test]
584    fn test_rows_to_chunks() {
585        let rows = (0..10).map(|i| vec![Value::Int64(i)]).collect();
586        let chunks = rows_to_chunks(rows, 3);
587
588        assert_eq!(chunks.len(), 4); // 10 rows / 3 = 4 chunks
589        assert_eq!(chunks[0].len(), 3);
590        assert_eq!(chunks[1].len(), 3);
591        assert_eq!(chunks[2].len(), 3);
592        assert_eq!(chunks[3].len(), 1);
593    }
594
595    #[test]
596    fn test_merge_distinct_results() {
597        let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[
598            Value::Int64(1),
599            Value::Int64(2),
600            Value::Int64(3),
601        ])]);
602
603        let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[
604            Value::Int64(2),
605            Value::Int64(3),
606            Value::Int64(4),
607        ])]);
608
609        let results = vec![vec![chunk1], vec![chunk2]];
610        let merged = merge_distinct_results(results);
611
612        let total_rows: usize = merged.iter().map(DataChunk::len).sum();
613        assert_eq!(total_rows, 4); // 1, 2, 3, 4 (no duplicates)
614    }
615
616    #[test]
617    fn test_hash_row_with_non_primitive_values() {
618        // Exercises the catch-all branch in hash_row for non-primitive Value types
619        let row1 = vec![Value::List(vec![Value::Int64(1)].into())];
620        let row2 = vec![Value::List(vec![Value::Int64(2)].into())];
621        let row3 = vec![Value::Bytes(vec![1, 2, 3].into())];
622
623        // The catch-all hashes all non-primitive types to the same bucket (0u8)
624        let h1 = hash_row(&row1);
625        let h2 = hash_row(&row2);
626        let h3 = hash_row(&row3);
627
628        // All non-primitive types hash identically via the catch-all
629        assert_eq!(h1, h2);
630        assert_eq!(h2, h3);
631    }
632
633    #[test]
634    fn test_concat_parallel_results() {
635        let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(1)])]);
636        let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(2)])]);
637        let chunk3 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(3)])]);
638
639        let results = vec![vec![chunk1], vec![chunk2, chunk3]];
640        let concatenated = concat_parallel_results(results);
641
642        assert_eq!(concatenated.len(), 3);
643    }
644}