Skip to main content

oxigdal_distributed/
shuffle.rs

1//! Data shuffle operations for distributed processing.
2//!
3//! This module provides shuffle operations for redistributing data across
4//! worker nodes, supporting operations like group-by, sort, and join.
5
6use crate::error::{DistributedError, Result};
7use crate::partition::{HashPartitioner, RangePartitioner};
8use crate::task::PartitionId;
9use arrow::array::{Array, ArrayRef, AsArray};
10use arrow::compute;
11use arrow::datatypes::*;
12use arrow::record_batch::RecordBatch;
13use std::collections::HashMap;
14
15/// Type of shuffle operation.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ShuffleType {
18    /// Hash-based shuffle for group-by operations.
19    Hash,
20    /// Range-based shuffle for sorting.
21    Range,
22    /// Broadcast shuffle (send same data to all workers).
23    Broadcast,
24    /// Custom shuffle with user-defined logic.
25    Custom,
26}
27
28/// Shuffle key for determining partition assignment.
29#[derive(Debug, Clone)]
30pub enum ShuffleKey {
31    /// Shuffle by a single column.
32    Column(String),
33    /// Shuffle by multiple columns.
34    Columns(Vec<String>),
35    /// Shuffle by a computed expression.
36    Expression(String),
37}
38
39/// Configuration for shuffle operations.
40#[derive(Debug, Clone)]
41pub struct ShuffleConfig {
42    /// Type of shuffle.
43    pub shuffle_type: ShuffleType,
44    /// Key to shuffle by.
45    pub key: ShuffleKey,
46    /// Number of target partitions.
47    pub num_partitions: usize,
48    /// Buffer size for shuffle writes.
49    pub buffer_size: usize,
50}
51
52impl ShuffleConfig {
53    /// Create a new shuffle configuration.
54    pub fn new(shuffle_type: ShuffleType, key: ShuffleKey, num_partitions: usize) -> Result<Self> {
55        if num_partitions == 0 {
56            return Err(DistributedError::shuffle(
57                "Number of partitions must be greater than zero",
58            ));
59        }
60
61        Ok(Self {
62            shuffle_type,
63            key,
64            num_partitions,
65            buffer_size: 1024 * 1024, // 1 MB default
66        })
67    }
68
69    /// Set the buffer size.
70    pub fn with_buffer_size(mut self, size: usize) -> Self {
71        self.buffer_size = size;
72        self
73    }
74}
75
76/// Result of a shuffle operation.
77pub struct ShuffleResult {
78    /// Partitioned data, keyed by partition ID.
79    pub partitions: HashMap<PartitionId, Vec<RecordBatch>>,
80    /// Statistics about the shuffle.
81    pub stats: ShuffleStats,
82}
83
84/// Statistics about a shuffle operation.
85#[derive(Debug, Clone, Default)]
86pub struct ShuffleStats {
87    /// Total number of rows shuffled.
88    pub total_rows: u64,
89    /// Total bytes shuffled.
90    pub total_bytes: u64,
91    /// Number of output partitions.
92    pub num_partitions: usize,
93    /// Time taken in milliseconds.
94    pub duration_ms: u64,
95}
96
97/// Hash-based shuffle implementation.
98pub struct HashShuffle {
99    /// Partitioner for hash-based distribution.
100    partitioner: HashPartitioner,
101    /// Column name to hash.
102    column_name: String,
103}
104
105impl HashShuffle {
106    /// Create a new hash shuffle.
107    pub fn new(column_name: String, num_partitions: usize) -> Result<Self> {
108        let partitioner = HashPartitioner::new(num_partitions)?;
109        Ok(Self {
110            partitioner,
111            column_name,
112        })
113    }
114
115    /// Shuffle a record batch.
116    pub fn shuffle(&self, batch: &RecordBatch) -> Result<HashMap<PartitionId, RecordBatch>> {
117        let schema = batch.schema();
118
119        // Find the column to hash
120        let column_index = schema
121            .column_with_name(&self.column_name)
122            .map(|(idx, _)| idx)
123            .ok_or_else(|| {
124                DistributedError::shuffle(format!("Column {} not found", self.column_name))
125            })?;
126
127        let column = batch.column(column_index);
128
129        // Compute partition for each row
130        let partitions = self.compute_partitions(column)?;
131
132        // Group rows by partition
133        let mut partition_indices: HashMap<PartitionId, Vec<usize>> = HashMap::new();
134        for (row_idx, &partition_id) in partitions.iter().enumerate() {
135            partition_indices
136                .entry(partition_id)
137                .or_default()
138                .push(row_idx);
139        }
140
141        // Create a record batch for each partition
142        let mut result = HashMap::new();
143        for (partition_id, indices) in partition_indices {
144            let partition_batch = self.create_partition_batch(batch, &indices)?;
145            result.insert(partition_id, partition_batch);
146        }
147
148        Ok(result)
149    }
150
151    /// Compute partition for each value in a column.
152    fn compute_partitions(&self, column: &ArrayRef) -> Result<Vec<PartitionId>> {
153        let mut partitions = Vec::with_capacity(column.len());
154
155        match column.data_type() {
156            DataType::Int32 => {
157                let array = column.as_primitive::<Int32Type>();
158                for i in 0..array.len() {
159                    if array.is_null(i) {
160                        partitions.push(PartitionId(0));
161                    } else {
162                        let value = array.value(i);
163                        let key = value.to_le_bytes();
164                        partitions.push(self.partitioner.partition_for_key(&key));
165                    }
166                }
167            }
168            DataType::Int64 => {
169                let array = column.as_primitive::<Int64Type>();
170                for i in 0..array.len() {
171                    if array.is_null(i) {
172                        partitions.push(PartitionId(0));
173                    } else {
174                        let value = array.value(i);
175                        let key = value.to_le_bytes();
176                        partitions.push(self.partitioner.partition_for_key(&key));
177                    }
178                }
179            }
180            DataType::Utf8 => {
181                let array = column.as_string::<i32>();
182                for i in 0..array.len() {
183                    if array.is_null(i) {
184                        partitions.push(PartitionId(0));
185                    } else {
186                        let value = array.value(i);
187                        let key = value.as_bytes();
188                        partitions.push(self.partitioner.partition_for_key(key));
189                    }
190                }
191            }
192            DataType::Float64 => {
193                let array = column.as_primitive::<Float64Type>();
194                for i in 0..array.len() {
195                    if array.is_null(i) {
196                        partitions.push(PartitionId(0));
197                    } else {
198                        let value = array.value(i);
199                        let key = value.to_le_bytes();
200                        partitions.push(self.partitioner.partition_for_key(&key));
201                    }
202                }
203            }
204            _ => {
205                return Err(DistributedError::shuffle(format!(
206                    "Unsupported column type for hash shuffle: {:?}",
207                    column.data_type()
208                )));
209            }
210        }
211
212        Ok(partitions)
213    }
214
215    /// Create a record batch from selected indices.
216    fn create_partition_batch(
217        &self,
218        batch: &RecordBatch,
219        indices: &[usize],
220    ) -> Result<RecordBatch> {
221        // Convert indices to Int32Array for use with take kernel
222        let indices_array =
223            arrow::array::Int32Array::from(indices.iter().map(|&i| i as i32).collect::<Vec<_>>());
224
225        // Use Arrow's take kernel to extract rows
226        let columns: Result<Vec<ArrayRef>> = batch
227            .columns()
228            .iter()
229            .map(|col| {
230                compute::take(col.as_ref(), &indices_array, None)
231                    .map_err(|e| DistributedError::arrow(e.to_string()))
232            })
233            .collect();
234
235        RecordBatch::try_new(batch.schema(), columns?)
236            .map_err(|e| DistributedError::arrow(e.to_string()))
237    }
238}
239
240/// Range-based shuffle implementation for sorting.
241pub struct RangeShuffle {
242    /// Partitioner for range-based distribution.
243    partitioner: RangePartitioner,
244    /// Column name to partition by.
245    column_name: String,
246}
247
248impl RangeShuffle {
249    /// Create a new range shuffle.
250    pub fn new(column_name: String, boundaries: Vec<f64>) -> Result<Self> {
251        let partitioner = RangePartitioner::new(boundaries)?;
252        Ok(Self {
253            partitioner,
254            column_name,
255        })
256    }
257
258    /// Shuffle a record batch.
259    pub fn shuffle(&self, batch: &RecordBatch) -> Result<HashMap<PartitionId, RecordBatch>> {
260        let schema = batch.schema();
261
262        // Find the column
263        let column_index = schema
264            .column_with_name(&self.column_name)
265            .map(|(idx, _)| idx)
266            .ok_or_else(|| {
267                DistributedError::shuffle(format!("Column {} not found", self.column_name))
268            })?;
269
270        let column = batch.column(column_index);
271
272        // Compute partition for each row
273        let partitions = self.compute_partitions(column)?;
274
275        // Group rows by partition
276        let mut partition_indices: HashMap<PartitionId, Vec<usize>> = HashMap::new();
277        for (row_idx, &partition_id) in partitions.iter().enumerate() {
278            partition_indices
279                .entry(partition_id)
280                .or_default()
281                .push(row_idx);
282        }
283
284        // Create a record batch for each partition
285        let mut result = HashMap::new();
286        for (partition_id, indices) in partition_indices {
287            let partition_batch = self.create_partition_batch(batch, &indices)?;
288            result.insert(partition_id, partition_batch);
289        }
290
291        Ok(result)
292    }
293
294    /// Compute partition for each value in a column.
295    fn compute_partitions(&self, column: &ArrayRef) -> Result<Vec<PartitionId>> {
296        let mut partitions = Vec::with_capacity(column.len());
297
298        match column.data_type() {
299            DataType::Float64 => {
300                let array = column.as_primitive::<Float64Type>();
301                for i in 0..array.len() {
302                    if array.is_null(i) {
303                        partitions.push(PartitionId(0));
304                    } else {
305                        let value = array.value(i);
306                        partitions.push(self.partitioner.partition_for_value(value));
307                    }
308                }
309            }
310            DataType::Int32 => {
311                let array = column.as_primitive::<Int32Type>();
312                for i in 0..array.len() {
313                    if array.is_null(i) {
314                        partitions.push(PartitionId(0));
315                    } else {
316                        let value = f64::from(array.value(i));
317                        partitions.push(self.partitioner.partition_for_value(value));
318                    }
319                }
320            }
321            DataType::Int64 => {
322                let array = column.as_primitive::<Int64Type>();
323                for i in 0..array.len() {
324                    if array.is_null(i) {
325                        partitions.push(PartitionId(0));
326                    } else {
327                        let value = array.value(i) as f64;
328                        partitions.push(self.partitioner.partition_for_value(value));
329                    }
330                }
331            }
332            _ => {
333                return Err(DistributedError::shuffle(format!(
334                    "Unsupported column type for range shuffle: {:?}",
335                    column.data_type()
336                )));
337            }
338        }
339
340        Ok(partitions)
341    }
342
343    /// Create a record batch from selected indices.
344    fn create_partition_batch(
345        &self,
346        batch: &RecordBatch,
347        indices: &[usize],
348    ) -> Result<RecordBatch> {
349        let indices_array =
350            arrow::array::Int32Array::from(indices.iter().map(|&i| i as i32).collect::<Vec<_>>());
351
352        let columns: Result<Vec<ArrayRef>> = batch
353            .columns()
354            .iter()
355            .map(|col| {
356                compute::take(col.as_ref(), &indices_array, None)
357                    .map_err(|e| DistributedError::arrow(e.to_string()))
358            })
359            .collect();
360
361        RecordBatch::try_new(batch.schema(), columns?)
362            .map_err(|e| DistributedError::arrow(e.to_string()))
363    }
364}
365
366/// Broadcast shuffle that replicates data to all partitions.
367pub struct BroadcastShuffle {
368    /// Number of target partitions.
369    num_partitions: usize,
370}
371
372impl BroadcastShuffle {
373    /// Create a new broadcast shuffle.
374    pub fn new(num_partitions: usize) -> Result<Self> {
375        if num_partitions == 0 {
376            return Err(DistributedError::shuffle(
377                "Number of partitions must be greater than zero",
378            ));
379        }
380        Ok(Self { num_partitions })
381    }
382
383    /// Shuffle a record batch (broadcast to all partitions).
384    pub fn shuffle(&self, batch: &RecordBatch) -> HashMap<PartitionId, RecordBatch> {
385        let mut result = HashMap::new();
386        for i in 0..self.num_partitions {
387            result.insert(PartitionId(i as u64), batch.clone());
388        }
389        result
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use arrow::array::{Float64Array, Int32Array, StringArray};
397    use arrow::datatypes::{Field, Schema};
398    use std::sync::Arc;
399
400    fn create_test_batch() -> std::result::Result<RecordBatch, Box<dyn std::error::Error>> {
401        let schema = Arc::new(Schema::new(vec![
402            Field::new("id", DataType::Int32, false),
403            Field::new("value", DataType::Float64, false),
404            Field::new("name", DataType::Utf8, false),
405        ]));
406
407        let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
408        let value_array = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
409        let name_array = StringArray::from(vec!["a", "b", "c", "d", "e"]);
410
411        Ok(RecordBatch::try_new(
412            schema,
413            vec![
414                Arc::new(id_array),
415                Arc::new(value_array),
416                Arc::new(name_array),
417            ],
418        )?)
419    }
420
421    #[test]
422    fn test_hash_shuffle() -> std::result::Result<(), Box<dyn std::error::Error>> {
423        let batch = create_test_batch()?;
424        let shuffle = HashShuffle::new("id".to_string(), 2)?;
425
426        let result = shuffle.shuffle(&batch)?;
427
428        // Should have at most 2 partitions
429        assert!(result.len() <= 2);
430
431        // Total rows should match
432        let total_rows: usize = result.values().map(|b| b.num_rows()).sum();
433        assert_eq!(total_rows, batch.num_rows());
434        Ok(())
435    }
436
437    #[test]
438    fn test_range_shuffle() -> std::result::Result<(), Box<dyn std::error::Error>> {
439        let batch = create_test_batch()?;
440        let boundaries = vec![2.5];
441        let shuffle = RangeShuffle::new("id".to_string(), boundaries)?;
442
443        let result = shuffle.shuffle(&batch)?;
444
445        // Should have at most 2 partitions
446        assert!(result.len() <= 2);
447
448        // Total rows should match
449        let total_rows: usize = result.values().map(|b| b.num_rows()).sum();
450        assert_eq!(total_rows, batch.num_rows());
451        Ok(())
452    }
453
454    #[test]
455    fn test_broadcast_shuffle() -> std::result::Result<(), Box<dyn std::error::Error>> {
456        let batch = create_test_batch()?;
457        let shuffle = BroadcastShuffle::new(3)?;
458
459        let result = shuffle.shuffle(&batch);
460
461        // Should have exactly 3 partitions
462        assert_eq!(result.len(), 3);
463
464        // Each partition should have all rows
465        for partition_batch in result.values() {
466            assert_eq!(partition_batch.num_rows(), batch.num_rows());
467        }
468        Ok(())
469    }
470
471    #[test]
472    fn test_shuffle_config() -> std::result::Result<(), Box<dyn std::error::Error>> {
473        let config =
474            ShuffleConfig::new(ShuffleType::Hash, ShuffleKey::Column("id".to_string()), 4)?;
475
476        assert_eq!(config.shuffle_type, ShuffleType::Hash);
477        assert_eq!(config.num_partitions, 4);
478        assert_eq!(config.buffer_size, 1024 * 1024);
479        Ok(())
480    }
481}