1use crate::execution::chunk::DataChunk;
7use crate::execution::operators::OperatorError;
8use crate::execution::vector::ValueVector;
9use grafeo_common::types::Value;
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12
13pub trait MergeableOperator: Send + Sync {
17 fn merge_from(&mut self, other: Self)
19 where
20 Self: Sized;
21
22 fn supports_parallel_merge(&self) -> bool {
24 true
25 }
26}
27
28#[derive(Debug, Clone)]
32pub struct MergeableAccumulator {
33 pub count: i64,
35 pub sum: f64,
37 pub min: Option<Value>,
39 pub max: Option<Value>,
41 pub first: Option<Value>,
43 pub sum_squared: f64,
45}
46
47impl MergeableAccumulator {
48 #[must_use]
50 pub fn new() -> Self {
51 Self {
52 count: 0,
53 sum: 0.0,
54 min: None,
55 max: None,
56 first: None,
57 sum_squared: 0.0,
58 }
59 }
60
61 pub fn add(&mut self, value: &Value) {
63 if matches!(value, Value::Null) {
64 return;
65 }
66
67 self.count += 1;
68
69 if let Some(n) = value_to_f64(value) {
70 self.sum += n;
71 self.sum_squared += n * n;
72 }
73
74 if self.min.is_none() || compare_for_min(&self.min, value) {
76 self.min = Some(value.clone());
77 }
78
79 if self.max.is_none() || compare_for_max(&self.max, value) {
81 self.max = Some(value.clone());
82 }
83
84 if self.first.is_none() {
86 self.first = Some(value.clone());
87 }
88 }
89
90 pub fn merge(&mut self, other: &MergeableAccumulator) {
92 self.count += other.count;
93 self.sum += other.sum;
94 self.sum_squared += other.sum_squared;
95
96 if let Some(ref other_min) = other.min
98 && compare_for_min(&self.min, other_min)
99 {
100 self.min = Some(other_min.clone());
101 }
102
103 if let Some(ref other_max) = other.max
105 && compare_for_max(&self.max, other_max)
106 {
107 self.max = Some(other_max.clone());
108 }
109
110 if self.first.is_none() {
113 self.first.clone_from(&other.first);
114 }
115 }
116
117 #[must_use]
119 pub fn finalize_count(&self) -> Value {
120 Value::Int64(self.count)
121 }
122
123 #[must_use]
125 pub fn finalize_sum(&self) -> Value {
126 if self.count == 0 {
127 Value::Null
128 } else {
129 Value::Float64(self.sum)
130 }
131 }
132
133 #[must_use]
135 pub fn finalize_min(&self) -> Value {
136 self.min.clone().unwrap_or(Value::Null)
137 }
138
139 #[must_use]
141 pub fn finalize_max(&self) -> Value {
142 self.max.clone().unwrap_or(Value::Null)
143 }
144
145 #[must_use]
147 pub fn finalize_avg(&self) -> Value {
148 if self.count == 0 {
149 Value::Null
150 } else {
151 Value::Float64(self.sum / self.count as f64)
152 }
153 }
154
155 #[must_use]
157 pub fn finalize_first(&self) -> Value {
158 self.first.clone().unwrap_or(Value::Null)
159 }
160}
161
162impl Default for MergeableAccumulator {
163 fn default() -> Self {
164 Self::new()
165 }
166}
167
168fn value_to_f64(value: &Value) -> Option<f64> {
169 match value {
170 Value::Int64(i) => Some(*i as f64),
171 Value::Float64(f) => Some(*f),
172 _ => None,
173 }
174}
175
176fn compare_for_min(current: &Option<Value>, new: &Value) -> bool {
177 match (current, new) {
178 (None, _) => true,
179 (Some(Value::Int64(a)), Value::Int64(b)) => b < a,
180 (Some(Value::Float64(a)), Value::Float64(b)) => b < a,
181 (Some(Value::String(a)), Value::String(b)) => b < a,
182 _ => false,
183 }
184}
185
186fn compare_for_max(current: &Option<Value>, new: &Value) -> bool {
187 match (current, new) {
188 (None, _) => true,
189 (Some(Value::Int64(a)), Value::Int64(b)) => b > a,
190 (Some(Value::Float64(a)), Value::Float64(b)) => b > a,
191 (Some(Value::String(a)), Value::String(b)) => b > a,
192 _ => false,
193 }
194}
195
196#[derive(Debug, Clone)]
198pub struct SortKey {
199 pub column: usize,
201 pub ascending: bool,
203 pub nulls_first: bool,
205}
206
207impl SortKey {
208 #[must_use]
210 pub fn ascending(column: usize) -> Self {
211 Self {
212 column,
213 ascending: true,
214 nulls_first: false,
215 }
216 }
217
218 #[must_use]
220 pub fn descending(column: usize) -> Self {
221 Self {
222 column,
223 ascending: false,
224 nulls_first: true,
225 }
226 }
227}
228
229struct MergeEntry {
231 row: Vec<Value>,
233 run_index: usize,
235 keys: Vec<SortKey>,
237}
238
239impl MergeEntry {
240 fn compare_to(&self, other: &Self) -> Ordering {
241 for key in &self.keys {
242 let a = self.row.get(key.column);
243 let b = other.row.get(key.column);
244
245 let ordering = compare_values_for_sort(a, b, key.nulls_first);
246
247 let ordering = if key.ascending {
248 ordering
249 } else {
250 ordering.reverse()
251 };
252
253 if ordering != Ordering::Equal {
254 return ordering;
255 }
256 }
257 Ordering::Equal
258 }
259}
260
261impl PartialEq for MergeEntry {
262 fn eq(&self, other: &Self) -> bool {
263 self.compare_to(other) == Ordering::Equal
264 }
265}
266
267impl Eq for MergeEntry {}
268
269impl PartialOrd for MergeEntry {
270 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
271 Some(self.cmp(other))
272 }
273}
274
275impl Ord for MergeEntry {
276 fn cmp(&self, other: &Self) -> Ordering {
277 other.compare_to(self)
279 }
280}
281
282fn compare_values_for_sort(a: Option<&Value>, b: Option<&Value>, nulls_first: bool) -> Ordering {
283 match (a, b) {
284 (None, None) | (Some(Value::Null), Some(Value::Null)) => Ordering::Equal,
285 (None, _) | (Some(Value::Null), _) => {
286 if nulls_first {
287 Ordering::Less
288 } else {
289 Ordering::Greater
290 }
291 }
292 (_, None) | (_, Some(Value::Null)) => {
293 if nulls_first {
294 Ordering::Greater
295 } else {
296 Ordering::Less
297 }
298 }
299 (Some(a), Some(b)) => compare_values(a, b),
300 }
301}
302
303fn compare_values(a: &Value, b: &Value) -> Ordering {
304 match (a, b) {
305 (Value::Bool(a), Value::Bool(b)) => a.cmp(b),
306 (Value::Int64(a), Value::Int64(b)) => a.cmp(b),
307 (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
308 (Value::String(a), Value::String(b)) => a.cmp(b),
309 (Value::Timestamp(a), Value::Timestamp(b)) => a.cmp(b),
310 (Value::Date(a), Value::Date(b)) => a.cmp(b),
311 (Value::Time(a), Value::Time(b)) => a.cmp(b),
312 _ => Ordering::Equal,
313 }
314}
315
316pub fn merge_sorted_runs(
320 runs: Vec<Vec<Vec<Value>>>,
321 keys: &[SortKey],
322) -> Result<Vec<Vec<Value>>, OperatorError> {
323 if runs.is_empty() {
324 return Ok(Vec::new());
325 }
326
327 if runs.len() == 1 {
328 return Ok(runs
330 .into_iter()
331 .next()
332 .expect("runs has exactly one element: checked on previous line"));
333 }
334
335 let total_rows: usize = runs.iter().map(|r| r.len()).sum();
337 let mut result = Vec::with_capacity(total_rows);
338
339 let mut positions: Vec<usize> = vec![0; runs.len()];
341
342 let mut heap = BinaryHeap::new();
344 for (run_index, run) in runs.iter().enumerate() {
345 if !run.is_empty() {
346 heap.push(MergeEntry {
347 row: run[0].clone(),
348 run_index,
349 keys: keys.to_vec(),
350 });
351 positions[run_index] = 1;
352 }
353 }
354
355 while let Some(entry) = heap.pop() {
357 result.push(entry.row);
358
359 let pos = positions[entry.run_index];
361 if pos < runs[entry.run_index].len() {
362 heap.push(MergeEntry {
363 row: runs[entry.run_index][pos].clone(),
364 run_index: entry.run_index,
365 keys: keys.to_vec(),
366 });
367 positions[entry.run_index] += 1;
368 }
369 }
370
371 Ok(result)
372}
373
374pub fn rows_to_chunks(
376 rows: Vec<Vec<Value>>,
377 chunk_size: usize,
378) -> Result<Vec<DataChunk>, OperatorError> {
379 if rows.is_empty() {
380 return Ok(Vec::new());
381 }
382
383 let num_columns = rows[0].len();
384 let num_chunks = (rows.len() + chunk_size - 1) / chunk_size;
385 let mut chunks = Vec::with_capacity(num_chunks);
386
387 for chunk_rows in rows.chunks(chunk_size) {
388 let mut columns: Vec<ValueVector> = (0..num_columns).map(|_| ValueVector::new()).collect();
389
390 for row in chunk_rows {
391 for (col_idx, col) in columns.iter_mut().enumerate() {
392 let val = row.get(col_idx).cloned().unwrap_or(Value::Null);
393 col.push(val);
394 }
395 }
396
397 chunks.push(DataChunk::new(columns));
398 }
399
400 Ok(chunks)
401}
402
403pub fn merge_sorted_chunks(
405 runs: Vec<Vec<DataChunk>>,
406 keys: &[SortKey],
407 chunk_size: usize,
408) -> Result<Vec<DataChunk>, OperatorError> {
409 let row_runs: Vec<Vec<Vec<Value>>> = runs.into_iter().map(chunks_to_rows).collect();
411
412 let merged_rows = merge_sorted_runs(row_runs, keys)?;
413 rows_to_chunks(merged_rows, chunk_size)
414}
415
416fn chunks_to_rows(chunks: Vec<DataChunk>) -> Vec<Vec<Value>> {
418 let mut rows = Vec::new();
419
420 for chunk in chunks {
421 let num_columns = chunk.num_columns();
422 for i in 0..chunk.len() {
423 let mut row = Vec::with_capacity(num_columns);
424 for col_idx in 0..num_columns {
425 let val = chunk
426 .column(col_idx)
427 .and_then(|c| c.get(i))
428 .unwrap_or(Value::Null);
429 row.push(val);
430 }
431 rows.push(row);
432 }
433 }
434
435 rows
436}
437
438pub fn concat_parallel_results(results: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
440 results.into_iter().flatten().collect()
441}
442
443pub fn merge_distinct_results(
445 results: Vec<Vec<DataChunk>>,
446) -> Result<Vec<DataChunk>, OperatorError> {
447 use std::collections::HashSet;
448
449 let mut seen: HashSet<u64> = HashSet::new();
451 let mut unique_rows: Vec<Vec<Value>> = Vec::new();
452
453 for chunks in results {
454 for chunk in chunks {
455 let num_columns = chunk.num_columns();
456 for i in 0..chunk.len() {
457 let mut row = Vec::with_capacity(num_columns);
458 for col_idx in 0..num_columns {
459 let val = chunk
460 .column(col_idx)
461 .and_then(|c| c.get(i))
462 .unwrap_or(Value::Null);
463 row.push(val);
464 }
465
466 let hash = hash_row(&row);
467 if seen.insert(hash) {
468 unique_rows.push(row);
469 }
470 }
471 }
472 }
473
474 rows_to_chunks(unique_rows, 2048)
475}
476
477fn hash_row(row: &[Value]) -> u64 {
478 use std::collections::hash_map::DefaultHasher;
479 use std::hash::{Hash, Hasher};
480
481 let mut hasher = DefaultHasher::new();
482 for value in row {
483 match value {
484 Value::Null => 0u8.hash(&mut hasher),
485 Value::Bool(b) => b.hash(&mut hasher),
486 Value::Int64(i) => i.hash(&mut hasher),
487 Value::Float64(f) => f.to_bits().hash(&mut hasher),
488 Value::String(s) => s.hash(&mut hasher),
489 _ => 0u8.hash(&mut hasher),
490 }
491 }
492 hasher.finish()
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn test_mergeable_accumulator() {
501 let mut acc1 = MergeableAccumulator::new();
502 acc1.add(&Value::Int64(10));
503 acc1.add(&Value::Int64(20));
504
505 let mut acc2 = MergeableAccumulator::new();
506 acc2.add(&Value::Int64(30));
507 acc2.add(&Value::Int64(40));
508
509 acc1.merge(&acc2);
510
511 assert_eq!(acc1.count, 4);
512 assert_eq!(acc1.sum, 100.0);
513 assert_eq!(acc1.finalize_min(), Value::Int64(10));
514 assert_eq!(acc1.finalize_max(), Value::Int64(40));
515 assert_eq!(acc1.finalize_avg(), Value::Float64(25.0));
516 }
517
518 #[test]
519 fn test_merge_sorted_runs_empty() {
520 let runs: Vec<Vec<Vec<Value>>> = Vec::new();
521 let result = merge_sorted_runs(runs, &[]).unwrap();
522 assert!(result.is_empty());
523 }
524
525 #[test]
526 fn test_merge_sorted_runs_single() {
527 let runs = vec![vec![
528 vec![Value::Int64(1)],
529 vec![Value::Int64(2)],
530 vec![Value::Int64(3)],
531 ]];
532 let keys = vec![SortKey::ascending(0)];
533
534 let result = merge_sorted_runs(runs, &keys).unwrap();
535 assert_eq!(result.len(), 3);
536 }
537
538 #[test]
539 fn test_merge_sorted_runs_multiple() {
540 let runs = vec![
544 vec![
545 vec![Value::Int64(1)],
546 vec![Value::Int64(4)],
547 vec![Value::Int64(7)],
548 ],
549 vec![
550 vec![Value::Int64(2)],
551 vec![Value::Int64(5)],
552 vec![Value::Int64(8)],
553 ],
554 vec![
555 vec![Value::Int64(3)],
556 vec![Value::Int64(6)],
557 vec![Value::Int64(9)],
558 ],
559 ];
560 let keys = vec![SortKey::ascending(0)];
561
562 let result = merge_sorted_runs(runs, &keys).unwrap();
563 assert_eq!(result.len(), 9);
564
565 for i in 0..9 {
567 assert_eq!(result[i][0], Value::Int64((i + 1) as i64));
568 }
569 }
570
571 #[test]
572 fn test_merge_sorted_runs_descending() {
573 let runs = vec![
574 vec![
575 vec![Value::Int64(7)],
576 vec![Value::Int64(4)],
577 vec![Value::Int64(1)],
578 ],
579 vec![
580 vec![Value::Int64(8)],
581 vec![Value::Int64(5)],
582 vec![Value::Int64(2)],
583 ],
584 ];
585 let keys = vec![SortKey::descending(0)];
586
587 let result = merge_sorted_runs(runs, &keys).unwrap();
588 assert_eq!(result.len(), 6);
589
590 assert_eq!(result[0][0], Value::Int64(8));
592 assert_eq!(result[1][0], Value::Int64(7));
593 assert_eq!(result[5][0], Value::Int64(1));
594 }
595
596 #[test]
597 fn test_rows_to_chunks() {
598 let rows = (0..10).map(|i| vec![Value::Int64(i)]).collect();
599 let chunks = rows_to_chunks(rows, 3).unwrap();
600
601 assert_eq!(chunks.len(), 4); assert_eq!(chunks[0].len(), 3);
603 assert_eq!(chunks[1].len(), 3);
604 assert_eq!(chunks[2].len(), 3);
605 assert_eq!(chunks[3].len(), 1);
606 }
607
608 #[test]
609 fn test_merge_distinct_results() {
610 let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[
611 Value::Int64(1),
612 Value::Int64(2),
613 Value::Int64(3),
614 ])]);
615
616 let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[
617 Value::Int64(2),
618 Value::Int64(3),
619 Value::Int64(4),
620 ])]);
621
622 let results = vec![vec![chunk1], vec![chunk2]];
623 let merged = merge_distinct_results(results).unwrap();
624
625 let total_rows: usize = merged.iter().map(DataChunk::len).sum();
626 assert_eq!(total_rows, 4); }
628
629 #[test]
630 fn test_hash_row_with_non_primitive_values() {
631 let row1 = vec![Value::List(vec![Value::Int64(1)].into())];
633 let row2 = vec![Value::List(vec![Value::Int64(2)].into())];
634 let row3 = vec![Value::Bytes(vec![1, 2, 3].into())];
635
636 let h1 = hash_row(&row1);
638 let h2 = hash_row(&row2);
639 let h3 = hash_row(&row3);
640
641 assert_eq!(h1, h2);
643 assert_eq!(h2, h3);
644 }
645
646 #[test]
647 fn test_concat_parallel_results() {
648 let chunk1 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(1)])]);
649 let chunk2 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(2)])]);
650 let chunk3 = DataChunk::new(vec![ValueVector::from_values(&[Value::Int64(3)])]);
651
652 let results = vec![vec![chunk1], vec![chunk2, chunk3]];
653 let concatenated = concat_parallel_results(results);
654
655 assert_eq!(concatenated.len(), 3);
656 }
657}