#[derive(Debug, Clone)]
pub struct SelectionVector {
indices: Vec<u16>,
}
impl SelectionVector {
pub const MAX_CAPACITY: usize = u16::MAX as usize;
#[must_use]
pub fn new_all(count: usize) -> Self {
assert!(count <= Self::MAX_CAPACITY);
Self {
#[allow(clippy::cast_possible_truncation)]
indices: (0..count as u16).collect(),
}
}
#[must_use]
pub fn new_empty() -> Self {
Self {
indices: Vec::new(),
}
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
indices: Vec::with_capacity(capacity.min(Self::MAX_CAPACITY)),
}
}
#[must_use]
pub fn from_predicate<F>(count: usize, predicate: F) -> Self
where
F: Fn(usize) -> bool,
{
let indices: Vec<u16> = (0..count)
.filter(|&i| predicate(i))
.map(|i| {
#[allow(clippy::cast_possible_truncation)]
let idx = i as u16;
idx
})
.collect();
Self { indices }
}
#[must_use]
pub fn len(&self) -> usize {
self.indices.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.indices.is_empty()
}
#[must_use]
pub fn get(&self, position: usize) -> Option<usize> {
self.indices.get(position).map(|&i| i as usize)
}
pub fn push(&mut self, index: usize) {
assert!(index <= Self::MAX_CAPACITY);
#[allow(clippy::cast_possible_truncation)]
self.indices.push(index as u16);
}
#[must_use]
pub fn as_slice(&self) -> &[u16] {
&self.indices
}
pub fn clear(&mut self) {
self.indices.clear();
}
#[must_use]
pub fn filter<F>(&self, predicate: F) -> Self
where
F: Fn(usize) -> bool,
{
let indices: Vec<u16> = self
.indices
.iter()
.copied()
.filter(|&i| predicate(i as usize))
.collect();
Self { indices }
}
#[must_use]
pub fn intersect(&self, other: &Self) -> Self {
let mut result = Vec::new();
let mut i = 0;
let mut j = 0;
while i < self.indices.len() && j < other.indices.len() {
match self.indices[i].cmp(&other.indices[j]) {
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
std::cmp::Ordering::Equal => {
result.push(self.indices[i]);
i += 1;
j += 1;
}
}
}
Self { indices: result }
}
#[must_use]
pub fn union(&self, other: &Self) -> Self {
let mut result = Vec::new();
let mut i = 0;
let mut j = 0;
while i < self.indices.len() && j < other.indices.len() {
match self.indices[i].cmp(&other.indices[j]) {
std::cmp::Ordering::Less => {
result.push(self.indices[i]);
i += 1;
}
std::cmp::Ordering::Greater => {
result.push(other.indices[j]);
j += 1;
}
std::cmp::Ordering::Equal => {
result.push(self.indices[i]);
i += 1;
j += 1;
}
}
}
result.extend_from_slice(&self.indices[i..]);
result.extend_from_slice(&other.indices[j..]);
Self { indices: result }
}
pub fn iter(&self) -> impl Iterator<Item = usize> + '_ {
self.indices.iter().map(|&i| i as usize)
}
#[must_use]
pub fn contains(&self, index: usize) -> bool {
if index > u16::MAX as usize {
return false;
}
#[allow(clippy::cast_possible_truncation)]
self.indices.binary_search(&(index as u16)).is_ok()
}
}
impl Default for SelectionVector {
fn default() -> Self {
Self::new_empty()
}
}
impl IntoIterator for SelectionVector {
type Item = usize;
type IntoIter = std::iter::Map<std::vec::IntoIter<u16>, fn(u16) -> usize>;
fn into_iter(self) -> Self::IntoIter {
self.indices.into_iter().map(|i| i as usize)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_selection_all() {
let sel = SelectionVector::new_all(10);
assert_eq!(sel.len(), 10);
for i in 0..10 {
assert_eq!(sel.get(i), Some(i));
}
}
#[test]
fn test_selection_from_predicate() {
let sel = SelectionVector::from_predicate(10, |i| i % 2 == 0);
assert_eq!(sel.len(), 5);
assert_eq!(sel.get(0), Some(0));
assert_eq!(sel.get(1), Some(2));
assert_eq!(sel.get(2), Some(4));
}
#[test]
fn test_selection_filter() {
let sel = SelectionVector::new_all(10);
let filtered = sel.filter(|i| i >= 5);
assert_eq!(filtered.len(), 5);
assert_eq!(filtered.get(0), Some(5));
}
#[test]
fn test_selection_intersect() {
let sel1 = SelectionVector::from_predicate(10, |i| i % 2 == 0); let sel2 = SelectionVector::from_predicate(10, |i| i % 3 == 0);
let intersection = sel1.intersect(&sel2);
assert_eq!(intersection.len(), 2);
assert_eq!(intersection.get(0), Some(0));
assert_eq!(intersection.get(1), Some(6));
}
#[test]
fn test_selection_union() {
let sel1 = SelectionVector::from_predicate(10, |i| i == 1 || i == 3); let sel2 = SelectionVector::from_predicate(10, |i| i == 2 || i == 3);
let union = sel1.union(&sel2);
assert_eq!(union.len(), 3);
assert_eq!(union.get(0), Some(1));
assert_eq!(union.get(1), Some(2));
assert_eq!(union.get(2), Some(3));
}
#[test]
fn test_selection_iterator() {
let sel = SelectionVector::from_predicate(5, |i| i < 3);
let collected: Vec<_> = sel.iter().collect();
assert_eq!(collected, vec![0, 1, 2]);
}
}