Skip to main content

grafeo_core/execution/parallel/
merge.rs

1//! Merge utilities for parallel pipeline breakers.
2//!
3//! When parallel pipelines have pipeline breakers (Sort, Aggregate, Distinct),
4//! each worker produces partial results that must be merged into final output.
5
6use crate::execution::chunk::DataChunk;
7use crate::execution::operators::OperatorError;
8use crate::execution::vector::ValueVector;
9use grafeo_common::types::Value;
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12
13/// Trait for operators that support parallel merge.
14///
15/// Pipeline breakers must implement this to enable parallel execution.
16pub trait MergeableOperator: Send + Sync {
17    /// Merges partial results from another operator instance.
18    fn merge_from(&mut self, other: Self)
19    where
20        Self: Sized;
21
22    /// Returns whether this operator supports parallel merge.
23    fn supports_parallel_merge(&self) -> bool {
24        true
25    }
26}
27
28/// Accumulator state that supports merging.
29///
30/// Used by aggregate operators to merge partial aggregations.
31#[derive(Debug, Clone)]
32pub struct MergeableAccumulator {
33    /// Count of values.
34    pub count: i64,
35    /// Sum of values.
36    pub sum: f64,
37    /// Minimum value.
38    pub min: Option<Value>,
39    /// Maximum value.
40    pub max: Option<Value>,
41    /// First value encountered.
42    pub first: Option<Value>,
43    /// For AVG: sum of squared values (for variance if needed).
44    pub sum_squared: f64,
45}
46
47impl MergeableAccumulator {
48    /// Creates a new empty accumulator.
49    #[must_use]
50    pub fn new() -> Self {
51        Self {
52            count: 0,
53            sum: 0.0,
54            min: None,
55            max: None,
56            first: None,
57            sum_squared: 0.0,
58        }
59    }
60
61    /// Adds a value to the accumulator.
62    pub fn add(&mut self, value: &Value) {
63        if matches!(value, Value::Null) {
64            return;
65        }
66
67        self.count += 1;
68
69        if let Some(n) = value_to_f64(value) {
70            self.sum += n;
71            self.sum_squared += n * n;
72        }
73
74        // Min
75        if self.min.is_none() || compare_for_min(&self.min, value) {
76            self.min = Some(value.clone());
77        }
78
79        // Max
80        if self.max.is_none() || compare_for_max(&self.max, value) {
81            self.max = Some(value.clone());
82        }
83
84        // First
85        if self.first.is_none() {
86            self.first = Some(value.clone());
87        }
88    }
89
90    /// Merges another accumulator into this one.
91    pub fn merge(&mut self, other: &MergeableAccumulator) {
92        self.count += other.count;
93        self.sum += other.sum;
94        self.sum_squared += other.sum_squared;
95
96        // Merge min
97        if let Some(ref other_min) = other.min
98            && compare_for_min(&self.min, other_min)
99        {
100            self.min = Some(other_min.clone());
101        }
102
103        // Merge max
104        if let Some(ref other_max) = other.max
105            && compare_for_max(&self.max, other_max)
106        {
107            self.max = Some(other_max.clone());
108        }
109
110        // Keep our first (we processed earlier)
111        // If we have no first, take theirs
112        if self.first.is_none() {
113            self.first = other.first.clone();
114        }
115    }
116
117    /// Finalizes COUNT aggregate.
118    #[must_use]
119    pub fn finalize_count(&self) -> Value {
120        Value::Int64(self.count)
121    }
122
123    /// Finalizes SUM aggregate.
124    #[must_use]
125    pub fn finalize_sum(&self) -> Value {
126        if self.count == 0 {
127            Value::Null
128        } else {
129            Value::Float64(self.sum)
130        }
131    }
132
133    /// Finalizes MIN aggregate.
134    #[must_use]
135    pub fn finalize_min(&self) -> Value {
136        self.min.clone().unwrap_or(Value::Null)
137    }
138
139    /// Finalizes MAX aggregate.
140    #[must_use]
141    pub fn finalize_max(&self) -> Value {
142        self.max.clone().unwrap_or(Value::Null)
143    }
144
145    /// Finalizes AVG aggregate.
146    #[must_use]
147    pub fn finalize_avg(&self) -> Value {
148        if self.count == 0 {
149            Value::Null
150        } else {
151            Value::Float64(self.sum / self.count as f64)
152        }
153    }
154
155    /// Finalizes FIRST aggregate.
156    #[must_use]
157    pub fn finalize_first(&self) -> Value {
158        self.first.clone().unwrap_or(Value::Null)
159    }
160}
161
162impl Default for MergeableAccumulator {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168fn value_to_f64(value: &Value) -> Option<f64> {
169    match value {
170        Value::Int64(i) => Some(*i as f64),
171        Value::Float64(f) => Some(*f),
172        _ => None,
173    }
174}
175
176fn compare_for_min(current: &Option<Value>, new: &Value) -> bool {
177    match (current, new) {
178        (None, _) => true,
179        (Some(Value::Int64(a)), Value::Int64(b)) => b < a,
180        (Some(Value::Float64(a)), Value::Float64(b)) => b < a,
181        (Some(Value::String(a)), Value::String(b)) => b < a,
182        _ => false,
183    }
184}
185
186fn compare_for_max(current: &Option<Value>, new: &Value) -> bool {
187    match (current, new) {
188        (None, _) => true,
189        (Some(Value::Int64(a)), Value::Int64(b)) => b > a,
190        (Some(Value::Float64(a)), Value::Float64(b)) => b > a,
191        (Some(Value::String(a)), Value::String(b)) => b > a,
192        _ => false,
193    }
194}
195
196/// Sort key for k-way merge.
197#[derive(Debug, Clone)]
198pub struct SortKey {
199    /// Column index to sort by.
200    pub column: usize,
201    /// Sort direction (ascending = true).
202    pub ascending: bool,
203    /// Nulls first (true) or last (false).
204    pub nulls_first: bool,
205}
206
207impl SortKey {
208    /// Creates an ascending sort key.
209    #[must_use]
210    pub fn ascending(column: usize) -> Self {
211        Self {
212            column,
213            ascending: true,
214            nulls_first: false,
215        }
216    }
217
218    /// Creates a descending sort key.
219    #[must_use]
220    pub fn descending(column: usize) -> Self {
221        Self {
222            column,
223            ascending: false,
224            nulls_first: true,
225        }
226    }
227}
228
229/// Entry in the k-way merge heap.
230struct MergeEntry {
231    /// Row data.
232    row: Vec<Value>,
233    /// Source run index.
234    run_index: usize,
235    /// Row index within the run (for debugging/tracing).
236    #[allow(dead_code)]
237    row_index: usize,
238    /// Sort keys for comparison.
239    keys: Vec<SortKey>,
240}
241
242impl MergeEntry {
243    fn compare_to(&self, other: &Self) -> Ordering {
244        for key in &self.keys {
245            let a = self.row.get(key.column);
246            let b = other.row.get(key.column);
247
248            let ordering = compare_values_for_sort(a, b, key.nulls_first);
249
250            let ordering = if key.ascending {
251                ordering
252            } else {
253                ordering.reverse()
254            };
255
256            if ordering != Ordering::Equal {
257                return ordering;
258            }
259        }
260        Ordering::Equal
261    }
262}
263
264impl PartialEq for MergeEntry {
265    fn eq(&self, other: &Self) -> bool {
266        self.compare_to(other) == Ordering::Equal
267    }
268}
269
270impl Eq for MergeEntry {}
271
272impl PartialOrd for MergeEntry {
273    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
274        Some(self.cmp(other))
275    }
276}
277
278impl Ord for MergeEntry {
279    fn cmp(&self, other: &Self) -> Ordering {
280        // Reverse for min-heap behavior (we want smallest first)
281        other.compare_to(self)
282    }
283}
284
285fn compare_values_for_sort(a: Option<&Value>, b: Option<&Value>, nulls_first: bool) -> Ordering {
286    match (a, b) {
287        (None, None) | (Some(Value::Null), Some(Value::Null)) => Ordering::Equal,
288        (None, _) | (Some(Value::Null), _) => {
289            if nulls_first {
290                Ordering::Less
291            } else {
292                Ordering::Greater
293            }
294        }
295        (_, None) | (_, Some(Value::Null)) => {
296            if nulls_first {
297                Ordering::Greater
298            } else {
299                Ordering::Less
300            }
301        }
302        (Some(a), Some(b)) => compare_values(a, b),
303    }
304}
305
306fn compare_values(a: &Value, b: &Value) -> Ordering {
307    match (a, b) {
308        (Value::Bool(a), Value::Bool(b)) => a.cmp(b),
309        (Value::Int64(a), Value::Int64(b)) => a.cmp(b),
310        (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
311        (Value::String(a), Value::String(b)) => a.cmp(b),
312        _ => Ordering::Equal,
313    }
314}
315
316/// Merges multiple sorted runs into a single sorted output.
317///
318/// Uses a min-heap for efficient k-way merge.
319pub fn merge_sorted_runs(
320    runs: Vec<Vec<Vec<Value>>>,
321    keys: &[SortKey],
322) -> Result<Vec<Vec<Value>>, OperatorError> {
323    if runs.is_empty() {
324        return Ok(Vec::new());
325    }
326
327    if runs.len() == 1 {
328        // Invariant: runs.len() == 1 guarantees exactly one element
329        return Ok(runs
330            .into_iter()
331            .next()
332            .expect("runs has exactly one element: checked on previous line"));
333    }
334
335    // Count total rows
336    let total_rows: usize = runs.iter().map(|r| r.len()).sum();
337    let mut result = Vec::with_capacity(total_rows);
338
339    // Track position in each run
340    let mut positions: Vec<usize> = vec![0; runs.len()];
341
342    // Initialize heap with first row from each non-empty run
343    let mut heap = BinaryHeap::new();
344    for (run_index, run) in runs.iter().enumerate() {
345        if !run.is_empty() {
346            heap.push(MergeEntry {
347                row: run[0].clone(),
348                run_index,
349                row_index: 0,
350                keys: keys.to_vec(),
351            });
352            positions[run_index] = 1;
353        }
354    }
355
356    // Extract rows in order
357    while let Some(entry) = heap.pop() {
358        result.push(entry.row);
359
360        // Add next row from same run if available
361        let pos = positions[entry.run_index];
362        if pos < runs[entry.run_index].len() {
363            heap.push(MergeEntry {
364                row: runs[entry.run_index][pos].clone(),
365                run_index: entry.run_index,
366                row_index: pos,
367                keys: keys.to_vec(),
368            });
369            positions[entry.run_index] += 1;
370        }
371    }
372
373    Ok(result)
374}
375
376/// Converts sorted rows to DataChunks.
377pub fn rows_to_chunks(
378    rows: Vec<Vec<Value>>,
379    chunk_size: usize,
380) -> Result<Vec<DataChunk>, OperatorError> {
381    if rows.is_empty() {
382        return Ok(Vec::new());
383    }
384
385    let num_columns = rows[0].len();
386    let num_chunks = (rows.len() + chunk_size - 1) / chunk_size;
387    let mut chunks = Vec::with_capacity(num_chunks);
388
389    for chunk_rows in rows.chunks(chunk_size) {
390        let mut columns: Vec<ValueVector> = (0..num_columns).map(|_| ValueVector::new()).collect();
391
392        for row in chunk_rows {
393            for (col_idx, col) in columns.iter_mut().enumerate() {
394                let val = row.get(col_idx).cloned().unwrap_or(Value::Null);
395                col.push(val);
396            }
397        }
398
399        chunks.push(DataChunk::new(columns));
400    }
401
402    Ok(chunks)
403}
404
405/// Merges multiple sorted DataChunk streams into a single sorted stream.
406pub fn merge_sorted_chunks(
407    runs: Vec<Vec<DataChunk>>,
408    keys: &[SortKey],
409    chunk_size: usize,
410) -> Result<Vec<DataChunk>, OperatorError> {
411    // Convert chunks to row format for merging
412    let row_runs: Vec<Vec<Vec<Value>>> = runs.into_iter().map(chunks_to_rows).collect();
413
414    let merged_rows = merge_sorted_runs(row_runs, keys)?;
415    rows_to_chunks(merged_rows, chunk_size)
416}
417
418/// Converts DataChunks to row format.
419fn chunks_to_rows(chunks: Vec<DataChunk>) -> Vec<Vec<Value>> {
420    let mut rows = Vec::new();
421
422    for chunk in chunks {
423        let num_columns = chunk.num_columns();
424        for i in 0..chunk.len() {
425            let mut row = Vec::with_capacity(num_columns);
426            for col_idx in 0..num_columns {
427                let val = chunk
428                    .column(col_idx)
429                    .and_then(|c| c.get(i))
430                    .unwrap_or(Value::Null);
431                row.push(val);
432            }
433            rows.push(row);
434        }
435    }
436
437    rows
438}
439
440/// Concatenates multiple DataChunk results (for non-sorted parallel results).
441pub fn concat_parallel_results(results: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
442    results.into_iter().flatten().collect()
443}
444
445/// Merges parallel DISTINCT results by deduplication.
446pub fn merge_distinct_results(
447    results: Vec<Vec<DataChunk>>,
448) -> Result<Vec<DataChunk>, OperatorError> {
449    use std::collections::HashSet;
450
451    // Simple row-based deduplication using hash
452    let mut seen: HashSet<u64> = HashSet::new();
453    let mut unique_rows: Vec<Vec<Value>> = Vec::new();
454
455    for chunks in results {
456        for chunk in chunks {
457            let num_columns = chunk.num_columns();
458            for i in 0..chunk.len() {
459                let mut row = Vec::with_capacity(num_columns);
460                for col_idx in 0..num_columns {
461                    let val = chunk
462                        .column(col_idx)
463                        .and_then(|c| c.get(i))
464                        .unwrap_or(Value::Null);
465                    row.push(val);
466                }
467
468                let hash = hash_row(&row);
469                if seen.insert(hash) {
470                    unique_rows.push(row);
471                }
472            }
473        }
474    }
475
476    rows_to_chunks(unique_rows, 2048)
477}
478
479fn hash_row(row: &[Value]) -> u64 {
480    use std::collections::hash_map::DefaultHasher;
481    use std::hash::{Hash, Hasher};
482
483    let mut hasher = DefaultHasher::new();
484    for value in row {
485        match value {
486            Value::Null => 0u8.hash(&mut hasher),
487            Value::Bool(b) => b.hash(&mut hasher),
488            Value::Int64(i) => i.hash(&mut hasher),
489            Value::Float64(f) => f.to_bits().hash(&mut hasher),
490            Value::String(s) => s.hash(&mut hasher),
491            _ => 0u8.hash(&mut hasher),
492        }
493    }
494    hasher.finish()
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500
501    #[test]
502    fn test_mergeable_accumulator() {
503        let mut acc1 = MergeableAccumulator::new();
504        acc1.add(&Value::Int64(10));
505        acc1.add(&Value::Int64(20));
506
507        let mut acc2 = MergeableAccumulator::new();
508        acc2.add(&Value::Int64(30));
509        acc2.add(&Value::Int64(40));
510
511        acc1.merge(&acc2);
512
513        assert_eq!(acc1.count, 4);
514        assert_eq!(acc1.sum, 100.0);
515        assert_eq!(acc1.finalize_min(), Value::Int64(10));
516        assert_eq!(acc1.finalize_max(), Value::Int64(40));
517        assert_eq!(acc1.finalize_avg(), Value::Float64(25.0));
518    }
519
520    #[test]
521    fn test_merge_sorted_runs_empty() {
522        let runs: Vec<Vec<Vec<Value>>> = Vec::new();
523        let result = merge_sorted_runs(runs, &[]).unwrap();
524        assert!(result.is_empty());
525    }
526
527    #[test]
528    fn test_merge_sorted_runs_single() {
529        let runs = vec![vec![
530            vec![Value::Int64(1)],
531            vec![Value::Int64(2)],
532            vec![Value::Int64(3)],
533        ]];
534        let keys = vec![SortKey::ascending(0)];
535
536        let result = merge_sorted_runs(runs, &keys).unwrap();
537        assert_eq!(result.len(), 3);
538    }
539
540    #[test]
541    fn test_merge_sorted_runs_multiple() {
542        // Run 1: [1, 4, 7]
543        // Run 2: [2, 5, 8]
544        // Run 3: [3, 6, 9]
545        let runs = vec![
546            vec![
547                vec![Value::Int64(1)],
548                vec![Value::Int64(4)],
549                vec![Value::Int64(7)],
550            ],
551            vec![
552                vec![Value::Int64(2)],
553                vec![Value::Int64(5)],
554                vec![Value::Int64(8)],
555            ],
556            vec![
557                vec![Value::Int64(3)],
558                vec![Value::Int64(6)],
559                vec![Value::Int64(9)],
560            ],
561        ];
562        let keys = vec![SortKey::ascending(0)];
563
564        let result = merge_sorted_runs(runs, &keys).unwrap();
565        assert_eq!(result.len(), 9);
566
567        // Verify sorted order
568        for i in 0..9 {
569            assert_eq!(result[i][0], Value::Int64((i + 1) as i64));
570        }
571    }
572
573    #[test]
574    fn test_merge_sorted_runs_descending() {
575        let runs = vec![
576            vec![
577                vec![Value::Int64(7)],
578                vec![Value::Int64(4)],
579                vec![Value::Int64(1)],
580            ],
581            vec![
582                vec![Value::Int64(8)],
583                vec![Value::Int64(5)],
584                vec![Value::Int64(2)],
585            ],
586        ];
587        let keys = vec![SortKey::descending(0)];
588
589        let result = merge_sorted_runs(runs, &keys).unwrap();
590        assert_eq!(result.len(), 6);
591
592        // Verify descending order
593        assert_eq!(result[0][0], Value::Int64(8));
594        assert_eq!(result[1][0], Value::Int64(7));
595        assert_eq!(result[5][0], Value::Int64(1));
596    }
597
598    #[test]
599    fn test_rows_to_chunks() {
600        let rows = (0..10).map(|i| vec![Value::Int64(i)]).collect();
601        let chunks = rows_to_chunks(rows, 3).unwrap();
602
603        assert_eq!(chunks.len(), 4); // 10 rows / 3 = 4 chunks
604        assert_eq!(chunks[0].len(), 3);
605        assert_eq!(chunks[1].len(), 3);
606        assert_eq!(chunks[2].len(), 3);
607        assert_eq!(chunks[3].len(), 1);
608    }
609
610    #[test]
611    fn test_merge_distinct_results() {
612        let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[
613            Value::Int64(1),
614            Value::Int64(2),
615            Value::Int64(3),
616        ])]);
617
618        let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[
619            Value::Int64(2),
620            Value::Int64(3),
621            Value::Int64(4),
622        ])]);
623
624        let results = vec![vec![chunk1], vec![chunk2]];
625        let merged = merge_distinct_results(results).unwrap();
626
627        let total_rows: usize = merged.iter().map(DataChunk::len).sum();
628        assert_eq!(total_rows, 4); // 1, 2, 3, 4 (no duplicates)
629    }
630
631    #[test]
632    fn test_concat_parallel_results() {
633        let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(1)])]);
634        let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(2)])]);
635        let chunk3 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(3)])]);
636
637        let results = vec![vec![chunk1], vec![chunk2, chunk3]];
638        let concatenated = concat_parallel_results(results);
639
640        assert_eq!(concatenated.len(), 3);
641    }
642}