use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="rms_norm_qgemv",
subop="rms_norm_qgemv",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Reduction,
)]
#[kernel]
pub fn ffai_rms_norm_qgemv<T>(
x: Tensor<T>,
norm_weight: Tensor<T>,
weight: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
output: Tensor<T>,
eps_buf: Tensor<f32>,
#[constexpr] in_dim: u32,
#[constexpr] group_size: u32,
) {
let row = program_id::<0>();
let mut ssq = 0.0f32;
let n_iters = (in_dim + lsize - 1u32) / lsize;
for _iter in range(0u32, n_iters, 1u32) {
let d = _iter * lsize + tid;
if d < in_dim {
let v = load(x[d]).cast::<f32>();
ssq = ssq + v * v;
}
}
let inv_rms = mt_rms_inv_scalar(ssq, eps_buf, in_dim);
let vals_per_pack = 8u32; let mask = 15u32;
let n_packs_per_row = in_dim / vals_per_pack;
let n_groups = in_dim / group_size;
let packs_per_group = group_size / vals_per_pack;
let row_pack_off = row * n_packs_per_row;
let row_group_off = row * n_groups;
let mut acc = 0.0f32;
let p_iters = (n_packs_per_row + lsize - 1u32) / lsize;
for _p in range(0u32, p_iters, 1u32) {
let pack_idx = _p * lsize + tid;
if pack_idx < n_packs_per_row {
let g = pack_idx / packs_per_group;
let scale = load(scales[row_group_off + g]).cast::<f32>();
let bias = load(biases[row_group_off + g]).cast::<f32>();
let packed = load(weight[row_pack_off + pack_idx]);
let p_off = pack_idx * vals_per_pack;
for i in range(0u32, vals_per_pack, 1u32) {
let q = (packed >> (i * 4u32)) & mask;
let xi = load(x[p_off + i]).cast::<f32>();
let nw = load(norm_weight[p_off + i]).cast::<f32>();
let normed = xi * nw * inv_rms;
acc = acc + (q.cast::<f32>() * scale + bias) * normed;
}
}
}
let total = reduce_sum(acc);
if tid == 0u32 {
store(output[row], total.cast::<T>());
}
}
#[bench_kernel(
op="rms_norm_qgemv",
subop="rms_norm_qgemv_fast",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Reduction,
)]
#[kernel]
pub fn ffai_rms_norm_qgemv_fast<T>(
x: Tensor<T>,
norm_weight: Tensor<T>,
weight: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
output: Tensor<T>,
eps_buf: Tensor<f32>,
#[constexpr] in_dim: u32,
#[constexpr] group_size: u32,
) {
let tg = tgid_x;
let sg = simd_id;
let lane = simd_lane;
let row0 = tg * 8u32 + sg * 4u32;
let row1 = row0 + 1u32;
let row2 = row0 + 2u32;
let row3 = row0 + 3u32;
let mut ssq = 0.0f32;
let n_iters = (in_dim + lsize - 1u32) / lsize;
for _iter in range(0u32, n_iters, 1u32) {
let d = _iter * lsize + tid;
if d < in_dim {
let v = load(x[d]).cast::<f32>();
ssq = ssq + v * v;
}
}
let inv_rms = mt_rms_inv_scalar(ssq, eps_buf, in_dim);
let gs_per_row = in_dim / group_size;
let packs_per_row = in_dim / 8u32; let w_base0 = row0 * packs_per_row;
let w_base1 = row1 * packs_per_row;
let w_base2 = row2 * packs_per_row;
let w_base3 = row3 * packs_per_row;
let sb_base0 = row0 * gs_per_row;
let sb_base1 = row1 * gs_per_row;
let sb_base2 = row2 * gs_per_row;
let sb_base3 = row3 * gs_per_row;
let mut acc0 = 0.0f32;
let mut acc1 = 0.0f32;
let mut acc2 = 0.0f32;
let mut acc3 = 0.0f32;
let lane_x_off = lane * 16u32;
let lane_pack_off = lane * 2u32;
let s_16 = 0.0625f32;
let s_256 = 0.00390625f32;
let s_4096 = 0.000244140625f32;
for _b in range(0u32, in_dim, 512u32) {
let xb = _b + lane_x_off;
let xi0 = xb;
let xi1 = xb + 1u32;
let xi2 = xb + 2u32;
let xi3 = xb + 3u32;
let xi4 = xb + 4u32;
let xi5 = xb + 5u32;
let xi6 = xb + 6u32;
let xi7 = xb + 7u32;
let xi8 = xb + 8u32;
let xi9 = xb + 9u32;
let xi10 = xb + 10u32;
let xi11 = xb + 11u32;
let xi12 = xb + 12u32;
let xi13 = xb + 13u32;
let xi14 = xb + 14u32;
let xi15 = xb + 15u32;
let n0_raw = load(x[xi0]).cast::<f32>() * load(norm_weight[xi0]).cast::<f32>() * inv_rms;
let n1_raw = load(x[xi1]).cast::<f32>() * load(norm_weight[xi1]).cast::<f32>() * inv_rms;
let n2_raw = load(x[xi2]).cast::<f32>() * load(norm_weight[xi2]).cast::<f32>() * inv_rms;
let n3_raw = load(x[xi3]).cast::<f32>() * load(norm_weight[xi3]).cast::<f32>() * inv_rms;
let n4_raw = load(x[xi4]).cast::<f32>() * load(norm_weight[xi4]).cast::<f32>() * inv_rms;
let n5_raw = load(x[xi5]).cast::<f32>() * load(norm_weight[xi5]).cast::<f32>() * inv_rms;
let n6_raw = load(x[xi6]).cast::<f32>() * load(norm_weight[xi6]).cast::<f32>() * inv_rms;
let n7_raw = load(x[xi7]).cast::<f32>() * load(norm_weight[xi7]).cast::<f32>() * inv_rms;
let n8_raw = load(x[xi8]).cast::<f32>() * load(norm_weight[xi8]).cast::<f32>() * inv_rms;
let n9_raw = load(x[xi9]).cast::<f32>() * load(norm_weight[xi9]).cast::<f32>() * inv_rms;
let n10_raw = load(x[xi10]).cast::<f32>() * load(norm_weight[xi10]).cast::<f32>() * inv_rms;
let n11_raw = load(x[xi11]).cast::<f32>() * load(norm_weight[xi11]).cast::<f32>() * inv_rms;
let n12_raw = load(x[xi12]).cast::<f32>() * load(norm_weight[xi12]).cast::<f32>() * inv_rms;
let n13_raw = load(x[xi13]).cast::<f32>() * load(norm_weight[xi13]).cast::<f32>() * inv_rms;
let n14_raw = load(x[xi14]).cast::<f32>() * load(norm_weight[xi14]).cast::<f32>() * inv_rms;
let n15_raw = load(x[xi15]).cast::<f32>() * load(norm_weight[xi15]).cast::<f32>() * inv_rms;
let ns = n0_raw
+ n1_raw
+ n2_raw
+ n3_raw
+ n4_raw
+ n5_raw
+ n6_raw
+ n7_raw
+ n8_raw
+ n9_raw
+ n10_raw
+ n11_raw
+ n12_raw
+ n13_raw
+ n14_raw
+ n15_raw;
let n1 = n1_raw * s_16;
let n2 = n2_raw * s_256;
let n3 = n3_raw * s_4096;
let n5 = n5_raw * s_16;
let n6 = n6_raw * s_256;
let n7 = n7_raw * s_4096;
let n9 = n9_raw * s_16;
let n10 = n10_raw * s_256;
let n11 = n11_raw * s_4096;
let n13 = n13_raw * s_16;
let n14 = n14_raw * s_256;
let n15 = n15_raw * s_4096;
let g = xb / group_size;
let pack_off = _b / 8u32 + lane_pack_off;
let p00 = load(weight[w_base0 + pack_off]);
let p01 = load(weight[w_base0 + pack_off + 1u32]);
let p00_hi = p00 >> 16u32;
let p01_hi = p01 >> 16u32;
let s0 = load(scales[sb_base0 + g]).cast::<f32>();
let bi0 = load(biases[sb_base0 + g]).cast::<f32>();
let q00 = (p00 & 15u32).cast::<f32>();
let q01 = (p00 & 240u32).cast::<f32>();
let q02 = (p00 & 3840u32).cast::<f32>();
let q03 = (p00 & 61440u32).cast::<f32>();
let q04 = (p00_hi & 15u32).cast::<f32>();
let q05 = (p00_hi & 240u32).cast::<f32>();
let q06 = (p00_hi & 3840u32).cast::<f32>();
let q07 = (p00_hi & 61440u32).cast::<f32>();
let q08 = (p01 & 15u32).cast::<f32>();
let q09 = (p01 & 240u32).cast::<f32>();
let q010 = (p01 & 3840u32).cast::<f32>();
let q011 = (p01 & 61440u32).cast::<f32>();
let q012 = (p01_hi & 15u32).cast::<f32>();
let q013 = (p01_hi & 240u32).cast::<f32>();
let q014 = (p01_hi & 3840u32).cast::<f32>();
let q015 = (p01_hi & 61440u32).cast::<f32>();
let qd0 = q00 * n0_raw
+ q01 * n1
+ q02 * n2
+ q03 * n3
+ q04 * n4_raw
+ q05 * n5
+ q06 * n6
+ q07 * n7
+ q08 * n8_raw
+ q09 * n9
+ q010 * n10
+ q011 * n11
+ q012 * n12_raw
+ q013 * n13
+ q014 * n14
+ q015 * n15;
acc0 = acc0 + s0 * qd0 + bi0 * ns;
let p10 = load(weight[w_base1 + pack_off]);
let p11 = load(weight[w_base1 + pack_off + 1u32]);
let p10_hi = p10 >> 16u32;
let p11_hi = p11 >> 16u32;
let s1 = load(scales[sb_base1 + g]).cast::<f32>();
let bi1 = load(biases[sb_base1 + g]).cast::<f32>();
let q10 = (p10 & 15u32).cast::<f32>();
let q11 = (p10 & 240u32).cast::<f32>();
let q12 = (p10 & 3840u32).cast::<f32>();
let q13 = (p10 & 61440u32).cast::<f32>();
let q14 = (p10_hi & 15u32).cast::<f32>();
let q15 = (p10_hi & 240u32).cast::<f32>();
let q16 = (p10_hi & 3840u32).cast::<f32>();
let q17 = (p10_hi & 61440u32).cast::<f32>();
let q18 = (p11 & 15u32).cast::<f32>();
let q19 = (p11 & 240u32).cast::<f32>();
let q110 = (p11 & 3840u32).cast::<f32>();
let q111 = (p11 & 61440u32).cast::<f32>();
let q112 = (p11_hi & 15u32).cast::<f32>();
let q113 = (p11_hi & 240u32).cast::<f32>();
let q114 = (p11_hi & 3840u32).cast::<f32>();
let q115 = (p11_hi & 61440u32).cast::<f32>();
let qd1 = q10 * n0_raw
+ q11 * n1
+ q12 * n2
+ q13 * n3
+ q14 * n4_raw
+ q15 * n5
+ q16 * n6
+ q17 * n7
+ q18 * n8_raw
+ q19 * n9
+ q110 * n10
+ q111 * n11
+ q112 * n12_raw
+ q113 * n13
+ q114 * n14
+ q115 * n15;
acc1 = acc1 + s1 * qd1 + bi1 * ns;
let p20 = load(weight[w_base2 + pack_off]);
let p21 = load(weight[w_base2 + pack_off + 1u32]);
let p20_hi = p20 >> 16u32;
let p21_hi = p21 >> 16u32;
let s2 = load(scales[sb_base2 + g]).cast::<f32>();
let bi2 = load(biases[sb_base2 + g]).cast::<f32>();
let q20 = (p20 & 15u32).cast::<f32>();
let q21 = (p20 & 240u32).cast::<f32>();
let q22 = (p20 & 3840u32).cast::<f32>();
let q23 = (p20 & 61440u32).cast::<f32>();
let q24 = (p20_hi & 15u32).cast::<f32>();
let q25 = (p20_hi & 240u32).cast::<f32>();
let q26 = (p20_hi & 3840u32).cast::<f32>();
let q27 = (p20_hi & 61440u32).cast::<f32>();
let q28 = (p21 & 15u32).cast::<f32>();
let q29 = (p21 & 240u32).cast::<f32>();
let q210 = (p21 & 3840u32).cast::<f32>();
let q211 = (p21 & 61440u32).cast::<f32>();
let q212 = (p21_hi & 15u32).cast::<f32>();
let q213 = (p21_hi & 240u32).cast::<f32>();
let q214 = (p21_hi & 3840u32).cast::<f32>();
let q215 = (p21_hi & 61440u32).cast::<f32>();
let qd2 = q20 * n0_raw
+ q21 * n1
+ q22 * n2
+ q23 * n3
+ q24 * n4_raw
+ q25 * n5
+ q26 * n6
+ q27 * n7
+ q28 * n8_raw
+ q29 * n9
+ q210 * n10
+ q211 * n11
+ q212 * n12_raw
+ q213 * n13
+ q214 * n14
+ q215 * n15;
acc2 = acc2 + s2 * qd2 + bi2 * ns;
let p30 = load(weight[w_base3 + pack_off]);
let p31 = load(weight[w_base3 + pack_off + 1u32]);
let p30_hi = p30 >> 16u32;
let p31_hi = p31 >> 16u32;
let s3 = load(scales[sb_base3 + g]).cast::<f32>();
let bi3 = load(biases[sb_base3 + g]).cast::<f32>();
let q30 = (p30 & 15u32).cast::<f32>();
let q31 = (p30 & 240u32).cast::<f32>();
let q32 = (p30 & 3840u32).cast::<f32>();
let q33 = (p30 & 61440u32).cast::<f32>();
let q34 = (p30_hi & 15u32).cast::<f32>();
let q35 = (p30_hi & 240u32).cast::<f32>();
let q36 = (p30_hi & 3840u32).cast::<f32>();
let q37 = (p30_hi & 61440u32).cast::<f32>();
let q38 = (p31 & 15u32).cast::<f32>();
let q39 = (p31 & 240u32).cast::<f32>();
let q310 = (p31 & 3840u32).cast::<f32>();
let q311 = (p31 & 61440u32).cast::<f32>();
let q312 = (p31_hi & 15u32).cast::<f32>();
let q313 = (p31_hi & 240u32).cast::<f32>();
let q314 = (p31_hi & 3840u32).cast::<f32>();
let q315 = (p31_hi & 61440u32).cast::<f32>();
let qd3 = q30 * n0_raw
+ q31 * n1
+ q32 * n2
+ q33 * n3
+ q34 * n4_raw
+ q35 * n5
+ q36 * n6
+ q37 * n7
+ q38 * n8_raw
+ q39 * n9
+ q310 * n10
+ q311 * n11
+ q312 * n12_raw
+ q313 * n13
+ q314 * n14
+ q315 * n15;
acc3 = acc3 + s3 * qd3 + bi3 * ns;
}
let r0 = simd_sum(acc0);
let r1 = simd_sum(acc1);
let r2 = simd_sum(acc2);
let r3 = simd_sum(acc3);
if lane == 0u32 {
store(output[row0], r0.cast::<T>());
store(output[row1], r1.cast::<T>());
store(output[row2], r2.cast::<T>());
store(output[row3], r3.cast::<T>());
}
}
#[bench_kernel(
op="rms_norm_qgemv",
subop="rms_norm_qgemv_int8_fast",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Reduction,
)]
#[kernel]
pub fn ffai_rms_norm_qgemv_int8_fast<T>(
x: Tensor<T>,
norm_weight: Tensor<T>,
weight: Tensor<u32>,
scales: Tensor<T>,
biases: Tensor<T>,
output: Tensor<T>,
eps_buf: Tensor<f32>,
#[constexpr] in_dim: u32,
#[constexpr] group_size: u32,
) {
let tg = tgid_x;
let sg = simd_id;
let lane = simd_lane;
let row0 = tg * 8u32 + sg * 4u32;
let row1 = row0 + 1u32;
let row2 = row0 + 2u32;
let row3 = row0 + 3u32;
let mut ssq = 0.0f32;
let n_iters = (in_dim + lsize - 1u32) / lsize;
for _iter in range(0u32, n_iters, 1u32) {
let d = _iter * lsize + tid;
if d < in_dim {
let v = load(x[d]).cast::<f32>();
ssq = ssq + v * v;
}
}
let inv_rms = mt_rms_inv_scalar(ssq, eps_buf, in_dim);
let gs_per_row = in_dim / group_size;
let packs_per_row = in_dim / 4u32;
let w_base0 = row0 * packs_per_row;
let w_base1 = row1 * packs_per_row;
let w_base2 = row2 * packs_per_row;
let w_base3 = row3 * packs_per_row;
let sb_base0 = row0 * gs_per_row;
let sb_base1 = row1 * gs_per_row;
let sb_base2 = row2 * gs_per_row;
let sb_base3 = row3 * gs_per_row;
let mut acc0 = 0.0f32;
let mut acc1 = 0.0f32;
let mut acc2 = 0.0f32;
let mut acc3 = 0.0f32;
let lane_x_off = lane * 16u32;
let lane_pack_off = lane * 4u32;
for _b in range(0u32, in_dim, 512u32) {
let xb = _b + lane_x_off;
let n0 = load(x[xb]).cast::<f32>() * load(norm_weight[xb]).cast::<f32>() * inv_rms;
let n1 =
load(x[xb + 1u32]).cast::<f32>() * load(norm_weight[xb + 1u32]).cast::<f32>() * inv_rms;
let n2 =
load(x[xb + 2u32]).cast::<f32>() * load(norm_weight[xb + 2u32]).cast::<f32>() * inv_rms;
let n3 =
load(x[xb + 3u32]).cast::<f32>() * load(norm_weight[xb + 3u32]).cast::<f32>() * inv_rms;
let n4 =
load(x[xb + 4u32]).cast::<f32>() * load(norm_weight[xb + 4u32]).cast::<f32>() * inv_rms;
let n5 =
load(x[xb + 5u32]).cast::<f32>() * load(norm_weight[xb + 5u32]).cast::<f32>() * inv_rms;
let n6 =
load(x[xb + 6u32]).cast::<f32>() * load(norm_weight[xb + 6u32]).cast::<f32>() * inv_rms;
let n7 =
load(x[xb + 7u32]).cast::<f32>() * load(norm_weight[xb + 7u32]).cast::<f32>() * inv_rms;
let n8 =
load(x[xb + 8u32]).cast::<f32>() * load(norm_weight[xb + 8u32]).cast::<f32>() * inv_rms;
let n9 =
load(x[xb + 9u32]).cast::<f32>() * load(norm_weight[xb + 9u32]).cast::<f32>() * inv_rms;
let n10 = load(x[xb + 10u32]).cast::<f32>()
* load(norm_weight[xb + 10u32]).cast::<f32>()
* inv_rms;
let n11 = load(x[xb + 11u32]).cast::<f32>()
* load(norm_weight[xb + 11u32]).cast::<f32>()
* inv_rms;
let n12 = load(x[xb + 12u32]).cast::<f32>()
* load(norm_weight[xb + 12u32]).cast::<f32>()
* inv_rms;
let n13 = load(x[xb + 13u32]).cast::<f32>()
* load(norm_weight[xb + 13u32]).cast::<f32>()
* inv_rms;
let n14 = load(x[xb + 14u32]).cast::<f32>()
* load(norm_weight[xb + 14u32]).cast::<f32>()
* inv_rms;
let n15 = load(x[xb + 15u32]).cast::<f32>()
* load(norm_weight[xb + 15u32]).cast::<f32>()
* inv_rms;
let ns =
n0 + n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 + n13 + n14 + n15;
let g = xb / group_size;
let pack_off = _b / 4u32 + lane_pack_off;
let p00 = load(weight[w_base0 + pack_off]);
let p01 = load(weight[w_base0 + pack_off + 1u32]);
let p02 = load(weight[w_base0 + pack_off + 2u32]);
let p03 = load(weight[w_base0 + pack_off + 3u32]);
let s0 = load(scales[sb_base0 + g]).cast::<f32>();
let bi0 = load(biases[sb_base0 + g]).cast::<f32>();
let qd0 = (p00 & 255u32).cast::<f32>() * n0
+ ((p00 >> 8u32) & 255u32).cast::<f32>() * n1
+ ((p00 >> 16u32) & 255u32).cast::<f32>() * n2
+ ((p00 >> 24u32) & 255u32).cast::<f32>() * n3
+ (p01 & 255u32).cast::<f32>() * n4
+ ((p01 >> 8u32) & 255u32).cast::<f32>() * n5
+ ((p01 >> 16u32) & 255u32).cast::<f32>() * n6
+ ((p01 >> 24u32) & 255u32).cast::<f32>() * n7
+ (p02 & 255u32).cast::<f32>() * n8
+ ((p02 >> 8u32) & 255u32).cast::<f32>() * n9
+ ((p02 >> 16u32) & 255u32).cast::<f32>() * n10
+ ((p02 >> 24u32) & 255u32).cast::<f32>() * n11
+ (p03 & 255u32).cast::<f32>() * n12
+ ((p03 >> 8u32) & 255u32).cast::<f32>() * n13
+ ((p03 >> 16u32) & 255u32).cast::<f32>() * n14
+ ((p03 >> 24u32) & 255u32).cast::<f32>() * n15;
acc0 = acc0 + s0 * qd0 + bi0 * ns;
let p10 = load(weight[w_base1 + pack_off]);
let p11 = load(weight[w_base1 + pack_off + 1u32]);
let p12 = load(weight[w_base1 + pack_off + 2u32]);
let p13 = load(weight[w_base1 + pack_off + 3u32]);
let s1 = load(scales[sb_base1 + g]).cast::<f32>();
let bi1 = load(biases[sb_base1 + g]).cast::<f32>();
let qd1 = (p10 & 255u32).cast::<f32>() * n0
+ ((p10 >> 8u32) & 255u32).cast::<f32>() * n1
+ ((p10 >> 16u32) & 255u32).cast::<f32>() * n2
+ ((p10 >> 24u32) & 255u32).cast::<f32>() * n3
+ (p11 & 255u32).cast::<f32>() * n4
+ ((p11 >> 8u32) & 255u32).cast::<f32>() * n5
+ ((p11 >> 16u32) & 255u32).cast::<f32>() * n6
+ ((p11 >> 24u32) & 255u32).cast::<f32>() * n7
+ (p12 & 255u32).cast::<f32>() * n8
+ ((p12 >> 8u32) & 255u32).cast::<f32>() * n9
+ ((p12 >> 16u32) & 255u32).cast::<f32>() * n10
+ ((p12 >> 24u32) & 255u32).cast::<f32>() * n11
+ (p13 & 255u32).cast::<f32>() * n12
+ ((p13 >> 8u32) & 255u32).cast::<f32>() * n13
+ ((p13 >> 16u32) & 255u32).cast::<f32>() * n14
+ ((p13 >> 24u32) & 255u32).cast::<f32>() * n15;
acc1 = acc1 + s1 * qd1 + bi1 * ns;
let p20 = load(weight[w_base2 + pack_off]);
let p21 = load(weight[w_base2 + pack_off + 1u32]);
let p22 = load(weight[w_base2 + pack_off + 2u32]);
let p23 = load(weight[w_base2 + pack_off + 3u32]);
let s2 = load(scales[sb_base2 + g]).cast::<f32>();
let bi2 = load(biases[sb_base2 + g]).cast::<f32>();
let qd2 = (p20 & 255u32).cast::<f32>() * n0
+ ((p20 >> 8u32) & 255u32).cast::<f32>() * n1
+ ((p20 >> 16u32) & 255u32).cast::<f32>() * n2
+ ((p20 >> 24u32) & 255u32).cast::<f32>() * n3
+ (p21 & 255u32).cast::<f32>() * n4
+ ((p21 >> 8u32) & 255u32).cast::<f32>() * n5
+ ((p21 >> 16u32) & 255u32).cast::<f32>() * n6
+ ((p21 >> 24u32) & 255u32).cast::<f32>() * n7
+ (p22 & 255u32).cast::<f32>() * n8
+ ((p22 >> 8u32) & 255u32).cast::<f32>() * n9
+ ((p22 >> 16u32) & 255u32).cast::<f32>() * n10
+ ((p22 >> 24u32) & 255u32).cast::<f32>() * n11
+ (p23 & 255u32).cast::<f32>() * n12
+ ((p23 >> 8u32) & 255u32).cast::<f32>() * n13
+ ((p23 >> 16u32) & 255u32).cast::<f32>() * n14
+ ((p23 >> 24u32) & 255u32).cast::<f32>() * n15;
acc2 = acc2 + s2 * qd2 + bi2 * ns;
let p30 = load(weight[w_base3 + pack_off]);
let p31 = load(weight[w_base3 + pack_off + 1u32]);
let p32 = load(weight[w_base3 + pack_off + 2u32]);
let p33 = load(weight[w_base3 + pack_off + 3u32]);
let s3 = load(scales[sb_base3 + g]).cast::<f32>();
let bi3 = load(biases[sb_base3 + g]).cast::<f32>();
let qd3 = (p30 & 255u32).cast::<f32>() * n0
+ ((p30 >> 8u32) & 255u32).cast::<f32>() * n1
+ ((p30 >> 16u32) & 255u32).cast::<f32>() * n2
+ ((p30 >> 24u32) & 255u32).cast::<f32>() * n3
+ (p31 & 255u32).cast::<f32>() * n4
+ ((p31 >> 8u32) & 255u32).cast::<f32>() * n5
+ ((p31 >> 16u32) & 255u32).cast::<f32>() * n6
+ ((p31 >> 24u32) & 255u32).cast::<f32>() * n7
+ (p32 & 255u32).cast::<f32>() * n8
+ ((p32 >> 8u32) & 255u32).cast::<f32>() * n9
+ ((p32 >> 16u32) & 255u32).cast::<f32>() * n10
+ ((p32 >> 24u32) & 255u32).cast::<f32>() * n11
+ (p33 & 255u32).cast::<f32>() * n12
+ ((p33 >> 8u32) & 255u32).cast::<f32>() * n13
+ ((p33 >> 16u32) & 255u32).cast::<f32>() * n14
+ ((p33 >> 24u32) & 255u32).cast::<f32>() * n15;
acc3 = acc3 + s3 * qd3 + bi3 * ns;
}
let r0 = simd_sum(acc0);
let r1 = simd_sum(acc1);
let r2 = simd_sum(acc2);
let r3 = simd_sum(acc3);
if lane == 0u32 {
store(output[row0], r0.cast::<T>());
store(output[row1], r1.cast::<T>());
store(output[row2], r2.cast::<T>());
store(output[row3], r3.cast::<T>());
}
}