Skip to main content

grafeo_core/execution/parallel/
merge.rs

1//! Merge utilities for parallel pipeline breakers.
2//!
3//! When parallel pipelines have pipeline breakers (Sort, Aggregate, Distinct),
4//! each worker produces partial results that must be merged into final output.
5
6use crate::execution::chunk::DataChunk;
7use crate::execution::operators::OperatorError;
8use crate::execution::vector::ValueVector;
9use grafeo_common::types::Value;
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12
13/// Trait for operators that support parallel merge.
14///
15/// Pipeline breakers must implement this to enable parallel execution.
16pub trait MergeableOperator: Send + Sync {
17    /// Merges partial results from another operator instance.
18    fn merge_from(&mut self, other: Self)
19    where
20        Self: Sized;
21
22    /// Returns whether this operator supports parallel merge.
23    fn supports_parallel_merge(&self) -> bool {
24        true
25    }
26}
27
28/// Accumulator state that supports merging.
29///
30/// Used by aggregate operators to merge partial aggregations.
31#[derive(Debug, Clone)]
32pub struct MergeableAccumulator {
33    /// Count of values.
34    pub count: i64,
35    /// Sum of values.
36    pub sum: f64,
37    /// Minimum value.
38    pub min: Option<Value>,
39    /// Maximum value.
40    pub max: Option<Value>,
41    /// First value encountered.
42    pub first: Option<Value>,
43    /// For AVG: sum of squared values (for variance if needed).
44    pub sum_squared: f64,
45}
46
47impl MergeableAccumulator {
48    /// Creates a new empty accumulator.
49    #[must_use]
50    pub fn new() -> Self {
51        Self {
52            count: 0,
53            sum: 0.0,
54            min: None,
55            max: None,
56            first: None,
57            sum_squared: 0.0,
58        }
59    }
60
61    /// Adds a value to the accumulator.
62    pub fn add(&mut self, value: &Value) {
63        if matches!(value, Value::Null) {
64            return;
65        }
66
67        self.count += 1;
68
69        if let Some(n) = value_to_f64(value) {
70            self.sum += n;
71            self.sum_squared += n * n;
72        }
73
74        // Min
75        if self.min.is_none() || compare_for_min(&self.min, value) {
76            self.min = Some(value.clone());
77        }
78
79        // Max
80        if self.max.is_none() || compare_for_max(&self.max, value) {
81            self.max = Some(value.clone());
82        }
83
84        // First
85        if self.first.is_none() {
86            self.first = Some(value.clone());
87        }
88    }
89
90    /// Merges another accumulator into this one.
91    pub fn merge(&mut self, other: &MergeableAccumulator) {
92        self.count += other.count;
93        self.sum += other.sum;
94        self.sum_squared += other.sum_squared;
95
96        // Merge min
97        if let Some(ref other_min) = other.min
98            && compare_for_min(&self.min, other_min)
99        {
100            self.min = Some(other_min.clone());
101        }
102
103        // Merge max
104        if let Some(ref other_max) = other.max
105            && compare_for_max(&self.max, other_max)
106        {
107            self.max = Some(other_max.clone());
108        }
109
110        // Keep our first (we processed earlier)
111        // If we have no first, take theirs
112        if self.first.is_none() {
113            self.first.clone_from(&other.first);
114        }
115    }
116
117    /// Finalizes COUNT aggregate.
118    #[must_use]
119    pub fn finalize_count(&self) -> Value {
120        Value::Int64(self.count)
121    }
122
123    /// Finalizes SUM aggregate.
124    #[must_use]
125    pub fn finalize_sum(&self) -> Value {
126        if self.count == 0 {
127            Value::Null
128        } else {
129            Value::Float64(self.sum)
130        }
131    }
132
133    /// Finalizes MIN aggregate.
134    #[must_use]
135    pub fn finalize_min(&self) -> Value {
136        self.min.clone().unwrap_or(Value::Null)
137    }
138
139    /// Finalizes MAX aggregate.
140    #[must_use]
141    pub fn finalize_max(&self) -> Value {
142        self.max.clone().unwrap_or(Value::Null)
143    }
144
145    /// Finalizes AVG aggregate.
146    #[must_use]
147    pub fn finalize_avg(&self) -> Value {
148        if self.count == 0 {
149            Value::Null
150        } else {
151            Value::Float64(self.sum / self.count as f64)
152        }
153    }
154
155    /// Finalizes FIRST aggregate.
156    #[must_use]
157    pub fn finalize_first(&self) -> Value {
158        self.first.clone().unwrap_or(Value::Null)
159    }
160}
161
162impl Default for MergeableAccumulator {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168fn value_to_f64(value: &Value) -> Option<f64> {
169    match value {
170        Value::Int64(i) => Some(*i as f64),
171        Value::Float64(f) => Some(*f),
172        _ => None,
173    }
174}
175
176fn compare_for_min(current: &Option<Value>, new: &Value) -> bool {
177    match (current, new) {
178        (None, _) => true,
179        (Some(Value::Int64(a)), Value::Int64(b)) => b < a,
180        (Some(Value::Float64(a)), Value::Float64(b)) => b < a,
181        (Some(Value::String(a)), Value::String(b)) => b < a,
182        _ => false,
183    }
184}
185
186fn compare_for_max(current: &Option<Value>, new: &Value) -> bool {
187    match (current, new) {
188        (None, _) => true,
189        (Some(Value::Int64(a)), Value::Int64(b)) => b > a,
190        (Some(Value::Float64(a)), Value::Float64(b)) => b > a,
191        (Some(Value::String(a)), Value::String(b)) => b > a,
192        _ => false,
193    }
194}
195
196/// Sort key for k-way merge.
197#[derive(Debug, Clone)]
198pub struct SortKey {
199    /// Column index to sort by.
200    pub column: usize,
201    /// Sort direction (ascending = true).
202    pub ascending: bool,
203    /// Nulls first (true) or last (false).
204    pub nulls_first: bool,
205}
206
207impl SortKey {
208    /// Creates an ascending sort key.
209    #[must_use]
210    pub fn ascending(column: usize) -> Self {
211        Self {
212            column,
213            ascending: true,
214            nulls_first: false,
215        }
216    }
217
218    /// Creates a descending sort key.
219    #[must_use]
220    pub fn descending(column: usize) -> Self {
221        Self {
222            column,
223            ascending: false,
224            nulls_first: true,
225        }
226    }
227}
228
229/// Entry in the k-way merge heap.
230struct MergeEntry {
231    /// Row data.
232    row: Vec<Value>,
233    /// Source run index.
234    run_index: usize,
235    /// Sort keys for comparison.
236    keys: Vec<SortKey>,
237}
238
239impl MergeEntry {
240    fn compare_to(&self, other: &Self) -> Ordering {
241        for key in &self.keys {
242            let a = self.row.get(key.column);
243            let b = other.row.get(key.column);
244
245            let ordering = compare_values_for_sort(a, b, key.nulls_first);
246
247            let ordering = if key.ascending {
248                ordering
249            } else {
250                ordering.reverse()
251            };
252
253            if ordering != Ordering::Equal {
254                return ordering;
255            }
256        }
257        Ordering::Equal
258    }
259}
260
261impl PartialEq for MergeEntry {
262    fn eq(&self, other: &Self) -> bool {
263        self.compare_to(other) == Ordering::Equal
264    }
265}
266
267impl Eq for MergeEntry {}
268
269impl PartialOrd for MergeEntry {
270    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
271        Some(self.cmp(other))
272    }
273}
274
275impl Ord for MergeEntry {
276    fn cmp(&self, other: &Self) -> Ordering {
277        // Reverse for min-heap behavior (we want smallest first)
278        other.compare_to(self)
279    }
280}
281
282fn compare_values_for_sort(a: Option<&Value>, b: Option<&Value>, nulls_first: bool) -> Ordering {
283    match (a, b) {
284        (None, None) | (Some(Value::Null), Some(Value::Null)) => Ordering::Equal,
285        (None, _) | (Some(Value::Null), _) => {
286            if nulls_first {
287                Ordering::Less
288            } else {
289                Ordering::Greater
290            }
291        }
292        (_, None) | (_, Some(Value::Null)) => {
293            if nulls_first {
294                Ordering::Greater
295            } else {
296                Ordering::Less
297            }
298        }
299        (Some(a), Some(b)) => compare_values(a, b),
300    }
301}
302
303fn compare_values(a: &Value, b: &Value) -> Ordering {
304    match (a, b) {
305        (Value::Bool(a), Value::Bool(b)) => a.cmp(b),
306        (Value::Int64(a), Value::Int64(b)) => a.cmp(b),
307        (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
308        (Value::String(a), Value::String(b)) => a.cmp(b),
309        (Value::Timestamp(a), Value::Timestamp(b)) => a.cmp(b),
310        (Value::Date(a), Value::Date(b)) => a.cmp(b),
311        (Value::Time(a), Value::Time(b)) => a.cmp(b),
312        _ => Ordering::Equal,
313    }
314}
315
316/// Merges multiple sorted runs into a single sorted output.
317///
318/// Uses a min-heap for efficient k-way merge.
319pub fn merge_sorted_runs(
320    runs: Vec<Vec<Vec<Value>>>,
321    keys: &[SortKey],
322) -> Result<Vec<Vec<Value>>, OperatorError> {
323    if runs.is_empty() {
324        return Ok(Vec::new());
325    }
326
327    if runs.len() == 1 {
328        // Invariant: runs.len() == 1 guarantees exactly one element
329        return Ok(runs
330            .into_iter()
331            .next()
332            .expect("runs has exactly one element: checked on previous line"));
333    }
334
335    // Count total rows
336    let total_rows: usize = runs.iter().map(|r| r.len()).sum();
337    let mut result = Vec::with_capacity(total_rows);
338
339    // Track position in each run
340    let mut positions: Vec<usize> = vec![0; runs.len()];
341
342    // Initialize heap with first row from each non-empty run
343    let mut heap = BinaryHeap::new();
344    for (run_index, run) in runs.iter().enumerate() {
345        if !run.is_empty() {
346            heap.push(MergeEntry {
347                row: run[0].clone(),
348                run_index,
349                keys: keys.to_vec(),
350            });
351            positions[run_index] = 1;
352        }
353    }
354
355    // Extract rows in order
356    while let Some(entry) = heap.pop() {
357        result.push(entry.row);
358
359        // Add next row from same run if available
360        let pos = positions[entry.run_index];
361        if pos < runs[entry.run_index].len() {
362            heap.push(MergeEntry {
363                row: runs[entry.run_index][pos].clone(),
364                run_index: entry.run_index,
365                keys: keys.to_vec(),
366            });
367            positions[entry.run_index] += 1;
368        }
369    }
370
371    Ok(result)
372}
373
374/// Converts sorted rows to DataChunks.
375pub fn rows_to_chunks(
376    rows: Vec<Vec<Value>>,
377    chunk_size: usize,
378) -> Result<Vec<DataChunk>, OperatorError> {
379    if rows.is_empty() {
380        return Ok(Vec::new());
381    }
382
383    let num_columns = rows[0].len();
384    let num_chunks = (rows.len() + chunk_size - 1) / chunk_size;
385    let mut chunks = Vec::with_capacity(num_chunks);
386
387    for chunk_rows in rows.chunks(chunk_size) {
388        let mut columns: Vec<ValueVector> = (0..num_columns).map(|_| ValueVector::new()).collect();
389
390        for row in chunk_rows {
391            for (col_idx, col) in columns.iter_mut().enumerate() {
392                let val = row.get(col_idx).cloned().unwrap_or(Value::Null);
393                col.push(val);
394            }
395        }
396
397        chunks.push(DataChunk::new(columns));
398    }
399
400    Ok(chunks)
401}
402
403/// Merges multiple sorted DataChunk streams into a single sorted stream.
404pub fn merge_sorted_chunks(
405    runs: Vec<Vec<DataChunk>>,
406    keys: &[SortKey],
407    chunk_size: usize,
408) -> Result<Vec<DataChunk>, OperatorError> {
409    // Convert chunks to row format for merging
410    let row_runs: Vec<Vec<Vec<Value>>> = runs.into_iter().map(chunks_to_rows).collect();
411
412    let merged_rows = merge_sorted_runs(row_runs, keys)?;
413    rows_to_chunks(merged_rows, chunk_size)
414}
415
416/// Converts DataChunks to row format.
417fn chunks_to_rows(chunks: Vec<DataChunk>) -> Vec<Vec<Value>> {
418    let mut rows = Vec::new();
419
420    for chunk in chunks {
421        let num_columns = chunk.num_columns();
422        for i in 0..chunk.len() {
423            let mut row = Vec::with_capacity(num_columns);
424            for col_idx in 0..num_columns {
425                let val = chunk
426                    .column(col_idx)
427                    .and_then(|c| c.get(i))
428                    .unwrap_or(Value::Null);
429                row.push(val);
430            }
431            rows.push(row);
432        }
433    }
434
435    rows
436}
437
438/// Concatenates multiple DataChunk results (for non-sorted parallel results).
439pub fn concat_parallel_results(results: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
440    results.into_iter().flatten().collect()
441}
442
443/// Merges parallel DISTINCT results by deduplication.
444pub fn merge_distinct_results(
445    results: Vec<Vec<DataChunk>>,
446) -> Result<Vec<DataChunk>, OperatorError> {
447    use std::collections::HashSet;
448
449    // Simple row-based deduplication using hash
450    let mut seen: HashSet<u64> = HashSet::new();
451    let mut unique_rows: Vec<Vec<Value>> = Vec::new();
452
453    for chunks in results {
454        for chunk in chunks {
455            let num_columns = chunk.num_columns();
456            for i in 0..chunk.len() {
457                let mut row = Vec::with_capacity(num_columns);
458                for col_idx in 0..num_columns {
459                    let val = chunk
460                        .column(col_idx)
461                        .and_then(|c| c.get(i))
462                        .unwrap_or(Value::Null);
463                    row.push(val);
464                }
465
466                let hash = hash_row(&row);
467                if seen.insert(hash) {
468                    unique_rows.push(row);
469                }
470            }
471        }
472    }
473
474    rows_to_chunks(unique_rows, 2048)
475}
476
477fn hash_row(row: &[Value]) -> u64 {
478    use std::collections::hash_map::DefaultHasher;
479    use std::hash::{Hash, Hasher};
480
481    let mut hasher = DefaultHasher::new();
482    for value in row {
483        match value {
484            Value::Null => 0u8.hash(&mut hasher),
485            Value::Bool(b) => b.hash(&mut hasher),
486            Value::Int64(i) => i.hash(&mut hasher),
487            Value::Float64(f) => f.to_bits().hash(&mut hasher),
488            Value::String(s) => s.hash(&mut hasher),
489            _ => 0u8.hash(&mut hasher),
490        }
491    }
492    hasher.finish()
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn test_mergeable_accumulator() {
501        let mut acc1 = MergeableAccumulator::new();
502        acc1.add(&Value::Int64(10));
503        acc1.add(&Value::Int64(20));
504
505        let mut acc2 = MergeableAccumulator::new();
506        acc2.add(&Value::Int64(30));
507        acc2.add(&Value::Int64(40));
508
509        acc1.merge(&acc2);
510
511        assert_eq!(acc1.count, 4);
512        assert_eq!(acc1.sum, 100.0);
513        assert_eq!(acc1.finalize_min(), Value::Int64(10));
514        assert_eq!(acc1.finalize_max(), Value::Int64(40));
515        assert_eq!(acc1.finalize_avg(), Value::Float64(25.0));
516    }
517
518    #[test]
519    fn test_merge_sorted_runs_empty() {
520        let runs: Vec<Vec<Vec<Value>>> = Vec::new();
521        let result = merge_sorted_runs(runs, &[]).unwrap();
522        assert!(result.is_empty());
523    }
524
525    #[test]
526    fn test_merge_sorted_runs_single() {
527        let runs = vec![vec![
528            vec![Value::Int64(1)],
529            vec![Value::Int64(2)],
530            vec![Value::Int64(3)],
531        ]];
532        let keys = vec![SortKey::ascending(0)];
533
534        let result = merge_sorted_runs(runs, &keys).unwrap();
535        assert_eq!(result.len(), 3);
536    }
537
538    #[test]
539    fn test_merge_sorted_runs_multiple() {
540        // Run 1: [1, 4, 7]
541        // Run 2: [2, 5, 8]
542        // Run 3: [3, 6, 9]
543        let runs = vec![
544            vec![
545                vec![Value::Int64(1)],
546                vec![Value::Int64(4)],
547                vec![Value::Int64(7)],
548            ],
549            vec![
550                vec![Value::Int64(2)],
551                vec![Value::Int64(5)],
552                vec![Value::Int64(8)],
553            ],
554            vec![
555                vec![Value::Int64(3)],
556                vec![Value::Int64(6)],
557                vec![Value::Int64(9)],
558            ],
559        ];
560        let keys = vec![SortKey::ascending(0)];
561
562        let result = merge_sorted_runs(runs, &keys).unwrap();
563        assert_eq!(result.len(), 9);
564
565        // Verify sorted order
566        for i in 0..9 {
567            assert_eq!(result[i][0], Value::Int64((i + 1) as i64));
568        }
569    }
570
571    #[test]
572    fn test_merge_sorted_runs_descending() {
573        let runs = vec![
574            vec![
575                vec![Value::Int64(7)],
576                vec![Value::Int64(4)],
577                vec![Value::Int64(1)],
578            ],
579            vec![
580                vec![Value::Int64(8)],
581                vec![Value::Int64(5)],
582                vec![Value::Int64(2)],
583            ],
584        ];
585        let keys = vec![SortKey::descending(0)];
586
587        let result = merge_sorted_runs(runs, &keys).unwrap();
588        assert_eq!(result.len(), 6);
589
590        // Verify descending order
591        assert_eq!(result[0][0], Value::Int64(8));
592        assert_eq!(result[1][0], Value::Int64(7));
593        assert_eq!(result[5][0], Value::Int64(1));
594    }
595
596    #[test]
597    fn test_rows_to_chunks() {
598        let rows = (0..10).map(|i| vec![Value::Int64(i)]).collect();
599        let chunks = rows_to_chunks(rows, 3).unwrap();
600
601        assert_eq!(chunks.len(), 4); // 10 rows / 3 = 4 chunks
602        assert_eq!(chunks[0].len(), 3);
603        assert_eq!(chunks[1].len(), 3);
604        assert_eq!(chunks[2].len(), 3);
605        assert_eq!(chunks[3].len(), 1);
606    }
607
608    #[test]
609    fn test_merge_distinct_results() {
610        let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[
611            Value::Int64(1),
612            Value::Int64(2),
613            Value::Int64(3),
614        ])]);
615
616        let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[
617            Value::Int64(2),
618            Value::Int64(3),
619            Value::Int64(4),
620        ])]);
621
622        let results = vec![vec![chunk1], vec![chunk2]];
623        let merged = merge_distinct_results(results).unwrap();
624
625        let total_rows: usize = merged.iter().map(DataChunk::len).sum();
626        assert_eq!(total_rows, 4); // 1, 2, 3, 4 (no duplicates)
627    }
628
629    #[test]
630    fn test_hash_row_with_non_primitive_values() {
631        // Exercises the catch-all branch in hash_row for non-primitive Value types
632        let row1 = vec![Value::List(vec![Value::Int64(1)].into())];
633        let row2 = vec![Value::List(vec![Value::Int64(2)].into())];
634        let row3 = vec![Value::Bytes(vec![1, 2, 3].into())];
635
636        // The catch-all hashes all non-primitive types to the same bucket (0u8)
637        let h1 = hash_row(&row1);
638        let h2 = hash_row(&row2);
639        let h3 = hash_row(&row3);
640
641        // All non-primitive types hash identically via the catch-all
642        assert_eq!(h1, h2);
643        assert_eq!(h2, h3);
644    }
645
646    #[test]
647    fn test_concat_parallel_results() {
648        let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(1)])]);
649        let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(2)])]);
650        let chunk3 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(3)])]);
651
652        let results = vec![vec![chunk1], vec![chunk2, chunk3]];
653        let concatenated = concat_parallel_results(results);
654
655        assert_eq!(concatenated.len(), 3);
656    }
657}