struct Attrs {
d: usize,
nh: usize,
np: usize,
ref_dim: usize,
shapes: Vec<(usize, usize)>,
}
fn decode_attrs(bytes: &[u8]) -> Result<Attrs, String> {
if bytes.len() < 5 * 4 {
return Err("ms_deform_attn: attrs too short".into());
}
let rd = |i: usize| -> u32 {
u32::from_le_bytes([
bytes[i * 4],
bytes[i * 4 + 1],
bytes[i * 4 + 2],
bytes[i * 4 + 3],
])
};
let d = rd(0) as usize;
let nh = rd(1) as usize;
let np = rd(2) as usize;
let ref_dim = rd(3) as usize;
let nl = rd(4) as usize;
if bytes.len() < (5 + nl * 2) * 4 {
return Err("ms_deform_attn: attrs truncated shapes".into());
}
let shapes = (0..nl)
.map(|l| (rd(5 + l * 2) as usize, rd(5 + l * 2 + 1) as usize))
.collect();
Ok(Attrs {
d,
nh,
np,
ref_dim,
shapes,
})
}
fn linear(x: &[f32], rows: usize, in_dim: usize, w: &[f32], out_dim: usize, b: &[f32]) -> Vec<f32> {
let mut out = vec![0f32; rows * out_dim];
for r in 0..rows {
for o in 0..out_dim {
let mut acc = if b.is_empty() { 0.0 } else { b[o] };
let xr = &x[r * in_dim..r * in_dim + in_dim];
let wo = &w[o * in_dim..o * in_dim + in_dim];
for i in 0..in_dim {
acc += xr[i] * wo[i];
}
out[r * out_dim + o] = acc;
}
}
out
}
pub fn execute(inputs: &[&[f32]], attrs: &[u8], out: &mut [f32]) -> Result<(), String> {
if inputs.len() != 11 {
return Err(format!(
"ms_deform_attn: expected 11 inputs, got {}",
inputs.len()
));
}
let a = decode_attrs(attrs)?;
let (d, nh, np, ref_dim) = (a.d, a.nh, a.np, a.ref_dim);
let nl = a.shapes.len();
let hd = d / nh;
let query = inputs[0];
let value_src = inputs[1];
let reference = inputs[2];
let nq = query.len() / d;
let seq = value_src.len() / d;
let mut starts = vec![0usize; nl];
{
let mut acc = 0;
for (l, (h, w)) in a.shapes.iter().enumerate() {
starts[l] = acc;
acc += h * w;
}
}
let value = linear(value_src, seq, d, inputs[3], d, inputs[4]);
let offsets = linear(query, nq, d, inputs[5], nh * nl * np * 2, inputs[6]);
let mut attn = linear(query, nq, d, inputs[7], nh * nl * np, inputs[8]);
softmax_rows(&mut attn, nq * nh, nl * np);
let mut combined = vec![0f32; nq * d];
for q in 0..nq {
for m in 0..nh {
let mut acc = vec![0f32; hd];
for l in 0..nl {
let (h, w) = a.shapes[l];
let base = starts[l];
for p in 0..np {
let off_base = (((q * nh + m) * nl + l) * np + p) * 2;
let off_x = offsets[off_base];
let off_y = offsets[off_base + 1];
let rb = (q * nl + l) * ref_dim;
let (loc_x, loc_y) = if ref_dim == 2 {
(
reference[rb] + off_x / w as f32,
reference[rb + 1] + off_y / h as f32,
)
} else {
(
reference[rb] + off_x / np as f32 * reference[rb + 2] * 0.5,
reference[rb + 1] + off_y / np as f32 * reference[rb + 3] * 0.5,
)
};
let aw = attn[(q * nh + m) * (nl * np) + l * np + p];
if aw == 0.0 {
continue;
}
sample(&value, d, base, h, w, m, hd, loc_x, loc_y, aw, &mut acc);
}
}
for c in 0..hd {
combined[q * d + m * hd + c] = acc[c];
}
}
}
let res = linear(&combined, nq, d, inputs[9], d, inputs[10]);
if res.len() != out.len() {
return Err(format!(
"ms_deform_attn: out len {} != {}",
out.len(),
res.len()
));
}
out.copy_from_slice(&res);
Ok(())
}
pub fn execute_in_arena(
arena: &mut [f32],
in_offs: &[(usize, usize)],
out_f32_off: usize,
out_f32_len: usize,
attrs: &[u8],
) -> Result<(), String> {
let ins: Vec<Vec<f32>> = in_offs
.iter()
.map(|&(off, len)| arena[off..off + len].to_vec())
.collect();
let in_refs: Vec<&[f32]> = ins.iter().map(|v| v.as_slice()).collect();
let mut out = vec![0f32; out_f32_len];
execute(&in_refs, attrs, &mut out)?;
arena[out_f32_off..out_f32_off + out_f32_len].copy_from_slice(&out);
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn sample(
value: &[f32],
d: usize,
base: usize,
h: usize,
w: usize,
m: usize,
hd: usize,
loc_x: f32,
loc_y: f32,
weight: f32,
acc: &mut [f32],
) {
let ix = ((2.0 * loc_x - 1.0 + 1.0) * w as f32 - 1.0) * 0.5;
let iy = ((2.0 * loc_y - 1.0 + 1.0) * h as f32 - 1.0) * 0.5;
let x0 = ix.floor() as isize;
let y0 = iy.floor() as isize;
let wx1 = ix - x0 as f32;
let wy1 = iy - y0 as f32;
let corners = [
(y0, x0, (1.0 - wy1) * (1.0 - wx1)),
(y0, x0 + 1, (1.0 - wy1) * wx1),
(y0 + 1, x0, wy1 * (1.0 - wx1)),
(y0 + 1, x0 + 1, wy1 * wx1),
];
for (cy, cx, cw) in corners {
if cy < 0 || cx < 0 || cy >= h as isize || cx >= w as isize {
continue;
}
let row = base + cy as usize * w + cx as usize;
let voff = row * d + m * hd;
let cw = cw * weight;
for c in 0..hd {
acc[c] += cw * value[voff + c];
}
}
}
fn softmax_rows(x: &mut [f32], rows: usize, cols: usize) {
for r in 0..rows {
let row = &mut x[r * cols..r * cols + cols];
let mut mx = f32::NEG_INFINITY;
for &v in row.iter() {
if v > mx {
mx = v;
}
}
let mut sum = 0f32;
for v in row.iter_mut() {
*v = (*v - mx).exp();
sum += *v;
}
if sum > 0.0 {
for v in row.iter_mut() {
*v /= sum;
}
}
}
}