use std::ops::{Mul, MulAssign};
use ariadnetor_core::Scalar;
use num_traits::{One, Zero};
use super::Tensor;
use crate::{DenseLayout, DenseStorage, DenseTensorData, TensorData};
impl<S> Tensor<DenseStorage<S>, DenseLayout> {
pub fn data_slice(&self) -> &[S] {
self.data.storage().data()
}
pub fn data_slice_mut(&mut self) -> &mut [S]
where
S: Clone,
{
self.data.storage_mut().data_mut()
}
pub fn reshape(&self, new_shape: Vec<usize>) -> Self {
let new_layout = DenseLayout::new(new_shape, self.data.layout().order());
let new_storage = self.data.storage().clone();
Self::from_data(TensorData::new(new_storage, new_layout))
}
}
impl<S: Scalar> Tensor<DenseStorage<S>, DenseLayout> {
pub fn order(&self) -> ariadnetor_core::backend::MemoryOrder {
self.data.layout().order()
}
pub fn get(&self, indices: impl AsRef<[usize]>) -> S {
let indices = indices.as_ref();
let shape = self.shape();
assert_eq!(
indices.len(),
shape.len(),
"Tensor::get: indices length {} doesn't match rank {}",
indices.len(),
shape.len(),
);
for (axis, (&idx, &dim)) in indices.iter().zip(shape).enumerate() {
assert!(
idx < dim,
"Tensor::get: index {idx} out of bounds for axis {axis} with size {dim}",
);
}
let order = self.order();
let flat = crate::flat_index(indices, shape, order);
self.data.storage().data()[flat]
}
pub fn set(&mut self, indices: impl AsRef<[usize]>, value: S) {
let indices = indices.as_ref();
let flat = {
let shape = self.shape();
assert_eq!(
indices.len(),
shape.len(),
"Tensor::set: indices length {} doesn't match rank {}",
indices.len(),
shape.len(),
);
for (axis, (&idx, &dim)) in indices.iter().zip(shape).enumerate() {
assert!(
idx < dim,
"Tensor::set: index {idx} out of bounds for axis {axis} with size {dim}",
);
}
crate::flat_index(indices, shape, self.order())
};
self.data.storage_mut().data_mut()[flat] = value;
}
pub fn fill(&mut self, value: S) {
for slot in self.data.storage_mut().data_mut().iter_mut() {
*slot = value;
}
}
}
impl<S: Clone> Tensor<DenseStorage<S>, DenseLayout> {
pub fn scale<F>(&mut self, factor: F)
where
S: Mul<F, Output = S>,
F: Clone,
{
for slot in self.data.storage_mut().data_mut().iter_mut() {
*slot = slot.clone() * factor.clone();
}
}
pub fn scaled<F>(&self, factor: F) -> Self
where
S: Mul<F, Output = S>,
F: Clone,
{
let new_data: Vec<S> = self
.data
.storage()
.data()
.iter()
.map(|x| x.clone() * factor.clone())
.collect();
let shape = self.shape().to_vec();
let order = self.data.layout().order();
let td = DenseTensorData::from_raw_parts(new_data, shape, order);
Self::from_data(td)
}
}
impl<S> Mul<S> for Tensor<DenseStorage<S>, DenseLayout>
where
S: Clone + Mul<Output = S>,
{
type Output = Tensor<DenseStorage<S>, DenseLayout>;
fn mul(mut self, rhs: S) -> Self::Output {
self.scale(rhs);
self
}
}
impl<S> Mul<S> for &Tensor<DenseStorage<S>, DenseLayout>
where
S: Clone + Mul<Output = S>,
{
type Output = Tensor<DenseStorage<S>, DenseLayout>;
fn mul(self, rhs: S) -> Self::Output {
self.scaled(rhs)
}
}
impl<S> MulAssign<S> for Tensor<DenseStorage<S>, DenseLayout>
where
S: Clone + Mul<Output = S>,
{
fn mul_assign(&mut self, rhs: S) {
self.scale(rhs);
}
}
impl<S> Tensor<DenseStorage<S>, DenseLayout>
where
S: Scalar,
{
pub fn norm(&self) -> S::Real {
let mut sq = S::Real::zero();
for &x in self.data.storage().data() {
let a = x.abs();
sq = sq + a * a;
}
<S::Real as num_traits::Float>::sqrt(sq)
}
pub fn normalize(&mut self) -> S::Real {
let norm = self.norm();
assert!(norm != S::Real::zero(), "Cannot normalize zero tensor");
let inv_norm = S::Real::one() / norm;
for slot in self.data.storage_mut().data_mut().iter_mut() {
*slot = slot.scale_real(inv_norm);
}
norm
}
pub fn normalized(&self) -> (Self, S::Real) {
let mut clone = self.clone();
let n = clone.normalize();
(clone, n)
}
pub fn conj(&self) -> Self {
Self {
data: self.data.conj(),
}
}
pub fn reordered(&self, to: ariadnetor_core::backend::MemoryOrder) -> Self {
let reordered = crate::reorder::reorder_data(&self.data, to);
Self { data: reordered }
}
pub fn reshape_logical(&self, new_shape: Vec<usize>) -> Self {
let orig_order = self.order();
self.reordered(ariadnetor_core::backend::MemoryOrder::RowMajor)
.reshape(new_shape)
.reordered(orig_order)
}
pub fn fuse_legs(&self, range: std::ops::Range<usize>) -> Self {
let shape = self.shape();
let rank = shape.len();
assert!(
range.start < range.end && range.end <= rank,
"fuse_legs: range {range:?} out of bounds for rank {rank}",
);
let fused: usize = shape[range.clone()].iter().product();
let mut new_shape = shape[..range.start].to_vec();
new_shape.push(fused);
new_shape.extend_from_slice(&shape[range.end..]);
self.reshape_logical(new_shape)
}
pub fn split_leg(&self, axis: usize, into: &[usize]) -> Self {
let shape = self.shape();
let rank = shape.len();
assert!(
axis < rank,
"split_leg: axis {axis} out of bounds for rank {rank}",
);
assert!(!into.is_empty(), "split_leg: `into` must be non-empty");
let prod: usize = into.iter().product();
assert_eq!(
prod, shape[axis],
"split_leg: product of {into:?} != axis {axis} extent {}",
shape[axis],
);
let mut new_shape = shape[..axis].to_vec();
new_shape.extend_from_slice(into);
new_shape.extend_from_slice(&shape[axis + 1..]);
self.reshape_logical(new_shape)
}
}