mlxrs 0.1.0

Safe Rust bindings for Apple's MLX array framework, with LM, VLM, audio, and embeddings support
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
//! Safe `mlx_closure` wrapper.
//!
//! `mlx_closure` is mlx-c's callable handle: a function-pointer + opaque
//! `void* payload` pair that the autograd / custom-VJP / checkpoint / compile
//! transforms accept as their user-supplied function argument. The trampoline
//! pattern here mirrors `mlx-swift`'s
//! [`new_mlx_closure`](https://github.com/ml-explore/mlx-swift/blob/main/Source/MLX/Cmlx%2BUtil.swift)
//! (`Cmlx+Util.swift`) and the equivalent `pybind` shim on the Python side.
//!
//! ## Lifetime
//!
//! The Rust callable is boxed (`Box<Inner<F>>`) and `Box::into_raw`'d into a
//! stable `*mut c_void` payload pointer. A C destructor (`destroy_payload`)
//! reclaims the box via `Box::from_raw`. `mlx_closure_free` invokes the
//! destructor exactly once (mlx-c's `mlx_closure` is a shared_ptr-backed
//! handle, so dtor runs when the last reference drops, not necessarily at the
//! `mlx_closure_free` call). The [`Closure`] wrapper owns *one* reference to
//! the handle and frees it in [`Drop`]; the payload box is *not* owned by
//! [`Closure`] directly — it is owned by the C++ shared destructor.
//!
//! ## Re-entrancy and panics
//!
//! The trampoline catches Rust panics via [`std::panic::catch_unwind`] and
//! converts them to a non-zero rc — unwinding across the `extern "C"` boundary
//! is undefined behavior. The user function is required `Fn + 'static` (not
//! `FnMut`); aliasing the captured state across re-entrant mlx-c calls is
//! safe because `Fn` mandates `&self` access.

use std::{
  ffi::c_void,
  os::raw::c_int,
  panic::{AssertUnwindSafe, catch_unwind},
  ptr,
};

use crate::{
  Array,
  error::{Error, ParsePayload, Result, ensure_handler_installed},
};

/// Boxed type-erased Rust callable invoked by the mlx-c trampoline.
///
/// `Box<dyn Fn(&[Array]) -> Result<Vec<Array>>>` is itself a fat pointer
/// (vtable + data), so we wrap it in an outer `Box` to land on a stable
/// thin `*mut c_void` (the inner `Box<dyn Fn>` already heap-allocates the
/// closure; the outer `Box` is the indirection layer mlx-c hands back).
pub(crate) type BoxedFn = Box<dyn Fn(&[Array]) -> Result<Vec<Array>> + 'static>;

/// Safe RAII wrapper around an `mlx_closure` that keeps the captured Rust
/// callable alive for the entire lifetime of the C handle.
///
/// Construct via [`Closure::new`]; the returned value owns one reference to
/// the underlying `mlx_closure` and frees it on [`Drop`]. To pass the handle
/// into mlx-c transforms (`mlx_value_and_grad`, `mlx_vjp`, …) use
/// [`Closure::as_raw`], which borrows the handle without transferring
/// ownership. The Rust callable is held alive by the closure's mlx-c
/// destructor, *not* by this struct.
///
/// `Closure` is intentionally `!Send` + `!Sync`: the captured `F` may
/// reference [`crate::Array`] handles (themselves `!Send`), and the mlx-c
/// closure's payload destructor must run on the thread that built it.
pub struct Closure {
  inner: mlxrs_sys::mlx_closure,
}

impl Closure {
  /// Construct a closure from a Rust callable. Returns `Err` if the underlying
  /// `mlx_closure_new_func_payload` allocation fails.
  ///
  /// `F` is required `Fn + 'static` so the mlx-c side can invoke it across
  /// arbitrary later re-entries (including from within `mlx_eval`).
  pub fn new<F>(f: F) -> Result<Self>
  where
    F: Fn(&[Array]) -> Result<Vec<Array>> + 'static,
  {
    ensure_handler_installed();
    // Box the user closure on the heap, then re-box the resulting fat trait-
    // object pointer so the payload we hand to C is a thin `*mut c_void`.
    // SAFETY of pointer round-trip: we recover the same `Box<BoxedFn>` via
    // `Box::from_raw` exactly once, in `destroy_payload`. mlx-c invokes the
    // destructor exactly once when the underlying `shared_ptr` reaches
    // refcount 0.
    let boxed: Box<BoxedFn> = Box::new(Box::new(f));
    let payload_ptr: *mut c_void = Box::into_raw(boxed).cast();

    // SAFETY: `trampoline::<F>` and `destroy_payload` are both `extern "C"`
    // with the exact signatures mlx-c expects. `payload_ptr` is a freshly
    // boxed `Box<BoxedFn>` whose lifetime is transferred to mlx-c by this
    // call: mlx-c IMMEDIATELY wraps it in `std::shared_ptr<void>(payload,
    // dtor)` as the very first statement of its `try` block (see vendored
    // `mlx-c/mlx/c/closure.cpp::mlx_closure_new_func_payload`, line 70).
    // From that point on, the shared_ptr OWNS the payload — even if any
    // later allocation inside the same `try` throws (e.g. the lambda
    // capture or `mlx_closure_new_(cpp_closure)`), the shared_ptr's
    // destructor runs `destroy_payload(payload_ptr)` as part of stack
    // unwinding before the `catch` clause returns a NULL closure to us.
    // Therefore the NULL-ctx return path below MUST NOT reclaim the box
    // ourselves — that would double-free / UAF.
    // In production we call mlx-c directly. In `cfg(test)` builds we route
    // through a swappable function pointer (`test_seam::closure_new_fn`) so
    // unit tests can inject a NULL-returning stub that exercises the
    // `inner.ctx.is_null()` branch where a double-free could otherwise
    // occur — see the `tests::closure_new_returns_err_*` cases in this
    // file. The
    // `#[cfg(test)]` arm defaults to the same FFI symbol; test stubs satisfy
    // the same ABI + ownership contract (see `test_seam` docs), so the
    // unsafe contract is identical between the two arms.
    // SAFETY: `trampoline` and `destroy_payload` have the exact extern "C"
    // signatures mlx-c expects; `payload_ptr` is a freshly leaked
    // `Box<BoxedFn>` whose ownership transfers to mlx-c's shared_ptr per
    // the contract documented above. The `#[cfg(test)]` arm is functionally
    // identical (defaults to the same FFI symbol).
    let inner = unsafe { call_closure_new_ffi(payload_ptr) };
    if inner.ctx.is_null() {
      // mlx-c already owns `payload_ptr` via the
      // `std::shared_ptr<void>(payload, dtor)` it constructed at the top
      // of its `try` block. If the C++ ctor threw post-shared_ptr-
      // construction, the shared_ptr destructor has ALREADY released the
      // payload via `destroy_payload`. Reclaiming with `Box::from_raw`
      // here would be a double-free / UAF.
      //
      // We accept the (tiny) leak on the alternate path where mlx-c
      // returns NULL without ever constructing the shared_ptr (i.e. the
      // `mlx_closure_new_()` infallible sentinel constructor on the
      // catch arm somehow surfaced NULL — not currently observed in any
      // mlx-c codepath but a defensive consideration). Leak is strictly
      // preferable to UAF.
      return Err(crate::error::take_last().unwrap_or(Error::FfiNullHandle(
        crate::error::FfiNullHandlePayload::new("mlx_closure_new_func_payload"),
      )));
    }
    Ok(Self { inner })
  }

  /// Borrow the raw `mlx_closure` handle for a transient FFI call.
  ///
  /// The returned handle MUST NOT be retained past this `&self` borrow —
  /// `Drop` will free the underlying handle. mlx-c transforms that consume a
  /// closure by *value* internally take a shared_ptr copy, so passing
  /// `closure.as_raw()` into e.g. `mlx_value_and_grad` is sound.
  #[inline(always)]
  pub fn as_raw(&self) -> mlxrs_sys::mlx_closure {
    self.inner
  }
}

impl Drop for Closure {
  fn drop(&mut self) {
    // SAFETY: frees the handle this `Closure` owns exactly once. The closure's
    // C++ shared_ptr refcount drops; when it hits 0 the payload destructor
    // we registered runs and reclaims the `Box<BoxedFn>`. Runs during `Drop`
    // so must not touch TLS / panic / unwind across `extern "C"` — the rc is
    // discarded silently per the crate's `Drop` convention.
    unsafe {
      let _ = mlxrs_sys::mlx_closure_free(self.inner);
    }
  }
}

// ──────────────────────── FFI call indirection ────────────────────────

/// Invoke `mlx_closure_new_func_payload` (production) or the test-seam stub
/// (`#[cfg(test)]`). Kept in a single helper so the safety annotation lives in
/// exactly one place — see [`Closure::new`] and the `test_seam` docs for the
/// ownership contract on `payload_ptr`.
///
/// # Safety
/// Caller must ensure `payload_ptr` was produced by `Box::into_raw` on a
/// `Box<BoxedFn>` and that ownership is hereby transferred to mlx-c's
/// `shared_ptr<void>(payload, destroy_payload)`. The `#[cfg(test)]` arm
/// routes through a swappable function pointer that defaults to the same
/// FFI symbol; swapped-in stubs must satisfy the identical ABI + ownership
/// contract.
#[inline]
unsafe fn call_closure_new_ffi(payload_ptr: *mut c_void) -> mlxrs_sys::mlx_closure {
  #[cfg(not(test))]
  // SAFETY: forwarded from caller; this is the production direct-FFI arm.
  unsafe {
    mlxrs_sys::mlx_closure_new_func_payload(Some(trampoline), payload_ptr, Some(destroy_payload))
  }
  #[cfg(test)]
  // SAFETY: forwarded from caller; the seam defaults to the same FFI symbol.
  unsafe {
    (test_seam::closure_new_fn())(Some(trampoline), payload_ptr, Some(destroy_payload))
  }
}

/// Invoke `mlx_closure_custom_new_func_payload` (production) or the test-seam
/// stub (`#[cfg(test)]`). Same single-call-site rationale as
/// [`call_closure_new_ffi`].
///
/// # Safety
/// Caller must ensure `payload_ptr` was produced by `Box::into_raw` on a
/// `Box<BoxedFn3>` and that ownership transfers to mlx-c's `shared_ptr`.
#[inline]
unsafe fn call_closure_custom_new_ffi(payload_ptr: *mut c_void) -> mlxrs_sys::mlx_closure_custom {
  #[cfg(not(test))]
  // SAFETY: forwarded from caller; production direct-FFI arm.
  unsafe {
    mlxrs_sys::mlx_closure_custom_new_func_payload(
      Some(trampoline_custom),
      payload_ptr,
      Some(destroy_payload_3),
    )
  }
  #[cfg(test)]
  // SAFETY: forwarded from caller; seam defaults to the same FFI symbol.
  unsafe {
    (test_seam::closure_custom_new_fn())(
      Some(trampoline_custom),
      payload_ptr,
      Some(destroy_payload_3),
    )
  }
}

// ─────────────────────────── trampoline ───────────────────────────

/// `extern "C"` shim invoked by mlx-c whenever the closure is applied.
///
/// `outputs_out` is an out-parameter slot pre-allocated by the caller (NULL
/// `ctx`); we populate it via `mlx_vector_array_set_data`. `inputs` is owned
/// by mlx-c (we read it; we do NOT free it). `payload` is the `*mut c_void`
/// we registered.
///
/// Returns `0` on success, non-zero on user error or panic. On user error /
/// panic we leave `outputs_out` populated with an empty `mlx_vector_array`
/// (still a valid handle that mlx-c will free) and post a `Backend` message
/// into the TLS error slot so `crate::error::check(rc)` can drain it.
extern "C" fn trampoline(
  outputs_out: *mut mlxrs_sys::mlx_vector_array,
  inputs: mlxrs_sys::mlx_vector_array,
  payload: *mut c_void,
) -> c_int {
  // Wrap the entire body in `catch_unwind` — any panic across `extern "C"`
  // is UB. We restore the panic as a Backend error in the TLS slot.
  let result = catch_unwind(AssertUnwindSafe(|| {
    // SAFETY: `payload` is the `*mut c_void` we stored via `Box::into_raw`
    // (preserved by mlx-c across calls). We cast back to `*const BoxedFn` and
    // borrow — NOT take ownership; the box is reclaimed in `destroy_payload`.
    let f: &BoxedFn = unsafe { &*payload.cast::<BoxedFn>() };

    // Borrow the input handles WITHOUT taking ownership: we build a
    // `Vec<Array>` of fresh handles by copying each element via
    // `mlx_vector_array_get` (refcount bump) — the original `inputs` vector
    // is mlx-c's. We then call the user function with a `&[Array]` borrow.
    let inputs_vec = borrow_inputs(inputs)?;

    // Invoke user function.
    let outputs = f(&inputs_vec)?;

    // Marshal outputs back into the out-param `mlx_vector_array`. We use
    // `mlx_vector_array_set_data` which copies the array handles into the
    // existing vector slot (refcount bump on each).
    write_outputs(outputs_out, &outputs)?;
    Ok::<(), Error>(())
  }));
  match result {
    Ok(Ok(())) => 0,
    Ok(Err(e)) => {
      // Stash the user error in TLS so `check(rc)` drains it.
      crate::error::set_last(e);
      // Populate the out-param with an empty vector so mlx-c's later
      // `mlx_vector_array_free` is a defined no-op.
      // SAFETY: `outputs_out` is the caller-owned pre-allocated handle slot;
      // writing an empty vector handle is the safe way to leave it.
      unsafe {
        if !outputs_out.is_null() {
          *outputs_out = mlxrs_sys::mlx_vector_array_new();
        }
      }
      1
    }
    Err(panic_payload) => {
      let msg = if let Some(s) = panic_payload.downcast_ref::<&'static str>() {
        (*s).to_string()
      } else if let Some(s) = panic_payload.downcast_ref::<String>() {
        s.clone()
      } else {
        "panic in mlxrs::transforms closure trampoline".to_string()
      };
      crate::error::set_last(Error::Parse(ParsePayload::new(
        "transforms::closure trampoline: caught panic",
        "Rust closure panic payload",
        std::io::Error::other(msg),
      )));
      // SAFETY: same as above — leave the out-param holding an empty handle.
      unsafe {
        if !outputs_out.is_null() {
          *outputs_out = mlxrs_sys::mlx_vector_array_new();
        }
      }
      1
    }
  }
}

/// `extern "C"` destructor mlx-c invokes when the closure's last `shared_ptr`
/// copy drops. Reclaims the `Box<BoxedFn>` we leaked at construction.
extern "C" fn destroy_payload(payload: *mut c_void) {
  if payload.is_null() {
    return;
  }
  // SAFETY: `payload` is the `*mut c_void` produced by `Box::into_raw` on a
  // `Box<BoxedFn>` in `Closure::new`. mlx-c calls this destructor exactly
  // once per registration. Box ownership returns here and is dropped.
  // Wrap drop in `catch_unwind` so a panicking user closure-destructor
  // cannot unwind across the C++ boundary.
  let _ = catch_unwind(AssertUnwindSafe(|| {
    // SAFETY: see fn doc above — payload is a Box<BoxedFn> we created.
    let _: Box<BoxedFn> = unsafe { Box::from_raw(payload.cast::<BoxedFn>()) };
  }));
}

// ─────────────────────── vector_array marshalling ───────────────────────

/// Re-export of the shared FFI helpers (`crate::ffi`). `VectorArrayGuard`
/// (RAII free) and `drain_vector` (pure read) were previously defined
/// locally here and re-used by sibling transforms (`autograd`, `custom`,
/// `checkpoint`) via `transforms::closure::{...}`; they now live in the
/// shared module (audit issue #259) but stay reachable on this path so
/// those consumers — and `borrow_inputs` below — are unaffected.
pub(crate) use crate::ffi::{VectorArrayGuard, drain_vector};

/// Borrow the input handles of a `mlx_vector_array` as a `Vec<Array>` of
/// fresh refcount-shared copies. Same effect as [`drain_vector`] but used
/// inside the trampoline where the source `vec` is owned by mlx-c (we MUST
/// NOT free it).
fn borrow_inputs(vec: mlxrs_sys::mlx_vector_array) -> Result<Vec<Array>> {
  drain_vector(vec)
}

/// Pack a `&[Array]` into a freshly allocated `mlx_vector_array` and write
/// its handle into `out`. mlx-c copies refcount-shared array handles into
/// the new vector storage. The previous contents of `*out` are leaked — mlx-c
/// gives us a NULL-ctx slot on first entry, so this is a safe overwrite.
fn write_outputs(out: *mut mlxrs_sys::mlx_vector_array, outputs: &[Array]) -> Result<()> {
  // Collect raw handles into a contiguous `Vec<mlx_array>` for FFI.
  let raw: Vec<mlxrs_sys::mlx_array> = outputs.iter().map(|a| a.0).collect();
  let data_ptr = if raw.is_empty() {
    ptr::null()
  } else {
    raw.as_ptr()
  };
  // SAFETY: `out` is the trampoline's caller-owned out-param. Per mlx-c's
  // convention on entry it is a NULL-ctx handle; we replace it with a fresh
  // vector populated from `raw` (mlx-c copies the array handles, refcount-
  // bumping each).
  unsafe {
    *out = mlxrs_sys::mlx_vector_array_new_data(data_ptr, raw.len());
  }
  // SAFETY: post-write null-check — the constructor is fallible.
  if unsafe { (*out).ctx.is_null() } && !outputs.is_empty() {
    return Err(crate::error::take_last().unwrap_or(Error::FfiNullHandle(
      crate::error::FfiNullHandlePayload::new("mlx_vector_array_new_data"),
    )));
  }
  Ok(())
}

// ─────────────────────── caller-side helpers ───────────────────────

/// Pack a `&[Array]` (or `&[&Array]` via iterator) into a fresh
/// `mlx_vector_array`. Returns the handle wrapped in a guard for RAII free.
pub(crate) fn vector_array_from_borrow(arrays: &[&Array]) -> Result<VectorArrayGuard> {
  ensure_handler_installed();
  let raw: Vec<mlxrs_sys::mlx_array> = arrays.iter().map(|a| a.0).collect();
  let data_ptr = if raw.is_empty() {
    ptr::null()
  } else {
    raw.as_ptr()
  };
  // SAFETY: `data_ptr` is either NULL (n==0, mlx-c builds an empty vector) or
  // a valid pointer to `raw.len()` borrowed handles live for this call (mlx-c
  // copies into the new vector, refcount-bumping each).
  let vec = unsafe { mlxrs_sys::mlx_vector_array_new_data(data_ptr, raw.len()) };
  if vec.ctx.is_null() {
    return Err(crate::error::take_last().unwrap_or(Error::FfiNullHandle(
      crate::error::FfiNullHandlePayload::new("mlx_vector_array_new_data"),
    )));
  }
  Ok(VectorArrayGuard(vec))
}

/// Same as [`vector_array_from_borrow`] but takes `&[Array]` (most-common
/// caller convenience).
pub(crate) fn vector_array_from_slice(arrays: &[Array]) -> Result<VectorArrayGuard> {
  ensure_handler_installed();
  let raw: Vec<mlxrs_sys::mlx_array> = arrays.iter().map(|a| a.0).collect();
  let data_ptr = if raw.is_empty() {
    ptr::null()
  } else {
    raw.as_ptr()
  };
  // SAFETY: `data_ptr` is either NULL or a valid pointer to `raw.len()`
  // borrowed handles live for this call; mlx-c copies into the new vector.
  let vec = unsafe { mlxrs_sys::mlx_vector_array_new_data(data_ptr, raw.len()) };
  if vec.ctx.is_null() {
    return Err(crate::error::take_last().unwrap_or(Error::FfiNullHandle(
      crate::error::FfiNullHandlePayload::new("mlx_vector_array_new_data"),
    )));
  }
  Ok(VectorArrayGuard(vec))
}

/// RAII guard for a temporary `mlx_closure_value_and_grad`.
pub(crate) struct ClosureValueAndGradGuard(pub(crate) mlxrs_sys::mlx_closure_value_and_grad);
impl ClosureValueAndGradGuard {
  /// Borrow the raw handle for a transient FFI call. Must not outlive `self`.
  #[allow(dead_code)]
  #[inline(always)]
  pub(crate) const fn as_raw(&self) -> mlxrs_sys::mlx_closure_value_and_grad {
    self.0
  }
}
impl Drop for ClosureValueAndGradGuard {
  fn drop(&mut self) {
    // SAFETY: same discipline as `VectorArrayGuard` — single-owner free,
    // rc discarded.
    unsafe {
      let _ = mlxrs_sys::mlx_closure_value_and_grad_free(self.0);
    }
  }
}

/// RAII guard for a temporary `mlx_closure_custom`.
pub(crate) struct ClosureCustomGuard(pub(crate) mlxrs_sys::mlx_closure_custom);
impl ClosureCustomGuard {
  /// Borrow the raw handle for a transient FFI call. Must not outlive `self`.
  #[allow(dead_code)]
  #[inline(always)]
  pub(crate) const fn as_raw(&self) -> mlxrs_sys::mlx_closure_custom {
    self.0
  }
}
impl Drop for ClosureCustomGuard {
  fn drop(&mut self) {
    // SAFETY: same discipline as `VectorArrayGuard` — single-owner free.
    unsafe {
      let _ = mlxrs_sys::mlx_closure_custom_free(self.0);
    }
  }
}

/// RAII guard for a temporary `mlx_closure` that we own (e.g. the result of
/// `mlx_checkpoint` / `mlx_custom_function`).
pub(crate) struct RawClosureGuard(pub(crate) mlxrs_sys::mlx_closure);
impl RawClosureGuard {
  /// Borrow the raw handle for a transient FFI call. Must not outlive `self`.
  #[allow(dead_code)]
  #[inline(always)]
  pub(crate) const fn as_raw(&self) -> mlxrs_sys::mlx_closure {
    self.0
  }
}
impl Drop for RawClosureGuard {
  fn drop(&mut self) {
    // SAFETY: same discipline as `VectorArrayGuard` — single-owner free.
    unsafe {
      let _ = mlxrs_sys::mlx_closure_free(self.0);
    }
  }
}

/// Build a custom-VJP `mlx_closure_custom` from a Rust 3-input function.
///
/// The contract matches `mlx_custom_vjp`'s `fun_vjp` argument:
/// `(primals, cotangents, outputs) -> grads` — the same positional order
/// `mlx::core::CustomTransforms::vjp` invokes its `vjp_fun_` callback with
/// (`mlx/primitives.cpp::CustomTransforms::vjp`).
pub(crate) fn closure_custom_new<F>(f: F) -> Result<ClosureCustomGuard>
where
  F: Fn(&[Array], &[Array], &[Array]) -> Result<Vec<Array>> + 'static,
{
  ensure_handler_installed();
  let boxed: Box<BoxedFn3> = Box::new(Box::new(f));
  let payload_ptr: *mut c_void = Box::into_raw(boxed).cast();
  // SAFETY: trampoline + destructor have correct signatures. `payload_ptr` is
  // a freshly leaked `Box<BoxedFn3>` whose lifetime is transferred to mlx-c
  // by this call: mlx-c IMMEDIATELY wraps it in
  // `std::shared_ptr<void>(payload, dtor)` as the first statement of its
  // `try` block (see vendored
  // `mlx-c/mlx/c/closure.cpp::mlx_closure_custom_new_func_payload`,
  // line 471). From that point on the shared_ptr OWNS the payload — even
  // if any later allocation inside the same `try` throws, the shared_ptr
  // destructor runs `destroy_payload_3(payload_ptr)` during unwinding
  // before the `catch` clause returns NULL. Therefore the NULL-ctx return
  // path below MUST NOT reclaim the box — that would double-free / UAF.
  // Production: direct FFI; tests: route through the swappable seam so the
  // NULL-ctx branch (where a double-free could otherwise occur) is
  // exercised deterministically by `tests::closure_custom_new_returns_err_*`.
  // SAFETY: `payload_ptr` is a freshly leaked `Box<BoxedFn3>` whose
  // ownership transfers to mlx-c per the contract documented above.
  let inner = unsafe { call_closure_custom_new_ffi(payload_ptr) };
  if inner.ctx.is_null() {
    // mlx-c already owns `payload_ptr` via its `shared_ptr<void>`; the
    // shared_ptr destructor has run (or will run on the natural drop
    // path) and released the payload via `destroy_payload_3`. DO NOT
    // reclaim manually — that would be a double-free / UAF. Same
    // rationale as `Closure::new` above: accept a (tiny) leak on the
    // unobserved-NULL path over a deterministic UAF.
    return Err(crate::error::take_last().unwrap_or(Error::FfiNullHandle(
      crate::error::FfiNullHandlePayload::new("mlx_closure_custom_new_func_payload"),
    )));
  }
  Ok(ClosureCustomGuard(inner))
}

pub(crate) type BoxedFn3 =
  Box<dyn Fn(&[Array], &[Array], &[Array]) -> Result<Vec<Array>> + 'static>;

// MLX core `CustomTransforms::vjp` invokes its `vjp_fun_` callback with the
// positional argument order `(primals, cotangents, outputs)` — see
// `mlx/primitives.cpp::CustomTransforms::vjp` upstream:
//
// ```cpp
// auto all_vjps = vjp_fun_(inputs, cotangents, outputs);
// ```
//
// The Rust trampoline therefore names its second / third `mlx_vector_array`
// arguments `cotangents` / `outputs` to match — the user closure receives the
// triple in this same order via `f(&primals, &cotangents, &outputs)`.
extern "C" fn trampoline_custom(
  outputs_out: *mut mlxrs_sys::mlx_vector_array,
  primals: mlxrs_sys::mlx_vector_array,
  cotangents: mlxrs_sys::mlx_vector_array,
  outputs: mlxrs_sys::mlx_vector_array,
  payload: *mut c_void,
) -> c_int {
  let result = catch_unwind(AssertUnwindSafe(|| {
    // SAFETY: `payload` was produced by `Box::into_raw(Box<BoxedFn3>)` and
    // is preserved by mlx-c; borrow without taking ownership.
    let f: &BoxedFn3 = unsafe { &*payload.cast::<BoxedFn3>() };
    let p = borrow_inputs(primals)?;
    let c = borrow_inputs(cotangents)?;
    let o = borrow_inputs(outputs)?;
    let grads = f(&p, &c, &o)?;
    write_outputs(outputs_out, &grads)?;
    Ok::<(), Error>(())
  }));
  match result {
    Ok(Ok(())) => 0,
    Ok(Err(e)) => {
      crate::error::set_last(e);
      // SAFETY: leave out-param holding an empty vector handle.
      unsafe {
        if !outputs_out.is_null() {
          *outputs_out = mlxrs_sys::mlx_vector_array_new();
        }
      }
      1
    }
    Err(panic_payload) => {
      let msg = if let Some(s) = panic_payload.downcast_ref::<&'static str>() {
        (*s).to_string()
      } else if let Some(s) = panic_payload.downcast_ref::<String>() {
        s.clone()
      } else {
        "panic in mlxrs::transforms custom-VJP trampoline".to_string()
      };
      crate::error::set_last(Error::Parse(ParsePayload::new(
        "transforms::custom_vjp trampoline: caught panic",
        "Rust closure panic payload",
        std::io::Error::other(msg),
      )));
      // SAFETY: leave out-param holding an empty vector handle.
      unsafe {
        if !outputs_out.is_null() {
          *outputs_out = mlxrs_sys::mlx_vector_array_new();
        }
      }
      1
    }
  }
}

extern "C" fn destroy_payload_3(payload: *mut c_void) {
  if payload.is_null() {
    return;
  }
  let _ = catch_unwind(AssertUnwindSafe(|| {
    // SAFETY: payload is a Box<BoxedFn3> we created; reclaim ownership once.
    let _: Box<BoxedFn3> = unsafe { Box::from_raw(payload.cast::<BoxedFn3>()) };
  }));
}

// ─────────────────────────── test seam ───────────────────────────

/// Test-only function-pointer indirection over the mlx-c closure constructors.
///
/// Production builds (`#[cfg(not(test))]`) call
/// `mlxrs_sys::mlx_closure_*_new_func_payload` directly: zero indirection,
/// zero overhead. The compiler eliminates this module entirely.
///
/// In `#[cfg(test)]` builds the constructor call in `Closure::new` /
/// `closure_custom_new` routes through an [`AtomicPtr`]-backed function
/// pointer slot here, defaulting to the real mlx-c symbol. The unit tests
/// below swap in a deterministic stub that simulates mlx-c's
/// shared_ptr-then-throw failure mode (invokes the destructor we registered,
/// then returns NULL ctx) to exercise the `inner.ctx.is_null()` branch
/// where a double-free could otherwise occur. Without this seam the
/// NULL-ctx branch is unreachable from Rust (we cannot inject OOM into
/// mlx-c) and CI would be blind to a regression that re-introduced the
/// reclaim.
///
/// ## Serialization + non-reentrant install lock
///
/// Each per-constructor slot now has TWO collaborators:
///
/// * `*_slot()` — an `AtomicPtr<()>` holding the currently-installed fn
///   pointer. Read via lock-free `load(Acquire)` in `*_fn()`, which the
///   `call_*_ffi` helpers invoke synchronously during `Closure::new` /
///   `closure_custom_new`. Lock-free reads guarantee no deadlock if a test
///   has the install lock held: the install lock and the slot are
///   independent.
/// * `*_install_lock()` — a `Mutex<()>` held by the [`ScopedClosureCtor`] /
///   [`ScopedCustomCtor`] guard for its ENTIRE lifetime (install + use +
///   restore). This makes the install→use→restore sequence atomic w.r.t.
///   any other guard. Combined with the [`serial_guard`] mutex inside the
///   test module (which every seam test acquires as its first action),
///   the seam tests run strictly one-at-a-time even under default parallel
///   `cargo test`; the install lock is defense-in-depth in case future
///   tests forget to acquire `serial_guard`.
///
/// The earlier design held the slot mutex only across the fn-pointer
/// swap, then released it before the guarded `Closure::new` call. Two
/// parallel seam tests could install conflicting stubs between install and
/// use, and non-LIFO drops could restore an older stub atop a newer
/// guard's install. Both windows are closed by holding the install lock
/// for the entire guard lifetime.
#[cfg(test)]
pub(crate) mod test_seam {
  use std::sync::{
    Mutex, MutexGuard, OnceLock,
    atomic::{AtomicPtr, Ordering},
  };

  use super::*;

  /// Function-pointer type matching `mlx_closure_new_func_payload`'s ABI.
  pub(crate) type ClosureNewFn = unsafe extern "C" fn(
    fun: Option<
      unsafe extern "C" fn(
        *mut mlxrs_sys::mlx_vector_array,
        mlxrs_sys::mlx_vector_array,
        *mut c_void,
      ) -> c_int,
    >,
    payload: *mut c_void,
    dtor: Option<unsafe extern "C" fn(*mut c_void)>,
  ) -> mlxrs_sys::mlx_closure;

  /// Function-pointer type matching `mlx_closure_custom_new_func_payload`'s ABI.
  pub(crate) type ClosureCustomNewFn = unsafe extern "C" fn(
    fun: Option<
      unsafe extern "C" fn(
        *mut mlxrs_sys::mlx_vector_array,
        mlxrs_sys::mlx_vector_array,
        mlxrs_sys::mlx_vector_array,
        mlxrs_sys::mlx_vector_array,
        *mut c_void,
      ) -> c_int,
    >,
    payload: *mut c_void,
    dtor: Option<unsafe extern "C" fn(*mut c_void)>,
  ) -> mlxrs_sys::mlx_closure_custom;

  /// Slot storing the currently-installed `ClosureNewFn` pointer.
  ///
  /// Stored as `AtomicPtr<()>` so reads (in [`closure_new_fn`]) don't need
  /// to take a lock — critical because that read happens during
  /// `Closure::new` while a test may be holding the install lock for the
  /// guard's lifetime; locking here would deadlock.
  fn closure_new_slot() -> &'static AtomicPtr<()> {
    static SLOT: OnceLock<AtomicPtr<()>> = OnceLock::new();
    SLOT.get_or_init(|| AtomicPtr::new(mlxrs_sys::mlx_closure_new_func_payload as *mut ()))
  }

  /// Mutex held by a `ScopedClosureCtor` guard for its entire lifetime to
  /// serialize install→use→restore against any other guard installation.
  /// `Mutex<()>` (not `Mutex<FnPtr>`) so the held guard never blocks the
  /// lock-free `closure_new_fn()` reads on the slot.
  fn closure_new_install_lock() -> &'static Mutex<()> {
    static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
    LOCK.get_or_init(|| Mutex::new(()))
  }

  /// Mirror of [`closure_new_slot`] for the custom-VJP constructor seam.
  fn closure_custom_new_slot() -> &'static AtomicPtr<()> {
    static SLOT: OnceLock<AtomicPtr<()>> = OnceLock::new();
    SLOT.get_or_init(|| AtomicPtr::new(mlxrs_sys::mlx_closure_custom_new_func_payload as *mut ()))
  }

  /// Mirror of [`closure_new_install_lock`] for the custom-VJP seam.
  fn closure_custom_new_install_lock() -> &'static Mutex<()> {
    static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
    LOCK.get_or_init(|| Mutex::new(()))
  }

  /// Read the currently-installed constructor (default: real mlx-c symbol).
  ///
  /// Lock-free atomic load: must NOT block, because the calling
  /// `Closure::new` may be running inside a test that already holds the
  /// install lock via [`ScopedClosureCtor`].
  pub(crate) fn closure_new_fn() -> ClosureNewFn {
    let ptr = closure_new_slot().load(Ordering::Acquire);
    // SAFETY: SLOT only ever contains values written by:
    //   (a) initial `OnceLock` init: address of the real mlx-c FFI symbol;
    //   (b) `ScopedClosureCtor::install` / its `Drop`: the address of a
    //       `ClosureNewFn` (an `unsafe extern "C" fn`) cast `as *mut ()`,
    //       or a prior value of (a)/(b).
    // Both source forms are valid fn-pointers of type `ClosureNewFn`, so
    // round-tripping through `*mut ()` and transmuting back recovers the
    // exact original pointer with the same ABI. Fn pointers and `*mut ()`
    // have identical size + repr on all supported targets (the language
    // guarantees fn pointers are word-sized).
    unsafe { std::mem::transmute::<*mut (), ClosureNewFn>(ptr) }
  }

  /// Read the currently-installed custom-VJP constructor.
  pub(crate) fn closure_custom_new_fn() -> ClosureCustomNewFn {
    let ptr = closure_custom_new_slot().load(Ordering::Acquire);
    // SAFETY: mirror of `closure_new_fn` above — SLOT only ever stores
    // valid `ClosureCustomNewFn` fn-pointer addresses; round-tripping
    // through `*mut ()` is sound on all supported targets.
    unsafe { std::mem::transmute::<*mut (), ClosureCustomNewFn>(ptr) }
  }

  /// RAII guard: replace [`closure_new_fn`] with `stub` for the guard's
  /// lifetime, restore the previous symbol on drop.
  ///
  /// Holds [`closure_new_install_lock`] for the ENTIRE guard lifetime
  /// (install + test body + restore) so the swap→use→restore sequence is
  /// atomic with respect to any concurrent `ScopedClosureCtor` install on
  /// another thread. Test bodies don't read this lock directly — only
  /// other `install` calls block on it, which means a parallel seam test
  /// can't install a conflicting stub between this guard's install and
  /// the matching `Closure::new` call inside the test body.
  pub(crate) struct ScopedClosureCtor {
    // Holds the install lock for the entire guard lifetime, blocking any
    // other `install` call. Auto-released when this struct drops.
    _install_guard: MutexGuard<'static, ()>,
    prev: *mut (),
  }

  impl ScopedClosureCtor {
    pub(crate) fn install(stub: ClosureNewFn) -> Self {
      // Acquire the install lock first and hold it for the guard's
      // lifetime. Recover from poison: a prior seam test that panicked
      // mid-test should not block subsequent runs.
      let guard = closure_new_install_lock()
        .lock()
        .unwrap_or_else(|poison| poison.into_inner());
      let stub_ptr = stub as *mut ();
      let prev = closure_new_slot().swap(stub_ptr, Ordering::AcqRel);
      Self {
        _install_guard: guard,
        prev,
      }
    }
  }

  impl Drop for ScopedClosureCtor {
    fn drop(&mut self) {
      // Restore previous symbol even if the test panicked. The atomic
      // swap pairs with the matching swap in `install`; the install lock
      // is released when `_install_guard` drops at the end of this fn.
      closure_new_slot().store(self.prev, Ordering::Release);
    }
  }

  /// Mirror of [`ScopedClosureCtor`] for the custom-VJP constructor seam.
  pub(crate) struct ScopedCustomCtor {
    _install_guard: MutexGuard<'static, ()>,
    prev: *mut (),
  }

  impl ScopedCustomCtor {
    pub(crate) fn install(stub: ClosureCustomNewFn) -> Self {
      let guard = closure_custom_new_install_lock()
        .lock()
        .unwrap_or_else(|poison| poison.into_inner());
      let stub_ptr = stub as *mut ();
      let prev = closure_custom_new_slot().swap(stub_ptr, Ordering::AcqRel);
      Self {
        _install_guard: guard,
        prev,
      }
    }
  }

  impl Drop for ScopedCustomCtor {
    fn drop(&mut self) {
      closure_custom_new_slot().store(self.prev, Ordering::Release);
    }
  }
}

#[cfg(test)]
mod tests;