use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::fmt::{Debug, Display};
use std::marker::PhantomData;
use std::sync::{Arc, LazyLock, RwLock};
use std::time::{Duration, Instant};
use crate::error::{CoreError, CoreResult, ErrorContext};
mod distributed_impl;
mod gpu_impl;
mod jit_impl;
mod operations;
pub use crate::array_function_dispatch;
pub mod auto_device;
pub mod distributed_training;
pub mod grad;
pub mod mixed_precision;
pub mod ml_ops;
pub mod neural;
#[cfg(feature = "serialization")]
pub mod serialization;
pub mod training;
pub trait ArrayProtocol: Any + Send + Sync {
fn array_function(
&self,
func: &ArrayFunction,
types: &[TypeId],
args: &[Box<dyn Any>],
kwargs: &HashMap<String, Box<dyn Any>>,
) -> Result<Box<dyn Any>, NotImplemented>;
#[must_use]
fn as_any(&self) -> &dyn Any;
#[must_use]
fn shape(&self) -> &[usize] {
&[]
}
#[must_use]
fn dtype(&self) -> TypeId {
TypeId::of::<f64>()
}
#[must_use]
fn box_clone(&self) -> Box<dyn ArrayProtocol>;
}
impl Clone for Box<dyn ArrayProtocol> {
fn clone(&self) -> Self {
self.box_clone()
}
}
#[derive(Debug, Clone, Copy)]
pub struct NotImplemented;
impl Display for NotImplemented {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "NotImplemented")
}
}
pub type ArrayFunctionImpl = dyn Fn(&[Box<dyn Any>], &HashMap<String, Box<dyn Any>>) -> CoreResult<Box<dyn Any>>
+ Send
+ Sync;
#[derive(Clone)]
pub struct ArrayFunction {
pub name: &'static str,
pub implementation: Arc<ArrayFunctionImpl>,
}
impl Debug for ArrayFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrayFunction")
.field("name", &self.name)
.finish_non_exhaustive()
}
}
impl PartialEq for ArrayFunction {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
impl Eq for ArrayFunction {}
impl std::hash::Hash for ArrayFunction {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
}
}
impl ArrayFunction {
#[must_use]
pub fn new(name: &'static str) -> Self {
Self {
name,
implementation: Arc::new(|_args, _kwargs| {
Err(CoreError::NotImplementedError(ErrorContext::new(
"Function not implemented".to_string(),
)))
}),
}
}
}
#[derive(Debug, Clone)]
pub struct DispatchCacheEntry {
#[allow(dead_code)]
type_signature: Vec<TypeId>,
#[allow(dead_code)]
preferred_impl_type: TypeId,
timestamp: Instant,
hit_count: u64,
}
#[derive(Debug)]
pub struct ArrayFunctionRegistry {
functions: HashMap<&'static str, ArrayFunction>,
dispatch_cache: HashMap<(&'static str, Vec<TypeId>), DispatchCacheEntry>,
max_cache_size: usize,
cache_ttl: Duration,
}
impl Default for ArrayFunctionRegistry {
fn default() -> Self {
Self {
functions: HashMap::new(),
dispatch_cache: HashMap::new(),
max_cache_size: 1000, cache_ttl: Duration::from_secs(300), }
}
}
impl ArrayFunctionRegistry {
#[must_use]
pub fn global() -> &'static RwLock<Self> {
static REGISTRY: LazyLock<RwLock<ArrayFunctionRegistry>> =
LazyLock::new(|| RwLock::new(ArrayFunctionRegistry::default()));
®ISTRY
}
pub fn register(&mut self, func: ArrayFunction) {
self.functions.insert(func.name, func);
}
#[must_use]
#[allow(dead_code)]
pub fn get(&self, name: &str) -> Option<&ArrayFunction> {
self.functions.get(name)
}
#[must_use]
pub fn all_functions(&self) -> Vec<&ArrayFunction> {
self.functions.values().collect()
}
#[must_use]
pub fn get_cached_dispatch(
&self,
funcname: &'static str,
types: &[TypeId],
) -> Option<&DispatchCacheEntry> {
let key = (funcname, types.to_vec());
if let Some(entry) = self.dispatch_cache.get(&key) {
if entry.timestamp.elapsed() < self.cache_ttl {
return Some(entry);
}
}
None
}
pub fn cache_dispatch(
&mut self,
funcname: &'static str,
types: Vec<TypeId>,
impl_type: TypeId,
) {
if self.dispatch_cache.len() >= self.max_cache_size {
self.cleanup_cache();
}
let key = (funcname, types.clone());
let entry = DispatchCacheEntry {
type_signature: types,
preferred_impl_type: impl_type,
timestamp: Instant::now(),
hit_count: 0,
};
self.dispatch_cache.insert(key, entry);
}
pub fn update_cache_hit(&mut self, funcname: &'static str, types: &[TypeId]) {
let key = (funcname, types.to_vec());
if let Some(entry) = self.dispatch_cache.get_mut(&key) {
entry.hit_count += 1;
}
}
fn cleanup_cache(&mut self) {
let now = Instant::now();
self.dispatch_cache
.retain(|_, entry| now.duration_since(entry.timestamp) < self.cache_ttl);
if self.dispatch_cache.len() >= self.max_cache_size {
let mut entries: Vec<_> = self
.dispatch_cache
.iter()
.map(|(k, v)| (k.clone(), v.hit_count))
.collect();
entries.sort_by_key(|(_, hit_count)| *hit_count);
let to_remove = self.dispatch_cache.len() / 4;
let keys_to_remove: Vec<_> = entries
.iter()
.take(to_remove)
.map(|(key, _)| key.clone())
.collect();
for key in keys_to_remove {
self.dispatch_cache.remove(&key);
}
}
}
#[must_use]
pub fn cache_stats(&self) -> HashMap<String, u64> {
let mut stats = HashMap::new();
stats.insert("cache_size".to_string(), self.dispatch_cache.len() as u64);
stats.insert("max_cache_size".to_string(), self.max_cache_size as u64);
let total_hits: u64 = self.dispatch_cache.values().map(|e| e.hit_count).sum();
stats.insert("total_hits".to_string(), total_hits);
stats
}
}
#[allow(dead_code)]
pub fn get_implementing_args(args: &[Box<dyn Any>]) -> Vec<(TypeId, &dyn ArrayProtocol)> {
if args.is_empty() {
return Vec::new();
}
let mut implementing_args = Vec::with_capacity(args.len());
for arg in args {
if let Some(array_protocol_obj) = arg.downcast_ref::<Box<dyn ArrayProtocol>>() {
let type_id = (**array_protocol_obj).type_id();
implementing_args.push((type_id, &**array_protocol_obj));
}
}
implementing_args.sort_by_key(|&_type_id_| {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
std::any::TypeId::of::<i32>().hash(&mut hasher);
hasher.finish()
});
implementing_args
}
#[allow(dead_code)]
pub fn array_function_dispatch(
func: &ArrayFunction,
args: &[Box<dyn Any>],
kwargs: &HashMap<String, Box<dyn Any>>,
) -> CoreResult<Box<dyn Any>> {
if args.is_empty() {
return (func.implementation)(args, kwargs);
}
let implementing_args = get_implementing_args(args);
if implementing_args.is_empty() {
return (func.implementation)(args, kwargs);
}
if implementing_args.len() == 1 {
let (type_id, array_protocol_obj) = implementing_args[0];
let types = [type_id];
match array_protocol_obj.array_function(func, &types, args, kwargs) {
Ok(result) => return Ok(result),
Err(NotImplemented) => {
return Err(CoreError::DispatchError(ErrorContext::new(format!(
"No implementation found for {} with type {:?}",
func.name, type_id
))));
}
}
}
let mut unique_types = Vec::with_capacity(implementing_args.len());
let mut seen_types = std::collections::HashSet::with_capacity(implementing_args.len());
for &(type_id, _) in &implementing_args {
if seen_types.insert(type_id) {
unique_types.push(type_id);
}
}
for (_, array_protocol_obj) in implementing_args {
if let Ok(result) = array_protocol_obj.array_function(func, &unique_types, args, kwargs) {
return Ok(result);
}
}
Err(CoreError::DispatchError(ErrorContext::new(format!(
"No implementation found for {} with {} argument types: {:?}",
func.name,
unique_types.len(),
unique_types
))))
}
pub struct ArrayFunctionDecorator<F> {
function: F,
name: &'static str,
}
impl<F> ArrayFunctionDecorator<F>
where
F: Send + Sync + 'static,
{
#[must_use]
pub fn new(function: F, name: &'static str) -> Self {
Self { function, name }
}
pub fn register(self) -> F {
let implementation = Arc::new(
move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
Err(CoreError::NotImplementedError(ErrorContext::new(
"ArrayFunctionDecorator: Type conversion in array_function_dispatch is not implemented yet".to_string()
)))
},
);
let func = ArrayFunction {
name: self.name,
implementation,
};
let registry = ArrayFunctionRegistry::global();
if let Ok(mut registry) = registry.write() {
registry.register(func);
} else {
eprintln!("Warning: Failed to acquire write lock on ArrayFunctionRegistry, skipping function registration");
}
self.function
}
}
pub trait GPUArray: ArrayProtocol {
fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>>;
fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>>;
#[must_use]
fn is_on_gpu(&self) -> bool;
#[must_use]
fn device_info(&self) -> HashMap<String, String>;
}
pub trait DistributedArray: ArrayProtocol {
#[must_use]
fn distribution_info(&self) -> HashMap<String, String>;
fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>>;
fn scatter(&self, chunks: usize) -> CoreResult<Box<dyn DistributedArray>>;
#[must_use]
fn is_distributed(&self) -> bool;
}
pub trait JITArray: ArrayProtocol {
fn compile(&self, expression: &str) -> CoreResult<Box<dyn JITFunction>>;
#[must_use]
fn supports_jit(&self) -> bool;
#[must_use]
fn jit_info(&self) -> HashMap<String, String>;
}
pub trait JITFunction: Send + Sync {
fn evaluate(&self, args: &[Box<dyn Any>]) -> CoreResult<Box<dyn Any>>;
#[must_use]
fn source(&self) -> String;
#[must_use]
fn compile_info(&self) -> HashMap<String, String>;
#[must_use]
fn clone_box(&self) -> Box<dyn JITFunction>;
}
pub trait JITFunctionFactory: Send + Sync {
fn create_jit_function(
&self,
expression: &str,
array_typeid: TypeId,
) -> CoreResult<Box<dyn JITFunction>>;
#[must_use]
fn supports_array_type(&self, array_typeid: TypeId) -> bool;
}
#[derive(Default)]
pub struct JITFactoryRegistry {
factories: Vec<Box<dyn JITFunctionFactory>>,
}
impl std::fmt::Debug for JITFactoryRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"JITFactoryRegistry {{ factories: {} }}",
self.factories.len()
)
}
}
impl JITFactoryRegistry {
#[must_use]
pub fn global() -> &'static RwLock<Self> {
static REGISTRY: LazyLock<RwLock<JITFactoryRegistry>> = LazyLock::new(|| {
RwLock::new(JITFactoryRegistry {
factories: Vec::new(),
})
});
®ISTRY
}
pub fn register(&mut self, factory: Box<dyn JITFunctionFactory>) {
self.factories.push(factory);
}
#[must_use]
pub fn get_factory_for_array_type(
&self,
array_typeid: TypeId,
) -> Option<&dyn JITFunctionFactory> {
for factory in &self.factories {
if factory.supports_array_type(array_typeid) {
return Some(&**factory);
}
}
None
}
}
#[derive(Debug, Clone)]
pub struct NdarrayWrapper<T, D: crate::ndarray::Dimension> {
array: crate::ndarray::Array<T, D>,
phantom: PhantomData<(T, D)>,
}
impl<T, D> NdarrayWrapper<T, D>
where
T: Clone + 'static,
D: crate::ndarray::Dimension + 'static,
{
#[must_use]
pub fn new(array: crate::ndarray::Array<T, D>) -> Self {
Self {
array,
phantom: PhantomData,
}
}
#[must_use]
pub const fn as_array(&self) -> &crate::ndarray::Array<T, D> {
&self.array
}
#[must_use]
pub fn into_array(self) -> crate::ndarray::Array<T, D> {
self.array
}
pub fn array_2(&mut self, newarray: crate::ndarray::Array<T, D>) {
self.array = newarray;
}
}
impl<T, D> ArrayProtocol for NdarrayWrapper<T, D>
where
T: Clone + Send + Sync + 'static,
D: crate::ndarray::Dimension + Send + Sync + 'static,
{
fn array_function(
&self,
func: &ArrayFunction,
_types: &[TypeId],
args: &[Box<dyn Any>],
kwargs: &HashMap<String, Box<dyn Any>>,
) -> Result<Box<dyn Any>, NotImplemented> {
match func.name {
"scirs2::array_protocol::operations::add" => {
if args.len() < 2 {
return Err(NotImplemented);
}
if let Some(other) = args[1].downcast_ref::<NdarrayWrapper<T, D>>() {
if let (Some(a), Some(b)) = (
self.as_any().downcast_ref::<NdarrayWrapper<T, D>>(),
other.as_any().downcast_ref::<NdarrayWrapper<T, D>>(),
) {
if TypeId::of::<T>() == TypeId::of::<f64>() {
let a_f64 =
unsafe { &*(a as *const _ as *const NdarrayWrapper<f64, D>) };
let b_f64 =
unsafe { &*(b as *const _ as *const NdarrayWrapper<f64, D>) };
let result = a_f64.as_array() + b_f64.as_array();
return Ok(Box::new(NdarrayWrapper::new(result)));
} else if TypeId::of::<T>() == TypeId::of::<f32>() {
let a_f32 =
unsafe { &*(a as *const _ as *const NdarrayWrapper<f32, D>) };
let b_f32 =
unsafe { &*(b as *const _ as *const NdarrayWrapper<f32, D>) };
let result = a_f32.as_array() + b_f32.as_array();
return Ok(Box::new(NdarrayWrapper::new(result)));
}
}
}
Err(NotImplemented)
}
"scirs2::array_protocol::operations::matmul" => {
if args.len() < 2 {
return Err(NotImplemented);
}
if TypeId::of::<D>() != TypeId::of::<crate::ndarray::Ix2>() {
return Err(NotImplemented);
}
if let Some(other) = args[1].downcast_ref::<NdarrayWrapper<T, D>>() {
if TypeId::of::<T>() == TypeId::of::<f64>() {
let a_f64 = unsafe {
&*(self as *const _ as *const NdarrayWrapper<f64, crate::ndarray::Ix2>)
};
let b_f64 = unsafe {
&*(other as *const _ as *const NdarrayWrapper<f64, crate::ndarray::Ix2>)
};
let ashape = a_f64.as_array().shape();
let bshape = b_f64.as_array().shape();
if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
return Err(NotImplemented);
}
let result = a_f64.as_array().dot(b_f64.as_array());
return Ok(Box::new(NdarrayWrapper::new(result)));
}
else if TypeId::of::<T>() == TypeId::of::<f32>() {
let a_f32 = unsafe {
&*(self as *const _ as *const NdarrayWrapper<f32, crate::ndarray::Ix2>)
};
let b_f32 = unsafe {
&*(other as *const _ as *const NdarrayWrapper<f32, crate::ndarray::Ix2>)
};
let ashape = a_f32.as_array().shape();
let bshape = b_f32.as_array().shape();
if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
return Err(NotImplemented);
}
let result = a_f32.as_array().dot(b_f32.as_array());
return Ok(Box::new(NdarrayWrapper::new(result)));
}
}
Err(NotImplemented)
}
"scirs2::array_protocol::operations::transpose" => {
if TypeId::of::<T>() == TypeId::of::<f64>() {
let a_f64 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
let result = a_f64.as_array().t().to_owned();
return Ok(Box::new(NdarrayWrapper::new(result)));
} else if TypeId::of::<T>() == TypeId::of::<f32>() {
let a_f32 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
let result = a_f32.as_array().t().to_owned();
return Ok(Box::new(NdarrayWrapper::new(result)));
}
Err(NotImplemented)
}
"scirs2::array_protocol::operations::sum" => {
let axis_ref = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
if TypeId::of::<T>() == TypeId::of::<f64>() {
let a_f64 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
match axis_ref {
Some(&_ax) => {
let result = a_f64.as_array().sum();
return Ok(Box::new(result));
}
None => {
let result = a_f64.as_array().sum();
return Ok(Box::new(result));
}
}
} else if TypeId::of::<T>() == TypeId::of::<f32>() {
let a_f32 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
match axis_ref {
Some(&_ax) => {
let result = a_f32.as_array().sum();
return Ok(Box::new(result));
}
None => {
let result = a_f32.as_array().sum();
return Ok(Box::new(result));
}
}
}
Err(NotImplemented)
}
"scirs2::array_protocol::operations::reshape" => {
if let Some(shape) = kwargs
.get("shape")
.and_then(|s| s.downcast_ref::<Vec<usize>>())
{
if TypeId::of::<T>() == TypeId::of::<f64>() {
let a_f64 =
unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
match a_f64
.as_array()
.clone()
.into_shape_with_order(shape.clone())
{
Ok(result) => return Ok(Box::new(NdarrayWrapper::new(result))),
Err(_) => return Err(NotImplemented),
}
} else if TypeId::of::<T>() == TypeId::of::<f32>() {
let a_f32 =
unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
match a_f32
.as_array()
.clone()
.into_shape_with_order(shape.clone())
{
Ok(result) => return Ok(Box::new(NdarrayWrapper::new(result))),
Err(_) => return Err(NotImplemented),
}
}
}
Err(NotImplemented)
}
_ => Err(NotImplemented),
}
}
fn as_any(&self) -> &dyn Any {
self
}
fn shape(&self) -> &[usize] {
self.array.shape()
}
fn dtype(&self) -> TypeId {
TypeId::of::<T>()
}
fn box_clone(&self) -> Box<dyn ArrayProtocol> {
Box::new(self.clone())
}
}
#[derive(Debug, Clone)]
pub struct MockDistributedArray<T: Clone + 'static> {
chunks: Vec<T>,
shape: Vec<usize>,
}
impl<T: Clone + Send + Sync + 'static> MockDistributedArray<T> {
#[must_use]
pub fn new(chunks: Vec<T>, shape: Vec<usize>) -> Self {
Self { chunks, shape }
}
}
impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MockDistributedArray<T> {
fn array_function(
&self,
func: &ArrayFunction,
_types: &[TypeId],
_args: &[Box<dyn Any>],
_kwargs: &HashMap<String, Box<dyn Any>>,
) -> Result<Box<dyn Any>, NotImplemented> {
match func.name {
"scirs2::mean" => {
let result = T::clone(&self.chunks[0]);
Ok(Box::new(result))
}
_ => Err(NotImplemented),
}
}
fn as_any(&self) -> &dyn Any {
self
}
fn shape(&self) -> &[usize] {
&self.shape
}
fn box_clone(&self) -> Box<dyn ArrayProtocol> {
Box::new(self.clone())
}
}
impl<T: Clone + Send + Sync + 'static> DistributedArray for MockDistributedArray<T> {
fn distribution_info(&self) -> HashMap<String, String> {
let mut info = HashMap::new();
info.insert("type".to_string(), "mock_distributed".to_string());
info.insert("chunks".to_string(), self.chunks.len().to_string());
info
}
fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
Ok(Box::new(self.clone()) as Box<dyn ArrayProtocol>)
}
fn scatter(&self, _numchunks: usize) -> CoreResult<Box<dyn DistributedArray>> {
Ok(Box::new(self.clone()) as Box<dyn DistributedArray>)
}
fn is_distributed(&self) -> bool {
true
}
}
#[derive(Debug, Clone)]
pub struct MockGPUArray<T: Clone + 'static> {
data: Vec<T>,
shape: Vec<usize>,
device: String,
}
impl<T: Clone + Send + Sync + 'static> MockGPUArray<T> {
#[must_use]
pub fn new(data: Vec<T>, shape: Vec<usize>, device: String) -> Self {
Self {
data,
shape,
device,
}
}
}
impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MockGPUArray<T> {
fn array_function(
&self,
func: &ArrayFunction,
_types: &[TypeId],
_args: &[Box<dyn Any>],
_kwargs: &HashMap<String, Box<dyn Any>>,
) -> Result<Box<dyn Any>, NotImplemented> {
match func.name {
"scirs2::matmul" => {
let result =
MockGPUArray::new(self.data.clone(), self.shape.clone(), self.device.clone());
Ok(Box::new(result))
}
_ => Err(NotImplemented),
}
}
fn as_any(&self) -> &dyn Any {
self
}
fn shape(&self) -> &[usize] {
&self.shape
}
fn box_clone(&self) -> Box<dyn ArrayProtocol> {
Box::new(self.clone())
}
}
impl<T: Clone + Send + Sync + 'static> GPUArray for MockGPUArray<T> {
fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>> {
Ok(Box::new(self.clone()) as Box<dyn GPUArray>)
}
fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
Ok(Box::new(self.clone()) as Box<dyn ArrayProtocol>)
}
fn is_on_gpu(&self) -> bool {
true
}
fn device_info(&self) -> HashMap<String, String> {
let mut info = HashMap::new();
info.insert("device".to_string(), self.device.clone());
info.insert("type".to_string(), "mock_gpu".to_string());
info
}
}
#[derive(Debug)]
pub struct ArrayProtocolFunction<F> {
func: F,
name: &'static str,
}
impl<F> ArrayProtocolFunction<F> {
#[must_use]
pub fn new(func: F, name: &'static str) -> Self {
Self { func, name }
}
}
impl<F> ArrayProtocolFunction<F>
where
F: Clone + Send + Sync + 'static,
{
pub fn register(self) -> F {
let implementation = Arc::new(
move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
Err(CoreError::NotImplementedError(ErrorContext::new(
"ArrayProtocolFunction: Implementation for array protocol functions is not complete".to_string()
)))
},
);
let array_func = ArrayFunction {
name: self.name,
implementation,
};
if let Ok(mut registry) = ArrayFunctionRegistry::global().write() {
registry.register(array_func);
} else {
eprintln!("Warning: Failed to acquire write lock on ArrayFunctionRegistry during array protocol building, skipping function registration");
}
self.func
}
}
#[macro_export]
macro_rules! array_function_def {
(fn $name:ident $(<$($gen:ident),*>)? ($($arg:ident : $arg_ty:ty),*) -> $ret:ty $body:block, $funcname:expr) => {
{
fn $name $(<$($gen),*>)? ($($arg : $arg_ty),*) -> $ret $body
$name
}
};
}
pub use self::distributed_impl::{
ArrayChunk, DistributedBackend, DistributedConfig, DistributedNdarray, DistributionStrategy,
};
pub use self::gpu_impl::{
kernels as gpu_kernels, GPUArrayBuilder, GPUBackend, GPUConfig, GPUNdarray,
};
pub use self::jit_impl::{
CraneliftFunctionFactory, JITBackend, JITConfig, JITEnabledArray, JITFunctionImpl, JITManager,
LLVMFunctionFactory,
};
pub use self::operations::{
add, apply_elementwise, concatenate, inverse, matmul, multiply, reshape, subtract, sum, svd,
transpose, OperationError,
};
pub use self::ml_ops::{
activation, batch_norm, conv2d, cross_entropy, dropout, max_pool2d, self_attention,
ActivationFunc,
};
#[allow(dead_code)]
pub fn init() {
let mut jit_manager = JITManager::global().write().expect("Operation failed");
jit_manager.initialize();
}
pub mod traits {
use super::*;
pub trait StridedArray: ArrayProtocol {
#[must_use]
fn strides(&self) -> Vec<usize>;
#[must_use]
fn is_contiguous(&self) -> bool;
#[must_use]
fn is_fortran_contiguous(&self) -> bool;
}
pub trait ZeroCopyArray: ArrayProtocol {
#[must_use]
fn view(&self) -> Box<dyn ZeroCopyArray>;
#[must_use]
fn view_mut(&mut self) -> Box<dyn ZeroCopyArray>;
#[must_use]
fn is_view(&self) -> bool;
}
pub trait DifferentiableArray: ArrayProtocol {
fn gradient(
&self,
variables: &[Box<dyn DifferentiableArray>],
) -> Vec<Box<dyn DifferentiableArray>>;
fn set_requiresgrad(&mut self, requiresgrad: bool);
#[must_use]
fn requiresgrad(&self) -> bool;
#[must_use]
fn grad(&self) -> Option<Box<dyn DifferentiableArray>>;
}
pub trait AsyncArray: ArrayProtocol {
fn async_op<F, R>(&self, op: F) -> impl std::future::Future<Output = CoreResult<R>>
where
F: FnOnce(&Self) -> CoreResult<R> + Send + 'static,
R: Send + 'static;
#[must_use]
fn supports_async(&self) -> bool;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_array_protocol_registry() {
let implementation = Arc::new(
move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
Ok(Box::new(42.0) as Box<dyn Any>)
},
);
let func = ArrayFunction {
name: "scirs2::test::test_func",
implementation,
};
let registry = ArrayFunctionRegistry::global();
{
let mut reg = registry.write().expect("Operation failed");
reg.register(func.clone());
}
{
let reg = registry.read().expect("Operation failed");
let registered_func = reg
.get("scirs2::test::test_func")
.expect("Operation failed");
assert_eq!(registered_func.name, "scirs2::test::test_func");
}
}
#[test]
fn test_mock_distributed_array() {
let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
assert!(array.is_distributed());
let info = array.distribution_info();
assert_eq!(
info.get("type").expect("Operation failed"),
"mock_distributed"
);
assert_eq!(info.get("chunks").expect("Operation failed"), "3");
}
#[test]
fn test_mock_gpu_array() {
let array = MockGPUArray::new(vec![1.0, 2.0, 3.0], vec![3], "cuda:0".to_string());
assert!(array.is_on_gpu());
let info = array.device_info();
assert_eq!(info.get("device").expect("Operation failed"), "cuda:0");
assert_eq!(info.get("type").expect("Operation failed"), "mock_gpu");
}
#[test]
fn test_box_clone() {
let array = crate::ndarray::Array2::<f64>::ones((3, 3));
let wrapped = NdarrayWrapper::new(array);
let boxed: Box<dyn ArrayProtocol> = Box::new(wrapped);
let cloned = boxed.clone();
assert_eq!(cloned.shape(), &[3, 3]);
let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
let boxed: Box<dyn ArrayProtocol> = Box::new(array);
let cloned = boxed.clone();
assert_eq!(cloned.shape(), &[3]);
}
}
#[cfg(test)]
mod examples {
use super::*;
use ::ndarray::Array2;
use std::any::Any;
use std::collections::HashMap;
#[test]
fn example_distributed_array() {
let array = Array2::<f64>::ones((10, 5));
let config = DistributedConfig {
chunks: 3,
balance: true,
strategy: DistributionStrategy::RowWise,
backend: DistributedBackend::Threaded,
};
let dist_array = DistributedNdarray::from_array(&array, config);
assert_eq!(dist_array.num_chunks(), 3);
assert_eq!(dist_array.shape(), &[10, 5]);
let result = dist_array.to_array().expect("Operation failed");
assert_eq!(result.shape(), array.shape());
}
#[test]
fn example_gpu_array() {
let array = Array2::<f64>::ones((10, 5));
let config = GPUConfig {
backend: GPUBackend::CUDA,
device_id: 0,
async_ops: true,
mixed_precision: false,
memory_fraction: 0.9,
};
let gpu_array = GPUNdarray::new(array.clone(), config);
assert_eq!(gpu_array.shape(), &[10, 5]);
assert!(gpu_array.is_on_gpu());
let info = gpu_array.device_info();
assert_eq!(info.get("backend").expect("Operation failed"), "CUDA");
let gpu_box: Box<dyn ArrayProtocol> = Box::new(gpu_array);
let gpu_clone = gpu_box.clone();
assert_eq!(gpu_clone.shape(), &[10, 5]);
}
#[test]
fn example_jit_array() {
init();
let array = Array2::<f64>::ones((10, 5));
let wrapped = NdarrayWrapper::new(array);
let jitarray: JITEnabledArray<f64, NdarrayWrapper<f64, crate::ndarray::Ix2>> =
JITEnabledArray::new(wrapped);
assert!(jitarray.supports_jit());
let expression = "x + y";
let jit_function = jitarray.compile(expression).expect("Operation failed");
assert_eq!(jit_function.source(), expression);
let info = jitarray.jit_info();
assert_eq!(info.get("supports_jit").expect("Operation failed"), "true");
let jit_box: Box<dyn ArrayProtocol> = Box::new(jitarray);
let jit_clone = jit_box.clone();
assert_eq!(jit_clone.shape(), &[10, 5]);
}
#[test]
fn example_cloning_array_protocol_objects() {
let array = Array2::<f64>::ones((10, 5));
let config = GPUConfig::default();
let gpu_array = GPUNdarray::new(array.clone(), config);
let boxed: Box<dyn ArrayProtocol> = Box::new(gpu_array);
let cloned = boxed.clone();
assert_eq!(cloned.shape(), &[10, 5]);
let config = DistributedConfig {
chunks: 3,
balance: true,
strategy: DistributionStrategy::RowWise,
backend: DistributedBackend::Threaded,
};
let dist_array = DistributedNdarray::from_array(&array, config);
let boxed: Box<dyn ArrayProtocol> = Box::new(dist_array);
let cloned = boxed.clone();
assert_eq!(cloned.shape(), &[10, 5]);
}
#[test]
fn example_array_interoperability() {
init();
let cpu_array = Array2::<f64>::ones((5, 5));
let gpu_config = GPUConfig {
backend: GPUBackend::CUDA,
device_id: 0,
async_ops: false,
mixed_precision: false,
memory_fraction: 0.9,
};
let gpu_array = GPUNdarray::new(cpu_array.clone(), gpu_config);
let dist_config = DistributedConfig {
chunks: 2,
balance: true,
strategy: DistributionStrategy::RowWise,
backend: DistributedBackend::Threaded,
};
let dist_array = DistributedNdarray::from_array(&cpu_array, dist_config);
let gpu_wrapper: Box<dyn ArrayProtocol> = Box::new(gpu_array);
let dist_wrapper: Box<dyn ArrayProtocol> = Box::new(dist_array);
let gpu_clone = gpu_wrapper.clone();
let dist_clone = dist_wrapper.clone();
assert_eq!(gpu_clone.shape(), &[5, 5]);
assert_eq!(dist_clone.shape(), &[5, 5]);
}
#[test]
fn example_custom_array_type() {
use std::sync::Arc;
struct MyCustomArray<T> {
data: Vec<T>,
shape: Vec<usize>,
}
impl<T: Clone + 'static> MyCustomArray<T> {
fn new(data: Vec<T>, shape: Vec<usize>) -> Self {
Self { data, shape }
}
}
impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MyCustomArray<T> {
fn array_function(
&self,
func: &ArrayFunction,
_types: &[TypeId],
_args: &[Box<dyn Any>],
_kwargs: &HashMap<String, Box<dyn Any>>,
) -> Result<Box<dyn Any>, NotImplemented> {
if func.name == "scirs2::example::custom_sum" {
match std::any::TypeId::of::<T>() {
tid if tid == std::any::TypeId::of::<f64>() => {
let f64_data = unsafe {
std::slice::from_raw_parts(
self.data.as_ptr() as *const f64,
self.data.len(),
)
};
let sum = f64_data.iter().sum::<f64>();
Ok(Box::new(sum))
}
tid if tid == std::any::TypeId::of::<f32>() => {
let f32_data = unsafe {
std::slice::from_raw_parts(
self.data.as_ptr() as *const f32,
self.data.len(),
)
};
let sum = f32_data.iter().sum::<f32>();
Ok(Box::new(sum))
}
_ => Err(NotImplemented),
}
} else {
Err(NotImplemented)
}
}
fn as_any(&self) -> &dyn Any {
self
}
fn shape(&self) -> &[usize] {
&self.shape
}
fn box_clone(&self) -> Box<dyn ArrayProtocol> {
Box::new(MyCustomArray {
data: self.data.clone(),
shape: self.shape.clone(),
})
}
}
let custom_array = MyCustomArray::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let boxed: Box<dyn ArrayProtocol> = Box::new(custom_array);
let cloned = boxed.clone();
assert_eq!(cloned.shape(), &[2, 2]);
let func = ArrayFunction {
name: "scirs2::example::custom_sum",
implementation: Arc::new(move |_args, _kwargs| {
Ok(Box::new(42.0) as Box<dyn Any>)
}),
};
let result = cloned.array_function(
&func,
&[std::any::TypeId::of::<f64>()],
&[],
&HashMap::new(),
);
assert!(result.is_ok());
if let Ok(value) = result {
let sum = *value.downcast_ref::<f64>().expect("Operation failed");
assert_eq!(sum, 10.0);
}
}
}
impl Clone for Box<dyn JITFunction> {
fn clone(&self) -> Self {
self.clone_box()
}
}