1use crate::execution::chunk::DataChunk;
7use crate::execution::operators::OperatorError;
8use crate::execution::vector::ValueVector;
9use grafeo_common::types::Value;
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12
13pub trait MergeableOperator: Send + Sync {
17 fn merge_from(&mut self, other: Self)
19 where
20 Self: Sized;
21
22 fn supports_parallel_merge(&self) -> bool {
24 true
25 }
26}
27
28#[derive(Debug, Clone)]
32pub struct MergeableAccumulator {
33 pub count: i64,
35 pub sum: f64,
37 pub min: Option<Value>,
39 pub max: Option<Value>,
41 pub first: Option<Value>,
43 pub sum_squared: f64,
45}
46
47impl MergeableAccumulator {
48 #[must_use]
50 pub fn new() -> Self {
51 Self {
52 count: 0,
53 sum: 0.0,
54 min: None,
55 max: None,
56 first: None,
57 sum_squared: 0.0,
58 }
59 }
60
61 pub fn add(&mut self, value: &Value) {
63 if matches!(value, Value::Null) {
64 return;
65 }
66
67 self.count += 1;
68
69 if let Some(n) = value_to_f64(value) {
70 self.sum += n;
71 self.sum_squared += n * n;
72 }
73
74 if self.min.is_none() || compare_for_min(&self.min, value) {
76 self.min = Some(value.clone());
77 }
78
79 if self.max.is_none() || compare_for_max(&self.max, value) {
81 self.max = Some(value.clone());
82 }
83
84 if self.first.is_none() {
86 self.first = Some(value.clone());
87 }
88 }
89
90 pub fn merge(&mut self, other: &MergeableAccumulator) {
92 self.count += other.count;
93 self.sum += other.sum;
94 self.sum_squared += other.sum_squared;
95
96 if let Some(ref other_min) = other.min {
98 if compare_for_min(&self.min, other_min) {
99 self.min = Some(other_min.clone());
100 }
101 }
102
103 if let Some(ref other_max) = other.max {
105 if compare_for_max(&self.max, other_max) {
106 self.max = Some(other_max.clone());
107 }
108 }
109
110 if self.first.is_none() {
113 self.first = other.first.clone();
114 }
115 }
116
117 #[must_use]
119 pub fn finalize_count(&self) -> Value {
120 Value::Int64(self.count)
121 }
122
123 #[must_use]
125 pub fn finalize_sum(&self) -> Value {
126 if self.count == 0 {
127 Value::Null
128 } else {
129 Value::Float64(self.sum)
130 }
131 }
132
133 #[must_use]
135 pub fn finalize_min(&self) -> Value {
136 self.min.clone().unwrap_or(Value::Null)
137 }
138
139 #[must_use]
141 pub fn finalize_max(&self) -> Value {
142 self.max.clone().unwrap_or(Value::Null)
143 }
144
145 #[must_use]
147 pub fn finalize_avg(&self) -> Value {
148 if self.count == 0 {
149 Value::Null
150 } else {
151 Value::Float64(self.sum / self.count as f64)
152 }
153 }
154
155 #[must_use]
157 pub fn finalize_first(&self) -> Value {
158 self.first.clone().unwrap_or(Value::Null)
159 }
160}
161
162impl Default for MergeableAccumulator {
163 fn default() -> Self {
164 Self::new()
165 }
166}
167
168fn value_to_f64(value: &Value) -> Option<f64> {
169 match value {
170 Value::Int64(i) => Some(*i as f64),
171 Value::Float64(f) => Some(*f),
172 _ => None,
173 }
174}
175
176fn compare_for_min(current: &Option<Value>, new: &Value) -> bool {
177 match (current, new) {
178 (None, _) => true,
179 (Some(Value::Int64(a)), Value::Int64(b)) => b < a,
180 (Some(Value::Float64(a)), Value::Float64(b)) => b < a,
181 (Some(Value::String(a)), Value::String(b)) => b < a,
182 _ => false,
183 }
184}
185
186fn compare_for_max(current: &Option<Value>, new: &Value) -> bool {
187 match (current, new) {
188 (None, _) => true,
189 (Some(Value::Int64(a)), Value::Int64(b)) => b > a,
190 (Some(Value::Float64(a)), Value::Float64(b)) => b > a,
191 (Some(Value::String(a)), Value::String(b)) => b > a,
192 _ => false,
193 }
194}
195
196#[derive(Debug, Clone)]
198pub struct SortKey {
199 pub column: usize,
201 pub ascending: bool,
203 pub nulls_first: bool,
205}
206
207impl SortKey {
208 #[must_use]
210 pub fn ascending(column: usize) -> Self {
211 Self {
212 column,
213 ascending: true,
214 nulls_first: false,
215 }
216 }
217
218 #[must_use]
220 pub fn descending(column: usize) -> Self {
221 Self {
222 column,
223 ascending: false,
224 nulls_first: true,
225 }
226 }
227}
228
229struct MergeEntry {
231 row: Vec<Value>,
233 run_index: usize,
235 #[allow(dead_code)]
237 row_index: usize,
238 keys: Vec<SortKey>,
240}
241
242impl MergeEntry {
243 fn compare_to(&self, other: &Self) -> Ordering {
244 for key in &self.keys {
245 let a = self.row.get(key.column);
246 let b = other.row.get(key.column);
247
248 let ordering = compare_values_for_sort(a, b, key.nulls_first);
249
250 let ordering = if key.ascending {
251 ordering
252 } else {
253 ordering.reverse()
254 };
255
256 if ordering != Ordering::Equal {
257 return ordering;
258 }
259 }
260 Ordering::Equal
261 }
262}
263
264impl PartialEq for MergeEntry {
265 fn eq(&self, other: &Self) -> bool {
266 self.compare_to(other) == Ordering::Equal
267 }
268}
269
270impl Eq for MergeEntry {}
271
272impl PartialOrd for MergeEntry {
273 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
274 Some(self.cmp(other))
275 }
276}
277
278impl Ord for MergeEntry {
279 fn cmp(&self, other: &Self) -> Ordering {
280 other.compare_to(self)
282 }
283}
284
285fn compare_values_for_sort(a: Option<&Value>, b: Option<&Value>, nulls_first: bool) -> Ordering {
286 match (a, b) {
287 (None, None) | (Some(Value::Null), Some(Value::Null)) => Ordering::Equal,
288 (None, _) | (Some(Value::Null), _) => {
289 if nulls_first {
290 Ordering::Less
291 } else {
292 Ordering::Greater
293 }
294 }
295 (_, None) | (_, Some(Value::Null)) => {
296 if nulls_first {
297 Ordering::Greater
298 } else {
299 Ordering::Less
300 }
301 }
302 (Some(a), Some(b)) => compare_values(a, b),
303 }
304}
305
306fn compare_values(a: &Value, b: &Value) -> Ordering {
307 match (a, b) {
308 (Value::Bool(a), Value::Bool(b)) => a.cmp(b),
309 (Value::Int64(a), Value::Int64(b)) => a.cmp(b),
310 (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
311 (Value::String(a), Value::String(b)) => a.cmp(b),
312 _ => Ordering::Equal,
313 }
314}
315
316pub fn merge_sorted_runs(
320 runs: Vec<Vec<Vec<Value>>>,
321 keys: &[SortKey],
322) -> Result<Vec<Vec<Value>>, OperatorError> {
323 if runs.is_empty() {
324 return Ok(Vec::new());
325 }
326
327 if runs.len() == 1 {
328 return Ok(runs
330 .into_iter()
331 .next()
332 .expect("runs has exactly one element: checked on previous line"));
333 }
334
335 let total_rows: usize = runs.iter().map(|r| r.len()).sum();
337 let mut result = Vec::with_capacity(total_rows);
338
339 let mut positions: Vec<usize> = vec![0; runs.len()];
341
342 let mut heap = BinaryHeap::new();
344 for (run_index, run) in runs.iter().enumerate() {
345 if !run.is_empty() {
346 heap.push(MergeEntry {
347 row: run[0].clone(),
348 run_index,
349 row_index: 0,
350 keys: keys.to_vec(),
351 });
352 positions[run_index] = 1;
353 }
354 }
355
356 while let Some(entry) = heap.pop() {
358 result.push(entry.row);
359
360 let pos = positions[entry.run_index];
362 if pos < runs[entry.run_index].len() {
363 heap.push(MergeEntry {
364 row: runs[entry.run_index][pos].clone(),
365 run_index: entry.run_index,
366 row_index: pos,
367 keys: keys.to_vec(),
368 });
369 positions[entry.run_index] += 1;
370 }
371 }
372
373 Ok(result)
374}
375
376pub fn rows_to_chunks(
378 rows: Vec<Vec<Value>>,
379 chunk_size: usize,
380) -> Result<Vec<DataChunk>, OperatorError> {
381 if rows.is_empty() {
382 return Ok(Vec::new());
383 }
384
385 let num_columns = rows[0].len();
386 let num_chunks = (rows.len() + chunk_size - 1) / chunk_size;
387 let mut chunks = Vec::with_capacity(num_chunks);
388
389 for chunk_rows in rows.chunks(chunk_size) {
390 let mut columns: Vec<ValueVector> = (0..num_columns).map(|_| ValueVector::new()).collect();
391
392 for row in chunk_rows {
393 for (col_idx, col) in columns.iter_mut().enumerate() {
394 let val = row.get(col_idx).cloned().unwrap_or(Value::Null);
395 col.push(val);
396 }
397 }
398
399 chunks.push(DataChunk::new(columns));
400 }
401
402 Ok(chunks)
403}
404
405pub fn merge_sorted_chunks(
407 runs: Vec<Vec<DataChunk>>,
408 keys: &[SortKey],
409 chunk_size: usize,
410) -> Result<Vec<DataChunk>, OperatorError> {
411 let row_runs: Vec<Vec<Vec<Value>>> = runs
413 .into_iter()
414 .map(|chunks| chunks_to_rows(chunks))
415 .collect();
416
417 let merged_rows = merge_sorted_runs(row_runs, keys)?;
418 rows_to_chunks(merged_rows, chunk_size)
419}
420
421fn chunks_to_rows(chunks: Vec<DataChunk>) -> Vec<Vec<Value>> {
423 let mut rows = Vec::new();
424
425 for chunk in chunks {
426 let num_columns = chunk.num_columns();
427 for i in 0..chunk.len() {
428 let mut row = Vec::with_capacity(num_columns);
429 for col_idx in 0..num_columns {
430 let val = chunk
431 .column(col_idx)
432 .and_then(|c| c.get(i))
433 .unwrap_or(Value::Null);
434 row.push(val);
435 }
436 rows.push(row);
437 }
438 }
439
440 rows
441}
442
443pub fn concat_parallel_results(results: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
445 results.into_iter().flatten().collect()
446}
447
448pub fn merge_distinct_results(
450 results: Vec<Vec<DataChunk>>,
451) -> Result<Vec<DataChunk>, OperatorError> {
452 use std::collections::HashSet;
453
454 let mut seen: HashSet<u64> = HashSet::new();
456 let mut unique_rows: Vec<Vec<Value>> = Vec::new();
457
458 for chunks in results {
459 for chunk in chunks {
460 let num_columns = chunk.num_columns();
461 for i in 0..chunk.len() {
462 let mut row = Vec::with_capacity(num_columns);
463 for col_idx in 0..num_columns {
464 let val = chunk
465 .column(col_idx)
466 .and_then(|c| c.get(i))
467 .unwrap_or(Value::Null);
468 row.push(val);
469 }
470
471 let hash = hash_row(&row);
472 if seen.insert(hash) {
473 unique_rows.push(row);
474 }
475 }
476 }
477 }
478
479 rows_to_chunks(unique_rows, 2048)
480}
481
482fn hash_row(row: &[Value]) -> u64 {
483 use std::collections::hash_map::DefaultHasher;
484 use std::hash::{Hash, Hasher};
485
486 let mut hasher = DefaultHasher::new();
487 for value in row {
488 match value {
489 Value::Null => 0u8.hash(&mut hasher),
490 Value::Bool(b) => b.hash(&mut hasher),
491 Value::Int64(i) => i.hash(&mut hasher),
492 Value::Float64(f) => f.to_bits().hash(&mut hasher),
493 Value::String(s) => s.hash(&mut hasher),
494 _ => 0u8.hash(&mut hasher),
495 }
496 }
497 hasher.finish()
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503
504 #[test]
505 fn test_mergeable_accumulator() {
506 let mut acc1 = MergeableAccumulator::new();
507 acc1.add(&Value::Int64(10));
508 acc1.add(&Value::Int64(20));
509
510 let mut acc2 = MergeableAccumulator::new();
511 acc2.add(&Value::Int64(30));
512 acc2.add(&Value::Int64(40));
513
514 acc1.merge(&acc2);
515
516 assert_eq!(acc1.count, 4);
517 assert_eq!(acc1.sum, 100.0);
518 assert_eq!(acc1.finalize_min(), Value::Int64(10));
519 assert_eq!(acc1.finalize_max(), Value::Int64(40));
520 assert_eq!(acc1.finalize_avg(), Value::Float64(25.0));
521 }
522
523 #[test]
524 fn test_merge_sorted_runs_empty() {
525 let runs: Vec<Vec<Vec<Value>>> = Vec::new();
526 let result = merge_sorted_runs(runs, &[]).unwrap();
527 assert!(result.is_empty());
528 }
529
530 #[test]
531 fn test_merge_sorted_runs_single() {
532 let runs = vec![vec![
533 vec![Value::Int64(1)],
534 vec![Value::Int64(2)],
535 vec![Value::Int64(3)],
536 ]];
537 let keys = vec![SortKey::ascending(0)];
538
539 let result = merge_sorted_runs(runs, &keys).unwrap();
540 assert_eq!(result.len(), 3);
541 }
542
543 #[test]
544 fn test_merge_sorted_runs_multiple() {
545 let runs = vec![
549 vec![
550 vec![Value::Int64(1)],
551 vec![Value::Int64(4)],
552 vec![Value::Int64(7)],
553 ],
554 vec![
555 vec![Value::Int64(2)],
556 vec![Value::Int64(5)],
557 vec![Value::Int64(8)],
558 ],
559 vec![
560 vec![Value::Int64(3)],
561 vec![Value::Int64(6)],
562 vec![Value::Int64(9)],
563 ],
564 ];
565 let keys = vec![SortKey::ascending(0)];
566
567 let result = merge_sorted_runs(runs, &keys).unwrap();
568 assert_eq!(result.len(), 9);
569
570 for i in 0..9 {
572 assert_eq!(result[i][0], Value::Int64((i + 1) as i64));
573 }
574 }
575
576 #[test]
577 fn test_merge_sorted_runs_descending() {
578 let runs = vec![
579 vec![
580 vec![Value::Int64(7)],
581 vec![Value::Int64(4)],
582 vec![Value::Int64(1)],
583 ],
584 vec![
585 vec![Value::Int64(8)],
586 vec![Value::Int64(5)],
587 vec![Value::Int64(2)],
588 ],
589 ];
590 let keys = vec![SortKey::descending(0)];
591
592 let result = merge_sorted_runs(runs, &keys).unwrap();
593 assert_eq!(result.len(), 6);
594
595 assert_eq!(result[0][0], Value::Int64(8));
597 assert_eq!(result[1][0], Value::Int64(7));
598 assert_eq!(result[5][0], Value::Int64(1));
599 }
600
601 #[test]
602 fn test_rows_to_chunks() {
603 let rows = (0..10).map(|i| vec![Value::Int64(i)]).collect();
604 let chunks = rows_to_chunks(rows, 3).unwrap();
605
606 assert_eq!(chunks.len(), 4); assert_eq!(chunks[0].len(), 3);
608 assert_eq!(chunks[1].len(), 3);
609 assert_eq!(chunks[2].len(), 3);
610 assert_eq!(chunks[3].len(), 1);
611 }
612
613 #[test]
614 fn test_merge_distinct_results() {
615 let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[
616 Value::Int64(1),
617 Value::Int64(2),
618 Value::Int64(3),
619 ])]);
620
621 let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[
622 Value::Int64(2),
623 Value::Int64(3),
624 Value::Int64(4),
625 ])]);
626
627 let results = vec![vec![chunk1], vec![chunk2]];
628 let merged = merge_distinct_results(results).unwrap();
629
630 let total_rows: usize = merged.iter().map(DataChunk::len).sum();
631 assert_eq!(total_rows, 4); }
633
634 #[test]
635 fn test_concat_parallel_results() {
636 let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(1)])]);
637 let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(2)])]);
638 let chunk3 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(3)])]);
639
640 let results = vec![vec![chunk1], vec![chunk2, chunk3]];
641 let concatenated = concat_parallel_results(results);
642
643 assert_eq!(concatenated.len(), 3);
644 }
645}