use num_traits::Zero;
#[inline]
pub fn stride_from_shape(shape: &[usize]) -> Vec<usize> {
let mut stride = vec![1; shape.len()];
for i in (1..shape.len()).rev() {
stride[i - 1] = stride[i] * shape[i];
}
stride
}
#[inline]
#[track_caller]
pub fn deinterleave<T: Clone>(x: &[T], evens: &mut [T], odds: &mut [T]) {
let nx = x.len();
let n_e = evens.len();
let n_o = odds.len();
assert_eq!(
nx / 2,
n_o,
"incorrect odd length, {n_o}, for slice deinterleave"
);
assert_eq!(
nx.div_ceil(2),
n_e,
"incorrect even length, {n_e}, for slice deinterleave"
);
deinterleave_unchecked(x, evens, odds);
}
#[inline(always)]
pub(crate) fn deinterleave_unchecked<T: Clone>(x: &[T], evens: &mut [T], odds: &mut [T]) {
let (xc, rem) = x.as_chunks();
xc.iter()
.zip(evens.iter_mut().zip(odds))
.for_each(|([xe, xo], (e, o))| {
*e = xe.clone();
*o = xo.clone();
});
if !rem.is_empty() {
*evens.last_mut().unwrap() = rem.first().unwrap().clone();
}
}
#[inline]
#[track_caller]
pub fn deinterleave_2d<T: Clone>(input: &[T], output: &mut [T], shape: &[usize; 2]) {
let n_total: usize = shape.iter().product();
assert_eq!(input.len(), n_total);
assert_eq!(output.len(), n_total);
let n_first = shape[0].div_ceil(2);
let n_sub: usize = shape[1..].iter().product();
let (first, second) = output.split_at_mut(n_first * n_sub);
let mut in_chunks = input.chunks_exact(2 * n_sub);
let mut first_chunks = first.chunks_exact_mut(n_sub);
let n_first = shape[1].div_ceil(2);
first_chunks
.by_ref()
.zip(second.chunks_exact_mut(n_sub))
.zip(in_chunks.by_ref())
.for_each(|((f, s), inp)| {
let (f_f, f_s) = f.split_at_mut(n_first);
deinterleave(&inp[0..n_sub], f_f, f_s);
let (s_f, s_s) = s.split_at_mut(n_first);
deinterleave(&inp[n_sub..2 * n_sub], s_f, s_s);
});
first_chunks.for_each(|f| {
let (evens, odds) = f.split_at_mut(n_first);
deinterleave(in_chunks.remainder(), evens, odds);
});
}
#[inline]
#[track_caller]
pub fn deinterleave_nd<T: Clone>(input: &[T], output: &mut [T], shape: &[usize]) {
match shape.len() {
0 => {}
1 => {
let (f, s) = output.split_at_mut(shape[0].div_ceil(2));
deinterleave(input, f, s);
}
2 => deinterleave_2d(
input,
output,
shape
.try_into()
.expect("shape length was already checked to be 2"),
),
_ => {
let n_total: usize = shape.iter().product();
assert_eq!(input.len(), n_total);
assert_eq!(input.len(), n_total);
deinterleave_nd_unchecked(input, output, shape);
}
}
}
#[inline]
fn deinterleave_nd_unchecked<T: Clone>(input: &[T], output: &mut [T], shape: &[usize]) {
match shape.len() {
0 => {}
1 => {
let (f, s) = output.split_at_mut(shape[0].div_ceil(2));
deinterleave(input, f, s);
}
2 => deinterleave_2d(
input,
output,
shape
.try_into()
.expect("shape length was already checked to be 2"),
),
_ => {
let n_first = shape[0].div_ceil(2);
let n_sub: usize = shape[1..].iter().product();
let (first, second) = output.split_at_mut(n_first * n_sub);
let mut first_chunks = first.chunks_exact_mut(n_sub);
let mut in_chunks = input.chunks_exact(2 * n_sub);
first_chunks
.by_ref()
.zip(second.chunks_exact_mut(n_sub))
.zip(in_chunks.by_ref())
.for_each(|((f, s), inp)| {
let (in_even, in_odd) = inp.split_at(n_sub);
deinterleave_nd_unchecked(in_even, f, &shape[1..]);
deinterleave_nd_unchecked(in_odd, s, &shape[1..]);
});
first_chunks.for_each(|f| {
deinterleave_nd_unchecked(in_chunks.remainder(), f, &shape[1..]);
});
}
}
}
#[inline]
#[track_caller]
pub fn stack<T: Clone + Zero>(first: &[T], second: &[T], out: &mut [T]) {
let nf = first.len();
let ns = second.len();
let n = out.len();
assert!(
nf + ns <= n,
"invalid lengths for slice stack, first: {nf}, second: {ns}, third: {n}",
);
stack_unchecked(first, second, out);
}
#[inline(always)]
pub(crate) fn stack_unchecked<T: Clone + Zero>(first: &[T], second: &[T], out: &mut [T]) {
let (xf, xe) = out.split_at_mut(first.len());
let (xm, xs) = xe.split_at_mut(xe.len() - second.len());
first.iter().cloned().zip(xf).for_each(|(i, o)| *o = i);
xm.iter_mut().for_each(|o| *o = T::zero());
second.iter().cloned().zip(xs).for_each(|(i, o)| *o = i);
}
#[inline]
#[track_caller]
pub fn split<T: Clone>(x: &[T], first: &mut [T], second: &mut [T]) {
let nf = first.len();
let ns = second.len();
let nx = x.len();
assert!(
nf + ns <= nx,
"invalid lengths for slice stack, first: {nf}, second: {ns}, third: {nx}"
);
split_unchecked(x, first, second);
}
#[inline(always)]
pub(crate) fn split_unchecked<T: Clone>(x: &[T], first: &mut [T], second: &mut [T]) {
let (xf, xe) = x.split_at(first.len());
let (_, xs) = xe.split_at(xe.len() - second.len());
xf.iter().cloned().zip(first).for_each(|(i, o)| *o = i);
xs.iter().cloned().zip(second).for_each(|(i, o)| *o = i);
}
#[inline]
#[track_caller]
pub fn pour_into<T: Clone>(source: &[T], sink: &mut [T]) {
let n = source.len();
let no = sink.len();
assert!(
no <= n,
"Output slice with length {no} too long for strided slice with length {n}."
);
pour_into_unchecked(source, sink);
}
#[inline(always)]
pub(crate) fn pour_into_unchecked<T: Clone>(source: &[T], sink: &mut [T]) {
source.iter().cloned().zip(sink).for_each(|(i, o)| *o = i);
}
#[inline]
#[track_caller]
pub fn fill_from<T: Clone + Zero>(x: &[T], sink: &mut [T]) {
let n = x.len();
let no = sink.len();
assert!(
no <= n,
"Output slice with length {no} too long for strided slice with length {n}."
);
fill_from_unchecked(x, sink);
}
#[inline(always)]
pub(crate) fn fill_from_unchecked<T: Clone + Zero>(source: &[T], sink: &mut [T]) {
let (head, tail) = sink.split_at_mut(source.len());
source.iter().cloned().zip(head).for_each(|(i, o)| *o = i);
tail.iter_mut().for_each(|o| *o = T::zero());
}
#[inline]
#[track_caller]
pub fn interleave<T: Clone>(evens: &[T], odds: &[T], x: &mut [T]) {
let nx = x.len();
let n_e = evens.len();
let n_o = odds.len();
assert_eq!(nx / 2, n_o);
assert_eq!(nx.div_ceil(2), n_e);
interleave_unchecked(evens, odds, x);
}
#[inline(always)]
pub(crate) fn interleave_unchecked<T: Clone>(evens: &[T], odds: &[T], x: &mut [T]) {
let (xc, rem) = x.as_chunks_mut();
xc.iter_mut()
.zip(evens.iter().cloned().zip(odds.iter().cloned()))
.for_each(|([xe, xo], (e, o))| {
*xe = e;
*xo = o;
});
if !rem.is_empty() {
*rem.first_mut().unwrap() = evens.last().unwrap().clone();
}
}
#[inline]
pub fn interleave_inplace<T: Clone>(x: &mut [T]) {
let n = x.len();
if n < 2 {
return;
} else if n == 3 {
x.swap(1, 2);
return;
}
let do_sub = n % 2 == 1;
let x = match do_sub {
true => &mut x[1..],
false => x,
};
let n = x.len();
let mut m = 0;
while m < n {
let i = lookup(n - m);
let slice_start = m + (i - 1) / 2;
let slice_len = (n - m) / 2;
shift_n(&mut x[slice_start..slice_start + slice_len], (i - 1) / 2);
perfect_shuffle(&mut x[m..m + i - 1]);
m += i - 1;
}
if !do_sub {
x.chunks_exact_mut(2).for_each(|x| x.reverse());
}
}
#[inline(always)]
fn cycle<T: Clone>(x: &mut [T], start: usize) {
let n = x.len();
let mut i_c = (start * 2).rem_euclid(n + 1);
let mut t1 = x[start - 1].clone();
std::mem::swap(&mut x[i_c - 1], &mut t1);
while i_c != start {
let i = (i_c * 2).rem_euclid(n + 1);
std::mem::swap(&mut x[i - 1], &mut t1);
i_c = i;
}
}
#[inline(always)]
fn shift_n<T>(x: &mut [T], n: usize) {
assert!(n <= x.len());
let (left, right) = x.split_at_mut(x.len() - n);
left.reverse();
right.reverse();
x.reverse();
}
#[inline(always)]
fn lookup(n: usize) -> usize {
let mut i = 3;
while i <= n + 1 {
i *= 3
}
if i > 3 {
i /= 3
};
i
}
#[inline(always)]
fn perfect_shuffle<T: Clone>(x: &mut [T]) {
let n = x.len();
match n {
2 => x.swap(0, 1),
_ => {
let mut i = 1;
while i < n {
cycle(x, i);
i *= 3;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use itertools::Itertools;
use rstest::rstest;
#[rstest]
fn test_interleave_inplace(#[values(20, 21, 22, 32, 423, 553)] n: usize) {
let evens = (0..n).step_by(2).collect_vec();
let odds = (1..n).step_by(2).collect_vec();
let mut x = evens.iter().chain(odds.iter()).cloned().collect::<Vec<_>>();
interleave_inplace(&mut x);
assert_eq!(x, (0..n).collect_vec());
}
#[rstest]
fn test_interleave(#[values(20, 21, 22, 32, 423, 553)] n: usize) {
let ns = (n + 1) / 2;
let nd = n / 2;
let s = (0..ns).collect_vec();
let d = (ns..ns + nd).collect_vec();
let mut out = vec![0; n];
interleave(&s, &d, &mut out);
let expected = (0..ns).interleave(ns..ns + nd).collect_vec();
assert_eq!(out, expected);
}
#[rstest]
fn test_deinterleave(#[values(20, 21, 22, 32, 423, 553)] n: usize) {
let ns = (n + 1) / 2;
let nd = n / 2;
let mut s = vec![0; ns];
let mut d = vec![0; nd];
let inp = (0..ns).interleave(ns..ns + nd).collect_vec();
deinterleave(&inp, &mut s, &mut d);
assert_eq!(s, (0..ns).collect_vec());
assert_eq!(d, (ns..ns + nd).collect_vec());
}
}