use std::ffi::c_int;
use crate::{
array::Array,
dtype::Dtype,
error::{Error, FfiNullHandlePayload, Result, check},
stream::default_stream,
};
struct ScalarGuard(mlxrs_sys::mlx_array);
impl Drop for ScalarGuard {
fn drop(&mut self) {
unsafe {
let _ = mlxrs_sys::mlx_array_free(self.0);
}
}
}
fn checked_scalar_f32(value: f32) -> Result<ScalarGuard> {
crate::error::ensure_handler_installed();
let raw = unsafe { mlxrs_sys::mlx_array_new_float32(value) };
let guard = ScalarGuard(raw);
if raw.ctx.is_null() {
return Err(
crate::error::LAST
.with(|c| c.borrow_mut().take())
.unwrap_or(Error::FfiNullHandle(FfiNullHandlePayload::new(
"mlx_array_new_float32",
))),
);
}
Ok(guard)
}
pub fn argmax(a: &Array, axis: Option<i32>, keepdims: bool) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
match axis {
Some(ax) => {
mlxrs_sys::mlx_argmax_axis(&mut out.0, a.0, ax as c_int, keepdims, default_stream())
}
None => mlxrs_sys::mlx_argmax(&mut out.0, a.0, keepdims, default_stream()),
}
})?;
Ok(out)
}
pub fn argmin(a: &Array, axis: Option<i32>, keepdims: bool) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
match axis {
Some(ax) => {
mlxrs_sys::mlx_argmin_axis(&mut out.0, a.0, ax as c_int, keepdims, default_stream())
}
None => mlxrs_sys::mlx_argmin(&mut out.0, a.0, keepdims, default_stream()),
}
})?;
Ok(out)
}
pub fn cumsum(a: &Array, axis: i32, reverse: bool, inclusive: bool) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_cumsum(
&mut out.0,
a.0,
axis as c_int,
reverse,
inclusive,
default_stream(),
)
})?;
Ok(out)
}
pub fn cumprod(a: &Array, axis: i32, reverse: bool, inclusive: bool) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_cumprod(
&mut out.0,
a.0,
axis as c_int,
reverse,
inclusive,
default_stream(),
)
})?;
Ok(out)
}
pub fn cummax(a: &Array, axis: i32, reverse: bool, inclusive: bool) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_cummax(
&mut out.0,
a.0,
axis as c_int,
reverse,
inclusive,
default_stream(),
)
})?;
Ok(out)
}
pub fn cummin(a: &Array, axis: i32, reverse: bool, inclusive: bool) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_cummin(
&mut out.0,
a.0,
axis as c_int,
reverse,
inclusive,
default_stream(),
)
})?;
Ok(out)
}
pub fn sort(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_sort(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn sort_axis(a: &Array, axis: i32) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_sort_axis(&mut out.0, a.0, axis as c_int, default_stream()) })?;
Ok(out)
}
pub fn argsort(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_argsort(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn argsort_axis(a: &Array, axis: i32) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_argsort_axis(&mut out.0, a.0, axis as c_int, default_stream()) })?;
Ok(out)
}
pub fn topk(a: &Array, k: i32) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_topk(&mut out.0, a.0, k as c_int, default_stream()) })?;
Ok(out)
}
pub fn topk_axis(a: &Array, k: i32, axis: i32) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_topk_axis(&mut out.0, a.0, k as c_int, axis as c_int, default_stream())
})?;
Ok(out)
}
pub fn partition(a: &Array, kth: i32) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_partition(&mut out.0, a.0, kth as c_int, default_stream()) })?;
Ok(out)
}
pub fn partition_axis(a: &Array, kth: i32, axis: i32) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_partition_axis(
&mut out.0,
a.0,
kth as c_int,
axis as c_int,
default_stream(),
)
})?;
Ok(out)
}
pub fn argpartition(a: &Array, kth: i32) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_argpartition(&mut out.0, a.0, kth as c_int, default_stream()) })?;
Ok(out)
}
pub fn argpartition_axis(a: &Array, kth: i32, axis: i32) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_argpartition_axis(
&mut out.0,
a.0,
kth as c_int,
axis as c_int,
default_stream(),
)
})?;
Ok(out)
}
pub fn softmax_axis(a: &Array, axis: i32, precise: bool) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_softmax_axis(&mut out.0, a.0, axis as c_int, precise, default_stream())
})?;
Ok(out)
}
pub fn clip(a: &Array, a_min: &Array, a_max: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_clip(&mut out.0, a.0, a_min.0, a_max.0, default_stream()) })?;
Ok(out)
}
pub fn clip_with_scalar(a: &Array, min: f32, max: f32) -> Result<Array> {
let lo = checked_scalar_f32(min)?;
let hi = checked_scalar_f32(max)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_clip(&mut out.0, a.0, lo.0, hi.0, default_stream()) })?;
Ok(out)
}
pub fn ones_like(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_ones_like(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn zeros_like(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_zeros_like(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn full_like(a: &Array, value: f32) -> Result<Array> {
let dtype = mlxrs_sys::mlx_dtype::from(a.dtype()?);
let scalar = checked_scalar_f32(value)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_full_like(&mut out.0, a.0, scalar.0, dtype, default_stream()) })?;
Ok(out)
}
pub fn astype(a: &Array, dtype: Dtype) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_astype(
&mut out.0,
a.0,
mlxrs_sys::mlx_dtype::from(dtype),
default_stream(),
)
})?;
Ok(out)
}
pub fn view(a: &Array, dtype: Dtype) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_view(
&mut out.0,
a.0,
mlxrs_sys::mlx_dtype::from(dtype),
default_stream(),
)
})?;
Ok(out)
}
pub fn stop_gradient(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_stop_gradient(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}