use scirs2_core::array_protocol::{ArrayFunction, ArrayProtocol, NdarrayWrapper, NotImplemented};
use scirs2_core::ndarray_ext::Array2;
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::marker::PhantomData;
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct MockDistributedArray {
data: Vec<f64>, shape: Vec<usize>,
}
impl MockDistributedArray {
fn new(data: Vec<f64>, shape: Vec<usize>) -> Self {
Self { data, shape }
}
}
impl ArrayProtocol for MockDistributedArray {
fn array_function(
&self,
_func: &ArrayFunction,
_types: &[TypeId],
_args: &[Box<dyn Any>],
_kwargs: &HashMap<String, Box<dyn Any>>,
) -> Result<Box<dyn Any>, NotImplemented> {
Err(NotImplemented)
}
fn as_any(&self) -> &dyn Any {
self
}
fn shape(&self) -> &[usize] {
&self.shape
}
fn box_clone(&self) -> Box<dyn ArrayProtocol> {
Box::new(self.clone())
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct MockGPUArray {
data: Vec<f64>, shape: Vec<usize>,
device: String, }
impl MockGPUArray {
fn new(data: Vec<f64>, shape: Vec<usize>, device: String) -> Self {
Self {
data,
shape,
device,
}
}
}
impl ArrayProtocol for MockGPUArray {
fn array_function(
&self,
_func: &ArrayFunction,
_types: &[TypeId],
_args: &[Box<dyn Any>],
_kwargs: &HashMap<String, Box<dyn Any>>,
) -> Result<Box<dyn Any>, NotImplemented> {
Err(NotImplemented)
}
fn as_any(&self) -> &dyn Any {
self
}
fn shape(&self) -> &[usize] {
&self.shape
}
fn box_clone(&self) -> Box<dyn ArrayProtocol> {
Box::new(self.clone())
}
}
#[derive(Debug, Clone)]
struct JITEnabledArray<T, A: Clone> {
inner: A,
phantom: PhantomData<T>,
}
impl<T, A: Clone> JITEnabledArray<T, A> {
fn new(inner: A) -> Self {
Self {
inner,
phantom: PhantomData,
}
}
}
impl<T, A> ArrayProtocol for JITEnabledArray<T, A>
where
T: Clone + Send + Sync + 'static,
A: ArrayProtocol + Clone + Send + Sync + 'static,
{
fn array_function(
&self,
func: &ArrayFunction,
types: &[TypeId],
args: &[Box<dyn Any>],
kwargs: &HashMap<String, Box<dyn Any>>,
) -> Result<Box<dyn Any>, NotImplemented> {
self.inner.array_function(func, types, args, kwargs)
}
fn as_any(&self) -> &dyn Any {
self
}
fn shape(&self) -> &[usize] {
self.inner.shape()
}
fn box_clone(&self) -> Box<dyn ArrayProtocol> {
Box::new(self.clone())
}
}
#[test]
#[allow(dead_code)]
fn test_array_protocol_box_clone() {
let array = Array2::<f64>::ones((3, 3));
let wrapped = NdarrayWrapper::new(array);
let boxed: Box<dyn ArrayProtocol> = Box::new(wrapped);
let cloned = boxed.clone();
assert_eq!(cloned.shape(), &[3, 3]);
let array = MockGPUArray::new(vec![1.0, 2.0, 3.0], vec![3], "cuda:0".to_string());
let boxed: Box<dyn ArrayProtocol> = Box::new(array);
let cloned = boxed.clone();
assert_eq!(cloned.shape(), &[3]);
let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
let boxed: Box<dyn ArrayProtocol> = Box::new(array);
let cloned = boxed.clone();
assert_eq!(cloned.shape(), &[3]);
}
#[test]
#[allow(dead_code)]
fn test_ndarray_wrapper_box_clone() {
let array = Array2::<f64>::ones((3, 3));
let wrapped = NdarrayWrapper::new(array);
let boxed: Box<dyn ArrayProtocol> = Box::new(wrapped);
let cloned = boxed.clone();
assert_eq!(cloned.shape(), &[3, 3]);
let unwrapped = cloned
.as_any()
.downcast_ref::<NdarrayWrapper<f64, scirs2_core::ndarray::Ix2>>();
assert!(unwrapped.is_some());
}
#[test]
#[allow(dead_code)]
fn test_mock_distributed_array_box_clone() {
let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
let boxed: Box<dyn ArrayProtocol> = Box::new(array);
let cloned = boxed.clone();
assert_eq!(cloned.shape(), &[3]);
let unwrapped = cloned.as_any().downcast_ref::<MockDistributedArray>();
assert!(unwrapped.is_some());
}
#[test]
#[allow(dead_code)]
fn test_mock_gpu_array_box_clone() {
let array = MockGPUArray::new(vec![1.0, 2.0, 3.0], vec![3], "cuda:0".to_string());
let boxed: Box<dyn ArrayProtocol> = Box::new(array);
let cloned = boxed.clone();
assert_eq!(cloned.shape(), &[3]);
let unwrapped = cloned.as_any().downcast_ref::<MockGPUArray>();
assert!(unwrapped.is_some());
}
#[test]
#[allow(dead_code)]
fn test_jit_array_box_clone() {
let array = Array2::<f64>::ones((10, 5));
let wrapped = NdarrayWrapper::new(array);
let jit_array = JITEnabledArray::<f64, _>::new(wrapped);
let boxed: Box<dyn ArrayProtocol> = Box::new(jit_array);
let cloned = boxed.clone();
assert_eq!(cloned.shape(), &[10, 5]);
let unwrapped = cloned
.as_any()
.downcast_ref::<JITEnabledArray<f64, NdarrayWrapper<f64, scirs2_core::ndarray::Ix2>>>();
assert!(unwrapped.is_some());
}
#[test]
#[allow(dead_code)]
fn test_chained_box_clone() {
let ndarray = NdarrayWrapper::new(Array2::<f64>::ones((3, 3)));
let mock_distributed = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
let mock_gpu = MockGPUArray::new(vec![1.0, 2.0, 3.0], vec![3], "cuda:0".to_string());
let mut boxed_arrays: Vec<Box<dyn ArrayProtocol>> = vec![
Box::new(ndarray),
Box::new(mock_distributed),
Box::new(mock_gpu),
];
let cloned_arrays = boxed_arrays.clone();
assert_eq!(boxed_arrays.len(), cloned_arrays.len());
for (i, cloned) in cloned_arrays.iter().enumerate() {
assert_eq!(cloned.shape(), boxed_arrays[i].shape());
}
let another_clone = boxed_arrays[0].clone();
boxed_arrays.push(another_clone);
assert_eq!(boxed_arrays.len(), 4);
}