include!(concat!(env!("OUT_DIR"), "/model_version.rs"));
const WEIGHTS: &[u8] = include_bytes!("weights.bin");
const INPUT_DIM: usize = 42;
const EXPERT_COUNT: usize = 6;
const EXPERT_FC1_OUT: usize = 32;
const EXPERT_FC2_OUT: usize = 16;
const GATE_W_COUNT: usize = INPUT_DIM * EXPERT_COUNT;
const GATE_B_COUNT: usize = EXPERT_COUNT;
const EXPERT_FC1_W_COUNT: usize = INPUT_DIM * EXPERT_FC1_OUT;
const EXPERT_FC1_B_COUNT: usize = EXPERT_FC1_OUT;
const EXPERT_FC2_W_COUNT: usize = EXPERT_FC1_OUT * EXPERT_FC2_OUT;
const EXPERT_FC2_B_COUNT: usize = EXPERT_FC2_OUT;
const EXPERT_FC3_W_COUNT: usize = EXPERT_FC2_OUT;
const EXPERT_FC3_B_COUNT: usize = 1;
const EXPERT_PARAM_COUNT: usize = EXPERT_FC1_W_COUNT
+ EXPERT_FC1_B_COUNT
+ EXPERT_FC2_W_COUNT
+ EXPERT_FC2_B_COUNT
+ EXPERT_FC3_W_COUNT
+ EXPERT_FC3_B_COUNT;
const GATE_W_OFF: usize = 0;
const GATE_B_OFF: usize = GATE_W_OFF + GATE_W_COUNT;
const EXPERTS_OFF: usize = GATE_B_OFF + GATE_B_COUNT;
const TOTAL_F32_COUNT: usize = EXPERTS_OFF + EXPERT_COUNT * EXPERT_PARAM_COUNT;
fn all_weights() -> &'static [f32] {
static PARSED: std::sync::OnceLock<Box<[f32]>> = std::sync::OnceLock::new();
PARSED.get_or_init(|| {
assert_eq!(
WEIGHTS.len(),
TOTAL_F32_COUNT * 4,
"weights.bin size does not match expected f32 count"
);
WEIGHTS
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
})
}
fn load_f32_slice(offset: usize, count: usize) -> &'static [f32] {
&all_weights()[offset..offset + count]
}
pub fn all_weights_slice() -> &'static [f32] {
all_weights()
}
pub fn gate_weight() -> &'static [f32] {
load_f32_slice(GATE_W_OFF, GATE_W_COUNT)
}
pub fn gate_bias() -> &'static [f32] {
load_f32_slice(GATE_B_OFF, GATE_B_COUNT)
}
fn expert_base_offset(expert_idx: usize) -> usize {
assert!(expert_idx < EXPERT_COUNT, "expert index out of range");
EXPERTS_OFF + expert_idx * EXPERT_PARAM_COUNT
}
pub fn expert_fc1_weight(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx);
load_f32_slice(base, EXPERT_FC1_W_COUNT)
}
pub fn expert_fc1_bias(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx) + EXPERT_FC1_W_COUNT;
load_f32_slice(base, EXPERT_FC1_B_COUNT)
}
pub fn expert_fc2_weight(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx) + EXPERT_FC1_W_COUNT + EXPERT_FC1_B_COUNT;
load_f32_slice(base, EXPERT_FC2_W_COUNT)
}
pub fn expert_fc2_bias(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx)
+ EXPERT_FC1_W_COUNT
+ EXPERT_FC1_B_COUNT
+ EXPERT_FC2_W_COUNT;
load_f32_slice(base, EXPERT_FC2_B_COUNT)
}
pub fn expert_fc3_weight(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx)
+ EXPERT_FC1_W_COUNT
+ EXPERT_FC1_B_COUNT
+ EXPERT_FC2_W_COUNT
+ EXPERT_FC2_B_COUNT;
load_f32_slice(base, EXPERT_FC3_W_COUNT)
}
pub fn expert_fc3_bias(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx)
+ EXPERT_FC1_W_COUNT
+ EXPERT_FC1_B_COUNT
+ EXPERT_FC2_W_COUNT
+ EXPERT_FC2_B_COUNT
+ EXPERT_FC3_W_COUNT;
load_f32_slice(base, EXPERT_FC3_B_COUNT)
}