use std::simd::cmp::SimdPartialEq;
use crate::simd_element::{SimdMaskOps, SortedSimdElement};
pub fn deduplicate<T>(out: &mut [T], input: &[T]) -> usize
where
T: SortedSimdElement,
T::SimdVec: SimdPartialEq<Mask = T::SimdMask>,
{
assert!(
out.len() >= input.len(),
"output slice must be at least as large as input"
);
if input.is_empty() {
return 0;
}
let lanes = T::LANES;
out[0] = input[0];
let mut write_pos = 1;
let mut i = 1;
while i + lanes <= input.len() {
if input[i] == input[i + lanes - 1] {
if input[i] != out[write_pos - 1] {
out[write_pos] = input[i];
write_pos += 1;
}
i += lanes;
continue;
}
let curr = T::simd_from_slice(&input[i..i + lanes]);
let prev = T::simd_from_slice(&input[i - 1..i + lanes - 1]);
let ne_mask = curr.simd_ne(prev);
if ne_mask.all() {
out[write_pos..write_pos + lanes].copy_from_slice(&input[i..i + lanes]);
write_pos += lanes;
i += lanes;
continue;
}
for lane in 0..lanes {
if ne_mask.test(lane) {
out[write_pos] = input[i + lane];
write_pos += 1;
}
}
i += lanes;
}
while i < input.len() {
if input[i] != out[write_pos - 1] {
out[write_pos] = input[i];
write_pos += 1;
}
i += 1;
}
write_pos
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deduplicate_empty() {
let input: Vec<u64> = vec![];
let mut out: Vec<u64> = vec![];
let new_len = deduplicate(&mut out[..], &input[..]);
assert_eq!(new_len, 0);
}
#[test]
fn test_deduplicate_tail_dupe() {
let input = vec![1u64, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3];
let mut out = vec![0u64; input.len()];
let new_len = deduplicate(&mut out[..], &input[..]);
assert_eq!(new_len, 3);
assert_eq!(out[0..new_len], [1, 2, 3]);
}
#[test]
fn test_deduplicate_head_tail_dupe() {
let input = vec![1u64, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3];
let mut out = vec![0u64; input.len()];
let new_len = deduplicate(&mut out[..], &input[..]);
assert_eq!(new_len, 2);
assert_eq!(out[0..new_len], [1, 3]);
}
#[test]
fn test_deduplicate_one_early() {
let input = vec![
1u64, 2, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
];
let mut out = vec![0u64; input.len()];
let new_len = deduplicate(&mut out[..], &input[..]);
assert_eq!(new_len, 17);
assert_eq!(
out[0..new_len],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17],
);
}
#[test]
fn test_deduplicate_many() {
let input = [1u64, 1, 2, 3, 4, 5, 5, 6];
let mut out = vec![0u64; input.len()];
let new_len = deduplicate(&mut out[..], &input[..]);
assert_eq!(out[0..6], [1, 2, 3, 4, 5, 6]);
assert_eq!(new_len, 6);
}
macro_rules! test_deduplicate_type {
($name:ident, $t:ty) => {
#[test]
fn $name() {
let input: Vec<$t> = vec![1, 1, 2, 3, 3, 4];
let mut out: Vec<$t> = vec![0; input.len()];
let new_len = deduplicate(&mut out, &input);
assert_eq!(new_len, 4);
assert_eq!(&out[..new_len], &[1, 2, 3, 4]);
let input: Vec<$t> = vec![1, 2, 3, 4, 5];
let mut out: Vec<$t> = vec![0; input.len()];
assert_eq!(deduplicate(&mut out, &input), 5);
let input: Vec<$t> = vec![5, 5, 5, 5, 5];
let mut out: Vec<$t> = vec![0; input.len()];
assert_eq!(deduplicate(&mut out, &input), 1);
let input: Vec<$t> = (0..100).map(|x| (x / 2) as $t).collect();
let mut out: Vec<$t> = vec![0; input.len()];
let new_len = deduplicate(&mut out, &input);
assert_eq!(new_len, 50);
}
};
}
test_deduplicate_type!(test_deduplicate_u8, u8);
test_deduplicate_type!(test_deduplicate_u16, u16);
test_deduplicate_type!(test_deduplicate_u32, u32);
test_deduplicate_type!(test_deduplicate_i8, i8);
test_deduplicate_type!(test_deduplicate_i16, i16);
test_deduplicate_type!(test_deduplicate_i32, i32);
test_deduplicate_type!(test_deduplicate_i64, i64);
}