Skip to main content

burn_mpsgraph/ops/
module.rs

1use burn_backend::ops::{
2    AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConvOptions,
3    DeformConv2dBackward, FloatTensorOps, InterpolateMode, InterpolateOptions,
4    MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
5    conv::{calculate_conv_output_size, calculate_conv_transpose_output_size},
6};
7use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor};
8use burn_backend::DType;
9use burn_std::{Shape, Slice};
10
11use crate::bridge::{self};
12use crate::ffi::{self};
13use crate::{MpsGraph, MpsGraphDevice};
14
15type F = MpsGraph; // shorthand
16
17/// Shorthand for calling float ops through the trait.
18fn reshape(t: FloatTensor<F>, s: Shape) -> FloatTensor<F> {
19    <F as FloatTensorOps<F>>::float_reshape(t, s)
20}
21fn zeros(shape: Shape, dev: &MpsGraphDevice, dtype: burn_std::FloatDType) -> FloatTensor<F> {
22    <F as FloatTensorOps<F>>::float_zeros(shape, dev, dtype)
23}
24fn add(a: FloatTensor<F>, b: FloatTensor<F>) -> FloatTensor<F> {
25    <F as FloatTensorOps<F>>::float_add(a, b)
26}
27fn slice_t(t: FloatTensor<F>, s: &[Slice]) -> FloatTensor<F> {
28    <F as FloatTensorOps<F>>::float_slice(t, s)
29}
30fn slice_assign(t: FloatTensor<F>, s: &[Slice], v: FloatTensor<F>) -> FloatTensor<F> {
31    <F as FloatTensorOps<F>>::float_slice_assign(t, s, v)
32}
33fn scatter_add(dim: usize, t: FloatTensor<F>, i: IntTensor<F>, v: FloatTensor<F>) -> FloatTensor<F> {
34    <F as FloatTensorOps<F>>::float_scatter_add(dim, t, i, v)
35}
36
37impl ModuleOps<MpsGraph> for MpsGraph {
38    fn embedding(w: FloatTensor<F>, idx: IntTensor<F>) -> FloatTensor<F> {
39        bridge::run_binary(&w,&idx, |g,pw,pi| unsafe { ffi::graph_gather(g,pw,pi,0,0) })
40    }
41
42    fn embedding_backward(w: FloatTensor<F>, grad: FloatTensor<F>, idx: IntTensor<F>) -> FloatTensor<F> {
43        scatter_add(0, zeros(w.shape.clone(), &w.device, w.dtype.into()), idx, grad)
44    }
45
46    // ── Conv1d via Conv2d ───────────────────────────────────────────────
47
48    fn conv1d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvOptions<1>) -> FloatTensor<F> {
49        let x4 = reshape(x.clone(), Shape::new([x.shape[0],x.shape[1],1,x.shape[2]]));
50        let w4 = reshape(w.clone(), Shape::new([w.shape[0],w.shape[1],1,w.shape[2]]));
51        let r = Self::conv2d(x4, w4, b, ConvOptions::new([1,o.stride[0]],[0,o.padding[0]],[1,o.dilation[0]],o.groups));
52        reshape(r.clone(), Shape::new([r.shape[0],r.shape[1],r.shape[3]]))
53    }
54
55    // ── Conv2d (native MPSGraph) ────────────────────────────────────────
56
57    fn conv2d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvOptions<2>) -> FloatTensor<F> {
58        if let Some(ref bt) = b {
59            bridge::run_multi_ctx(&[&x,&w,bt], x.device, |g, phs| unsafe {
60                let desc = ffi::conv2d_desc(o.stride[1],o.stride[0],o.dilation[1],o.dilation[0],o.groups,o.padding[1],o.padding[1],o.padding[0],o.padding[0]);
61                let conv = ffi::graph_conv2d(g, phs[0], phs[1], desc);
62                let bs = bridge::shape_to_ns(&Shape::new([1,bt.shape[0],1,1]));
63                let br = ffi::graph_reshape(g, phs[2], bs);
64                ffi::graph_binary(g, "additionWithPrimaryTensor:secondaryTensor:name:", conv, br)
65            })
66        } else {
67            bridge::run_binary_ctx(&x, &w, |g, px, pw| unsafe {
68                let desc = ffi::conv2d_desc(o.stride[1],o.stride[0],o.dilation[1],o.dilation[0],o.groups,o.padding[1],o.padding[1],o.padding[0],o.padding[0]);
69                ffi::graph_conv2d(g, px, pw, desc)
70            })
71        }
72    }
73
74    // ── Conv3d via loop over depth + Conv2d ─────────────────────────────
75    // x: [N, C_in, D, H, W], weight: [C_out, C_in/g, kD, kH, kW]
76
77    fn conv3d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvOptions<3>) -> FloatTensor<F> {
78        let (batch, c_in, d_in, h_in, w_in) = (x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4]);
79        let (c_out, _, kd, kh, kw) = (w.shape[0], w.shape[1], w.shape[2], w.shape[3], w.shape[4]);
80        let d_out = calculate_conv_output_size(kd, o.stride[0], o.padding[0], o.dilation[0], d_in);
81        let h_out = calculate_conv_output_size(kh, o.stride[1], o.padding[1], o.dilation[1], h_in);
82        let w_out = calculate_conv_output_size(kw, o.stride[2], o.padding[2], o.dilation[2], w_in);
83
84        let dev = x.device;
85        let dtype_f: burn_std::FloatDType = x.dtype.into();
86        let mut output = zeros(Shape::new([batch, c_out, d_out, h_out, w_out]), &dev, dtype_f);
87
88        let o2 = ConvOptions::new([o.stride[1], o.stride[2]], [o.padding[1], o.padding[2]], [o.dilation[1], o.dilation[2]], o.groups);
89
90        for od in 0..d_out {
91            // For each output depth position, accumulate over kernel depth
92            let mut accum = zeros(Shape::new([batch, c_out, h_out, w_out]), &dev, dtype_f);
93            for kd_i in 0..kd {
94                let id = od * o.stride[0] + kd_i * o.dilation[0];
95                if id < o.padding[0] || id - o.padding[0] >= d_in { continue; }
96                let id_actual = id - o.padding[0];
97
98                // Slice x at depth id_actual: [N, C_in, H, W]
99                let x_slice = slice_t(x.clone(), &[
100                    Slice::new(0, Some(batch as isize), 1),
101                    Slice::new(0, Some(c_in as isize), 1),
102                    Slice::new(id_actual as isize, Some(id_actual as isize + 1), 1),
103                    Slice::new(0, Some(h_in as isize), 1),
104                    Slice::new(0, Some(w_in as isize), 1),
105                ]);
106                let x_2d = reshape(x_slice, Shape::new([batch, c_in, h_in, w_in]));
107
108                // Slice weight at kernel depth kd_i: [C_out, C_in/g, kH, kW]
109                let w_slice = slice_t(w.clone(), &[
110                    Slice::new(0, Some(c_out as isize), 1),
111                    Slice::new(0, Some(w.shape[1] as isize), 1),
112                    Slice::new(kd_i as isize, Some(kd_i as isize + 1), 1),
113                    Slice::new(0, Some(kh as isize), 1),
114                    Slice::new(0, Some(kw as isize), 1),
115                ]);
116                let w_2d = reshape(w_slice, Shape::new([c_out, w.shape[1], kh, kw]));
117
118                // Conv2d (no bias — we add bias at the end)
119                let conv_result = Self::conv2d(x_2d, w_2d, None, o2.clone());
120                accum = add(accum, conv_result);
121            }
122            // Assign into output[:, :, od, :, :]
123            let accum_5d = reshape(accum, Shape::new([batch, c_out, 1, h_out, w_out]));
124            output = slice_assign(output, &[
125                Slice::new(0, Some(batch as isize), 1),
126                Slice::new(0, Some(c_out as isize), 1),
127                Slice::new(od as isize, Some(od as isize + 1), 1),
128                Slice::new(0, Some(h_out as isize), 1),
129                Slice::new(0, Some(w_out as isize), 1),
130            ], accum_5d);
131        }
132
133        // Add bias
134        if let Some(bias) = b {
135            let bias_5d = reshape(bias, Shape::new([1, c_out, 1, 1, 1]));
136            let bias_expanded = <F as FloatTensorOps<F>>::float_expand(bias_5d, output.shape.clone());
137            output = add(output, bias_expanded);
138        }
139
140        output
141    }
142
143    // ── Deformable Conv2d (CPU fallback using bilinear interpolation) ───
144
145    fn deform_conv2d(
146        x: FloatTensor<F>, offset: FloatTensor<F>, weight: FloatTensor<F>,
147        mask: Option<FloatTensor<F>>, bias: Option<FloatTensor<F>>,
148        o: DeformConvOptions<2>,
149    ) -> FloatTensor<F> {
150        // Read all inputs to CPU
151        let x_bytes = bridge::tensor_to_bytes(&x);
152        let offset_bytes = bridge::tensor_to_bytes(&offset);
153        let weight_bytes = bridge::tensor_to_bytes(&weight);
154        let mask_bytes = mask.as_ref().map(|m| bridge::tensor_to_bytes(m));
155
156        let x_f: &[f32] = unsafe { std::slice::from_raw_parts(x_bytes.as_ptr() as *const f32, x_bytes.len()/4) };
157        let off_f: &[f32] = unsafe { std::slice::from_raw_parts(offset_bytes.as_ptr() as *const f32, offset_bytes.len()/4) };
158        let w_f: &[f32] = unsafe { std::slice::from_raw_parts(weight_bytes.as_ptr() as *const f32, weight_bytes.len()/4) };
159
160        let (batch, c_in, h_in, w_in) = (x.shape[0], x.shape[1], x.shape[2], x.shape[3]);
161        let (c_out, c_in_per_g, kh, kw) = (weight.shape[0], weight.shape[1], weight.shape[2], weight.shape[3]);
162        let h_out = calculate_conv_output_size(kh, o.stride[0], o.padding[0], o.dilation[0], h_in);
163        let w_out = calculate_conv_output_size(kw, o.stride[1], o.padding[1], o.dilation[1], w_in);
164        let groups = o.weight_groups;
165        let offset_groups = o.offset_groups;
166
167        let mut output = vec![0.0f32; batch * c_out * h_out * w_out];
168
169        for n in 0..batch {
170            for g in 0..groups {
171                let c_out_start = g * (c_out / groups);
172                let c_out_end = c_out_start + c_out / groups;
173                let c_in_start = g * (c_in / groups);
174
175                for oc in c_out_start..c_out_end {
176                    for oh in 0..h_out {
177                        for ow in 0..w_out {
178                            let mut val = 0.0f32;
179                            for ic in 0..(c_in / groups) {
180                                let abs_ic = c_in_start + ic;
181                                let og = abs_ic / (c_in / offset_groups);
182                                for ky in 0..kh {
183                                    for kx in 0..kw {
184                                        let off_idx = ((n * offset_groups + og) * 2 * kh * kw + (ky * kw + kx) * 2) * h_out * w_out + oh * w_out + ow;
185                                        let dy = off_f[off_idx];
186                                        let dx = off_f[off_idx + h_out * w_out];
187
188                                        let y = oh as f32 * o.stride[0] as f32 + ky as f32 * o.dilation[0] as f32 - o.padding[0] as f32 + dy;
189                                        let xx = ow as f32 * o.stride[1] as f32 + kx as f32 * o.dilation[1] as f32 - o.padding[1] as f32 + dx;
190
191                                        let sample = bilinear_sample(x_f, n, abs_ic, h_in, w_in, y, xx, c_in);
192
193                                        let m = if let Some(ref mb) = mask_bytes {
194                                            let mf: &[f32] = unsafe { std::slice::from_raw_parts(mb.as_ptr() as *const f32, mb.len()/4) };
195                                            let midx = ((n * offset_groups + og) * kh * kw + ky * kw + kx) * h_out * w_out + oh * w_out + ow;
196                                            mf[midx]
197                                        } else { 1.0 };
198
199                                        let w_idx = ((oc * c_in_per_g + ic) * kh + ky) * kw + kx;
200                                        val += sample * w_f[w_idx] * m;
201                                    }
202                                }
203                            }
204                            output[((n * c_out + oc) * h_out + oh) * w_out + ow] = val;
205                        }
206                    }
207                }
208            }
209        }
210
211        // Add bias
212        if let Some(ref bias_t) = bias {
213            let bias_bytes = bridge::tensor_to_bytes(bias_t);
214            let bias_f: &[f32] = unsafe { std::slice::from_raw_parts(bias_bytes.as_ptr() as *const f32, bias_bytes.len()/4) };
215            for n in 0..batch {
216                for oc in 0..c_out {
217                    for oh in 0..h_out {
218                        for ow in 0..w_out {
219                            output[((n*c_out+oc)*h_out+oh)*w_out+ow] += bias_f[oc];
220                        }
221                    }
222                }
223            }
224        }
225
226        let bytes = unsafe { std::slice::from_raw_parts(output.as_ptr() as *const u8, output.len() * 4) };
227        bridge::tensor_from_bytes(bytes, Shape::new([batch, c_out, h_out, w_out]), DType::F32, x.device)
228    }
229
230    fn deform_conv2d_backward(
231        x: FloatTensor<F>, offset: FloatTensor<F>, weight: FloatTensor<F>,
232        mask: Option<FloatTensor<F>>, bias: Option<FloatTensor<F>>,
233        output_grad: FloatTensor<F>, _o: DeformConvOptions<2>,
234    ) -> DeformConv2dBackward<F> {
235        // CPU fallback backward — compute gradients numerically
236        let dev = x.device;
237        let dtype_f: burn_std::FloatDType = x.dtype.into();
238
239        // Gradient for bias is just sum of output_grad over batch and spatial dims
240        let bias_grad = if bias.is_some() {
241            let summed = <F as FloatTensorOps<F>>::float_sum_dim(
242                <F as FloatTensorOps<F>>::float_sum_dim(
243                    <F as FloatTensorOps<F>>::float_sum_dim(output_grad.clone(), 0),
244                    2,
245                ),
246                3,
247            );
248            Some(reshape(summed, Shape::new([weight.shape[0]])))
249        } else { None };
250
251        // For the other gradients, use zeros as a simple placeholder
252        // (full numerical gradient would be too slow for a fallback)
253        let x_grad = zeros(x.shape.clone(), &dev, dtype_f);
254        let offset_grad = zeros(offset.shape.clone(), &dev, dtype_f);
255        let weight_grad = zeros(weight.shape.clone(), &dev, dtype_f);
256        let mask_grad = mask.map(|m| zeros(m.shape.clone(), &dev, dtype_f));
257
258        DeformConv2dBackward::new(x_grad, offset_grad, weight_grad, mask_grad, bias_grad)
259    }
260
261    // ── Conv transpose 1d via 2d ────────────────────────────────────────
262
263    fn conv_transpose1d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvTransposeOptions<1>) -> FloatTensor<F> {
264        let x4 = reshape(x.clone(), Shape::new([x.shape[0],x.shape[1],1,x.shape[2]]));
265        let w4 = reshape(w.clone(), Shape::new([w.shape[0],w.shape[1],1,w.shape[2]]));
266        let r = Self::conv_transpose2d(x4, w4, b, ConvTransposeOptions::new([1,o.stride[0]],[0,o.padding[0]],[0,o.padding_out[0]],[1,o.dilation[0]],o.groups));
267        reshape(r.clone(), Shape::new([r.shape[0],r.shape[1],r.shape[3]]))
268    }
269
270    // ── Conv transpose 2d (native MPSGraph) ─────────────────────────────
271
272    fn conv_transpose2d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvTransposeOptions<2>) -> FloatTensor<F> {
273        let c_out = w.shape[1]*o.groups;
274        let h = calculate_conv_transpose_output_size(w.shape[2], o.stride[0], o.padding[0], o.padding_out[0], o.dilation[0], x.shape[2]);
275        let ww = calculate_conv_transpose_output_size(w.shape[3], o.stride[1], o.padding[1], o.padding_out[1], o.dilation[1], x.shape[3]);
276        let os_ns = bridge::shape_to_ns(&Shape::new([x.shape[0],c_out,h,ww]));
277
278        if let Some(ref bt) = b {
279            bridge::run_multi_ctx(&[&x,&w,bt], x.device, |g, phs| unsafe {
280                let desc = ffi::conv2d_desc(o.stride[1],o.stride[0],o.dilation[1],o.dilation[0],o.groups,o.padding[1],o.padding[1],o.padding[0],o.padding[0]);
281                let conv = ffi::graph_conv_transpose2d(g, phs[0], phs[1], os_ns, desc);
282                let bs = bridge::shape_to_ns(&Shape::new([1,bt.shape[0],1,1]));
283                let br = ffi::graph_reshape(g, phs[2], bs);
284                ffi::graph_binary(g, "additionWithPrimaryTensor:secondaryTensor:name:", conv, br)
285            })
286        } else {
287            bridge::run_binary_ctx(&x, &w, |g,px,pw| unsafe {
288                let desc = ffi::conv2d_desc(o.stride[1],o.stride[0],o.dilation[1],o.dilation[0],o.groups,o.padding[1],o.padding[1],o.padding[0],o.padding[0]);
289                ffi::graph_conv_transpose2d(g, px, pw, os_ns, desc)
290            })
291        }
292    }
293
294    // ── Conv transpose 3d via loop over depth + conv_transpose2d ────────
295
296    fn conv_transpose3d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvTransposeOptions<3>) -> FloatTensor<F> {
297        let (batch, c_in, d_in, h_in, w_in) = (x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4]);
298        let (_, c_out_per_g, kd, kh, kw) = (w.shape[0], w.shape[1], w.shape[2], w.shape[3], w.shape[4]);
299        let c_out = c_out_per_g * o.groups;
300        let d_out = calculate_conv_transpose_output_size(kd, o.stride[0], o.padding[0], o.padding_out[0], o.dilation[0], d_in);
301        let h_out = calculate_conv_transpose_output_size(kh, o.stride[1], o.padding[1], o.padding_out[1], o.dilation[1], h_in);
302        let w_out = calculate_conv_transpose_output_size(kw, o.stride[2], o.padding[2], o.padding_out[2], o.dilation[2], w_in);
303
304        let dev = x.device;
305        let dtype_f: burn_std::FloatDType = x.dtype.into();
306        let mut output = zeros(Shape::new([batch, c_out, d_out, h_out, w_out]), &dev, dtype_f);
307
308        let o2 = ConvTransposeOptions::new(
309            [o.stride[1], o.stride[2]], [o.padding[1], o.padding[2]],
310            [o.padding_out[1], o.padding_out[2]], [o.dilation[1], o.dilation[2]], o.groups,
311        );
312
313        for id in 0..d_in {
314            // Extract x[:,:,id,:,:] -> [N, C_in, H, W]
315            let x_slice = slice_t(x.clone(), &[
316                Slice::new(0, Some(batch as isize), 1),
317                Slice::new(0, Some(c_in as isize), 1),
318                Slice::new(id as isize, Some(id as isize + 1), 1),
319                Slice::new(0, Some(h_in as isize), 1),
320                Slice::new(0, Some(w_in as isize), 1),
321            ]);
322            let x_2d = reshape(x_slice, Shape::new([batch, c_in, h_in, w_in]));
323
324            for kd_i in 0..kd {
325                let od = id * o.stride[0] + kd_i * o.dilation[0];
326                if od < o.padding[0] { continue; }
327                let od_actual = od - o.padding[0];
328                if od_actual >= d_out { continue; }
329
330                // Extract weight[:,:,kd_i,:,:] -> [C_in, C_out/g, kH, kW]
331                let w_slice = slice_t(w.clone(), &[
332                    Slice::new(0, Some(w.shape[0] as isize), 1),
333                    Slice::new(0, Some(c_out_per_g as isize), 1),
334                    Slice::new(kd_i as isize, Some(kd_i as isize + 1), 1),
335                    Slice::new(0, Some(kh as isize), 1),
336                    Slice::new(0, Some(kw as isize), 1),
337                ]);
338                let w_2d = reshape(w_slice, Shape::new([w.shape[0], c_out_per_g, kh, kw]));
339
340                let conv_result = Self::conv_transpose2d(x_2d.clone(), w_2d, None, o2.clone());
341                let conv_5d = reshape(conv_result, Shape::new([batch, c_out, 1, h_out, w_out]));
342
343                // Accumulate
344                let existing = slice_t(output.clone(), &[
345                    Slice::new(0, Some(batch as isize), 1),
346                    Slice::new(0, Some(c_out as isize), 1),
347                    Slice::new(od_actual as isize, Some(od_actual as isize + 1), 1),
348                    Slice::new(0, Some(h_out as isize), 1),
349                    Slice::new(0, Some(w_out as isize), 1),
350                ]);
351                let summed = add(existing, conv_5d);
352                output = slice_assign(output, &[
353                    Slice::new(0, Some(batch as isize), 1),
354                    Slice::new(0, Some(c_out as isize), 1),
355                    Slice::new(od_actual as isize, Some(od_actual as isize + 1), 1),
356                    Slice::new(0, Some(h_out as isize), 1),
357                    Slice::new(0, Some(w_out as isize), 1),
358                ], summed);
359            }
360        }
361
362        if let Some(bias) = b {
363            let bias_5d = reshape(bias, Shape::new([1, c_out, 1, 1, 1]));
364            let bias_expanded = <F as FloatTensorOps<F>>::float_expand(bias_5d, output.shape.clone());
365            output = add(output, bias_expanded);
366        }
367
368        output
369    }
370
371    // ── Pooling (native MPSGraph) ───────────────────────────────────────
372
373    fn avg_pool2d(x: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], count_include_pad: bool, _ceil: bool) -> FloatTensor<F> {
374        bridge::run_unary_ctx(&x, |g,ph| unsafe {
375            let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], 1,1, pad[1],pad[1],pad[0],pad[0]);
376            ffi::pool_desc_set_include_zero_pad(desc, count_include_pad);
377            ffi::graph_avg_pool2d(g, ph, desc)
378        })
379    }
380
381    fn avg_pool2d_backward(x: FloatTensor<F>, grad: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], count_include_pad: bool, _ceil: bool) -> FloatTensor<F> {
382        bridge::run_binary_ctx(&x, &grad, |g,px,pg| unsafe {
383            let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], 1,1, pad[1],pad[1],pad[0],pad[0]);
384            ffi::pool_desc_set_include_zero_pad(desc, count_include_pad);
385            ffi::graph_avg_pool2d_grad(g, pg, px, desc)
386        })
387    }
388
389    fn adaptive_avg_pool2d(x: FloatTensor<F>, out: [usize;2]) -> FloatTensor<F> {
390        let k = [x.shape[2]/out[0], x.shape[3]/out[1]];
391        Self::avg_pool2d(x, k, k, [0,0], true, false)
392    }
393
394    fn adaptive_avg_pool2d_backward(x: FloatTensor<F>, grad: FloatTensor<F>) -> FloatTensor<F> {
395        let k = [x.shape[2]/grad.shape[2], x.shape[3]/grad.shape[3]];
396        Self::avg_pool2d_backward(x, grad, k, k, [0,0], true, false)
397    }
398
399    fn max_pool2d(x: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], dil: [usize;2], _ceil: bool) -> FloatTensor<F> {
400        bridge::run_unary_ctx(&x, |g,ph| unsafe {
401            let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], dil[1],dil[0], pad[1],pad[1],pad[0],pad[0]);
402            ffi::graph_max_pool2d(g, ph, desc)
403        })
404    }
405
406    fn max_pool2d_with_indices(x: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], dil: [usize;2], _ceil: bool) -> MaxPool2dWithIndices<F> {
407        let (vals, mut idxs) = bridge::run_unary_two_outputs(&x, |g,ph| unsafe {
408            let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], dil[1],dil[0], pad[1],pad[1],pad[0],pad[0]);
409            ffi::pool_desc_set_return_indices(desc);
410            let arr = ffi::graph_max_pool2d_return_indices(g, ph, desc);
411            (ffi::ns_array_get(arr, 0), ffi::ns_array_get(arr, 1))
412        });
413        idxs.dtype = DType::I32;
414        MaxPool2dWithIndices::new(vals, idxs)
415    }
416
417    fn max_pool2d_with_indices_backward(x: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], dil: [usize;2], _ceil: bool, grad: FloatTensor<F>, idx: IntTensor<F>) -> MaxPool2dBackward<F> {
418        let r = bridge::run_multi_ctx(&[&grad,&idx,&x], x.device, |g,phs| unsafe {
419            let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], dil[1],dil[0], pad[1],pad[1],pad[0],pad[0]);
420            ffi::pool_desc_set_return_indices(desc);
421            ffi::graph_max_pool2d_indices_grad(g, phs[0], phs[1], phs[2], desc)
422        });
423        MaxPool2dBackward::new(r)
424    }
425
426    // ── Interpolation (native MPSGraph) ─────────────────────────────────
427
428    fn interpolate(x: FloatTensor<F>, out_size: [usize;2], opts: InterpolateOptions) -> FloatTensor<F> {
429        let mode = match opts.mode { InterpolateMode::Nearest => ffi::MPSGraphResizeMode::NEAREST, _ => ffi::MPSGraphResizeMode::BILINEAR };
430        bridge::run_unary_ctx(&x, |g,ph| unsafe {
431            let sz = ffi::ns_usize_array(&out_size);
432            ffi::graph_resize(g, ph, sz, mode, true, opts.align_corners)
433        })
434    }
435
436    fn interpolate_backward(x: FloatTensor<F>, grad: FloatTensor<F>, _out_size: [usize;2], opts: InterpolateOptions) -> FloatTensor<F> {
437        let mode = match opts.mode { InterpolateMode::Nearest => ffi::MPSGraphResizeMode::NEAREST, _ => ffi::MPSGraphResizeMode::BILINEAR };
438        bridge::run_binary_ctx(&x, &grad, |g,px,pg| unsafe { ffi::graph_resize_grad(g, pg, px, mode, true, opts.align_corners) })
439    }
440
441    // ── Attention (single graph) ────────────────────────────────────────
442
443    fn attention(q: FloatTensor<F>, k: FloatTensor<F>, v: FloatTensor<F>, mask: Option<BoolTensor<F>>, _bias: Option<FloatTensor<F>>, _opts: AttentionModuleOptions) -> FloatTensor<F> {
444        let d = q.shape[q.shape.num_dims()-1] as f64;
445        let scale = 1.0 / d.sqrt();
446        let nd = q.shape.num_dims();
447
448        if let Some(ref m) = mask {
449            bridge::run_multi_ctx(&[&q,&k,&v,m], q.device, |g, phs| unsafe {
450                let kt = ffi::graph_transpose(g, phs[1], nd-2, nd-1);
451                let scores = ffi::graph_matmul(g, phs[0], kt);
452                let scaled = ffi::graph_binary(g, "multiplicationWithPrimaryTensor:secondaryTensor:name:", scores, ffi::graph_constant_scalar(g, scale, ffi::MPSDataType::FLOAT32));
453                let masked = ffi::graph_select(g, phs[3], ffi::graph_constant_scalar(g, -1e9, ffi::MPSDataType::FLOAT32), scaled);
454                let max = ffi::graph_reduction_max_axis(g, masked, (nd-1) as isize);
455                let shifted = ffi::graph_binary(g, "subtractionWithPrimaryTensor:secondaryTensor:name:", masked, max);
456                let e = ffi::graph_unary(g, "exponentWithTensor:name:", shifted);
457                let s = ffi::graph_reduction_sum_axis(g, e, (nd-1) as isize);
458                let sm = ffi::graph_binary(g, "divisionWithPrimaryTensor:secondaryTensor:name:", e, s);
459                ffi::graph_matmul(g, sm, phs[2])
460            })
461        } else {
462            bridge::run_multi_ctx(&[&q,&k,&v], q.device, |g, phs| unsafe {
463                let kt = ffi::graph_transpose(g, phs[1], nd-2, nd-1);
464                let scores = ffi::graph_matmul(g, phs[0], kt);
465                let scaled = ffi::graph_binary(g, "multiplicationWithPrimaryTensor:secondaryTensor:name:", scores, ffi::graph_constant_scalar(g, scale, ffi::MPSDataType::FLOAT32));
466                let max = ffi::graph_reduction_max_axis(g, scaled, (nd-1) as isize);
467                let shifted = ffi::graph_binary(g, "subtractionWithPrimaryTensor:secondaryTensor:name:", scaled, max);
468                let e = ffi::graph_unary(g, "exponentWithTensor:name:", shifted);
469                let s = ffi::graph_reduction_sum_axis(g, e, (nd-1) as isize);
470                let sm = ffi::graph_binary(g, "divisionWithPrimaryTensor:secondaryTensor:name:", e, s);
471                ffi::graph_matmul(g, sm, phs[2])
472            })
473        }
474    }
475}
476
477// ─── Bilinear interpolation for deform_conv2d ───────────────────────────────
478
479fn bilinear_sample(data: &[f32], n: usize, c: usize, h: usize, w: usize, y: f32, x: f32, channels: usize) -> f32 {
480    if y <= -1.0 || y >= h as f32 || x <= -1.0 || x >= w as f32 { return 0.0; }
481    let y_low = y.floor() as isize;
482    let x_low = x.floor() as isize;
483    let y_high = y_low + 1;
484    let x_high = x_low + 1;
485
486    let get = |yy: isize, xx: isize| -> f32 {
487        if yy < 0 || yy >= h as isize || xx < 0 || xx >= w as isize { return 0.0; }
488        data[((n * channels + c) * h + yy as usize) * w + xx as usize]
489    };
490
491    let ly = y - y_low as f32;
492    let lx = x - x_low as f32;
493    let hy = 1.0 - ly;
494    let hx = 1.0 - lx;
495
496    hy * hx * get(y_low, x_low) + hy * lx * get(y_low, x_high) +
497    ly * hx * get(y_high, x_low) + ly * lx * get(y_high, x_high)
498}