use crate::error::{validate_positions, IncompatibleDimensions, InvalidPositions};
use crate::BinNum;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::ops::{Add, Deref, Mul};
mod bitwise_operations;
use bitwise_operations::BitwiseZipIter;
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
pub struct SparseBinVecBase<T> {
positions: T,
length: usize,
}
pub type SparseBinVec = SparseBinVecBase<Vec<usize>>;
pub type SparseBinSlice<'a> = SparseBinVecBase<&'a [usize]>;
impl SparseBinVec {
pub fn zeros(length: usize) -> Self {
Self {
length,
positions: Vec::new(),
}
}
pub fn empty() -> Self {
Self {
length: 0,
positions: Vec::new(),
}
}
pub fn to_positions_vec(self) -> Vec<usize> {
self.positions
}
}
impl<T: Deref<Target = [usize]>> SparseBinVecBase<T> {
pub fn new(length: usize, positions: T) -> Self {
Self::try_new(length, positions).unwrap()
}
pub fn try_new(length: usize, positions: T) -> Result<Self, InvalidPositions> {
validate_positions(length, &positions)?;
Ok(Self { positions, length })
}
pub(crate) fn new_unchecked(length: usize, positions: T) -> Self {
Self { positions, length }
}
pub fn len(&self) -> usize {
self.length
}
pub fn weight(&self) -> usize {
self.positions.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_zero(&self) -> bool {
self.weight() == 0
}
pub fn get(&self, position: usize) -> Option<BinNum> {
if position < self.len() {
if self.positions.contains(&position) {
Some(1.into())
} else {
Some(0.into())
}
} else {
None
}
}
pub fn is_zero_at(&self, position: usize) -> Option<bool> {
self.get(position).map(|value| value == 0.into())
}
pub fn is_one_at(&self, position: usize) -> Option<bool> {
self.get(position).map(|value| value == 1.into())
}
pub fn non_trivial_positions<'a>(&'a self) -> NonTrivialPositions<'a> {
NonTrivialPositions {
positions: &self.positions,
index: 0,
}
}
pub fn iter_dense<'a>(&'a self) -> IterDense<'a, T> {
IterDense {
vec: self,
index: 0,
}
}
pub fn concat(&self, other: &Self) -> SparseBinVec {
let positions = self
.non_trivial_positions()
.chain(other.non_trivial_positions().map(|pos| pos + self.len()))
.collect();
SparseBinVec::new_unchecked(self.len() + other.len(), positions)
}
pub fn keep_only_positions(
&self,
positions: &[usize],
) -> Result<SparseBinVec, InvalidPositions> {
validate_positions(self.length, positions)?;
let old_to_new_positions_map = positions
.iter()
.enumerate()
.map(|(new, old)| (old, new))
.collect::<HashMap<_, _>>();
let new_positions = self
.non_trivial_positions()
.filter_map(|position| old_to_new_positions_map.get(&position).cloned())
.collect();
Ok(SparseBinVec::new_unchecked(positions.len(), new_positions))
}
pub fn without_positions(&self, positions: &[usize]) -> Result<SparseBinVec, InvalidPositions> {
let to_keep: Vec<usize> = (0..self.len()).filter(|x| !positions.contains(x)).collect();
self.keep_only_positions(&to_keep)
}
pub fn as_view(&self) -> SparseBinSlice {
SparseBinSlice {
length: self.length,
positions: &self.positions,
}
}
pub fn as_slice(&self) -> &[usize] {
self.positions.as_ref()
}
pub fn to_vec(self) -> SparseBinVec {
SparseBinVec {
length: self.length,
positions: self.positions.to_owned(),
}
}
pub fn dot_with<S: Deref<Target = [usize]>>(
&self,
other: &SparseBinVecBase<S>,
) -> Result<BinNum, IncompatibleDimensions<usize, usize>> {
if self.len() != other.len() {
return Err(IncompatibleDimensions::new(self.len(), other.len()));
}
Ok(
BitwiseZipIter::new(self.as_view(), other.as_view()).fold(0.into(), |sum, x| {
sum + x.first_row_value * x.second_row_value
}),
)
}
pub fn bitwise_xor_with<S: Deref<Target = [usize]>>(
&self,
other: &SparseBinVecBase<S>,
) -> Result<SparseBinVec, IncompatibleDimensions<usize, usize>> {
if self.len() != other.len() {
return Err(IncompatibleDimensions::new(self.len(), other.len()));
}
let positions = BitwiseZipIter::new(self.as_view(), other.as_view())
.filter_map(|x| {
if x.first_row_value + x.second_row_value == 1.into() {
Some(x.position)
} else {
None
}
})
.collect();
Ok(SparseBinVec::new_unchecked(self.len(), positions))
}
pub fn as_json(&self) -> Result<String, serde_json::Error>
where
T: Serialize,
{
serde_json::to_string(self)
}
}
impl<S, T> Add<&SparseBinVecBase<S>> for &SparseBinVecBase<T>
where
S: Deref<Target = [usize]>,
T: Deref<Target = [usize]>,
{
type Output = SparseBinVec;
fn add(self, other: &SparseBinVecBase<S>) -> Self::Output {
self.bitwise_xor_with(other).expect(&format!(
"vector of length {} can't be added to vector of length {}",
self.len(),
other.len()
))
}
}
impl<S, T> Mul<&SparseBinVecBase<S>> for &SparseBinVecBase<T>
where
S: Deref<Target = [usize]>,
T: Deref<Target = [usize]>,
{
type Output = BinNum;
fn mul(self, other: &SparseBinVecBase<S>) -> Self::Output {
self.dot_with(other).expect(&format!(
"vector of length {} can't be dotted to vector of length {}",
self.len(),
other.len()
))
}
}
impl<T: Deref<Target = [usize]>> fmt::Display for SparseBinVecBase<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self.positions.deref())
}
}
#[derive(Debug, Clone)]
pub struct NonTrivialPositions<'vec> {
positions: &'vec [usize],
index: usize,
}
impl<'vec> Iterator for NonTrivialPositions<'vec> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
self.positions.get(self.index).map(|position| {
self.index += 1;
*position
})
}
}
#[derive(Debug, Clone)]
pub struct IterDense<'vec, T> {
vec: &'vec SparseBinVecBase<T>,
index: usize,
}
impl<'vec, T> Iterator for IterDense<'vec, T>
where
T: Deref<Target = [usize]>,
{
type Item = BinNum;
fn next(&mut self) -> Option<Self::Item> {
let value = self.vec.get(self.index);
self.index += 1;
value
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn addition() {
let first_vector = SparseBinVec::new(6, vec![0, 2, 4]);
let second_vector = SparseBinVec::new(6, vec![0, 1, 2]);
let sum = SparseBinVec::new(6, vec![1, 4]);
assert_eq!(&first_vector + &second_vector, sum);
}
#[test]
fn panics_on_addition_if_different_length() {
let vector_6 = SparseBinVec::new(6, vec![0, 2, 4]);
let vector_2 = SparseBinVec::new(2, vec![0]);
let result = std::panic::catch_unwind(|| &vector_6 + &vector_2);
assert!(result.is_err());
}
#[test]
fn ser_de() {
let vector = SparseBinVec::new(10, vec![0, 5, 7, 8]);
let json = serde_json::to_string(&vector).unwrap();
let expected = String::from("{\"positions\":[0,5,7,8],\"length\":10}");
assert_eq!(json, expected);
}
}