use crate::internal::*;
const WB: usize = 16;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct BlockedConv {
pub n: usize,
pub c_in: usize,
pub h_in: usize,
pub w: usize,
pub oc: usize,
pub group: usize,
pub kh: usize,
pub stride_h: usize,
pub dil_h: usize,
pub pad_before_h: usize,
pub h_out: usize,
}
impl BlockedConv {
#[inline]
fn icg(&self) -> usize {
self.c_in / self.group
}
#[inline]
fn ocg(&self) -> usize {
self.oc / self.group
}
}
impl Op for BlockedConv {
fn name(&self) -> StaticName {
"BlockedConv".into()
}
fn info(&self) -> TractResult<Vec<String>> {
Ok(vec![format!(
"N={} C={}->OC={} group={} kh={} (icg={} ocg={}) HxW={}x{} -> H_out={} pad_before={} stride_h={} dil_h={}",
self.n,
self.c_in,
self.oc,
self.group,
self.kh,
self.icg(),
self.ocg(),
self.h_in,
self.w,
self.h_out,
self.pad_before_h,
self.stride_h,
self.dil_h,
)])
}
op_as_typed_op!();
}
impl EvalOp for BlockedConv {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let x_t = inputs[0].cast_to::<f32>()?;
let k_t = inputs[1].cast_to::<f32>()?;
let b_t = inputs[2].cast_to::<f32>()?;
let x = unsafe { x_t.as_slice_unchecked::<f32>() };
let kernel = unsafe { k_t.as_slice_unchecked::<f32>() };
let bias_raw = unsafe { b_t.as_slice_unchecked::<f32>() };
let bias_vec: Vec<f32> = match bias_raw.len() {
0 => vec![0.0; self.oc],
1 => vec![bias_raw[0]; self.oc],
_ => bias_raw.to_vec(),
};
let bias = bias_vec.as_slice();
let mut output =
unsafe { Tensor::uninitialized::<f32>(&[self.n, self.oc, self.h_out, self.w])? };
let out = unsafe { output.as_slice_mut_unchecked::<f32>() };
let ocg = self.ocg();
match ocg {
1 => self.run::<1>(x, kernel, bias, out),
2 => self.run::<2>(x, kernel, bias, out),
3 => self.run::<3>(x, kernel, bias, out),
4 => self.run::<4>(x, kernel, bias, out),
5 => self.run::<5>(x, kernel, bias, out),
6 => self.run::<6>(x, kernel, bias, out),
8 => self.run::<8>(x, kernel, bias, out),
_ => self.run_generic(x, kernel, bias, out),
}
Ok(tvec!(output.into_tvalue()))
}
}
impl BlockedConv {
#[allow(clippy::needless_range_loop)]
fn run<const OCG: usize>(&self, x: &[f32], kernel: &[f32], bias: &[f32], out: &mut [f32]) {
let (icg, w, h_in, h_out, kh) = (self.icg(), self.w, self.h_in, self.h_out, self.kh);
let (sh, dh, pb) =
(self.stride_h as isize, self.dil_h as isize, self.pad_before_h as isize);
let kstride_oc = icg * kh; let n_full = w / WB; for ni in 0..self.n {
let x_n = &x[ni * self.c_in * h_in * w..];
let out_n = &mut out[ni * self.oc * h_out * w..];
for g in 0..self.group {
let oc0 = g * OCG;
let ic0 = g * icg;
for oh in 0..h_out {
for blk in 0..n_full {
let wb = blk * WB;
let mut acc = [[0f32; WB]; OCG];
for ocl in 0..OCG {
let b = bias[oc0 + ocl];
for j in 0..WB {
acc[ocl][j] = b;
}
}
for kh_i in 0..kh {
let ih = oh as isize * sh + kh_i as isize * dh - pb;
if ih < 0 || ih >= h_in as isize {
continue;
}
let row0 = ((ic0 * h_in + ih as usize) * w + wb) as isize;
for icl in 0..icg {
let row_base = (row0 + (icl * h_in * w) as isize) as usize;
for ocl in 0..OCG {
let wv = unsafe {
*kernel.get_unchecked(
(oc0 + ocl) * kstride_oc + icl * kh + kh_i,
)
};
let a = &mut acc[ocl];
for j in 0..WB {
a[j] += unsafe { *x_n.get_unchecked(row_base + j) } * wv;
}
}
}
}
for ocl in 0..OCG {
let ob = ((oc0 + ocl) * h_out + oh) * w + wb;
for j in 0..WB {
unsafe { *out_n.get_unchecked_mut(ob + j) = acc[ocl][j] };
}
}
}
let wb = n_full * WB;
if wb < w {
let rem = w - wb;
for ocl in 0..OCG {
let b = bias[oc0 + ocl];
let ob = ((oc0 + ocl) * h_out + oh) * w + wb;
for j in 0..rem {
out_n[ob + j] = b;
}
}
for kh_i in 0..kh {
let ih = oh as isize * sh + kh_i as isize * dh - pb;
if ih < 0 || ih >= h_in as isize {
continue;
}
let ih = ih as usize;
for icl in 0..icg {
let row_base = ((ic0 + icl) * h_in + ih) * w + wb;
for ocl in 0..OCG {
let wv = kernel[(oc0 + ocl) * kstride_oc + icl * kh + kh_i];
let ob = ((oc0 + ocl) * h_out + oh) * w + wb;
for j in 0..rem {
out_n[ob + j] += x_n[row_base + j] * wv;
}
}
}
}
}
}
}
}
}
#[allow(clippy::needless_range_loop)]
fn run_generic(&self, x: &[f32], kernel: &[f32], bias: &[f32], out: &mut [f32]) {
let (icg, ocg, w, h_in, h_out, kh) =
(self.icg(), self.ocg(), self.w, self.h_in, self.h_out, self.kh);
let (sh, dh, pb) =
(self.stride_h as isize, self.dil_h as isize, self.pad_before_h as isize);
let kstride_oc = icg * kh;
let mut acc = vec![0f32; ocg * w];
for ni in 0..self.n {
let x_n = &x[ni * self.c_in * h_in * w..];
let out_n = &mut out[ni * self.oc * h_out * w..];
for g in 0..self.group {
let oc0 = g * ocg;
let ic0 = g * icg;
for oh in 0..h_out {
for ocl in 0..ocg {
let b = bias[oc0 + ocl];
for j in 0..w {
acc[ocl * w + j] = b;
}
}
for kh_i in 0..kh {
let ih = oh as isize * sh + kh_i as isize * dh - pb;
if ih < 0 || ih >= h_in as isize {
continue;
}
let ih = ih as usize;
for icl in 0..icg {
let ic = ic0 + icl;
let row = &x_n[(ic * h_in + ih) * w..(ic * h_in + ih) * w + w];
for ocl in 0..ocg {
let wv = kernel[(oc0 + ocl) * kstride_oc + icl * kh + kh_i];
let a = &mut acc[ocl * w..ocl * w + w];
for j in 0..w {
a[j] += row[j] * wv;
}
}
}
}
for ocl in 0..ocg {
let ob = ((oc0 + ocl) * h_out + oh) * w;
out_n[ob..ob + w].copy_from_slice(&acc[ocl * w..ocl * w + w]);
}
}
}
}
}
}
impl TypedOp for BlockedConv {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(inputs.len() == 3, "BlockedConv expects 3 inputs (X, kernel, bias)");
Ok(tvec!(f32::datum_type().fact([self.n, self.oc, self.h_out, self.w])))
}
fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
let macs = self.n * self.oc * self.h_out * self.w * self.icg() * self.kh;
Ok(tvec!((Cost::FMA(f32::datum_type()), macs.to_dim())))
}
as_op!();
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(clippy::too_many_arguments)]
fn reference(op: &BlockedConv, x: &[f32], kernel: &[f32], bias: &[f32]) -> Vec<f32> {
let (icg, ocg) = (op.icg(), op.ocg());
let (h_in, w, kh) = (op.h_in, op.w, op.kh);
let (sh, dh, pb) = (op.stride_h as isize, op.dil_h as isize, op.pad_before_h as isize);
let mut out = vec![0f32; op.n * op.oc * op.h_out * w];
for ni in 0..op.n {
for oc in 0..op.oc {
let g = oc / ocg;
for oh in 0..op.h_out {
for wi in 0..w {
let mut acc = bias[oc];
for kh_i in 0..kh {
let ih = oh as isize * sh + kh_i as isize * dh - pb;
if ih < 0 || ih >= h_in as isize {
continue;
}
let ih = ih as usize;
for icl in 0..icg {
let ic = g * icg + icl;
let xv = x[((ni * op.c_in + ic) * h_in + ih) * w + wi];
acc += xv * kernel[oc * (icg * kh) + icl * kh + kh_i];
}
}
out[((ni * op.oc + oc) * op.h_out + oh) * w + wi] = acc;
}
}
}
}
out
}
fn run_case(c_in: usize, oc: usize, group: usize, kh: usize, h_in: usize, w: usize, pb: usize) {
let icg = c_in / group;
let h_out = h_in + pb - (kh - 1); let op = BlockedConv {
n: 1,
c_in,
h_in,
w,
oc,
group,
kh,
stride_h: 1,
dil_h: 1,
pad_before_h: pb,
h_out,
};
let x: Vec<f32> = (0..c_in * h_in * w).map(|i| ((i as f32 * 0.137).sin()) * 0.7).collect();
let kernel: Vec<f32> =
(0..oc * icg * kh).map(|i| ((i as f32 * 0.091).cos()) * 0.3).collect();
let bias: Vec<f32> = (0..oc).map(|i| (i as f32 * 0.05) - 0.1).collect();
let want = reference(&op, &x, &kernel, &bias);
let got = op
.eval(tvec![
Tensor::from_shape(&[1, c_in, h_in, w], &x).unwrap().into_tvalue(),
Tensor::from_shape(&[oc, icg * kh], &kernel).unwrap().into_tvalue(),
Tensor::from_shape(&[oc], &bias).unwrap().into_tvalue(),
])
.unwrap();
let got_view = got[0].to_plain_array_view::<f32>().unwrap();
let got = got_view.as_slice().unwrap();
assert_eq!(got.len(), want.len());
let max_abs = got.iter().zip(&want).map(|(a, b)| (a - b).abs()).fold(0.0, f32::max);
assert!(
max_abs < 1e-5,
"BlockedConv mismatch (c_in={c_in} oc={oc} g={group} kh={kh} h={h_in} w={w} pb={pb}): max_abs={max_abs}"
);
}
#[test]
fn blocked_conv_matches_reference() {
run_case(64, 10, 2, 5, 12, 96, 4);
run_case(4, 4, 2, 3, 5, 20, 1);
run_case(8, 6, 2, 4, 7, 5, 2);
run_case(6, 3, 1, 3, 8, 33, 0);
run_case(4, 2, 2, 2, 6, 17, 1);
}
}