1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
//! Drives the TLS error capture path across multiple threads, and the
//! contiguity guard on buffer-extracting methods.
use std::thread;
#[test]
fn shape_error_returns_err_not_abort() {
// Reshaping a 4-element array to incompatible shape should produce Err, not abort.
// mlx C++ throws `[reshape] …`; the handler's `MlxOpKind::parse_prefix` maps
// the `[reshape]` prefix to `MlxOpKind::Reshape` and emits `Error::MlxOp`.
let r = mlxrs::Array::ones::<f32>(&(2, 2)).and_then(|a| a.reshape(&(3,)));
assert!(
matches!(
&r,
Err(mlxrs::Error::MlxOp(p)) if matches!(p.op(), mlxrs::error::MlxOpKind::Reshape)
),
"expected Err(Error::MlxOp(Reshape)), got {r:?}"
);
}
#[test]
fn each_thread_has_independent_error_slot() {
// Each thread should get its own TLS error capture, no cross-talk.
// Source shape (2,2) has 4 elements; reshape targets must NOT equal 4.
// Use {5, 6, 7, 8} so every thread's reshape is genuinely incompatible.
// Each surfaces as `Error::MlxOp(Reshape)` via the `[reshape]` prefix.
let handles: Vec<_> = (0..4)
.map(|tid| {
thread::spawn(move || {
let r = mlxrs::Array::ones::<f32>(&(2, 2)).and_then(|a| a.reshape(&(5 + tid,)));
assert!(matches!(
&r,
Err(mlxrs::Error::MlxOp(p)) if matches!(p.op(), mlxrs::error::MlxOpKind::Reshape)
));
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn to_vec_rejects_non_contiguous_view() {
// Regression test for the UB pathway: a strided view has the
// same `mlx_array_size` as its source but reordered strides, so
// `from_raw_parts(ptr, size)` reads in the wrong layout (and for broadcast
// views, can read past the allocation entirely). The contiguity guard must
// convert this into Err(NonContiguous).
//
// We construct the view via FFI + from_raw because the safe wrapper doesn't
// expose transpose/broadcast yet. Going through from_raw is also
// the exact pathway reachable from safe code.
use mlxrs_sys::{mlx_array, mlx_array_new, mlx_default_gpu_stream_new, mlx_transpose};
let src = mlxrs::Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &(2, 3)).unwrap();
// SAFETY: `Array::into_raw`'s contract — `src` is a valid owned Array;
// ownership of the raw handle transfers to the caller and `Drop` will not
// run (the handle is freed manually below).
let raw_src = unsafe { src.into_raw() };
// SAFETY: returns this thread's default GPU stream handle, mirroring
// `stream::default_stream`; the test's `#[ctor]`-installed handler is live,
// so a failed init surfaces rather than `printf+exit`.
let stream = unsafe { mlx_default_gpu_stream_new() };
// SAFETY: `mlx_array_new()` returns a fresh empty out-param handle (NULL
// ctx) per the mlx-c convention; it is populated by the `mlx_transpose`
// call below before any use.
let mut out: mlx_array = unsafe { mlx_array_new() };
// SAFETY: `raw_src` and `stream` are valid handles (not retained by mlx
// past the call); `out` is the fresh out-param allocated above; the rc is
// asserted on the next line.
let rc = unsafe { mlx_transpose(&mut out, raw_src, stream) };
assert_eq!(rc, 0, "mlx_transpose failed");
// SAFETY: `raw_src` is the handle this test owns via `into_raw` (freed
// exactly once here); `mlx_transpose` does not retain it.
unsafe {
let _ = mlxrs_sys::mlx_array_free(raw_src);
}
// SAFETY: `Array::from_raw`'s contract — `out` is a valid handle freshly
// produced by `mlx_transpose`, not aliased elsewhere; the safe `Array`
// now owns it and frees it on `Drop`.
let mut view = unsafe { mlxrs::Array::from_raw(out) };
assert_eq!(view.shape(), vec![3, 2]);
let r = view.to_vec::<f32>();
assert!(
matches!(r, Err(mlxrs::Error::NonContiguous)),
"expected Err(NonContiguous), got {r:?}"
);
let r2 = view.as_slice::<f32>();
assert!(
matches!(r2, Err(mlxrs::Error::NonContiguous)),
"expected Err(NonContiguous), got {r2:?}"
);
}
#[test]
fn to_vec_works_on_contiguous_array() {
// Sanity: the guard does not break the happy path.
let mut a = mlxrs::Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let v = a.to_vec::<f32>().unwrap();
assert_eq!(v, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn from_slice_rejects_negative_i32_dims() {
// Without the IntoShape negative-dim guard, `-1i32 as usize` becomes
// usize::MAX and the shape-product check would multiply that into a value
// that may match data.len() — handing mlx-c a buffer smaller than the
// shape says. Must surface as a typed OutOfRange payload identifying the
// offending dim index and value.
let r = mlxrs::Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[-1i32, 3]);
match r {
Err(mlxrs::Error::OutOfRange(p)) => {
assert_eq!(p.context(), "shape::validate_dims: dim");
assert_eq!(p.requirement(), "must be non-negative");
assert_eq!(p.value(), "dim[0]=-1");
}
other => panic!("expected OutOfRange dim<0, got {other:?}"),
}
}
#[test]
fn from_slice_rejects_negative_i32_slice_dims() {
// Same guard for the &[i32] IntoShape path (escape hatch for runtime shapes).
let dims: &[i32] = &[2, -3];
let r = mlxrs::Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &dims);
match r {
Err(mlxrs::Error::OutOfRange(p)) => {
assert_eq!(p.context(), "shape::validate_dims: dim");
assert_eq!(p.requirement(), "must be non-negative");
assert_eq!(p.value(), "dim[1]=-3");
}
other => panic!("expected OutOfRange dim<0, got {other:?}"),
}
}
#[test]
fn from_slice_rejects_overflowing_shape_product() {
// Three large positive dims whose usize product wraps in release builds.
// `i32::MAX^3 ≈ 9.9e27` >> `usize::MAX ≈ 1.8e19`, so wrapping is guaranteed.
// Without checked_mul, the wrapped value could match data.len() and pass.
let r = mlxrs::Array::from_slice::<f32>(&[1.0], &[i32::MAX, i32::MAX, i32::MAX]);
match r {
Err(mlxrs::Error::ArithmeticOverflow(p)) => {
assert!(
p.context().contains("Array::from_slice: shape product"),
"context should identify shape-product op, got {:?}",
p.context()
);
assert_eq!(p.op_type(), "usize");
let ops = p.operands();
assert!(
ops.iter().any(|(name, _)| *name == "dim_index"),
"operands must carry a `dim_index` entry identifying which multiply tripped, got {ops:?}"
);
}
other => panic!("expected ArithmeticOverflow on shape product, got {other:?}"),
}
}
#[test]
fn slice_rejects_len_ne_ndim() {
// start/stop/strides length must equal a.ndim() — passing empty against a
// 2-D array is the "len != ndim" failure mode, not the dangling-pointer
// one. (The dangling-pointer concern is now handled by dim_ptr's sentinel,
// so the safe-FFI boundary is closed without rejecting 0-D-scalar slicing.)
let a = mlxrs::Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let r = mlxrs::ops::indexing::slice(&a, &[], &[], &[]);
// Typed: `slice` returns `LengthMismatch` when start/stop/strides
// agree with each other (empty here) but disagree with `a.ndim()`.
assert!(
matches!(
r,
Err(mlxrs::Error::LengthMismatch(ref p))
if p.context() == "slice: start/stop/strides length"
&& p.expected() == 2
&& p.actual() == 0
),
"expected Err(LengthMismatch) on len != ndim, got {r:?}"
);
}
#[test]
fn slice_rejects_mismatched_lengths() {
// start/stop/strides must agree on length (one entry per axis).
let a = mlxrs::Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let r = mlxrs::ops::indexing::slice(&a, &[0, 0], &[1], &[1, 1]);
// Typed: when start/stop/strides disagree with each other (not just
// with ndim), `slice` returns `MultiLengthMismatch` with named lengths.
assert!(
matches!(
r,
Err(mlxrs::Error::MultiLengthMismatch(ref p))
if p.context() == "slice: start/stop/strides"
),
"expected Err(MultiLengthMismatch) on length mismatch, got {r:?}"
);
}
#[test]
fn slice_accepts_empty_for_zero_dim_scalar() {
// 0-D scalar input → all three slice arrays must be empty (one entry per
// axis = zero entries). Empty inputs route through dim_ptr's sentinel,
// not rejected.
let empty: [i32; 0] = [];
let a = mlxrs::Array::from_slice::<f32>(&[42.0], &empty).unwrap();
assert_eq!(a.ndim(), 0);
let mut r = mlxrs::ops::indexing::slice(&a, &[], &[], &[]).unwrap();
assert_eq!(r.shape(), Vec::<usize>::new());
assert_eq!(r.item::<f32>().unwrap(), 42.0);
}
#[test]
fn sum_axes_empty_returns_clone() {
// Empty axes = sum over no axes = identity (numpy/mlx semantics). Must
// short-circuit to clone instead of crossing FFI with a dangling pointer.
let mut a = mlxrs::Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut r = mlxrs::ops::reduction::sum_axes(&a, &[], false).unwrap();
assert_eq!(r.shape(), a.shape());
assert_eq!(
r.to_vec::<f32>().unwrap(),
a.to_vec::<f32>().unwrap(),
"sum over no axes should be identity"
);
}
#[test]
fn concatenate_rejects_empty_input() {
// Concatenating zero arrays has no defined result; must reject before FFI.
// Dangling-pointer concern for empty Vec::as_ptr().
let r = mlxrs::ops::shape::concatenate(&[], 0);
assert!(
matches!(r, Err(mlxrs::Error::EmptyInput(_))),
"expected Err(EmptyInput) on empty input, got {r:?}"
);
}
#[test]
fn from_slice_zero_element_uses_sentinel() {
// Zero-element arrays are valid in numpy/mlx. The dangling-pointer concern
// for Rust's `<&[T]>::as_ptr()` on an empty slice still needs a sentinel —
// this exercises the data_ptr helper.
let mut a = mlxrs::Array::from_slice::<f32>(&[], &[0i32]).unwrap();
assert_eq!(a.shape(), vec![0]);
assert_eq!(a.size(), 0);
// 2-D zero-element shape too.
let b = mlxrs::Array::from_slice::<f32>(&[], &[2i32, 0]).unwrap();
assert_eq!(b.shape(), vec![2, 0]);
assert_eq!(b.size(), 0);
// to_vec on a zero-element contiguous array is just an empty Vec.
assert_eq!(a.to_vec::<f32>().unwrap(), Vec::<f32>::new());
}
#[test]
fn from_slice_zero_element_all_element_types() {
// Every Element impl provides its own typed sentinel. Verify each
// compiles + constructs without UB.
let mut b = mlxrs::Array::from_slice::<bool>(&[], &[0i32]).unwrap();
assert_eq!(b.shape(), vec![0]);
assert_eq!(b.to_vec::<bool>().unwrap(), Vec::<bool>::new());
let mut i = mlxrs::Array::from_slice::<i32>(&[], &[0i32]).unwrap();
assert_eq!(i.shape(), vec![0]);
assert_eq!(i.to_vec::<i32>().unwrap(), Vec::<i32>::new());
let mut u = mlxrs::Array::from_slice::<u32>(&[], &[0i32]).unwrap();
assert_eq!(u.shape(), vec![0]);
assert_eq!(u.to_vec::<u32>().unwrap(), Vec::<u32>::new());
let mut h = mlxrs::Array::from_slice::<half::f16>(&[], &[0i32]).unwrap();
assert_eq!(h.shape(), vec![0]);
assert_eq!(h.to_vec::<half::f16>().unwrap(), Vec::<half::f16>::new());
}