use std::ffi::c_int;
use crate::{
array::Array,
error::{Error, LengthMismatchPayload, Result, check},
ffi::VectorArrayGuard,
stream::default_stream,
};
#[inline(always)]
fn opt_float(v: Option<f32>) -> mlxrs_sys::mlx_optional_float {
mlxrs_sys::mlx_optional_float {
value: v.unwrap_or(0.0),
has_value: v.is_some(),
}
}
pub fn add(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_add(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}
pub fn subtract(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_subtract(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}
pub fn multiply(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_multiply(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}
pub fn divide(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_divide(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}
pub fn maximum(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_maximum(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}
pub fn minimum(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_minimum(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}
pub fn power(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_power(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}
pub fn negative(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_negative(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn abs(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_abs(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn sqrt(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_sqrt(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn square(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_square(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn exp(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_exp(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn log(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_log(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn sin(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_sin(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn cos(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_cos(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn tan(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_tan(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn tanh(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_tanh(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn log10(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_log10(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn log2(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_log2(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn log1p(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_log1p(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn expm1(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_expm1(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn erf(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_erf(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn erfinv(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_erfinv(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn sigmoid(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_sigmoid(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn ceil(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_ceil(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn floor(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_floor(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn round(a: &Array, decimals: i32) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_round(&mut out.0, a.0, decimals as c_int, default_stream()) })?;
Ok(out)
}
pub fn sign(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_sign(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn reciprocal(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_reciprocal(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn rsqrt(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_rsqrt(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn conjugate(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_conjugate(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn real(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_real(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn imag(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_imag(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn degrees(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_degrees(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn radians(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_radians(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn sinh(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_sinh(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn cosh(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_cosh(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn arcsin(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_arcsin(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn arccos(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_arccos(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn arctan(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_arctan(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn arcsinh(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_arcsinh(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn arccosh(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_arccosh(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn arctanh(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_arctanh(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn nan_to_num(a: &Array, nan: f32, posinf: Option<f32>, neginf: Option<f32>) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_nan_to_num(
&mut out.0,
a.0,
nan,
opt_float(posinf),
opt_float(neginf),
default_stream(),
)
})?;
Ok(out)
}
pub fn bitwise_invert(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_bitwise_invert(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn arctan2(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_arctan2(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}
pub fn floor_divide(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_floor_divide(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}
pub fn remainder(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_remainder(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}
pub fn divmod(a: &Array, b: &Array) -> Result<(Array, Array)> {
crate::error::ensure_handler_installed();
let mut vec_out = unsafe { mlxrs_sys::mlx_vector_array_new() };
let _vec_guard = VectorArrayGuard(vec_out);
check(unsafe { mlxrs_sys::mlx_divmod(&mut vec_out, a.0, b.0, default_stream()) })?;
let n = unsafe { mlxrs_sys::mlx_vector_array_size(vec_out) };
if n != 2 {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"divmod: output count from mlx_divmod",
2,
n,
)));
}
let mut quot = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_vector_array_get(&mut quot.0, vec_out, 0) })?;
let mut rem = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_vector_array_get(&mut rem.0, vec_out, 1) })?;
Ok((quot, rem))
}
pub fn bitwise_and(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_bitwise_and(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}
pub fn bitwise_or(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_bitwise_or(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}
pub fn bitwise_xor(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_bitwise_xor(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}
pub fn left_shift(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_left_shift(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}
pub fn right_shift(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_right_shift(&mut out.0, a.0, b.0, default_stream()) })?;
Ok(out)
}