include!(concat!(env!("OUT_DIR"), "/model_version.rs"));
const WEIGHTS: &[u8] = include_bytes!("weights.bin");
const INPUT_DIM: usize = 41;
const EXPERT_COUNT: usize = 6;
const EXPERT_FC1_OUT: usize = 32;
const EXPERT_FC2_OUT: usize = 16;
const GATE_W_COUNT: usize = INPUT_DIM * EXPERT_COUNT;
const GATE_B_COUNT: usize = EXPERT_COUNT;
const EXPERT_FC1_W_COUNT: usize = INPUT_DIM * EXPERT_FC1_OUT;
const EXPERT_FC1_B_COUNT: usize = EXPERT_FC1_OUT;
const EXPERT_FC2_W_COUNT: usize = EXPERT_FC1_OUT * EXPERT_FC2_OUT;
const EXPERT_FC2_B_COUNT: usize = EXPERT_FC2_OUT;
const EXPERT_FC3_W_COUNT: usize = EXPERT_FC2_OUT;
const EXPERT_FC3_B_COUNT: usize = 1;
const EXPERT_PARAM_COUNT: usize = EXPERT_FC1_W_COUNT
+ EXPERT_FC1_B_COUNT
+ EXPERT_FC2_W_COUNT
+ EXPERT_FC2_B_COUNT
+ EXPERT_FC3_W_COUNT
+ EXPERT_FC3_B_COUNT;
const GATE_W_OFF: usize = 0;
const GATE_B_OFF: usize = GATE_W_OFF + GATE_W_COUNT;
const EXPERTS_OFF: usize = GATE_B_OFF + GATE_B_COUNT;
const TOTAL_F32_COUNT: usize = EXPERTS_OFF + EXPERT_COUNT * EXPERT_PARAM_COUNT;
fn all_weights() -> &'static [f32] {
static PARSED: std::sync::OnceLock<Box<[f32]>> = std::sync::OnceLock::new();
PARSED.get_or_init(|| {
assert_eq!(
WEIGHTS.len(),
TOTAL_F32_COUNT * 4,
"weights.bin size does not match expected f32 count"
);
WEIGHTS
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
})
}
fn load_f32_slice(offset: usize, count: usize) -> &'static [f32] {
&all_weights()[offset..offset + count]
}
#[cfg(feature = "gpu")]
#[allow(dead_code)]
pub fn all_weights_slice() -> &'static [f32] {
all_weights()
}
pub fn gate_weight() -> &'static [f32] {
load_f32_slice(GATE_W_OFF, GATE_W_COUNT)
}
pub fn gate_bias() -> &'static [f32] {
load_f32_slice(GATE_B_OFF, GATE_B_COUNT)
}
fn expert_base_offset(expert_idx: usize) -> usize {
assert!(expert_idx < EXPERT_COUNT, "expert index out of range");
EXPERTS_OFF + expert_idx * EXPERT_PARAM_COUNT
}
pub fn expert_fc1_weight(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx);
load_f32_slice(base, EXPERT_FC1_W_COUNT)
}
pub fn expert_fc1_bias(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx) + EXPERT_FC1_W_COUNT;
load_f32_slice(base, EXPERT_FC1_B_COUNT)
}
pub fn expert_fc2_weight(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx) + EXPERT_FC1_W_COUNT + EXPERT_FC1_B_COUNT;
load_f32_slice(base, EXPERT_FC2_W_COUNT)
}
pub fn expert_fc2_bias(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx)
+ EXPERT_FC1_W_COUNT
+ EXPERT_FC1_B_COUNT
+ EXPERT_FC2_W_COUNT;
load_f32_slice(base, EXPERT_FC2_B_COUNT)
}
pub fn expert_fc3_weight(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx)
+ EXPERT_FC1_W_COUNT
+ EXPERT_FC1_B_COUNT
+ EXPERT_FC2_W_COUNT
+ EXPERT_FC2_B_COUNT;
load_f32_slice(base, EXPERT_FC3_W_COUNT)
}
pub fn expert_fc3_bias(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx)
+ EXPERT_FC1_W_COUNT
+ EXPERT_FC1_B_COUNT
+ EXPERT_FC2_W_COUNT
+ EXPERT_FC2_B_COUNT
+ EXPERT_FC3_W_COUNT;
load_f32_slice(base, EXPERT_FC3_B_COUNT)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn weights_load_successfully() {
assert!(!gate_weight().is_empty(), "gate weights empty");
assert!(!gate_bias().is_empty(), "gate bias empty");
for i in 0..EXPERT_COUNT {
assert!(
!expert_fc1_weight(i).is_empty(),
"expert {} fc1 weights empty",
i
);
assert!(
!expert_fc1_bias(i).is_empty(),
"expert {} fc1 bias empty",
i
);
assert!(
!expert_fc2_weight(i).is_empty(),
"expert {} fc2 weights empty",
i
);
assert!(
!expert_fc2_bias(i).is_empty(),
"expert {} fc2 bias empty",
i
);
assert!(
!expert_fc3_weight(i).is_empty(),
"expert {} fc3 weights empty",
i
);
assert!(
!expert_fc3_bias(i).is_empty(),
"expert {} fc3 bias empty",
i
);
}
}
#[test]
fn weights_have_correct_dimensions() {
assert_eq!(gate_weight().len(), INPUT_DIM * EXPERT_COUNT);
assert_eq!(gate_bias().len(), EXPERT_COUNT);
for i in 0..EXPERT_COUNT {
assert_eq!(expert_fc1_weight(i).len(), INPUT_DIM * EXPERT_FC1_OUT);
assert_eq!(expert_fc1_bias(i).len(), EXPERT_FC1_OUT);
assert_eq!(expert_fc2_weight(i).len(), EXPERT_FC1_OUT * EXPERT_FC2_OUT);
assert_eq!(expert_fc2_bias(i).len(), EXPERT_FC2_OUT);
assert_eq!(expert_fc3_weight(i).len(), EXPERT_FC2_OUT);
assert_eq!(expert_fc3_bias(i).len(), 1);
}
}
#[test]
fn weights_are_finite() {
for &w in all_weights() {
assert!(w.is_finite(), "weight is NaN or Inf: {w}");
}
}
#[test]
fn model_version_is_nonempty() {
assert!(!MODEL_VERSION.is_empty());
}
}