use std::ffi::c_int;
use crate::{
array::Array,
dtype::Dtype,
error::{
CapExceededPayload, EmptyInputPayload, Error, InvariantViolationPayload, LengthMismatchPayload,
MultiLengthMismatchPayload, OutOfRangePayload, Result, check,
},
ffi::VectorArrayGuard,
shape::dim_ptr,
stream::default_stream,
};
pub fn slice(a: &Array, start: &[i32], stop: &[i32], strides: &[i32]) -> Result<Array> {
if start.len() != stop.len() || start.len() != strides.len() {
return Err(Error::MultiLengthMismatch(MultiLengthMismatchPayload::new(
"slice: start/stop/strides",
vec![
("start", start.len()),
("stop", stop.len()),
("strides", strides.len()),
],
)));
}
if start.len() != a.ndim() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"slice: start/stop/strides length",
a.ndim(),
start.len(),
)));
}
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_slice(
&mut out.0,
a.0,
dim_ptr(start),
start.len(),
dim_ptr(stop),
stop.len(),
dim_ptr(strides),
strides.len(),
default_stream(),
)
})?;
Ok(out)
}
pub fn take(a: &Array, indices: &Array) -> Result<Array> {
reject_bool_index("take: index dtype", indices)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_take(&mut out.0, a.0, indices.0, default_stream()) })?;
Ok(out)
}
pub fn take_axis(a: &Array, indices: &Array, axis: i32) -> Result<Array> {
reject_bool_index("take_axis: index dtype", indices)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_take_axis(&mut out.0, a.0, indices.0, axis as c_int, default_stream())
})?;
Ok(out)
}
pub fn take_along_axis(a: &Array, indices: &Array, axis: i32) -> Result<Array> {
reject_bool_index("take_along_axis: index dtype", indices)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_take_along_axis(&mut out.0, a.0, indices.0, axis as c_int, default_stream())
})?;
Ok(out)
}
pub fn put_along_axis(a: &Array, indices: &Array, values: &Array, axis: i32) -> Result<Array> {
reject_bool_index("put_along_axis: index dtype", indices)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_put_along_axis(
&mut out.0,
a.0,
indices.0,
values.0,
axis as c_int,
default_stream(),
)
})?;
Ok(out)
}
pub fn scatter_add_axis(a: &Array, indices: &Array, values: &Array, axis: i32) -> Result<Array> {
reject_bool_index("scatter_add_axis: index dtype", indices)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_scatter_add_axis(
&mut out.0,
a.0,
indices.0,
values.0,
axis as c_int,
default_stream(),
)
})?;
Ok(out)
}
pub fn gather(a: &Array, indices: &[&Array], axes: &[i32], slice_sizes: &[i32]) -> Result<Array> {
if indices.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"gather: indices slice",
)));
}
if indices.len() != axes.len() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"gather: indices.len() vs axes.len()",
axes.len(),
indices.len(),
)));
}
if slice_sizes.len() != a.ndim() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"gather: slice_sizes.len() vs a.ndim()",
a.ndim(),
slice_sizes.len(),
)));
}
crate::shape::validate_dims(slice_sizes)?;
for idx in indices {
reject_bool_index("gather: index dtype", idx)?;
}
crate::error::ensure_handler_installed();
let raw: Vec<mlxrs_sys::mlx_array> = indices.iter().map(|a| a.0).collect();
let vec = unsafe { mlxrs_sys::mlx_vector_array_new_data(raw.as_ptr(), raw.len()) };
let _vec_guard = VectorArrayGuard(vec);
if vec.ctx.is_null() {
return Err(
crate::error::LAST
.with(|c| c.borrow_mut().take())
.unwrap_or(Error::FfiNullHandle(
crate::error::FfiNullHandlePayload::new("mlx_vector_array_new_data"),
)),
);
}
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_gather(
&mut out.0,
a.0,
vec,
dim_ptr(axes),
axes.len(),
dim_ptr(slice_sizes),
slice_sizes.len(),
default_stream(),
)
})?;
Ok(out)
}
fn reject_bool_index(context: &'static str, idx: &Array) -> Result<()> {
if idx.dtype()? == Dtype::Bool {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
context,
"indices must not be Bool (Bool indices are not supported by mlx indexing ops)",
)));
}
Ok(())
}
fn scatter_multi(
context: &'static str,
a: &Array,
indices: &[&Array],
updates: &Array,
axes: &[i32],
ffi: unsafe extern "C" fn(
*mut mlxrs_sys::mlx_array,
mlxrs_sys::mlx_array,
mlxrs_sys::mlx_vector_array,
mlxrs_sys::mlx_array,
*const c_int,
usize,
mlxrs_sys::mlx_stream,
) -> c_int,
) -> Result<Array> {
if indices.len() != axes.len() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
context,
axes.len(),
indices.len(),
)));
}
if indices.len() > a.ndim() {
return Err(Error::CapExceeded(CapExceededPayload::new(
"scatter: number of index arrays",
"a.ndim()",
a.ndim() as u64,
indices.len() as u64,
)));
}
for idx in indices {
reject_bool_index("scatter: index dtype", idx)?;
}
crate::error::ensure_handler_installed();
let raw: Vec<mlxrs_sys::mlx_array> = indices.iter().map(|a| a.0).collect();
let vec = unsafe { mlxrs_sys::mlx_vector_array_new_data(raw.as_ptr(), raw.len()) };
let _vec_guard = VectorArrayGuard(vec);
if vec.ctx.is_null() {
return Err(
crate::error::LAST
.with(|c| c.borrow_mut().take())
.unwrap_or(Error::FfiNullHandle(
crate::error::FfiNullHandlePayload::new("mlx_vector_array_new_data"),
)),
);
}
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
ffi(
&mut out.0,
a.0,
vec,
updates.0,
dim_ptr(axes),
axes.len(),
default_stream(),
)
})?;
Ok(out)
}
pub fn scatter(a: &Array, indices: &[&Array], updates: &Array, axes: &[i32]) -> Result<Array> {
scatter_multi(
"scatter: indices.len() vs axes.len()",
a,
indices,
updates,
axes,
mlxrs_sys::mlx_scatter,
)
}
pub fn scatter_add(a: &Array, indices: &[&Array], updates: &Array, axes: &[i32]) -> Result<Array> {
scatter_multi(
"scatter_add: indices.len() vs axes.len()",
a,
indices,
updates,
axes,
mlxrs_sys::mlx_scatter_add,
)
}
pub fn scatter_max(a: &Array, indices: &[&Array], updates: &Array, axes: &[i32]) -> Result<Array> {
scatter_multi(
"scatter_max: indices.len() vs axes.len()",
a,
indices,
updates,
axes,
mlxrs_sys::mlx_scatter_max,
)
}
pub fn scatter_min(a: &Array, indices: &[&Array], updates: &Array, axes: &[i32]) -> Result<Array> {
scatter_multi(
"scatter_min: indices.len() vs axes.len()",
a,
indices,
updates,
axes,
mlxrs_sys::mlx_scatter_min,
)
}
pub fn scatter_prod(a: &Array, indices: &[&Array], updates: &Array, axes: &[i32]) -> Result<Array> {
scatter_multi(
"scatter_prod: indices.len() vs axes.len()",
a,
indices,
updates,
axes,
mlxrs_sys::mlx_scatter_prod,
)
}
fn scatter_single(
a: &Array,
indices: &Array,
updates: &Array,
axis: i32,
ffi: unsafe extern "C" fn(
*mut mlxrs_sys::mlx_array,
mlxrs_sys::mlx_array,
mlxrs_sys::mlx_array,
mlxrs_sys::mlx_array,
c_int,
mlxrs_sys::mlx_stream,
) -> c_int,
) -> Result<Array> {
reject_bool_index("scatter: index dtype", indices)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
ffi(
&mut out.0,
a.0,
indices.0,
updates.0,
axis as c_int,
default_stream(),
)
})?;
Ok(out)
}
pub fn scatter_axis(a: &Array, indices: &Array, updates: &Array, axis: i32) -> Result<Array> {
scatter_single(a, indices, updates, axis, mlxrs_sys::mlx_scatter_single)
}
pub fn scatter_add_single(a: &Array, indices: &Array, updates: &Array, axis: i32) -> Result<Array> {
scatter_single(a, indices, updates, axis, mlxrs_sys::mlx_scatter_add_single)
}
pub fn scatter_max_single(a: &Array, indices: &Array, updates: &Array, axis: i32) -> Result<Array> {
scatter_single(a, indices, updates, axis, mlxrs_sys::mlx_scatter_max_single)
}
pub fn scatter_min_single(a: &Array, indices: &Array, updates: &Array, axis: i32) -> Result<Array> {
scatter_single(a, indices, updates, axis, mlxrs_sys::mlx_scatter_min_single)
}
pub fn scatter_prod_single(
a: &Array,
indices: &Array,
updates: &Array,
axis: i32,
) -> Result<Array> {
scatter_single(
a,
indices,
updates,
axis,
mlxrs_sys::mlx_scatter_prod_single,
)
}
#[allow(clippy::too_many_arguments)]
fn slice_update_impl(
context_multi: &'static str,
context_len: &'static str,
src: &Array,
update: &Array,
start: &[i32],
stop: &[i32],
strides: &[i32],
ffi: unsafe extern "C" fn(
*mut mlxrs_sys::mlx_array,
mlxrs_sys::mlx_array,
mlxrs_sys::mlx_array,
*const c_int,
usize,
*const c_int,
usize,
*const c_int,
usize,
mlxrs_sys::mlx_stream,
) -> c_int,
) -> Result<Array> {
if start.len() != stop.len() || start.len() != strides.len() {
return Err(Error::MultiLengthMismatch(MultiLengthMismatchPayload::new(
context_multi,
vec![
("start", start.len()),
("stop", stop.len()),
("strides", strides.len()),
],
)));
}
if start.len() != src.ndim() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
context_len,
src.ndim(),
start.len(),
)));
}
let shape = src.shape();
for (axis, &stride) in strides.iter().enumerate() {
if stride == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"slice-update: stride",
"must be non-zero (mlx normalize_slice divides by it)",
"0",
)));
}
if stride == i32::MIN {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"slice-update: stride",
"must not be i32::MIN (its negation is UB in mlx normalize_slice, any axis size)",
"i32::MIN",
)));
}
if shape[axis] as i64 + (stride as i64).abs() - 1 > i32::MAX as i64 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"slice-update: stride",
"axis_size + abs(stride) - 1 must fit in i32 (mlx normalize_slice int32 overflow)",
"out-of-range magnitude",
)));
}
}
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
ffi(
&mut out.0,
src.0,
update.0,
dim_ptr(start),
start.len(),
dim_ptr(stop),
stop.len(),
dim_ptr(strides),
strides.len(),
default_stream(),
)
})?;
Ok(out)
}
pub fn slice_update(
src: &Array,
update: &Array,
start: &[i32],
stop: &[i32],
strides: &[i32],
) -> Result<Array> {
slice_update_impl(
"slice_update: start/stop/strides",
"slice_update: start/stop/strides length",
src,
update,
start,
stop,
strides,
mlxrs_sys::mlx_slice_update,
)
}
pub fn slice_update_add(
src: &Array,
update: &Array,
start: &[i32],
stop: &[i32],
strides: &[i32],
) -> Result<Array> {
slice_update_impl(
"slice_update_add: start/stop/strides",
"slice_update_add: start/stop/strides length",
src,
update,
start,
stop,
strides,
mlxrs_sys::mlx_slice_update_add,
)
}
pub fn slice_update_max(
src: &Array,
update: &Array,
start: &[i32],
stop: &[i32],
strides: &[i32],
) -> Result<Array> {
slice_update_impl(
"slice_update_max: start/stop/strides",
"slice_update_max: start/stop/strides length",
src,
update,
start,
stop,
strides,
mlxrs_sys::mlx_slice_update_max,
)
}
pub fn slice_update_min(
src: &Array,
update: &Array,
start: &[i32],
stop: &[i32],
strides: &[i32],
) -> Result<Array> {
slice_update_impl(
"slice_update_min: start/stop/strides",
"slice_update_min: start/stop/strides length",
src,
update,
start,
stop,
strides,
mlxrs_sys::mlx_slice_update_min,
)
}
pub fn slice_update_prod(
src: &Array,
update: &Array,
start: &[i32],
stop: &[i32],
strides: &[i32],
) -> Result<Array> {
slice_update_impl(
"slice_update_prod: start/stop/strides",
"slice_update_prod: start/stop/strides length",
src,
update,
start,
stop,
strides,
mlxrs_sys::mlx_slice_update_prod,
)
}
pub fn slice_update_dynamic(
src: &Array,
update: &Array,
start: &Array,
axes: &[i32],
) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_slice_update_dynamic(
&mut out.0,
src.0,
update.0,
start.0,
dim_ptr(axes),
axes.len(),
default_stream(),
)
})?;
Ok(out)
}