use ahash::AHashMap;
use rayon::iter::ParallelBridge;
use rayon::prelude::*;
use rayon::slice::ParallelSliceMut;
use std::ops::{Add, BitAnd, Div, Mul, Sub};
use super::dense::{Tensor as TensorDense, checked_num_elements};
use super::tensor_trait::TensorTrait;
use crate::math::scalar::{Scalar, ScalarCastError};
#[derive(Clone, Debug)]
pub struct Tensor<T: Scalar> {
shape: Vec<usize>,
data: AHashMap<usize, T>, }
impl<T: Scalar> Tensor<T> {
#[inline(always)]
pub fn len_dense(&self) -> usize {
checked_num_elements(&self.shape, "sparse tensor")
}
#[inline(always)]
pub fn rank(&self) -> usize {
self.shape.len()
}
#[inline(always)]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[inline(always)]
pub fn nnz(&self) -> usize {
self.data.len()
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
#[inline(always)]
fn wrap_axis_index(idx: isize, dim: usize) -> usize {
debug_assert!(dim > 0);
let d = dim as isize;
let mut m = idx % d;
if m < 0 {
m += d;
}
m as usize
}
impl<T: Scalar> Tensor<T> {
#[inline]
pub fn index(&self, idx: &[isize]) -> usize {
assert_eq!(idx.len(), self.shape.len(), "Index rank mismatch");
let mut flat = 0usize;
let mut stride = 1usize;
for (&dim, &a_raw) in self.shape.iter().rev().zip(idx.iter().rev()) {
let a = wrap_axis_index(a_raw, dim);
flat += a * stride;
stride *= dim;
}
flat
}
#[inline]
pub(crate) fn get_opt(&self, idx: &[isize]) -> Option<&T> {
let k = self.index(idx);
self.data.get(&k)
}
#[inline]
pub fn get(&self, idx: &[isize]) -> T {
self.get_opt(idx).copied().unwrap_or_else(T::zero)
}
#[inline]
pub(crate) fn get_mut_or_insert_zero(&mut self, idx: &[isize]) -> &mut T {
let k = self.index(idx);
self.data.entry(k).or_insert_with(T::zero)
}
#[inline]
pub fn set(&mut self, idx: &[isize], val: T) {
let k = self.index(idx);
if val == T::zero() {
self.data.remove(&k);
} else {
self.data.insert(k, val);
}
}
#[inline]
pub(crate) fn iter(&self) -> impl Iterator<Item = (&usize, &T)> {
self.data.iter()
}
#[inline]
fn from_flat_pairs(shape: Vec<usize>, pairs: Vec<(usize, T)>) -> Self {
let size = checked_num_elements(&shape, "sparse tensor from flat pairs");
let mut map = AHashMap::with_capacity(pairs.len());
for (k, v) in pairs {
assert!(
k < size,
"sparse flat index out of bounds: {k} >= dense size {size}"
);
if v != T::zero() {
map.insert(k, v);
}
}
Self { shape, data: map }
}
}
macro_rules! impl_sparse_binop {
($trait:ident, $method:ident, $op:tt) => {
impl<T> $trait for Tensor<T>
where
T: Scalar + $trait<Output = T> + Send + Sync,
{
type Output = Self;
#[inline]
fn $method(self, rhs: Self) -> Self::Output {
assert_eq!(self.shape, rhs.shape, "Tensor shape mismatch");
let mut keys: Vec<usize> =
Vec::with_capacity(self.data.len() + rhs.data.len());
keys.extend(self.data.keys().copied());
keys.extend(rhs.data.keys().copied());
keys.par_sort_unstable();
keys.dedup();
let out_pairs: Vec<(usize, T)> = keys
.into_par_iter()
.filter_map(|k| {
let a = self.data.get(&k).copied().unwrap_or_else(T::zero);
let b = rhs.data.get(&k).copied().unwrap_or_else(T::zero);
let r = a $op b;
if r == T::zero() {
None
} else {
Some((k, r))
}
})
.collect();
Self::from_flat_pairs(self.shape, out_pairs)
}
}
};
}
impl_sparse_binop!(Add, add, +);
impl_sparse_binop!(Sub, sub, -);
impl_sparse_binop!(Mul, mul, *);
impl_sparse_binop!(Div, div, /);
impl<T> BitAnd for Tensor<T>
where
T: Scalar + BitAnd<Output = T> + Send + Sync,
{
type Output = Self;
#[inline]
fn bitand(self, rhs: Self) -> Self::Output {
assert_eq!(self.shape, rhs.shape, "Tensor shape mismatch");
let mut keys: Vec<usize> = Vec::with_capacity(self.data.len() + rhs.data.len());
keys.extend(self.data.keys().copied());
keys.extend(rhs.data.keys().copied());
keys.par_sort_unstable();
keys.dedup();
let out_pairs: Vec<(usize, T)> = keys
.into_par_iter()
.filter_map(|k| {
let a = self.data.get(&k).copied().unwrap_or_else(T::zero);
let b = rhs.data.get(&k).copied().unwrap_or_else(T::zero);
let r = a & b;
if r == T::zero() { None } else { Some((k, r)) }
})
.collect();
Self::from_flat_pairs(self.shape, out_pairs)
}
}
macro_rules! impl_sparse_scalar_binop_rhs_scalar {
($trait:ident, $method:ident, $op:tt) => {
impl<T> $trait<T> for Tensor<T>
where
T: Scalar + $trait<Output = T> + Send + Sync,
{
type Output = Self;
#[inline]
fn $method(self, rhs: T) -> Self::Output {
let out_pairs: Vec<(usize, T)> = self
.data
.into_iter()
.par_bridge()
.map(|(k, v)| (k, v $op rhs))
.filter(|&(_, v)| v != T::zero())
.collect();
Self::from_flat_pairs(self.shape, out_pairs)
}
}
};
}
impl_sparse_scalar_binop_rhs_scalar!(Add, add, +);
impl_sparse_scalar_binop_rhs_scalar!(Sub, sub, -);
impl_sparse_scalar_binop_rhs_scalar!(Mul, mul, *);
impl_sparse_scalar_binop_rhs_scalar!(Div, div, /);
impl<T: Scalar> Tensor<T> {
pub fn try_cast_to<U: Scalar>(&self) -> Result<Tensor<U>, ScalarCastError> {
let out_pairs: Result<Vec<(usize, U)>, _> = self
.data
.par_iter()
.map(|(&k, &v)| v.try_cast::<U>().map(|u| (k, u)))
.filter_map(|res| match res {
Ok((k, v)) if v != U::zero() => Some(Ok((k, v))), Ok(_) => None,
Err(e) => Some(Err(e)),
})
.collect();
Ok(Tensor::<U>::from_flat_pairs(self.shape.clone(), out_pairs?))
}
#[inline]
pub fn cast_to<U: Scalar>(&self) -> Tensor<U> {
self.try_cast_to::<U>()
.expect("sparse tensor cast failed: component out of range for target type")
}
}
impl<T: Scalar> Tensor<T> {
pub fn from_triplets(
shape: Vec<usize>,
triplets: impl IntoIterator<Item = (Vec<usize>, T)>,
) -> Self {
fn index_of(shape: &[usize], idx: &[usize]) -> usize {
assert_eq!(idx.len(), shape.len(), "Triplet index rank mismatch");
let mut flat = 0usize;
let mut stride = 1usize;
for (&dim, &a) in shape.iter().rev().zip(idx.iter().rev()) {
assert!(a < dim, "Index out of bounds on an axis: {} >= {}", a, dim);
flat += a * stride;
stride *= dim;
}
flat
}
assert!(!shape.is_empty(), "Tensor rank must be >= 1");
assert!(
shape.iter().all(|&d| d > 0),
"All dimensions must be > 0; got {shape:?}"
);
checked_num_elements(&shape, "sparse tensor from triplets");
let mut map = AHashMap::default();
for (idx, v) in triplets {
if v == T::zero() {
continue;
}
let k = index_of(&shape, &idx);
map.insert(k, v);
}
Self { shape, data: map }
}
#[inline]
pub fn to_dense(&self) -> TensorDense<T> {
let size: usize = self.len_dense();
let mut out = vec![T::zero(); size];
for (&k, &v) in &self.data {
out[k] = v;
}
TensorDense::from_parts_unchecked(self.shape.clone(), out)
}
#[inline]
pub fn from_dense(dense: &TensorDense<T>) -> Self {
let shape = dense.shape().to_vec();
let size = checked_num_elements(&shape, "sparse tensor from dense");
assert_eq!(size, dense.data().len(), "Dense size/shape mismatch");
let pairs: Vec<(usize, T)> = dense
.data()
.iter()
.copied()
.enumerate()
.filter_map(|(k, v)| if v == T::zero() { None } else { Some((k, v)) })
.collect();
Self::from_flat_pairs(shape, pairs)
}
pub fn print(&self) {
println!(
"Sparse tensor: shape={:?}, dense_size={}, nnz={}",
self.shape,
self.len_dense(),
self.nnz()
);
if self.data.is_empty() {
println!(" all entries are implicit zero");
return;
}
let mut entries: Vec<(usize, T)> = self.data.iter().map(|(&k, &v)| (k, v)).collect();
entries.par_sort_unstable_by_key(|&(k, _)| k);
let shown = entries.len().min(32);
for (k, value) in entries.iter().take(shown) {
println!(
" [{:?}] flat={} value={}",
self.unravel_index(*k),
k,
value
);
}
if entries.len() > shown {
println!(" ... {} more stored entries", entries.len() - shown);
}
}
fn unravel_index(&self, flat: usize) -> Vec<usize> {
let mut rem = flat;
let mut idx = vec![0usize; self.shape.len()];
for axis in (0..self.shape.len()).rev() {
let dim = self.shape[axis];
idx[axis] = rem % dim;
rem /= dim;
}
idx
}
}
impl<T> TensorTrait<T> for Tensor<T>
where
T: Scalar,
{
type Repr<U: Scalar> = Tensor<U>;
#[inline]
fn empty(shape: &[usize]) -> Self {
checked_num_elements(shape, "sparse tensor");
Self {
shape: shape.to_vec(),
data: AHashMap::default(),
}
}
fn get_sum(&self) -> T {
self.data
.par_iter()
.map(|(_, &x)| x)
.reduce(|| T::zero(), |acc, x| acc + x)
}
#[inline]
fn shape(&self) -> &[usize] {
&self.shape
}
#[inline(always)]
fn index(&self, indices: &[isize]) -> usize {
Tensor::<T>::index(self, indices)
}
#[inline(always)]
fn get(&self, indices: &[isize]) -> T {
Tensor::<T>::get(self, indices)
}
#[inline(always)]
fn get_mut(&mut self, indices: &[isize]) -> &mut T {
self.get_mut_or_insert_zero(indices)
}
#[inline(always)]
fn set(&mut self, indices: &[isize], val: T) {
Tensor::<T>::set(self, indices, val)
}
#[inline]
fn par_fill(&mut self, value: T)
where
T: Copy + Send + Sync,
{
if value == T::zero() {
self.data.clear();
return;
}
let keys: Vec<usize> = self.data.keys().copied().collect();
let mut new_map = AHashMap::with_capacity(keys.len());
for k in keys {
new_map.insert(k, value);
}
self.data = new_map;
}
#[inline]
fn par_map_in_place<F>(&mut self, f: F)
where
T: Copy + Send + Sync,
F: Fn(T) -> T + Sync + Send,
{
let pairs: Vec<(usize, T)> = self.iter().map(|(&k, &v)| (k, v)).collect();
let mapped: Vec<(usize, T)> = pairs
.into_par_iter()
.map(|(k, v)| (k, f(v)))
.filter(|&(_, v)| v != T::zero())
.collect();
self.data.clear();
for (k, v) in mapped {
self.data.insert(k, v);
}
}
#[inline]
fn par_zip_with_inplace<F, Rhs>(&mut self, other: &Rhs, f: F)
where
Rhs: TensorTrait<T>,
T: Copy + Send + Sync,
F: Fn(T, T) -> T + Sync + Send,
{
assert_eq!(self.shape(), other.shape(), "Tensor shape mismatch");
let rank = self.shape.len();
let dims = self.shape.clone();
let pairs: Vec<(usize, T)> = self.iter().map(|(&k, &v)| (k, v)).collect();
let zipped: Vec<(usize, T)> = pairs
.into_par_iter()
.map(|(k, a)| {
let mut rem = k;
let mut idx = vec![0isize; rank];
for ax in (0..rank).rev() {
let d = dims[ax];
idx[ax] = (rem % d) as isize;
rem /= d;
}
let b = other.get(&idx);
(k, f(a, b))
})
.filter(|&(_, r)| r != T::zero())
.collect();
self.data.clear();
for (k, v) in zipped {
self.data.insert(k, v);
}
}
#[inline]
fn try_cast_to<U: Scalar>(&self) -> Result<Self::Repr<U>, ScalarCastError>
where
T: Copy + Send + Sync,
{
Tensor::<T>::try_cast_to::<U>(self)
}
fn print(&self) {
Tensor::<T>::print(self);
}
}