use crate::base::if_rayon;
use alloc::{format, string::String, vec::Vec};
use core::cmp::Ordering;
use itertools::Itertools;
#[cfg(feature = "rayon")]
use rayon::prelude::ParallelSliceMut;
use snafu::Snafu;
#[derive(Snafu, Debug, PartialEq, Eq)]
pub enum PermutationError {
#[snafu(display("Permutation is invalid {error}"))]
InvalidPermutation { error: String },
#[snafu(display("Application of a permutation to a slice with a different length {permutation_size} != {slice_length}"))]
PermutationSizeMismatch {
permutation_size: usize,
slice_length: usize,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Permutation {
permutation: Vec<usize>,
}
impl Permutation {
#[expect(dead_code)]
pub(crate) fn unchecked_new_from_cmp<F>(length: usize, cmp: F) -> Self
where
F: Fn(&usize, &usize) -> Ordering + Sync,
{
let mut indexes = (0..length).collect_vec();
if_rayon!(
indexes.par_sort_unstable_by(cmp),
indexes.sort_unstable_by(cmp)
);
Self {
permutation: indexes,
}
}
pub fn try_new(permutation: Vec<usize>) -> Result<Self, PermutationError> {
let length = permutation.len();
let mut elements = permutation.clone();
elements.sort_unstable();
elements.dedup();
if elements.len() < length {
Err(PermutationError::InvalidPermutation {
error: format!("Permutation can not have duplicate elements: {permutation:?}"),
})
}
else if permutation.iter().any(|&i| i >= length) {
Err(PermutationError::InvalidPermutation {
error: format!("Permutation can not have elements out of bounds: {permutation:?}"),
})
} else {
Ok(Self { permutation })
}
}
pub fn size(&self) -> usize {
self.permutation.len()
}
pub fn try_apply<T>(&self, slice: &[T]) -> Result<Vec<T>, PermutationError>
where
T: Clone,
{
if slice.len() == self.size() {
Ok(self.permutation.iter().map(|&i| slice[i].clone()).collect())
} else {
Err(PermutationError::PermutationSizeMismatch {
permutation_size: self.size(),
slice_length: slice.len(),
})
}
}
}
#[cfg(test)]
mod test {
use super::*;
use alloc::vec;
#[test]
fn test_apply_permutation() {
let permutation = Permutation::try_new(vec![1, 0, 2]).unwrap();
assert_eq!(permutation.size(), 3);
assert_eq!(
permutation.try_apply(&["and", "Space", "Time"]).unwrap(),
vec!["Space", "and", "Time"]
);
}
#[test]
fn test_invalid_permutation() {
assert!(matches!(
Permutation::try_new(vec![1, 0, 0]),
Err(PermutationError::InvalidPermutation { .. })
));
assert!(matches!(
Permutation::try_new(vec![1, 0, 3]),
Err(PermutationError::InvalidPermutation { .. })
));
}
#[test]
fn test_permutation_size_mismatch() {
let permutation = Permutation::try_new(vec![1, 0, 2]).unwrap();
assert_eq!(
permutation.try_apply(&["Space", "Time"]),
Err(PermutationError::PermutationSizeMismatch {
permutation_size: 3,
slice_length: 2
})
);
}
}