1use crate::execution::chunk::DataChunk;
7use crate::execution::vector::ValueVector;
8use grafeo_common::types::Value;
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12pub trait MergeableOperator: Send + Sync {
16 fn merge_from(&mut self, other: Self)
18 where
19 Self: Sized;
20
21 fn supports_parallel_merge(&self) -> bool {
23 true
24 }
25}
26
27#[derive(Debug, Clone)]
31pub struct MergeableAccumulator {
32 pub count: i64,
34 pub sum: f64,
36 pub min: Option<Value>,
38 pub max: Option<Value>,
40 pub first: Option<Value>,
42 pub sum_squared: f64,
44}
45
46impl MergeableAccumulator {
47 #[must_use]
49 pub fn new() -> Self {
50 Self {
51 count: 0,
52 sum: 0.0,
53 min: None,
54 max: None,
55 first: None,
56 sum_squared: 0.0,
57 }
58 }
59
60 pub fn add(&mut self, value: &Value) {
62 if matches!(value, Value::Null) {
63 return;
64 }
65
66 self.count += 1;
67
68 if let Some(n) = value_to_f64(value) {
69 self.sum += n;
70 self.sum_squared += n * n;
71 }
72
73 if self.min.is_none() || compare_for_min(&self.min, value) {
75 self.min = Some(value.clone());
76 }
77
78 if self.max.is_none() || compare_for_max(&self.max, value) {
80 self.max = Some(value.clone());
81 }
82
83 if self.first.is_none() {
85 self.first = Some(value.clone());
86 }
87 }
88
89 pub fn merge(&mut self, other: &MergeableAccumulator) {
91 self.count += other.count;
92 self.sum += other.sum;
93 self.sum_squared += other.sum_squared;
94
95 if let Some(ref other_min) = other.min
97 && compare_for_min(&self.min, other_min)
98 {
99 self.min = Some(other_min.clone());
100 }
101
102 if let Some(ref other_max) = other.max
104 && compare_for_max(&self.max, other_max)
105 {
106 self.max = Some(other_max.clone());
107 }
108
109 if self.first.is_none() {
112 self.first.clone_from(&other.first);
113 }
114 }
115
116 #[must_use]
118 pub fn finalize_count(&self) -> Value {
119 Value::Int64(self.count)
120 }
121
122 #[must_use]
124 pub fn finalize_sum(&self) -> Value {
125 if self.count == 0 {
126 Value::Null
127 } else {
128 Value::Float64(self.sum)
129 }
130 }
131
132 #[must_use]
134 pub fn finalize_min(&self) -> Value {
135 self.min.clone().unwrap_or(Value::Null)
136 }
137
138 #[must_use]
140 pub fn finalize_max(&self) -> Value {
141 self.max.clone().unwrap_or(Value::Null)
142 }
143
144 #[must_use]
146 pub fn finalize_avg(&self) -> Value {
147 if self.count == 0 {
148 Value::Null
149 } else {
150 Value::Float64(self.sum / self.count as f64)
151 }
152 }
153
154 #[must_use]
156 pub fn finalize_first(&self) -> Value {
157 self.first.clone().unwrap_or(Value::Null)
158 }
159}
160
161impl Default for MergeableAccumulator {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166
167fn value_to_f64(value: &Value) -> Option<f64> {
168 match value {
169 Value::Int64(i) => Some(*i as f64),
170 Value::Float64(f) => Some(*f),
171 _ => None,
172 }
173}
174
175fn compare_for_min(current: &Option<Value>, new: &Value) -> bool {
176 match (current, new) {
177 (None, _) => true,
178 (Some(Value::Int64(a)), Value::Int64(b)) => b < a,
179 (Some(Value::Float64(a)), Value::Float64(b)) => b < a,
180 (Some(Value::String(a)), Value::String(b)) => b < a,
181 _ => false,
182 }
183}
184
185fn compare_for_max(current: &Option<Value>, new: &Value) -> bool {
186 match (current, new) {
187 (None, _) => true,
188 (Some(Value::Int64(a)), Value::Int64(b)) => b > a,
189 (Some(Value::Float64(a)), Value::Float64(b)) => b > a,
190 (Some(Value::String(a)), Value::String(b)) => b > a,
191 _ => false,
192 }
193}
194
195#[derive(Debug, Clone)]
197pub struct SortKey {
198 pub column: usize,
200 pub ascending: bool,
202 pub nulls_first: bool,
204}
205
206impl SortKey {
207 #[must_use]
209 pub fn ascending(column: usize) -> Self {
210 Self {
211 column,
212 ascending: true,
213 nulls_first: false,
214 }
215 }
216
217 #[must_use]
219 pub fn descending(column: usize) -> Self {
220 Self {
221 column,
222 ascending: false,
223 nulls_first: true,
224 }
225 }
226}
227
228struct MergeEntry {
230 row: Vec<Value>,
232 run_index: usize,
234 keys: Vec<SortKey>,
236}
237
238impl MergeEntry {
239 fn compare_to(&self, other: &Self) -> Ordering {
240 for key in &self.keys {
241 let a = self.row.get(key.column);
242 let b = other.row.get(key.column);
243
244 let ordering = compare_values_for_sort(a, b, key.nulls_first);
245
246 let ordering = if key.ascending {
247 ordering
248 } else {
249 ordering.reverse()
250 };
251
252 if ordering != Ordering::Equal {
253 return ordering;
254 }
255 }
256 Ordering::Equal
257 }
258}
259
260impl PartialEq for MergeEntry {
261 fn eq(&self, other: &Self) -> bool {
262 self.compare_to(other) == Ordering::Equal
263 }
264}
265
266impl Eq for MergeEntry {}
267
268impl PartialOrd for MergeEntry {
269 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
270 Some(self.cmp(other))
271 }
272}
273
274impl Ord for MergeEntry {
275 fn cmp(&self, other: &Self) -> Ordering {
276 other.compare_to(self)
278 }
279}
280
281fn compare_values_for_sort(a: Option<&Value>, b: Option<&Value>, nulls_first: bool) -> Ordering {
282 match (a, b) {
283 (None, None) | (Some(Value::Null), Some(Value::Null)) => Ordering::Equal,
284 (None, _) | (Some(Value::Null), _) => {
285 if nulls_first {
286 Ordering::Less
287 } else {
288 Ordering::Greater
289 }
290 }
291 (_, None) | (_, Some(Value::Null)) => {
292 if nulls_first {
293 Ordering::Greater
294 } else {
295 Ordering::Less
296 }
297 }
298 (Some(a), Some(b)) => compare_values(a, b),
299 }
300}
301
302fn compare_values(a: &Value, b: &Value) -> Ordering {
303 match (a, b) {
304 (Value::Bool(a), Value::Bool(b)) => a.cmp(b),
305 (Value::Int64(a), Value::Int64(b)) => a.cmp(b),
306 (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
307 (Value::String(a), Value::String(b)) => a.cmp(b),
308 (Value::Timestamp(a), Value::Timestamp(b)) => a.cmp(b),
309 (Value::Date(a), Value::Date(b)) => a.cmp(b),
310 (Value::Time(a), Value::Time(b)) => a.cmp(b),
311 _ => Ordering::Equal,
312 }
313}
314
315pub fn merge_sorted_runs(runs: Vec<Vec<Vec<Value>>>, keys: &[SortKey]) -> Vec<Vec<Value>> {
319 if runs.is_empty() {
320 return Vec::new();
321 }
322
323 if runs.len() == 1 {
324 return runs.into_iter().next().unwrap_or_default();
325 }
326
327 let total_rows: usize = runs.iter().map(|r| r.len()).sum();
329 let mut result = Vec::with_capacity(total_rows);
330
331 let mut positions: Vec<usize> = vec![0; runs.len()];
333
334 let mut heap = BinaryHeap::new();
336 for (run_index, run) in runs.iter().enumerate() {
337 if !run.is_empty() {
338 heap.push(MergeEntry {
339 row: run[0].clone(),
340 run_index,
341 keys: keys.to_vec(),
342 });
343 positions[run_index] = 1;
344 }
345 }
346
347 while let Some(entry) = heap.pop() {
349 result.push(entry.row);
350
351 let pos = positions[entry.run_index];
353 if pos < runs[entry.run_index].len() {
354 heap.push(MergeEntry {
355 row: runs[entry.run_index][pos].clone(),
356 run_index: entry.run_index,
357 keys: keys.to_vec(),
358 });
359 positions[entry.run_index] += 1;
360 }
361 }
362
363 result
364}
365
366pub fn rows_to_chunks(rows: Vec<Vec<Value>>, chunk_size: usize) -> Vec<DataChunk> {
368 if rows.is_empty() {
369 return Vec::new();
370 }
371
372 let num_columns = rows[0].len();
373 let num_chunks = (rows.len() + chunk_size - 1) / chunk_size;
374 let mut chunks = Vec::with_capacity(num_chunks);
375
376 for chunk_rows in rows.chunks(chunk_size) {
377 let mut columns: Vec<ValueVector> = (0..num_columns).map(|_| ValueVector::new()).collect();
378
379 for row in chunk_rows {
380 for (col_idx, col) in columns.iter_mut().enumerate() {
381 let val = row.get(col_idx).cloned().unwrap_or(Value::Null);
382 col.push(val);
383 }
384 }
385
386 chunks.push(DataChunk::new(columns));
387 }
388
389 chunks
390}
391
392pub fn merge_sorted_chunks(
394 runs: Vec<Vec<DataChunk>>,
395 keys: &[SortKey],
396 chunk_size: usize,
397) -> Vec<DataChunk> {
398 let row_runs: Vec<Vec<Vec<Value>>> = runs.into_iter().map(chunks_to_rows).collect();
400
401 let merged_rows = merge_sorted_runs(row_runs, keys);
402 rows_to_chunks(merged_rows, chunk_size)
403}
404
405fn chunks_to_rows(chunks: Vec<DataChunk>) -> Vec<Vec<Value>> {
407 let mut rows = Vec::new();
408
409 for chunk in chunks {
410 let num_columns = chunk.num_columns();
411 for i in 0..chunk.len() {
412 let mut row = Vec::with_capacity(num_columns);
413 for col_idx in 0..num_columns {
414 let val = chunk
415 .column(col_idx)
416 .and_then(|c| c.get(i))
417 .unwrap_or(Value::Null);
418 row.push(val);
419 }
420 rows.push(row);
421 }
422 }
423
424 rows
425}
426
427pub fn concat_parallel_results(results: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
429 results.into_iter().flatten().collect()
430}
431
432pub fn merge_distinct_results(results: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
434 use std::collections::HashSet;
435
436 let mut seen: HashSet<u64> = HashSet::new();
438 let mut unique_rows: Vec<Vec<Value>> = Vec::new();
439
440 for chunks in results {
441 for chunk in chunks {
442 let num_columns = chunk.num_columns();
443 for i in 0..chunk.len() {
444 let mut row = Vec::with_capacity(num_columns);
445 for col_idx in 0..num_columns {
446 let val = chunk
447 .column(col_idx)
448 .and_then(|c| c.get(i))
449 .unwrap_or(Value::Null);
450 row.push(val);
451 }
452
453 let hash = hash_row(&row);
454 if seen.insert(hash) {
455 unique_rows.push(row);
456 }
457 }
458 }
459 }
460
461 rows_to_chunks(unique_rows, 2048)
462}
463
464fn hash_row(row: &[Value]) -> u64 {
465 use std::collections::hash_map::DefaultHasher;
466 use std::hash::{Hash, Hasher};
467
468 let mut hasher = DefaultHasher::new();
469 for value in row {
470 match value {
471 Value::Null => 0u8.hash(&mut hasher),
472 Value::Bool(b) => b.hash(&mut hasher),
473 Value::Int64(i) => i.hash(&mut hasher),
474 Value::Float64(f) => f.to_bits().hash(&mut hasher),
475 Value::String(s) => s.hash(&mut hasher),
476 _ => 0u8.hash(&mut hasher),
477 }
478 }
479 hasher.finish()
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn test_mergeable_accumulator() {
488 let mut acc1 = MergeableAccumulator::new();
489 acc1.add(&Value::Int64(10));
490 acc1.add(&Value::Int64(20));
491
492 let mut acc2 = MergeableAccumulator::new();
493 acc2.add(&Value::Int64(30));
494 acc2.add(&Value::Int64(40));
495
496 acc1.merge(&acc2);
497
498 assert_eq!(acc1.count, 4);
499 assert_eq!(acc1.sum, 100.0);
500 assert_eq!(acc1.finalize_min(), Value::Int64(10));
501 assert_eq!(acc1.finalize_max(), Value::Int64(40));
502 assert_eq!(acc1.finalize_avg(), Value::Float64(25.0));
503 }
504
505 #[test]
506 fn test_merge_sorted_runs_empty() {
507 let runs: Vec<Vec<Vec<Value>>> = Vec::new();
508 let result = merge_sorted_runs(runs, &[]);
509 assert!(result.is_empty());
510 }
511
512 #[test]
513 fn test_merge_sorted_runs_single() {
514 let runs = vec![vec![
515 vec![Value::Int64(1)],
516 vec![Value::Int64(2)],
517 vec![Value::Int64(3)],
518 ]];
519 let keys = vec![SortKey::ascending(0)];
520
521 let result = merge_sorted_runs(runs, &keys);
522 assert_eq!(result.len(), 3);
523 }
524
525 #[test]
526 #[allow(clippy::cast_possible_wrap)]
528 fn test_merge_sorted_runs_multiple() {
529 let runs = vec![
533 vec![
534 vec![Value::Int64(1)],
535 vec![Value::Int64(4)],
536 vec![Value::Int64(7)],
537 ],
538 vec![
539 vec![Value::Int64(2)],
540 vec![Value::Int64(5)],
541 vec![Value::Int64(8)],
542 ],
543 vec![
544 vec![Value::Int64(3)],
545 vec![Value::Int64(6)],
546 vec![Value::Int64(9)],
547 ],
548 ];
549 let keys = vec![SortKey::ascending(0)];
550
551 let result = merge_sorted_runs(runs, &keys);
552 assert_eq!(result.len(), 9);
553
554 for i in 0..9 {
556 assert_eq!(result[i][0], Value::Int64((i + 1) as i64));
557 }
558 }
559
560 #[test]
561 fn test_merge_sorted_runs_descending() {
562 let runs = vec![
563 vec![
564 vec![Value::Int64(7)],
565 vec![Value::Int64(4)],
566 vec![Value::Int64(1)],
567 ],
568 vec![
569 vec![Value::Int64(8)],
570 vec![Value::Int64(5)],
571 vec![Value::Int64(2)],
572 ],
573 ];
574 let keys = vec![SortKey::descending(0)];
575
576 let result = merge_sorted_runs(runs, &keys);
577 assert_eq!(result.len(), 6);
578
579 assert_eq!(result[0][0], Value::Int64(8));
581 assert_eq!(result[1][0], Value::Int64(7));
582 assert_eq!(result[5][0], Value::Int64(1));
583 }
584
585 #[test]
586 fn test_rows_to_chunks() {
587 let rows = (0..10).map(|i| vec![Value::Int64(i)]).collect();
588 let chunks = rows_to_chunks(rows, 3);
589
590 assert_eq!(chunks.len(), 4); assert_eq!(chunks[0].len(), 3);
592 assert_eq!(chunks[1].len(), 3);
593 assert_eq!(chunks[2].len(), 3);
594 assert_eq!(chunks[3].len(), 1);
595 }
596
597 #[test]
598 fn test_merge_distinct_results() {
599 let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[
600 Value::Int64(1),
601 Value::Int64(2),
602 Value::Int64(3),
603 ])]);
604
605 let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[
606 Value::Int64(2),
607 Value::Int64(3),
608 Value::Int64(4),
609 ])]);
610
611 let results = vec![vec![chunk1], vec![chunk2]];
612 let merged = merge_distinct_results(results);
613
614 let total_rows: usize = merged.iter().map(DataChunk::len).sum();
615 assert_eq!(total_rows, 4); }
617
618 #[test]
619 fn test_hash_row_with_non_primitive_values() {
620 let row1 = vec![Value::List(vec![Value::Int64(1)].into())];
622 let row2 = vec![Value::List(vec![Value::Int64(2)].into())];
623 let row3 = vec![Value::Bytes(vec![1, 2, 3].into())];
624
625 let h1 = hash_row(&row1);
627 let h2 = hash_row(&row2);
628 let h3 = hash_row(&row3);
629
630 assert_eq!(h1, h2);
632 assert_eq!(h2, h3);
633 }
634
635 #[test]
636 fn test_concat_parallel_results() {
637 let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(1)])]);
638 let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(2)])]);
639 let chunk3 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(3)])]);
640
641 let results = vec![vec![chunk1], vec![chunk2, chunk3]];
642 let concatenated = concat_parallel_results(results);
643
644 assert_eq!(concatenated.len(), 3);
645 }
646}