use anyhow::{Context, Result};
use std::path::Path;
use ternlang_ml::coherence::{ModelCoherence, unpack_layer};
use moe_core::core::mock_layer::TernaryLayer;
use moe_core::core::routing::ExpertBank13;
pub const EXPERT_INPUT_DIM: usize = 64;
pub const EXPERT_OUTPUT_DIM: usize = 64;
pub struct ModelFileInfo {
pub source_model: String,
pub num_layers: usize,
pub sparsity: f32,
pub mean_alpha: f32,
}
pub fn load_expert_bank(path: &str) -> Result<(ExpertBank13, ModelFileInfo)> {
let model_path = Path::new(path);
let coherence = ModelCoherence::load_bin(model_path)
.with_context(|| format!("Failed to load tern.bin from '{}'", path))?;
let source_model = coherence.source_model.clone();
let num_layers = coherence.layers.len();
if num_layers == 0 {
anyhow::bail!("Model '{}' has no layers — file may be corrupted", path);
}
let needed = EXPERT_INPUT_DIM * EXPERT_OUTPUT_DIM;
let mut shards: Vec<Option<Vec<i8>>> = (0..13).map(|_| None).collect();
let mut alphas = vec![1.0f32; 13];
let mut total_weights = 0usize;
let mut total_zeros = 0usize;
for (i, layer) in coherence.layers.iter().enumerate() {
let eid = i % 13;
let raw = unpack_layer(layer).to_i8_vec();
total_weights += raw.len();
total_zeros += raw.iter().filter(|&&w| w == 0).count();
match shards[eid].as_mut() {
None => {
let mut shard: Vec<i8> = raw.into_iter().take(needed).collect();
shard.resize(needed, 0);
shards[eid] = Some(shard);
alphas[eid] = layer.scale;
}
Some(existing) => {
for (j, new_w) in raw.into_iter().take(needed).enumerate() {
existing[j] = majority_trit(existing[j], new_w);
}
alphas[eid] = (alphas[eid] + layer.scale) * 0.5;
}
}
}
let experts: Vec<TernaryLayer> = (0..13)
.map(|eid| TernaryLayer {
weights: shards[eid].take().unwrap_or_else(|| vec![0i8; needed]),
alpha: alphas[eid],
bias: vec![0.0f32; EXPERT_OUTPUT_DIM],
input_dim: EXPERT_INPUT_DIM,
output_dim: EXPERT_OUTPUT_DIM,
})
.collect();
let sparsity = if total_weights > 0 {
total_zeros as f32 / total_weights as f32
} else {
0.0
};
let mean_alpha = alphas.iter().sum::<f32>() / 13.0;
let info = ModelFileInfo {
source_model,
num_layers,
sparsity,
mean_alpha,
};
Ok((ExpertBank13::from_layers(experts), info))
}
#[inline]
fn majority_trit(a: i8, b: i8) -> i8 {
if a == b { a } else { 0 }
}