Skip to main content

grafeo_core/execution/parallel/
merge.rs

1//! Merge utilities for parallel pipeline breakers.
2//!
3//! When parallel pipelines have pipeline breakers (Sort, Aggregate, Distinct),
4//! each worker produces partial results that must be merged into final output.
5
6use crate::execution::chunk::DataChunk;
7use crate::execution::operators::OperatorError;
8use crate::execution::vector::ValueVector;
9use grafeo_common::types::Value;
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12
13/// Trait for operators that support parallel merge.
14///
15/// Pipeline breakers must implement this to enable parallel execution.
16pub trait MergeableOperator: Send + Sync {
17    /// Merges partial results from another operator instance.
18    fn merge_from(&mut self, other: Self)
19    where
20        Self: Sized;
21
22    /// Returns whether this operator supports parallel merge.
23    fn supports_parallel_merge(&self) -> bool {
24        true
25    }
26}
27
28/// Accumulator state that supports merging.
29///
30/// Used by aggregate operators to merge partial aggregations.
31#[derive(Debug, Clone)]
32pub struct MergeableAccumulator {
33    /// Count of values.
34    pub count: i64,
35    /// Sum of values.
36    pub sum: f64,
37    /// Minimum value.
38    pub min: Option<Value>,
39    /// Maximum value.
40    pub max: Option<Value>,
41    /// First value encountered.
42    pub first: Option<Value>,
43    /// For AVG: sum of squared values (for variance if needed).
44    pub sum_squared: f64,
45}
46
47impl MergeableAccumulator {
48    /// Creates a new empty accumulator.
49    #[must_use]
50    pub fn new() -> Self {
51        Self {
52            count: 0,
53            sum: 0.0,
54            min: None,
55            max: None,
56            first: None,
57            sum_squared: 0.0,
58        }
59    }
60
61    /// Adds a value to the accumulator.
62    pub fn add(&mut self, value: &Value) {
63        if matches!(value, Value::Null) {
64            return;
65        }
66
67        self.count += 1;
68
69        if let Some(n) = value_to_f64(value) {
70            self.sum += n;
71            self.sum_squared += n * n;
72        }
73
74        // Min
75        if self.min.is_none() || compare_for_min(&self.min, value) {
76            self.min = Some(value.clone());
77        }
78
79        // Max
80        if self.max.is_none() || compare_for_max(&self.max, value) {
81            self.max = Some(value.clone());
82        }
83
84        // First
85        if self.first.is_none() {
86            self.first = Some(value.clone());
87        }
88    }
89
90    /// Merges another accumulator into this one.
91    pub fn merge(&mut self, other: &MergeableAccumulator) {
92        self.count += other.count;
93        self.sum += other.sum;
94        self.sum_squared += other.sum_squared;
95
96        // Merge min
97        if let Some(ref other_min) = other.min {
98            if compare_for_min(&self.min, other_min) {
99                self.min = Some(other_min.clone());
100            }
101        }
102
103        // Merge max
104        if let Some(ref other_max) = other.max {
105            if compare_for_max(&self.max, other_max) {
106                self.max = Some(other_max.clone());
107            }
108        }
109
110        // Keep our first (we processed earlier)
111        // If we have no first, take theirs
112        if self.first.is_none() {
113            self.first = other.first.clone();
114        }
115    }
116
117    /// Finalizes COUNT aggregate.
118    #[must_use]
119    pub fn finalize_count(&self) -> Value {
120        Value::Int64(self.count)
121    }
122
123    /// Finalizes SUM aggregate.
124    #[must_use]
125    pub fn finalize_sum(&self) -> Value {
126        if self.count == 0 {
127            Value::Null
128        } else {
129            Value::Float64(self.sum)
130        }
131    }
132
133    /// Finalizes MIN aggregate.
134    #[must_use]
135    pub fn finalize_min(&self) -> Value {
136        self.min.clone().unwrap_or(Value::Null)
137    }
138
139    /// Finalizes MAX aggregate.
140    #[must_use]
141    pub fn finalize_max(&self) -> Value {
142        self.max.clone().unwrap_or(Value::Null)
143    }
144
145    /// Finalizes AVG aggregate.
146    #[must_use]
147    pub fn finalize_avg(&self) -> Value {
148        if self.count == 0 {
149            Value::Null
150        } else {
151            Value::Float64(self.sum / self.count as f64)
152        }
153    }
154
155    /// Finalizes FIRST aggregate.
156    #[must_use]
157    pub fn finalize_first(&self) -> Value {
158        self.first.clone().unwrap_or(Value::Null)
159    }
160}
161
162impl Default for MergeableAccumulator {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168fn value_to_f64(value: &Value) -> Option<f64> {
169    match value {
170        Value::Int64(i) => Some(*i as f64),
171        Value::Float64(f) => Some(*f),
172        _ => None,
173    }
174}
175
176fn compare_for_min(current: &Option<Value>, new: &Value) -> bool {
177    match (current, new) {
178        (None, _) => true,
179        (Some(Value::Int64(a)), Value::Int64(b)) => b < a,
180        (Some(Value::Float64(a)), Value::Float64(b)) => b < a,
181        (Some(Value::String(a)), Value::String(b)) => b < a,
182        _ => false,
183    }
184}
185
186fn compare_for_max(current: &Option<Value>, new: &Value) -> bool {
187    match (current, new) {
188        (None, _) => true,
189        (Some(Value::Int64(a)), Value::Int64(b)) => b > a,
190        (Some(Value::Float64(a)), Value::Float64(b)) => b > a,
191        (Some(Value::String(a)), Value::String(b)) => b > a,
192        _ => false,
193    }
194}
195
196/// Sort key for k-way merge.
197#[derive(Debug, Clone)]
198pub struct SortKey {
199    /// Column index to sort by.
200    pub column: usize,
201    /// Sort direction (ascending = true).
202    pub ascending: bool,
203    /// Nulls first (true) or last (false).
204    pub nulls_first: bool,
205}
206
207impl SortKey {
208    /// Creates an ascending sort key.
209    #[must_use]
210    pub fn ascending(column: usize) -> Self {
211        Self {
212            column,
213            ascending: true,
214            nulls_first: false,
215        }
216    }
217
218    /// Creates a descending sort key.
219    #[must_use]
220    pub fn descending(column: usize) -> Self {
221        Self {
222            column,
223            ascending: false,
224            nulls_first: true,
225        }
226    }
227}
228
229/// Entry in the k-way merge heap.
230struct MergeEntry {
231    /// Row data.
232    row: Vec<Value>,
233    /// Source run index.
234    run_index: usize,
235    /// Row index within the run (for debugging/tracing).
236    #[allow(dead_code)]
237    row_index: usize,
238    /// Sort keys for comparison.
239    keys: Vec<SortKey>,
240}
241
242impl MergeEntry {
243    fn compare_to(&self, other: &Self) -> Ordering {
244        for key in &self.keys {
245            let a = self.row.get(key.column);
246            let b = other.row.get(key.column);
247
248            let ordering = compare_values_for_sort(a, b, key.nulls_first);
249
250            let ordering = if key.ascending {
251                ordering
252            } else {
253                ordering.reverse()
254            };
255
256            if ordering != Ordering::Equal {
257                return ordering;
258            }
259        }
260        Ordering::Equal
261    }
262}
263
264impl PartialEq for MergeEntry {
265    fn eq(&self, other: &Self) -> bool {
266        self.compare_to(other) == Ordering::Equal
267    }
268}
269
270impl Eq for MergeEntry {}
271
272impl PartialOrd for MergeEntry {
273    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
274        Some(self.cmp(other))
275    }
276}
277
278impl Ord for MergeEntry {
279    fn cmp(&self, other: &Self) -> Ordering {
280        // Reverse for min-heap behavior (we want smallest first)
281        other.compare_to(self)
282    }
283}
284
285fn compare_values_for_sort(a: Option<&Value>, b: Option<&Value>, nulls_first: bool) -> Ordering {
286    match (a, b) {
287        (None, None) | (Some(Value::Null), Some(Value::Null)) => Ordering::Equal,
288        (None, _) | (Some(Value::Null), _) => {
289            if nulls_first {
290                Ordering::Less
291            } else {
292                Ordering::Greater
293            }
294        }
295        (_, None) | (_, Some(Value::Null)) => {
296            if nulls_first {
297                Ordering::Greater
298            } else {
299                Ordering::Less
300            }
301        }
302        (Some(a), Some(b)) => compare_values(a, b),
303    }
304}
305
306fn compare_values(a: &Value, b: &Value) -> Ordering {
307    match (a, b) {
308        (Value::Bool(a), Value::Bool(b)) => a.cmp(b),
309        (Value::Int64(a), Value::Int64(b)) => a.cmp(b),
310        (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
311        (Value::String(a), Value::String(b)) => a.cmp(b),
312        _ => Ordering::Equal,
313    }
314}
315
316/// Merges multiple sorted runs into a single sorted output.
317///
318/// Uses a min-heap for efficient k-way merge.
319pub fn merge_sorted_runs(
320    runs: Vec<Vec<Vec<Value>>>,
321    keys: &[SortKey],
322) -> Result<Vec<Vec<Value>>, OperatorError> {
323    if runs.is_empty() {
324        return Ok(Vec::new());
325    }
326
327    if runs.len() == 1 {
328        // Invariant: runs.len() == 1 guarantees exactly one element
329        return Ok(runs
330            .into_iter()
331            .next()
332            .expect("runs has exactly one element: checked on previous line"));
333    }
334
335    // Count total rows
336    let total_rows: usize = runs.iter().map(|r| r.len()).sum();
337    let mut result = Vec::with_capacity(total_rows);
338
339    // Track position in each run
340    let mut positions: Vec<usize> = vec![0; runs.len()];
341
342    // Initialize heap with first row from each non-empty run
343    let mut heap = BinaryHeap::new();
344    for (run_index, run) in runs.iter().enumerate() {
345        if !run.is_empty() {
346            heap.push(MergeEntry {
347                row: run[0].clone(),
348                run_index,
349                row_index: 0,
350                keys: keys.to_vec(),
351            });
352            positions[run_index] = 1;
353        }
354    }
355
356    // Extract rows in order
357    while let Some(entry) = heap.pop() {
358        result.push(entry.row);
359
360        // Add next row from same run if available
361        let pos = positions[entry.run_index];
362        if pos < runs[entry.run_index].len() {
363            heap.push(MergeEntry {
364                row: runs[entry.run_index][pos].clone(),
365                run_index: entry.run_index,
366                row_index: pos,
367                keys: keys.to_vec(),
368            });
369            positions[entry.run_index] += 1;
370        }
371    }
372
373    Ok(result)
374}
375
376/// Converts sorted rows to DataChunks.
377pub fn rows_to_chunks(
378    rows: Vec<Vec<Value>>,
379    chunk_size: usize,
380) -> Result<Vec<DataChunk>, OperatorError> {
381    if rows.is_empty() {
382        return Ok(Vec::new());
383    }
384
385    let num_columns = rows[0].len();
386    let num_chunks = (rows.len() + chunk_size - 1) / chunk_size;
387    let mut chunks = Vec::with_capacity(num_chunks);
388
389    for chunk_rows in rows.chunks(chunk_size) {
390        let mut columns: Vec<ValueVector> = (0..num_columns).map(|_| ValueVector::new()).collect();
391
392        for row in chunk_rows {
393            for (col_idx, col) in columns.iter_mut().enumerate() {
394                let val = row.get(col_idx).cloned().unwrap_or(Value::Null);
395                col.push(val);
396            }
397        }
398
399        chunks.push(DataChunk::new(columns));
400    }
401
402    Ok(chunks)
403}
404
405/// Merges multiple sorted DataChunk streams into a single sorted stream.
406pub fn merge_sorted_chunks(
407    runs: Vec<Vec<DataChunk>>,
408    keys: &[SortKey],
409    chunk_size: usize,
410) -> Result<Vec<DataChunk>, OperatorError> {
411    // Convert chunks to row format for merging
412    let row_runs: Vec<Vec<Vec<Value>>> = runs
413        .into_iter()
414        .map(|chunks| chunks_to_rows(chunks))
415        .collect();
416
417    let merged_rows = merge_sorted_runs(row_runs, keys)?;
418    rows_to_chunks(merged_rows, chunk_size)
419}
420
421/// Converts DataChunks to row format.
422fn chunks_to_rows(chunks: Vec<DataChunk>) -> Vec<Vec<Value>> {
423    let mut rows = Vec::new();
424
425    for chunk in chunks {
426        let num_columns = chunk.num_columns();
427        for i in 0..chunk.len() {
428            let mut row = Vec::with_capacity(num_columns);
429            for col_idx in 0..num_columns {
430                let val = chunk
431                    .column(col_idx)
432                    .and_then(|c| c.get(i))
433                    .unwrap_or(Value::Null);
434                row.push(val);
435            }
436            rows.push(row);
437        }
438    }
439
440    rows
441}
442
443/// Concatenates multiple DataChunk results (for non-sorted parallel results).
444pub fn concat_parallel_results(results: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
445    results.into_iter().flatten().collect()
446}
447
448/// Merges parallel DISTINCT results by deduplication.
449pub fn merge_distinct_results(
450    results: Vec<Vec<DataChunk>>,
451) -> Result<Vec<DataChunk>, OperatorError> {
452    use std::collections::HashSet;
453
454    // Simple row-based deduplication using hash
455    let mut seen: HashSet<u64> = HashSet::new();
456    let mut unique_rows: Vec<Vec<Value>> = Vec::new();
457
458    for chunks in results {
459        for chunk in chunks {
460            let num_columns = chunk.num_columns();
461            for i in 0..chunk.len() {
462                let mut row = Vec::with_capacity(num_columns);
463                for col_idx in 0..num_columns {
464                    let val = chunk
465                        .column(col_idx)
466                        .and_then(|c| c.get(i))
467                        .unwrap_or(Value::Null);
468                    row.push(val);
469                }
470
471                let hash = hash_row(&row);
472                if seen.insert(hash) {
473                    unique_rows.push(row);
474                }
475            }
476        }
477    }
478
479    rows_to_chunks(unique_rows, 2048)
480}
481
482fn hash_row(row: &[Value]) -> u64 {
483    use std::collections::hash_map::DefaultHasher;
484    use std::hash::{Hash, Hasher};
485
486    let mut hasher = DefaultHasher::new();
487    for value in row {
488        match value {
489            Value::Null => 0u8.hash(&mut hasher),
490            Value::Bool(b) => b.hash(&mut hasher),
491            Value::Int64(i) => i.hash(&mut hasher),
492            Value::Float64(f) => f.to_bits().hash(&mut hasher),
493            Value::String(s) => s.hash(&mut hasher),
494            _ => 0u8.hash(&mut hasher),
495        }
496    }
497    hasher.finish()
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503
504    #[test]
505    fn test_mergeable_accumulator() {
506        let mut acc1 = MergeableAccumulator::new();
507        acc1.add(&Value::Int64(10));
508        acc1.add(&Value::Int64(20));
509
510        let mut acc2 = MergeableAccumulator::new();
511        acc2.add(&Value::Int64(30));
512        acc2.add(&Value::Int64(40));
513
514        acc1.merge(&acc2);
515
516        assert_eq!(acc1.count, 4);
517        assert_eq!(acc1.sum, 100.0);
518        assert_eq!(acc1.finalize_min(), Value::Int64(10));
519        assert_eq!(acc1.finalize_max(), Value::Int64(40));
520        assert_eq!(acc1.finalize_avg(), Value::Float64(25.0));
521    }
522
523    #[test]
524    fn test_merge_sorted_runs_empty() {
525        let runs: Vec<Vec<Vec<Value>>> = Vec::new();
526        let result = merge_sorted_runs(runs, &[]).unwrap();
527        assert!(result.is_empty());
528    }
529
530    #[test]
531    fn test_merge_sorted_runs_single() {
532        let runs = vec![vec![
533            vec![Value::Int64(1)],
534            vec![Value::Int64(2)],
535            vec![Value::Int64(3)],
536        ]];
537        let keys = vec![SortKey::ascending(0)];
538
539        let result = merge_sorted_runs(runs, &keys).unwrap();
540        assert_eq!(result.len(), 3);
541    }
542
543    #[test]
544    fn test_merge_sorted_runs_multiple() {
545        // Run 1: [1, 4, 7]
546        // Run 2: [2, 5, 8]
547        // Run 3: [3, 6, 9]
548        let runs = vec![
549            vec![
550                vec![Value::Int64(1)],
551                vec![Value::Int64(4)],
552                vec![Value::Int64(7)],
553            ],
554            vec![
555                vec![Value::Int64(2)],
556                vec![Value::Int64(5)],
557                vec![Value::Int64(8)],
558            ],
559            vec![
560                vec![Value::Int64(3)],
561                vec![Value::Int64(6)],
562                vec![Value::Int64(9)],
563            ],
564        ];
565        let keys = vec![SortKey::ascending(0)];
566
567        let result = merge_sorted_runs(runs, &keys).unwrap();
568        assert_eq!(result.len(), 9);
569
570        // Verify sorted order
571        for i in 0..9 {
572            assert_eq!(result[i][0], Value::Int64((i + 1) as i64));
573        }
574    }
575
576    #[test]
577    fn test_merge_sorted_runs_descending() {
578        let runs = vec![
579            vec![
580                vec![Value::Int64(7)],
581                vec![Value::Int64(4)],
582                vec![Value::Int64(1)],
583            ],
584            vec![
585                vec![Value::Int64(8)],
586                vec![Value::Int64(5)],
587                vec![Value::Int64(2)],
588            ],
589        ];
590        let keys = vec![SortKey::descending(0)];
591
592        let result = merge_sorted_runs(runs, &keys).unwrap();
593        assert_eq!(result.len(), 6);
594
595        // Verify descending order
596        assert_eq!(result[0][0], Value::Int64(8));
597        assert_eq!(result[1][0], Value::Int64(7));
598        assert_eq!(result[5][0], Value::Int64(1));
599    }
600
601    #[test]
602    fn test_rows_to_chunks() {
603        let rows = (0..10).map(|i| vec![Value::Int64(i)]).collect();
604        let chunks = rows_to_chunks(rows, 3).unwrap();
605
606        assert_eq!(chunks.len(), 4); // 10 rows / 3 = 4 chunks
607        assert_eq!(chunks[0].len(), 3);
608        assert_eq!(chunks[1].len(), 3);
609        assert_eq!(chunks[2].len(), 3);
610        assert_eq!(chunks[3].len(), 1);
611    }
612
613    #[test]
614    fn test_merge_distinct_results() {
615        let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[
616            Value::Int64(1),
617            Value::Int64(2),
618            Value::Int64(3),
619        ])]);
620
621        let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[
622            Value::Int64(2),
623            Value::Int64(3),
624            Value::Int64(4),
625        ])]);
626
627        let results = vec![vec![chunk1], vec![chunk2]];
628        let merged = merge_distinct_results(results).unwrap();
629
630        let total_rows: usize = merged.iter().map(DataChunk::len).sum();
631        assert_eq!(total_rows, 4); // 1, 2, 3, 4 (no duplicates)
632    }
633
634    #[test]
635    fn test_concat_parallel_results() {
636        let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(1)])]);
637        let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(2)])]);
638        let chunk3 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(3)])]);
639
640        let results = vec![vec![chunk1], vec![chunk2, chunk3]];
641        let concatenated = concat_parallel_results(results);
642
643        assert_eq!(concatenated.len(), 3);
644    }
645}