use crate::host::RawShape;
use crate::host::SliceInfo;
use anyhow::anyhow;
use bitvec::prelude::*;
use ndarray::{prelude::*, RemoveAxis};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Serialize, Deserialize, Hash, Clone, PartialEq)]
pub struct BitArrayRepr {
pub data: Arc<BitVec<u8, Lsb0>>,
pub dim: Arc<IxDyn>,
}
fn do_collapse_axis<D: Dimension>(dims: &mut D, strides: &D, axis: usize, index: usize) -> isize {
let dim = dims.slice()[axis];
let stride = strides.slice()[axis];
assert!(
index < dim,
"collapse_axis: Index {} must be less than axis length {} for \
array with shape {:?}",
index,
dim,
*dims
);
dims.slice_mut()[axis] = 1;
(index as isize * (stride as isize)) as isize
}
impl BitArrayRepr {
pub fn new_with_shape(dim: Arc<IxDyn>) -> Self {
let data = BitVec::repeat(false, dim.size());
BitArrayRepr {
data: Arc::new(data),
dim,
}
}
pub fn from_raw(data: BitVec<u8, Lsb0>, dim: IxDyn) -> Self {
BitArrayRepr {
data: Arc::new(data),
dim: Arc::new(dim),
}
}
pub fn from_vec(vec: Vec<u8>, shape: &RawShape) -> Self {
let data: BitVec<u8, Lsb0> = vec.iter().map(|&ai| ai != 0).collect();
let dim = IxDyn(&shape.0);
BitArrayRepr {
data: Arc::new(data),
dim: Arc::new(dim),
}
}
pub fn from_elem(shape: &RawShape, elem: u8) -> Self {
let dim = IxDyn(&shape.0);
let data = BitVec::repeat(elem != 0, dim.size());
BitArrayRepr {
data: Arc::new(data),
dim: Arc::new(dim),
}
}
pub fn ndim(&self) -> usize {
self.dim.ndim()
}
pub fn shape(&self) -> &[usize] {
self.dim.slice()
}
pub fn into_array<T: From<u8>>(&self) -> anyhow::Result<ArrayD<T>> {
Array::from_iter(
self.data
.iter()
.map(|item| if *item { T::from(1) } else { T::from(0) }),
)
.into_shape(IxDyn(self.shape()))
.map_err(|e| anyhow!("Invalid shape {}", e))
}
pub fn index_axis(&self, axis: usize, index: usize) -> BitArrayRepr {
let mut dim = IxDyn(self.dim.slice());
let strides = dim.default_strides();
let offset = do_collapse_axis(&mut dim, &strides, axis, index) as usize;
let new_dim = self.dim.remove_axis(Axis(axis));
let new_ptr = self.data.as_bitslice();
let data = {
if new_dim.size() > 0 {
Arc::new(BitVec::from_bitslice(
&new_ptr[offset..offset + new_dim.size()],
))
} else {
Arc::new(BitVec::from_bitslice(&new_ptr[offset..offset + 1]))
}
};
BitArrayRepr {
data,
dim: Arc::new(new_dim),
}
}
pub fn into_diag(&self) -> BitArrayRepr {
let len = self.dim.slice().iter().cloned().min().unwrap_or(1);
let mut data: BitVec<u8, Lsb0> = BitVec::EMPTY;
match len {
1 => data.push(self.data[0]),
2 => {
data.push(self.data[0]);
let pos =
IxDyn::stride_offset(&IxDyn(&[1, 1]), &self.dim.default_strides()) as usize;
data.push(self.data[pos])
}
_ => todo!(),
};
BitArrayRepr {
data: Arc::new(data),
dim: Arc::new(IxDyn(&[len])),
}
}
pub(crate) fn slice(&self, _info: SliceInfo) -> anyhow::Result<BitArrayRepr> {
Err(anyhow::anyhow!("slicing not implemented for BitArray yet"))
}
pub(crate) fn reversed_axes(&self) -> anyhow::Result<BitArrayRepr> {
let mut dim = IxDyn(self.dim.slice());
let mut new_data = (*self.data).clone();
let default_strides = dim.default_strides();
let fortran_strides = dim.fortran_strides();
match dim.ndim() {
0 => (),
1 => (),
2 => {
for i in 0..dim[0] {
for j in 0..dim[1] {
new_data.set(
j * fortran_strides[1] + i * fortran_strides[0],
(*self.data)[i * default_strides[0] + j * default_strides[1]],
);
}
}
}
3 => {
for i in 0..dim[0] {
for j in 0..dim[1] {
for k in 0..dim[2] {
new_data.set(
k * fortran_strides[2]
+ j * fortran_strides[1]
+ i * fortran_strides[0],
(*self.data)[i * default_strides[0]
+ j * default_strides[1]
+ k * default_strides[2]],
)
}
}
}
}
_ => {
return Err(anyhow::anyhow!(
"tranposing not implemented for 4D tensors or larger yet"
))
}
}
dim.slice_mut().reverse();
Ok(BitArrayRepr {
data: Arc::new(new_data),
dim: Arc::new(dim),
})
}
}
impl std::ops::BitXor for &BitArrayRepr {
type Output = BitArrayRepr;
fn bitxor(self, rhs: Self) -> Self::Output {
let mut data = (*self.data).clone();
data ^= Arc::as_ref(&rhs.data);
BitArrayRepr {
data: Arc::new(data),
dim: self.dim.clone(),
}
}
}
impl std::ops::Not for &BitArrayRepr {
type Output = BitArrayRepr;
fn not(self) -> Self::Output {
let data = !(*self.data).clone();
BitArrayRepr {
data: Arc::new(data),
dim: self.dim.clone(),
}
}
}
impl std::ops::BitAnd for &BitArrayRepr {
type Output = BitArrayRepr;
fn bitand(self, rhs: Self) -> Self::Output {
let mut data = (*self.data).clone();
data &= Arc::as_ref(&rhs.data);
BitArrayRepr {
data: Arc::new(data),
dim: self.dim.clone(),
}
}
}
impl std::ops::BitOr for &BitArrayRepr {
type Output = BitArrayRepr;
fn bitor(self, rhs: Self) -> Self::Output {
let mut data = (*self.data).clone();
data |= Arc::as_ref(&rhs.data);
BitArrayRepr {
data: Arc::new(data),
dim: self.dim.clone(),
}
}
}