use core::ops::{Index, IndexMut};
use crate::error::{Error, Result};
use crate::numeric::Float;
use crate::rand::SmallRng;
use crate::view2::{ArrayView2, ArrayViewMut2};
#[derive(Clone, Debug, PartialEq)]
pub struct Array2<T> {
data: Vec<T>,
rows: usize,
cols: usize,
}
impl<T> Array2<T> {
pub fn from_vec(shape: [usize; 2], data: Vec<T>) -> Result<Self> {
let expected = shape[0]
.checked_mul(shape[1])
.ok_or(Error::DimensionTooLarge)?;
if data.len() != expected {
return Err(Error::shape(vec![expected], vec![data.len()]));
}
Ok(Self {
data,
rows: shape[0],
cols: shape[1],
})
}
pub fn from_fn(shape: [usize; 2], mut f: impl FnMut(usize, usize) -> T) -> Self {
let len = shape[0] * shape[1];
let mut data = Vec::with_capacity(len);
for i in 0..shape[0] {
for j in 0..shape[1] {
data.push(f(i, j));
}
}
Self {
data,
rows: shape[0],
cols: shape[1],
}
}
pub fn try_from_fn(shape: [usize; 2], mut f: impl FnMut(usize, usize) -> T) -> Result<Self> {
let len = checked_len(shape)?;
let mut data = Vec::new();
data.try_reserve_exact(len)
.map_err(|_| Error::AllocationFailed)?;
for i in 0..shape[0] {
for j in 0..shape[1] {
data.push(f(i, j));
}
}
Ok(Self {
data,
rows: shape[0],
cols: shape[1],
})
}
#[inline]
pub fn shape(&self) -> [usize; 2] {
[self.rows, self.cols]
}
#[inline]
pub fn rows(&self) -> usize {
self.rows
}
#[inline]
pub fn cols(&self) -> usize {
self.cols
}
#[inline]
pub fn strides(&self) -> [isize; 2] {
[self.cols as isize, 1]
}
#[inline]
pub fn row_stride(&self) -> isize {
self.cols as isize
}
#[inline]
pub fn col_stride(&self) -> isize {
1
}
#[inline]
pub fn leading_dimension(&self) -> isize {
self.cols as isize
}
#[inline]
pub fn len(&self) -> usize {
self.data.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[inline]
pub fn is_contiguous(&self) -> bool {
true
}
#[inline]
pub fn as_slice(&self) -> &[T] {
&self.data
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [T] {
&mut self.data
}
pub fn into_vec(self) -> Vec<T> {
self.data
}
pub fn view(&self) -> ArrayView2<'_, T> {
ArrayView2::from_raw_parts(&self.data, self.shape(), self.strides(), 0)
}
pub fn view_mut(&mut self) -> ArrayViewMut2<'_, T> {
ArrayViewMut2::from_raw_parts(
&mut self.data,
[self.rows, self.cols],
[self.cols as isize, 1],
0,
)
}
pub fn transpose_view(&self) -> ArrayView2<'_, T> {
self.view().transpose()
}
#[inline]
pub fn get(&self, row: usize, col: usize) -> Option<&T> {
(row < self.rows && col < self.cols).then(|| &self.data[row * self.cols + col])
}
#[inline]
pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
(row < self.rows && col < self.cols).then(|| &mut self.data[row * self.cols + col])
}
pub fn row(&self, row: usize) -> Result<ArrayView2<'_, T>> {
self.view().row(row)
}
pub fn row_slice(&self, row: usize) -> Result<&[T]> {
if row >= self.rows {
return Err(Error::IndexOutOfBounds);
}
let start = row * self.cols;
Ok(&self.data[start..start + self.cols])
}
pub fn row_mut(&mut self, row: usize) -> Result<ArrayViewMut2<'_, T>> {
if row >= self.rows {
return Err(Error::IndexOutOfBounds);
}
Ok(ArrayViewMut2::from_raw_parts(
&mut self.data,
[1, self.cols],
[self.cols as isize, 1],
(row * self.cols) as isize,
))
}
pub fn row_slice_mut(&mut self, row: usize) -> Result<&mut [T]> {
if row >= self.rows {
return Err(Error::IndexOutOfBounds);
}
let start = row * self.cols;
Ok(&mut self.data[start..start + self.cols])
}
pub fn col(&self, col: usize) -> Result<ArrayView2<'_, T>> {
self.view().col(col)
}
pub fn col_mut(&mut self, col: usize) -> Result<ArrayViewMut2<'_, T>> {
if col >= self.cols {
return Err(Error::IndexOutOfBounds);
}
Ok(ArrayViewMut2::from_raw_parts(
&mut self.data,
[self.rows, 1],
[self.cols as isize, 1],
col as isize,
))
}
pub fn rows_range(&self, start: usize, end: usize) -> Result<ArrayView2<'_, T>> {
self.view().rows_range(start, end)
}
pub fn rows_range_mut(&mut self, start: usize, end: usize) -> Result<ArrayViewMut2<'_, T>> {
if start > end || end > self.rows {
return Err(Error::IndexOutOfBounds);
}
Ok(ArrayViewMut2::from_raw_parts(
&mut self.data,
[end - start, self.cols],
[self.cols as isize, 1],
(start * self.cols) as isize,
))
}
pub fn cols_range(&self, start: usize, end: usize) -> Result<ArrayView2<'_, T>> {
self.view().cols_range(start, end)
}
pub fn cols_range_mut(&mut self, start: usize, end: usize) -> Result<ArrayViewMut2<'_, T>> {
if start > end || end > self.cols {
return Err(Error::IndexOutOfBounds);
}
Ok(ArrayViewMut2::from_raw_parts(
&mut self.data,
[self.rows, end - start],
[self.cols as isize, 1],
start as isize,
))
}
pub fn reshape(mut self, shape: [usize; 2]) -> Result<Self> {
let expected = shape[0]
.checked_mul(shape[1])
.ok_or(Error::DimensionTooLarge)?;
if expected != self.data.len() {
return Err(Error::shape(vec![self.data.len()], vec![expected]));
}
self.rows = shape[0];
self.cols = shape[1];
Ok(self)
}
}
impl<T: Clone> Array2<T> {
pub fn filled(shape: [usize; 2], value: T) -> Self {
Self {
data: vec![value; shape[0] * shape[1]],
rows: shape[0],
cols: shape[1],
}
}
pub fn try_filled(shape: [usize; 2], value: T) -> Result<Self> {
let len = checked_len(shape)?;
let mut data = Vec::new();
data.try_reserve_exact(len)
.map_err(|_| Error::AllocationFailed)?;
data.resize(len, value);
Ok(Self {
data,
rows: shape[0],
cols: shape[1],
})
}
pub fn clone_contiguous(view: ArrayView2<'_, T>) -> Self {
Self::from_fn(view.shape(), |i, j| view[(i, j)].clone())
}
pub fn to_row_major(&self) -> Self {
self.clone()
}
pub fn to_col_major_vec(&self) -> Vec<T> {
self.view().to_col_major_vec()
}
pub fn copy_from_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
if self.shape() != other.shape() {
return Err(Error::shape(self.shape(), other.shape()));
}
for i in 0..self.rows {
for j in 0..self.cols {
self[(i, j)] = other[(i, j)].clone();
}
}
Ok(())
}
}
impl<T: Float> Array2<T> {
pub fn zeros(shape: [usize; 2]) -> Self {
Self::filled(shape, T::zero())
}
pub fn try_zeros(shape: [usize; 2]) -> Result<Self> {
Self::try_filled(shape, T::zero())
}
pub fn ones(shape: [usize; 2]) -> Self {
Self::filled(shape, T::one())
}
pub fn try_ones(shape: [usize; 2]) -> Result<Self> {
Self::try_filled(shape, T::one())
}
pub fn zeros_like(&self) -> Self {
Self::zeros(self.shape())
}
pub fn scale(&mut self, alpha: T) {
for value in &mut self.data {
*value *= alpha;
}
}
pub fn scaled(&self, alpha: T) -> Self {
Self::from_fn(self.shape(), |i, j| self[(i, j)] * alpha)
}
pub fn scaled_into(&self, alpha: T, mut out: ArrayViewMut2<'_, T>) -> Result<()> {
if self.shape() != out.shape() {
return Err(Error::shape(self.shape(), out.shape()));
}
for i in 0..self.rows {
for j in 0..self.cols {
out[(i, j)] = self[(i, j)] * alpha;
}
}
Ok(())
}
pub fn add(&self, other: ArrayView2<'_, T>) -> Result<Self> {
self.zip_map(other, |left, right| left + right)
}
pub fn add_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
self.zip_map_into(other, out, |left, right| left + right)
}
pub fn sub(&self, other: ArrayView2<'_, T>) -> Result<Self> {
self.zip_map(other, |left, right| left - right)
}
pub fn sub_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
self.zip_map_into(other, out, |left, right| left - right)
}
pub fn mul(&self, other: ArrayView2<'_, T>) -> Result<Self> {
self.zip_map(other, |left, right| left * right)
}
pub fn mul_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
self.zip_map_into(other, out, |left, right| left * right)
}
pub fn hadamard(&self, other: ArrayView2<'_, T>) -> Result<Self> {
self.mul(other)
}
pub fn hadamard_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
self.mul_into(other, out)
}
pub fn div(&self, other: ArrayView2<'_, T>) -> Result<Self> {
self.zip_map(other, |left, right| left / right)
}
pub fn div_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
self.zip_map_into(other, out, |left, right| left / right)
}
pub fn axpy_result(&self, alpha: T, x: ArrayView2<'_, T>) -> Result<Self> {
self.zip_map(x, |left, right| left + alpha * right)
}
pub fn axpy_into(
&self,
alpha: T,
x: ArrayView2<'_, T>,
out: ArrayViewMut2<'_, T>,
) -> Result<()> {
self.zip_map_into(x, out, |left, right| left + alpha * right)
}
pub fn matmul(&self, other: ArrayView2<'_, T>) -> Result<Self> {
crate::linalg::matmul(self.view(), other)
}
pub fn matmul_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
crate::linalg::gemm(T::one(), self.view(), false, other, false, T::zero(), out)
}
pub fn add_assign_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
if self.shape() != other.shape() {
return Err(Error::shape(self.shape(), other.shape()));
}
for i in 0..self.rows {
for j in 0..self.cols {
self[(i, j)] += other[(i, j)];
}
}
Ok(())
}
pub fn sub_assign_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
if self.shape() != other.shape() {
return Err(Error::shape(self.shape(), other.shape()));
}
for i in 0..self.rows {
for j in 0..self.cols {
self[(i, j)] -= other[(i, j)];
}
}
Ok(())
}
pub fn mul_assign_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
if self.shape() != other.shape() {
return Err(Error::shape(self.shape(), other.shape()));
}
for i in 0..self.rows {
for j in 0..self.cols {
self[(i, j)] *= other[(i, j)];
}
}
Ok(())
}
pub fn div_assign_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
if self.shape() != other.shape() {
return Err(Error::shape(self.shape(), other.shape()));
}
for i in 0..self.rows {
for j in 0..self.cols {
self[(i, j)] /= other[(i, j)];
}
}
Ok(())
}
pub fn axpy(&mut self, alpha: T, x: ArrayView2<'_, T>) -> Result<()> {
if self.shape() != x.shape() {
return Err(Error::shape(self.shape(), x.shape()));
}
for i in 0..self.rows {
for j in 0..self.cols {
self[(i, j)] += alpha * x[(i, j)];
}
}
Ok(())
}
pub fn map_inplace(&mut self, mut f: impl FnMut(T) -> T) {
for value in &mut self.data {
*value = f(*value);
}
}
pub fn zip_map_inplace(
&mut self,
other: ArrayView2<'_, T>,
mut f: impl FnMut(T, T) -> T,
) -> Result<()> {
if self.shape() != other.shape() {
return Err(Error::shape(self.shape(), other.shape()));
}
for i in 0..self.rows {
for j in 0..self.cols {
self[(i, j)] = f(self[(i, j)], other[(i, j)]);
}
}
Ok(())
}
pub fn zip_map(&self, other: ArrayView2<'_, T>, mut f: impl FnMut(T, T) -> T) -> Result<Self> {
if self.shape() != other.shape() {
return Err(Error::shape(self.shape(), other.shape()));
}
Ok(Self::from_fn(self.shape(), |i, j| {
f(self[(i, j)], other[(i, j)])
}))
}
pub fn zip_map_into(
&self,
other: ArrayView2<'_, T>,
mut out: ArrayViewMut2<'_, T>,
mut f: impl FnMut(T, T) -> T,
) -> Result<()> {
if self.shape() != other.shape() {
return Err(Error::shape(self.shape(), other.shape()));
}
if self.shape() != out.shape() {
return Err(Error::shape(self.shape(), out.shape()));
}
for i in 0..self.rows {
for j in 0..self.cols {
out[(i, j)] = f(self[(i, j)], other[(i, j)]);
}
}
Ok(())
}
pub fn fill_uniform(&mut self, seed: u64) {
let mut rng = SmallRng::new(seed);
for value in &mut self.data {
*value = rng.uniform();
}
}
pub fn fill_randn(&mut self, seed: u64) {
let mut rng = SmallRng::new(seed);
for value in &mut self.data {
*value = rng.normal();
}
}
pub fn sum(&self) -> T {
self.data.iter().copied().sum()
}
pub fn mean(&self) -> T {
if self.is_empty() {
T::zero()
} else {
self.sum() / T::from_f64(self.len() as f64)
}
}
pub fn norm_frobenius(&self) -> T {
self.data
.iter()
.copied()
.map(|value| value * value)
.sum::<T>()
.sqrt()
}
pub fn max_abs(&self) -> T {
self.data
.iter()
.copied()
.map(T::abs)
.fold(
T::zero(),
|best, value| if value > best { value } else { best },
)
}
pub fn dot(&self, other: ArrayView2<'_, T>) -> Result<T> {
if self.shape() != other.shape() {
return Err(Error::shape(self.shape(), other.shape()));
}
let mut sum = T::zero();
for i in 0..self.rows {
for j in 0..self.cols {
sum += self[(i, j)] * other[(i, j)];
}
}
Ok(sum)
}
}
fn checked_len(shape: [usize; 2]) -> Result<usize> {
shape[0]
.checked_mul(shape[1])
.ok_or(Error::DimensionTooLarge)
}
impl<T> Index<(usize, usize)> for Array2<T> {
type Output = T;
fn index(&self, index: (usize, usize)) -> &Self::Output {
self.get(index.0, index.1)
.expect("array index out of bounds")
}
}
impl<T> IndexMut<(usize, usize)> for Array2<T> {
fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
self.get_mut(index.0, index.1)
.expect("array index out of bounds")
}
}