include!(concat!(env!("OUT_DIR"), "/model_version.rs"));
const WEIGHTS: &[u8] = include_bytes!("weights.bin");
const INPUT_DIM: usize = 42;
const EXPERT_COUNT: usize = 6;
const EXPERT_FC1_OUT: usize = 32;
const EXPERT_FC2_OUT: usize = 16;
const GATE_W_COUNT: usize = INPUT_DIM * EXPERT_COUNT;
const GATE_B_COUNT: usize = EXPERT_COUNT;
const EXPERT_FC1_W_COUNT: usize = INPUT_DIM * EXPERT_FC1_OUT;
const EXPERT_FC1_B_COUNT: usize = EXPERT_FC1_OUT;
const EXPERT_FC2_W_COUNT: usize = EXPERT_FC1_OUT * EXPERT_FC2_OUT;
const EXPERT_FC2_B_COUNT: usize = EXPERT_FC2_OUT;
const EXPERT_FC3_W_COUNT: usize = EXPERT_FC2_OUT;
const EXPERT_FC3_B_COUNT: usize = 1;
const EXPERT_PARAM_COUNT: usize = EXPERT_FC1_W_COUNT
+ EXPERT_FC1_B_COUNT
+ EXPERT_FC2_W_COUNT
+ EXPERT_FC2_B_COUNT
+ EXPERT_FC3_W_COUNT
+ EXPERT_FC3_B_COUNT;
const GATE_W_OFF: usize = 0;
const GATE_B_OFF: usize = GATE_W_OFF + GATE_W_COUNT;
const EXPERTS_OFF: usize = GATE_B_OFF + GATE_B_COUNT;
const TOTAL_F32_COUNT: usize = EXPERTS_OFF + EXPERT_COUNT * EXPERT_PARAM_COUNT;
fn all_weights() -> &'static [f32] {
static PARSED: std::sync::OnceLock<Box<[f32]>> = std::sync::OnceLock::new();
PARSED.get_or_init(|| {
assert_eq!(
WEIGHTS.len(),
TOTAL_F32_COUNT * 4,
"weights.bin size does not match expected f32 count"
);
WEIGHTS
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
})
}
fn load_f32_slice(offset: usize, count: usize) -> &'static [f32] {
&all_weights()[offset..offset + count]
}
pub(crate) struct ExpertWeights {
pub fc1_weight_t: &'static [f32],
pub fc1_bias: &'static [f32],
pub fc2_weight_t: &'static [f32],
pub fc2_bias: &'static [f32],
pub fc3_weight: &'static [f32],
pub fc3_bias: f32,
}
pub(crate) struct MoeModel {
pub gate_weight: &'static [f32],
pub gate_bias: &'static [f32],
pub experts: [ExpertWeights; EXPERT_COUNT],
}
fn transpose_static(src: &[f32], rows: usize, cols: usize) -> &'static [f32] {
assert_eq!(src.len(), rows * cols, "transpose dimensions must match");
let mut out = vec![0.0f32; rows * cols];
for o in 0..rows {
for k in 0..cols {
out[k * rows + o] = src[o * cols + k];
}
}
Box::leak(out.into_boxed_slice())
}
pub(crate) fn model() -> &'static MoeModel {
static MODEL: std::sync::OnceLock<MoeModel> = std::sync::OnceLock::new();
MODEL.get_or_init(|| {
let _ = all_weights();
let experts = std::array::from_fn(|expert_idx| ExpertWeights {
fc1_weight_t: transpose_static(
expert_fc1_weight(expert_idx),
EXPERT_FC1_OUT,
INPUT_DIM,
),
fc1_bias: expert_fc1_bias(expert_idx),
fc2_weight_t: transpose_static(
expert_fc2_weight(expert_idx),
EXPERT_FC2_OUT,
EXPERT_FC1_OUT,
),
fc2_bias: expert_fc2_bias(expert_idx),
fc3_weight: expert_fc3_weight(expert_idx),
fc3_bias: expert_fc3_bias(expert_idx)[0],
});
MoeModel {
gate_weight: gate_weight(),
gate_bias: gate_bias(),
experts,
}
})
}
#[cfg(feature = "gpu")]
pub fn all_weights_slice() -> &'static [f32] {
all_weights()
}
pub fn gate_weight() -> &'static [f32] {
load_f32_slice(GATE_W_OFF, GATE_W_COUNT)
}
pub fn gate_bias() -> &'static [f32] {
load_f32_slice(GATE_B_OFF, GATE_B_COUNT)
}
fn expert_base_offset(expert_idx: usize) -> usize {
assert!(expert_idx < EXPERT_COUNT, "expert index out of range");
EXPERTS_OFF + expert_idx * EXPERT_PARAM_COUNT
}
pub fn expert_fc1_weight(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx);
load_f32_slice(base, EXPERT_FC1_W_COUNT)
}
pub fn expert_fc1_bias(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx) + EXPERT_FC1_W_COUNT;
load_f32_slice(base, EXPERT_FC1_B_COUNT)
}
pub fn expert_fc2_weight(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx) + EXPERT_FC1_W_COUNT + EXPERT_FC1_B_COUNT;
load_f32_slice(base, EXPERT_FC2_W_COUNT)
}
pub fn expert_fc2_bias(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx)
+ EXPERT_FC1_W_COUNT
+ EXPERT_FC1_B_COUNT
+ EXPERT_FC2_W_COUNT;
load_f32_slice(base, EXPERT_FC2_B_COUNT)
}
pub fn expert_fc3_weight(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx)
+ EXPERT_FC1_W_COUNT
+ EXPERT_FC1_B_COUNT
+ EXPERT_FC2_W_COUNT
+ EXPERT_FC2_B_COUNT;
load_f32_slice(base, EXPERT_FC3_W_COUNT)
}
pub fn expert_fc3_bias(expert_idx: usize) -> &'static [f32] {
let base = expert_base_offset(expert_idx)
+ EXPERT_FC1_W_COUNT
+ EXPERT_FC1_B_COUNT
+ EXPERT_FC2_W_COUNT
+ EXPERT_FC2_B_COUNT
+ EXPERT_FC3_W_COUNT;
load_f32_slice(base, EXPERT_FC3_B_COUNT)
}