use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use crate::views::ArrayView;
use num_traits::{AsPrimitive, NumCast, Zero}; use scirs2_core::Complex;
use std::ops::{Add, Div, Mul, Sub};
pub trait ConvertibleTo<T>: Sized {
fn convert_to(&self) -> Result<T>;
}
impl<S, T> ConvertibleTo<T> for S
where
S: Clone + NumCast,
T: Clone + NumCast,
{
fn convert_to(&self) -> Result<T> {
NumCast::from(self.clone()).ok_or_else(|| {
NumRs2Error::TypeCastError(format!(
"Failed to convert from type {} to type {}",
std::any::type_name::<S>(),
std::any::type_name::<T>()
))
})
}
}
impl<T> Array<T> {
pub fn astype<U>(&self) -> Result<Array<U>>
where
T: Clone + ConvertibleTo<U> + std::fmt::Debug,
U: Clone,
{
let data = self.to_vec();
let mut converted = Vec::with_capacity(data.len());
for value in data {
converted.push(value.convert_to()?);
}
Ok(Array::from_vec(converted).reshape(&self.shape()))
}
pub fn upcast<U>(&self) -> Result<Array<U>>
where
T: Clone + AsPrimitive<U>,
U: Clone + 'static + Copy,
{
let data = self.to_vec();
let converted: Vec<U> = data.into_iter().map(|x| x.as_()).collect();
Ok(Array::from_vec(converted).reshape(&self.shape()))
}
pub fn downcast<U>(&self) -> Result<Array<U>>
where
T: Clone + NumCast + std::fmt::Debug,
U: Clone + NumCast + std::fmt::Debug,
{
let data = self.to_vec();
let mut converted = Vec::with_capacity(data.len());
for value in data {
let converted_value = NumCast::from(value.clone()).ok_or_else(|| {
NumRs2Error::TypeCastError(format!(
"Value {:?} cannot be represented in target type",
value
))
})?;
converted.push(converted_value);
}
Ok(Array::from_vec(converted).reshape(&self.shape()))
}
pub fn to_complex<U>(&self) -> Result<Array<Complex<U>>>
where
T: Clone + NumCast,
U: Clone + NumCast + Zero,
{
let data = self.to_vec();
let mut converted = Vec::with_capacity(data.len());
for value in data {
let real = NumCast::from(value.clone()).ok_or_else(|| {
NumRs2Error::TypeCastError(
"Failed to convert real part to complex type".to_string(),
)
})?;
converted.push(Complex::new(real, U::zero()));
}
Ok(Array::from_vec(converted).reshape(&self.shape()))
}
}
impl<T> Array<T>
where
T: Clone + NumCast,
{
pub fn add_mixed<V, U>(&self, other: &Array<U>) -> Result<Array<V>>
where
T: Clone + NumCast + std::fmt::Debug,
U: Clone + NumCast + std::fmt::Debug,
V: Clone + NumCast + Add<Output = V> + std::fmt::Debug,
{
let self_converted = self.astype::<V>()?;
let other_converted = other.astype::<V>()?;
self_converted.add_broadcast(&other_converted)
}
pub fn subtract_mixed<V, U>(&self, other: &Array<U>) -> Result<Array<V>>
where
T: Clone + NumCast + std::fmt::Debug,
U: Clone + NumCast + std::fmt::Debug,
V: Clone + NumCast + Sub<Output = V> + std::fmt::Debug,
{
let self_converted = self.astype::<V>()?;
let other_converted = other.astype::<V>()?;
self_converted.subtract_broadcast(&other_converted)
}
pub fn multiply_mixed<V, U>(&self, other: &Array<U>) -> Result<Array<V>>
where
T: Clone + NumCast + std::fmt::Debug,
U: Clone + NumCast + std::fmt::Debug,
V: Clone + NumCast + Mul<Output = V> + std::fmt::Debug,
{
let self_converted = self.astype::<V>()?;
let other_converted = other.astype::<V>()?;
self_converted.multiply_broadcast(&other_converted)
}
pub fn divide_mixed<V, U>(&self, other: &Array<U>) -> Result<Array<V>>
where
T: Clone + NumCast + std::fmt::Debug,
U: Clone + NumCast + std::fmt::Debug,
V: Clone + NumCast + Div<Output = V> + std::fmt::Debug,
{
let self_converted = self.astype::<V>()?;
let other_converted = other.astype::<V>()?;
self_converted.divide_broadcast(&other_converted)
}
}
impl<'a, T> ArrayView<'a, T>
where
T: 'a + Clone + NumCast + std::fmt::Debug,
{
pub fn astype<U>(&self) -> Result<Array<U>>
where
U: Clone + NumCast,
{
self.to_owned().astype::<U>()
}
pub fn to_complex<U>(&self) -> Result<Array<Complex<U>>>
where
U: Clone + NumCast + Zero,
{
self.to_owned().to_complex::<U>()
}
}
pub fn promote_types<T, U>() -> std::any::TypeId
where
T: 'static,
U: 'static,
{
let t_id = std::any::TypeId::of::<T>();
let u_id = std::any::TypeId::of::<U>();
if type_precedence::<T>() >= type_precedence::<U>() {
t_id
} else {
u_id
}
}
fn type_precedence<T: 'static>() -> u8 {
let t_id = std::any::TypeId::of::<T>();
if t_id == std::any::TypeId::of::<bool>() {
0
} else if t_id == std::any::TypeId::of::<u8>() {
1
} else if t_id == std::any::TypeId::of::<i8>() {
2
} else if t_id == std::any::TypeId::of::<u16>() {
3
} else if t_id == std::any::TypeId::of::<i16>() {
4
} else if t_id == std::any::TypeId::of::<u32>() {
5
} else if t_id == std::any::TypeId::of::<i32>() {
6
} else if t_id == std::any::TypeId::of::<u64>() {
7
} else if t_id == std::any::TypeId::of::<i64>() {
8
} else if t_id == std::any::TypeId::of::<f32>() {
9
} else if t_id == std::any::TypeId::of::<f64>() {
10
} else if t_id == std::any::TypeId::of::<Complex<f32>>() {
11
} else if t_id == std::any::TypeId::of::<Complex<f64>>() {
12
} else {
0
} }