use std::cmp::Ordering;
use std::fmt;
use std::ops::{Add, Div, Mul, Rem, Sub};
use number_general as ng;
use safecast::CastFrom;
#[cfg(feature = "complex")]
pub use num_complex as complex;
pub use smallvec::smallvec as axes;
pub use smallvec::smallvec as coord;
pub use smallvec::smallvec as range;
pub use smallvec::smallvec as slice;
pub use smallvec::smallvec as shape;
pub use smallvec::smallvec as stackvec;
use smallvec::SmallVec;
pub use access::*;
pub use array::{
same_shape, MatrixDual, MatrixUnary, NDArray, NDArrayAbs, NDArrayBoolean, NDArrayBooleanScalar,
NDArrayCast, NDArrayCompare, NDArrayCompareScalar, NDArrayMath, NDArrayMathScalar,
NDArrayNumeric, NDArrayRead, NDArrayReduce, NDArrayReduceAll, NDArrayReduceBoolean,
NDArrayTransform, NDArrayTrig, NDArrayUnary, NDArrayUnaryBoolean, NDArrayWhere, NDArrayWrite,
};
#[cfg(feature = "complex")]
pub use array::{MatrixUnaryComplex, NDArrayComplex, NDArrayFourier};
pub use buffer::{Buffer, BufferConverter, BufferInstance, BufferMut};
pub use host::StackVec;
pub use platform::*;
mod access;
mod array;
mod buffer;
#[cfg(feature = "complex")]
pub mod fft;
pub mod host;
#[cfg(feature = "opencl")]
pub mod opencl;
pub mod ops;
mod platform;
fn id<T>(this: T) -> T {
this
}
#[cfg(feature = "opencl")]
pub trait CLType:
opencl::CLElement + PartialEq + Copy + Send + Sync + fmt::Display + fmt::Debug + 'static
{
}
#[cfg(not(feature = "opencl"))]
pub trait CLType: PartialEq + Copy + Send + Sync + fmt::Display + fmt::Debug + 'static {}
impl CLType for f32 {}
impl CLType for f64 {}
impl CLType for i8 {}
impl CLType for i16 {}
impl CLType for i32 {}
impl CLType for i64 {}
impl CLType for u8 {}
impl CLType for u16 {}
impl CLType for u32 {}
impl CLType for u64 {}
#[cfg(feature = "complex")]
impl CLType for complex::Complex<f32> {}
#[cfg(feature = "complex")]
impl CLType for complex::Complex<f64> {}
pub trait Number: CLType + Into<ng::Number> + CastFrom<ng::Number> + Default {
const ZERO: Self;
const ONE: Self;
type Abs: Number;
fn abs(self) -> Self::Abs;
fn add(self, other: Self) -> Self;
fn div(self, other: Self) -> Self;
fn mul(self, other: Self) -> Self;
fn sub(self, other: Self) -> Self;
fn pow(self, exp: Self) -> Self;
}
macro_rules! number {
($t:ty, $abs_t:ty, $one:expr, $zero:expr, $abs:expr, $add:expr, $div:expr, $mul:expr, $sub:expr, $pow:expr) => {
impl Number for $t {
const ONE: Self = $one;
const ZERO: Self = $zero;
type Abs = $abs_t;
fn abs(self) -> Self::Abs {
$abs(self)
}
fn add(self, other: Self) -> Self {
$add(self, other)
}
fn div(self, other: Self) -> Self {
$div(self, other)
}
fn mul(self, other: Self) -> Self {
$mul(self, other)
}
fn sub(self, other: Self) -> Self {
$sub(self, other)
}
fn pow(self, exp: Self) -> Self {
($pow)(self, exp)
}
}
};
}
#[cfg(feature = "complex")]
number!(
complex::Complex32,
f32,
complex::Complex32::ONE,
complex::Complex32::ZERO,
complex::Complex32::norm,
Add::add,
Div::div,
Mul::mul,
Sub::sub,
complex::Complex32::powc
);
#[cfg(feature = "complex")]
number!(
complex::Complex64,
f64,
complex::Complex64::ONE,
complex::Complex64::ZERO,
complex::Complex64::norm,
Add::add,
Div::div,
Mul::mul,
Sub::sub,
complex::Complex64::powc
);
number!(
f32,
Self,
1.,
0.,
f32::abs,
Add::add,
Div::div,
Mul::mul,
Sub::sub,
f32::powf
);
number!(
f64,
Self,
1.,
0.,
f64::abs,
Add::add,
Div::div,
Mul::mul,
Sub::sub,
f64::powf
);
number!(
i8,
Self,
1,
0,
Self::wrapping_abs,
Self::wrapping_add,
|l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
Self::wrapping_mul,
Self::wrapping_sub,
|a, e| f32::powi(a as f32, e as i32) as i8
);
number!(
i16,
Self,
1,
0,
Self::wrapping_abs,
Self::wrapping_add,
|l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
Self::wrapping_mul,
Self::wrapping_sub,
|a, e| f32::powi(a as f32, e as i32) as i16
);
number!(
i32,
Self,
1,
0,
Self::wrapping_abs,
Self::wrapping_add,
|l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
Self::wrapping_mul,
Self::wrapping_sub,
|a, e| f32::powi(a as f32, e) as i32
);
number!(
i64,
Self,
1,
0,
Self::wrapping_abs,
Self::wrapping_add,
|l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
Self::wrapping_mul,
Self::wrapping_sub,
|a, e| f64::powi(
a as f64,
i32::try_from(e).unwrap_or(if e >= 0 { i32::MAX } else { i32::MIN })
) as i64
);
number!(
u8,
Self,
1,
0,
id,
Self::wrapping_add,
|l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
Self::wrapping_mul,
Self::wrapping_sub,
|a, e| u8::pow(a, e as u32)
);
number!(
u16,
Self,
1,
0,
id,
Self::wrapping_add,
|l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
Self::wrapping_mul,
Self::wrapping_sub,
|a, e| u16::pow(a, e as u32)
);
number!(
u32,
Self,
1,
0,
id,
Self::wrapping_add,
|l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
Self::wrapping_mul,
Self::wrapping_sub,
u32::pow
);
number!(
u64,
Self,
1,
0,
id,
Self::wrapping_add,
|l, r| if r == 0 { 0 } else { Self::wrapping_div(l, r) },
Self::wrapping_mul,
Self::wrapping_sub,
|a, e| u64::pow(a, u32::try_from(e).unwrap_or(u32::MAX))
);
#[cfg(not(feature = "opencl"))]
pub trait Real: Number + PartialOrd {
const MAX: Self;
const MIN: Self;
fn max(l: Self, r: Self) -> Self;
fn min(l: Self, r: Self) -> Self;
fn rem(self, other: Self) -> Self;
fn round(self) -> Self;
}
#[cfg(feature = "opencl")]
pub trait Real: Number + PartialOrd + opencl::CLElementReal {
const MAX: Self;
const MIN: Self;
fn max(l: Self, r: Self) -> Self;
fn min(l: Self, r: Self) -> Self;
fn rem(self, other: Self) -> Self;
fn round(self) -> Self;
}
macro_rules! real {
($t:ty, $rem:expr, $ord:expr, $round:expr) => {
impl Real for $t {
const MAX: Self = <$t>::MAX;
const MIN: Self = <$t>::MIN;
fn max(l: Self, r: Self) -> $t {
match $ord(&l, &r) {
Ordering::Greater | Ordering::Equal => l,
Ordering::Less => r,
}
}
fn min(l: Self, r: Self) -> $t {
match $ord(&l, &r) {
Ordering::Less | Ordering::Equal => l,
Ordering::Greater => r,
}
}
fn rem(self, other: Self) -> Self {
$rem(self, other)
}
fn round(self) -> Self {
$round(self)
}
}
};
}
real!(f32, Rem::rem, f32::total_cmp, f32::round);
real!(f64, Rem::rem, f64::total_cmp, f64::round);
real!(i8, Self::wrapping_rem, Ord::cmp, id);
real!(i16, Self::wrapping_rem, Ord::cmp, id);
real!(i32, Self::wrapping_rem, Ord::cmp, id);
real!(i64, Self::wrapping_rem, Ord::cmp, id);
real!(u8, Self::wrapping_rem, Ord::cmp, id);
real!(u16, Self::wrapping_rem, Ord::cmp, id);
real!(u32, Self::wrapping_rem, Ord::cmp, id);
real!(u64, Self::wrapping_rem, Ord::cmp, id);
#[cfg(not(feature = "opencl"))]
pub trait Float: Number {
fn is_inf(self) -> bool;
fn is_nan(self) -> bool;
fn exp(self) -> Self;
fn ln(self) -> Self;
fn log(self, base: Self) -> Self;
fn sin(self) -> Self;
fn asin(self) -> Self;
fn sinh(self) -> Self;
fn cos(self) -> Self;
fn acos(self) -> Self;
fn cosh(self) -> Self;
fn tan(self) -> Self;
fn atan(self) -> Self;
fn tanh(self) -> Self;
}
#[cfg(feature = "opencl")]
pub trait Float: Number + opencl::CLElementTrig {
fn is_inf(self) -> bool;
fn is_nan(self) -> bool;
fn exp(self) -> Self;
fn ln(self) -> Self;
fn log(self, base: Self) -> Self;
fn sin(self) -> Self;
fn asin(self) -> Self;
fn sinh(self) -> Self;
fn cos(self) -> Self;
fn acos(self) -> Self;
fn cosh(self) -> Self;
fn tan(self) -> Self;
fn atan(self) -> Self;
fn tanh(self) -> Self;
}
macro_rules! float_type {
($t:ty, $inf:expr, $nan:expr) => {
impl Float for $t {
fn is_inf(self) -> bool {
$inf(self)
}
fn is_nan(self) -> bool {
$nan(self)
}
fn exp(self) -> Self {
<$t>::exp(self)
}
fn ln(self) -> Self {
<$t>::ln(self)
}
fn log(self, base: Self) -> Self {
self.ln() / base.ln()
}
fn sin(self) -> Self {
<$t>::sin(self)
}
fn asin(self) -> Self {
<$t>::asin(self)
}
fn sinh(self) -> Self {
<$t>::sinh(self)
}
fn cos(self) -> Self {
<$t>::cos(self)
}
fn acos(self) -> Self {
<$t>::acos(self)
}
fn cosh(self) -> Self {
<$t>::cosh(self)
}
fn tan(self) -> Self {
<$t>::tan(self)
}
fn atan(self) -> Self {
<$t>::atan(self)
}
fn tanh(self) -> Self {
<$t>::tanh(self)
}
}
};
}
#[cfg(feature = "complex")]
float_type!(complex::Complex32, |_| false, |_| false);
#[cfg(feature = "complex")]
float_type!(complex::Complex64, |_| false, |_| false);
float_type!(f32, f32::is_infinite, f32::is_nan);
float_type!(f64, f64::is_infinite, f64::is_nan);
#[cfg(all(feature = "complex", not(feature = "opencl")))]
pub trait Complex: Float<Abs = Self::Real> {
type Real: Float + Real;
fn angle(self) -> Self::Real;
fn conj(self) -> Self;
fn im(self) -> Self::Real;
fn re(self) -> Self::Real;
}
#[cfg(all(feature = "complex", feature = "opencl"))]
pub trait Complex: Float<Abs = Self::Real> + opencl::CLElementComplex {
type Real: Float + Real;
fn angle(self) -> Self::Real;
fn conj(self) -> Self;
fn im(self) -> Self::Real;
fn re(self) -> Self::Real;
}
#[cfg(feature = "complex")]
macro_rules! complex_type {
($t:ty, $r:ty) => {
impl Complex for $t {
type Real = $r;
fn angle(self) -> $r {
Self::arg(self)
}
fn conj(self) -> Self {
complex::Complex::<$r>::conj(&self)
}
fn im(self) -> $r {
self.im
}
fn re(self) -> $r {
self.re
}
}
};
}
#[cfg(feature = "complex")]
complex_type!(complex::Complex32, f32);
#[cfg(feature = "complex")]
complex_type!(complex::Complex64, f64);
pub enum Error {
Bounds(String),
Unsupported(String),
#[cfg(feature = "opencl")]
OCL(std::sync::Arc<ocl::Error>),
}
impl Error {
pub fn bounds(msg: String) -> Self {
#[cfg(feature = "debug_crash")]
panic!("{}", msg);
#[cfg(not(feature = "debug_crash"))]
Self::Bounds(msg)
}
pub fn unsupported(msg: String) -> Self {
#[cfg(feature = "debug_crash")]
panic!("{}", msg);
#[cfg(not(feature = "debug_crash"))]
Self::Unsupported(msg)
}
}
impl Clone for Error {
fn clone(&self) -> Self {
match self {
Self::Bounds(msg) => Self::Bounds(msg.clone()),
Self::Unsupported(msg) => Self::Unsupported(msg.clone()),
#[cfg(feature = "opencl")]
Self::OCL(cause) => Self::OCL(cause.clone()),
}
}
}
#[cfg(feature = "opencl")]
impl From<ocl::Error> for Error {
fn from(cause: ocl::Error) -> Self {
#[cfg(feature = "debug_crash")]
panic!("OpenCL error: {:?}", cause);
#[cfg(not(feature = "debug_crash"))]
Self::OCL(std::sync::Arc::new(cause))
}
}
impl fmt::Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Bounds(cause) => f.write_str(cause),
Self::Unsupported(cause) => f.write_str(cause),
#[cfg(feature = "opencl")]
Self::OCL(cause) => cause.fmt(f),
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Bounds(cause) => f.write_str(cause),
Self::Unsupported(cause) => f.write_str(cause),
#[cfg(feature = "opencl")]
Self::OCL(cause) => cause.fmt(f),
}
}
}
impl std::error::Error for Error {}
pub type Axes = SmallVec<[usize; 8]>;
pub type Range = SmallVec<[AxisRange; 8]>;
pub type Shape = SmallVec<[usize; 8]>;
pub type Strides = SmallVec<[usize; 8]>;
pub type Array<T, A> = array::Array<T, A, Platform>;
pub type ArrayBuf<T, B> = array::Array<T, AccessBuf<B>, Platform>;
pub type ArrayOp<T, Op> = array::Array<T, AccessOp<Op>, Platform>;
pub type ArrayAccess<'a, T> = array::Array<T, Accessor<'a, T>, Platform>;
pub type AccessOp<Op> = access::AccessOp<Op, Platform>;
#[derive(Clone, Eq, PartialEq, Hash)]
pub enum AxisRange {
At(usize),
In(usize, usize, usize),
Of(SmallVec<[usize; 8]>),
}
impl AxisRange {
pub fn is_index(&self) -> bool {
matches!(self, Self::At(_))
}
pub fn size(&self) -> Option<usize> {
match self {
Self::At(_) => None,
Self::In(start, stop, step) => Some((stop - start) / step),
Self::Of(indices) => Some(indices.len()),
}
}
}
impl From<usize> for AxisRange {
fn from(i: usize) -> Self {
Self::At(i)
}
}
impl From<std::ops::Range<usize>> for AxisRange {
fn from(range: std::ops::Range<usize>) -> Self {
Self::In(range.start, range.end, 1)
}
}
impl From<SmallVec<[usize; 8]>> for AxisRange {
fn from(indices: SmallVec<[usize; 8]>) -> Self {
Self::Of(indices)
}
}
impl fmt::Debug for AxisRange {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::At(i) => write!(f, "{}", i),
Self::In(start, stop, 1) => write!(f, "{}:{}", start, stop),
Self::In(start, stop, step) => write!(f, "{}:{}:{}", start, stop, step),
Self::Of(indices) => write!(f, "{:?}", indices),
}
}
}
#[inline]
pub fn broadcast_shape(left: &[usize], right: &[usize]) -> Result<Shape, Error> {
let ndim = usize::max(left.len(), right.len());
let mut shape = Shape::with_capacity(ndim);
let mut left = left.iter().rev().copied();
let mut right = right.iter().rev().copied();
while let Some(dim) = broadcast_dim(left.next(), right.next())? {
shape.push(dim)
}
shape.reverse();
Ok(shape)
}
#[inline]
pub fn broadcast_matmul_shape(left: &[usize], right: &[usize]) -> Result<(Shape, Shape), Error> {
let (left_ndim, right_ndim) = (left.len(), right.len());
let ndim = usize::max(left_ndim, right_ndim);
let mut left = left.iter().rev().copied();
let mut right = right.iter().rev().copied();
let k = right.next().unwrap_or(1);
let j = match (left.next(), right.next()) {
(Some(jl), Some(jr)) => match (jl, jr) {
(jl, jr) if jl == jr => Ok(jl),
(1, jr) => Ok(jr),
(jl, 1) => Ok(jl),
_ => Err(Error::bounds(format!(
"cannot matrix-multiply shapes {left:?} and {right:?}"
))),
},
(Some(jl), None) => Ok(jl),
(None, Some(jr)) => Ok(jr),
(None, None) => Ok(1),
}?;
let i = left.next().unwrap_or(1);
let mut broadcast_shape = Shape::with_capacity(ndim);
while let Some(dim) = broadcast_dim(left.next(), right.next())? {
broadcast_shape.push(dim);
}
broadcast_shape.reverse();
let left = broadcast_shape.iter().copied().chain([i, j]).collect();
let right = broadcast_shape.into_iter().chain([j, k]).collect();
Ok((left, right))
}
#[inline]
fn broadcast_dim(left: Option<usize>, right: Option<usize>) -> Result<Option<usize>, Error> {
match (left, right) {
(Some(l), Some(r)) if l == r => Ok(Some(l)),
(Some(1), Some(r)) => Ok(Some(r)),
(Some(l), Some(1)) => Ok(Some(l)),
(None, Some(r)) => Ok(Some(r)),
(Some(l), None) => Ok(Some(l)),
(None, None) => Ok(None),
(l, r) => Err(Error::bounds(format!(
"cannot broadcast dimensions {l:?} and {r:?}"
))),
}
}
#[inline]
fn range_shape(source_shape: &[usize], range: &[AxisRange]) -> Shape {
debug_assert_eq!(source_shape.len(), range.len());
range.iter().filter_map(|ar| ar.size()).collect()
}
#[inline]
pub fn strides_for<'a>(shape: &'a [usize], ndim: usize) -> impl Iterator<Item = usize> + 'a {
debug_assert!(ndim >= shape.len());
let zeros = std::iter::repeat_n(0, ndim - shape.len());
let strides = shape.iter().copied().enumerate().map(|(x, dim)| {
if dim == 1 {
0
} else {
shape.iter().rev().take(shape.len() - 1 - x).product()
}
});
zeros.chain(strides)
}