use ferray_core::Array;
use ferray_core::dimension::Dimension;
use ferray_core::dtype::Element;
use ferray_core::dtype::promotion::{PromoteTo, Promoted};
use ferray_core::error::{FerrayError, FerrayResult};
use num_traits::Float;
use crate::helpers::binary_elementwise_op;
use crate::ops::arithmetic::WrappingArith;
pub trait PromoteFloat: Element + Copy {
type Compute: Element + Copy + Float;
type Out: Element + Copy;
fn to_compute(self) -> Self::Compute;
fn narrow(c: Self::Compute) -> Self::Out;
}
macro_rules! impl_promote_float_same {
($int:ty => $flt:ty) => {
impl PromoteFloat for $int {
type Compute = $flt;
type Out = $flt;
#[inline]
fn to_compute(self) -> $flt {
<$int as PromoteTo<$flt>>::promote(self)
}
#[inline]
fn narrow(c: $flt) -> $flt {
c
}
}
};
}
impl_promote_float_same!(i16 => f32);
impl_promote_float_same!(u16 => f32);
impl_promote_float_same!(i32 => f64);
impl_promote_float_same!(i64 => f64);
impl_promote_float_same!(u32 => f64);
impl_promote_float_same!(u64 => f64);
#[cfg(feature = "f16")]
macro_rules! impl_promote_float_f16 {
($int:ty) => {
impl PromoteFloat for $int {
type Compute = f32;
type Out = half::f16;
#[inline]
fn to_compute(self) -> f32 {
<$int as PromoteTo<f32>>::promote(self)
}
#[inline]
fn narrow(c: f32) -> half::f16 {
half::f16::from_f32(c)
}
}
};
}
#[cfg(feature = "f16")]
impl_promote_float_f16!(bool);
#[cfg(feature = "f16")]
impl_promote_float_f16!(i8);
#[cfg(feature = "f16")]
impl_promote_float_f16!(u8);
#[inline]
fn unary_promote_float<T, D>(
input: &Array<T, D>,
op: impl Fn(&Array<T::Compute, D>) -> FerrayResult<Array<T::Compute, D>>,
) -> FerrayResult<Array<<T as PromoteFloat>::Out, D>>
where
T: PromoteFloat,
D: Dimension,
{
let compute: Array<T::Compute, D> = if let Some(slice) = input.as_slice() {
let data: Vec<T::Compute> = slice.iter().map(|&x| x.to_compute()).collect();
Array::from_vec(input.dim().clone(), data)?
} else {
let data: Vec<T::Compute> = input.iter().map(|&x| x.to_compute()).collect();
Array::from_vec(input.dim().clone(), data)?
};
let result = op(&compute)?;
let out: Array<<T as PromoteFloat>::Out, D> = if let Some(slice) = result.as_slice() {
let data: Vec<<T as PromoteFloat>::Out> = slice.iter().map(|&c| T::narrow(c)).collect();
Array::from_vec(result.dim().clone(), data)?
} else {
let data: Vec<<T as PromoteFloat>::Out> = result.iter().map(|&c| T::narrow(c)).collect();
Array::from_vec(result.dim().clone(), data)?
};
Ok(out)
}
macro_rules! unary_promote_fn {
(
$(#[$attr:meta])*
$name:ident,
$op_path:path
) => {
$(#[$attr])*
pub fn $name<T, D>(
input: &Array<T, D>,
) -> FerrayResult<Array<<T as PromoteFloat>::Out, D>>
where
T: PromoteFloat,
T::Compute: crate::cr_math::CrMath,
D: Dimension,
{
unary_promote_float(input, |c| $op_path(c))
}
};
}
unary_promote_fn!(
exp_promote,
crate::exp
);
unary_promote_fn!(
exp2_promote,
crate::exp2
);
unary_promote_fn!(
expm1_promote,
crate::expm1
);
unary_promote_fn!(
log_promote,
crate::log
);
unary_promote_fn!(
log2_promote,
crate::log2
);
unary_promote_fn!(
log10_promote,
crate::log10
);
unary_promote_fn!(
log1p_promote,
crate::log1p
);
unary_promote_fn!(
sqrt_promote,
crate::sqrt
);
unary_promote_fn!(
cbrt_promote,
crate::cbrt
);
unary_promote_fn!(
fabs_promote,
crate::fabs
);
unary_promote_fn!(
rint_promote,
crate::rint
);
unary_promote_fn!(
sin_promote,
crate::sin
);
unary_promote_fn!(
cos_promote,
crate::cos
);
unary_promote_fn!(
tan_promote,
crate::tan
);
unary_promote_fn!(
arcsin_promote,
crate::arcsin
);
unary_promote_fn!(
arccos_promote,
crate::arccos
);
unary_promote_fn!(
arctan_promote,
crate::arctan
);
unary_promote_fn!(
sinh_promote,
crate::sinh
);
unary_promote_fn!(
cosh_promote,
crate::cosh
);
unary_promote_fn!(
tanh_promote,
crate::tanh
);
unary_promote_fn!(
arcsinh_promote,
crate::arcsinh
);
unary_promote_fn!(
arccosh_promote,
crate::arccosh
);
unary_promote_fn!(
arctanh_promote,
crate::arctanh
);
#[inline]
fn binary_promote_float<T, D>(
a: &Array<T, D>,
b: &Array<T, D>,
op: impl Fn(&Array<T::Compute, D>, &Array<T::Compute, D>) -> FerrayResult<Array<T::Compute, D>>,
) -> FerrayResult<Array<<T as PromoteFloat>::Out, D>>
where
T: PromoteFloat,
D: Dimension,
{
let a_compute = cast_to_compute(a)?;
let b_compute = cast_to_compute(b)?;
let result = op(&a_compute, &b_compute)?;
let out: Array<<T as PromoteFloat>::Out, D> = if let Some(slice) = result.as_slice() {
let data: Vec<<T as PromoteFloat>::Out> = slice.iter().map(|&c| T::narrow(c)).collect();
Array::from_vec(result.dim().clone(), data)?
} else {
let data: Vec<<T as PromoteFloat>::Out> = result.iter().map(|&c| T::narrow(c)).collect();
Array::from_vec(result.dim().clone(), data)?
};
Ok(out)
}
#[inline]
fn cast_to_compute<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T::Compute, D>>
where
T: PromoteFloat,
D: Dimension,
{
if let Some(slice) = input.as_slice() {
let data: Vec<T::Compute> = slice.iter().map(|&x| x.to_compute()).collect();
Array::from_vec(input.dim().clone(), data)
} else {
let data: Vec<T::Compute> = input.iter().map(|&x| x.to_compute()).collect();
Array::from_vec(input.dim().clone(), data)
}
}
macro_rules! binary_promote_fn {
(
$(#[$attr:meta])*
$name:ident,
$op_path:path
) => {
$(#[$attr])*
pub fn $name<T, D>(
a: &Array<T, D>,
b: &Array<T, D>,
) -> FerrayResult<Array<<T as PromoteFloat>::Out, D>>
where
T: PromoteFloat,
T::Compute: crate::cr_math::CrMath,
D: Dimension,
{
binary_promote_float(a, b, |x, y| $op_path(x, y))
}
};
}
binary_promote_fn!(
hypot_promote,
crate::hypot
);
binary_promote_fn!(
arctan2_promote,
crate::arctan2
);
binary_promote_fn!(
logaddexp_promote,
crate::logaddexp
);
binary_promote_fn!(
logaddexp2_promote,
crate::logaddexp2
);
binary_promote_fn!(
copysign_promote,
crate::copysign
);
binary_promote_fn!(
nextafter_promote,
crate::nextafter
);
#[inline]
fn cast_array<A, Out, D>(a: &Array<A, D>) -> FerrayResult<Array<Out, D>>
where
A: Element + Copy + PromoteTo<Out>,
Out: Element + Copy,
D: Dimension,
{
if let Some(slice) = a.as_slice() {
let data: Vec<Out> = slice.iter().map(|&x| x.promote()).collect();
Array::from_vec(a.dim().clone(), data)
} else {
let data: Vec<Out> = a.iter().map(|&x| x.promote()).collect();
Array::from_vec(a.dim().clone(), data)
}
}
pub fn add_promoted<A, B, D>(
a: &Array<A, D>,
b: &Array<B, D>,
) -> FerrayResult<Array<<A as Promoted<B>>::Output, D>>
where
A: Element + Copy + Promoted<B> + PromoteTo<<A as Promoted<B>>::Output>,
B: Element + Copy + PromoteTo<<A as Promoted<B>>::Output>,
<A as Promoted<B>>::Output: Element + Copy + WrappingArith,
D: Dimension,
{
if a.shape() != b.shape() {
return Err(FerrayError::shape_mismatch(format!(
"add_promoted: shapes {:?} and {:?} differ",
a.shape(),
b.shape()
)));
}
let a_cast = cast_array::<A, <A as Promoted<B>>::Output, D>(a)?;
let b_cast = cast_array::<B, <A as Promoted<B>>::Output, D>(b)?;
binary_elementwise_op(&a_cast, &b_cast, WrappingArith::wadd)
}
pub fn subtract_promoted<A, B, D>(
a: &Array<A, D>,
b: &Array<B, D>,
) -> FerrayResult<Array<<A as Promoted<B>>::Output, D>>
where
A: Element + Copy + Promoted<B> + PromoteTo<<A as Promoted<B>>::Output>,
B: Element + Copy + PromoteTo<<A as Promoted<B>>::Output>,
<A as Promoted<B>>::Output: Element + Copy + WrappingArith,
D: Dimension,
{
if a.shape() != b.shape() {
return Err(FerrayError::shape_mismatch(format!(
"subtract_promoted: shapes {:?} and {:?} differ",
a.shape(),
b.shape()
)));
}
let a_cast = cast_array::<A, <A as Promoted<B>>::Output, D>(a)?;
let b_cast = cast_array::<B, <A as Promoted<B>>::Output, D>(b)?;
binary_elementwise_op(&a_cast, &b_cast, WrappingArith::wsub)
}
pub fn multiply_promoted<A, B, D>(
a: &Array<A, D>,
b: &Array<B, D>,
) -> FerrayResult<Array<<A as Promoted<B>>::Output, D>>
where
A: Element + Copy + Promoted<B> + PromoteTo<<A as Promoted<B>>::Output>,
B: Element + Copy + PromoteTo<<A as Promoted<B>>::Output>,
<A as Promoted<B>>::Output: Element + Copy + WrappingArith,
D: Dimension,
{
if a.shape() != b.shape() {
return Err(FerrayError::shape_mismatch(format!(
"multiply_promoted: shapes {:?} and {:?} differ",
a.shape(),
b.shape()
)));
}
let a_cast = cast_array::<A, <A as Promoted<B>>::Output, D>(a)?;
let b_cast = cast_array::<B, <A as Promoted<B>>::Output, D>(b)?;
binary_elementwise_op(&a_cast, &b_cast, WrappingArith::wmul)
}
pub fn divide_promoted<A, B, D>(
a: &Array<A, D>,
b: &Array<B, D>,
) -> FerrayResult<Array<<A as Promoted<B>>::Output, D>>
where
A: Element + Copy + Promoted<B> + PromoteTo<<A as Promoted<B>>::Output>,
B: Element + Copy + PromoteTo<<A as Promoted<B>>::Output>,
<A as Promoted<B>>::Output: Element + Copy + Float,
D: Dimension,
{
if a.shape() != b.shape() {
return Err(FerrayError::shape_mismatch(format!(
"divide_promoted: shapes {:?} and {:?} differ",
a.shape(),
b.shape()
)));
}
let a_cast = cast_array::<A, <A as Promoted<B>>::Output, D>(a)?;
let b_cast = cast_array::<B, <A as Promoted<B>>::Output, D>(b)?;
binary_elementwise_op(&a_cast, &b_cast, |x, y| x / y)
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::{Ix1, Ix2};
#[test]
fn add_i8_i64_promotes_to_i64() {
let a = Array::<i8, Ix1>::from_vec(Ix1::new([2]), vec![1i8, 2]).unwrap();
let b = Array::<i64, Ix1>::from_vec(Ix1::new([2]), vec![10i64, 20]).unwrap();
let c = add_promoted(&a, &b).unwrap();
let slice: &[i64] = c.as_slice().unwrap();
assert_eq!(slice, &[11, 22]);
}
#[test]
fn multiply_i32_i64_promotes_to_i64() {
let a = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![3i32, 4]).unwrap();
let b = Array::<i64, Ix1>::from_vec(Ix1::new([2]), vec![10i64, 20]).unwrap();
let c = multiply_promoted(&a, &b).unwrap();
let slice: &[i64] = c.as_slice().unwrap();
assert_eq!(slice, &[30, 80]);
}
#[test]
fn add_i32_f64_promotes_to_f64() {
let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1i32, 2, 3]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![0.5, 1.5, 2.5]).unwrap();
let c = add_promoted(&a, &b).unwrap();
let slice: &[f64] = c.as_slice().unwrap();
assert_eq!(slice, &[1.5, 3.5, 5.5]);
}
#[test]
fn add_f32_f64_promotes_to_f64() {
let a = Array::<f32, Ix1>::from_vec(Ix1::new([3]), vec![1.0f32, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![0.5, 0.5, 0.5]).unwrap();
let c = add_promoted(&a, &b).unwrap();
let slice: &[f64] = c.as_slice().unwrap();
assert_eq!(slice, &[1.5, 2.5, 3.5]);
}
#[test]
fn subtract_i16_f32_promotes_to_f32() {
let a = Array::<i16, Ix1>::from_vec(Ix1::new([3]), vec![10i16, 20, 30]).unwrap();
let b = Array::<f32, Ix1>::from_vec(Ix1::new([3]), vec![1.5f32, 2.5, 3.5]).unwrap();
let c = subtract_promoted(&a, &b).unwrap();
let slice: &[f32] = c.as_slice().unwrap();
assert_eq!(slice, &[8.5, 17.5, 26.5]);
}
#[test]
fn multiply_u8_f64_promotes_to_f64() {
let a = Array::<u8, Ix1>::from_vec(Ix1::new([3]), vec![2u8, 3, 4]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![0.5, 0.5, 0.5]).unwrap();
let c = multiply_promoted(&a, &b).unwrap();
assert_eq!(c.as_slice().unwrap(), &[1.0, 1.5, 2.0]);
}
#[test]
fn divide_f32_f64_promotes_to_f64() {
let a = Array::<f32, Ix1>::from_vec(Ix1::new([3]), vec![10.0f32, 20.0, 30.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0, 4.0, 5.0]).unwrap();
let c = divide_promoted(&a, &b).unwrap();
let slice: &[f64] = c.as_slice().unwrap();
assert_eq!(slice, &[5.0, 5.0, 6.0]);
}
#[test]
fn same_type_path_is_identity() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![4.0, 5.0, 6.0]).unwrap();
let c = add_promoted(&a, &b).unwrap();
assert_eq!(c.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
}
#[test]
fn promoted_2d_shape_preserved() {
let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1i32, 2, 3, 4, 5, 6]).unwrap();
let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
.unwrap();
let c = add_promoted(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 3]);
assert_eq!(c.as_slice().unwrap(), &[1.5, 2.5, 3.5, 4.5, 5.5, 6.5]);
}
#[test]
fn shape_mismatch_errors() {
let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1i32, 2, 3]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
assert!(add_promoted(&a, &b).is_err());
}
}