use std::ffi::c_int;
use smol_str::format_smolstr;
use crate::{
array::Array,
error::{
ArithmeticOverflowPayload, CapExceededPayload, EmptyInputPayload, Error, LengthMismatchPayload,
MultiLengthMismatchPayload, OutOfRangePayload, Result, check,
},
ffi::VectorArrayGuard,
shape::{IntoShape, dim_ptr, stride_ptr, validate_dims},
stream::default_stream,
};
fn check_count(context: &'static str, cap_name: &'static str, len: usize) -> Result<()> {
const CAP: usize = i32::MAX as usize;
if len > CAP {
return Err(Error::CapExceeded(CapExceededPayload::new(
context, cap_name, CAP as u64, len as u64,
)));
}
Ok(())
}
fn check_tile_intermediate_rank(ndim: usize, reps: &[i32]) -> Result<()> {
let aligned_rank = reps.len().max(ndim);
let extra = reps.iter().filter(|&&r| r != 1).count();
let max_intermediate_rank = aligned_rank.saturating_add(extra);
check_count(
"tile: max intermediate rank aligned_rank + count(reps != 1)",
"i32::MAX",
max_intermediate_rank,
)
}
fn checked_total_shift(context: &'static str, shift: &[i32]) -> Result<i32> {
let mut total: i32 = 0;
for &s in shift {
total = total
.checked_add(s)
.ok_or_else(|| Error::ArithmeticOverflow(ArithmeticOverflowPayload::new(context, "i32")))?;
}
if total == i32::MIN {
return Err(Error::OutOfRange(OutOfRangePayload::new(
context,
"shift sum must not be i32::MIN (its negation is UB in MLX)",
format_smolstr!("{total}"),
)));
}
Ok(total)
}
pub fn reshape(a: &Array, shape: &impl IntoShape) -> Result<Array> {
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_reshape(&mut out.0, a.0, dim_ptr(s), s.len(), default_stream())
})?;
Ok(out)
})
}
pub fn concatenate(arrays: &[&Array], axis: i32) -> Result<Array> {
if arrays.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"concatenate: arrays slice",
)));
}
crate::error::ensure_handler_installed();
let raw: Vec<mlxrs_sys::mlx_array> = arrays.iter().map(|a| a.0).collect();
let vec = unsafe { mlxrs_sys::mlx_vector_array_new_data(raw.as_ptr(), raw.len()) };
let _vec_guard = VectorArrayGuard(vec);
if vec.ctx.is_null() {
return Err(
crate::error::LAST
.with(|c| c.borrow_mut().take())
.unwrap_or(Error::FfiNullHandle(
crate::error::FfiNullHandlePayload::new("mlx_vector_array_new_data"),
)),
);
}
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_concatenate_axis(&mut out.0, vec, axis, default_stream()) })?;
Ok(out)
}
pub fn transpose(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_transpose(&mut out.0, a.0, default_stream()) })?;
Ok(out)
}
pub fn transpose_axes(a: &Array, axes: &[i32]) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_transpose_axes(&mut out.0, a.0, dim_ptr(axes), axes.len(), default_stream())
})?;
Ok(out)
}
pub fn expand_dims_axes(a: &Array, axes: &[i32]) -> Result<Array> {
if axes.is_empty() {
return a.try_clone();
}
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_expand_dims_axes(&mut out.0, a.0, dim_ptr(axes), axes.len(), default_stream())
})?;
Ok(out)
}
pub fn squeeze_axes(a: &Array, axes: &[i32]) -> Result<Array> {
if axes.is_empty() {
return a.try_clone();
}
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_squeeze_axes(&mut out.0, a.0, dim_ptr(axes), axes.len(), default_stream())
})?;
Ok(out)
}
pub fn broadcast_to(a: &Array, shape: &impl IntoShape) -> Result<Array> {
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_broadcast_to(&mut out.0, a.0, dim_ptr(s), s.len(), default_stream())
})?;
Ok(out)
})
}
pub fn stack(arrays: &[&Array]) -> Result<Array> {
if arrays.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"stack: arrays slice",
)));
}
crate::error::ensure_handler_installed();
let raw: Vec<mlxrs_sys::mlx_array> = arrays.iter().map(|a| a.0).collect();
let vec = unsafe { mlxrs_sys::mlx_vector_array_new_data(raw.as_ptr(), raw.len()) };
let _vec_guard = VectorArrayGuard(vec);
if vec.ctx.is_null() {
return Err(
crate::error::LAST
.with(|c| c.borrow_mut().take())
.unwrap_or(Error::FfiNullHandle(
crate::error::FfiNullHandlePayload::new("mlx_vector_array_new_data"),
)),
);
}
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_stack(&mut out.0, vec, default_stream()) })?;
Ok(out)
}
pub fn stack_axis(arrays: &[&Array], axis: i32) -> Result<Array> {
if arrays.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"stack_axis: arrays slice",
)));
}
crate::error::ensure_handler_installed();
let raw: Vec<mlxrs_sys::mlx_array> = arrays.iter().map(|a| a.0).collect();
let vec = unsafe { mlxrs_sys::mlx_vector_array_new_data(raw.as_ptr(), raw.len()) };
let _vec_guard = VectorArrayGuard(vec);
if vec.ctx.is_null() {
return Err(
crate::error::LAST
.with(|c| c.borrow_mut().take())
.unwrap_or(Error::FfiNullHandle(
crate::error::FfiNullHandlePayload::new("mlx_vector_array_new_data"),
)),
);
}
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_stack_axis(&mut out.0, vec, axis as c_int, default_stream()) })?;
Ok(out)
}
pub fn split_sections(a: &Array, indices: &[i32], axis: i32) -> Result<Vec<Array>> {
crate::error::ensure_handler_installed();
let mut vec_out = unsafe { mlxrs_sys::mlx_vector_array_new() };
let _vec_guard = VectorArrayGuard(vec_out);
check(unsafe {
mlxrs_sys::mlx_split_sections(
&mut vec_out,
a.0,
dim_ptr(indices),
indices.len(),
axis as c_int,
default_stream(),
)
})?;
let n = unsafe { mlxrs_sys::mlx_vector_array_size(vec_out) };
let mut parts = Vec::with_capacity(n);
for i in 0..n {
let mut part = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_vector_array_get(&mut part.0, vec_out, i) })?;
parts.push(part);
}
Ok(parts)
}
pub fn flatten(a: &Array, start_axis: i32, end_axis: i32) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_flatten(
&mut out.0,
a.0,
start_axis as c_int,
end_axis as c_int,
default_stream(),
)
})?;
Ok(out)
}
pub fn swapaxes(a: &Array, axis1: i32, axis2: i32) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_swapaxes(
&mut out.0,
a.0,
axis1 as c_int,
axis2 as c_int,
default_stream(),
)
})?;
Ok(out)
}
pub fn pad(
a: &Array,
axes: &[i32],
low: &[i32],
high: &[i32],
pad_value: &Array,
mode: &std::ffi::CStr,
) -> Result<Array> {
if axes.len() != low.len() || axes.len() != high.len() {
return Err(Error::MultiLengthMismatch(MultiLengthMismatchPayload::new(
"pad: axes/low/high",
vec![
("axes", axes.len()),
("low", low.len()),
("high", high.len()),
],
)));
}
crate::shape::validate_dims(low)?;
crate::shape::validate_dims(high)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_pad(
&mut out.0,
a.0,
dim_ptr(axes),
axes.len(),
dim_ptr(low),
low.len(),
dim_ptr(high),
high.len(),
pad_value.0,
mode.as_ptr(),
default_stream(),
)
})?;
Ok(out)
}
pub fn contiguous(a: &Array, allow_col_major: bool) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_contiguous(&mut out.0, a.0, allow_col_major, default_stream()) })?;
Ok(out)
}
pub unsafe fn as_strided(
a: &Array,
shape: &impl IntoShape,
strides: &[i64],
offset: usize,
) -> Result<Array> {
shape.with_shape(|s| {
if s.len() != strides.len() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"as_strided: shape length vs strides length",
s.len(),
strides.len(),
)));
}
validate_dims(s)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_as_strided(
&mut out.0,
a.0,
dim_ptr(s),
s.len(),
stride_ptr(strides),
strides.len(),
offset,
default_stream(),
)
})?;
Ok(out)
})
}
pub fn moveaxis(a: &Array, source: i32, destination: i32) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_moveaxis(
&mut out.0,
a.0,
source as c_int,
destination as c_int,
default_stream(),
)
})?;
Ok(out)
}
pub fn roll(a: &Array, shift: &[i32]) -> Result<Array> {
let total = checked_total_shift("roll: shift sum", shift)?;
let total = [total];
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_roll(
&mut out.0,
a.0,
dim_ptr(&total),
total.len(),
default_stream(),
)
})?;
Ok(out)
}
pub fn roll_axis(a: &Array, shift: &[i32], axis: i32) -> Result<Array> {
let total = checked_total_shift("roll_axis: shift sum", shift)?;
let total = [total];
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_roll_axis(
&mut out.0,
a.0,
dim_ptr(&total),
total.len(),
axis as c_int,
default_stream(),
)
})?;
Ok(out)
}
pub fn roll_axes(a: &Array, shift: &[i32], axes: &[i32]) -> Result<Array> {
if shift.len() != axes.len() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"roll_axes: shift.len() vs axes.len()",
axes.len(),
shift.len(),
)));
}
check_count("roll_axes: axes.len()", "i32::MAX", axes.len())?;
check_count("roll_axes: shift.len()", "i32::MAX", shift.len())?;
if let Some((i, &s)) = shift.iter().enumerate().find(|&(_, &s)| s == i32::MIN) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"roll_axes: shift",
"no shift may be i32::MIN (its negation is UB in MLX)",
format_smolstr!("shift[{i}]={s}"),
)));
}
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_roll_axes(
&mut out.0,
a.0,
dim_ptr(shift),
shift.len(),
dim_ptr(axes),
axes.len(),
default_stream(),
)
})?;
Ok(out)
}
pub fn tile(a: &Array, reps: &[i32]) -> Result<Array> {
check_tile_intermediate_rank(a.ndim(), reps)?;
let shape = a.shape();
let mut reps_rev = reps.iter().rev();
let mut shape_rev = shape.iter().rev();
loop {
let next_rep = reps_rev.next();
let next_dim = shape_rev.next();
if next_rep.is_none() && next_dim.is_none() {
break; }
let rep: i64 = match next_rep {
Some(&r) => {
if r < 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"tile: reps",
"every rep must be non-negative",
format_smolstr!("{r}"),
)));
}
i64::from(r)
}
None => 1,
};
let dim: i64 = next_dim.map_or(1, |&d| d as i64);
let out_dim = rep * dim; if out_dim > i64::from(i32::MAX) {
return Err(Error::ArithmeticOverflow(
ArithmeticOverflowPayload::with_operands(
"tile: reps[i] * shape[i] output dim",
"i32",
[("rep", rep as u64), ("dim", dim as u64)],
),
));
}
}
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_tile(&mut out.0, a.0, dim_ptr(reps), reps.len(), default_stream())
})?;
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn check_count_boundary() {
assert!(check_count("t", "i32::MAX", 0).is_ok());
assert!(check_count("t", "i32::MAX", i32::MAX as usize).is_ok());
let over = i32::MAX as usize + 1;
match check_count("ctx", "i32::MAX", over) {
Err(Error::CapExceeded(p)) => {
assert_eq!(p.context(), "ctx");
assert_eq!(p.cap_name(), "i32::MAX");
assert_eq!(p.cap(), i32::MAX as u64);
assert_eq!(p.observed(), over as u64);
}
other => panic!("expected Err(CapExceeded) one past the cap, got {other:?}"),
}
}
#[test]
fn tile_intermediate_rank_boundary() {
let cap = i32::MAX as usize;
match check_tile_intermediate_rank(cap, &[2]) {
Err(Error::CapExceeded(p)) => {
assert_eq!(p.cap(), cap as u64);
assert_eq!(p.observed(), cap as u64 + 1); }
other => panic!("expected CapExceeded for cap+non_unit_rep, got {other:?}"),
}
assert!(check_tile_intermediate_rank(cap, &[1]).is_ok());
assert!(check_tile_intermediate_rank(1, &[2, 1, 3, 1, 4]).is_ok());
match check_tile_intermediate_rank(cap - 2, &[2, 3, 4]) {
Err(Error::CapExceeded(p)) => assert_eq!(p.observed(), cap as u64 + 1),
other => panic!("expected CapExceeded for cap-2 + 3 non-unit reps, got {other:?}"),
}
assert!(check_tile_intermediate_rank(cap - 1, &[2]).is_ok());
assert!(check_tile_intermediate_rank(0, &[]).is_ok());
assert!(check_tile_intermediate_rank(3, &[]).is_ok());
}
}