use burn_backend::ops::{
AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConvOptions,
DeformConv2dBackward, FloatTensorOps, InterpolateMode, InterpolateOptions,
MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
conv::{calculate_conv_output_size, calculate_conv_transpose_output_size},
};
use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor};
use burn_backend::DType;
use burn_std::{Shape, Slice};
use crate::bridge::{self};
use crate::ffi::{self};
use crate::{MpsGraph, MpsGraphDevice};
type F = MpsGraph;
fn reshape(t: FloatTensor<F>, s: Shape) -> FloatTensor<F> {
<F as FloatTensorOps<F>>::float_reshape(t, s)
}
fn zeros(shape: Shape, dev: &MpsGraphDevice, dtype: burn_std::FloatDType) -> FloatTensor<F> {
<F as FloatTensorOps<F>>::float_zeros(shape, dev, dtype)
}
fn add(a: FloatTensor<F>, b: FloatTensor<F>) -> FloatTensor<F> {
<F as FloatTensorOps<F>>::float_add(a, b)
}
fn slice_t(t: FloatTensor<F>, s: &[Slice]) -> FloatTensor<F> {
<F as FloatTensorOps<F>>::float_slice(t, s)
}
fn slice_assign(t: FloatTensor<F>, s: &[Slice], v: FloatTensor<F>) -> FloatTensor<F> {
<F as FloatTensorOps<F>>::float_slice_assign(t, s, v)
}
fn scatter_add(dim: usize, t: FloatTensor<F>, i: IntTensor<F>, v: FloatTensor<F>) -> FloatTensor<F> {
<F as FloatTensorOps<F>>::float_scatter_add(dim, t, i, v)
}
impl ModuleOps<MpsGraph> for MpsGraph {
fn embedding(w: FloatTensor<F>, idx: IntTensor<F>) -> FloatTensor<F> {
bridge::run_binary(&w,&idx, |g,pw,pi| unsafe { ffi::graph_gather(g,pw,pi,0,0) })
}
fn embedding_backward(w: FloatTensor<F>, grad: FloatTensor<F>, idx: IntTensor<F>) -> FloatTensor<F> {
scatter_add(0, zeros(w.shape.clone(), &w.device, w.dtype.into()), idx, grad)
}
fn conv1d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvOptions<1>) -> FloatTensor<F> {
let x4 = reshape(x.clone(), Shape::new([x.shape[0],x.shape[1],1,x.shape[2]]));
let w4 = reshape(w.clone(), Shape::new([w.shape[0],w.shape[1],1,w.shape[2]]));
let r = Self::conv2d(x4, w4, b, ConvOptions::new([1,o.stride[0]],[0,o.padding[0]],[1,o.dilation[0]],o.groups));
reshape(r.clone(), Shape::new([r.shape[0],r.shape[1],r.shape[3]]))
}
fn conv2d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvOptions<2>) -> FloatTensor<F> {
if let Some(ref bt) = b {
bridge::run_multi_ctx(&[&x,&w,bt], x.device, |g, phs| unsafe {
let desc = ffi::conv2d_desc(o.stride[1],o.stride[0],o.dilation[1],o.dilation[0],o.groups,o.padding[1],o.padding[1],o.padding[0],o.padding[0]);
let conv = ffi::graph_conv2d(g, phs[0], phs[1], desc);
let bs = bridge::shape_to_ns(&Shape::new([1,bt.shape[0],1,1]));
let br = ffi::graph_reshape(g, phs[2], bs);
ffi::graph_binary(g, "additionWithPrimaryTensor:secondaryTensor:name:", conv, br)
})
} else {
bridge::run_binary_ctx(&x, &w, |g, px, pw| unsafe {
let desc = ffi::conv2d_desc(o.stride[1],o.stride[0],o.dilation[1],o.dilation[0],o.groups,o.padding[1],o.padding[1],o.padding[0],o.padding[0]);
ffi::graph_conv2d(g, px, pw, desc)
})
}
}
fn conv3d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvOptions<3>) -> FloatTensor<F> {
let (batch, c_in, d_in, h_in, w_in) = (x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4]);
let (c_out, _, kd, kh, kw) = (w.shape[0], w.shape[1], w.shape[2], w.shape[3], w.shape[4]);
let d_out = calculate_conv_output_size(kd, o.stride[0], o.padding[0], o.dilation[0], d_in);
let h_out = calculate_conv_output_size(kh, o.stride[1], o.padding[1], o.dilation[1], h_in);
let w_out = calculate_conv_output_size(kw, o.stride[2], o.padding[2], o.dilation[2], w_in);
let dev = x.device;
let dtype_f: burn_std::FloatDType = x.dtype.into();
let mut output = zeros(Shape::new([batch, c_out, d_out, h_out, w_out]), &dev, dtype_f);
let o2 = ConvOptions::new([o.stride[1], o.stride[2]], [o.padding[1], o.padding[2]], [o.dilation[1], o.dilation[2]], o.groups);
for od in 0..d_out {
let mut accum = zeros(Shape::new([batch, c_out, h_out, w_out]), &dev, dtype_f);
for kd_i in 0..kd {
let id = od * o.stride[0] + kd_i * o.dilation[0];
if id < o.padding[0] || id - o.padding[0] >= d_in { continue; }
let id_actual = id - o.padding[0];
let x_slice = slice_t(x.clone(), &[
Slice::new(0, Some(batch as isize), 1),
Slice::new(0, Some(c_in as isize), 1),
Slice::new(id_actual as isize, Some(id_actual as isize + 1), 1),
Slice::new(0, Some(h_in as isize), 1),
Slice::new(0, Some(w_in as isize), 1),
]);
let x_2d = reshape(x_slice, Shape::new([batch, c_in, h_in, w_in]));
let w_slice = slice_t(w.clone(), &[
Slice::new(0, Some(c_out as isize), 1),
Slice::new(0, Some(w.shape[1] as isize), 1),
Slice::new(kd_i as isize, Some(kd_i as isize + 1), 1),
Slice::new(0, Some(kh as isize), 1),
Slice::new(0, Some(kw as isize), 1),
]);
let w_2d = reshape(w_slice, Shape::new([c_out, w.shape[1], kh, kw]));
let conv_result = Self::conv2d(x_2d, w_2d, None, o2.clone());
accum = add(accum, conv_result);
}
let accum_5d = reshape(accum, Shape::new([batch, c_out, 1, h_out, w_out]));
output = slice_assign(output, &[
Slice::new(0, Some(batch as isize), 1),
Slice::new(0, Some(c_out as isize), 1),
Slice::new(od as isize, Some(od as isize + 1), 1),
Slice::new(0, Some(h_out as isize), 1),
Slice::new(0, Some(w_out as isize), 1),
], accum_5d);
}
if let Some(bias) = b {
let bias_5d = reshape(bias, Shape::new([1, c_out, 1, 1, 1]));
let bias_expanded = <F as FloatTensorOps<F>>::float_expand(bias_5d, output.shape.clone());
output = add(output, bias_expanded);
}
output
}
fn deform_conv2d(
x: FloatTensor<F>, offset: FloatTensor<F>, weight: FloatTensor<F>,
mask: Option<FloatTensor<F>>, bias: Option<FloatTensor<F>>,
o: DeformConvOptions<2>,
) -> FloatTensor<F> {
let x_bytes = bridge::tensor_to_bytes(&x);
let offset_bytes = bridge::tensor_to_bytes(&offset);
let weight_bytes = bridge::tensor_to_bytes(&weight);
let mask_bytes = mask.as_ref().map(|m| bridge::tensor_to_bytes(m));
let x_f: &[f32] = unsafe { std::slice::from_raw_parts(x_bytes.as_ptr() as *const f32, x_bytes.len()/4) };
let off_f: &[f32] = unsafe { std::slice::from_raw_parts(offset_bytes.as_ptr() as *const f32, offset_bytes.len()/4) };
let w_f: &[f32] = unsafe { std::slice::from_raw_parts(weight_bytes.as_ptr() as *const f32, weight_bytes.len()/4) };
let (batch, c_in, h_in, w_in) = (x.shape[0], x.shape[1], x.shape[2], x.shape[3]);
let (c_out, c_in_per_g, kh, kw) = (weight.shape[0], weight.shape[1], weight.shape[2], weight.shape[3]);
let h_out = calculate_conv_output_size(kh, o.stride[0], o.padding[0], o.dilation[0], h_in);
let w_out = calculate_conv_output_size(kw, o.stride[1], o.padding[1], o.dilation[1], w_in);
let groups = o.weight_groups;
let offset_groups = o.offset_groups;
let mut output = vec![0.0f32; batch * c_out * h_out * w_out];
for n in 0..batch {
for g in 0..groups {
let c_out_start = g * (c_out / groups);
let c_out_end = c_out_start + c_out / groups;
let c_in_start = g * (c_in / groups);
for oc in c_out_start..c_out_end {
for oh in 0..h_out {
for ow in 0..w_out {
let mut val = 0.0f32;
for ic in 0..(c_in / groups) {
let abs_ic = c_in_start + ic;
let og = abs_ic / (c_in / offset_groups);
for ky in 0..kh {
for kx in 0..kw {
let off_idx = ((n * offset_groups + og) * 2 * kh * kw + (ky * kw + kx) * 2) * h_out * w_out + oh * w_out + ow;
let dy = off_f[off_idx];
let dx = off_f[off_idx + h_out * w_out];
let y = oh as f32 * o.stride[0] as f32 + ky as f32 * o.dilation[0] as f32 - o.padding[0] as f32 + dy;
let xx = ow as f32 * o.stride[1] as f32 + kx as f32 * o.dilation[1] as f32 - o.padding[1] as f32 + dx;
let sample = bilinear_sample(x_f, n, abs_ic, h_in, w_in, y, xx, c_in);
let m = if let Some(ref mb) = mask_bytes {
let mf: &[f32] = unsafe { std::slice::from_raw_parts(mb.as_ptr() as *const f32, mb.len()/4) };
let midx = ((n * offset_groups + og) * kh * kw + ky * kw + kx) * h_out * w_out + oh * w_out + ow;
mf[midx]
} else { 1.0 };
let w_idx = ((oc * c_in_per_g + ic) * kh + ky) * kw + kx;
val += sample * w_f[w_idx] * m;
}
}
}
output[((n * c_out + oc) * h_out + oh) * w_out + ow] = val;
}
}
}
}
}
if let Some(ref bias_t) = bias {
let bias_bytes = bridge::tensor_to_bytes(bias_t);
let bias_f: &[f32] = unsafe { std::slice::from_raw_parts(bias_bytes.as_ptr() as *const f32, bias_bytes.len()/4) };
for n in 0..batch {
for oc in 0..c_out {
for oh in 0..h_out {
for ow in 0..w_out {
output[((n*c_out+oc)*h_out+oh)*w_out+ow] += bias_f[oc];
}
}
}
}
}
let bytes = unsafe { std::slice::from_raw_parts(output.as_ptr() as *const u8, output.len() * 4) };
bridge::tensor_from_bytes(bytes, Shape::new([batch, c_out, h_out, w_out]), DType::F32, x.device)
}
fn deform_conv2d_backward(
x: FloatTensor<F>, offset: FloatTensor<F>, weight: FloatTensor<F>,
mask: Option<FloatTensor<F>>, bias: Option<FloatTensor<F>>,
output_grad: FloatTensor<F>, _o: DeformConvOptions<2>,
) -> DeformConv2dBackward<F> {
let dev = x.device;
let dtype_f: burn_std::FloatDType = x.dtype.into();
let bias_grad = if bias.is_some() {
let summed = <F as FloatTensorOps<F>>::float_sum_dim(
<F as FloatTensorOps<F>>::float_sum_dim(
<F as FloatTensorOps<F>>::float_sum_dim(output_grad.clone(), 0),
2,
),
3,
);
Some(reshape(summed, Shape::new([weight.shape[0]])))
} else { None };
let x_grad = zeros(x.shape.clone(), &dev, dtype_f);
let offset_grad = zeros(offset.shape.clone(), &dev, dtype_f);
let weight_grad = zeros(weight.shape.clone(), &dev, dtype_f);
let mask_grad = mask.map(|m| zeros(m.shape.clone(), &dev, dtype_f));
DeformConv2dBackward::new(x_grad, offset_grad, weight_grad, mask_grad, bias_grad)
}
fn conv_transpose1d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvTransposeOptions<1>) -> FloatTensor<F> {
let x4 = reshape(x.clone(), Shape::new([x.shape[0],x.shape[1],1,x.shape[2]]));
let w4 = reshape(w.clone(), Shape::new([w.shape[0],w.shape[1],1,w.shape[2]]));
let r = Self::conv_transpose2d(x4, w4, b, ConvTransposeOptions::new([1,o.stride[0]],[0,o.padding[0]],[0,o.padding_out[0]],[1,o.dilation[0]],o.groups));
reshape(r.clone(), Shape::new([r.shape[0],r.shape[1],r.shape[3]]))
}
fn conv_transpose2d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvTransposeOptions<2>) -> FloatTensor<F> {
let c_out = w.shape[1]*o.groups;
let h = calculate_conv_transpose_output_size(w.shape[2], o.stride[0], o.padding[0], o.padding_out[0], o.dilation[0], x.shape[2]);
let ww = calculate_conv_transpose_output_size(w.shape[3], o.stride[1], o.padding[1], o.padding_out[1], o.dilation[1], x.shape[3]);
let os_ns = bridge::shape_to_ns(&Shape::new([x.shape[0],c_out,h,ww]));
if let Some(ref bt) = b {
bridge::run_multi_ctx(&[&x,&w,bt], x.device, |g, phs| unsafe {
let desc = ffi::conv2d_desc(o.stride[1],o.stride[0],o.dilation[1],o.dilation[0],o.groups,o.padding[1],o.padding[1],o.padding[0],o.padding[0]);
let conv = ffi::graph_conv_transpose2d(g, phs[0], phs[1], os_ns, desc);
let bs = bridge::shape_to_ns(&Shape::new([1,bt.shape[0],1,1]));
let br = ffi::graph_reshape(g, phs[2], bs);
ffi::graph_binary(g, "additionWithPrimaryTensor:secondaryTensor:name:", conv, br)
})
} else {
bridge::run_binary_ctx(&x, &w, |g,px,pw| unsafe {
let desc = ffi::conv2d_desc(o.stride[1],o.stride[0],o.dilation[1],o.dilation[0],o.groups,o.padding[1],o.padding[1],o.padding[0],o.padding[0]);
ffi::graph_conv_transpose2d(g, px, pw, os_ns, desc)
})
}
}
fn conv_transpose3d(x: FloatTensor<F>, w: FloatTensor<F>, b: Option<FloatTensor<F>>, o: ConvTransposeOptions<3>) -> FloatTensor<F> {
let (batch, c_in, d_in, h_in, w_in) = (x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4]);
let (_, c_out_per_g, kd, kh, kw) = (w.shape[0], w.shape[1], w.shape[2], w.shape[3], w.shape[4]);
let c_out = c_out_per_g * o.groups;
let d_out = calculate_conv_transpose_output_size(kd, o.stride[0], o.padding[0], o.padding_out[0], o.dilation[0], d_in);
let h_out = calculate_conv_transpose_output_size(kh, o.stride[1], o.padding[1], o.padding_out[1], o.dilation[1], h_in);
let w_out = calculate_conv_transpose_output_size(kw, o.stride[2], o.padding[2], o.padding_out[2], o.dilation[2], w_in);
let dev = x.device;
let dtype_f: burn_std::FloatDType = x.dtype.into();
let mut output = zeros(Shape::new([batch, c_out, d_out, h_out, w_out]), &dev, dtype_f);
let o2 = ConvTransposeOptions::new(
[o.stride[1], o.stride[2]], [o.padding[1], o.padding[2]],
[o.padding_out[1], o.padding_out[2]], [o.dilation[1], o.dilation[2]], o.groups,
);
for id in 0..d_in {
let x_slice = slice_t(x.clone(), &[
Slice::new(0, Some(batch as isize), 1),
Slice::new(0, Some(c_in as isize), 1),
Slice::new(id as isize, Some(id as isize + 1), 1),
Slice::new(0, Some(h_in as isize), 1),
Slice::new(0, Some(w_in as isize), 1),
]);
let x_2d = reshape(x_slice, Shape::new([batch, c_in, h_in, w_in]));
for kd_i in 0..kd {
let od = id * o.stride[0] + kd_i * o.dilation[0];
if od < o.padding[0] { continue; }
let od_actual = od - o.padding[0];
if od_actual >= d_out { continue; }
let w_slice = slice_t(w.clone(), &[
Slice::new(0, Some(w.shape[0] as isize), 1),
Slice::new(0, Some(c_out_per_g as isize), 1),
Slice::new(kd_i as isize, Some(kd_i as isize + 1), 1),
Slice::new(0, Some(kh as isize), 1),
Slice::new(0, Some(kw as isize), 1),
]);
let w_2d = reshape(w_slice, Shape::new([w.shape[0], c_out_per_g, kh, kw]));
let conv_result = Self::conv_transpose2d(x_2d.clone(), w_2d, None, o2.clone());
let conv_5d = reshape(conv_result, Shape::new([batch, c_out, 1, h_out, w_out]));
let existing = slice_t(output.clone(), &[
Slice::new(0, Some(batch as isize), 1),
Slice::new(0, Some(c_out as isize), 1),
Slice::new(od_actual as isize, Some(od_actual as isize + 1), 1),
Slice::new(0, Some(h_out as isize), 1),
Slice::new(0, Some(w_out as isize), 1),
]);
let summed = add(existing, conv_5d);
output = slice_assign(output, &[
Slice::new(0, Some(batch as isize), 1),
Slice::new(0, Some(c_out as isize), 1),
Slice::new(od_actual as isize, Some(od_actual as isize + 1), 1),
Slice::new(0, Some(h_out as isize), 1),
Slice::new(0, Some(w_out as isize), 1),
], summed);
}
}
if let Some(bias) = b {
let bias_5d = reshape(bias, Shape::new([1, c_out, 1, 1, 1]));
let bias_expanded = <F as FloatTensorOps<F>>::float_expand(bias_5d, output.shape.clone());
output = add(output, bias_expanded);
}
output
}
fn avg_pool2d(x: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], count_include_pad: bool, _ceil: bool) -> FloatTensor<F> {
bridge::run_unary_ctx(&x, |g,ph| unsafe {
let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], 1,1, pad[1],pad[1],pad[0],pad[0]);
ffi::pool_desc_set_include_zero_pad(desc, count_include_pad);
ffi::graph_avg_pool2d(g, ph, desc)
})
}
fn avg_pool2d_backward(x: FloatTensor<F>, grad: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], count_include_pad: bool, _ceil: bool) -> FloatTensor<F> {
bridge::run_binary_ctx(&x, &grad, |g,px,pg| unsafe {
let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], 1,1, pad[1],pad[1],pad[0],pad[0]);
ffi::pool_desc_set_include_zero_pad(desc, count_include_pad);
ffi::graph_avg_pool2d_grad(g, pg, px, desc)
})
}
fn adaptive_avg_pool2d(x: FloatTensor<F>, out: [usize;2]) -> FloatTensor<F> {
let k = [x.shape[2]/out[0], x.shape[3]/out[1]];
Self::avg_pool2d(x, k, k, [0,0], true, false)
}
fn adaptive_avg_pool2d_backward(x: FloatTensor<F>, grad: FloatTensor<F>) -> FloatTensor<F> {
let k = [x.shape[2]/grad.shape[2], x.shape[3]/grad.shape[3]];
Self::avg_pool2d_backward(x, grad, k, k, [0,0], true, false)
}
fn max_pool2d(x: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], dil: [usize;2], _ceil: bool) -> FloatTensor<F> {
bridge::run_unary_ctx(&x, |g,ph| unsafe {
let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], dil[1],dil[0], pad[1],pad[1],pad[0],pad[0]);
ffi::graph_max_pool2d(g, ph, desc)
})
}
fn max_pool2d_with_indices(x: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], dil: [usize;2], _ceil: bool) -> MaxPool2dWithIndices<F> {
let (vals, mut idxs) = bridge::run_unary_two_outputs(&x, |g,ph| unsafe {
let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], dil[1],dil[0], pad[1],pad[1],pad[0],pad[0]);
ffi::pool_desc_set_return_indices(desc);
let arr = ffi::graph_max_pool2d_return_indices(g, ph, desc);
(ffi::ns_array_get(arr, 0), ffi::ns_array_get(arr, 1))
});
idxs.dtype = DType::I32;
MaxPool2dWithIndices::new(vals, idxs)
}
fn max_pool2d_with_indices_backward(x: FloatTensor<F>, ks: [usize;2], stride: [usize;2], pad: [usize;2], dil: [usize;2], _ceil: bool, grad: FloatTensor<F>, idx: IntTensor<F>) -> MaxPool2dBackward<F> {
let r = bridge::run_multi_ctx(&[&grad,&idx,&x], x.device, |g,phs| unsafe {
let desc = ffi::pool2d_desc(ks[1],ks[0], stride[1],stride[0], dil[1],dil[0], pad[1],pad[1],pad[0],pad[0]);
ffi::pool_desc_set_return_indices(desc);
ffi::graph_max_pool2d_indices_grad(g, phs[0], phs[1], phs[2], desc)
});
MaxPool2dBackward::new(r)
}
fn interpolate(x: FloatTensor<F>, out_size: [usize;2], opts: InterpolateOptions) -> FloatTensor<F> {
let mode = match opts.mode { InterpolateMode::Nearest => ffi::MPSGraphResizeMode::NEAREST, _ => ffi::MPSGraphResizeMode::BILINEAR };
bridge::run_unary_ctx(&x, |g,ph| unsafe {
let sz = ffi::ns_usize_array(&out_size);
ffi::graph_resize(g, ph, sz, mode, true, opts.align_corners)
})
}
fn interpolate_backward(x: FloatTensor<F>, grad: FloatTensor<F>, _out_size: [usize;2], opts: InterpolateOptions) -> FloatTensor<F> {
let mode = match opts.mode { InterpolateMode::Nearest => ffi::MPSGraphResizeMode::NEAREST, _ => ffi::MPSGraphResizeMode::BILINEAR };
bridge::run_binary_ctx(&x, &grad, |g,px,pg| unsafe { ffi::graph_resize_grad(g, pg, px, mode, true, opts.align_corners) })
}
fn attention(q: FloatTensor<F>, k: FloatTensor<F>, v: FloatTensor<F>, mask: Option<BoolTensor<F>>, _bias: Option<FloatTensor<F>>, _opts: AttentionModuleOptions) -> FloatTensor<F> {
let d = q.shape[q.shape.num_dims()-1] as f64;
let scale = 1.0 / d.sqrt();
let nd = q.shape.num_dims();
if let Some(ref m) = mask {
bridge::run_multi_ctx(&[&q,&k,&v,m], q.device, |g, phs| unsafe {
let kt = ffi::graph_transpose(g, phs[1], nd-2, nd-1);
let scores = ffi::graph_matmul(g, phs[0], kt);
let scaled = ffi::graph_binary(g, "multiplicationWithPrimaryTensor:secondaryTensor:name:", scores, ffi::graph_constant_scalar(g, scale, ffi::MPSDataType::FLOAT32));
let masked = ffi::graph_select(g, phs[3], ffi::graph_constant_scalar(g, -1e9, ffi::MPSDataType::FLOAT32), scaled);
let max = ffi::graph_reduction_max_axis(g, masked, (nd-1) as isize);
let shifted = ffi::graph_binary(g, "subtractionWithPrimaryTensor:secondaryTensor:name:", masked, max);
let e = ffi::graph_unary(g, "exponentWithTensor:name:", shifted);
let s = ffi::graph_reduction_sum_axis(g, e, (nd-1) as isize);
let sm = ffi::graph_binary(g, "divisionWithPrimaryTensor:secondaryTensor:name:", e, s);
ffi::graph_matmul(g, sm, phs[2])
})
} else {
bridge::run_multi_ctx(&[&q,&k,&v], q.device, |g, phs| unsafe {
let kt = ffi::graph_transpose(g, phs[1], nd-2, nd-1);
let scores = ffi::graph_matmul(g, phs[0], kt);
let scaled = ffi::graph_binary(g, "multiplicationWithPrimaryTensor:secondaryTensor:name:", scores, ffi::graph_constant_scalar(g, scale, ffi::MPSDataType::FLOAT32));
let max = ffi::graph_reduction_max_axis(g, scaled, (nd-1) as isize);
let shifted = ffi::graph_binary(g, "subtractionWithPrimaryTensor:secondaryTensor:name:", scaled, max);
let e = ffi::graph_unary(g, "exponentWithTensor:name:", shifted);
let s = ffi::graph_reduction_sum_axis(g, e, (nd-1) as isize);
let sm = ffi::graph_binary(g, "divisionWithPrimaryTensor:secondaryTensor:name:", e, s);
ffi::graph_matmul(g, sm, phs[2])
})
}
}
}
fn bilinear_sample(data: &[f32], n: usize, c: usize, h: usize, w: usize, y: f32, x: f32, channels: usize) -> f32 {
if y <= -1.0 || y >= h as f32 || x <= -1.0 || x >= w as f32 { return 0.0; }
let y_low = y.floor() as isize;
let x_low = x.floor() as isize;
let y_high = y_low + 1;
let x_high = x_low + 1;
let get = |yy: isize, xx: isize| -> f32 {
if yy < 0 || yy >= h as isize || xx < 0 || xx >= w as isize { return 0.0; }
data[((n * channels + c) * h + yy as usize) * w + xx as usize]
};
let ly = y - y_low as f32;
let lx = x - x_low as f32;
let hy = 1.0 - ly;
let hx = 1.0 - lx;
hy * hx * get(y_low, x_low) + hy * lx * get(y_low, x_high) +
ly * hx * get(y_high, x_low) + ly * lx * get(y_high, x_high)
}