moe_platform/
file_loader.rs1use anyhow::{Context, Result};
17use std::path::Path;
18use ternlang_ml::coherence::{ModelCoherence, unpack_layer};
19use moe_core::core::mock_layer::TernaryLayer;
20use moe_core::core::routing::ExpertBank13;
21
22pub const EXPERT_INPUT_DIM: usize = 64;
23pub const EXPERT_OUTPUT_DIM: usize = 64;
24
25pub struct ModelFileInfo {
27 pub source_model: String,
29 pub num_layers: usize,
31 pub sparsity: f32,
33 pub mean_alpha: f32,
35}
36
37pub fn load_expert_bank(path: &str) -> Result<(ExpertBank13, ModelFileInfo)> {
39 let model_path = Path::new(path);
40 let coherence = ModelCoherence::load_bin(model_path)
41 .with_context(|| format!("Failed to load tern.bin from '{}'", path))?;
42
43 let source_model = coherence.source_model.clone();
44 let num_layers = coherence.layers.len();
45
46 if num_layers == 0 {
47 anyhow::bail!("Model '{}' has no layers — file may be corrupted", path);
48 }
49
50 let needed = EXPERT_INPUT_DIM * EXPERT_OUTPUT_DIM;
51 let mut shards: Vec<Option<Vec<i8>>> = (0..13).map(|_| None).collect();
52 let mut alphas = vec![1.0f32; 13];
53 let mut total_weights = 0usize;
54 let mut total_zeros = 0usize;
55
56 for (i, layer) in coherence.layers.iter().enumerate() {
57 let eid = i % 13;
58 let raw = unpack_layer(layer).to_i8_vec();
59
60 total_weights += raw.len();
61 total_zeros += raw.iter().filter(|&&w| w == 0).count();
62
63 match shards[eid].as_mut() {
64 None => {
65 let mut shard: Vec<i8> = raw.into_iter().take(needed).collect();
67 shard.resize(needed, 0);
68 shards[eid] = Some(shard);
69 alphas[eid] = layer.scale;
70 }
71 Some(existing) => {
72 for (j, new_w) in raw.into_iter().take(needed).enumerate() {
74 existing[j] = majority_trit(existing[j], new_w);
75 }
76 alphas[eid] = (alphas[eid] + layer.scale) * 0.5;
78 }
79 }
80 }
81
82 let experts: Vec<TernaryLayer> = (0..13)
83 .map(|eid| TernaryLayer {
84 weights: shards[eid].take().unwrap_or_else(|| vec![0i8; needed]),
85 alpha: alphas[eid],
86 bias: vec![0.0f32; EXPERT_OUTPUT_DIM],
87 input_dim: EXPERT_INPUT_DIM,
88 output_dim: EXPERT_OUTPUT_DIM,
89 })
90 .collect();
91
92 let sparsity = if total_weights > 0 {
93 total_zeros as f32 / total_weights as f32
94 } else {
95 0.0
96 };
97
98 let mean_alpha = alphas.iter().sum::<f32>() / 13.0;
99
100 let info = ModelFileInfo {
101 source_model,
102 num_layers,
103 sparsity,
104 mean_alpha,
105 };
106
107 Ok((ExpertBank13::from_layers(experts), info))
108}
109
110#[inline]
112fn majority_trit(a: i8, b: i8) -> i8 {
113 if a == b { a } else { 0 }
114}