1use crate::error::{DistributedError, Result};
7use crate::partition::{HashPartitioner, RangePartitioner};
8use crate::task::PartitionId;
9use arrow::array::{Array, ArrayRef, AsArray};
10use arrow::compute;
11use arrow::datatypes::*;
12use arrow::record_batch::RecordBatch;
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ShuffleType {
18 Hash,
20 Range,
22 Broadcast,
24 Custom,
26}
27
28#[derive(Debug, Clone)]
30pub enum ShuffleKey {
31 Column(String),
33 Columns(Vec<String>),
35 Expression(String),
37}
38
39#[derive(Debug, Clone)]
41pub struct ShuffleConfig {
42 pub shuffle_type: ShuffleType,
44 pub key: ShuffleKey,
46 pub num_partitions: usize,
48 pub buffer_size: usize,
50}
51
52impl ShuffleConfig {
53 pub fn new(shuffle_type: ShuffleType, key: ShuffleKey, num_partitions: usize) -> Result<Self> {
55 if num_partitions == 0 {
56 return Err(DistributedError::shuffle(
57 "Number of partitions must be greater than zero",
58 ));
59 }
60
61 Ok(Self {
62 shuffle_type,
63 key,
64 num_partitions,
65 buffer_size: 1024 * 1024, })
67 }
68
69 pub fn with_buffer_size(mut self, size: usize) -> Self {
71 self.buffer_size = size;
72 self
73 }
74}
75
76pub struct ShuffleResult {
78 pub partitions: HashMap<PartitionId, Vec<RecordBatch>>,
80 pub stats: ShuffleStats,
82}
83
84#[derive(Debug, Clone, Default)]
86pub struct ShuffleStats {
87 pub total_rows: u64,
89 pub total_bytes: u64,
91 pub num_partitions: usize,
93 pub duration_ms: u64,
95}
96
97pub struct HashShuffle {
99 partitioner: HashPartitioner,
101 column_name: String,
103}
104
105impl HashShuffle {
106 pub fn new(column_name: String, num_partitions: usize) -> Result<Self> {
108 let partitioner = HashPartitioner::new(num_partitions)?;
109 Ok(Self {
110 partitioner,
111 column_name,
112 })
113 }
114
115 pub fn shuffle(&self, batch: &RecordBatch) -> Result<HashMap<PartitionId, RecordBatch>> {
117 let schema = batch.schema();
118
119 let column_index = schema
121 .column_with_name(&self.column_name)
122 .map(|(idx, _)| idx)
123 .ok_or_else(|| {
124 DistributedError::shuffle(format!("Column {} not found", self.column_name))
125 })?;
126
127 let column = batch.column(column_index);
128
129 let partitions = self.compute_partitions(column)?;
131
132 let mut partition_indices: HashMap<PartitionId, Vec<usize>> = HashMap::new();
134 for (row_idx, &partition_id) in partitions.iter().enumerate() {
135 partition_indices
136 .entry(partition_id)
137 .or_default()
138 .push(row_idx);
139 }
140
141 let mut result = HashMap::new();
143 for (partition_id, indices) in partition_indices {
144 let partition_batch = self.create_partition_batch(batch, &indices)?;
145 result.insert(partition_id, partition_batch);
146 }
147
148 Ok(result)
149 }
150
151 fn compute_partitions(&self, column: &ArrayRef) -> Result<Vec<PartitionId>> {
153 let mut partitions = Vec::with_capacity(column.len());
154
155 match column.data_type() {
156 DataType::Int32 => {
157 let array = column.as_primitive::<Int32Type>();
158 for i in 0..array.len() {
159 if array.is_null(i) {
160 partitions.push(PartitionId(0));
161 } else {
162 let value = array.value(i);
163 let key = value.to_le_bytes();
164 partitions.push(self.partitioner.partition_for_key(&key));
165 }
166 }
167 }
168 DataType::Int64 => {
169 let array = column.as_primitive::<Int64Type>();
170 for i in 0..array.len() {
171 if array.is_null(i) {
172 partitions.push(PartitionId(0));
173 } else {
174 let value = array.value(i);
175 let key = value.to_le_bytes();
176 partitions.push(self.partitioner.partition_for_key(&key));
177 }
178 }
179 }
180 DataType::Utf8 => {
181 let array = column.as_string::<i32>();
182 for i in 0..array.len() {
183 if array.is_null(i) {
184 partitions.push(PartitionId(0));
185 } else {
186 let value = array.value(i);
187 let key = value.as_bytes();
188 partitions.push(self.partitioner.partition_for_key(key));
189 }
190 }
191 }
192 DataType::Float64 => {
193 let array = column.as_primitive::<Float64Type>();
194 for i in 0..array.len() {
195 if array.is_null(i) {
196 partitions.push(PartitionId(0));
197 } else {
198 let value = array.value(i);
199 let key = value.to_le_bytes();
200 partitions.push(self.partitioner.partition_for_key(&key));
201 }
202 }
203 }
204 _ => {
205 return Err(DistributedError::shuffle(format!(
206 "Unsupported column type for hash shuffle: {:?}",
207 column.data_type()
208 )));
209 }
210 }
211
212 Ok(partitions)
213 }
214
215 fn create_partition_batch(
217 &self,
218 batch: &RecordBatch,
219 indices: &[usize],
220 ) -> Result<RecordBatch> {
221 let indices_array =
223 arrow::array::Int32Array::from(indices.iter().map(|&i| i as i32).collect::<Vec<_>>());
224
225 let columns: Result<Vec<ArrayRef>> = batch
227 .columns()
228 .iter()
229 .map(|col| {
230 compute::take(col.as_ref(), &indices_array, None)
231 .map_err(|e| DistributedError::arrow(e.to_string()))
232 })
233 .collect();
234
235 RecordBatch::try_new(batch.schema(), columns?)
236 .map_err(|e| DistributedError::arrow(e.to_string()))
237 }
238}
239
240pub struct RangeShuffle {
242 partitioner: RangePartitioner,
244 column_name: String,
246}
247
248impl RangeShuffle {
249 pub fn new(column_name: String, boundaries: Vec<f64>) -> Result<Self> {
251 let partitioner = RangePartitioner::new(boundaries)?;
252 Ok(Self {
253 partitioner,
254 column_name,
255 })
256 }
257
258 pub fn shuffle(&self, batch: &RecordBatch) -> Result<HashMap<PartitionId, RecordBatch>> {
260 let schema = batch.schema();
261
262 let column_index = schema
264 .column_with_name(&self.column_name)
265 .map(|(idx, _)| idx)
266 .ok_or_else(|| {
267 DistributedError::shuffle(format!("Column {} not found", self.column_name))
268 })?;
269
270 let column = batch.column(column_index);
271
272 let partitions = self.compute_partitions(column)?;
274
275 let mut partition_indices: HashMap<PartitionId, Vec<usize>> = HashMap::new();
277 for (row_idx, &partition_id) in partitions.iter().enumerate() {
278 partition_indices
279 .entry(partition_id)
280 .or_default()
281 .push(row_idx);
282 }
283
284 let mut result = HashMap::new();
286 for (partition_id, indices) in partition_indices {
287 let partition_batch = self.create_partition_batch(batch, &indices)?;
288 result.insert(partition_id, partition_batch);
289 }
290
291 Ok(result)
292 }
293
294 fn compute_partitions(&self, column: &ArrayRef) -> Result<Vec<PartitionId>> {
296 let mut partitions = Vec::with_capacity(column.len());
297
298 match column.data_type() {
299 DataType::Float64 => {
300 let array = column.as_primitive::<Float64Type>();
301 for i in 0..array.len() {
302 if array.is_null(i) {
303 partitions.push(PartitionId(0));
304 } else {
305 let value = array.value(i);
306 partitions.push(self.partitioner.partition_for_value(value));
307 }
308 }
309 }
310 DataType::Int32 => {
311 let array = column.as_primitive::<Int32Type>();
312 for i in 0..array.len() {
313 if array.is_null(i) {
314 partitions.push(PartitionId(0));
315 } else {
316 let value = f64::from(array.value(i));
317 partitions.push(self.partitioner.partition_for_value(value));
318 }
319 }
320 }
321 DataType::Int64 => {
322 let array = column.as_primitive::<Int64Type>();
323 for i in 0..array.len() {
324 if array.is_null(i) {
325 partitions.push(PartitionId(0));
326 } else {
327 let value = array.value(i) as f64;
328 partitions.push(self.partitioner.partition_for_value(value));
329 }
330 }
331 }
332 _ => {
333 return Err(DistributedError::shuffle(format!(
334 "Unsupported column type for range shuffle: {:?}",
335 column.data_type()
336 )));
337 }
338 }
339
340 Ok(partitions)
341 }
342
343 fn create_partition_batch(
345 &self,
346 batch: &RecordBatch,
347 indices: &[usize],
348 ) -> Result<RecordBatch> {
349 let indices_array =
350 arrow::array::Int32Array::from(indices.iter().map(|&i| i as i32).collect::<Vec<_>>());
351
352 let columns: Result<Vec<ArrayRef>> = batch
353 .columns()
354 .iter()
355 .map(|col| {
356 compute::take(col.as_ref(), &indices_array, None)
357 .map_err(|e| DistributedError::arrow(e.to_string()))
358 })
359 .collect();
360
361 RecordBatch::try_new(batch.schema(), columns?)
362 .map_err(|e| DistributedError::arrow(e.to_string()))
363 }
364}
365
366pub struct BroadcastShuffle {
368 num_partitions: usize,
370}
371
372impl BroadcastShuffle {
373 pub fn new(num_partitions: usize) -> Result<Self> {
375 if num_partitions == 0 {
376 return Err(DistributedError::shuffle(
377 "Number of partitions must be greater than zero",
378 ));
379 }
380 Ok(Self { num_partitions })
381 }
382
383 pub fn shuffle(&self, batch: &RecordBatch) -> HashMap<PartitionId, RecordBatch> {
385 let mut result = HashMap::new();
386 for i in 0..self.num_partitions {
387 result.insert(PartitionId(i as u64), batch.clone());
388 }
389 result
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use arrow::array::{Float64Array, Int32Array, StringArray};
397 use arrow::datatypes::{Field, Schema};
398 use std::sync::Arc;
399
400 fn create_test_batch() -> std::result::Result<RecordBatch, Box<dyn std::error::Error>> {
401 let schema = Arc::new(Schema::new(vec![
402 Field::new("id", DataType::Int32, false),
403 Field::new("value", DataType::Float64, false),
404 Field::new("name", DataType::Utf8, false),
405 ]));
406
407 let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
408 let value_array = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0, 50.0]);
409 let name_array = StringArray::from(vec!["a", "b", "c", "d", "e"]);
410
411 Ok(RecordBatch::try_new(
412 schema,
413 vec![
414 Arc::new(id_array),
415 Arc::new(value_array),
416 Arc::new(name_array),
417 ],
418 )?)
419 }
420
421 #[test]
422 fn test_hash_shuffle() -> std::result::Result<(), Box<dyn std::error::Error>> {
423 let batch = create_test_batch()?;
424 let shuffle = HashShuffle::new("id".to_string(), 2)?;
425
426 let result = shuffle.shuffle(&batch)?;
427
428 assert!(result.len() <= 2);
430
431 let total_rows: usize = result.values().map(|b| b.num_rows()).sum();
433 assert_eq!(total_rows, batch.num_rows());
434 Ok(())
435 }
436
437 #[test]
438 fn test_range_shuffle() -> std::result::Result<(), Box<dyn std::error::Error>> {
439 let batch = create_test_batch()?;
440 let boundaries = vec![2.5];
441 let shuffle = RangeShuffle::new("id".to_string(), boundaries)?;
442
443 let result = shuffle.shuffle(&batch)?;
444
445 assert!(result.len() <= 2);
447
448 let total_rows: usize = result.values().map(|b| b.num_rows()).sum();
450 assert_eq!(total_rows, batch.num_rows());
451 Ok(())
452 }
453
454 #[test]
455 fn test_broadcast_shuffle() -> std::result::Result<(), Box<dyn std::error::Error>> {
456 let batch = create_test_batch()?;
457 let shuffle = BroadcastShuffle::new(3)?;
458
459 let result = shuffle.shuffle(&batch);
460
461 assert_eq!(result.len(), 3);
463
464 for partition_batch in result.values() {
466 assert_eq!(partition_batch.num_rows(), batch.num_rows());
467 }
468 Ok(())
469 }
470
471 #[test]
472 fn test_shuffle_config() -> std::result::Result<(), Box<dyn std::error::Error>> {
473 let config =
474 ShuffleConfig::new(ShuffleType::Hash, ShuffleKey::Column("id".to_string()), 4)?;
475
476 assert_eq!(config.shuffle_type, ShuffleType::Hash);
477 assert_eq!(config.num_partitions, 4);
478 assert_eq!(config.buffer_size, 1024 * 1024);
479 Ok(())
480 }
481}