1use crate::execution::chunk::DataChunk;
7use crate::execution::operators::OperatorError;
8use crate::execution::vector::ValueVector;
9use grafeo_common::types::Value;
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12
13pub trait MergeableOperator: Send + Sync {
17 fn merge_from(&mut self, other: Self)
19 where
20 Self: Sized;
21
22 fn supports_parallel_merge(&self) -> bool {
24 true
25 }
26}
27
28#[derive(Debug, Clone)]
32pub struct MergeableAccumulator {
33 pub count: i64,
35 pub sum: f64,
37 pub min: Option<Value>,
39 pub max: Option<Value>,
41 pub first: Option<Value>,
43 pub sum_squared: f64,
45}
46
47impl MergeableAccumulator {
48 #[must_use]
50 pub fn new() -> Self {
51 Self {
52 count: 0,
53 sum: 0.0,
54 min: None,
55 max: None,
56 first: None,
57 sum_squared: 0.0,
58 }
59 }
60
61 pub fn add(&mut self, value: &Value) {
63 if matches!(value, Value::Null) {
64 return;
65 }
66
67 self.count += 1;
68
69 if let Some(n) = value_to_f64(value) {
70 self.sum += n;
71 self.sum_squared += n * n;
72 }
73
74 if self.min.is_none() || compare_for_min(&self.min, value) {
76 self.min = Some(value.clone());
77 }
78
79 if self.max.is_none() || compare_for_max(&self.max, value) {
81 self.max = Some(value.clone());
82 }
83
84 if self.first.is_none() {
86 self.first = Some(value.clone());
87 }
88 }
89
90 pub fn merge(&mut self, other: &MergeableAccumulator) {
92 self.count += other.count;
93 self.sum += other.sum;
94 self.sum_squared += other.sum_squared;
95
96 if let Some(ref other_min) = other.min
98 && compare_for_min(&self.min, other_min)
99 {
100 self.min = Some(other_min.clone());
101 }
102
103 if let Some(ref other_max) = other.max
105 && compare_for_max(&self.max, other_max)
106 {
107 self.max = Some(other_max.clone());
108 }
109
110 if self.first.is_none() {
113 self.first = other.first.clone();
114 }
115 }
116
117 #[must_use]
119 pub fn finalize_count(&self) -> Value {
120 Value::Int64(self.count)
121 }
122
123 #[must_use]
125 pub fn finalize_sum(&self) -> Value {
126 if self.count == 0 {
127 Value::Null
128 } else {
129 Value::Float64(self.sum)
130 }
131 }
132
133 #[must_use]
135 pub fn finalize_min(&self) -> Value {
136 self.min.clone().unwrap_or(Value::Null)
137 }
138
139 #[must_use]
141 pub fn finalize_max(&self) -> Value {
142 self.max.clone().unwrap_or(Value::Null)
143 }
144
145 #[must_use]
147 pub fn finalize_avg(&self) -> Value {
148 if self.count == 0 {
149 Value::Null
150 } else {
151 Value::Float64(self.sum / self.count as f64)
152 }
153 }
154
155 #[must_use]
157 pub fn finalize_first(&self) -> Value {
158 self.first.clone().unwrap_or(Value::Null)
159 }
160}
161
162impl Default for MergeableAccumulator {
163 fn default() -> Self {
164 Self::new()
165 }
166}
167
168fn value_to_f64(value: &Value) -> Option<f64> {
169 match value {
170 Value::Int64(i) => Some(*i as f64),
171 Value::Float64(f) => Some(*f),
172 _ => None,
173 }
174}
175
176fn compare_for_min(current: &Option<Value>, new: &Value) -> bool {
177 match (current, new) {
178 (None, _) => true,
179 (Some(Value::Int64(a)), Value::Int64(b)) => b < a,
180 (Some(Value::Float64(a)), Value::Float64(b)) => b < a,
181 (Some(Value::String(a)), Value::String(b)) => b < a,
182 _ => false,
183 }
184}
185
186fn compare_for_max(current: &Option<Value>, new: &Value) -> bool {
187 match (current, new) {
188 (None, _) => true,
189 (Some(Value::Int64(a)), Value::Int64(b)) => b > a,
190 (Some(Value::Float64(a)), Value::Float64(b)) => b > a,
191 (Some(Value::String(a)), Value::String(b)) => b > a,
192 _ => false,
193 }
194}
195
196#[derive(Debug, Clone)]
198pub struct SortKey {
199 pub column: usize,
201 pub ascending: bool,
203 pub nulls_first: bool,
205}
206
207impl SortKey {
208 #[must_use]
210 pub fn ascending(column: usize) -> Self {
211 Self {
212 column,
213 ascending: true,
214 nulls_first: false,
215 }
216 }
217
218 #[must_use]
220 pub fn descending(column: usize) -> Self {
221 Self {
222 column,
223 ascending: false,
224 nulls_first: true,
225 }
226 }
227}
228
229struct MergeEntry {
231 row: Vec<Value>,
233 run_index: usize,
235 #[allow(dead_code)]
237 row_index: usize,
238 keys: Vec<SortKey>,
240}
241
242impl MergeEntry {
243 fn compare_to(&self, other: &Self) -> Ordering {
244 for key in &self.keys {
245 let a = self.row.get(key.column);
246 let b = other.row.get(key.column);
247
248 let ordering = compare_values_for_sort(a, b, key.nulls_first);
249
250 let ordering = if key.ascending {
251 ordering
252 } else {
253 ordering.reverse()
254 };
255
256 if ordering != Ordering::Equal {
257 return ordering;
258 }
259 }
260 Ordering::Equal
261 }
262}
263
264impl PartialEq for MergeEntry {
265 fn eq(&self, other: &Self) -> bool {
266 self.compare_to(other) == Ordering::Equal
267 }
268}
269
270impl Eq for MergeEntry {}
271
272impl PartialOrd for MergeEntry {
273 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
274 Some(self.cmp(other))
275 }
276}
277
278impl Ord for MergeEntry {
279 fn cmp(&self, other: &Self) -> Ordering {
280 other.compare_to(self)
282 }
283}
284
285fn compare_values_for_sort(a: Option<&Value>, b: Option<&Value>, nulls_first: bool) -> Ordering {
286 match (a, b) {
287 (None, None) | (Some(Value::Null), Some(Value::Null)) => Ordering::Equal,
288 (None, _) | (Some(Value::Null), _) => {
289 if nulls_first {
290 Ordering::Less
291 } else {
292 Ordering::Greater
293 }
294 }
295 (_, None) | (_, Some(Value::Null)) => {
296 if nulls_first {
297 Ordering::Greater
298 } else {
299 Ordering::Less
300 }
301 }
302 (Some(a), Some(b)) => compare_values(a, b),
303 }
304}
305
306fn compare_values(a: &Value, b: &Value) -> Ordering {
307 match (a, b) {
308 (Value::Bool(a), Value::Bool(b)) => a.cmp(b),
309 (Value::Int64(a), Value::Int64(b)) => a.cmp(b),
310 (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
311 (Value::String(a), Value::String(b)) => a.cmp(b),
312 _ => Ordering::Equal,
313 }
314}
315
316pub fn merge_sorted_runs(
320 runs: Vec<Vec<Vec<Value>>>,
321 keys: &[SortKey],
322) -> Result<Vec<Vec<Value>>, OperatorError> {
323 if runs.is_empty() {
324 return Ok(Vec::new());
325 }
326
327 if runs.len() == 1 {
328 return Ok(runs
330 .into_iter()
331 .next()
332 .expect("runs has exactly one element: checked on previous line"));
333 }
334
335 let total_rows: usize = runs.iter().map(|r| r.len()).sum();
337 let mut result = Vec::with_capacity(total_rows);
338
339 let mut positions: Vec<usize> = vec![0; runs.len()];
341
342 let mut heap = BinaryHeap::new();
344 for (run_index, run) in runs.iter().enumerate() {
345 if !run.is_empty() {
346 heap.push(MergeEntry {
347 row: run[0].clone(),
348 run_index,
349 row_index: 0,
350 keys: keys.to_vec(),
351 });
352 positions[run_index] = 1;
353 }
354 }
355
356 while let Some(entry) = heap.pop() {
358 result.push(entry.row);
359
360 let pos = positions[entry.run_index];
362 if pos < runs[entry.run_index].len() {
363 heap.push(MergeEntry {
364 row: runs[entry.run_index][pos].clone(),
365 run_index: entry.run_index,
366 row_index: pos,
367 keys: keys.to_vec(),
368 });
369 positions[entry.run_index] += 1;
370 }
371 }
372
373 Ok(result)
374}
375
376pub fn rows_to_chunks(
378 rows: Vec<Vec<Value>>,
379 chunk_size: usize,
380) -> Result<Vec<DataChunk>, OperatorError> {
381 if rows.is_empty() {
382 return Ok(Vec::new());
383 }
384
385 let num_columns = rows[0].len();
386 let num_chunks = (rows.len() + chunk_size - 1) / chunk_size;
387 let mut chunks = Vec::with_capacity(num_chunks);
388
389 for chunk_rows in rows.chunks(chunk_size) {
390 let mut columns: Vec<ValueVector> = (0..num_columns).map(|_| ValueVector::new()).collect();
391
392 for row in chunk_rows {
393 for (col_idx, col) in columns.iter_mut().enumerate() {
394 let val = row.get(col_idx).cloned().unwrap_or(Value::Null);
395 col.push(val);
396 }
397 }
398
399 chunks.push(DataChunk::new(columns));
400 }
401
402 Ok(chunks)
403}
404
405pub fn merge_sorted_chunks(
407 runs: Vec<Vec<DataChunk>>,
408 keys: &[SortKey],
409 chunk_size: usize,
410) -> Result<Vec<DataChunk>, OperatorError> {
411 let row_runs: Vec<Vec<Vec<Value>>> = runs.into_iter().map(chunks_to_rows).collect();
413
414 let merged_rows = merge_sorted_runs(row_runs, keys)?;
415 rows_to_chunks(merged_rows, chunk_size)
416}
417
418fn chunks_to_rows(chunks: Vec<DataChunk>) -> Vec<Vec<Value>> {
420 let mut rows = Vec::new();
421
422 for chunk in chunks {
423 let num_columns = chunk.num_columns();
424 for i in 0..chunk.len() {
425 let mut row = Vec::with_capacity(num_columns);
426 for col_idx in 0..num_columns {
427 let val = chunk
428 .column(col_idx)
429 .and_then(|c| c.get(i))
430 .unwrap_or(Value::Null);
431 row.push(val);
432 }
433 rows.push(row);
434 }
435 }
436
437 rows
438}
439
440pub fn concat_parallel_results(results: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
442 results.into_iter().flatten().collect()
443}
444
445pub fn merge_distinct_results(
447 results: Vec<Vec<DataChunk>>,
448) -> Result<Vec<DataChunk>, OperatorError> {
449 use std::collections::HashSet;
450
451 let mut seen: HashSet<u64> = HashSet::new();
453 let mut unique_rows: Vec<Vec<Value>> = Vec::new();
454
455 for chunks in results {
456 for chunk in chunks {
457 let num_columns = chunk.num_columns();
458 for i in 0..chunk.len() {
459 let mut row = Vec::with_capacity(num_columns);
460 for col_idx in 0..num_columns {
461 let val = chunk
462 .column(col_idx)
463 .and_then(|c| c.get(i))
464 .unwrap_or(Value::Null);
465 row.push(val);
466 }
467
468 let hash = hash_row(&row);
469 if seen.insert(hash) {
470 unique_rows.push(row);
471 }
472 }
473 }
474 }
475
476 rows_to_chunks(unique_rows, 2048)
477}
478
479fn hash_row(row: &[Value]) -> u64 {
480 use std::collections::hash_map::DefaultHasher;
481 use std::hash::{Hash, Hasher};
482
483 let mut hasher = DefaultHasher::new();
484 for value in row {
485 match value {
486 Value::Null => 0u8.hash(&mut hasher),
487 Value::Bool(b) => b.hash(&mut hasher),
488 Value::Int64(i) => i.hash(&mut hasher),
489 Value::Float64(f) => f.to_bits().hash(&mut hasher),
490 Value::String(s) => s.hash(&mut hasher),
491 _ => 0u8.hash(&mut hasher),
492 }
493 }
494 hasher.finish()
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn test_mergeable_accumulator() {
503 let mut acc1 = MergeableAccumulator::new();
504 acc1.add(&Value::Int64(10));
505 acc1.add(&Value::Int64(20));
506
507 let mut acc2 = MergeableAccumulator::new();
508 acc2.add(&Value::Int64(30));
509 acc2.add(&Value::Int64(40));
510
511 acc1.merge(&acc2);
512
513 assert_eq!(acc1.count, 4);
514 assert_eq!(acc1.sum, 100.0);
515 assert_eq!(acc1.finalize_min(), Value::Int64(10));
516 assert_eq!(acc1.finalize_max(), Value::Int64(40));
517 assert_eq!(acc1.finalize_avg(), Value::Float64(25.0));
518 }
519
520 #[test]
521 fn test_merge_sorted_runs_empty() {
522 let runs: Vec<Vec<Vec<Value>>> = Vec::new();
523 let result = merge_sorted_runs(runs, &[]).unwrap();
524 assert!(result.is_empty());
525 }
526
527 #[test]
528 fn test_merge_sorted_runs_single() {
529 let runs = vec![vec![
530 vec![Value::Int64(1)],
531 vec![Value::Int64(2)],
532 vec![Value::Int64(3)],
533 ]];
534 let keys = vec![SortKey::ascending(0)];
535
536 let result = merge_sorted_runs(runs, &keys).unwrap();
537 assert_eq!(result.len(), 3);
538 }
539
540 #[test]
541 fn test_merge_sorted_runs_multiple() {
542 let runs = vec![
546 vec![
547 vec![Value::Int64(1)],
548 vec![Value::Int64(4)],
549 vec![Value::Int64(7)],
550 ],
551 vec![
552 vec![Value::Int64(2)],
553 vec![Value::Int64(5)],
554 vec![Value::Int64(8)],
555 ],
556 vec![
557 vec![Value::Int64(3)],
558 vec![Value::Int64(6)],
559 vec![Value::Int64(9)],
560 ],
561 ];
562 let keys = vec![SortKey::ascending(0)];
563
564 let result = merge_sorted_runs(runs, &keys).unwrap();
565 assert_eq!(result.len(), 9);
566
567 for i in 0..9 {
569 assert_eq!(result[i][0], Value::Int64((i + 1) as i64));
570 }
571 }
572
573 #[test]
574 fn test_merge_sorted_runs_descending() {
575 let runs = vec![
576 vec![
577 vec![Value::Int64(7)],
578 vec![Value::Int64(4)],
579 vec![Value::Int64(1)],
580 ],
581 vec![
582 vec![Value::Int64(8)],
583 vec![Value::Int64(5)],
584 vec![Value::Int64(2)],
585 ],
586 ];
587 let keys = vec![SortKey::descending(0)];
588
589 let result = merge_sorted_runs(runs, &keys).unwrap();
590 assert_eq!(result.len(), 6);
591
592 assert_eq!(result[0][0], Value::Int64(8));
594 assert_eq!(result[1][0], Value::Int64(7));
595 assert_eq!(result[5][0], Value::Int64(1));
596 }
597
598 #[test]
599 fn test_rows_to_chunks() {
600 let rows = (0..10).map(|i| vec![Value::Int64(i)]).collect();
601 let chunks = rows_to_chunks(rows, 3).unwrap();
602
603 assert_eq!(chunks.len(), 4); assert_eq!(chunks[0].len(), 3);
605 assert_eq!(chunks[1].len(), 3);
606 assert_eq!(chunks[2].len(), 3);
607 assert_eq!(chunks[3].len(), 1);
608 }
609
610 #[test]
611 fn test_merge_distinct_results() {
612 let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[
613 Value::Int64(1),
614 Value::Int64(2),
615 Value::Int64(3),
616 ])]);
617
618 let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[
619 Value::Int64(2),
620 Value::Int64(3),
621 Value::Int64(4),
622 ])]);
623
624 let results = vec![vec![chunk1], vec![chunk2]];
625 let merged = merge_distinct_results(results).unwrap();
626
627 let total_rows: usize = merged.iter().map(DataChunk::len).sum();
628 assert_eq!(total_rows, 4); }
630
631 #[test]
632 fn test_concat_parallel_results() {
633 let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(1)])]);
634 let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(2)])]);
635 let chunk3 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(3)])]);
636
637 let results = vec![vec![chunk1], vec![chunk2, chunk3]];
638 let concatenated = concat_parallel_results(results);
639
640 assert_eq!(concatenated.len(), 3);
641 }
642}