use thiserror::Error;
#[derive(Debug, Error, PartialEq)]
#[non_exhaustive]
pub enum ShuffleError {
#[error("element_size must not be zero")]
InvalidElementSize,
#[error("data length {data_len} is not divisible by element_size {element_size}")]
Misaligned {
data_len: usize,
element_size: usize,
},
#[error("failed to reserve {bytes} bytes for shuffle output: {reason}")]
AllocationFailed { bytes: usize, reason: String },
}
fn try_reserve_shuffle(out: &mut Vec<u8>, n: usize) -> Result<(), ShuffleError> {
out.try_reserve_exact(n)
.map_err(|e| ShuffleError::AllocationFailed {
bytes: n,
reason: e.to_string(),
})
}
fn try_clone(data: &[u8]) -> Result<Vec<u8>, ShuffleError> {
let mut out: Vec<u8> = Vec::new();
try_reserve_shuffle(&mut out, data.len())?;
out.extend_from_slice(data);
Ok(out)
}
#[cfg(feature = "threads")]
const PARALLEL_SHUFFLE_MIN_BYTES: usize = 64 * 1024;
pub fn shuffle(data: &[u8], element_size: usize) -> Result<Vec<u8>, ShuffleError> {
shuffle_with_threads(data, element_size, 0)
}
pub fn shuffle_with_threads(
data: &[u8],
element_size: usize,
#[allow(unused_variables)] threads: u32,
) -> Result<Vec<u8>, ShuffleError> {
if element_size == 0 {
return Err(ShuffleError::InvalidElementSize);
}
if element_size == 1 || data.is_empty() {
return try_clone(data);
}
if !data.len().is_multiple_of(element_size) {
return Err(ShuffleError::Misaligned {
data_len: data.len(),
element_size,
});
}
let num_elements = data.len() / element_size;
#[cfg(feature = "threads")]
{
if threads >= 2 && element_size >= 2 && data.len() >= PARALLEL_SHUFFLE_MIN_BYTES {
use rayon::prelude::*;
let mut output: Vec<u8> = Vec::new();
try_reserve_shuffle(&mut output, data.len())?;
output.resize(data.len(), 0);
output
.par_chunks_exact_mut(num_elements)
.enumerate()
.for_each(|(byte_idx, plane)| {
for elem in 0..num_elements {
plane[elem] = data[elem * element_size + byte_idx];
}
});
return Ok(output);
}
}
let mut output: Vec<u8> = Vec::new();
try_reserve_shuffle(&mut output, data.len())?;
output.resize(data.len(), 0);
for byte_idx in 0..element_size {
for elem in 0..num_elements {
output[byte_idx * num_elements + elem] = data[elem * element_size + byte_idx];
}
}
Ok(output)
}
pub fn unshuffle(data: &[u8], element_size: usize) -> Result<Vec<u8>, ShuffleError> {
unshuffle_with_threads(data, element_size, 0)
}
pub fn unshuffle_with_threads(
data: &[u8],
element_size: usize,
#[allow(unused_variables)] threads: u32,
) -> Result<Vec<u8>, ShuffleError> {
if element_size == 0 {
return Err(ShuffleError::InvalidElementSize);
}
if element_size == 1 || data.is_empty() {
return try_clone(data);
}
if !data.len().is_multiple_of(element_size) {
return Err(ShuffleError::Misaligned {
data_len: data.len(),
element_size,
});
}
let num_elements = data.len() / element_size;
#[cfg(feature = "threads")]
{
if threads >= 2 && element_size >= 2 && data.len() >= PARALLEL_SHUFFLE_MIN_BYTES {
use rayon::prelude::*;
let mut output: Vec<u8> = Vec::new();
try_reserve_shuffle(&mut output, data.len())?;
output.resize(data.len(), 0);
let chunk_elems = (4096 / element_size).max(64);
output
.par_chunks_mut(chunk_elems * element_size)
.enumerate()
.for_each(|(chunk_idx, out_chunk)| {
let elem_start = chunk_idx * chunk_elems;
for (local_elem, dst) in out_chunk.chunks_exact_mut(element_size).enumerate() {
let elem = elem_start + local_elem;
for byte_idx in 0..element_size {
dst[byte_idx] = data[byte_idx * num_elements + elem];
}
}
});
return Ok(output);
}
}
let mut output: Vec<u8> = Vec::new();
try_reserve_shuffle(&mut output, data.len())?;
output.resize(data.len(), 0);
for byte_idx in 0..element_size {
for elem in 0..num_elements {
output[elem * element_size + byte_idx] = data[byte_idx * num_elements + elem];
}
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shuffle_unshuffle_float32() {
let data: Vec<u8> = (0..12).collect();
let shuffled = shuffle(&data, 4).unwrap();
assert_eq!(shuffled, vec![0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11]);
let unshuffled = unshuffle(&shuffled, 4).unwrap();
assert_eq!(unshuffled, data);
}
#[test]
fn test_shuffle_unshuffle_float64() {
let data: Vec<u8> = (0..16).collect(); let shuffled = shuffle(&data, 8).unwrap();
let unshuffled = unshuffle(&shuffled, 8).unwrap();
assert_eq!(unshuffled, data);
}
#[test]
fn test_shuffle_element_size_1() {
let data = vec![1, 2, 3, 4];
assert_eq!(shuffle(&data, 1).unwrap(), data);
assert_eq!(unshuffle(&data, 1).unwrap(), data);
}
#[test]
fn test_shuffle_empty() {
let data: Vec<u8> = vec![];
assert_eq!(shuffle(&data, 4).unwrap(), data);
}
#[test]
fn test_shuffle_element_size_zero() {
assert!(
matches!(
shuffle(&[1, 2, 3], 0),
Err(ShuffleError::InvalidElementSize)
),
"shuffle with element_size=0 must return Err(InvalidElementSize)"
);
assert!(
matches!(
unshuffle(&[1, 2, 3], 0),
Err(ShuffleError::InvalidElementSize)
),
"unshuffle with element_size=0 must return Err(InvalidElementSize)"
);
}
#[cfg(feature = "threads")]
#[test]
fn shuffle_threads_byte_identical() {
let data: Vec<u8> = (0..256 * 1024).map(|i| (i % 256) as u8).collect();
for element_size in [2usize, 4, 8, 16] {
let seq = shuffle_with_threads(&data, element_size, 0).unwrap();
for t in [1u32, 2, 4, 8] {
let par = shuffle_with_threads(&data, element_size, t).unwrap();
assert_eq!(seq, par, "shuffle threads={t} elem={element_size} mismatch");
let rt = unshuffle_with_threads(&par, element_size, t).unwrap();
assert_eq!(
rt, data,
"round-trip mismatch threads={t} elem={element_size}"
);
}
}
}
#[cfg(feature = "threads")]
#[test]
fn unshuffle_threads_byte_identical() {
let data: Vec<u8> = (0..256 * 1024).map(|i| (i % 256) as u8).collect();
for element_size in [2usize, 4, 8, 16] {
let shuffled = shuffle(&data, element_size).unwrap();
let seq = unshuffle_with_threads(&shuffled, element_size, 0).unwrap();
for t in [1u32, 2, 4, 8] {
let par = unshuffle_with_threads(&shuffled, element_size, t).unwrap();
assert_eq!(
seq, par,
"unshuffle threads={t} elem={element_size} mismatch"
);
}
}
}
#[cfg(feature = "threads")]
#[test]
fn shuffle_below_threshold_uses_sequential_path() {
let data: Vec<u8> = (0..64).collect();
let seq = shuffle(&data, 4).unwrap();
let par = shuffle_with_threads(&data, 4, 8).unwrap();
assert_eq!(seq, par);
}
#[test]
fn test_shuffle_misaligned_data() {
let result = shuffle(&[1, 2, 3], 2);
assert!(
matches!(
result,
Err(ShuffleError::Misaligned {
data_len: 3,
element_size: 2
})
),
"shuffle with misaligned data must return Err(Misaligned)"
);
let result2 = unshuffle(&[1, 2, 3], 2);
assert!(
matches!(
result2,
Err(ShuffleError::Misaligned {
data_len: 3,
element_size: 2
})
),
"unshuffle with misaligned data must return Err(Misaligned)"
);
}
#[test]
fn try_reserve_shuffle_rejects_pathological_capacity() {
let mut v: Vec<u8> = Vec::new();
let err = try_reserve_shuffle(&mut v, usize::MAX)
.expect_err("reservation at usize::MAX must fail the capacity check");
match err {
ShuffleError::AllocationFailed { .. } => {}
other => panic!("expected AllocationFailed, got {other:?}"),
}
}
}