1use crate::execution::chunk::DataChunk;
7use crate::execution::operators::OperatorError;
8use crate::execution::vector::ValueVector;
9use grafeo_common::types::Value;
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12
13pub trait MergeableOperator: Send + Sync {
17 fn merge_from(&mut self, other: Self)
19 where
20 Self: Sized;
21
22 fn supports_parallel_merge(&self) -> bool {
24 true
25 }
26}
27
28#[derive(Debug, Clone)]
32pub struct MergeableAccumulator {
33 pub count: i64,
35 pub sum: f64,
37 pub min: Option<Value>,
39 pub max: Option<Value>,
41 pub first: Option<Value>,
43 pub sum_squared: f64,
45}
46
47impl MergeableAccumulator {
48 #[must_use]
50 pub fn new() -> Self {
51 Self {
52 count: 0,
53 sum: 0.0,
54 min: None,
55 max: None,
56 first: None,
57 sum_squared: 0.0,
58 }
59 }
60
61 pub fn add(&mut self, value: &Value) {
63 if matches!(value, Value::Null) {
64 return;
65 }
66
67 self.count += 1;
68
69 if let Some(n) = value_to_f64(value) {
70 self.sum += n;
71 self.sum_squared += n * n;
72 }
73
74 if self.min.is_none() || compare_for_min(&self.min, value) {
76 self.min = Some(value.clone());
77 }
78
79 if self.max.is_none() || compare_for_max(&self.max, value) {
81 self.max = Some(value.clone());
82 }
83
84 if self.first.is_none() {
86 self.first = Some(value.clone());
87 }
88 }
89
90 pub fn merge(&mut self, other: &MergeableAccumulator) {
92 self.count += other.count;
93 self.sum += other.sum;
94 self.sum_squared += other.sum_squared;
95
96 if let Some(ref other_min) = other.min
98 && compare_for_min(&self.min, other_min)
99 {
100 self.min = Some(other_min.clone());
101 }
102
103 if let Some(ref other_max) = other.max
105 && compare_for_max(&self.max, other_max)
106 {
107 self.max = Some(other_max.clone());
108 }
109
110 if self.first.is_none() {
113 self.first.clone_from(&other.first);
114 }
115 }
116
117 #[must_use]
119 pub fn finalize_count(&self) -> Value {
120 Value::Int64(self.count)
121 }
122
123 #[must_use]
125 pub fn finalize_sum(&self) -> Value {
126 if self.count == 0 {
127 Value::Null
128 } else {
129 Value::Float64(self.sum)
130 }
131 }
132
133 #[must_use]
135 pub fn finalize_min(&self) -> Value {
136 self.min.clone().unwrap_or(Value::Null)
137 }
138
139 #[must_use]
141 pub fn finalize_max(&self) -> Value {
142 self.max.clone().unwrap_or(Value::Null)
143 }
144
145 #[must_use]
147 pub fn finalize_avg(&self) -> Value {
148 if self.count == 0 {
149 Value::Null
150 } else {
151 Value::Float64(self.sum / self.count as f64)
152 }
153 }
154
155 #[must_use]
157 pub fn finalize_first(&self) -> Value {
158 self.first.clone().unwrap_or(Value::Null)
159 }
160}
161
162impl Default for MergeableAccumulator {
163 fn default() -> Self {
164 Self::new()
165 }
166}
167
168fn value_to_f64(value: &Value) -> Option<f64> {
169 match value {
170 Value::Int64(i) => Some(*i as f64),
171 Value::Float64(f) => Some(*f),
172 _ => None,
173 }
174}
175
176fn compare_for_min(current: &Option<Value>, new: &Value) -> bool {
177 match (current, new) {
178 (None, _) => true,
179 (Some(Value::Int64(a)), Value::Int64(b)) => b < a,
180 (Some(Value::Float64(a)), Value::Float64(b)) => b < a,
181 (Some(Value::String(a)), Value::String(b)) => b < a,
182 _ => false,
183 }
184}
185
186fn compare_for_max(current: &Option<Value>, new: &Value) -> bool {
187 match (current, new) {
188 (None, _) => true,
189 (Some(Value::Int64(a)), Value::Int64(b)) => b > a,
190 (Some(Value::Float64(a)), Value::Float64(b)) => b > a,
191 (Some(Value::String(a)), Value::String(b)) => b > a,
192 _ => false,
193 }
194}
195
196#[derive(Debug, Clone)]
198pub struct SortKey {
199 pub column: usize,
201 pub ascending: bool,
203 pub nulls_first: bool,
205}
206
207impl SortKey {
208 #[must_use]
210 pub fn ascending(column: usize) -> Self {
211 Self {
212 column,
213 ascending: true,
214 nulls_first: false,
215 }
216 }
217
218 #[must_use]
220 pub fn descending(column: usize) -> Self {
221 Self {
222 column,
223 ascending: false,
224 nulls_first: true,
225 }
226 }
227}
228
229struct MergeEntry {
231 row: Vec<Value>,
233 run_index: usize,
235 keys: Vec<SortKey>,
237}
238
239impl MergeEntry {
240 fn compare_to(&self, other: &Self) -> Ordering {
241 for key in &self.keys {
242 let a = self.row.get(key.column);
243 let b = other.row.get(key.column);
244
245 let ordering = compare_values_for_sort(a, b, key.nulls_first);
246
247 let ordering = if key.ascending {
248 ordering
249 } else {
250 ordering.reverse()
251 };
252
253 if ordering != Ordering::Equal {
254 return ordering;
255 }
256 }
257 Ordering::Equal
258 }
259}
260
261impl PartialEq for MergeEntry {
262 fn eq(&self, other: &Self) -> bool {
263 self.compare_to(other) == Ordering::Equal
264 }
265}
266
267impl Eq for MergeEntry {}
268
269impl PartialOrd for MergeEntry {
270 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
271 Some(self.cmp(other))
272 }
273}
274
275impl Ord for MergeEntry {
276 fn cmp(&self, other: &Self) -> Ordering {
277 other.compare_to(self)
279 }
280}
281
282fn compare_values_for_sort(a: Option<&Value>, b: Option<&Value>, nulls_first: bool) -> Ordering {
283 match (a, b) {
284 (None, None) | (Some(Value::Null), Some(Value::Null)) => Ordering::Equal,
285 (None, _) | (Some(Value::Null), _) => {
286 if nulls_first {
287 Ordering::Less
288 } else {
289 Ordering::Greater
290 }
291 }
292 (_, None) | (_, Some(Value::Null)) => {
293 if nulls_first {
294 Ordering::Greater
295 } else {
296 Ordering::Less
297 }
298 }
299 (Some(a), Some(b)) => compare_values(a, b),
300 }
301}
302
303fn compare_values(a: &Value, b: &Value) -> Ordering {
304 match (a, b) {
305 (Value::Bool(a), Value::Bool(b)) => a.cmp(b),
306 (Value::Int64(a), Value::Int64(b)) => a.cmp(b),
307 (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
308 (Value::String(a), Value::String(b)) => a.cmp(b),
309 _ => Ordering::Equal,
310 }
311}
312
313pub fn merge_sorted_runs(
317 runs: Vec<Vec<Vec<Value>>>,
318 keys: &[SortKey],
319) -> Result<Vec<Vec<Value>>, OperatorError> {
320 if runs.is_empty() {
321 return Ok(Vec::new());
322 }
323
324 if runs.len() == 1 {
325 return Ok(runs
327 .into_iter()
328 .next()
329 .expect("runs has exactly one element: checked on previous line"));
330 }
331
332 let total_rows: usize = runs.iter().map(|r| r.len()).sum();
334 let mut result = Vec::with_capacity(total_rows);
335
336 let mut positions: Vec<usize> = vec![0; runs.len()];
338
339 let mut heap = BinaryHeap::new();
341 for (run_index, run) in runs.iter().enumerate() {
342 if !run.is_empty() {
343 heap.push(MergeEntry {
344 row: run[0].clone(),
345 run_index,
346 keys: keys.to_vec(),
347 });
348 positions[run_index] = 1;
349 }
350 }
351
352 while let Some(entry) = heap.pop() {
354 result.push(entry.row);
355
356 let pos = positions[entry.run_index];
358 if pos < runs[entry.run_index].len() {
359 heap.push(MergeEntry {
360 row: runs[entry.run_index][pos].clone(),
361 run_index: entry.run_index,
362 keys: keys.to_vec(),
363 });
364 positions[entry.run_index] += 1;
365 }
366 }
367
368 Ok(result)
369}
370
371pub fn rows_to_chunks(
373 rows: Vec<Vec<Value>>,
374 chunk_size: usize,
375) -> Result<Vec<DataChunk>, OperatorError> {
376 if rows.is_empty() {
377 return Ok(Vec::new());
378 }
379
380 let num_columns = rows[0].len();
381 let num_chunks = (rows.len() + chunk_size - 1) / chunk_size;
382 let mut chunks = Vec::with_capacity(num_chunks);
383
384 for chunk_rows in rows.chunks(chunk_size) {
385 let mut columns: Vec<ValueVector> = (0..num_columns).map(|_| ValueVector::new()).collect();
386
387 for row in chunk_rows {
388 for (col_idx, col) in columns.iter_mut().enumerate() {
389 let val = row.get(col_idx).cloned().unwrap_or(Value::Null);
390 col.push(val);
391 }
392 }
393
394 chunks.push(DataChunk::new(columns));
395 }
396
397 Ok(chunks)
398}
399
400pub fn merge_sorted_chunks(
402 runs: Vec<Vec<DataChunk>>,
403 keys: &[SortKey],
404 chunk_size: usize,
405) -> Result<Vec<DataChunk>, OperatorError> {
406 let row_runs: Vec<Vec<Vec<Value>>> = runs.into_iter().map(chunks_to_rows).collect();
408
409 let merged_rows = merge_sorted_runs(row_runs, keys)?;
410 rows_to_chunks(merged_rows, chunk_size)
411}
412
413fn chunks_to_rows(chunks: Vec<DataChunk>) -> Vec<Vec<Value>> {
415 let mut rows = Vec::new();
416
417 for chunk in chunks {
418 let num_columns = chunk.num_columns();
419 for i in 0..chunk.len() {
420 let mut row = Vec::with_capacity(num_columns);
421 for col_idx in 0..num_columns {
422 let val = chunk
423 .column(col_idx)
424 .and_then(|c| c.get(i))
425 .unwrap_or(Value::Null);
426 row.push(val);
427 }
428 rows.push(row);
429 }
430 }
431
432 rows
433}
434
435pub fn concat_parallel_results(results: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
437 results.into_iter().flatten().collect()
438}
439
440pub fn merge_distinct_results(
442 results: Vec<Vec<DataChunk>>,
443) -> Result<Vec<DataChunk>, OperatorError> {
444 use std::collections::HashSet;
445
446 let mut seen: HashSet<u64> = HashSet::new();
448 let mut unique_rows: Vec<Vec<Value>> = Vec::new();
449
450 for chunks in results {
451 for chunk in chunks {
452 let num_columns = chunk.num_columns();
453 for i in 0..chunk.len() {
454 let mut row = Vec::with_capacity(num_columns);
455 for col_idx in 0..num_columns {
456 let val = chunk
457 .column(col_idx)
458 .and_then(|c| c.get(i))
459 .unwrap_or(Value::Null);
460 row.push(val);
461 }
462
463 let hash = hash_row(&row);
464 if seen.insert(hash) {
465 unique_rows.push(row);
466 }
467 }
468 }
469 }
470
471 rows_to_chunks(unique_rows, 2048)
472}
473
474fn hash_row(row: &[Value]) -> u64 {
475 use std::collections::hash_map::DefaultHasher;
476 use std::hash::{Hash, Hasher};
477
478 let mut hasher = DefaultHasher::new();
479 for value in row {
480 match value {
481 Value::Null => 0u8.hash(&mut hasher),
482 Value::Bool(b) => b.hash(&mut hasher),
483 Value::Int64(i) => i.hash(&mut hasher),
484 Value::Float64(f) => f.to_bits().hash(&mut hasher),
485 Value::String(s) => s.hash(&mut hasher),
486 _ => 0u8.hash(&mut hasher),
487 }
488 }
489 hasher.finish()
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[test]
497 fn test_mergeable_accumulator() {
498 let mut acc1 = MergeableAccumulator::new();
499 acc1.add(&Value::Int64(10));
500 acc1.add(&Value::Int64(20));
501
502 let mut acc2 = MergeableAccumulator::new();
503 acc2.add(&Value::Int64(30));
504 acc2.add(&Value::Int64(40));
505
506 acc1.merge(&acc2);
507
508 assert_eq!(acc1.count, 4);
509 assert_eq!(acc1.sum, 100.0);
510 assert_eq!(acc1.finalize_min(), Value::Int64(10));
511 assert_eq!(acc1.finalize_max(), Value::Int64(40));
512 assert_eq!(acc1.finalize_avg(), Value::Float64(25.0));
513 }
514
515 #[test]
516 fn test_merge_sorted_runs_empty() {
517 let runs: Vec<Vec<Vec<Value>>> = Vec::new();
518 let result = merge_sorted_runs(runs, &[]).unwrap();
519 assert!(result.is_empty());
520 }
521
522 #[test]
523 fn test_merge_sorted_runs_single() {
524 let runs = vec![vec![
525 vec![Value::Int64(1)],
526 vec![Value::Int64(2)],
527 vec![Value::Int64(3)],
528 ]];
529 let keys = vec![SortKey::ascending(0)];
530
531 let result = merge_sorted_runs(runs, &keys).unwrap();
532 assert_eq!(result.len(), 3);
533 }
534
535 #[test]
536 fn test_merge_sorted_runs_multiple() {
537 let runs = vec![
541 vec![
542 vec![Value::Int64(1)],
543 vec![Value::Int64(4)],
544 vec![Value::Int64(7)],
545 ],
546 vec![
547 vec![Value::Int64(2)],
548 vec![Value::Int64(5)],
549 vec![Value::Int64(8)],
550 ],
551 vec![
552 vec![Value::Int64(3)],
553 vec![Value::Int64(6)],
554 vec![Value::Int64(9)],
555 ],
556 ];
557 let keys = vec![SortKey::ascending(0)];
558
559 let result = merge_sorted_runs(runs, &keys).unwrap();
560 assert_eq!(result.len(), 9);
561
562 for i in 0..9 {
564 assert_eq!(result[i][0], Value::Int64((i + 1) as i64));
565 }
566 }
567
568 #[test]
569 fn test_merge_sorted_runs_descending() {
570 let runs = vec![
571 vec![
572 vec![Value::Int64(7)],
573 vec![Value::Int64(4)],
574 vec![Value::Int64(1)],
575 ],
576 vec![
577 vec![Value::Int64(8)],
578 vec![Value::Int64(5)],
579 vec![Value::Int64(2)],
580 ],
581 ];
582 let keys = vec![SortKey::descending(0)];
583
584 let result = merge_sorted_runs(runs, &keys).unwrap();
585 assert_eq!(result.len(), 6);
586
587 assert_eq!(result[0][0], Value::Int64(8));
589 assert_eq!(result[1][0], Value::Int64(7));
590 assert_eq!(result[5][0], Value::Int64(1));
591 }
592
593 #[test]
594 fn test_rows_to_chunks() {
595 let rows = (0..10).map(|i| vec![Value::Int64(i)]).collect();
596 let chunks = rows_to_chunks(rows, 3).unwrap();
597
598 assert_eq!(chunks.len(), 4); assert_eq!(chunks[0].len(), 3);
600 assert_eq!(chunks[1].len(), 3);
601 assert_eq!(chunks[2].len(), 3);
602 assert_eq!(chunks[3].len(), 1);
603 }
604
605 #[test]
606 fn test_merge_distinct_results() {
607 let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[
608 Value::Int64(1),
609 Value::Int64(2),
610 Value::Int64(3),
611 ])]);
612
613 let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[
614 Value::Int64(2),
615 Value::Int64(3),
616 Value::Int64(4),
617 ])]);
618
619 let results = vec![vec![chunk1], vec![chunk2]];
620 let merged = merge_distinct_results(results).unwrap();
621
622 let total_rows: usize = merged.iter().map(DataChunk::len).sum();
623 assert_eq!(total_rows, 4); }
625
626 #[test]
627 fn test_hash_row_with_non_primitive_values() {
628 let row1 = vec![Value::List(vec![Value::Int64(1)].into())];
630 let row2 = vec![Value::List(vec![Value::Int64(2)].into())];
631 let row3 = vec![Value::Bytes(vec![1, 2, 3].into())];
632
633 let h1 = hash_row(&row1);
635 let h2 = hash_row(&row2);
636 let h3 = hash_row(&row3);
637
638 assert_eq!(h1, h2);
640 assert_eq!(h2, h3);
641 }
642
643 #[test]
644 fn test_concat_parallel_results() {
645 let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(1)])]);
646 let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(2)])]);
647 let chunk3 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(3)])]);
648
649 let results = vec![vec![chunk1], vec![chunk2, chunk3]];
650 let concatenated = concat_parallel_results(results);
651
652 assert_eq!(concatenated.len(), 3);
653 }
654}