use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{Float, NumCast, One, Zero};
use scirs2_core::ndarray::{
ArcArray, Array as NdArray, ArrayView as NdArrayView, ArrayViewMut as NdArrayViewMut, Axis,
Dimension, IxDyn, ShapeBuilder,
};
use std::fmt;
use std::ops::{Add, Div, Index, Mul, Sub};
use std::sync::Arc;
#[derive(Clone)]
pub struct SharedArray<T> {
data: ArcArray<T, IxDyn>,
}
impl<T: fmt::Debug + Clone> fmt::Debug for SharedArray<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SharedArray")
.field("shape", &self.shape())
.field("data", &self.data)
.finish()
}
}
impl<T: fmt::Display + Clone> fmt::Display for SharedArray<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SharedArray(shape={:?})", self.shape())
}
}
impl<T: Clone> SharedArray<T> {
pub fn from_vec(data: Vec<T>) -> Self {
let len = data.len();
let nd_arr = NdArray::from_vec(data)
.into_shape_with_order(IxDyn(&[len]))
.expect("Failed to reshape 1D vector: length mismatch should be impossible");
Self {
data: nd_arr.into_shared(),
}
}
pub fn from_vec_with_shape(data: Vec<T>, shape: &[usize]) -> Result<Self> {
let expected_size: usize = shape.iter().product();
if data.len() != expected_size {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![expected_size],
actual: vec![data.len()],
});
}
let nd_arr = NdArray::from_vec(data)
.into_shape_with_order(IxDyn(shape))
.map_err(|e| NumRs2Error::DimensionMismatch(format!("Failed to reshape: {}", e)))?;
Ok(Self {
data: nd_arr.into_shared(),
})
}
pub fn from_array(arr: Array<T>) -> Self {
Self {
data: arr.array().to_shared(),
}
}
pub fn from_arc_array(data: ArcArray<T, IxDyn>) -> Self {
Self { data }
}
pub fn zeros(shape: &[usize]) -> Self
where
T: Zero,
{
let nd_arr: NdArray<T, IxDyn> = NdArray::zeros(IxDyn(shape));
Self {
data: nd_arr.into_shared(),
}
}
pub fn ones(shape: &[usize]) -> Self
where
T: One,
{
let nd_arr: NdArray<T, IxDyn> = NdArray::ones(IxDyn(shape));
Self {
data: nd_arr.into_shared(),
}
}
pub fn full(shape: &[usize], value: T) -> Self {
let nd_arr: NdArray<T, IxDyn> = NdArray::from_elem(IxDyn(shape), value);
Self {
data: nd_arr.into_shared(),
}
}
pub fn shape(&self) -> Vec<usize> {
self.data.shape().to_vec()
}
pub fn ndim(&self) -> usize {
self.data.ndim()
}
pub fn size(&self) -> usize {
self.data.len()
}
pub fn ref_count(&self) -> usize {
1
}
pub fn is_unique(&self) -> bool {
true
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn strides(&self) -> Vec<isize> {
self.data.strides().to_vec()
}
pub fn get(&self, indices: &[usize]) -> Option<&T> {
self.data.get(IxDyn(indices))
}
pub fn get_mut(&mut self, indices: &[usize]) -> Option<&mut T> {
self.data.get_mut(IxDyn(indices))
}
pub fn set(&mut self, indices: &[usize], value: T) -> Result<()> {
if let Some(elem) = self.data.get_mut(IxDyn(indices)) {
*elem = value;
Ok(())
} else {
Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {:?} out of bounds for shape {:?}",
indices,
self.shape()
)))
}
}
pub fn get_flat(&self, index: usize) -> Result<T> {
if index >= self.size() {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Flat index {} out of bounds for array of size {}",
index,
self.size()
)));
}
let shape = self.shape();
let mut indices = Vec::with_capacity(shape.len());
let mut remainder = index;
for i in (0..shape.len()).rev() {
indices.push(remainder % shape[i]);
remainder /= shape[i];
}
indices.reverse();
self.data.get(IxDyn(&indices)).cloned().ok_or_else(|| {
NumRs2Error::IndexOutOfBounds(format!(
"Failed to access element at flat index {}",
index
))
})
}
pub fn to_owned_array(&self) -> Array<T> {
Array::from_ndarray(self.data.to_owned())
}
pub fn to_vec(&self) -> Vec<T> {
self.data.iter().cloned().collect()
}
pub fn as_arc_array(&self) -> &ArcArray<T, IxDyn> {
&self.data
}
pub fn view(&self) -> NdArrayView<'_, T, IxDyn> {
self.data.view()
}
pub fn reshape(&self, new_shape: &[usize]) -> Result<Self> {
let new_size: usize = new_shape.iter().product();
if new_size != self.size() {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![self.size()],
actual: vec![new_size],
});
}
let owned = self.data.to_owned();
let reshaped = owned
.into_shape_with_order(IxDyn(new_shape))
.map_err(|e| NumRs2Error::DimensionMismatch(format!("Reshape failed: {}", e)))?;
Ok(Self {
data: reshaped.into_shared(),
})
}
pub fn flatten(&self) -> Self {
let flat = self.to_vec();
Self::from_vec(flat)
}
pub fn transpose(&self) -> Self {
let transposed = self.data.t().to_owned();
Self {
data: transposed.into_shared(),
}
}
pub fn add(&self, other: &Self) -> Result<Self>
where
T: Add<Output = T> + Copy,
{
if self.shape() != other.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: other.shape(),
});
}
let result = &self.data + &other.data;
Ok(Self {
data: result.into_shared(),
})
}
pub fn sub(&self, other: &Self) -> Result<Self>
where
T: Sub<Output = T> + Copy,
{
if self.shape() != other.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: other.shape(),
});
}
let result = &self.data - &other.data;
Ok(Self {
data: result.into_shared(),
})
}
pub fn mul(&self, other: &Self) -> Result<Self>
where
T: Mul<Output = T> + Copy,
{
if self.shape() != other.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: other.shape(),
});
}
let result = &self.data * &other.data;
Ok(Self {
data: result.into_shared(),
})
}
pub fn div(&self, other: &Self) -> Result<Self>
where
T: Div<Output = T> + Copy,
{
if self.shape() != other.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape(),
actual: other.shape(),
});
}
let result = &self.data / &other.data;
Ok(Self {
data: result.into_shared(),
})
}
pub fn sum(&self) -> T
where
T: Zero + Add<Output = T> + Copy,
{
self.data.iter().copied().fold(T::zero(), |acc, x| acc + x)
}
pub fn mean(&self) -> Option<T>
where
T: Float + NumCast,
{
if self.is_empty() {
return None;
}
let sum: T = self.data.iter().copied().fold(T::zero(), |acc, x| acc + x);
let count = T::from(self.size())?;
Some(sum / count)
}
pub fn min(&self) -> Option<T>
where
T: PartialOrd + Copy,
{
self.data
.iter()
.copied()
.reduce(|a, b| if a < b { a } else { b })
}
pub fn max(&self) -> Option<T>
where
T: PartialOrd + Copy,
{
self.data
.iter()
.copied()
.reduce(|a, b| if a > b { a } else { b })
}
}
#[derive(Clone)]
pub struct SharedArrayView<T> {
source: SharedArray<T>,
offset: Vec<usize>,
view_shape: Vec<usize>,
}
impl<T: Clone> SharedArrayView<T> {
pub fn new(source: SharedArray<T>) -> Self {
let shape = source.shape();
Self {
source,
offset: vec![0; shape.len()],
view_shape: shape,
}
}
pub fn slice(source: SharedArray<T>, offset: Vec<usize>, shape: Vec<usize>) -> Self {
Self {
source,
offset,
view_shape: shape,
}
}
pub fn shape(&self) -> &[usize] {
&self.view_shape
}
pub fn get(&self, indices: &[usize]) -> Option<&T> {
let adjusted: Vec<usize> = indices
.iter()
.zip(&self.offset)
.map(|(i, o)| i + o)
.collect();
self.source.get(&adjusted)
}
pub fn to_shared_array(&self) -> SharedArray<T> {
if self.offset.iter().all(|&o| o == 0) && self.view_shape == self.source.shape() {
self.source.clone()
} else {
let mut result = Vec::with_capacity(self.view_shape.iter().product());
if self.view_shape.len() == 1 {
for i in 0..self.view_shape[0] {
if let Some(val) = self.get(&[i]) {
result.push(val.clone());
}
}
} else {
for i in 0..self.view_shape.iter().product::<usize>() {
let mut indices = vec![0; self.view_shape.len()];
let mut remainder = i;
for (j, &dim) in self.view_shape.iter().enumerate().rev() {
indices[j] = remainder % dim;
remainder /= dim;
}
if let Some(val) = self.get(&indices) {
result.push(val.clone());
}
}
}
let shape = self.view_shape.clone();
SharedArray::from_vec_with_shape(result.clone(), &shape)
.unwrap_or_else(|_| SharedArray::from_vec(result))
}
}
}
impl<T: Clone> SharedArray<T> {
pub fn shared_view(&self) -> SharedArrayView<T> {
SharedArrayView::new(self.clone())
}
}
impl<T: Clone> Index<&[usize]> for SharedArray<T> {
type Output = T;
fn index(&self, indices: &[usize]) -> &Self::Output {
self.get(indices).expect("Index out of bounds")
}
}
impl<T: Clone> From<Array<T>> for SharedArray<T> {
fn from(arr: Array<T>) -> Self {
SharedArray::from_array(arr)
}
}
impl<T: Clone> From<Vec<T>> for SharedArray<T> {
fn from(vec: Vec<T>) -> Self {
SharedArray::from_vec(vec)
}
}
impl<T: Clone> From<SharedArray<T>> for Array<T> {
fn from(shared: SharedArray<T>) -> Self {
shared.to_owned_array()
}
}
impl<T: Clone + PartialEq> PartialEq for SharedArray<T> {
fn eq(&self, other: &Self) -> bool {
self.shape() == other.shape() && self.to_vec() == other.to_vec()
}
}
impl<T> Add for SharedArray<T>
where
T: Clone + Add<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn add(self, rhs: Self) -> Self::Output {
SharedArray::add(&self, &rhs).expect("Shape mismatch in addition")
}
}
impl<T> Add for &SharedArray<T>
where
T: Clone + Add<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn add(self, rhs: Self) -> Self::Output {
SharedArray::add(self, rhs).expect("Shape mismatch in addition")
}
}
impl<T> Add<&SharedArray<T>> for SharedArray<T>
where
T: Clone + Add<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn add(self, rhs: &SharedArray<T>) -> Self::Output {
SharedArray::add(&self, rhs).expect("Shape mismatch in addition")
}
}
impl<T> Add<SharedArray<T>> for &SharedArray<T>
where
T: Clone + Add<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn add(self, rhs: SharedArray<T>) -> Self::Output {
SharedArray::add(self, &rhs).expect("Shape mismatch in addition")
}
}
impl<T> Sub for SharedArray<T>
where
T: Clone + Sub<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn sub(self, rhs: Self) -> Self::Output {
SharedArray::sub(&self, &rhs).expect("Shape mismatch in subtraction")
}
}
impl<T> Sub for &SharedArray<T>
where
T: Clone + Sub<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn sub(self, rhs: Self) -> Self::Output {
SharedArray::sub(self, rhs).expect("Shape mismatch in subtraction")
}
}
impl<T> Sub<&SharedArray<T>> for SharedArray<T>
where
T: Clone + Sub<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn sub(self, rhs: &SharedArray<T>) -> Self::Output {
SharedArray::sub(&self, rhs).expect("Shape mismatch in subtraction")
}
}
impl<T> Sub<SharedArray<T>> for &SharedArray<T>
where
T: Clone + Sub<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn sub(self, rhs: SharedArray<T>) -> Self::Output {
SharedArray::sub(self, &rhs).expect("Shape mismatch in subtraction")
}
}
impl<T> Mul for SharedArray<T>
where
T: Clone + Mul<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn mul(self, rhs: Self) -> Self::Output {
SharedArray::mul(&self, &rhs).expect("Shape mismatch in multiplication")
}
}
impl<T> Mul for &SharedArray<T>
where
T: Clone + Mul<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn mul(self, rhs: Self) -> Self::Output {
SharedArray::mul(self, rhs).expect("Shape mismatch in multiplication")
}
}
impl<T> Mul<&SharedArray<T>> for SharedArray<T>
where
T: Clone + Mul<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn mul(self, rhs: &SharedArray<T>) -> Self::Output {
SharedArray::mul(&self, rhs).expect("Shape mismatch in multiplication")
}
}
impl<T> Mul<SharedArray<T>> for &SharedArray<T>
where
T: Clone + Mul<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn mul(self, rhs: SharedArray<T>) -> Self::Output {
SharedArray::mul(self, &rhs).expect("Shape mismatch in multiplication")
}
}
impl<T> Div for SharedArray<T>
where
T: Clone + Div<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn div(self, rhs: Self) -> Self::Output {
SharedArray::div(&self, &rhs).expect("Shape mismatch in division")
}
}
impl<T> Div for &SharedArray<T>
where
T: Clone + Div<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn div(self, rhs: Self) -> Self::Output {
SharedArray::div(self, rhs).expect("Shape mismatch in division")
}
}
impl<T> Div<&SharedArray<T>> for SharedArray<T>
where
T: Clone + Div<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn div(self, rhs: &SharedArray<T>) -> Self::Output {
SharedArray::div(&self, rhs).expect("Shape mismatch in division")
}
}
impl<T> Div<SharedArray<T>> for &SharedArray<T>
where
T: Clone + Div<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn div(self, rhs: SharedArray<T>) -> Self::Output {
SharedArray::div(self, &rhs).expect("Shape mismatch in division")
}
}
impl<T: Clone> SharedArray<T> {
pub fn add_scalar(&self, scalar: T) -> Self
where
T: Add<Output = T> + Copy,
{
let result: Vec<T> = self.data.iter().map(|&x| x + scalar).collect();
SharedArray::from_vec_with_shape(result, &self.shape()).expect("Shape should be valid")
}
pub fn sub_scalar(&self, scalar: T) -> Self
where
T: Sub<Output = T> + Copy,
{
let result: Vec<T> = self.data.iter().map(|&x| x - scalar).collect();
SharedArray::from_vec_with_shape(result, &self.shape()).expect("Shape should be valid")
}
pub fn mul_scalar(&self, scalar: T) -> Self
where
T: Mul<Output = T> + Copy,
{
let result: Vec<T> = self.data.iter().map(|&x| x * scalar).collect();
SharedArray::from_vec_with_shape(result, &self.shape()).expect("Shape should be valid")
}
pub fn div_scalar(&self, scalar: T) -> Self
where
T: Div<Output = T> + Copy,
{
let result: Vec<T> = self.data.iter().map(|&x| x / scalar).collect();
SharedArray::from_vec_with_shape(result, &self.shape()).expect("Shape should be valid")
}
pub fn neg(&self) -> Self
where
T: std::ops::Neg<Output = T> + Copy,
{
let result: Vec<T> = self.data.iter().map(|&x| -x).collect();
SharedArray::from_vec_with_shape(result, &self.shape()).expect("Shape should be valid")
}
}
impl<T> Add<T> for SharedArray<T>
where
T: Clone + Add<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn add(self, scalar: T) -> Self::Output {
self.add_scalar(scalar)
}
}
impl<T> Add<T> for &SharedArray<T>
where
T: Clone + Add<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn add(self, scalar: T) -> Self::Output {
self.add_scalar(scalar)
}
}
impl<T> Sub<T> for SharedArray<T>
where
T: Clone + Sub<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn sub(self, scalar: T) -> Self::Output {
self.sub_scalar(scalar)
}
}
impl<T> Sub<T> for &SharedArray<T>
where
T: Clone + Sub<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn sub(self, scalar: T) -> Self::Output {
self.sub_scalar(scalar)
}
}
impl<T> Mul<T> for SharedArray<T>
where
T: Clone + Mul<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn mul(self, scalar: T) -> Self::Output {
self.mul_scalar(scalar)
}
}
impl<T> Mul<T> for &SharedArray<T>
where
T: Clone + Mul<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn mul(self, scalar: T) -> Self::Output {
self.mul_scalar(scalar)
}
}
impl<T> Div<T> for SharedArray<T>
where
T: Clone + Div<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn div(self, scalar: T) -> Self::Output {
self.div_scalar(scalar)
}
}
impl<T> Div<T> for &SharedArray<T>
where
T: Clone + Div<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn div(self, scalar: T) -> Self::Output {
self.div_scalar(scalar)
}
}
impl<T> std::ops::Neg for SharedArray<T>
where
T: Clone + std::ops::Neg<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn neg(self) -> Self::Output {
SharedArray::neg(&self)
}
}
impl<T> std::ops::Neg for &SharedArray<T>
where
T: Clone + std::ops::Neg<Output = T> + Copy,
{
type Output = SharedArray<T>;
fn neg(self) -> Self::Output {
SharedArray::neg(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_vec() {
let arr = SharedArray::from_vec(vec![1, 2, 3, 4]);
assert_eq!(arr.shape(), vec![4]);
assert_eq!(arr.size(), 4);
assert_eq!(arr.to_vec(), vec![1, 2, 3, 4]);
}
#[test]
fn test_from_vec_with_shape() {
let arr = SharedArray::from_vec_with_shape(vec![1, 2, 3, 4, 5, 6], &[2, 3])
.expect("from_vec_with_shape should succeed for valid shape");
assert_eq!(arr.shape(), vec![2, 3]);
assert_eq!(arr.ndim(), 2);
}
#[test]
fn test_zeros_ones() {
let zeros: SharedArray<f64> = SharedArray::zeros(&[3, 3]);
assert_eq!(zeros.shape(), vec![3, 3]);
assert!(zeros.to_vec().iter().all(|&x| x == 0.0));
let ones: SharedArray<f64> = SharedArray::ones(&[2, 2]);
assert!(ones.to_vec().iter().all(|&x| x == 1.0));
}
#[test]
fn test_clone_shares_data() {
let arr1 = SharedArray::from_vec(vec![1.0, 2.0, 3.0]);
let arr2 = arr1.clone();
assert_eq!(arr1.to_vec(), arr2.to_vec());
assert!(arr1.ref_count() >= 1);
}
#[test]
fn test_element_access() {
let arr = SharedArray::from_vec_with_shape(vec![1, 2, 3, 4], &[2, 2])
.expect("from_vec_with_shape should succeed for 2x2");
assert_eq!(arr.get(&[0, 0]), Some(&1));
assert_eq!(arr.get(&[0, 1]), Some(&2));
assert_eq!(arr.get(&[1, 0]), Some(&3));
assert_eq!(arr.get(&[1, 1]), Some(&4));
assert_eq!(arr.get(&[2, 0]), None);
}
#[test]
fn test_set() {
let mut arr = SharedArray::from_vec_with_shape(vec![1, 2, 3, 4], &[2, 2])
.expect("from_vec_with_shape should succeed for 2x2");
arr.set(&[0, 0], 10)
.expect("set should succeed for valid index");
assert_eq!(arr.get(&[0, 0]), Some(&10));
}
#[test]
fn test_reshape() {
let arr = SharedArray::from_vec(vec![1, 2, 3, 4, 5, 6]);
let reshaped = arr.reshape(&[2, 3]).expect("reshape to 2x3 should succeed");
assert_eq!(reshaped.shape(), vec![2, 3]);
assert!(arr.reshape(&[2, 2]).is_err());
}
#[test]
fn test_flatten() {
let arr = SharedArray::from_vec_with_shape(vec![1, 2, 3, 4], &[2, 2])
.expect("from_vec_with_shape should succeed for 2x2");
let flat = arr.flatten();
assert_eq!(flat.shape(), vec![4]);
assert_eq!(flat.to_vec(), vec![1, 2, 3, 4]);
}
#[test]
fn test_transpose() {
let arr = SharedArray::from_vec_with_shape(vec![1, 2, 3, 4, 5, 6], &[2, 3])
.expect("from_vec_with_shape should succeed for 2x3");
let transposed = arr.transpose();
assert_eq!(transposed.shape(), vec![3, 2]);
}
#[test]
fn test_arithmetic() {
let a = SharedArray::from_vec(vec![1.0, 2.0, 3.0]);
let b = SharedArray::from_vec(vec![4.0, 5.0, 6.0]);
let sum = SharedArray::add(&a, &b).expect("add should succeed for same-shape arrays");
assert_eq!(sum.to_vec(), vec![5.0, 7.0, 9.0]);
let diff = SharedArray::sub(&b, &a).expect("sub should succeed for same-shape arrays");
assert_eq!(diff.to_vec(), vec![3.0, 3.0, 3.0]);
let prod = SharedArray::mul(&a, &b).expect("mul should succeed for same-shape arrays");
assert_eq!(prod.to_vec(), vec![4.0, 10.0, 18.0]);
let quot = SharedArray::div(&b, &a).expect("div should succeed for same-shape arrays");
assert_eq!(quot.to_vec(), vec![4.0, 2.5, 2.0]);
}
#[test]
fn test_aggregations() {
let arr = SharedArray::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
assert_eq!(arr.sum(), 15.0);
assert_eq!(arr.mean(), Some(3.0));
assert_eq!(arr.min(), Some(1.0));
assert_eq!(arr.max(), Some(5.0));
}
#[test]
fn test_from_array_conversion() {
let arr = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let shared = SharedArray::from_array(arr.clone());
assert_eq!(shared.shape(), vec![2, 2]);
assert_eq!(shared.to_vec(), arr.to_vec());
}
#[test]
fn test_to_owned_array() {
let shared = SharedArray::from_vec(vec![1.0, 2.0, 3.0]);
let owned: Array<f64> = shared.to_owned_array();
assert_eq!(owned.to_vec(), vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_shared_view() {
let arr = SharedArray::from_vec_with_shape(vec![1, 2, 3, 4], &[2, 2])
.expect("from_vec_with_shape should succeed for 2x2");
let view = arr.shared_view();
assert_eq!(view.shape(), &[2, 2]);
assert_eq!(view.get(&[0, 0]), Some(&1));
assert_eq!(view.get(&[1, 1]), Some(&4));
}
#[test]
fn test_shared_view_to_array() {
let arr = SharedArray::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let view = arr.shared_view();
let shared2 = view.to_shared_array();
assert_eq!(shared2.to_vec(), arr.to_vec());
}
#[test]
fn test_from_trait_implementations() {
let shared: SharedArray<i32> = vec![1, 2, 3].into();
assert_eq!(shared.to_vec(), vec![1, 2, 3]);
let arr = Array::from_vec(vec![4.0, 5.0, 6.0]);
let shared: SharedArray<f64> = arr.into();
assert_eq!(shared.to_vec(), vec![4.0, 5.0, 6.0]);
let shared2 = SharedArray::from_vec(vec![7, 8, 9]);
let arr2: Array<i32> = shared2.into();
assert_eq!(arr2.to_vec(), vec![7, 8, 9]);
}
#[test]
fn test_partial_eq() {
let a = SharedArray::from_vec(vec![1, 2, 3]);
let b = SharedArray::from_vec(vec![1, 2, 3]);
let c = SharedArray::from_vec(vec![1, 2, 4]);
assert_eq!(a, b);
assert_ne!(a, c);
}
#[test]
fn test_index_trait() {
let arr = SharedArray::from_vec_with_shape(vec![1, 2, 3, 4], &[2, 2])
.expect("from_vec_with_shape should succeed for 2x2");
assert_eq!(arr[&[0, 0][..]], 1);
assert_eq!(arr[&[1, 1][..]], 4);
}
#[test]
fn test_operator_add_owned() {
let a = SharedArray::from_vec(vec![1.0, 2.0, 3.0]);
let b = SharedArray::from_vec(vec![4.0, 5.0, 6.0]);
let c = a + b; assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
}
#[test]
fn test_operator_add_refs() {
let a = SharedArray::from_vec(vec![1.0, 2.0, 3.0]);
let b = SharedArray::from_vec(vec![4.0, 5.0, 6.0]);
let c = &a + &b; assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
assert_eq!(a.to_vec(), vec![1.0, 2.0, 3.0]);
assert_eq!(b.to_vec(), vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_operator_sub() {
let a = SharedArray::from_vec(vec![5.0, 7.0, 9.0]);
let b = SharedArray::from_vec(vec![1.0, 2.0, 3.0]);
let c = &a - &b;
assert_eq!(c.to_vec(), vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_operator_mul() {
let a = SharedArray::from_vec(vec![2.0, 3.0, 4.0]);
let b = SharedArray::from_vec(vec![3.0, 4.0, 5.0]);
let c = &a * &b;
assert_eq!(c.to_vec(), vec![6.0, 12.0, 20.0]);
}
#[test]
fn test_operator_div() {
let a = SharedArray::from_vec(vec![10.0, 12.0, 15.0]);
let b = SharedArray::from_vec(vec![2.0, 3.0, 5.0]);
let c = &a / &b;
assert_eq!(c.to_vec(), vec![5.0, 4.0, 3.0]);
}
#[test]
fn test_operator_scalar_add() {
let a = SharedArray::from_vec(vec![1.0, 2.0, 3.0]);
let b = a.add_scalar(10.0);
assert_eq!(b.to_vec(), vec![11.0, 12.0, 13.0]);
}
#[test]
fn test_operator_scalar_mul() {
let a = SharedArray::from_vec(vec![1.0, 2.0, 3.0]);
let b = a.mul_scalar(2.0);
assert_eq!(b.to_vec(), vec![2.0, 4.0, 6.0]);
}
#[test]
fn test_operator_negation() {
let a = SharedArray::from_vec(vec![1.0, -2.0, 3.0]);
let b = -a;
assert_eq!(b.to_vec(), vec![-1.0, 2.0, -3.0]);
}
#[test]
fn test_operator_negation_ref() {
let a = SharedArray::from_vec(vec![1.0, -2.0, 3.0]);
let b = -&a;
assert_eq!(b.to_vec(), vec![-1.0, 2.0, -3.0]);
assert_eq!(a.to_vec(), vec![1.0, -2.0, 3.0]);
}
#[test]
fn test_operator_chaining() {
let a = SharedArray::from_vec(vec![1.0, 2.0, 3.0]);
let b = SharedArray::from_vec(vec![2.0, 3.0, 4.0]);
let c = SharedArray::from_vec(vec![2.0, 2.0, 2.0]);
let d = SharedArray::from_vec(vec![1.0, 1.0, 1.0]);
let result = (&a + &b) * &c - d;
assert_eq!(result.to_vec(), vec![5.0, 9.0, 13.0]);
}
#[test]
fn test_mixed_ownership_operations() {
let a = SharedArray::from_vec(vec![1.0, 2.0, 3.0]);
let b = SharedArray::from_vec(vec![4.0, 5.0, 6.0]);
let c = a.clone() + &b;
assert_eq!(c.to_vec(), vec![5.0, 7.0, 9.0]);
let d = &a + b.clone();
assert_eq!(d.to_vec(), vec![5.0, 7.0, 9.0]);
}
}