use smol_str::format_smolstr;
use crate::{
array::Array,
dtype::{Dtype, Element},
error::{
ArithmeticOverflowPayload, Error, LengthMismatchPayload, NonFiniteScalarPayload,
OutOfRangePayload, Result, UnsupportedDtypePayload, check, check_handle,
},
shape::{IntoShape, dim_ptr, validate_dims},
stream::default_stream,
};
#[inline]
fn data_ptr<T>(data: &[T]) -> *const T
where
T: Element,
{
if data.is_empty() {
T::sentinel_ptr()
} else {
data.as_ptr()
}
}
impl Array {
pub fn ones<T>(shape: &impl IntoShape) -> Result<Self>
where
T: Element,
{
crate::error::ensure_handler_installed();
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Self(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_ones(
&mut out.0,
dim_ptr(s),
s.len(),
mlxrs_sys::mlx_dtype::from(T::DTYPE),
default_stream(),
)
})?;
Ok(out)
})
}
pub fn zeros<T>(shape: &impl IntoShape) -> Result<Self>
where
T: Element,
{
crate::error::ensure_handler_installed();
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Self(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_zeros(
&mut out.0,
dim_ptr(s),
s.len(),
mlxrs_sys::mlx_dtype::from(T::DTYPE),
default_stream(),
)
})?;
Ok(out)
})
}
pub fn full<T>(shape: &impl IntoShape, value: T) -> Result<Self>
where
T: Element,
{
let scalar = Self::from_slice(&[value], &[0i32; 0])?;
crate::error::ensure_handler_installed();
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Self(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_full(
&mut out.0,
dim_ptr(s),
s.len(),
scalar.0,
mlxrs_sys::mlx_dtype::from(T::DTYPE),
default_stream(),
)
})?;
Ok(out)
})
}
pub fn eye<T>(n: usize, m: Option<usize>, k: i32) -> Result<Self>
where
T: Element,
{
crate::error::ensure_handler_installed();
let m = m.unwrap_or(n);
let n_i32 = i32::try_from(n).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"Array::eye: n",
"must fit in i32",
format_smolstr!("{n}"),
))
})?;
let m_i32 = i32::try_from(m).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"Array::eye: m",
"must fit in i32",
format_smolstr!("{m}"),
))
})?;
if k == i32::MIN {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Array::eye: k",
"must be greater than i32::MIN (mlx evaluates -k, which overflows there)",
format_smolstr!("{k}"),
)));
}
let mut out = Self(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_eye(
&mut out.0,
n_i32,
m_i32,
k,
mlxrs_sys::mlx_dtype::from(T::DTYPE),
default_stream(),
)
})?;
Ok(out)
}
pub fn arange<T>(
start: impl Into<f64>,
stop: impl Into<f64>,
step: impl Into<f64>,
) -> Result<Self>
where
T: Element,
{
let start: f64 = start.into();
let stop: f64 = stop.into();
let step: f64 = step.into();
if T::DTYPE == Dtype::Bool {
return Err(Error::UnsupportedDtype(UnsupportedDtypePayload::new(
"Array::arange",
Dtype::Bool,
ARANGE_SUPPORTED_DTYPES,
)));
}
if start.is_nan() || stop.is_nan() || step.is_nan() {
let v = if start.is_nan() {
start
} else if stop.is_nan() {
stop
} else {
step
};
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"Array::arange: start/stop/step must not be NaN",
v,
)));
}
if start.is_infinite() || stop.is_infinite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"Array::arange: start/stop must be finite",
if start.is_infinite() { start } else { stop },
)));
}
if step == 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Array::arange: step",
"must be non-zero",
format_smolstr!("{step}"),
)));
}
if step.is_infinite() {
let correct_dir = (step > 0.0 && start < stop) || (step < 0.0 && start > stop);
if !correct_dir {
return Self::from_slice::<T>(&[], &[0i32]);
}
if !representable_in(start, T::DTYPE) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Array::arange: start",
"must be representable in the output dtype",
format_smolstr!("{start}"),
)));
}
} else {
let real_size = ((stop - start) / step).ceil();
if !real_size.is_finite() || real_size > f64::from(i32::MAX) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Array::arange: range length",
"must be finite and not exceed i32::MAX",
format_smolstr!("{real_size}"),
)));
}
if real_size <= 0.0 {
return Self::from_slice::<T>(&[], &[0i32]);
}
if !representable_in(start, T::DTYPE) || !representable_in(start + step, T::DTYPE) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Array::arange: endpoint",
"must be representable in the output dtype",
format_smolstr!("start={start}, start+step={}", start + step),
)));
}
if matches!(T::DTYPE, Dtype::I32 | Dtype::I64) {
let first = start.trunc() as i128;
let next = (start + step).trunc() as i128;
let delta = next - first;
let post_last = first + (real_size as i128) * delta;
let (plo, phi) = if T::DTYPE == Dtype::I64 {
(i128::from(i64::MIN), i128::from(i64::MAX))
} else {
(i128::from(i32::MIN), i128::from(i32::MAX))
};
if !(plo..=phi).contains(&delta) || !(plo..=phi).contains(&post_last) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Array::arange: range",
"overflows the signed integer accumulation",
format_smolstr!("start={start}, step={step}, len={real_size}"),
)));
}
}
}
crate::error::ensure_handler_installed();
let mut out = Self(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_arange(
&mut out.0,
start,
stop,
step,
mlxrs_sys::mlx_dtype::from(T::DTYPE),
default_stream(),
)
})?;
Ok(out)
}
pub fn linspace<T>(start: impl Into<f64>, stop: impl Into<f64>, num: usize) -> Result<Self>
where
T: Element,
{
let start: f64 = start.into();
let stop: f64 = stop.into();
if num == 0 {
return Self::from_slice::<T>(&[], &[0i32]);
}
let n_i32 = i32::try_from(num).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"Array::linspace: num",
"must fit in i32",
format_smolstr!("{num}"),
))
})?;
if num == 1 {
if !representable_in(start, Dtype::F32)
|| !representable_in(f64::from(start as f32), T::DTYPE)
{
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Array::linspace: start",
"must be representable in the output dtype",
format_smolstr!("{start}"),
)));
}
} else if T::DTYPE != Dtype::F64 {
if !representable_in(start, Dtype::F32) || !representable_in(stop, Dtype::F32) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Array::linspace: endpoint",
"must be representable in f32 (the ramp inner dtype)",
format_smolstr!("start={start}, stop={stop}"),
)));
}
if !matches!(
T::DTYPE,
Dtype::F32 | Dtype::F64 | Dtype::Bool | Dtype::Complex64
) {
let a = f64::from(start as f32);
let b = f64::from(stop as f32);
let margin = a.abs().max(b.abs()) * f64::from(f32::EPSILON) * 32.0;
if !representable_in(a.max(b) + margin, T::DTYPE)
|| !representable_in(a.min(b) - margin, T::DTYPE)
{
return Err(Error::OutOfRange(OutOfRangePayload::new(
"Array::linspace: range",
"the f32 ramp leaves the integer/half output dtype range",
format_smolstr!("start={start}, stop={stop}"),
)));
}
}
}
crate::error::ensure_handler_installed();
let mut out = Self(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_linspace(
&mut out.0,
start,
stop,
n_i32,
mlxrs_sys::mlx_dtype::from(T::DTYPE),
default_stream(),
)
})?;
Ok(out)
}
pub fn from_slice<T>(data: &[T], shape: &impl IntoShape) -> Result<Self>
where
T: Element,
{
crate::error::ensure_handler_installed();
shape.with_shape(|s| {
validate_dims(s)?;
let total: usize = s.iter().enumerate().try_fold(1usize, |acc, (idx, &d)| {
let d_usize = d as usize;
acc.checked_mul(d_usize).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"Array::from_slice: shape product",
"usize",
[
("acc", acc as u64),
("dim", d_usize as u64),
("dim_index", idx as u64),
],
))
})
})?;
if total != data.len() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"Array::from_slice: shape product vs data.len()",
total,
data.len(),
)));
}
let dim_i32 = i32::try_from(s.len()).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"Array::from_slice: ndim",
"must fit in i32",
format_smolstr!("{}", s.len()),
))
})?;
let arr = unsafe {
mlxrs_sys::mlx_array_new_data(
data_ptr(data).cast::<std::ffi::c_void>(),
dim_ptr(s),
dim_i32,
mlxrs_sys::mlx_dtype::from(T::DTYPE),
)
};
check_handle(arr)
})
}
}
const ARANGE_SUPPORTED_DTYPES: &[Dtype] = &[
Dtype::U8,
Dtype::U16,
Dtype::U32,
Dtype::U64,
Dtype::I8,
Dtype::I16,
Dtype::I32,
Dtype::I64,
Dtype::F16,
Dtype::BF16,
Dtype::F32,
Dtype::F64,
Dtype::Complex64,
];
fn integer_cast_bounds(dtype: Dtype) -> Option<(f64, f64)> {
Some(match dtype {
Dtype::U8 => (0.0, f64::from(u8::MAX) + 1.0),
Dtype::U16 => (0.0, f64::from(u16::MAX) + 1.0),
Dtype::U32 => (0.0, f64::from(u32::MAX) + 1.0),
Dtype::U64 => (0.0, u64::MAX as f64 + 1.0),
Dtype::I8 => (f64::from(i8::MIN), f64::from(i8::MAX) + 1.0),
Dtype::I16 => (f64::from(i16::MIN), f64::from(i16::MAX) + 1.0),
Dtype::I32 => (f64::from(i32::MIN), f64::from(i32::MAX) + 1.0),
Dtype::I64 => (i64::MIN as f64, i64::MAX as f64 + 1.0),
_ => return None,
})
}
fn representable_in(v: f64, dtype: Dtype) -> bool {
match dtype {
Dtype::F64 | Dtype::Bool => true,
Dtype::F32 | Dtype::Complex64 => v.abs() <= f64::from(f32::MAX),
Dtype::F16 => v.abs() <= f64::from(half::f16::MAX),
Dtype::BF16 => v.abs() <= f64::from(half::bf16::MAX),
_ => integer_cast_bounds(dtype).is_some_and(|(lo, hi)| (lo..hi).contains(&v.trunc())),
}
}