use crate::util::libm;
const _WEIGHTS_LEN: usize = include_bytes!("weights.bin").len();
static WEIGHTS: &[u8; _WEIGHTS_LEN] = {
#[repr(C, align(4))]
struct AlignedData<T: ?Sized>(T);
const __DATA: &'static AlignedData<[u8; _WEIGHTS_LEN]> = &AlignedData(*include_bytes!("weights.bin"));
&__DATA.0
};
const fn weight<'a, const SIZE: usize>(offset: usize) -> &'a [f32; SIZE] {
unsafe { &*(WEIGHTS.as_ptr().cast::<f32>().add(offset) as *const [_; SIZE]) }
}
static NORM_WEIGHT: &[f32; 40] = weight(0);
static LAYER1_KERNEL: &[f32; 9] = weight(40);
static LAYER1_WEIGHT: &[f32; 16] = weight(49);
static LAYER1_BIAS: &[f32; 16] = weight(65);
static LAYER2_KERNEL: &[f32; 48] = weight(81);
static LAYER2_WEIGHT: &[f32; 256] = weight(129);
static LAYER2_BIAS: &[f32; 16] = weight(385);
static LAYER3_KERNEL: &[f32; 48] = weight(401);
static LAYER3_WEIGHT: &[f32; 256] = weight(449);
static LAYER3_BIAS: &[f32; 16] = weight(705);
static RNN1_WEIGHT: &[f32; 10240] = weight(721);
static RNN2_WEIGHT: &[f32; 8192] = weight(10961);
static OUTPUT_WEIGHT: &[f32; 128] = weight(19153);
pub struct DefaultPredictor {
state: [f32; 128]
}
impl DefaultPredictor {
pub const fn new() -> Self {
Self { state: [0.0; 128] }
}
}
impl crate::Predictor for DefaultPredictor {
fn reset(&mut self) {
self.state.fill(0.0);
}
fn normalize(&self, features: &mut [f32]) {
let i_rms = 1. / libm::sqrtf(features.iter().map(|x| x * x).sum::<f32>() / features.len() as f32);
for (i, v) in features.iter_mut().enumerate() {
*v = NORM_WEIGHT[i] * *v * i_rms;
}
}
fn predict(&mut self, features: &[f32], buffer: &mut [f32]) -> f32 {
let (buffer1, buffer2) = buffer.split_at_mut(288);
input_layer1(&features, buffer1);
input_layer2_3::<18, 9, false>(&buffer1[..288], LAYER2_KERNEL, LAYER2_WEIGHT, LAYER2_BIAS, &mut buffer2[..144]);
input_layer2_3::<9, 5, true>(&buffer2[..144], LAYER3_KERNEL, LAYER3_WEIGHT, LAYER3_BIAS, &mut buffer1[..80]);
mingru::<80>(&buffer1[..80], &self.state[..64], RNN1_WEIGHT, &mut buffer2[..128]);
self.state[..64].copy_from_slice(&buffer2[..64]);
mingru::<64>(&buffer2[..64], &self.state[64..128], RNN2_WEIGHT, &mut buffer1[..128]);
self.state[64..128].copy_from_slice(&buffer1[..64]);
output(&buffer2[..64], &buffer1[..64])
}
}
#[inline(never)]
fn input_layer1(features: &[f32], output: &mut [f32]) {
const NUM_FRAMES: usize = 3;
const NUM_FEATURES: usize = 40;
const KERNEL_SIZE: usize = 3;
const {
assert!((NUM_FRAMES - KERNEL_SIZE) / 1 + 1 == 1);
};
const DEPTHWISE_NUM_FEATURES: usize = (NUM_FEATURES - KERNEL_SIZE) / 1 + 1;
const OUT_CHANNELS: usize = 16;
const POOL_KERNEL_SIZE: usize = 3;
const POOL_STRIDE: usize = 2;
const POOLED_COLS: usize = (DEPTHWISE_NUM_FEATURES - POOL_KERNEL_SIZE) / POOL_STRIDE + 1;
output.fill(0.0);
let mut row = [0.0_f32; DEPTHWISE_NUM_FEATURES];
for ox in 0..DEPTHWISE_NUM_FEATURES {
let mut sum = 0.0;
for kh in 0..KERNEL_SIZE {
for kw in 0..KERNEL_SIZE {
let w = ox + kw;
let input_idx = (kh * NUM_FEATURES) + w;
sum += features[input_idx] * LAYER1_KERNEL[(kh * KERNEL_SIZE) + kw];
}
}
row[ox] = sum;
}
for c in 0..OUT_CHANNELS {
let mut new_row = [0.0; DEPTHWISE_NUM_FEATURES];
for ox in 0..DEPTHWISE_NUM_FEATURES {
new_row[ox] = (row[ox] * LAYER1_WEIGHT[c]) + LAYER1_BIAS[c];
}
let out_row_offs = POOLED_COLS * c;
for q in 0..POOLED_COLS {
for x in 0..POOL_KERNEL_SIZE {
let out_q = &mut output[out_row_offs + q];
*out_q = (*out_q).max(new_row[(q * POOL_STRIDE) + x]);
}
}
}
}
#[inline(never)]
fn input_layer2_3<const IN_FEATURES: usize, const OUT_FEATURES: usize, const LAYER3: bool>(
features: &[f32],
kernel: &[f32; 48],
weight: &[f32; 256],
bias: &[f32; 16],
output: &mut [f32]
) {
const HORIZONTAL_KERNEL_SIZE: usize = 3;
const STRIDE: usize = 2;
const CHANNELS: usize = 16;
output.fill(0.0);
for ox in 0..OUT_FEATURES {
let mut dw = [0.0; CHANNELS];
for c in 0..CHANNELS {
let mut sum = 0.0;
for kw in 0..HORIZONTAL_KERNEL_SIZE {
let ix = (ox * STRIDE + kw) as isize - 1;
if ix < 0 || ix >= IN_FEATURES as isize {
continue;
}
sum += features[(c * IN_FEATURES) + ix as usize] * kernel[(c * HORIZONTAL_KERNEL_SIZE) + kw];
}
dw[c] = sum;
}
for oc in 0..CHANNELS {
let mut ic = 0.0;
for c in 0..CHANNELS {
let sum = dw[c];
ic += sum * weight[(oc * CHANNELS) + c];
}
let ptr = if !LAYER3 { &mut output[(oc * OUT_FEATURES) + ox] } else { &mut output[(ox * CHANNELS) + oc] };
*ptr = (ic + bias[oc]).max(0.0);
}
}
}
#[inline(never)]
fn mingru<const IN_DIM: usize>(features: &[f32], h: &[f32], weight: &[f32], out: &mut [f32]) {
for d in 0..128 {
let mut o = 0.0;
let ri = d * IN_DIM;
for f in 0..IN_DIM {
o += features[f] * weight[ri + f];
}
out[d] = o;
}
for i in 0..64 {
let g = (out[64 + i] * 0.25).clamp(0.0, 1.0);
let v = &mut out[i];
*v = (1. - g) * h[i] + g * *v;
}
}
#[inline]
fn sigmoid(x: f32) -> f32 {
1. / (1. + libm::expf(-x))
}
#[inline(never)]
fn output(out_1: &[f32], out_2: &[f32]) -> f32 {
let mut out = 0.0;
for f in 0..64 {
out += out_1[f] * OUTPUT_WEIGHT[f];
}
for f in 0..64 {
out += out_2[f] * OUTPUT_WEIGHT[64 + f];
}
sigmoid(out)
}