use core::ops::{Index, IndexMut};
use crate::array2::Array2;
use crate::error::{Error, Result};
use crate::numeric::Float;
use crate::rand::SmallRng;
use crate::view2::{ArrayView2, ArrayViewMut2};
use crate::view3::{ArrayView3, ArrayViewMut3};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Axis3 {
Axis0,
Axis1,
Axis2,
}
impl Axis3 {
pub(crate) fn index(self) -> usize {
match self {
Self::Axis0 => 0,
Self::Axis1 => 1,
Self::Axis2 => 2,
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct Array3<T> {
data: Vec<T>,
shape: [usize; 3],
}
impl<T> Array3<T> {
pub fn from_vec(shape: [usize; 3], data: Vec<T>) -> Result<Self> {
let expected = shape
.iter()
.try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
.ok_or(Error::DimensionTooLarge)?;
if data.len() != expected {
return Err(Error::shape(vec![expected], vec![data.len()]));
}
Ok(Self { data, shape })
}
pub fn from_fn(shape: [usize; 3], mut f: impl FnMut(usize, usize, usize) -> T) -> Self {
let mut data = Vec::with_capacity(shape.iter().product());
for i in 0..shape[0] {
for j in 0..shape[1] {
for k in 0..shape[2] {
data.push(f(i, j, k));
}
}
}
Self { data, shape }
}
pub fn try_from_fn(
shape: [usize; 3],
mut f: impl FnMut(usize, usize, usize) -> T,
) -> Result<Self> {
let len = checked_len(shape)?;
let mut data = Vec::new();
data.try_reserve_exact(len)
.map_err(|_| Error::AllocationFailed)?;
for i in 0..shape[0] {
for j in 0..shape[1] {
for k in 0..shape[2] {
data.push(f(i, j, k));
}
}
}
Ok(Self { data, shape })
}
pub fn shape(&self) -> [usize; 3] {
self.shape
}
pub fn strides(&self) -> [isize; 3] {
[
(self.shape[1] * self.shape[2]) as isize,
self.shape[2] as isize,
1,
]
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn as_slice(&self) -> &[T] {
&self.data
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
&mut self.data
}
pub fn view(&self) -> ArrayView3<'_, T> {
ArrayView3::from_raw_parts(&self.data, self.shape, self.strides(), 0)
}
pub fn view_mut(&mut self) -> ArrayViewMut3<'_, T> {
ArrayViewMut3::from_raw_parts(
&mut self.data,
self.shape,
[
(self.shape[1] * self.shape[2]) as isize,
self.shape[2] as isize,
1,
],
0,
)
}
pub fn get(&self, i: usize, j: usize, k: usize) -> Option<&T> {
(i < self.shape[0] && j < self.shape[1] && k < self.shape[2])
.then(|| &self.data[self.linear_index(i, j, k)])
}
pub fn get_mut(&mut self, i: usize, j: usize, k: usize) -> Option<&mut T> {
if i >= self.shape[0] || j >= self.shape[1] || k >= self.shape[2] {
return None;
}
let idx = self.linear_index(i, j, k);
Some(&mut self.data[idx])
}
pub fn matrix_at(&self, axis: Axis3, index: usize) -> Result<ArrayView2<'_, T>> {
self.view().matrix_at(axis.index(), index)
}
pub fn for_each_matrix(
&self,
axis: Axis3,
f: impl FnMut(usize, ArrayView2<'_, T>) -> Result<()>,
) -> Result<()> {
self.view().for_each_matrix(axis.index(), f)
}
pub fn matrix_at_mut(&mut self, axis: Axis3, index: usize) -> Result<ArrayViewMut2<'_, T>> {
let axis = axis.index();
if index >= self.shape[axis] {
return Err(Error::IndexOutOfBounds);
}
let strides = self.strides();
let axes: Vec<usize> = (0..3).filter(|&candidate| candidate != axis).collect();
ArrayViewMut2::new(
&mut self.data,
[self.shape[axes[0]], self.shape[axes[1]]],
[strides[axes[0]], strides[axes[1]]],
index as isize * strides[axis],
)
}
pub fn for_each_matrix_mut(
&mut self,
axis: Axis3,
mut f: impl FnMut(usize, ArrayViewMut2<'_, T>) -> Result<()>,
) -> Result<()> {
let axis = axis.index();
if axis >= 3 {
return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
}
for index in 0..self.shape[axis] {
f(
index,
self.matrix_at_mut(
match axis {
0 => Axis3::Axis0,
1 => Axis3::Axis1,
_ => Axis3::Axis2,
},
index,
)?,
)?;
}
Ok(())
}
fn linear_index(&self, i: usize, j: usize, k: usize) -> usize {
(i * self.shape[1] + j) * self.shape[2] + k
}
}
impl<T: Clone> Array3<T> {
pub fn filled(shape: [usize; 3], value: T) -> Self {
Self {
data: vec![value; shape.iter().product()],
shape,
}
}
pub fn try_filled(shape: [usize; 3], value: T) -> Result<Self> {
let len = checked_len(shape)?;
let mut data = Vec::new();
data.try_reserve_exact(len)
.map_err(|_| Error::AllocationFailed)?;
data.resize(len, value);
Ok(Self { data, shape })
}
pub fn clone_contiguous(view: ArrayView3<'_, T>) -> Self {
Self::from_fn(view.shape(), |i, j, k| view[(i, j, k)].clone())
}
pub fn unfold(&self, axis: Axis3) -> Array2<T> {
unfold_view(self.view(), axis)
}
pub fn fold(axis: Axis3, shape: [usize; 3], matrix: ArrayView2<'_, T>) -> Result<Self> {
fold_view(axis, shape, matrix)
}
}
pub fn unfold_view<T: Clone>(a: ArrayView3<'_, T>, axis: Axis3) -> Array2<T> {
let shape = a.shape();
match axis {
Axis3::Axis0 => Array2::from_fn([shape[0], shape[1] * shape[2]], |row, col| {
let j = col / shape[2];
let k = col % shape[2];
a[(row, j, k)].clone()
}),
Axis3::Axis1 => Array2::from_fn([shape[1], shape[0] * shape[2]], |row, col| {
let i = col / shape[2];
let k = col % shape[2];
a[(i, row, k)].clone()
}),
Axis3::Axis2 => Array2::from_fn([shape[2], shape[0] * shape[1]], |row, col| {
let i = col / shape[1];
let j = col % shape[1];
a[(i, j, row)].clone()
}),
}
}
pub fn fold_view<T: Clone>(
axis: Axis3,
shape: [usize; 3],
matrix: ArrayView2<'_, T>,
) -> Result<Array3<T>> {
let expected = match axis {
Axis3::Axis0 => [shape[0], shape[1] * shape[2]],
Axis3::Axis1 => [shape[1], shape[0] * shape[2]],
Axis3::Axis2 => [shape[2], shape[0] * shape[1]],
};
if matrix.shape() != expected {
return Err(Error::shape(expected, matrix.shape()));
}
Ok(Array3::from_fn(shape, |i, j, k| match axis {
Axis3::Axis0 => matrix[(i, j * shape[2] + k)].clone(),
Axis3::Axis1 => matrix[(j, i * shape[2] + k)].clone(),
Axis3::Axis2 => matrix[(k, i * shape[1] + j)].clone(),
}))
}
impl<T: Float> Array3<T> {
pub fn zeros(shape: [usize; 3]) -> Self {
Self::filled(shape, T::zero())
}
pub fn try_zeros(shape: [usize; 3]) -> Result<Self> {
Self::try_filled(shape, T::zero())
}
pub fn ones(shape: [usize; 3]) -> Self {
Self::filled(shape, T::one())
}
pub fn try_ones(shape: [usize; 3]) -> Result<Self> {
Self::try_filled(shape, T::one())
}
pub fn zeros_like(&self) -> Self {
Self::zeros(self.shape)
}
pub fn scale(&mut self, alpha: T) {
for value in &mut self.data {
*value *= alpha;
}
}
pub fn scaled(&self, alpha: T) -> Self {
Self::from_fn(self.shape, |i, j, k| self[(i, j, k)] * alpha)
}
pub fn scaled_into(&self, alpha: T, mut out: ArrayViewMut3<'_, T>) -> Result<()> {
if self.shape != out.shape() {
return Err(Error::shape(self.shape, out.shape()));
}
for i in 0..self.shape[0] {
for j in 0..self.shape[1] {
for k in 0..self.shape[2] {
out[(i, j, k)] = self[(i, j, k)] * alpha;
}
}
}
Ok(())
}
pub fn add(&self, other: ArrayView3<'_, T>) -> Result<Self> {
self.zip_map(other, |left, right| left + right)
}
pub fn add_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
self.zip_map_into(other, out, |left, right| left + right)
}
pub fn sub(&self, other: ArrayView3<'_, T>) -> Result<Self> {
self.zip_map(other, |left, right| left - right)
}
pub fn sub_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
self.zip_map_into(other, out, |left, right| left - right)
}
pub fn mul(&self, other: ArrayView3<'_, T>) -> Result<Self> {
self.zip_map(other, |left, right| left * right)
}
pub fn mul_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
self.zip_map_into(other, out, |left, right| left * right)
}
pub fn hadamard(&self, other: ArrayView3<'_, T>) -> Result<Self> {
self.mul(other)
}
pub fn hadamard_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
self.mul_into(other, out)
}
pub fn div(&self, other: ArrayView3<'_, T>) -> Result<Self> {
self.zip_map(other, |left, right| left / right)
}
pub fn div_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
self.zip_map_into(other, out, |left, right| left / right)
}
pub fn axpy_result(&self, alpha: T, x: ArrayView3<'_, T>) -> Result<Self> {
self.zip_map(x, |left, right| left + alpha * right)
}
pub fn axpy_into(
&self,
alpha: T,
x: ArrayView3<'_, T>,
out: ArrayViewMut3<'_, T>,
) -> Result<()> {
self.zip_map_into(x, out, |left, right| left + alpha * right)
}
pub fn add_assign_view(&mut self, other: ArrayView3<'_, T>) -> Result<()> {
self.zip_map_inplace(other, |left, right| left + right)
}
pub fn sub_assign_view(&mut self, other: ArrayView3<'_, T>) -> Result<()> {
self.zip_map_inplace(other, |left, right| left - right)
}
pub fn mul_assign_view(&mut self, other: ArrayView3<'_, T>) -> Result<()> {
self.zip_map_inplace(other, |left, right| left * right)
}
pub fn div_assign_view(&mut self, other: ArrayView3<'_, T>) -> Result<()> {
self.zip_map_inplace(other, |left, right| left / right)
}
pub fn axpy(&mut self, alpha: T, x: ArrayView3<'_, T>) -> Result<()> {
if self.shape != x.shape() {
return Err(Error::shape(self.shape, x.shape()));
}
for i in 0..self.shape[0] {
for j in 0..self.shape[1] {
for k in 0..self.shape[2] {
self[(i, j, k)] += alpha * x[(i, j, k)];
}
}
}
Ok(())
}
pub fn zip_map(&self, other: ArrayView3<'_, T>, mut f: impl FnMut(T, T) -> T) -> Result<Self> {
if self.shape != other.shape() {
return Err(Error::shape(self.shape, other.shape()));
}
Ok(Self::from_fn(self.shape, |i, j, k| {
f(self[(i, j, k)], other[(i, j, k)])
}))
}
pub fn zip_map_into(
&self,
other: ArrayView3<'_, T>,
mut out: ArrayViewMut3<'_, T>,
mut f: impl FnMut(T, T) -> T,
) -> Result<()> {
if self.shape != other.shape() {
return Err(Error::shape(self.shape, other.shape()));
}
if self.shape != out.shape() {
return Err(Error::shape(self.shape, out.shape()));
}
for i in 0..self.shape[0] {
for j in 0..self.shape[1] {
for k in 0..self.shape[2] {
out[(i, j, k)] = f(self[(i, j, k)], other[(i, j, k)]);
}
}
}
Ok(())
}
pub fn map_inplace(&mut self, mut f: impl FnMut(T) -> T) {
for value in &mut self.data {
*value = f(*value);
}
}
pub fn zip_map_inplace(
&mut self,
other: ArrayView3<'_, T>,
mut f: impl FnMut(T, T) -> T,
) -> Result<()> {
if self.shape != other.shape() {
return Err(Error::shape(self.shape, other.shape()));
}
for i in 0..self.shape[0] {
for j in 0..self.shape[1] {
for k in 0..self.shape[2] {
self[(i, j, k)] = f(self[(i, j, k)], other[(i, j, k)]);
}
}
}
Ok(())
}
pub fn fill_uniform(&mut self, seed: u64) {
let mut rng = SmallRng::new(seed);
for value in &mut self.data {
*value = rng.uniform();
}
}
pub fn fill_randn(&mut self, seed: u64) {
let mut rng = SmallRng::new(seed);
for value in &mut self.data {
*value = rng.normal();
}
}
pub fn sum(&self) -> T {
self.data.iter().copied().sum()
}
pub fn mean(&self) -> T {
if self.is_empty() {
T::zero()
} else {
self.sum() / T::from_f64(self.len() as f64)
}
}
pub fn norm_frobenius(&self) -> T {
self.data
.iter()
.copied()
.map(|value| value * value)
.sum::<T>()
.sqrt()
}
pub fn max_abs(&self) -> T {
self.data
.iter()
.copied()
.map(T::abs)
.fold(
T::zero(),
|best, value| if value > best { value } else { best },
)
}
pub fn dot(&self, other: ArrayView3<'_, T>) -> Result<T> {
if self.shape != other.shape() {
return Err(Error::shape(self.shape, other.shape()));
}
let mut sum = T::zero();
for i in 0..self.shape[0] {
for j in 0..self.shape[1] {
for k in 0..self.shape[2] {
sum += self[(i, j, k)] * other[(i, j, k)];
}
}
}
Ok(sum)
}
}
impl<T> Index<(usize, usize, usize)> for Array3<T> {
type Output = T;
fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
self.get(index.0, index.1, index.2)
.expect("array index out of bounds")
}
}
fn checked_len(shape: [usize; 3]) -> Result<usize> {
shape
.iter()
.try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
.ok_or(Error::DimensionTooLarge)
}
impl<T> IndexMut<(usize, usize, usize)> for Array3<T> {
fn index_mut(&mut self, index: (usize, usize, usize)) -> &mut Self::Output {
self.get_mut(index.0, index.1, index.2)
.expect("array index out of bounds")
}
}