use std::ops::Mul;
use std::sync::Arc;
use aligned_vec::{AVec, ConstAlign};
use num_traits::{Float, One, Zero};
use super::Align64;
use crate::{Storage, StorageFor};
pub struct DenseStorage<T = f64> {
data: Arc<AVec<T, Align64>>,
}
impl<T> Clone for DenseStorage<T> {
fn clone(&self) -> Self {
Self {
data: Arc::clone(&self.data),
}
}
}
impl<T> DenseStorage<T> {
pub fn new(data: Vec<T>) -> Self {
let len = data.len();
let mut aligned: AVec<T, ConstAlign<64>> = AVec::with_capacity(64, len);
for elem in data {
aligned.push(elem);
}
Self {
data: Arc::new(aligned),
}
}
pub(crate) fn from_aligned(data: AVec<T, ConstAlign<64>>) -> Self {
Self {
data: Arc::new(data),
}
}
pub fn data(&self) -> &[T] {
&self.data[..]
}
pub fn data_mut(&mut self) -> &mut [T]
where
T: Clone,
{
Arc::make_mut(&mut self.data).as_mut_slice()
}
pub fn iter(&self) -> std::slice::Iter<'_, T> {
self.data[..].iter()
}
pub fn as_ptr(&self) -> *const T {
self.data.as_ptr()
}
pub fn as_mut_ptr(&mut self) -> *mut T
where
T: Clone,
{
Arc::make_mut(&mut self.data).as_mut_ptr()
}
}
impl<T> Storage for DenseStorage<T> {
type Element = T;
fn flat_len(&self) -> usize {
self.data.len()
}
}
impl<T> StorageFor<crate::DenseLayout> for DenseStorage<T> {}
impl<T> DenseStorage<T>
where
T: Clone,
{
pub(crate) fn fill(&mut self, value: T) {
Arc::make_mut(&mut self.data).as_mut_slice().fill(value);
}
pub(crate) fn map_mut<F>(&mut self, f: F)
where
F: Fn(&T) -> T,
{
let data = Arc::make_mut(&mut self.data).as_mut_slice();
for x in data.iter_mut() {
*x = f(x);
}
}
pub(crate) fn scale<S>(&mut self, factor: S)
where
T: Mul<S, Output = T>,
S: Clone,
{
let data = Arc::make_mut(&mut self.data).as_mut_slice();
for elem in data.iter_mut() {
*elem = elem.clone() * factor.clone();
}
}
}
impl<T> DenseStorage<T>
where
T: ariadnetor_core::Scalar,
{
fn norm_squared(&self) -> T::Real {
self.data
.iter()
.map(|&x| {
let a = x.abs();
a * a
})
.fold(T::Real::zero(), |acc, x| acc + x)
}
pub(crate) fn norm_frobenius(&self) -> T::Real {
self.norm_squared().sqrt()
}
pub(crate) fn norm(&self) -> T::Real {
self.norm_frobenius()
}
pub(crate) fn normalize(&mut self) -> T::Real {
let norm = self.norm_frobenius();
assert!(norm != T::Real::zero(), "Cannot normalize zero tensor");
let inv_norm = T::Real::one() / norm;
let data = Arc::make_mut(&mut self.data);
for elem in data.iter_mut() {
*elem = elem.scale_real(inv_norm);
}
norm
}
}