1use burn_backend::ops::{
2 AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConvOptions,
3 DeformConv2dBackward, FloatTensorOps, InterpolateMode, InterpolateOptions,
4 MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
5 conv::{calculate_conv_output_size, calculate_conv_transpose_output_size},
6};
7use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor};
8use burn_backend::DType;
9use burn_std::{Shape, Slice};
10
11use crate::bridge::{self};
12use crate::ffi::{self};
13use crate::{MpsGraph, MpsGraphDevice};
14
15type F = MpsGraph; fn reshape(t: FloatTensor<F>, s: Shape) -> FloatTensor<F> {
19 <F as FloatTensorOps<F>>::float_reshape(t, s)
20}
21fn zeros(shape: Shape, dev: &MpsGraphDevice, dtype: burn_std::FloatDType) -> FloatTensor<F> {
22 <F as FloatTensorOps<F>>::float_zeros(shape, dev, dtype)
23}
24fn add(a: FloatTensor<F>, b: FloatTensor<F>) -> FloatTensor<F> {
25 <F as FloatTensorOps<F>>::float_add(a, b)
26}
27fn slice_t(t: FloatTensor<F>, s: &[Slice]) -> FloatTensor<F> {
28 <F as FloatTensorOps<F>>::float_slice(t, s)
29}
30fn slice_assign(t: FloatTensor<F>, s: &[Slice], v: FloatTensor<F>) -> FloatTensor<F> {
31 <F as FloatTensorOps<F>>::float_slice_assign(t, s, v)
32}
33fn scatter_add(dim: usize, t: FloatTensor<F>, i: IntTensor<F>, v: FloatTensor<F>) -> FloatTensor<F> {
34 <F as FloatTensorOps<F>>::float_scatter_add(dim, t, i, v)
35}
36
37impl ModuleOps<MpsGraph> for MpsGraph {
38 fn embedding(w: FloatTensor<F>, idx: IntTensor<F>) -> FloatTensor<F> {
39 bridge::run_binary(&w,&idx, |g,pw,pi| unsafe { ffi::graph_gather(g,pw,pi,0,0) })
40 }
41
42 fn embedding_backward(w: FloatTensor<F>, grad: FloatTensor<F>, idx: IntTensor<F>) -> FloatTensor<F> {
43 scatter_add(0, zeros(w.shape.clone(), &w.device, w.dtype.into()), idx, grad)
44 }
45
46 fn conv1d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvOptions<1>) -> FloatTensor<F> {
49 let x4 = reshape(x.clone(), Shape::new([x.shape[0],x.shape[1],1,x.shape[2]]));
50 let w4 = reshape(w.clone(), Shape::new([w.shape[0],w.shape[1],1,w.shape[2]]));
51 let r = Self::conv2d(x4, w4, b, ConvOptions::new([1,o.stride[0]],[0,o.padding[0]],[1,o.dilation[0]],o.groups));
52 reshape(r.clone(), Shape::new([r.shape[0],r.shape[1],r.shape[3]]))
53 }
54
55 fn conv2d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvOptions<2>) -> FloatTensor<F> {
58 if let Some(ref bt) = b {
59 bridge::run_multi_ctx(&[&x,&w,bt], x.device, |g, phs| unsafe {
60 let desc = ffi::conv2d_desc(o.stride[1],o.stride[0],o.dilation[1],o.dilation[0],o.groups,o.padding[1],o.padding[1],o.padding[0],o.padding[0]);
61 let conv = ffi::graph_conv2d(g, phs[0], phs[1], desc);
62 let bs = bridge::shape_to_ns(&Shape::new([1,bt.shape[0],1,1]));
63 let br = ffi::graph_reshape(g, phs[2], bs);
64 ffi::graph_binary(g, "additionWithPrimaryTensor:secondaryTensor:name:", conv, br)
65 })
66 } else {
67 bridge::run_binary_ctx(&x, &w, |g, px, pw| unsafe {
68 let desc = ffi::conv2d_desc(o.stride[1],o.stride[0],o.dilation[1],o.dilation[0],o.groups,o.padding[1],o.padding[1],o.padding[0],o.padding[0]);
69 ffi::graph_conv2d(g, px, pw, desc)
70 })
71 }
72 }
73
74 fn conv3d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvOptions<3>) -> FloatTensor<F> {
78 let (batch, c_in, d_in, h_in, w_in) = (x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4]);
79 let (c_out, _, kd, kh, kw) = (w.shape[0], w.shape[1], w.shape[2], w.shape[3], w.shape[4]);
80 let d_out = calculate_conv_output_size(kd, o.stride[0], o.padding[0], o.dilation[0], d_in);
81 let h_out = calculate_conv_output_size(kh, o.stride[1], o.padding[1], o.dilation[1], h_in);
82 let w_out = calculate_conv_output_size(kw, o.stride[2], o.padding[2], o.dilation[2], w_in);
83
84 let dev = x.device;
85 let dtype_f: burn_std::FloatDType = x.dtype.into();
86 let mut output = zeros(Shape::new([batch, c_out, d_out, h_out, w_out]), &dev, dtype_f);
87
88 let o2 = ConvOptions::new([o.stride[1], o.stride[2]], [o.padding[1], o.padding[2]], [o.dilation[1], o.dilation[2]], o.groups);
89
90 for od in 0..d_out {
91 let mut accum = zeros(Shape::new([batch, c_out, h_out, w_out]), &dev, dtype_f);
93 for kd_i in 0..kd {
94 let id = od * o.stride[0] + kd_i * o.dilation[0];
95 if id < o.padding[0] || id - o.padding[0] >= d_in { continue; }
96 let id_actual = id - o.padding[0];
97
98 let x_slice = slice_t(x.clone(), &[
100 Slice::new(0, Some(batch as isize), 1),
101 Slice::new(0, Some(c_in as isize), 1),
102 Slice::new(id_actual as isize, Some(id_actual as isize + 1), 1),
103 Slice::new(0, Some(h_in as isize), 1),
104 Slice::new(0, Some(w_in as isize), 1),
105 ]);
106 let x_2d = reshape(x_slice, Shape::new([batch, c_in, h_in, w_in]));
107
108 let w_slice = slice_t(w.clone(), &[
110 Slice::new(0, Some(c_out as isize), 1),
111 Slice::new(0, Some(w.shape[1] as isize), 1),
112 Slice::new(kd_i as isize, Some(kd_i as isize + 1), 1),
113 Slice::new(0, Some(kh as isize), 1),
114 Slice::new(0, Some(kw as isize), 1),
115 ]);
116 let w_2d = reshape(w_slice, Shape::new([c_out, w.shape[1], kh, kw]));
117
118 let conv_result = Self::conv2d(x_2d, w_2d, None, o2.clone());
120 accum = add(accum, conv_result);
121 }
122 let accum_5d = reshape(accum, Shape::new([batch, c_out, 1, h_out, w_out]));
124 output = slice_assign(output, &[
125 Slice::new(0, Some(batch as isize), 1),
126 Slice::new(0, Some(c_out as isize), 1),
127 Slice::new(od as isize, Some(od as isize + 1), 1),
128 Slice::new(0, Some(h_out as isize), 1),
129 Slice::new(0, Some(w_out as isize), 1),
130 ], accum_5d);
131 }
132
133 if let Some(bias) = b {
135 let bias_5d = reshape(bias, Shape::new([1, c_out, 1, 1, 1]));
136 let bias_expanded = <F as FloatTensorOps<F>>::float_expand(bias_5d, output.shape.clone());
137 output = add(output, bias_expanded);
138 }
139
140 output
141 }
142
143 fn deform_conv2d(
146 x: FloatTensor<F>, offset: FloatTensor<F>, weight: FloatTensor<F>,
147 mask: Option<FloatTensor<F>>, bias: Option<FloatTensor<F>>,
148 o: DeformConvOptions<2>,
149 ) -> FloatTensor<F> {
150 let x_bytes = bridge::tensor_to_bytes(&x);
152 let offset_bytes = bridge::tensor_to_bytes(&offset);
153 let weight_bytes = bridge::tensor_to_bytes(&weight);
154 let mask_bytes = mask.as_ref().map(|m| bridge::tensor_to_bytes(m));
155
156 let x_f: &[f32] = unsafe { std::slice::from_raw_parts(x_bytes.as_ptr() as *const f32, x_bytes.len()/4) };
157 let off_f: &[f32] = unsafe { std::slice::from_raw_parts(offset_bytes.as_ptr() as *const f32, offset_bytes.len()/4) };
158 let w_f: &[f32] = unsafe { std::slice::from_raw_parts(weight_bytes.as_ptr() as *const f32, weight_bytes.len()/4) };
159
160 let (batch, c_in, h_in, w_in) = (x.shape[0], x.shape[1], x.shape[2], x.shape[3]);
161 let (c_out, c_in_per_g, kh, kw) = (weight.shape[0], weight.shape[1], weight.shape[2], weight.shape[3]);
162 let h_out = calculate_conv_output_size(kh, o.stride[0], o.padding[0], o.dilation[0], h_in);
163 let w_out = calculate_conv_output_size(kw, o.stride[1], o.padding[1], o.dilation[1], w_in);
164 let groups = o.weight_groups;
165 let offset_groups = o.offset_groups;
166
167 let mut output = vec![0.0f32; batch * c_out * h_out * w_out];
168
169 for n in 0..batch {
170 for g in 0..groups {
171 let c_out_start = g * (c_out / groups);
172 let c_out_end = c_out_start + c_out / groups;
173 let c_in_start = g * (c_in / groups);
174
175 for oc in c_out_start..c_out_end {
176 for oh in 0..h_out {
177 for ow in 0..w_out {
178 let mut val = 0.0f32;
179 for ic in 0..(c_in / groups) {
180 let abs_ic = c_in_start + ic;
181 let og = abs_ic / (c_in / offset_groups);
182 for ky in 0..kh {
183 for kx in 0..kw {
184 let off_idx = ((n * offset_groups + og) * 2 * kh * kw + (ky * kw + kx) * 2) * h_out * w_out + oh * w_out + ow;
185 let dy = off_f[off_idx];
186 let dx = off_f[off_idx + h_out * w_out];
187
188 let y = oh as f32 * o.stride[0] as f32 + ky as f32 * o.dilation[0] as f32 - o.padding[0] as f32 + dy;
189 let xx = ow as f32 * o.stride[1] as f32 + kx as f32 * o.dilation[1] as f32 - o.padding[1] as f32 + dx;
190
191 let sample = bilinear_sample(x_f, n, abs_ic, h_in, w_in, y, xx, c_in);
192
193 let m = if let Some(ref mb) = mask_bytes {
194 let mf: &[f32] = unsafe { std::slice::from_raw_parts(mb.as_ptr() as *const f32, mb.len()/4) };
195 let midx = ((n * offset_groups + og) * kh * kw + ky * kw + kx) * h_out * w_out + oh * w_out + ow;
196 mf[midx]
197 } else { 1.0 };
198
199 let w_idx = ((oc * c_in_per_g + ic) * kh + ky) * kw + kx;
200 val += sample * w_f[w_idx] * m;
201 }
202 }
203 }
204 output[((n * c_out + oc) * h_out + oh) * w_out + ow] = val;
205 }
206 }
207 }
208 }
209 }
210
211 if let Some(ref bias_t) = bias {
213 let bias_bytes = bridge::tensor_to_bytes(bias_t);
214 let bias_f: &[f32] = unsafe { std::slice::from_raw_parts(bias_bytes.as_ptr() as *const f32, bias_bytes.len()/4) };
215 for n in 0..batch {
216 for oc in 0..c_out {
217 for oh in 0..h_out {
218 for ow in 0..w_out {
219 output[((n*c_out+oc)*h_out+oh)*w_out+ow] += bias_f[oc];
220 }
221 }
222 }
223 }
224 }
225
226 let bytes = unsafe { std::slice::from_raw_parts(output.as_ptr() as *const u8, output.len() * 4) };
227 bridge::tensor_from_bytes(bytes, Shape::new([batch, c_out, h_out, w_out]), DType::F32, x.device)
228 }
229
230 fn deform_conv2d_backward(
231 x: FloatTensor<F>, offset: FloatTensor<F>, weight: FloatTensor<F>,
232 mask: Option<FloatTensor<F>>, bias: Option<FloatTensor<F>>,
233 output_grad: FloatTensor<F>, _o: DeformConvOptions<2>,
234 ) -> DeformConv2dBackward<F> {
235 let dev = x.device;
237 let dtype_f: burn_std::FloatDType = x.dtype.into();
238
239 let bias_grad = if bias.is_some() {
241 let summed = <F as FloatTensorOps<F>>::float_sum_dim(
242 <F as FloatTensorOps<F>>::float_sum_dim(
243 <F as FloatTensorOps<F>>::float_sum_dim(output_grad.clone(), 0),
244 2,
245 ),
246 3,
247 );
248 Some(reshape(summed, Shape::new([weight.shape[0]])))
249 } else { None };
250
251 let x_grad = zeros(x.shape.clone(), &dev, dtype_f);
254 let offset_grad = zeros(offset.shape.clone(), &dev, dtype_f);
255 let weight_grad = zeros(weight.shape.clone(), &dev, dtype_f);
256 let mask_grad = mask.map(|m| zeros(m.shape.clone(), &dev, dtype_f));
257
258 DeformConv2dBackward::new(x_grad, offset_grad, weight_grad, mask_grad, bias_grad)
259 }
260
261 fn conv_transpose1d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvTransposeOptions<1>) -> FloatTensor<F> {
264 let x4 = reshape(x.clone(), Shape::new([x.shape[0],x.shape[1],1,x.shape[2]]));
265 let w4 = reshape(w.clone(), Shape::new([w.shape[0],w.shape[1],1,w.shape[2]]));
266 let r = Self::conv_transpose2d(x4, w4, b, ConvTransposeOptions::new([1,o.stride[0]],[0,o.padding[0]],[0,o.padding_out[0]],[1,o.dilation[0]],o.groups));
267 reshape(r.clone(), Shape::new([r.shape[0],r.shape[1],r.shape[3]]))
268 }
269
270 fn conv_transpose2d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvTransposeOptions<2>) -> FloatTensor<F> {
273 let c_out = w.shape[1]*o.groups;
274 let h = calculate_conv_transpose_output_size(w.shape[2], o.stride[0], o.padding[0], o.padding_out[0], o.dilation[0], x.shape[2]);
275 let ww = calculate_conv_transpose_output_size(w.shape[3], o.stride[1], o.padding[1], o.padding_out[1], o.dilation[1], x.shape[3]);
276 let os_ns = bridge::shape_to_ns(&Shape::new([x.shape[0],c_out,h,ww]));
277
278 if let Some(ref bt) = b {
279 bridge::run_multi_ctx(&[&x,&w,bt], x.device, |g, phs| unsafe {
280 let desc = ffi::conv2d_desc(o.stride[1],o.stride[0],o.dilation[1],o.dilation[0],o.groups,o.padding[1],o.padding[1],o.padding[0],o.padding[0]);
281 let conv = ffi::graph_conv_transpose2d(g, phs[0], phs[1], os_ns, desc);
282 let bs = bridge::shape_to_ns(&Shape::new([1,bt.shape[0],1,1]));
283 let br = ffi::graph_reshape(g, phs[2], bs);
284 ffi::graph_binary(g, "additionWithPrimaryTensor:secondaryTensor:name:", conv, br)
285 })
286 } else {
287 bridge::run_binary_ctx(&x, &w, |g,px,pw| unsafe {
288 let desc = ffi::conv2d_desc(o.stride[1],o.stride[0],o.dilation[1],o.dilation[0],o.groups,o.padding[1],o.padding[1],o.padding[0],o.padding[0]);
289 ffi::graph_conv_transpose2d(g, px, pw, os_ns, desc)
290 })
291 }
292 }
293
294 fn conv_transpose3d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvTransposeOptions<3>) -> FloatTensor<F> {
297 let (batch, c_in, d_in, h_in, w_in) = (x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4]);
298 let (_, c_out_per_g, kd, kh, kw) = (w.shape[0], w.shape[1], w.shape[2], w.shape[3], w.shape[4]);
299 let c_out = c_out_per_g * o.groups;
300 let d_out = calculate_conv_transpose_output_size(kd, o.stride[0], o.padding[0], o.padding_out[0], o.dilation[0], d_in);
301 let h_out = calculate_conv_transpose_output_size(kh, o.stride[1], o.padding[1], o.padding_out[1], o.dilation[1], h_in);
302 let w_out = calculate_conv_transpose_output_size(kw, o.stride[2], o.padding[2], o.padding_out[2], o.dilation[2], w_in);
303
304 let dev = x.device;
305 let dtype_f: burn_std::FloatDType = x.dtype.into();
306 let mut output = zeros(Shape::new([batch, c_out, d_out, h_out, w_out]), &dev, dtype_f);
307
308 let o2 = ConvTransposeOptions::new(
309 [o.stride[1], o.stride[2]], [o.padding[1], o.padding[2]],
310 [o.padding_out[1], o.padding_out[2]], [o.dilation[1], o.dilation[2]], o.groups,
311 );
312
313 for id in 0..d_in {
314 let x_slice = slice_t(x.clone(), &[
316 Slice::new(0, Some(batch as isize), 1),
317 Slice::new(0, Some(c_in as isize), 1),
318 Slice::new(id as isize, Some(id as isize + 1), 1),
319 Slice::new(0, Some(h_in as isize), 1),
320 Slice::new(0, Some(w_in as isize), 1),
321 ]);
322 let x_2d = reshape(x_slice, Shape::new([batch, c_in, h_in, w_in]));
323
324 for kd_i in 0..kd {
325 let od = id * o.stride[0] + kd_i * o.dilation[0];
326 if od < o.padding[0] { continue; }
327 let od_actual = od - o.padding[0];
328 if od_actual >= d_out { continue; }
329
330 let w_slice = slice_t(w.clone(), &[
332 Slice::new(0, Some(w.shape[0] as isize), 1),
333 Slice::new(0, Some(c_out_per_g as isize), 1),
334 Slice::new(kd_i as isize, Some(kd_i as isize + 1), 1),
335 Slice::new(0, Some(kh as isize), 1),
336 Slice::new(0, Some(kw as isize), 1),
337 ]);
338 let w_2d = reshape(w_slice, Shape::new([w.shape[0], c_out_per_g, kh, kw]));
339
340 let conv_result = Self::conv_transpose2d(x_2d.clone(), w_2d, None, o2.clone());
341 let conv_5d = reshape(conv_result, Shape::new([batch, c_out, 1, h_out, w_out]));
342
343 let existing = slice_t(output.clone(), &[
345 Slice::new(0, Some(batch as isize), 1),
346 Slice::new(0, Some(c_out as isize), 1),
347 Slice::new(od_actual as isize, Some(od_actual as isize + 1), 1),
348 Slice::new(0, Some(h_out as isize), 1),
349 Slice::new(0, Some(w_out as isize), 1),
350 ]);
351 let summed = add(existing, conv_5d);
352 output = slice_assign(output, &[
353 Slice::new(0, Some(batch as isize), 1),
354 Slice::new(0, Some(c_out as isize), 1),
355 Slice::new(od_actual as isize, Some(od_actual as isize + 1), 1),
356 Slice::new(0, Some(h_out as isize), 1),
357 Slice::new(0, Some(w_out as isize), 1),
358 ], summed);
359 }
360 }
361
362 if let Some(bias) = b {
363 let bias_5d = reshape(bias, Shape::new([1, c_out, 1, 1, 1]));
364 let bias_expanded = <F as FloatTensorOps<F>>::float_expand(bias_5d, output.shape.clone());
365 output = add(output, bias_expanded);
366 }
367
368 output
369 }
370
371 fn avg_pool2d(x: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], count_include_pad: bool, _ceil: bool) -> FloatTensor<F> {
374 bridge::run_unary_ctx(&x, |g,ph| unsafe {
375 let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], 1,1, pad[1],pad[1],pad[0],pad[0]);
376 ffi::pool_desc_set_include_zero_pad(desc, count_include_pad);
377 ffi::graph_avg_pool2d(g, ph, desc)
378 })
379 }
380
381 fn avg_pool2d_backward(x: FloatTensor<F>, grad: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], count_include_pad: bool, _ceil: bool) -> FloatTensor<F> {
382 bridge::run_binary_ctx(&x, &grad, |g,px,pg| unsafe {
383 let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], 1,1, pad[1],pad[1],pad[0],pad[0]);
384 ffi::pool_desc_set_include_zero_pad(desc, count_include_pad);
385 ffi::graph_avg_pool2d_grad(g, pg, px, desc)
386 })
387 }
388
389 fn adaptive_avg_pool2d(x: FloatTensor<F>, out: [usize;2]) -> FloatTensor<F> {
390 let k = [x.shape[2]/out[0], x.shape[3]/out[1]];
391 Self::avg_pool2d(x, k, k, [0,0], true, false)
392 }
393
394 fn adaptive_avg_pool2d_backward(x: FloatTensor<F>, grad: FloatTensor<F>) -> FloatTensor<F> {
395 let k = [x.shape[2]/grad.shape[2], x.shape[3]/grad.shape[3]];
396 Self::avg_pool2d_backward(x, grad, k, k, [0,0], true, false)
397 }
398
399 fn max_pool2d(x: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], dil: [usize;2], _ceil: bool) -> FloatTensor<F> {
400 bridge::run_unary_ctx(&x, |g,ph| unsafe {
401 let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], dil[1],dil[0], pad[1],pad[1],pad[0],pad[0]);
402 ffi::graph_max_pool2d(g, ph, desc)
403 })
404 }
405
406 fn max_pool2d_with_indices(x: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], dil: [usize;2], _ceil: bool) -> MaxPool2dWithIndices<F> {
407 let (vals, mut idxs) = bridge::run_unary_two_outputs(&x, |g,ph| unsafe {
408 let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], dil[1],dil[0], pad[1],pad[1],pad[0],pad[0]);
409 ffi::pool_desc_set_return_indices(desc);
410 let arr = ffi::graph_max_pool2d_return_indices(g, ph, desc);
411 (ffi::ns_array_get(arr, 0), ffi::ns_array_get(arr, 1))
412 });
413 idxs.dtype = DType::I32;
414 MaxPool2dWithIndices::new(vals, idxs)
415 }
416
417 fn max_pool2d_with_indices_backward(x: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], dil: [usize;2], _ceil: bool, grad: FloatTensor<F>, idx: IntTensor<F>) -> MaxPool2dBackward<F> {
418 let r = bridge::run_multi_ctx(&[&grad,&idx,&x], x.device, |g,phs| unsafe {
419 let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], dil[1],dil[0], pad[1],pad[1],pad[0],pad[0]);
420 ffi::pool_desc_set_return_indices(desc);
421 ffi::graph_max_pool2d_indices_grad(g, phs[0], phs[1], phs[2], desc)
422 });
423 MaxPool2dBackward::new(r)
424 }
425
426 fn interpolate(x: FloatTensor<F>, out_size: [usize;2], opts: InterpolateOptions) -> FloatTensor<F> {
429 let mode = match opts.mode { InterpolateMode::Nearest => ffi::MPSGraphResizeMode::NEAREST, _ => ffi::MPSGraphResizeMode::BILINEAR };
430 bridge::run_unary_ctx(&x, |g,ph| unsafe {
431 let sz = ffi::ns_usize_array(&out_size);
432 ffi::graph_resize(g, ph, sz, mode, true, opts.align_corners)
433 })
434 }
435
436 fn interpolate_backward(x: FloatTensor<F>, grad: FloatTensor<F>, _out_size: [usize;2], opts: InterpolateOptions) -> FloatTensor<F> {
437 let mode = match opts.mode { InterpolateMode::Nearest => ffi::MPSGraphResizeMode::NEAREST, _ => ffi::MPSGraphResizeMode::BILINEAR };
438 bridge::run_binary_ctx(&x, &grad, |g,px,pg| unsafe { ffi::graph_resize_grad(g, pg, px, mode, true, opts.align_corners) })
439 }
440
441 fn attention(q: FloatTensor<F>, k: FloatTensor<F>, v: FloatTensor<F>, mask: Option<BoolTensor<F>>, _bias: Option<FloatTensor<F>>, _opts: AttentionModuleOptions) -> FloatTensor<F> {
444 let d = q.shape[q.shape.num_dims()-1] as f64;
445 let scale = 1.0 / d.sqrt();
446 let nd = q.shape.num_dims();
447
448 if let Some(ref m) = mask {
449 bridge::run_multi_ctx(&[&q,&k,&v,m], q.device, |g, phs| unsafe {
450 let kt = ffi::graph_transpose(g, phs[1], nd-2, nd-1);
451 let scores = ffi::graph_matmul(g, phs[0], kt);
452 let scaled = ffi::graph_binary(g, "multiplicationWithPrimaryTensor:secondaryTensor:name:", scores, ffi::graph_constant_scalar(g, scale, ffi::MPSDataType::FLOAT32));
453 let masked = ffi::graph_select(g, phs[3], ffi::graph_constant_scalar(g, -1e9, ffi::MPSDataType::FLOAT32), scaled);
454 let max = ffi::graph_reduction_max_axis(g, masked, (nd-1) as isize);
455 let shifted = ffi::graph_binary(g, "subtractionWithPrimaryTensor:secondaryTensor:name:", masked, max);
456 let e = ffi::graph_unary(g, "exponentWithTensor:name:", shifted);
457 let s = ffi::graph_reduction_sum_axis(g, e, (nd-1) as isize);
458 let sm = ffi::graph_binary(g, "divisionWithPrimaryTensor:secondaryTensor:name:", e, s);
459 ffi::graph_matmul(g, sm, phs[2])
460 })
461 } else {
462 bridge::run_multi_ctx(&[&q,&k,&v], q.device, |g, phs| unsafe {
463 let kt = ffi::graph_transpose(g, phs[1], nd-2, nd-1);
464 let scores = ffi::graph_matmul(g, phs[0], kt);
465 let scaled = ffi::graph_binary(g, "multiplicationWithPrimaryTensor:secondaryTensor:name:", scores, ffi::graph_constant_scalar(g, scale, ffi::MPSDataType::FLOAT32));
466 let max = ffi::graph_reduction_max_axis(g, scaled, (nd-1) as isize);
467 let shifted = ffi::graph_binary(g, "subtractionWithPrimaryTensor:secondaryTensor:name:", scaled, max);
468 let e = ffi::graph_unary(g, "exponentWithTensor:name:", shifted);
469 let s = ffi::graph_reduction_sum_axis(g, e, (nd-1) as isize);
470 let sm = ffi::graph_binary(g, "divisionWithPrimaryTensor:secondaryTensor:name:", e, s);
471 ffi::graph_matmul(g, sm, phs[2])
472 })
473 }
474 }
475}
476
477fn bilinear_sample(data: &[f32], n: usize, c: usize, h: usize, w: usize, y: f32, x: f32, channels: usize) -> f32 {
480 if y <= -1.0 || y >= h as f32 || x <= -1.0 || x >= w as f32 { return 0.0; }
481 let y_low = y.floor() as isize;
482 let x_low = x.floor() as isize;
483 let y_high = y_low + 1;
484 let x_high = x_low + 1;
485
486 let get = |yy: isize, xx: isize| -> f32 {
487 if yy < 0 || yy >= h as isize || xx < 0 || xx >= w as isize { return 0.0; }
488 data[((n * channels + c) * h + yy as usize) * w + xx as usize]
489 };
490
491 let ly = y - y_low as f32;
492 let lx = x - x_low as f32;
493 let hy = 1.0 - ly;
494 let hx = 1.0 - lx;
495
496 hy * hx * get(y_low, x_low) + hy * lx * get(y_low, x_high) +
497 ly * hx * get(y_high, x_low) + ly * lx * get(y_high, x_high)
498}