use std::{
cell::RefCell,
collections::TryReserveError,
ffi::{CStr, c_char, c_int, c_void},
panic::{AssertUnwindSafe, catch_unwind},
path::PathBuf,
ptr,
sync::{
OnceLock,
atomic::{AtomicBool, Ordering},
},
};
use smallvec::SmallVec;
use smol_str::SmolStr;
use crate::Dtype;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FileOp {
Create,
Write,
Flush,
Read,
Open,
Stat,
Copy,
Remove,
Rename,
Fsync,
Other(&'static str),
}
struct MultiLengthsFmt<'a>(&'a Vec<(&'static str, usize)>);
impl std::fmt::Display for MultiLengthsFmt<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut first = true;
for (name, len) in self.0 {
if !first {
f.write_str(", ")?;
}
write!(f, "{name}={len}")?;
first = false;
}
Ok(())
}
}
impl std::fmt::Display for FileOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Create => f.write_str("create"),
Self::Write => f.write_str("write"),
Self::Flush => f.write_str("flush"),
Self::Read => f.write_str("read"),
Self::Open => f.write_str("open"),
Self::Stat => f.write_str("stat"),
Self::Copy => f.write_str("copy"),
Self::Remove => f.write_str("remove"),
Self::Rename => f.write_str("rename"),
Self::Fsync => f.write_str("fsync"),
Self::Other(s) => f.write_str(s),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DtypeMismatchPayload {
expected: Dtype,
got: Dtype,
}
impl DtypeMismatchPayload {
pub fn new(expected: Dtype, got: Dtype) -> Self {
Self { expected, got }
}
#[inline(always)]
pub const fn expected(&self) -> Dtype {
self.expected
}
#[inline(always)]
pub const fn got(&self) -> Dtype {
self.got
}
}
impl std::fmt::Display for DtypeMismatchPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "expected {:?}, got {:?}", self.expected, self.got)
}
}
impl std::error::Error for DtypeMismatchPayload {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FfiNullHandlePayload {
fn_name: &'static str,
}
impl FfiNullHandlePayload {
pub fn new(fn_name: &'static str) -> Self {
Self { fn_name }
}
#[inline(always)]
pub const fn fn_name(&self) -> &'static str {
self.fn_name
}
}
impl std::fmt::Display for FfiNullHandlePayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "FFI: {} returned NULL handle", self.fn_name)
}
}
impl std::error::Error for FfiNullHandlePayload {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MissingFieldPayload {
type_name: &'static str,
field: &'static str,
}
impl MissingFieldPayload {
pub fn new(type_name: &'static str, field: &'static str) -> Self {
Self { type_name, field }
}
#[inline(always)]
pub const fn type_name(&self) -> &'static str {
self.type_name
}
#[inline(always)]
pub const fn field(&self) -> &'static str {
self.field
}
}
impl std::fmt::Display for MissingFieldPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: missing required field `{}`",
self.type_name, self.field
)
}
}
impl std::error::Error for MissingFieldPayload {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ArithmeticOverflowPayload {
context: &'static str,
op_type: &'static str,
operands: SmallVec<[(&'static str, u64); 2]>,
}
impl ArithmeticOverflowPayload {
pub fn new(context: &'static str, op_type: &'static str) -> Self {
Self {
context,
op_type,
operands: SmallVec::new(),
}
}
pub fn with_operands(
context: &'static str,
op_type: &'static str,
operands: impl IntoIterator<Item = (&'static str, u64)>,
) -> Self {
Self {
context,
op_type,
operands: operands.into_iter().collect(),
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn op_type(&self) -> &'static str {
self.op_type
}
#[inline(always)]
pub fn operands(&self) -> &[(&'static str, u64)] {
&self.operands
}
}
impl std::fmt::Display for ArithmeticOverflowPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.operands.is_empty() {
write!(f, "{}: overflow ({})", self.context, self.op_type)
} else {
write!(
f,
"{}: overflow ({}) with operands",
self.context, self.op_type
)?;
let mut first = true;
f.write_str(" ")?;
for (name, value) in &self.operands {
if !first {
f.write_str(", ")?;
}
write!(f, "{name}={value}")?;
first = false;
}
Ok(())
}
}
}
impl std::error::Error for ArithmeticOverflowPayload {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EmptyInputPayload {
context: &'static str,
}
impl EmptyInputPayload {
pub fn new(context: &'static str) -> Self {
Self { context }
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
}
impl std::fmt::Display for EmptyInputPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} is empty (at least one element required)",
self.context
)
}
}
impl std::error::Error for EmptyInputPayload {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct InvariantViolationPayload {
context: &'static str,
requirement: &'static str,
}
impl InvariantViolationPayload {
pub fn new(context: &'static str, requirement: &'static str) -> Self {
Self {
context,
requirement,
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn requirement(&self) -> &'static str {
self.requirement
}
}
impl std::fmt::Display for InvariantViolationPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "backend: {} {}", self.context, self.requirement)
}
}
impl std::error::Error for InvariantViolationPayload {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RankMismatchPayload {
context: &'static str,
actual: u32,
actual_shape: Vec<usize>,
}
impl RankMismatchPayload {
pub fn new(context: &'static str, actual: u32, actual_shape: Vec<usize>) -> Self {
Self {
context,
actual,
actual_shape,
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn actual(&self) -> u32 {
self.actual
}
pub fn actual_shape(&self) -> &[usize] {
&self.actual_shape
}
}
impl std::fmt::Display for RankMismatchPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"shape mismatch: {}: got rank {} (shape {:?})",
self.context, self.actual, self.actual_shape
)
}
}
impl std::error::Error for RankMismatchPayload {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LengthMismatchPayload {
context: &'static str,
expected: usize,
actual: usize,
}
impl LengthMismatchPayload {
pub fn new(context: &'static str, expected: usize, actual: usize) -> Self {
Self {
context,
expected,
actual,
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn expected(&self) -> usize {
self.expected
}
#[inline(always)]
pub const fn actual(&self) -> usize {
self.actual
}
}
impl std::fmt::Display for LengthMismatchPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"shape mismatch: {}: expected length {}, got {}",
self.context, self.expected, self.actual
)
}
}
impl std::error::Error for LengthMismatchPayload {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OutOfRangePayload {
context: &'static str,
requirement: &'static str,
value: SmolStr,
}
impl OutOfRangePayload {
pub fn new(context: &'static str, requirement: &'static str, value: impl Into<SmolStr>) -> Self {
Self {
context,
requirement,
value: value.into(),
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn requirement(&self) -> &'static str {
self.requirement
}
pub fn value(&self) -> &str {
&self.value
}
}
impl std::fmt::Display for OutOfRangePayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"shape mismatch: {}: {}, got {}",
self.context, self.requirement, self.value
)
}
}
impl std::error::Error for OutOfRangePayload {}
#[derive(Debug)]
pub struct FileIoPayload {
context: &'static str,
op: FileOp,
path: PathBuf,
inner: std::io::Error,
}
impl FileIoPayload {
pub fn new(context: &'static str, op: FileOp, path: PathBuf, inner: std::io::Error) -> Self {
Self {
context,
op,
path,
inner,
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn op(&self) -> FileOp {
self.op
}
pub fn path(&self) -> &std::path::Path {
&self.path
}
pub fn inner(&self) -> &std::io::Error {
&self.inner
}
}
impl std::fmt::Display for FileIoPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"io: {}: {} {}: {}",
self.context,
self.op,
self.path.display(),
self.inner
)
}
}
impl std::error::Error for FileIoPayload {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.inner)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MultiLengthMismatchPayload {
context: &'static str,
lengths: Vec<(&'static str, usize)>,
}
impl MultiLengthMismatchPayload {
pub fn new(context: &'static str, lengths: Vec<(&'static str, usize)>) -> Self {
Self { context, lengths }
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
pub fn lengths(&self) -> &[(&'static str, usize)] {
&self.lengths
}
}
impl std::fmt::Display for MultiLengthMismatchPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"shape mismatch: {}: length mismatch — {}",
self.context,
MultiLengthsFmt(&self.lengths)
)
}
}
impl std::error::Error for MultiLengthMismatchPayload {}
#[cfg(feature = "lm")]
#[cfg_attr(docsrs, doc(cfg(feature = "lm")))]
#[derive(Debug)]
pub struct DurabilityWarningPayload {
committed: bool,
source: std::io::Error,
}
#[cfg(feature = "lm")]
impl DurabilityWarningPayload {
pub fn new(committed: bool, source: std::io::Error) -> Self {
Self { committed, source }
}
#[inline(always)]
pub const fn committed(&self) -> bool {
self.committed
}
pub fn source(&self) -> &std::io::Error {
&self.source
}
pub fn into_source(self) -> std::io::Error {
self.source
}
}
#[cfg(feature = "lm")]
impl std::fmt::Display for DurabilityWarningPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"save committed but durability fsync failed (committed={}): {}",
self.committed, self.source
)
}
}
#[cfg(feature = "lm")]
impl std::error::Error for DurabilityWarningPayload {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.source)
}
}
#[cfg(feature = "lm")]
#[cfg_attr(docsrs, doc(cfg(feature = "lm")))]
#[derive(Debug)]
pub struct ConvertPostSavePartialPayload {
committed: bool,
save_warning: Option<std::io::Error>,
copy_error: Box<Error>,
}
#[cfg(feature = "lm")]
impl ConvertPostSavePartialPayload {
pub fn new(committed: bool, save_warning: Option<std::io::Error>, copy_error: Error) -> Self {
Self {
committed,
save_warning,
copy_error: Box::new(copy_error),
}
}
#[inline(always)]
pub const fn committed(&self) -> bool {
self.committed
}
pub fn save_warning(&self) -> Option<&std::io::Error> {
self.save_warning.as_ref()
}
pub fn copy_error(&self) -> &Error {
&self.copy_error
}
pub fn into_parts(self) -> (bool, Option<std::io::Error>, Error) {
(self.committed, self.save_warning, *self.copy_error)
}
}
#[cfg(feature = "lm")]
impl std::fmt::Display for ConvertPostSavePartialPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"convert: save committed but post-save extras copy partially failed (committed={}); \
destination directory may be incomplete (missing tokenizer/extras files)",
self.committed
)
}
}
#[cfg(feature = "lm")]
impl std::error::Error for ConvertPostSavePartialPayload {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(self.copy_error.as_ref() as &(dyn std::error::Error + 'static))
}
}
#[derive(Debug, thiserror::Error, derive_more::IsVariant)]
#[non_exhaustive]
pub enum Error {
#[error("dtype mismatch: expected {:?}, got {:?}", .0.expected(), .0.got())]
DtypeMismatch(DtypeMismatchPayload),
#[error(transparent)]
UnsupportedDtype(UnsupportedDtypePayload),
#[error("unknown dtype value from mlx: {0}")]
UnknownDtype(u32),
#[error("out of memory")]
OutOfMemory,
#[error("array is not contiguous; M2 will add .contiguous() to materialize")]
NonContiguous,
#[error(transparent)]
MlxOp(MlxOpPayload),
#[error("mlx-c: {0}")]
MlxC(SmolStr),
#[error("mlx backend: {0}")]
Backend(String),
#[error(transparent)]
MalformedData(MalformedDataPayload),
#[error(transparent)]
FfiNullHandle(FfiNullHandlePayload),
#[error(transparent)]
MissingField(MissingFieldPayload),
#[error(transparent)]
ArithmeticOverflow(ArithmeticOverflowPayload),
#[error(transparent)]
EmptyInput(EmptyInputPayload),
#[error(transparent)]
InvariantViolation(InvariantViolationPayload),
#[error(transparent)]
RankMismatch(RankMismatchPayload),
#[error(transparent)]
LengthMismatch(LengthMismatchPayload),
#[error(transparent)]
OutOfRange(OutOfRangePayload),
#[error(transparent)]
FileIo(FileIoPayload),
#[error(transparent)]
MultiLengthMismatch(MultiLengthMismatchPayload),
#[error(transparent)]
ShapePairMismatch(ShapePairMismatchPayload),
#[error(transparent)]
DivisibilityConstraint(DivisibilityConstraintPayload),
#[error(transparent)]
NonFiniteScalar(NonFiniteScalarPayload),
#[error(transparent)]
MissingKey(MissingKeyPayload),
#[error(transparent)]
UnknownEnumValue(UnknownEnumValuePayload),
#[error(transparent)]
KeyCollision(KeyCollisionPayload),
#[error(transparent)]
InteriorNul(InteriorNulPayload),
#[error(transparent)]
CapExceeded(CapExceededPayload),
#[error(transparent)]
AllocFailure(AllocFailurePayload),
#[error(transparent)]
Parse(ParsePayload),
#[error(transparent)]
ExternalOp(ExternalOpPayload),
#[error(transparent)]
BoundedDecode(BoundedDecodePayload),
#[error(transparent)]
LayerKeyed(LayerKeyedPayload),
#[cfg(feature = "tokenizer")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer")))]
#[error("tokenizer: {0}")]
Tokenizer(SmolStr),
#[cfg(feature = "lm")]
#[cfg_attr(docsrs, doc(cfg(feature = "lm")))]
#[error("shard path collision: {}", .0.display())]
ShardPathCollision(std::path::PathBuf),
#[cfg(feature = "lm")]
#[cfg_attr(docsrs, doc(cfg(feature = "lm")))]
#[error(transparent)]
DurabilityWarning(DurabilityWarningPayload),
#[cfg(feature = "lm")]
#[cfg_attr(docsrs, doc(cfg(feature = "lm")))]
#[error(transparent)]
ConvertPostSavePartial(ConvertPostSavePartialPayload),
#[cfg(feature = "lm")]
#[cfg_attr(docsrs, doc(cfg(feature = "lm")))]
#[error(transparent)]
ConvertDurabilityWarnings(#[from] ConvertDurabilityWarnings),
}
const _: () = assert!(
core::mem::size_of::<Error>() <= 96,
"Error enum exceeded 96 bytes — box or shrink the offending variant payload (see the #257 M8/M10 size note)"
);
#[cfg(feature = "lm")]
#[cfg_attr(docsrs, doc(cfg(feature = "lm")))]
#[derive(Debug)]
pub struct ConvertDurabilityWarnings {
pub(crate) committed: bool,
pub(crate) save: Option<std::io::Error>,
pub(crate) post_copy_file: Option<std::io::Error>,
pub(crate) post_copy_dir: Option<std::io::Error>,
}
#[cfg(feature = "lm")]
impl std::fmt::Display for ConvertDurabilityWarnings {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"convert: save committed but post-save durability warnings (committed={}); \
destination is on-disk and load-correct, but one or more fsync boundaries returned a warning",
self.committed
)
}
}
#[cfg(feature = "lm")]
impl std::error::Error for ConvertDurabilityWarnings {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self
.first_warning()
.map(|e| e as &(dyn std::error::Error + 'static))
}
}
#[cfg(feature = "lm")]
impl ConvertDurabilityWarnings {
pub fn new(
committed: bool,
save: Option<std::io::Error>,
post_copy_file: Option<std::io::Error>,
post_copy_dir: Option<std::io::Error>,
) -> Self {
Self {
committed,
save,
post_copy_file,
post_copy_dir,
}
}
#[inline(always)]
pub fn committed(&self) -> bool {
self.committed
}
pub fn save(&self) -> Option<&std::io::Error> {
self.save.as_ref()
}
pub fn post_copy_file(&self) -> Option<&std::io::Error> {
self.post_copy_file.as_ref()
}
pub fn post_copy_dir(&self) -> Option<&std::io::Error> {
self.post_copy_dir.as_ref()
}
pub fn into_parts(
self,
) -> (
bool,
Option<std::io::Error>,
Option<std::io::Error>,
Option<std::io::Error>,
) {
(
self.committed,
self.save,
self.post_copy_file,
self.post_copy_dir,
)
}
pub fn first_warning(&self) -> Option<&std::io::Error> {
self
.save
.as_ref()
.or(self.post_copy_file.as_ref())
.or(self.post_copy_dir.as_ref())
}
pub fn count(&self) -> usize {
usize::from(self.save.is_some())
+ usize::from(self.post_copy_file.is_some())
+ usize::from(self.post_copy_dir.is_some())
}
}
#[cfg(feature = "tokenizer")]
impl Error {
pub(crate) fn tokenizer(message: impl Into<SmolStr>) -> Self {
Self::Tokenizer(message.into())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MlxOpKind {
Matmul,
Reshape,
Broadcast,
Shape,
Slice,
Concat,
Gather,
Scatter,
Take,
Fft,
Quantize,
Dequantize,
Conv,
Pool,
Eval,
Sort,
ArgSort,
Norm,
Linalg,
Random,
Transform,
Elementwise,
System,
Positional,
Distributed,
Io,
Other(SmolStr),
}
impl std::fmt::Display for MlxOpKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Matmul => f.write_str("matmul"),
Self::Reshape => f.write_str("reshape"),
Self::Broadcast => f.write_str("broadcast"),
Self::Shape => f.write_str("shape"),
Self::Slice => f.write_str("slice"),
Self::Concat => f.write_str("concat"),
Self::Gather => f.write_str("gather"),
Self::Scatter => f.write_str("scatter"),
Self::Take => f.write_str("take"),
Self::Fft => f.write_str("fft"),
Self::Quantize => f.write_str("quantize"),
Self::Dequantize => f.write_str("dequantize"),
Self::Conv => f.write_str("conv"),
Self::Pool => f.write_str("reduce"),
Self::Eval => f.write_str("eval"),
Self::Sort => f.write_str("sort"),
Self::ArgSort => f.write_str("argsort"),
Self::Norm => f.write_str("norm"),
Self::Linalg => f.write_str("linalg"),
Self::Random => f.write_str("random"),
Self::Transform => f.write_str("transform"),
Self::Elementwise => f.write_str("elementwise"),
Self::System => f.write_str("system"),
Self::Positional => f.write_str("positional"),
Self::Distributed => f.write_str("distributed"),
Self::Io => f.write_str("io"),
Self::Other(s) => f.write_str(s),
}
}
}
impl MlxOpKind {
pub fn parse_prefix(msg: &str) -> Option<Self> {
let rest = msg.strip_prefix('[')?;
let end = rest.find(']')?;
let prefix = &rest[..end];
if let Some(rest_after_ns) = prefix.strip_prefix("linalg::") {
let _ = rest_after_ns;
return Some(Self::Linalg);
}
if let Some(rest_after_ns) = prefix.strip_prefix("Primitive::") {
let _ = rest_after_ns;
return Some(Self::Eval);
}
if let Some(rest_after_ns) = prefix.strip_prefix("import_function::") {
let _ = rest_after_ns;
return Some(Self::System);
}
if prefix.starts_with("gpu::")
|| prefix.starts_with("metal::")
|| prefix.starts_with("Metal::")
|| prefix.starts_with("Event::")
|| prefix.starts_with("Fence::")
{
return Some(Self::System);
}
let stripped = {
let mut s = prefix;
while let Some(rest) = s.strip_prefix("fast::") {
s = rest;
}
s
};
let no_method = stripped.split("::").next().unwrap_or(stripped);
let primary = no_method.split('.').next().unwrap_or(no_method);
let lower = primary.to_lowercase();
Some(match lower.as_str() {
"matmul" | "addmm" | "quantizedmatmul" | "quantized_matmul" | "block_masked_mm"
| "blockmaskedmm" | "gather_mm" | "gathermm" | "gather_qmm" | "gatherqmm" | "qqmm"
| "qqmatmul" | "segmented_mm" | "inner" | "tensordot" | "kron" | "gemm_and_bias" => {
Self::Matmul
}
"reshape" | "unflatten" | "flatten" | "expand_dims" | "squeeze" | "transpose"
| "swapaxes" | "moveaxis" | "view" => Self::Reshape,
"broadcast" | "broadcast_shapes" | "broadcast_to" | "broadcast_arrays" => Self::Broadcast,
"shape" => Self::Shape,
"slice"
| "slice_update"
| "sliceupdate"
| "dynamicslice"
| "dynamicsliceupdate"
| "dynamic_slice"
| "dynamic_slice_update"
| "split"
| "trace"
| "diag"
| "diagonal"
| "tril"
| "triu" => Self::Slice,
"concatenate" | "stack" | "repeat" | "meshgrid" => Self::Concat,
"gather" | "gather_axis" | "gatheraxis" => Self::Gather,
"scatter" | "scatter_axis" | "scatter_add_axis" | "scatter_add" | "scatter_max"
| "scatter_min" | "scatter_prod" | "scatteraxis" | "masked_scatter" | "maskedscatter"
| "put_along_axis" => Self::Scatter,
"take" | "take_along_axis" => Self::Take,
"fft" | "ifft" | "rfft" | "irfft" | "fft2" | "ifft2" | "fftn" | "ifftn" | "fftfreq"
| "rfftfreq" | "fftshift" | "ifftshift" | "hadamard" | "hadamard_transform" => Self::Fft,
"quantize" | "block_quantized" | "from_fp8" | "to_fp8" | "quantize_dequantize" => {
Self::Quantize
}
"dequantize" => Self::Dequantize,
"conv" | "conv1d" | "conv2d" | "conv3d" | "conv_transpose" | "convolution" => Self::Conv,
"max_pool" | "avg_pool" | "reduce" | "all" | "any" | "sum" | "prod" | "mean" | "var"
| "max" | "min" | "median" | "logsumexp" | "logcumsumexp" | "cumsum" | "cumprod"
| "cummax" | "cummin" | "number_of_elements" | "softmax" | "topk" => Self::Pool,
"compiledarray" | "compiled" | "compile" | "eval" | "async_eval" | "copy" | "nanequal" => {
Self::Eval
}
"sort" | "partition" => Self::Sort,
"argsort" | "argpartition" | "argmax" | "argmin" => Self::ArgSort,
"mlx_norm"
| "layer_norm"
| "rms_norm"
| "group_norm"
| "rope"
| "scaled_dot_product_attention"
| "scale_dot_product_attention"
| "vjp_layer_norm" => Self::Norm,
"grad" | "vjp" | "jvp" | "vmap" | "pad" => Self::Transform,
"arange"
| "full"
| "eye"
| "linspace"
| "uniform"
| "normal"
| "trunc_normal"
| "bernoulli"
| "categorical"
| "multivariate_normal"
| "laplace"
| "randint"
| "bits"
| "finfo"
| "iinfo"
| "randombits" => Self::Random,
"astype"
| "nan_to_num"
| "negative"
| "floor"
| "bitwise_invert"
| "divmod"
| "roll"
| "abs"
| "binary_float"
| "binary_int"
| "unary_fp"
| "unary_int"
| "unary_real"
| "extract_tensor_data" => Self::Elementwise,
"event"
| "streamcontext"
| "threadpool"
| "set_default_device"
| "set_default_stream"
| "default_stream"
| "deserialize_variant"
| "export_function"
| "import_function"
| "rope::vjp"
| "metal"
| "malloc"
| "new_stream"
| "metal_kernel"
| "cuda_kernel"
| "custom_kernel" => Self::System,
"cholesky" | "eig" | "eigh" | "inverse" | "luf" | "qrf" | "svd" => Self::Linalg,
"allgather" | "allreduce" | "reducescatter" | "send" | "recv" | "sum_scatter"
| "distributed" | "mpi" | "nccl" | "jaccl" | "ring" => Self::Distributed,
"load" | "save" | "load_safetensors" | "save_safetensors" | "load_gguf" | "save_gguf"
| "safetensor" | "read" | "write" | "from_str" => Self::Io,
_ => Self::Other(SmolStr::new(primary)),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MlxOpPayload {
op: MlxOpKind,
message: SmolStr,
}
impl MlxOpPayload {
pub fn new(op: MlxOpKind, message: impl Into<SmolStr>) -> Self {
Self {
op,
message: message.into(),
}
}
pub fn op(&self) -> &MlxOpKind {
&self.op
}
#[inline(always)]
pub fn message(&self) -> &str {
&self.message
}
}
impl std::fmt::Display for MlxOpPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "mlx {}: {}", self.op, self.message)
}
}
impl std::error::Error for MlxOpPayload {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MissingKeyPayload {
context: &'static str,
key: SmolStr,
}
impl MissingKeyPayload {
pub fn new(context: &'static str, key: impl Into<SmolStr>) -> Self {
Self {
context,
key: key.into(),
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub fn key(&self) -> &str {
&self.key
}
}
impl std::fmt::Display for MissingKeyPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: key `{}` not found", self.context, self.key)
}
}
impl std::error::Error for MissingKeyPayload {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UnknownEnumValuePayload {
type_name: &'static str,
value: SmolStr,
supported: &'static [&'static str],
}
impl UnknownEnumValuePayload {
pub fn new(
type_name: &'static str,
value: impl Into<SmolStr>,
supported: &'static [&'static str],
) -> Self {
Self {
type_name,
value: value.into(),
supported,
}
}
#[inline(always)]
pub const fn type_name(&self) -> &'static str {
self.type_name
}
#[inline(always)]
pub fn value(&self) -> &str {
&self.value
}
#[inline(always)]
pub const fn supported(&self) -> &'static [&'static str] {
self.supported
}
}
impl std::fmt::Display for UnknownEnumValuePayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: unknown value `{}` (supported: {:?})",
self.type_name, self.value, self.supported
)
}
}
impl std::error::Error for UnknownEnumValuePayload {}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct NonFiniteScalarPayload {
context: &'static str,
value: f64,
}
impl NonFiniteScalarPayload {
pub fn new(context: &'static str, value: f64) -> Self {
Self { context, value }
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn value(&self) -> f64 {
self.value
}
}
impl std::fmt::Display for NonFiniteScalarPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: value is non-finite (NaN or Inf): {}",
self.context, self.value
)
}
}
impl std::error::Error for NonFiniteScalarPayload {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KeyCollisionPayload {
context: &'static str,
key: SmolStr,
}
impl KeyCollisionPayload {
pub fn new(context: &'static str, key: impl Into<SmolStr>) -> Self {
Self {
context,
key: key.into(),
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub fn key(&self) -> &str {
&self.key
}
}
impl std::fmt::Display for KeyCollisionPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: key `{}` collides", self.context, self.key)
}
}
impl std::error::Error for KeyCollisionPayload {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct InteriorNulPayload {
context: &'static str,
bytes_kind: &'static str,
}
impl InteriorNulPayload {
pub fn new(context: &'static str, bytes_kind: &'static str) -> Self {
Self {
context,
bytes_kind,
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn bytes_kind(&self) -> &'static str {
self.bytes_kind
}
}
impl std::fmt::Display for InteriorNulPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: {} contains an interior NUL byte",
self.context, self.bytes_kind
)
}
}
impl std::error::Error for InteriorNulPayload {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CapExceededPayload {
context: &'static str,
cap_name: &'static str,
cap: u64,
observed: u64,
}
impl CapExceededPayload {
pub fn new(context: &'static str, cap_name: &'static str, cap: u64, observed: u64) -> Self {
Self {
context,
cap_name,
cap,
observed,
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn cap_name(&self) -> &'static str {
self.cap_name
}
#[inline(always)]
pub const fn cap(&self) -> u64 {
self.cap
}
#[inline(always)]
pub const fn observed(&self) -> u64 {
self.observed
}
}
impl std::fmt::Display for CapExceededPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: observed {} exceeds cap {} ({})",
self.context, self.observed, self.cap_name, self.cap
)
}
}
impl std::error::Error for CapExceededPayload {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ShapePairMismatchPayload {
context: &'static str,
expected: SmallVec<[usize; 2]>,
actual: SmallVec<[usize; 2]>,
}
impl ShapePairMismatchPayload {
pub fn new(
context: &'static str,
expected: impl Into<SmallVec<[usize; 2]>>,
actual: impl Into<SmallVec<[usize; 2]>>,
) -> Self {
Self {
context,
expected: expected.into(),
actual: actual.into(),
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub fn expected(&self) -> &[usize] {
&self.expected
}
#[inline(always)]
pub fn actual(&self) -> &[usize] {
&self.actual
}
}
impl std::fmt::Display for ShapePairMismatchPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"shape mismatch: {}: expected {:?}, got {:?}",
self.context,
&self.expected[..],
&self.actual[..]
)
}
}
impl std::error::Error for ShapePairMismatchPayload {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DivisibilityConstraintPayload {
context: &'static str,
name_dividend: &'static str,
name_divisor: &'static str,
dividend: u64,
divisor: u64,
}
impl DivisibilityConstraintPayload {
pub fn new(
context: &'static str,
name_dividend: &'static str,
dividend: u64,
name_divisor: &'static str,
divisor: u64,
) -> Self {
Self {
context,
name_dividend,
name_divisor,
dividend,
divisor,
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn name_dividend(&self) -> &'static str {
self.name_dividend
}
#[inline(always)]
pub const fn name_divisor(&self) -> &'static str {
self.name_divisor
}
#[inline(always)]
pub const fn dividend(&self) -> u64 {
self.dividend
}
#[inline(always)]
pub const fn divisor(&self) -> u64 {
self.divisor
}
}
impl std::fmt::Display for DivisibilityConstraintPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: {} ({}) must be divisible by {} ({})",
self.context, self.name_dividend, self.dividend, self.name_divisor, self.divisor
)
}
}
impl std::error::Error for DivisibilityConstraintPayload {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct UnsupportedDtypePayload {
context: &'static str,
dtype: Dtype,
supported: &'static [Dtype],
}
impl UnsupportedDtypePayload {
pub const fn new(context: &'static str, dtype: Dtype, supported: &'static [Dtype]) -> Self {
Self {
context,
dtype,
supported,
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn dtype(&self) -> Dtype {
self.dtype
}
#[inline(always)]
pub const fn supported(&self) -> &'static [Dtype] {
self.supported
}
}
impl std::fmt::Display for UnsupportedDtypePayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: unsupported dtype {:?} (supported: {:?})",
self.context, self.dtype, self.supported
)
}
}
impl std::error::Error for UnsupportedDtypePayload {}
#[derive(Debug)]
pub struct AllocFailurePayload {
context: &'static str,
item: &'static str,
count: u64,
inner: TryReserveError,
}
impl AllocFailurePayload {
pub fn new(
context: &'static str,
item: &'static str,
count: u64,
inner: TryReserveError,
) -> Self {
Self {
context,
item,
count,
inner,
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn item(&self) -> &'static str {
self.item
}
#[inline(always)]
pub const fn count(&self) -> u64 {
self.count
}
#[inline(always)]
pub fn inner(&self) -> &TryReserveError {
&self.inner
}
}
impl std::fmt::Display for AllocFailurePayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: reservation for {} {} failed: {}",
self.context, self.count, self.item, self.inner
)
}
}
impl std::error::Error for AllocFailurePayload {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.inner)
}
}
#[derive(Debug)]
pub struct ParsePayload {
context: &'static str,
input_kind: &'static str,
inner: Box<dyn std::error::Error + Send + Sync>,
}
impl ParsePayload {
pub fn new(
context: &'static str,
input_kind: &'static str,
inner: impl Into<Box<dyn std::error::Error + Send + Sync>>,
) -> Self {
Self {
context,
input_kind,
inner: inner.into(),
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn input_kind(&self) -> &'static str {
self.input_kind
}
pub fn inner(&self) -> &(dyn std::error::Error + Send + Sync + 'static) {
self.inner.as_ref()
}
}
impl std::fmt::Display for ParsePayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: parse {} failed: {}",
self.context, self.input_kind, self.inner
)
}
}
impl std::error::Error for ParsePayload {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(self.inner.as_ref())
}
}
#[derive(Debug)]
pub struct ExternalOpPayload {
context: &'static str,
op_kind: &'static str,
inner: Box<dyn std::error::Error + Send + Sync>,
}
impl ExternalOpPayload {
pub fn new(
context: &'static str,
op_kind: &'static str,
inner: impl Into<Box<dyn std::error::Error + Send + Sync>>,
) -> Self {
Self {
context,
op_kind,
inner: inner.into(),
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn op_kind(&self) -> &'static str {
self.op_kind
}
pub fn inner(&self) -> &(dyn std::error::Error + Send + Sync + 'static) {
self.inner.as_ref()
}
}
impl std::fmt::Display for ExternalOpPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: external {} failed: {}",
self.context, self.op_kind, self.inner
)
}
}
impl std::error::Error for ExternalOpPayload {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(self.inner.as_ref())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BoundedDecodePayload {
context: &'static str,
cap: u64,
observed: u64,
}
impl BoundedDecodePayload {
pub fn new(context: &'static str, cap: u64, observed: u64) -> Self {
Self {
context,
cap,
observed,
}
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn cap(&self) -> u64 {
self.cap
}
#[inline(always)]
pub const fn observed(&self) -> u64 {
self.observed
}
}
impl std::fmt::Display for BoundedDecodePayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: decoder produced {} elements (cap {})",
self.context, self.observed, self.cap
)
}
}
impl std::error::Error for BoundedDecodePayload {}
#[derive(Debug)]
pub struct LayerKeyedPayload {
layer: SmolStr,
inner: Box<Error>,
}
impl LayerKeyedPayload {
pub fn new(layer: impl Into<SmolStr>, inner: Error) -> Self {
Self {
layer: layer.into(),
inner: Box::new(inner),
}
}
#[inline(always)]
pub fn layer(&self) -> &str {
&self.layer
}
#[inline(always)]
pub fn inner(&self) -> &Error {
self.inner.as_ref()
}
}
impl std::fmt::Display for LayerKeyedPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "layer `{}`: {}", self.layer, self.inner)
}
}
impl std::error::Error for LayerKeyedPayload {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(self.inner.as_ref() as &(dyn std::error::Error + 'static))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MalformedDataPayload {
context: &'static str,
detail: &'static str,
}
impl MalformedDataPayload {
pub const fn new(context: &'static str, detail: &'static str) -> Self {
Self { context, detail }
}
#[inline(always)]
pub const fn context(&self) -> &'static str {
self.context
}
#[inline(always)]
pub const fn detail(&self) -> &'static str {
self.detail
}
}
impl std::fmt::Display for MalformedDataPayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "malformed data: {}: {}", self.context, self.detail)
}
}
impl std::error::Error for MalformedDataPayload {}
pub type Result<T> = std::result::Result<T, Error>;
#[cfg(any(feature = "lm", feature = "embeddings"))]
pub(crate) fn try_with_capacity<T>(cap: usize) -> Result<Vec<T>> {
let mut v = Vec::new();
v.try_reserve_exact(cap).map_err(|_| Error::OutOfMemory)?;
Ok(v)
}
#[cfg(feature = "vlm")]
pub(crate) fn try_to_vec<T: Clone>(slice: &[T]) -> Result<Vec<T>> {
let mut v = try_with_capacity(slice.len())?;
v.extend_from_slice(slice);
Ok(v)
}
#[cfg(feature = "lm")]
pub(crate) fn try_extend_from_slice<T: Clone>(v: &mut Vec<T>, slice: &[T]) -> Result<()> {
v.try_reserve(slice.len()).map_err(|_| Error::OutOfMemory)?;
v.extend_from_slice(slice);
Ok(())
}
thread_local! {
pub(crate) static LAST: RefCell<Option<Error>> = const { RefCell::new(None) };
}
#[inline]
pub(crate) fn take_last() -> Option<Error> {
LAST.with(|c| c.borrow_mut().take())
}
#[inline]
pub(crate) fn set_last(err: Error) {
let _ = LAST.try_with(|c| {
if let Ok(mut g) = c.try_borrow_mut() {
*g = Some(err);
}
});
}
pub(crate) fn last_error_message() -> Option<String> {
LAST
.try_with(|c| {
c.try_borrow()
.ok()
.and_then(|g| g.as_ref().map(|e| e.to_string()))
})
.ok()
.flatten()
}
pub(crate) static INIT_VIA_CTOR: AtomicBool = AtomicBool::new(false);
extern "C" fn handler(msg: *const c_char, _data: *mut c_void) {
let _ = catch_unwind(AssertUnwindSafe(|| {
let s = unsafe { CStr::from_ptr(msg) }.to_string_lossy();
let is_generic_closure_wrapper =
s.starts_with("mlx_closure") && s.contains("returned a non-zero value");
let _ = LAST.try_with(|c| {
if let Ok(mut g) = c.try_borrow_mut() {
if is_generic_closure_wrapper && g.is_some() {
return;
}
let payload: SmolStr = s.as_ref().into();
*g = Some(match MlxOpKind::parse_prefix(&payload) {
Some(op) => Error::MlxOp(MlxOpPayload::new(op, payload)),
None => Error::MlxC(payload),
});
}
});
}));
}
#[ctor::ctor(unsafe)]
fn install_handler() {
if std::env::var("MLXRS_DISABLE_CTOR_FOR_TEST").is_ok() {
return;
}
unsafe {
mlxrs_sys::mlx_set_error_handler(Some(handler), ptr::null_mut(), None);
}
INIT_VIA_CTOR.store(true, Ordering::Relaxed);
}
#[inline]
pub(crate) fn ensure_handler_installed() {
if INIT_VIA_CTOR.load(Ordering::Relaxed) {
return;
}
ensure_handler_installed_slow();
}
#[cold]
#[inline(never)]
fn ensure_handler_installed_slow() {
static FALLBACK: OnceLock<()> = OnceLock::new();
FALLBACK.get_or_init(|| {
unsafe {
mlxrs_sys::mlx_set_error_handler(Some(handler), ptr::null_mut(), None);
}
INIT_VIA_CTOR.store(true, Ordering::Relaxed);
});
}
#[inline]
pub(crate) fn check(rc: c_int) -> Result<()> {
if rc == 0 {
Ok(())
} else {
Err(LAST.with(|c| c.borrow_mut().take()).unwrap_or_else(|| {
Error::MlxC(smol_str::format_smolstr!(
"mlx returned {rc} with no message"
))
}))
}
}
#[inline]
pub(crate) fn check_handle(handle: mlxrs_sys::mlx_array) -> Result<crate::Array> {
if handle.ctx.is_null() {
Err(
LAST
.with(|c| c.borrow_mut().take())
.unwrap_or_else(|| Error::MlxC(SmolStr::new_static("mlx returned null handle"))),
)
} else {
Ok(crate::Array(handle))
}
}
#[inline]
pub(crate) fn check_vector_array_handle(handle: mlxrs_sys::mlx_vector_array) -> Result<()> {
if handle.ctx.is_null() {
Err(
LAST.with(|c| c.borrow_mut().take()).unwrap_or_else(|| {
Error::MlxC(SmolStr::new_static("mlx returned null vector_array handle"))
}),
)
} else {
Ok(())
}
}
#[cfg(test)]
mod init_smoke {
use super::*;
#[test]
fn ctor_fired() {
assert!(
INIT_VIA_CTOR.load(Ordering::Relaxed),
"ctor install did not fire — likely symbol stripping or static-init ordering issue"
);
}
#[test]
fn failing_op_returns_err_not_abort() {
super::LAST.with(|c| *c.borrow_mut() = None);
let r = crate::Array::ones::<f32>(&(2, 2)).and_then(|a| a.reshape(&(3,)));
assert!(
matches!(
&r,
Err(crate::Error::MlxOp(p)) if matches!(p.op(), crate::error::MlxOpKind::Reshape)
),
"failing reshape did not surface as MlxOp(Reshape); \
got: {r:?}"
);
}
#[test]
fn mlx_op_kind_parses_real_vendor_prefixes() {
use super::MlxOpKind;
let cases: &[(&str, MlxOpKind)] = &[
("[matmul]", MlxOpKind::Matmul),
("[addmm]", MlxOpKind::Matmul),
("[block_masked_mm]", MlxOpKind::Matmul),
("[BlockMaskedMM]", MlxOpKind::Matmul),
("[gather_mm]", MlxOpKind::Matmul),
("[GatherMM]", MlxOpKind::Matmul),
("[gather_qmm]", MlxOpKind::Matmul),
("[GatherQMM::vjp]", MlxOpKind::Matmul),
("[QuantizedMatmul::vjp]", MlxOpKind::Matmul),
("[QuantizedMatmul::jvp]", MlxOpKind::Matmul),
("[QuantizedMatmul::vmap]", MlxOpKind::Matmul),
("[qqmm]", MlxOpKind::Matmul),
("[segmented_mm]", MlxOpKind::Matmul),
("[inner]", MlxOpKind::Matmul),
("[tensordot]", MlxOpKind::Matmul),
("[kron]", MlxOpKind::Matmul),
("[reshape]", MlxOpKind::Reshape),
("[unflatten]", MlxOpKind::Reshape),
("[Unflatten]", MlxOpKind::Reshape),
("[flatten]", MlxOpKind::Reshape),
("[expand_dims]", MlxOpKind::Reshape),
("[squeeze]", MlxOpKind::Reshape),
("[transpose]", MlxOpKind::Reshape),
("[swapaxes]", MlxOpKind::Reshape),
("[moveaxis]", MlxOpKind::Reshape),
("[view]", MlxOpKind::Reshape),
("[broadcast_shapes]", MlxOpKind::Broadcast),
("[broadcast_arrays]", MlxOpKind::Broadcast),
("[Broadcast]", MlxOpKind::Broadcast),
("[slice]", MlxOpKind::Slice),
("[slice_update]", MlxOpKind::Slice),
("[SliceUpdate]", MlxOpKind::Slice),
("[DynamicSlice::vjp]", MlxOpKind::Slice),
("[DynamicSlice::vmap]", MlxOpKind::Slice),
("[DynamicSliceUpdate::vjp]", MlxOpKind::Slice),
("[DynamicSliceUpdate::vmap]", MlxOpKind::Slice),
("[split]", MlxOpKind::Slice),
("[trace]", MlxOpKind::Slice),
("[diag]", MlxOpKind::Slice),
("[diagonal]", MlxOpKind::Slice),
("[tril]", MlxOpKind::Slice),
("[triu]", MlxOpKind::Slice),
("[concatenate]", MlxOpKind::Concat),
("[stack]", MlxOpKind::Concat),
("[repeat]", MlxOpKind::Concat),
("[meshgrid]", MlxOpKind::Concat),
("[gather]", MlxOpKind::Gather),
("[Gather]", MlxOpKind::Gather),
("[gather_axis]", MlxOpKind::Gather),
("[scatter]", MlxOpKind::Scatter),
("[scatter_axis]", MlxOpKind::Scatter),
("[scatter_add_axis]", MlxOpKind::Scatter),
("[masked_scatter]", MlxOpKind::Scatter),
("[put_along_axis]", MlxOpKind::Scatter),
("[take]", MlxOpKind::Take),
("[take_along_axis]", MlxOpKind::Take),
("[fftn]", MlxOpKind::Fft),
("[fftfreq]", MlxOpKind::Fft),
("[rfftfreq]", MlxOpKind::Fft),
("[fftshift]", MlxOpKind::Fft),
("[ifftshift]", MlxOpKind::Fft),
("[hadamard_transform]", MlxOpKind::Fft),
("[quantize]", MlxOpKind::Quantize),
("[quantized_matmul]", MlxOpKind::Matmul), ("[from_fp8]", MlxOpKind::Quantize),
("[to_fp8]", MlxOpKind::Quantize),
("[dequantize]", MlxOpKind::Dequantize),
("[conv]", MlxOpKind::Conv),
("[sum]", MlxOpKind::Pool),
("[max]", MlxOpKind::Pool),
("[min]", MlxOpKind::Pool),
("[mean]", MlxOpKind::Pool),
("[prod]", MlxOpKind::Pool),
("[median]", MlxOpKind::Pool),
("[logsumexp]", MlxOpKind::Pool),
("[logcumsumexp]", MlxOpKind::Pool),
("[cumsum]", MlxOpKind::Pool),
("[cumprod]", MlxOpKind::Pool),
("[cummax]", MlxOpKind::Pool),
("[cummin]", MlxOpKind::Pool),
("[number_of_elements]", MlxOpKind::Pool),
("[softmax]", MlxOpKind::Pool),
("[topk]", MlxOpKind::Pool),
("[eval]", MlxOpKind::Eval),
("[async_eval]", MlxOpKind::Eval),
("[Compiled]", MlxOpKind::Eval),
("[Primitive::vjp]", MlxOpKind::Eval),
("[Primitive::jvp]", MlxOpKind::Eval),
("[Primitive::vmap]", MlxOpKind::Eval),
("[Primitive::output_shapes]", MlxOpKind::Eval),
("[sort]", MlxOpKind::Sort),
("[partition]", MlxOpKind::Sort),
("[argsort]", MlxOpKind::ArgSort),
("[argpartition]", MlxOpKind::ArgSort),
("[argmax]", MlxOpKind::ArgSort),
("[argmin]", MlxOpKind::ArgSort),
("[layer_norm]", MlxOpKind::Norm),
("[rms_norm]", MlxOpKind::Norm),
("[rope]", MlxOpKind::Norm),
("[scaled_dot_product_attention]", MlxOpKind::Norm),
("[scale_dot_product_attention]", MlxOpKind::Norm),
("[linalg::cholesky]", MlxOpKind::Linalg),
("[linalg::cholesky_inv]", MlxOpKind::Linalg),
("[linalg::cross]", MlxOpKind::Linalg),
("[linalg::eig]", MlxOpKind::Linalg),
("[linalg::eigh]", MlxOpKind::Linalg),
("[linalg::eigvals]", MlxOpKind::Linalg),
("[linalg::eigvalsh]", MlxOpKind::Linalg),
("[linalg::inv]", MlxOpKind::Linalg),
("[linalg::lu]", MlxOpKind::Linalg),
("[linalg::lu_factor]", MlxOpKind::Linalg),
("[linalg::norm]", MlxOpKind::Linalg),
("[linalg::pinv]", MlxOpKind::Linalg),
("[linalg::qr]", MlxOpKind::Linalg),
("[linalg::solve]", MlxOpKind::Linalg),
("[linalg::solve_triangular]", MlxOpKind::Linalg),
("[linalg::svd]", MlxOpKind::Linalg),
("[grad]", MlxOpKind::Transform),
("[vjp]", MlxOpKind::Transform),
("[jvp]", MlxOpKind::Transform),
("[vmap]", MlxOpKind::Transform),
("[compile]", MlxOpKind::Eval),
("[Pad::vmap]", MlxOpKind::Transform),
("[arange]", MlxOpKind::Random),
("[full]", MlxOpKind::Random),
("[eye]", MlxOpKind::Random),
("[linspace]", MlxOpKind::Random),
("[uniform]", MlxOpKind::Random),
("[normal]", MlxOpKind::Random),
("[trunc_normal]", MlxOpKind::Random),
("[bernoulli]", MlxOpKind::Random),
("[categorical]", MlxOpKind::Random),
("[multivariate_normal]", MlxOpKind::Random),
("[laplace]", MlxOpKind::Random),
("[randint]", MlxOpKind::Random),
("[bits]", MlxOpKind::Random),
("[finfo]", MlxOpKind::Random),
("[iinfo]", MlxOpKind::Random),
("[astype]", MlxOpKind::Elementwise),
("[nan_to_num]", MlxOpKind::Elementwise),
("[negative]", MlxOpKind::Elementwise),
("[floor]", MlxOpKind::Elementwise),
("[bitwise_invert]", MlxOpKind::Elementwise),
("[divmod]", MlxOpKind::Elementwise),
("[roll]", MlxOpKind::Elementwise),
("[Event::stream]", MlxOpKind::System),
("[Event::Event]", MlxOpKind::System),
("[Event::wait]", MlxOpKind::System),
("[Fence::update]", MlxOpKind::System),
("[Fence::wait]", MlxOpKind::System),
("[StreamContext]", MlxOpKind::System),
("[ThreadPool::enqueue]", MlxOpKind::System),
("[set_default_device]", MlxOpKind::System),
("[set_default_stream]", MlxOpKind::System),
("[default_stream]", MlxOpKind::System),
("[deserialize_variant]", MlxOpKind::System),
("[export_function]", MlxOpKind::System),
("[import_function]", MlxOpKind::System),
("[import_function::call]", MlxOpKind::System),
("[gpu::eval]", MlxOpKind::System),
("[gpu::finalize]", MlxOpKind::System),
("[gpu::synchronize]", MlxOpKind::System),
("[metal::CommandEncoder]", MlxOpKind::System),
("[metal::Device]", MlxOpKind::System),
("[metal::device_info]", MlxOpKind::System),
("[metal::load_device]", MlxOpKind::System),
("[metal::malloc]", MlxOpKind::System),
("[metal::set_wired_limit]", MlxOpKind::System),
("[metal::start_capture]", MlxOpKind::System),
("[Metal::binary]", MlxOpKind::System),
("[Metal::compiled]", MlxOpKind::System),
("[Metal::copy]", MlxOpKind::System),
("[Metal::ternary]", MlxOpKind::System),
("[Metal::unary]", MlxOpKind::System),
("[METAL]", MlxOpKind::System),
("[malloc]", MlxOpKind::System),
("[new_stream]", MlxOpKind::System),
("[metal_kernel]", MlxOpKind::System),
("[cuda_kernel]", MlxOpKind::System),
("[custom_kernel]", MlxOpKind::System),
("[Matmul::eval_cpu]", MlxOpKind::Matmul),
("[QQMatmul]", MlxOpKind::Matmul),
("[QQMatmul::eval_gpu]", MlxOpKind::Matmul),
("[BlockMaskedMM::eval]", MlxOpKind::Matmul),
("[GatherMM::eval]", MlxOpKind::Matmul),
("[gemm_and_bias]", MlxOpKind::Matmul),
("[Gather::eval_cpu]", MlxOpKind::Gather),
("[Gather::eval_gpu]", MlxOpKind::Gather),
("[GatherAxis::eval_cpu]", MlxOpKind::Gather),
("[Scatter::eval_cpu]", MlxOpKind::Scatter),
("[Scatter::eval_gpu]", MlxOpKind::Scatter),
("[ScatterAxis::eval_cpu]", MlxOpKind::Scatter),
("[MaskedScatter::eval_cpu]", MlxOpKind::Scatter),
("[Sort::eval_gpu]", MlxOpKind::Sort),
("[Convolution::eval]", MlxOpKind::Conv),
("[Convolution::eval_gpu]", MlxOpKind::Conv),
("[Cholesky::eval_cpu]", MlxOpKind::Linalg),
("[Cholesky::eval_gpu]", MlxOpKind::Linalg),
("[Eig::eval_cpu]", MlxOpKind::Linalg),
("[Eig::eval_gpu]", MlxOpKind::Linalg),
("[Eigh::eval_cpu]", MlxOpKind::Linalg),
("[Eigh::eval_gpu]", MlxOpKind::Linalg),
("[Inverse::eval_cpu]", MlxOpKind::Linalg),
("[Inverse::eval_gpu]", MlxOpKind::Linalg),
("[LUF::eval_cpu]", MlxOpKind::Linalg),
("[LUF::eval_gpu]", MlxOpKind::Linalg),
("[QRF::eval_cpu]", MlxOpKind::Linalg),
("[QRF::eval_gpu]", MlxOpKind::Linalg),
("[SVD::eval_cpu]", MlxOpKind::Linalg),
("[SVD::eval_gpu]", MlxOpKind::Linalg),
("[Compile::eval_cpu]", MlxOpKind::Eval),
("[Compiled::eval_cpu]", MlxOpKind::Eval),
("[Copy::eval_gpu]", MlxOpKind::Eval),
("[NanEqual::eval_cpu]", MlxOpKind::Eval),
("[Quantize::eval_gpu]", MlxOpKind::Quantize),
("[fast::Quantize::eval_cpu]", MlxOpKind::Quantize),
("[quantize_dequantize]", MlxOpKind::Quantize),
("[Arange::eval_gpu]", MlxOpKind::Random),
("[RandomBits::eval_gpu]", MlxOpKind::Random),
("[FFT]", MlxOpKind::Fft),
("[hadamard]", MlxOpKind::Fft),
("[vjp_layer_norm]", MlxOpKind::Norm),
("[AllGather::eval_gpu]", MlxOpKind::Distributed),
("[AllReduce::eval_gpu]", MlxOpKind::Distributed),
("[ReduceScatter]", MlxOpKind::Distributed),
("[ReduceScatter::eval_gpu]", MlxOpKind::Distributed),
("[Recv::eval_gpu]", MlxOpKind::Distributed),
("[Send::eval_gpu]", MlxOpKind::Distributed),
("[sum_scatter]", MlxOpKind::Distributed),
("[distributed]", MlxOpKind::Distributed),
("[mpi]", MlxOpKind::Distributed),
("[nccl]", MlxOpKind::Distributed),
("[jaccl]", MlxOpKind::Distributed),
("[ring]", MlxOpKind::Distributed),
("[load]", MlxOpKind::Io),
("[save]", MlxOpKind::Io),
("[load_safetensors]", MlxOpKind::Io),
("[save_safetensors]", MlxOpKind::Io),
("[load_gguf]", MlxOpKind::Io),
("[save_gguf]", MlxOpKind::Io),
("[safetensor]", MlxOpKind::Io),
("[read]", MlxOpKind::Io),
("[write]", MlxOpKind::Io),
("[Load::eval_gpu]", MlxOpKind::Io),
("[from_str]", MlxOpKind::Io),
("[Abs]", MlxOpKind::Elementwise),
("[DivMod]", MlxOpKind::Elementwise),
("[binary_float]", MlxOpKind::Elementwise),
("[binary_int]", MlxOpKind::Elementwise),
("[unary_fp]", MlxOpKind::Elementwise),
("[unary_int]", MlxOpKind::Elementwise),
("[unary_real]", MlxOpKind::Elementwise),
("[extract_tensor_data]", MlxOpKind::Elementwise),
];
for (vendor_prefix, expected) in cases {
let parsed = MlxOpKind::parse_prefix(&format!("{vendor_prefix} some message"))
.unwrap_or_else(|| {
panic!("vendor prefix {vendor_prefix:?} should classify, not return None")
});
assert_eq!(
&parsed, expected,
"vendor prefix {vendor_prefix:?} should classify as {expected:?}, got {parsed:?}",
);
}
let other = MlxOpKind::parse_prefix("[totally_invented_op] foo");
assert!(
matches!(other, Some(MlxOpKind::Other(ref s)) if s == "totally_invented_op"),
"got {other:?}",
);
let other_method = MlxOpKind::parse_prefix("[totally_invented_op::vjp] foo");
assert!(
matches!(other_method, Some(MlxOpKind::Other(ref s)) if s == "totally_invented_op"),
"got {other_method:?}",
);
assert!(MlxOpKind::parse_prefix("plain message without prefix").is_none());
assert!(MlxOpKind::parse_prefix("").is_none());
assert!(MlxOpKind::parse_prefix("[unterminated bracket").is_none());
let deeply_nested = "[".to_string() + &"fast::".repeat(10_000) + "Quantize::eval_cpu]";
assert_eq!(
MlxOpKind::parse_prefix(&deeply_nested),
Some(MlxOpKind::Quantize),
"deeply-nested fast:: must reduce to the inner op without recursing",
);
}
}
#[cfg(test)]
mod pure_payload_tests {
use super::*;
use std::io::{Error as IoError, ErrorKind};
#[test]
fn file_op_display_all_arms() {
assert_eq!(FileOp::Create.to_string(), "create");
assert_eq!(FileOp::Write.to_string(), "write");
assert_eq!(FileOp::Flush.to_string(), "flush");
assert_eq!(FileOp::Read.to_string(), "read");
assert_eq!(FileOp::Open.to_string(), "open");
assert_eq!(FileOp::Stat.to_string(), "stat");
assert_eq!(FileOp::Copy.to_string(), "copy");
assert_eq!(FileOp::Remove.to_string(), "remove");
assert_eq!(FileOp::Rename.to_string(), "rename");
assert_eq!(FileOp::Fsync.to_string(), "fsync");
assert_eq!(FileOp::Other("symlink").to_string(), "symlink");
}
#[test]
fn dtype_mismatch_payload() {
let p = DtypeMismatchPayload::new(Dtype::F32, Dtype::I32);
assert_eq!(p.expected(), Dtype::F32);
assert_eq!(p.got(), Dtype::I32);
assert_eq!(p.to_string(), "expected F32, got I32");
let e = Error::DtypeMismatch(p);
assert_eq!(e.to_string(), "dtype mismatch: expected F32, got I32");
}
#[test]
fn ffi_null_handle_payload() {
let p = FfiNullHandlePayload::new("mlx_array_new_float32");
assert_eq!(p.fn_name(), "mlx_array_new_float32");
assert_eq!(
p.to_string(),
"FFI: mlx_array_new_float32 returned NULL handle"
);
assert_eq!(
Error::FfiNullHandle(p).to_string(),
"FFI: mlx_array_new_float32 returned NULL handle"
);
}
#[test]
fn missing_field_payload() {
let p = MissingFieldPayload::new("SentencePieceTokenizer", "model.unk_id");
assert_eq!(p.type_name(), "SentencePieceTokenizer");
assert_eq!(p.field(), "model.unk_id");
assert_eq!(
p.to_string(),
"SentencePieceTokenizer: missing required field `model.unk_id`"
);
}
#[test]
fn arithmetic_overflow_no_operands() {
let p = ArithmeticOverflowPayload::new("vocab_size_base + added", "u32");
assert_eq!(p.context(), "vocab_size_base + added");
assert_eq!(p.op_type(), "u32");
assert!(p.operands().is_empty());
assert_eq!(p.to_string(), "vocab_size_base + added: overflow (u32)");
}
#[test]
fn arithmetic_overflow_with_operands() {
let p =
ArithmeticOverflowPayload::with_operands("a * b", "usize", [("a", 1u64 << 32), ("b", 4u64)]);
assert_eq!(p.context(), "a * b");
assert_eq!(p.op_type(), "usize");
assert_eq!(p.operands(), &[("a", 1u64 << 32), ("b", 4u64)]);
assert_eq!(
p.to_string(),
"a * b: overflow (usize) with operands a=4294967296, b=4"
);
}
#[test]
fn empty_input_payload() {
let p = EmptyInputPayload::new("value_and_grad: argnums");
assert_eq!(p.context(), "value_and_grad: argnums");
assert_eq!(
p.to_string(),
"value_and_grad: argnums is empty (at least one element required)"
);
}
#[test]
fn invariant_violation_payload() {
let p = InvariantViolationPayload::new("train: steps_per_eval", "must be >= 1");
assert_eq!(p.context(), "train: steps_per_eval");
assert_eq!(p.requirement(), "must be >= 1");
assert_eq!(p.to_string(), "backend: train: steps_per_eval must be >= 1");
}
#[test]
fn rank_mismatch_payload() {
let p = RankMismatchPayload::new("token_embeddings must be rank-3", 2, vec![4, 8]);
assert_eq!(p.context(), "token_embeddings must be rank-3");
assert_eq!(p.actual(), 2);
assert_eq!(p.actual_shape(), &[4usize, 8]);
assert_eq!(
p.to_string(),
"shape mismatch: token_embeddings must be rank-3: got rank 2 (shape [4, 8])"
);
}
#[test]
fn length_mismatch_payload() {
let p = LengthMismatchPayload::new("pad: axes vs low/high", 3, 2);
assert_eq!(p.context(), "pad: axes vs low/high");
assert_eq!(p.expected(), 3);
assert_eq!(p.actual(), 2);
assert_eq!(
p.to_string(),
"shape mismatch: pad: axes vs low/high: expected length 3, got 2"
);
}
#[test]
fn out_of_range_payload() {
let p = OutOfRangePayload::new("top_k: parameter", "must be in (0, vocab_size)", "1024");
assert_eq!(p.context(), "top_k: parameter");
assert_eq!(p.requirement(), "must be in (0, vocab_size)");
assert_eq!(p.value(), "1024");
assert_eq!(
p.to_string(),
"shape mismatch: top_k: parameter: must be in (0, vocab_size), got 1024"
);
}
#[test]
fn file_io_payload() {
let inner = IoError::new(ErrorKind::NotFound, "no such file");
let p = FileIoPayload::new(
"load_audio",
FileOp::Read,
PathBuf::from("/tmp/a.wav"),
inner,
);
assert_eq!(p.context(), "load_audio");
assert_eq!(p.op(), FileOp::Read);
assert_eq!(p.path(), std::path::Path::new("/tmp/a.wav"));
assert_eq!(p.inner().to_string(), "no such file");
assert_eq!(
p.to_string(),
"io: load_audio: read /tmp/a.wav: no such file"
);
let src = std::error::Error::source(&p).expect("source present");
assert_eq!(src.to_string(), "no such file");
}
#[test]
fn multi_length_mismatch_payload() {
let lengths = vec![("axes", 3usize), ("low", 2), ("high", 3)];
let p = MultiLengthMismatchPayload::new("pad: axes/low/high", lengths.clone());
assert_eq!(p.context(), "pad: axes/low/high");
assert_eq!(p.lengths(), lengths.as_slice());
assert_eq!(
p.to_string(),
"shape mismatch: pad: axes/low/high: length mismatch — axes=3, low=2, high=3"
);
}
#[test]
fn multi_length_mismatch_single_entry_no_separator() {
let p = MultiLengthMismatchPayload::new("x", vec![("only", 7usize)]);
assert_eq!(p.to_string(), "shape mismatch: x: length mismatch — only=7");
}
#[test]
fn mlx_op_kind_display_all_arms() {
assert_eq!(MlxOpKind::Matmul.to_string(), "matmul");
assert_eq!(MlxOpKind::Reshape.to_string(), "reshape");
assert_eq!(MlxOpKind::Broadcast.to_string(), "broadcast");
assert_eq!(MlxOpKind::Shape.to_string(), "shape");
assert_eq!(MlxOpKind::Slice.to_string(), "slice");
assert_eq!(MlxOpKind::Concat.to_string(), "concat");
assert_eq!(MlxOpKind::Gather.to_string(), "gather");
assert_eq!(MlxOpKind::Scatter.to_string(), "scatter");
assert_eq!(MlxOpKind::Take.to_string(), "take");
assert_eq!(MlxOpKind::Fft.to_string(), "fft");
assert_eq!(MlxOpKind::Quantize.to_string(), "quantize");
assert_eq!(MlxOpKind::Dequantize.to_string(), "dequantize");
assert_eq!(MlxOpKind::Conv.to_string(), "conv");
assert_eq!(MlxOpKind::Pool.to_string(), "reduce");
assert_eq!(MlxOpKind::Eval.to_string(), "eval");
assert_eq!(MlxOpKind::Sort.to_string(), "sort");
assert_eq!(MlxOpKind::ArgSort.to_string(), "argsort");
assert_eq!(MlxOpKind::Norm.to_string(), "norm");
assert_eq!(MlxOpKind::Linalg.to_string(), "linalg");
assert_eq!(MlxOpKind::Random.to_string(), "random");
assert_eq!(MlxOpKind::Transform.to_string(), "transform");
assert_eq!(MlxOpKind::Elementwise.to_string(), "elementwise");
assert_eq!(MlxOpKind::System.to_string(), "system");
assert_eq!(MlxOpKind::Positional.to_string(), "positional");
assert_eq!(MlxOpKind::Distributed.to_string(), "distributed");
assert_eq!(MlxOpKind::Io.to_string(), "io");
assert_eq!(
MlxOpKind::Other(SmolStr::new("weird_op")).to_string(),
"weird_op"
);
}
#[test]
fn mlx_op_payload() {
let p = MlxOpPayload::new(MlxOpKind::Reshape, "[reshape] Cannot reshape array");
assert_eq!(p.op(), &MlxOpKind::Reshape);
assert_eq!(p.message(), "[reshape] Cannot reshape array");
assert_eq!(p.to_string(), "mlx reshape: [reshape] Cannot reshape array");
assert_eq!(
Error::MlxOp(p).to_string(),
"mlx reshape: [reshape] Cannot reshape array"
);
}
#[test]
fn missing_key_payload() {
let p = MissingKeyPayload::new("dequantize_weights: missing .weight", "layers.0");
assert_eq!(p.context(), "dequantize_weights: missing .weight");
assert_eq!(p.key(), "layers.0");
assert_eq!(
p.to_string(),
"dequantize_weights: missing .weight: key `layers.0` not found"
);
}
#[test]
fn unknown_enum_value_payload() {
let supported: &'static [&'static str] = &["mean", "max"];
let p = UnknownEnumValuePayload::new("PoolingStrategy", "median", supported);
assert_eq!(p.type_name(), "PoolingStrategy");
assert_eq!(p.value(), "median");
assert_eq!(p.supported(), supported);
assert_eq!(
p.to_string(),
r#"PoolingStrategy: unknown value `median` (supported: ["mean", "max"])"#
);
}
#[test]
fn non_finite_scalar_payload() {
let p = NonFiniteScalarPayload::new("LearningRate: resolved value at step 5", f64::INFINITY);
assert_eq!(p.context(), "LearningRate: resolved value at step 5");
assert!(p.value().is_infinite() && p.value() > 0.0);
assert_eq!(
p.to_string(),
"LearningRate: resolved value at step 5: value is non-finite (NaN or Inf): inf"
);
}
#[test]
fn key_collision_payload() {
let p = KeyCollisionPayload::new("awq checkpoint", "qweight");
assert_eq!(p.context(), "awq checkpoint");
assert_eq!(p.key(), "qweight");
assert_eq!(p.to_string(), "awq checkpoint: key `qweight` collides");
}
#[test]
fn interior_nul_payload() {
let p = InteriorNulPayload::new("save_safetensors", "array key");
assert_eq!(p.context(), "save_safetensors");
assert_eq!(p.bytes_kind(), "array key");
assert_eq!(
p.to_string(),
"save_safetensors: array key contains an interior NUL byte"
);
}
#[test]
fn cap_exceeded_payload() {
let p = CapExceededPayload::new("decode_audio", "MAX_DECODED_SAMPLES", 1000, 4096);
assert_eq!(p.context(), "decode_audio");
assert_eq!(p.cap_name(), "MAX_DECODED_SAMPLES");
assert_eq!(p.cap(), 1000);
assert_eq!(p.observed(), 4096);
assert_eq!(
p.to_string(),
"decode_audio: observed 4096 exceeds cap MAX_DECODED_SAMPLES (1000)"
);
}
#[test]
fn shape_pair_mismatch_payload() {
let p = ShapePairMismatchPayload::new("attn", vec![2usize, 8, 64], vec![2usize, 4, 64]);
assert_eq!(p.context(), "attn");
assert_eq!(p.expected(), &[2usize, 8, 64]);
assert_eq!(p.actual(), &[2usize, 4, 64]);
assert_eq!(
p.to_string(),
"shape mismatch: attn: expected [2, 8, 64], got [2, 4, 64]"
);
}
#[test]
fn divisibility_constraint_payload() {
let p = DivisibilityConstraintPayload::new("awq", "in_features", 100, "group_size", 32);
assert_eq!(p.context(), "awq");
assert_eq!(p.name_dividend(), "in_features");
assert_eq!(p.name_divisor(), "group_size");
assert_eq!(p.dividend(), 100);
assert_eq!(p.divisor(), 32);
assert_eq!(
p.to_string(),
"awq: in_features (100) must be divisible by group_size (32)"
);
}
#[test]
fn unsupported_dtype_payload() {
let supported: &'static [Dtype] = &[Dtype::F32, Dtype::F16];
let p = UnsupportedDtypePayload::new("Adam: weights must be floating", Dtype::I32, supported);
assert_eq!(p.context(), "Adam: weights must be floating");
assert_eq!(p.dtype(), Dtype::I32);
assert_eq!(p.supported(), supported);
assert_eq!(
p.to_string(),
"Adam: weights must be floating: unsupported dtype I32 (supported: [F32, F16])"
);
}
#[test]
fn alloc_failure_payload() {
let inner: TryReserveError = Vec::<u8>::new()
.try_reserve_exact(usize::MAX)
.expect_err("usize::MAX reservation must fail");
let inner_display = inner.to_string();
let p = AllocFailurePayload::new("decode_audio", "samples", 4096, inner);
assert_eq!(p.context(), "decode_audio");
assert_eq!(p.item(), "samples");
assert_eq!(p.count(), 4096);
let s = p.to_string();
let controlled = s
.strip_suffix(&inner_display)
.expect("inner Display must be the message suffix");
assert_eq!(
controlled,
"decode_audio: reservation for 4096 samples failed: "
);
let src = std::error::Error::source(&p).expect("source present");
assert_eq!(src.to_string(), inner_display);
assert_eq!(p.inner().to_string(), inner_display);
}
#[test]
fn parse_payload() {
let inner = IoError::new(ErrorKind::InvalidData, "bad json at line 3");
let p = ParsePayload::new("load_config", "JSON", inner);
assert_eq!(p.context(), "load_config");
assert_eq!(p.input_kind(), "JSON");
assert_eq!(p.inner().to_string(), "bad json at line 3");
assert_eq!(
p.to_string(),
"load_config: parse JSON failed: bad json at line 3"
);
let src = std::error::Error::source(&p).expect("source present");
assert_eq!(src.to_string(), "bad json at line 3");
}
#[test]
fn external_op_payload() {
let inner = IoError::other("device busy");
let p = ExternalOpPayload::new("play_audio", "cpal stream", inner);
assert_eq!(p.context(), "play_audio");
assert_eq!(p.op_kind(), "cpal stream");
assert_eq!(p.inner().to_string(), "device busy");
assert_eq!(
p.to_string(),
"play_audio: external cpal stream failed: device busy"
);
let src = std::error::Error::source(&p).expect("source present");
assert_eq!(src.to_string(), "device busy");
}
#[test]
fn bounded_decode_payload() {
let p = BoundedDecodePayload::new("decode_wav", 1024, 5000);
assert_eq!(p.context(), "decode_wav");
assert_eq!(p.cap(), 1024);
assert_eq!(p.observed(), 5000);
assert_eq!(
p.to_string(),
"decode_wav: decoder produced 5000 elements (cap 1024)"
);
}
#[test]
fn layer_keyed_payload() {
let inner = Error::OutOfMemory;
let p = LayerKeyedPayload::new("layers.3.attn", inner);
assert_eq!(p.layer(), "layers.3.attn");
assert_eq!(p.inner().to_string(), "out of memory");
assert_eq!(p.to_string(), "layer `layers.3.attn`: out of memory");
let src = std::error::Error::source(&p).expect("source present");
assert_eq!(src.to_string(), "out of memory");
}
#[test]
fn malformed_data_payload() {
let p = MalformedDataPayload::new("SentencePiece protobuf", "truncated length-delimited field");
assert_eq!(p.context(), "SentencePiece protobuf");
assert_eq!(p.detail(), "truncated length-delimited field");
assert_eq!(
p.to_string(),
"malformed data: SentencePiece protobuf: truncated length-delimited field"
);
}
#[test]
fn message_only_enum_variants_display() {
assert_eq!(
Error::UnknownDtype(99).to_string(),
"unknown dtype value from mlx: 99"
);
assert_eq!(Error::OutOfMemory.to_string(), "out of memory");
assert_eq!(
Error::NonContiguous.to_string(),
"array is not contiguous; M2 will add .contiguous() to materialize"
);
assert_eq!(Error::MlxC(SmolStr::new("boom")).to_string(), "mlx-c: boom");
assert_eq!(
Error::Backend("legacy".to_string()).to_string(),
"mlx backend: legacy"
);
}
#[test]
fn check_nonzero_rc_with_empty_tls_falls_back_to_mlxc() {
let _ = take_last();
let err = check(7).expect_err("non-zero rc must be an Err");
match err {
Error::MlxC(ref s) => assert_eq!(s.as_str(), "mlx returned 7 with no message"),
other => panic!("expected MlxC fallback, got {other:?}"),
}
assert!(check(0).is_ok());
}
#[test]
fn check_nonzero_rc_drains_tls_error() {
let _ = take_last();
set_last(Error::OutOfMemory);
let err = check(1).expect_err("non-zero rc must be an Err");
assert!(matches!(err, Error::OutOfMemory));
assert!(take_last().is_none());
}
#[test]
fn check_handle_null_ctx_falls_back() {
let _ = take_last();
let handle = mlxrs_sys::mlx_array {
ctx: std::ptr::null_mut(),
};
let err = check_handle(handle).expect_err("null ctx handle must be an Err");
match err {
Error::MlxC(ref s) => assert_eq!(s.as_str(), "mlx returned null handle"),
other => panic!("expected MlxC fallback, got {other:?}"),
}
}
#[test]
fn check_vector_array_handle_null_ctx_falls_back() {
let _ = take_last();
let handle = mlxrs_sys::mlx_vector_array {
ctx: std::ptr::null_mut(),
};
let err = check_vector_array_handle(handle).expect_err("null ctx must be an Err");
match err {
Error::MlxC(ref s) => assert_eq!(s.as_str(), "mlx returned null vector_array handle"),
other => panic!("expected MlxC fallback, got {other:?}"),
}
}
#[test]
fn tls_set_take_and_last_message() {
let _ = take_last();
assert!(last_error_message().is_none());
set_last(Error::MlxC(SmolStr::new("stashed")));
assert_eq!(last_error_message().as_deref(), Some("mlx-c: stashed"));
let taken = take_last().expect("a stashed error");
assert_eq!(taken.to_string(), "mlx-c: stashed");
assert!(take_last().is_none());
assert!(last_error_message().is_none());
}
}
#[cfg(all(test, feature = "lm"))]
mod pure_lm_payload_tests {
use super::*;
use std::io::{Error as IoError, ErrorKind};
#[test]
fn durability_warning_payload() {
let p = DurabilityWarningPayload::new(true, IoError::other("fsync failed"));
assert!(p.committed());
assert_eq!(p.source().to_string(), "fsync failed");
assert_eq!(
p.to_string(),
"save committed but durability fsync failed (committed=true): fsync failed"
);
let src = std::error::Error::source(&p).expect("source present");
assert_eq!(src.to_string(), "fsync failed");
let owned = p.into_source();
assert_eq!(owned.to_string(), "fsync failed");
}
#[test]
fn convert_post_save_partial_payload() {
let warn = IoError::other("dir fsync warned");
let copy_err = Error::FileIo(FileIoPayload::new(
"copy_tokenizer_and_extras",
FileOp::Copy,
PathBuf::from("/dst/tokenizer.json"),
IoError::new(ErrorKind::PermissionDenied, "denied"),
));
let p = ConvertPostSavePartialPayload::new(true, Some(warn), copy_err);
assert!(p.committed());
assert_eq!(
p.save_warning().map(|e| e.to_string()).as_deref(),
Some("dir fsync warned")
);
assert_eq!(
p.copy_error().to_string(),
"io: copy_tokenizer_and_extras: copy /dst/tokenizer.json: denied"
);
assert_eq!(
p.to_string(),
"convert: save committed but post-save extras copy partially failed (committed=true); \
destination directory may be incomplete (missing tokenizer/extras files)"
);
let src = std::error::Error::source(&p).expect("source present");
assert_eq!(
src.to_string(),
"io: copy_tokenizer_and_extras: copy /dst/tokenizer.json: denied"
);
let (committed, save_warning, copy_error) = p.into_parts();
assert!(committed);
assert_eq!(
save_warning.map(|e| e.to_string()).as_deref(),
Some("dir fsync warned")
);
assert_eq!(
copy_error.to_string(),
"io: copy_tokenizer_and_extras: copy /dst/tokenizer.json: denied"
);
}
#[test]
fn convert_post_save_partial_no_save_warning() {
let copy_err = Error::OutOfMemory;
let p = ConvertPostSavePartialPayload::new(true, None, copy_err);
assert!(p.save_warning().is_none());
assert_eq!(p.copy_error().to_string(), "out of memory");
}
#[test]
fn convert_durability_warnings_full() {
let save = IoError::other("save warn");
let file = IoError::other("file warn");
let p = ConvertDurabilityWarnings::new(true, Some(save), Some(file), None);
assert!(p.committed());
assert_eq!(
p.save().map(|e| e.to_string()).as_deref(),
Some("save warn")
);
assert_eq!(
p.post_copy_file().map(|e| e.to_string()).as_deref(),
Some("file warn")
);
assert!(p.post_copy_dir().is_none());
assert_eq!(p.count(), 2);
assert_eq!(
p.first_warning().map(|e| e.to_string()).as_deref(),
Some("save warn")
);
assert_eq!(
p.to_string(),
"convert: save committed but post-save durability warnings (committed=true); \
destination is on-disk and load-correct, but one or more fsync boundaries returned a warning"
);
let src = std::error::Error::source(&p).expect("source present");
assert_eq!(src.to_string(), "save warn");
let (committed, s, fcp, dcp) = p.into_parts();
assert!(committed);
assert_eq!(s.map(|e| e.to_string()).as_deref(), Some("save warn"));
assert_eq!(fcp.map(|e| e.to_string()).as_deref(), Some("file warn"));
assert!(dcp.is_none());
}
#[test]
fn convert_durability_warnings_first_warning_priority_skips_none() {
let dir = IoError::other("dir warn");
let p = ConvertDurabilityWarnings::new(true, None, None, Some(dir));
assert_eq!(p.count(), 1);
assert_eq!(
p.first_warning().map(|e| e.to_string()).as_deref(),
Some("dir warn")
);
assert_eq!(
p.post_copy_dir().map(|e| e.to_string()).as_deref(),
Some("dir warn")
);
let e: Error = p.into();
assert!(matches!(e, Error::ConvertDurabilityWarnings(_)));
}
}