burn-mpsgraph 0.0.1

Apple MPSGraph backend for the Burn deep learning framework
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
//! Raw Objective-C FFI for Metal, MPS, MPSGraph, and Foundation.
//!
//! All Apple framework interactions go through this module.
//! Uses direct `objc_msgSend` calls — no objc2 / objc crate dependency.

#![allow(non_upper_case_globals, non_snake_case, unused)]

use std::ffi::{c_char, c_void, CString};
use std::ptr;
use std::sync::OnceLock;

// ─── ObjC runtime types ─────────────────────────────────────────────────────

/// Opaque ObjC object pointer (like `id` in ObjC).
pub type Id = *mut c_void;
/// ObjC class pointer.
pub type Class = *mut c_void;
/// ObjC selector pointer.
pub type Sel = *mut c_void;
pub type NSInteger = isize;
pub type NSUInteger = usize;

pub const NIL: Id = ptr::null_mut();

// ─── ObjC runtime imports ───────────────────────────────────────────────────

extern "C" {
    pub fn objc_getClass(name: *const c_char) -> Class;
    pub fn sel_registerName(name: *const c_char) -> Sel;
    pub fn objc_msgSend();           // Universal dispatch — transmute per call
    pub fn objc_alloc(cls: Class) -> Id;
    pub fn objc_retain(obj: Id) -> Id;
    pub fn objc_release(obj: Id);
    pub fn objc_autoreleasePoolPush() -> *mut c_void;
    pub fn objc_autoreleasePoolPop(ctx: *mut c_void);
}

// Metal
extern "C" {
    pub fn MTLCreateSystemDefaultDevice() -> Id;
}

// ─── Helpers ────────────────────────────────────────────────────────────────

#[inline]
pub fn class(name: &str) -> Class {
    let cs = CString::new(name).unwrap();
    let c = unsafe { objc_getClass(cs.as_ptr()) };
    assert!(!c.is_null(), "ObjC class '{}' not found", name);
    c
}

#[inline]
pub fn sel(name: &str) -> Sel {
    let cs = CString::new(name).unwrap();
    unsafe { sel_registerName(cs.as_ptr()) }
}

/// Retain an ObjC object (increment refcount). Returns the same pointer.
#[inline]
pub unsafe fn retain(obj: Id) -> Id {
    if !obj.is_null() { objc_retain(obj) } else { obj }
}

/// Release an ObjC object (decrement refcount).
#[inline]
pub unsafe fn release(obj: Id) {
    if !obj.is_null() { objc_release(obj); }
}

// ─── msg_send! macro ────────────────────────────────────────────────────────
//
// Sends an Objective-C message by transmuting `objc_msgSend` to the correct
// function pointer type.  On arm64 Apple Silicon, integer/pointer arguments
// travel in x0-x7 and float arguments in d0-d7, so the transmuted type must
// match the actual argument types exactly.

macro_rules! msg_send {
    // ── return Id ────────────────────────────────────────────────────────
    (Id; $obj:expr, $sel:expr) => {{
        let f: unsafe extern "C" fn(Id, Sel) -> Id =
            core::mem::transmute($crate::ffi::objc_msgSend as *const ());
        f($obj, $crate::ffi::sel($sel))
    }};
    (Id; $obj:expr, $sel:expr, ($($at:ty),+), $($av:expr),+) => {{
        let f: unsafe extern "C" fn(Id, Sel, $($at),+) -> Id =
            core::mem::transmute($crate::ffi::objc_msgSend as *const ());
        f($obj, $crate::ffi::sel($sel), $($av),+)
    }};

    // ── return NSUInteger (usize) ────────────────────────────────────────
    (usize; $obj:expr, $sel:expr) => {{
        let f: unsafe extern "C" fn(Id, Sel) -> usize =
            core::mem::transmute($crate::ffi::objc_msgSend as *const ());
        f($obj, $crate::ffi::sel($sel))
    }};
    (usize; $obj:expr, $sel:expr, ($($at:ty),+), $($av:expr),+) => {{
        let f: unsafe extern "C" fn(Id, Sel, $($at),+) -> usize =
            core::mem::transmute($crate::ffi::objc_msgSend as *const ());
        f($obj, $crate::ffi::sel($sel), $($av),+)
    }};

    // ── return NSInteger (isize) ─────────────────────────────────────────
    (isize; $obj:expr, $sel:expr) => {{
        let f: unsafe extern "C" fn(Id, Sel) -> isize =
            core::mem::transmute($crate::ffi::objc_msgSend as *const ());
        f($obj, $crate::ffi::sel($sel))
    }};

    // ── return bool ──────────────────────────────────────────────────────
    (bool; $obj:expr, $sel:expr) => {{
        let f: unsafe extern "C" fn(Id, Sel) -> bool =
            core::mem::transmute($crate::ffi::objc_msgSend as *const ());
        f($obj, $crate::ffi::sel($sel))
    }};

    // ── void return ──────────────────────────────────────────────────────
    (void; $obj:expr, $sel:expr) => {{
        let f: unsafe extern "C" fn(Id, Sel) =
            core::mem::transmute($crate::ffi::objc_msgSend as *const ());
        f($obj, $crate::ffi::sel($sel))
    }};
    (void; $obj:expr, $sel:expr, ($($at:ty),+), $($av:expr),+) => {{
        let f: unsafe extern "C" fn(Id, Sel, $($at),+) =
            core::mem::transmute($crate::ffi::objc_msgSend as *const ());
        f($obj, $crate::ffi::sel($sel), $($av),+)
    }};

    // ── return f64 ──────────────────────────────────────────────────────
    (f64; $obj:expr, $sel:expr) => {{
        let f: unsafe extern "C" fn(Id, Sel) -> f64 =
            core::mem::transmute($crate::ffi::objc_msgSend as *const ());
        f($obj, $crate::ffi::sel($sel))
    }};

    // ── return u32 (for MPSDataType) ────────────────────────────────────
    (u32; $obj:expr, $sel:expr) => {{
        let f: unsafe extern "C" fn(Id, Sel) -> u32 =
            core::mem::transmute($crate::ffi::objc_msgSend as *const ());
        f($obj, $crate::ffi::sel($sel))
    }};
}

pub(crate) use msg_send;

// ─── MPSDataType constants ──────────────────────────────────────────────────

pub mod MPSDataType {
    pub const FLOAT32:  u32 = 0x10000000 | 32;
    pub const FLOAT16:  u32 = 0x10000000 | 16;
    pub const BFLOAT16: u32 = 0x80000000 | (0x10000000 | 16); // AlternateBit | Float16
    pub const INT64:    u32 = 0x20000000 | 64;
    pub const INT32:    u32 = 0x20000000 | 32;
    pub const INT16:    u32 = 0x20000000 | 16;
    pub const INT8:     u32 = 0x20000000 | 8;
    pub const UINT64:   u32 = 64;
    pub const UINT32:   u32 = 32;
    pub const UINT16:   u32 = 16;
    pub const UINT8:    u32 = 8;
    pub const BOOL:     u32 = 0x80000000 | 8;
}

// ─── MPSGraph enums ─────────────────────────────────────────────────────────

pub mod MPSGraphPaddingStyle {
    pub const EXPLICIT: usize = 0;
}

pub mod MPSGraphTensorNamedDataLayout {
    pub const NCHW: usize = 0;
    pub const NHWC: usize = 1;
    pub const OIHW: usize = 2;
    pub const HWIO: usize = 3;
}

pub mod MPSGraphScatterMode {
    pub const ADD: isize = 1;
}

pub mod MPSGraphPoolingReturnIndicesMode {
    pub const GLOBAL_FLATTEN_2D: isize = 2;
}

pub mod MPSGraphResizeMode {
    pub const NEAREST: usize = 0;
    pub const BILINEAR: usize = 1;
}

// ─── Foundation helpers ─────────────────────────────────────────────────────

/// Create NSNumber from isize.
pub unsafe fn ns_number_isize(v: isize) -> Id {
    msg_send!(Id; class("NSNumber"), "numberWithLong:", (isize), v)
}

/// Create NSNumber from usize.
pub unsafe fn ns_number_usize(v: usize) -> Id {
    msg_send!(Id; class("NSNumber"), "numberWithUnsignedLong:", (usize), v)
}

/// Get isize from NSNumber.
pub unsafe fn ns_number_to_isize(n: Id) -> isize {
    msg_send!(isize; n, "longValue")
}

/// Get usize from NSNumber.
pub unsafe fn ns_number_to_usize(n: Id) -> usize {
    msg_send!(usize; n, "unsignedLongValue")
}

/// Create an NSArray from a slice of Id.
pub unsafe fn ns_array(objs: &[Id]) -> Id {
    msg_send!(Id; class("NSArray"), "arrayWithObjects:count:",
              (Id, usize), objs.as_ptr() as Id, objs.len())
}

/// NSArray count.
pub unsafe fn ns_array_count(arr: Id) -> usize {
    msg_send!(usize; arr, "count")
}

/// NSArray objectAtIndex:.
pub unsafe fn ns_array_get(arr: Id, idx: usize) -> Id {
    msg_send!(Id; arr, "objectAtIndex:", (usize), idx)
}

/// Create NSData from bytes (copies).
pub unsafe fn ns_data(bytes: *const u8, len: usize) -> Id {
    msg_send!(Id; class("NSData"), "dataWithBytes:length:",
              (Id, usize), bytes as Id, len)
}

/// Create NSDictionary from parallel key/value arrays.
pub unsafe fn ns_dictionary(keys: &[Id], vals: &[Id]) -> Id {
    assert_eq!(keys.len(), vals.len());
    msg_send!(Id; class("NSDictionary"), "dictionaryWithObjects:forKeys:count:",
              (Id, Id, usize), vals.as_ptr() as Id, keys.as_ptr() as Id, keys.len())
}

/// NSDictionary objectForKey:.
pub unsafe fn ns_dict_get(dict: Id, key: Id) -> Id {
    msg_send!(Id; dict, "objectForKey:", (Id), key)
}

/// NSArray of NSNumber from isize slice.
pub unsafe fn ns_isize_array(vals: &[isize]) -> Id {
    let nums: Vec<Id> = vals.iter().map(|&v| ns_number_isize(v)).collect();
    ns_array(&nums)
}

/// NSArray of NSNumber from usize slice.
pub unsafe fn ns_usize_array(vals: &[usize]) -> Id {
    let nums: Vec<Id> = vals.iter().map(|&v| ns_number_usize(v)).collect();
    ns_array(&nums)
}

// ─── Metal helpers ──────────────────────────────────────────────────────────

/// Get (or create) the global default Metal device.
pub fn metal_device() -> Id {
    static DEV: OnceLock<usize> = OnceLock::new();
    *DEV.get_or_init(|| {
        let d = unsafe { MTLCreateSystemDefaultDevice() };
        assert!(!d.is_null(), "No Metal device found");
        d as usize
    }) as Id
}

/// Get (or create) the global Metal command queue.
pub fn metal_queue() -> Id {
    static Q: OnceLock<usize> = OnceLock::new();
    *Q.get_or_init(|| {
        let q = unsafe { msg_send!(Id; metal_device(), "newCommandQueue") };
        assert!(!q.is_null(), "Failed to create command queue");
        q as usize
    }) as Id
}

/// Create an MTLBuffer with shared storage from bytes.
pub unsafe fn mtl_buffer_from_bytes(bytes: &[u8]) -> Id {
    // MTLResourceStorageModeShared = 0
    msg_send!(Id; metal_device(), "newBufferWithBytes:length:options:",
              (Id, usize, usize), bytes.as_ptr() as Id, bytes.len(), 0usize)
}

/// Create a zeroed MTLBuffer.
pub unsafe fn mtl_buffer_zeroed(len: usize) -> Id {
    let buf = msg_send!(Id; metal_device(), "newBufferWithLength:options:",
                        (usize, usize), len, 0usize);
    // Zero it
    let ptr = mtl_buffer_contents(buf);
    std::ptr::write_bytes(ptr as *mut u8, 0, len);
    buf
}

/// Get the raw pointer to buffer contents.
pub unsafe fn mtl_buffer_contents(buf: Id) -> *mut c_void {
    let f: unsafe extern "C" fn(Id, Sel) -> *mut c_void =
        core::mem::transmute(objc_msgSend as *const ());
    f(buf, sel("contents"))
}

/// Get buffer length.
pub unsafe fn mtl_buffer_length(buf: Id) -> usize {
    msg_send!(usize; buf, "length")
}

// ─── MPSGraph helpers ───────────────────────────────────────────────────────

/// Create a new MPSGraph.
pub unsafe fn mpsgraph_new() -> Id {
    let g = msg_send!(Id; class("MPSGraph"), "new");
    assert!(!g.is_null(), "Failed to create MPSGraph");
    g
}

/// MPSGraphDevice from MTLDevice.
pub unsafe fn mpsgraph_device() -> Id {
    msg_send!(Id; class("MPSGraphDevice"), "deviceWithMTLDevice:", (Id), metal_device())
}

/// Create MPSGraphTensorData backed by an MTLBuffer.
pub unsafe fn tensor_data_from_buffer(buf: Id, shape: Id, dtype: u32) -> Id {
    let alloc = objc_alloc(class("MPSGraphTensorData"));
    msg_send!(Id; alloc, "initWithMTLBuffer:shape:dataType:",
              (Id, Id, u32), buf, shape, dtype)
}

/// Create MPSGraphTensorData from NSData (copies to device).
pub unsafe fn tensor_data_from_nsdata(nsdata: Id, shape: Id, dtype: u32) -> Id {
    let alloc = objc_alloc(class("MPSGraphTensorData"));
    msg_send!(Id; alloc, "initWithDevice:data:shape:dataType:",
              (Id, Id, Id, u32), mpsgraph_device(), nsdata, shape, dtype)
}

/// Get shape of MPSGraphTensorData.
pub unsafe fn tensor_data_shape(td: Id) -> Id {
    msg_send!(Id; td, "shape")
}

/// Get MPSDataType of MPSGraphTensorData.
pub unsafe fn tensor_data_dtype(td: Id) -> u32 {
    msg_send!(u32; td, "dataType")
}

/// Get MPSNDArray from MPSGraphTensorData, then read bytes out.
pub unsafe fn tensor_data_read_bytes(td: Id, buf: &mut [u8]) {
    let ndarray = msg_send!(Id; td, "mpsndarray");
    let f: unsafe extern "C" fn(Id, Sel, *mut c_void, *mut isize) =
        core::mem::transmute(objc_msgSend as *const ());
    f(ndarray, sel("readBytes:strideBytes:"),
      buf.as_mut_ptr() as *mut c_void, ptr::null_mut());
}

/// Placeholder tensor.
pub unsafe fn graph_placeholder(graph: Id, shape: Id, dtype: u32) -> Id {
    msg_send!(Id; graph, "placeholderWithShape:dataType:name:",
              (Id, u32, Id), shape, dtype, NIL)
}

/// Scalar constant.
pub unsafe fn graph_constant_scalar(graph: Id, value: f64, dtype: u32) -> Id {
    msg_send!(Id; graph, "constantWithScalar:dataType:",
              (f64, u32), value, dtype)
}

/// Scalar constant with shape.
pub unsafe fn graph_constant_scalar_shape(graph: Id, value: f64, shape: Id, dtype: u32) -> Id {
    msg_send!(Id; graph, "constantWithScalar:shape:dataType:",
              (f64, Id, u32), value, shape, dtype)
}

/// Run graph synchronously.
pub unsafe fn graph_run(
    graph: Id,
    feeds: Id,       // NSDictionary<MPSGraphTensor, MPSGraphTensorData>
    targets: Id,      // NSArray<MPSGraphTensor>
) -> Id {
    msg_send!(Id; graph,
              "runWithMTLCommandQueue:feeds:targetTensors:targetOperations:",
              (Id, Id, Id, Id), metal_queue(), feeds, targets, NIL)
}

// ─── MPSGraph op wrappers ───────────────────────────────────────────────────
// These wrap the ObjC selectors into typed Rust functions.
// Pattern: graph_<op>(graph, tensor(s)...) -> MPSGraphTensor*

/// Generic unary op: graph method taking (tensor, name) -> tensor.
pub unsafe fn graph_unary(graph: Id, sel_name: &str, t: Id) -> Id {
    msg_send!(Id; graph, sel_name, (Id, Id), t, NIL)
}

/// Generic binary op: graph method taking (tensor, tensor, name) -> tensor.
pub unsafe fn graph_binary(graph: Id, sel_name: &str, a: Id, b: Id) -> Id {
    msg_send!(Id; graph, sel_name, (Id, Id, Id), a, b, NIL)
}

/// Cast tensor to dtype.
pub unsafe fn graph_cast(graph: Id, t: Id, dtype: u32) -> Id {
    msg_send!(Id; graph, "castTensor:toType:name:", (Id, u32, Id), t, dtype, NIL)
}

/// Reshape tensor.
pub unsafe fn graph_reshape(graph: Id, t: Id, shape: Id) -> Id {
    msg_send!(Id; graph, "reshapeTensor:withShape:name:", (Id, Id, Id), t, shape, NIL)
}

/// Transpose two dimensions.
pub unsafe fn graph_transpose(graph: Id, t: Id, dim1: usize, dim2: usize) -> Id {
    msg_send!(Id; graph, "transposeTensor:dimension:withDimension:name:",
              (Id, usize, usize, Id), t, dim1, dim2, NIL)
}

/// Permute dimensions.
pub unsafe fn graph_permute(graph: Id, t: Id, perm: Id) -> Id {
    msg_send!(Id; graph, "transposeTensor:permutation:name:",
              (Id, Id, Id), t, perm, NIL)
}

/// Broadcast to shape.
pub unsafe fn graph_broadcast(graph: Id, t: Id, shape: Id) -> Id {
    msg_send!(Id; graph, "broadcastTensor:toShape:name:", (Id, Id, Id), t, shape, NIL)
}

/// Slice with starts/ends/strides.
pub unsafe fn graph_slice(graph: Id, t: Id, starts: Id, ends: Id, strides: Id) -> Id {
    msg_send!(Id; graph, "sliceTensor:starts:ends:strides:name:",
              (Id, Id, Id, Id, Id), t, starts, ends, strides, NIL)
}

/// Slice with masks.
pub unsafe fn graph_slice_masked(
    graph: Id, t: Id, starts: Id, ends: Id, strides: Id,
    start_mask: u32, end_mask: u32, squeeze_mask: u32,
) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, Id, Id, Id, Id, u32, u32, u32, Id) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    f(graph,
      sel("sliceTensor:starts:ends:strides:startMask:endMask:squeezeMask:name:"),
      t, starts, ends, strides, start_mask, end_mask, squeeze_mask, NIL)
}

/// Slice update (assign into slice).
pub unsafe fn graph_slice_update(
    graph: Id, data: Id, update: Id, starts: Id, ends: Id, strides: Id,
) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, Id, Id, Id, Id, Id, Id) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    f(graph,
      sel("sliceUpdateDataTensor:updateTensor:starts:ends:strides:name:"),
      data, update, starts, ends, strides, NIL)
}

/// Concat tensors along dimension.
pub unsafe fn graph_concat(graph: Id, tensors: Id, dim: isize) -> Id {
    msg_send!(Id; graph, "concatTensors:dimension:name:",
              (Id, isize, Id), tensors, dim, NIL)
}

/// Reduction sum over axes.
pub unsafe fn graph_reduction_sum(graph: Id, t: Id, axes: Id) -> Id {
    msg_send!(Id; graph, "reductionSumWithTensor:axes:name:",
              (Id, Id, Id), t, axes, NIL)
}

/// Reduction sum over single axis.
pub unsafe fn graph_reduction_sum_axis(graph: Id, t: Id, axis: isize) -> Id {
    msg_send!(Id; graph, "reductionSumWithTensor:axis:name:",
              (Id, isize, Id), t, axis, NIL)
}

/// Reduction max over single axis.
pub unsafe fn graph_reduction_max_axis(graph: Id, t: Id, axis: isize) -> Id {
    msg_send!(Id; graph, "reductionMaximumWithTensor:axis:name:",
              (Id, isize, Id), t, axis, NIL)
}

/// Reduction min over single axis.
pub unsafe fn graph_reduction_min_axis(graph: Id, t: Id, axis: isize) -> Id {
    msg_send!(Id; graph, "reductionMinimumWithTensor:axis:name:",
              (Id, isize, Id), t, axis, NIL)
}

/// Reduction max over axes.
pub unsafe fn graph_reduction_max(graph: Id, t: Id, axes: Id) -> Id {
    msg_send!(Id; graph, "reductionMaximumWithTensor:axes:name:",
              (Id, Id, Id), t, axes, NIL)
}

/// Reduction min over axes.
pub unsafe fn graph_reduction_min(graph: Id, t: Id, axes: Id) -> Id {
    msg_send!(Id; graph, "reductionMinimumWithTensor:axes:name:",
              (Id, Id, Id), t, axes, NIL)
}

/// Reduction product over axes.
pub unsafe fn graph_reduction_prod(graph: Id, t: Id, axes: Id) -> Id {
    msg_send!(Id; graph, "reductionProductWithTensor:axes:name:",
              (Id, Id, Id), t, axes, NIL)
}

/// Reduction product over single axis.
pub unsafe fn graph_reduction_prod_axis(graph: Id, t: Id, axis: isize) -> Id {
    msg_send!(Id; graph, "reductionProductWithTensor:axis:name:",
              (Id, isize, Id), t, axis, NIL)
}

/// Argmax.
pub unsafe fn graph_argmax(graph: Id, t: Id, axis: isize) -> Id {
    msg_send!(Id; graph, "reductionArgMaximumWithTensor:axis:name:",
              (Id, isize, Id), t, axis, NIL)
}

/// Argmin.
pub unsafe fn graph_argmin(graph: Id, t: Id, axis: isize) -> Id {
    msg_send!(Id; graph, "reductionArgMinimumWithTensor:axis:name:",
              (Id, isize, Id), t, axis, NIL)
}

/// Cumulative sum.
pub unsafe fn graph_cumsum(graph: Id, t: Id, axis: isize) -> Id {
    msg_send!(Id; graph, "cumulativeSumWithTensor:axis:name:",
              (Id, isize, Id), t, axis, NIL)
}

/// Cumulative product.
pub unsafe fn graph_cumprod(graph: Id, t: Id, axis: isize) -> Id {
    msg_send!(Id; graph, "cumulativeProductWithTensor:axis:name:",
              (Id, isize, Id), t, axis, NIL)
}

/// Cumulative min.
pub unsafe fn graph_cummin(graph: Id, t: Id, axis: isize) -> Id {
    msg_send!(Id; graph, "cumulativeMinimumWithTensor:axis:name:",
              (Id, isize, Id), t, axis, NIL)
}

/// Cumulative max.
pub unsafe fn graph_cummax(graph: Id, t: Id, axis: isize) -> Id {
    msg_send!(Id; graph, "cumulativeMaximumWithTensor:axis:name:",
              (Id, isize, Id), t, axis, NIL)
}

/// Gather.
pub unsafe fn graph_gather(graph: Id, updates: Id, indices: Id, axis: usize, batch_dims: usize) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, Id, Id, usize, usize, Id) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    f(graph,
      sel("gatherWithUpdatesTensor:indicesTensor:axis:batchDimensions:name:"),
      updates, indices, axis, batch_dims, NIL)
}

/// Scatter along axis with mode.
pub unsafe fn graph_scatter_along(
    graph: Id, axis: isize, data: Id, updates: Id, indices: Id, mode: isize,
) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, isize, Id, Id, Id, isize, Id) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    f(graph,
      sel("scatterAlongAxis:withDataTensor:updatesTensor:indicesTensor:mode:name:"),
      axis, data, updates, indices, mode, NIL)
}

/// Select (ternary): predicate ? trueVal : falseVal.
pub unsafe fn graph_select(graph: Id, pred: Id, true_t: Id, false_t: Id) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, Id, Id, Id, Id) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    f(graph,
      sel("selectWithPredicateTensor:truePredicateTensor:falsePredicateTensor:name:"),
      pred, true_t, false_t, NIL)
}

/// Sort.
pub unsafe fn graph_sort(graph: Id, t: Id, axis: isize, descending: bool) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, Id, isize, bool, Id) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    f(graph, sel("sortWithTensor:axis:descending:name:"),
      t, axis, descending, NIL)
}

/// ArgSort.
pub unsafe fn graph_argsort(graph: Id, t: Id, axis: isize, descending: bool) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, Id, isize, bool, Id) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    f(graph, sel("argSortWithTensor:axis:descending:name:"),
      t, axis, descending, NIL)
}

/// Matrix multiplication.
pub unsafe fn graph_matmul(graph: Id, a: Id, b: Id) -> Id {
    graph_binary(graph, "matrixMultiplicationWithPrimaryTensor:secondaryTensor:name:", a, b)
}

/// Convolution 2D descriptor.
pub unsafe fn conv2d_desc(
    sx: usize, sy: usize, dx: usize, dy: usize, groups: usize,
    pl: usize, pr: usize, pt: usize, pb: usize,
) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, usize,usize,usize,usize,usize,usize,usize,usize,usize,usize,usize,usize) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    let desc = f(class("MPSGraphConvolution2DOpDescriptor"),
      sel("descriptorWithStrideInX:strideInY:dilationRateInX:dilationRateInY:groups:paddingLeft:paddingRight:paddingTop:paddingBottom:paddingStyle:dataLayout:weightsLayout:"),
      sx, sy, dx, dy, groups, pl, pr, pt, pb,
      MPSGraphPaddingStyle::EXPLICIT,
      MPSGraphTensorNamedDataLayout::NCHW,
      MPSGraphTensorNamedDataLayout::OIHW);
    assert!(!desc.is_null(), "Failed to create conv2d descriptor (sx={sx},sy={sy},dx={dx},dy={dy},g={groups})");
    desc
}

/// Convolution 2D forward.
pub unsafe fn graph_conv2d(graph: Id, src: Id, weights: Id, desc: Id) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, Id, Id, Id, Id) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    f(graph,
      sel("convolution2DWithSourceTensor:weightsTensor:descriptor:name:"),
      src, weights, desc, NIL)
}

/// Convolution transpose 2D.
pub unsafe fn graph_conv_transpose2d(graph: Id, src: Id, weights: Id, out_shape: Id, desc: Id) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, Id, Id, Id, Id, Id) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    f(graph,
      sel("convolutionTranspose2DWithSourceTensor:weightsTensor:outputShape:descriptor:name:"),
      src, weights, out_shape, desc, NIL)
}

/// Pooling 2D descriptor (with dilation).
pub unsafe fn pool2d_desc(
    kw: usize, kh: usize, sx: usize, sy: usize,
    dx: usize, dy: usize,
    pl: usize, pr: usize, pt: usize, pb: usize,
) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, usize,usize,usize,usize,usize,usize,usize,usize,usize,usize,usize,usize) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    f(class("MPSGraphPooling2DOpDescriptor"),
      sel("descriptorWithKernelWidth:kernelHeight:strideInX:strideInY:dilationRateInX:dilationRateInY:paddingLeft:paddingRight:paddingTop:paddingBottom:paddingStyle:dataLayout:"),
      kw, kh, sx, sy, dx, dy, pl, pr, pt, pb,
      MPSGraphPaddingStyle::EXPLICIT,
      MPSGraphTensorNamedDataLayout::NCHW)
}

/// Set includeZeroPadToAverage on pool descriptor.
pub unsafe fn pool_desc_set_include_zero_pad(desc: Id, val: bool) {
    msg_send!(void; desc, "setIncludeZeroPadToAverage:", (bool), val);
}

/// Set return indices mode & dtype on pool descriptor.
pub unsafe fn pool_desc_set_return_indices(desc: Id) {
    msg_send!(void; desc, "setReturnIndicesMode:", (isize),
              MPSGraphPoolingReturnIndicesMode::GLOBAL_FLATTEN_2D);
    msg_send!(void; desc, "setReturnIndicesDataType:", (u32), MPSDataType::INT32);
}

/// Average pool 2D forward.
pub unsafe fn graph_avg_pool2d(graph: Id, src: Id, desc: Id) -> Id {
    msg_send!(Id; graph, "avgPooling2DWithSourceTensor:descriptor:name:",
              (Id, Id, Id), src, desc, NIL)
}

/// Average pool 2D backward.
pub unsafe fn graph_avg_pool2d_grad(graph: Id, grad: Id, src: Id, desc: Id) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, Id, Id, Id, Id) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    f(graph,
      sel("avgPooling2DGradientWithGradientTensor:sourceTensor:descriptor:name:"),
      grad, src, desc, NIL)
}

/// Max pool 2D forward.
pub unsafe fn graph_max_pool2d(graph: Id, src: Id, desc: Id) -> Id {
    msg_send!(Id; graph, "maxPooling2DWithSourceTensor:descriptor:name:",
              (Id, Id, Id), src, desc, NIL)
}

/// Max pool 2D with indices (returns NSArray of 2: [values, indices]).
pub unsafe fn graph_max_pool2d_return_indices(graph: Id, src: Id, desc: Id) -> Id {
    msg_send!(Id; graph, "maxPooling2DReturnIndicesWithSourceTensor:descriptor:name:",
              (Id, Id, Id), src, desc, NIL)
}

/// Max pool 2D return indices backward.
pub unsafe fn graph_max_pool2d_indices_grad(graph: Id, grad: Id, indices: Id, src: Id, desc: Id) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, Id, Id, Id, Id, Id) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    f(graph,
      sel("maxPooling2DReturnIndicesGradientWithGradientTensor:indicesTensor:sourceTensor:descriptor:name:"),
      grad, indices, src, desc, NIL)
}

/// Resize tensor.
pub unsafe fn graph_resize(
    graph: Id, t: Id, size: Id, mode: usize,
    center: bool, align_corners: bool,
) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, Id, Id, usize, bool, bool, usize, Id) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    f(graph,
      sel("resizeTensor:size:mode:centerResult:alignCorners:layout:name:"),
      t, size, mode, center, align_corners,
      MPSGraphTensorNamedDataLayout::NCHW, NIL)
}

/// Resize backward.
pub unsafe fn graph_resize_grad(
    graph: Id, grad: Id, input: Id, mode: usize,
    center: bool, align_corners: bool,
) -> Id {
    let f: unsafe extern "C" fn(Id, Sel, Id, Id, usize, bool, bool, usize, Id) -> Id =
        core::mem::transmute(objc_msgSend as *const ());
    f(graph,
      sel("resizeWithGradientTensor:input:mode:centerResult:alignCorners:layout:name:"),
      grad, input, mode, center, align_corners,
      MPSGraphTensorNamedDataLayout::NCHW, NIL)
}