use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::fmt;
use std::sync::{LazyLock, RwLock};
use ::ndarray::{Array, Dimension};
use num_traits::{cast as num_cast, Float};
use crate::array_protocol::gpu_impl::GPUNdarray;
use crate::array_protocol::{
ArrayFunction, ArrayProtocol, GPUArray, NdarrayWrapper, NotImplemented,
};
use crate::error::{CoreError, CoreResult, ErrorContext};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Precision {
Half,
Single,
Double,
Mixed,
}
impl fmt::Display for Precision {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Precision::Half => write!(f, "half"),
Precision::Single => write!(f, "single"),
Precision::Double => write!(f, "double"),
Precision::Mixed => write!(f, "mixed"),
}
}
}
#[derive(Debug, Clone)]
pub struct MixedPrecisionConfig {
pub storage_precision: Precision,
pub computeprecision: Precision,
pub auto_precision: bool,
pub downcast_threshold: usize,
pub double_precision_accumulation: bool,
}
impl Default for MixedPrecisionConfig {
fn default() -> Self {
Self {
storage_precision: Precision::Single,
computeprecision: Precision::Double,
auto_precision: true,
downcast_threshold: 10_000_000, double_precision_accumulation: true,
}
}
}
pub static MIXED_PRECISION_CONFIG: LazyLock<RwLock<MixedPrecisionConfig>> = LazyLock::new(|| {
RwLock::new(MixedPrecisionConfig {
storage_precision: Precision::Single,
computeprecision: Precision::Double,
auto_precision: true,
downcast_threshold: 10_000_000, double_precision_accumulation: true,
})
});
#[allow(dead_code)]
pub fn set_mixed_precision_config(config: MixedPrecisionConfig) {
if let Ok(mut global_config) = MIXED_PRECISION_CONFIG.write() {
*global_config = config;
}
}
#[allow(dead_code)]
pub fn get_mixed_precision_config() -> MixedPrecisionConfig {
MIXED_PRECISION_CONFIG
.read()
.map(|c| c.clone())
.unwrap_or_default()
}
#[allow(dead_code)]
pub fn determine_optimal_precision<T, D>(array: &Array<T, D>) -> Precision
where
T: Clone + 'static,
D: Dimension,
{
let config = get_mixed_precision_config();
let size = array.len();
if config.auto_precision {
if size >= config.downcast_threshold {
Precision::Single
} else {
Precision::Double
}
} else {
config.storage_precision
}
}
#[derive(Debug, Clone)]
pub struct MixedPrecisionArray<T, D>
where
T: Clone + 'static,
D: Dimension,
{
array: Array<T, D>,
storage_precision: Precision,
computeprecision: Precision,
}
impl<T, D> MixedPrecisionArray<T, D>
where
T: Clone + Float + 'static,
D: Dimension,
{
pub fn new(array: Array<T, D>) -> Self {
let precision = match std::mem::size_of::<T>() {
2 => Precision::Half,
4 => Precision::Single,
8 => Precision::Double,
_ => Precision::Mixed,
};
Self {
array,
storage_precision: precision,
computeprecision: precision,
}
}
pub fn with_computeprecision(data: Array<T, D>, computeprecision: Precision) -> Self {
let storage_precision = match std::mem::size_of::<T>() {
2 => Precision::Half,
4 => Precision::Single,
8 => Precision::Double,
_ => Precision::Mixed,
};
Self {
array: data,
storage_precision,
computeprecision,
}
}
pub fn at_precision<U>(&self) -> CoreResult<Array<U, D>>
where
U: Clone + Float + 'static,
{
let mut converted: Vec<U> = Vec::with_capacity(self.array.len());
for x in self.array.iter() {
match num_cast::<T, U>(*x) {
Some(v) => converted.push(v),
None => {
return Err(CoreError::ComputationError(ErrorContext::new(format!(
"at_precision: failed to cast element to target precision (source size \
{} bytes, target size {} bytes)",
std::mem::size_of::<T>(),
std::mem::size_of::<U>(),
))))
}
}
}
Array::from_shape_vec(self.array.raw_dim(), converted).map_err(|e| {
CoreError::ShapeError(ErrorContext::new(format!(
"at_precision: failed to reconstruct array from converted elements: {e}"
)))
})
}
pub fn storage_precision(&self) -> Precision {
self.storage_precision
}
pub const fn array(&self) -> &Array<T, D> {
&self.array
}
}
pub trait MixedPrecisionSupport: ArrayProtocol {
fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>>;
fn precision(&self) -> Precision;
fn supports_precision(&self, precision: Precision) -> bool;
fn as_array_protocol(&self) -> &dyn ArrayProtocol;
}
fn extract_inner_ndarray<T, D>(arg: &dyn Any) -> Option<Array<T, D>>
where
T: Clone + Float + Send + Sync + 'static,
D: Dimension + Send + Sync + 'static,
{
if let Some(ap) = arg.downcast_ref::<Box<dyn ArrayProtocol>>() {
let inner: &dyn ArrayProtocol = &**ap;
if let Some(mp) = inner.as_any().downcast_ref::<MixedPrecisionArray<T, D>>() {
return Some(mp.array.clone());
}
if let Some(nd) = inner.as_any().downcast_ref::<NdarrayWrapper<T, D>>() {
return Some(nd.as_array().clone());
}
return None;
}
if let Some(mp) = arg.downcast_ref::<MixedPrecisionArray<T, D>>() {
return Some(mp.array.clone());
}
if let Some(nd) = arg.downcast_ref::<NdarrayWrapper<T, D>>() {
return Some(nd.as_array().clone());
}
None
}
fn rewrap_result_as_array_protocol<T>(result: Box<dyn Any>) -> Box<dyn Any>
where
T: Clone + Float + Send + Sync + 'static,
{
use crate::ndarray::{Ix1, Ix2, IxDyn};
if result.is::<Box<dyn ArrayProtocol>>() {
return result;
}
let result = match result.downcast::<NdarrayWrapper<T, Ix2>>() {
Ok(wrapper) => {
let boxed: Box<dyn ArrayProtocol> = wrapper;
return Box::new(boxed);
}
Err(other) => other,
};
let result = match result.downcast::<NdarrayWrapper<T, Ix1>>() {
Ok(wrapper) => {
let boxed: Box<dyn ArrayProtocol> = wrapper;
return Box::new(boxed);
}
Err(other) => other,
};
match result.downcast::<NdarrayWrapper<T, IxDyn>>() {
Ok(wrapper) => {
let boxed: Box<dyn ArrayProtocol> = wrapper;
Box::new(boxed)
}
Err(other) => other,
}
}
impl<T, D> ArrayProtocol for MixedPrecisionArray<T, D>
where
T: Clone + Float + Send + Sync + 'static,
D: Dimension + Send + Sync + 'static,
{
fn array_function(
&self,
func: &ArrayFunction,
types: &[TypeId],
args: &[Box<dyn Any>],
kwargs: &HashMap<String, Box<dyn Any>>,
) -> Result<Box<dyn Any>, NotImplemented> {
let wrapped_self = NdarrayWrapper::new(self.array.clone());
let precision = kwargs
.get("precision")
.and_then(|p| p.downcast_ref::<Precision>())
.cloned()
.unwrap_or(self.computeprecision);
match func.name {
"scirs2::array_protocol::operations::matmul"
| "scirs2::array_protocol::operations::add"
| "scirs2::array_protocol::operations::subtract"
| "scirs2::array_protocol::operations::multiply" => {
if args.len() < 2 {
return Err(NotImplemented);
}
let Some(other_array) = extract_inner_ndarray::<T, D>(args[1].as_ref()) else {
return Err(NotImplemented);
};
let wrapped_other = NdarrayWrapper::new(other_array);
if matches!(precision, Precision::Half) {
return Err(NotImplemented);
}
let new_args: Vec<Box<dyn Any>> =
vec![Box::new(wrapped_self.clone()), Box::new(wrapped_other)];
wrapped_self
.array_function(func, types, &new_args, kwargs)
.map(rewrap_result_as_array_protocol::<T>)
}
"scirs2::array_protocol::operations::transpose"
| "scirs2::array_protocol::operations::reshape"
| "scirs2::array_protocol::operations::sum" => {
let new_args: Vec<Box<dyn Any>> = vec![Box::new(wrapped_self.clone())];
wrapped_self
.array_function(func, types, &new_args, kwargs)
.map(rewrap_result_as_array_protocol::<T>)
}
_ => {
wrapped_self.array_function(func, types, args, kwargs)
}
}
}
fn as_any(&self) -> &dyn Any {
self
}
fn shape(&self) -> &[usize] {
self.array.shape()
}
fn box_clone(&self) -> Box<dyn ArrayProtocol> {
Box::new(Self {
array: self.array.clone(),
storage_precision: self.storage_precision,
computeprecision: self.computeprecision,
})
}
}
impl<T, D> MixedPrecisionSupport for MixedPrecisionArray<T, D>
where
T: Clone + Float + Send + Sync + 'static,
D: Dimension + Send + Sync + 'static,
{
fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
match precision {
Precision::Single => {
let current_precision = self.precision();
if current_precision == Precision::Single {
return Ok(Box::new(self.clone()));
}
let array_single = self.array.clone();
let newarray = MixedPrecisionArray::with_computeprecision(array_single, precision);
Ok(Box::new(newarray))
}
Precision::Double => {
let current_precision = self.precision();
if current_precision == Precision::Double {
return Ok(Box::new(self.clone()));
}
let array_double = self.array.clone();
let newarray = MixedPrecisionArray::with_computeprecision(array_double, precision);
Ok(Box::new(newarray))
}
Precision::Mixed => {
let array_mixed = self.array.clone();
let newarray =
MixedPrecisionArray::with_computeprecision(array_mixed, Precision::Double);
Ok(Box::new(newarray))
}
_ => Err(CoreError::NotImplementedError(ErrorContext::new(format!(
"Conversion to {precision} precision not implemented"
)))),
}
}
fn precision(&self) -> Precision {
if self.storage_precision != self.computeprecision {
Precision::Mixed
} else {
self.storage_precision
}
}
fn supports_precision(&self, precision: Precision) -> bool {
matches!(precision, Precision::Single | Precision::Double)
}
fn as_array_protocol(&self) -> &dyn ArrayProtocol {
self
}
}
impl<T, D> MixedPrecisionSupport for GPUNdarray<T, D>
where
T: Clone + Float + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
D: Dimension + Send + Sync + 'static + crate::ndarray::RemoveAxis,
{
fn to_precision(&self, precision: Precision) -> CoreResult<Box<dyn MixedPrecisionSupport>> {
let mut config = self.config().clone();
config.mixed_precision = precision == Precision::Mixed;
if let Ok(cpu_array) = self.to_cpu() {
if let Some(ndarray) = cpu_array.as_any().downcast_ref::<NdarrayWrapper<T, D>>() {
let new_gpu_array = GPUNdarray::new(ndarray.as_array().clone(), config);
return Ok(Box::new(new_gpu_array));
}
}
Err(CoreError::NotImplementedError(ErrorContext::new(format!(
"Conversion to {precision} precision not implemented for GPU arrays"
))))
}
fn precision(&self) -> Precision {
if self.config().mixed_precision {
Precision::Mixed
} else {
match std::mem::size_of::<T>() {
4 => Precision::Single,
8 => Precision::Double,
_ => Precision::Mixed,
}
}
}
fn supports_precision(&self, precision: Precision) -> bool {
true
}
fn as_array_protocol(&self) -> &dyn ArrayProtocol {
self
}
}
#[allow(dead_code)]
pub fn execute_with_precision<F, R>(
arrays: &[&dyn MixedPrecisionSupport],
precision: Precision,
executor: F,
) -> CoreResult<R>
where
F: FnOnce(&[&dyn ArrayProtocol]) -> CoreResult<R>,
R: 'static,
{
for array in arrays {
if !array.supports_precision(precision) {
return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
"One or more arrays do not support {precision} precision"
))));
}
}
let mut converted_arrays: Vec<Box<dyn MixedPrecisionSupport>> =
Vec::with_capacity(arrays.len());
for &array in arrays {
let converted = array.to_precision(precision)?;
converted_arrays.push(converted);
}
let protocol_refs: Vec<&dyn ArrayProtocol> = converted_arrays
.iter()
.map(|array| array.as_array_protocol())
.collect();
executor(&protocol_refs)
}
pub mod ops {
use super::*;
use crate::array_protocol::operations as array_ops;
pub fn matmul(
a: &dyn MixedPrecisionSupport,
b: &dyn MixedPrecisionSupport,
precision: Precision,
) -> CoreResult<Box<dyn ArrayProtocol>> {
execute_with_precision(&[a, b], precision, |arrays| {
match array_ops::matmul(arrays[0], arrays[1]) {
Ok(result) => Ok(result),
Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
e.to_string(),
))),
}
})
}
pub fn add(
a: &dyn MixedPrecisionSupport,
b: &dyn MixedPrecisionSupport,
precision: Precision,
) -> CoreResult<Box<dyn ArrayProtocol>> {
execute_with_precision(&[a, b], precision, |arrays| {
match array_ops::add(arrays[0], arrays[1]) {
Ok(result) => Ok(result),
Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
e.to_string(),
))),
}
})
}
pub fn multiply(
a: &dyn MixedPrecisionSupport,
b: &dyn MixedPrecisionSupport,
precision: Precision,
) -> CoreResult<Box<dyn ArrayProtocol>> {
execute_with_precision(&[a, b], precision, |arrays| {
match array_ops::multiply(arrays[0], arrays[1]) {
Ok(result) => Ok(result),
Err(e) => Err(CoreError::NotImplementedError(ErrorContext::new(
e.to_string(),
))),
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use ::ndarray::arr2;
#[test]
fn test_mixed_precision_array() {
let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
let mixed_array = MixedPrecisionArray::new(array.clone());
assert_eq!(mixed_array.storage_precision(), Precision::Double);
let array_protocol: &dyn ArrayProtocol = &mixed_array;
assert!(array_protocol
.as_any()
.is::<MixedPrecisionArray<f64, crate::ndarray::Ix2>>());
}
#[test]
fn test_mixed_precision_support() {
crate::array_protocol::init();
let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
let mixed_array = MixedPrecisionArray::new(array.clone());
let mixed_support: &dyn MixedPrecisionSupport = &mixed_array;
assert_eq!(mixed_support.precision(), Precision::Double);
assert!(mixed_support.supports_precision(Precision::Single));
assert!(mixed_support.supports_precision(Precision::Double));
}
#[test]
fn test_at_precision_f64_to_f32() {
use ::ndarray::array;
let arr = array![1.0_f64, 2.5_f64, -1.75_f64].into_dyn();
let mp = MixedPrecisionArray::new(arr);
let as_f32: crate::ndarray::ArrayD<f32> = mp
.at_precision()
.expect("f64 → f32 precision conversion should succeed");
assert!((as_f32[0] - 1.0_f32).abs() < 1e-6);
assert!((as_f32[1] - 2.5_f32).abs() < 1e-6);
assert!((as_f32[2] - (-1.75_f32)).abs() < 1e-6);
}
#[test]
fn test_at_precision_f32_to_f64() {
use ::ndarray::array;
let arr = array![0.5_f32, 1.25_f32, -2.0_f32].into_dyn();
let mp = MixedPrecisionArray::new(arr);
let as_f64: crate::ndarray::ArrayD<f64> = mp
.at_precision()
.expect("f32 → f64 precision conversion should succeed");
assert!((as_f64[0] - 0.5_f64).abs() < 1e-12);
assert!((as_f64[1] - 1.25_f64).abs() < 1e-12);
assert!((as_f64[2] - (-2.0_f64)).abs() < 1e-12);
}
#[test]
fn test_at_precision_same_type_is_identity() {
use ::ndarray::array;
let arr = array![42.0_f64, -7.5_f64].into_dyn();
let mp = MixedPrecisionArray::new(arr.clone());
let result: crate::ndarray::ArrayD<f64> = mp
.at_precision()
.expect("f64 → f64 precision conversion should succeed");
for (a, b) in arr.iter().zip(result.iter()) {
assert_eq!(*a, *b, "Identity conversion must not change values");
}
}
#[test]
fn test_at_precision_preserves_shape() {
let arr = arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]);
let mp = MixedPrecisionArray::new(arr);
let as_f32: crate::ndarray::Array<f32, crate::ndarray::Ix2> = mp
.at_precision()
.expect("2D f64 → f32 conversion should succeed");
assert_eq!(as_f32.shape(), &[2, 2]);
assert!((as_f32[[0, 0]] - 1.0_f32).abs() < 1e-6);
assert!((as_f32[[1, 1]] - 4.0_f32).abs() < 1e-6);
}
#[test]
fn test_execute_with_precision_matmul_single() {
crate::array_protocol::init();
let a = MixedPrecisionArray::new(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]));
let b = MixedPrecisionArray::new(arr2(&[[5.0_f64, 6.0], [7.0, 8.0]]));
let result = ops::matmul(&a, &b, Precision::Single)
.expect("mixed-precision matmul should succeed on stable Rust");
let wrapper = result
.as_any()
.downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix2>>()
.expect("matmul result should be an NdarrayWrapper<f64, Ix2>");
let out = wrapper.as_array();
assert_eq!(out.shape(), &[2, 2]);
assert!((out[[0, 0]] - 19.0).abs() < 1e-9);
assert!((out[[0, 1]] - 22.0).abs() < 1e-9);
assert!((out[[1, 0]] - 43.0).abs() < 1e-9);
assert!((out[[1, 1]] - 50.0).abs() < 1e-9);
}
#[test]
fn test_execute_with_precision_add_single() {
crate::array_protocol::init();
let a = MixedPrecisionArray::new(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]));
let b = MixedPrecisionArray::new(arr2(&[[10.0_f64, 20.0], [30.0, 40.0]]));
let result = ops::add(&a, &b, Precision::Single)
.expect("mixed-precision add should succeed on stable Rust");
let wrapper = result
.as_any()
.downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix2>>()
.expect("add result should be an NdarrayWrapper<f64, Ix2>");
let out = wrapper.as_array();
assert_eq!(out.shape(), &[2, 2]);
assert!((out[[0, 0]] - 11.0).abs() < 1e-9);
assert!((out[[0, 1]] - 22.0).abs() < 1e-9);
assert!((out[[1, 0]] - 33.0).abs() < 1e-9);
assert!((out[[1, 1]] - 44.0).abs() < 1e-9);
}
#[test]
fn test_execute_with_precision_half_is_rejected() {
crate::array_protocol::init();
let a = MixedPrecisionArray::new(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]));
let b = MixedPrecisionArray::new(arr2(&[[5.0_f64, 6.0], [7.0, 8.0]]));
let result = ops::matmul(&a, &b, Precision::Half);
assert!(
result.is_err(),
"Half precision matmul must return an error"
);
}
}