Skip to main content

moe_platform/
file_loader.rs

1// SPDX-License-Identifier: MIT
2//! # Model File Loader
3//!
4//! Loads a `.tern.bin` ModelCoherence file produced by `transmute_llama.py`
5//! and maps its ternarized layers into a 13-expert ExpertBank for EPIS inference.
6//!
7//! ## Layer → Expert mapping
8//! Round-robin: layer i → expert (i % 13).
9//! When multiple layers land on the same expert they are fused by majority vote:
10//! agreement → keep the value, disagreement → 0 (hold).
11//!
12//! ## Dimension handling
13//! Large transformer layers (e.g. 2048×2048) are truncated to EXPERT_DIMS.
14//! Layers smaller than EXPERT_DIMS are zero-padded.
15
16use anyhow::{Context, Result};
17use std::path::Path;
18use ternlang_ml::coherence::{ModelCoherence, unpack_layer};
19use moe_core::core::mock_layer::TernaryLayer;
20use moe_core::core::routing::ExpertBank13;
21
22pub const EXPERT_INPUT_DIM: usize = 64;
23pub const EXPERT_OUTPUT_DIM: usize = 64;
24
25/// Metadata returned alongside a loaded ExpertBank.
26pub struct ModelFileInfo {
27    /// Source model name embedded in the .tern.bin file.
28    pub source_model: String,
29    /// Total number of layers read from the file.
30    pub num_layers: usize,
31    /// Overall weight sparsity across all layers (fraction of zeros).
32    pub sparsity: f32,
33    /// Mean absolute scale (alpha) across all experts.
34    pub mean_alpha: f32,
35}
36
37/// Load a `.tern.bin` file and construct a real ExpertBank13 from its layers.
38pub fn load_expert_bank(path: &str) -> Result<(ExpertBank13, ModelFileInfo)> {
39    let model_path = Path::new(path);
40    let coherence = ModelCoherence::load_bin(model_path)
41        .with_context(|| format!("Failed to load tern.bin from '{}'", path))?;
42
43    let source_model = coherence.source_model.clone();
44    let num_layers = coherence.layers.len();
45
46    if num_layers == 0 {
47        anyhow::bail!("Model '{}' has no layers — file may be corrupted", path);
48    }
49
50    let needed = EXPERT_INPUT_DIM * EXPERT_OUTPUT_DIM;
51    let mut shards: Vec<Option<Vec<i8>>> = (0..13).map(|_| None).collect();
52    let mut alphas = vec![1.0f32; 13];
53    let mut total_weights = 0usize;
54    let mut total_zeros = 0usize;
55
56    for (i, layer) in coherence.layers.iter().enumerate() {
57        let eid = i % 13;
58        let raw = unpack_layer(layer).to_i8_vec();
59
60        total_weights += raw.len();
61        total_zeros += raw.iter().filter(|&&w| w == 0).count();
62
63        match shards[eid].as_mut() {
64            None => {
65                // First layer for this expert — truncate or zero-pad to target dims
66                let mut shard: Vec<i8> = raw.into_iter().take(needed).collect();
67                shard.resize(needed, 0);
68                shards[eid] = Some(shard);
69                alphas[eid] = layer.scale;
70            }
71            Some(existing) => {
72                // Fuse subsequent layers via majority vote to preserve signal
73                for (j, new_w) in raw.into_iter().take(needed).enumerate() {
74                    existing[j] = majority_trit(existing[j], new_w);
75                }
76                // Update alpha toward the mean
77                alphas[eid] = (alphas[eid] + layer.scale) * 0.5;
78            }
79        }
80    }
81
82    let experts: Vec<TernaryLayer> = (0..13)
83        .map(|eid| TernaryLayer {
84            weights: shards[eid].take().unwrap_or_else(|| vec![0i8; needed]),
85            alpha: alphas[eid],
86            bias: vec![0.0f32; EXPERT_OUTPUT_DIM],
87            input_dim: EXPERT_INPUT_DIM,
88            output_dim: EXPERT_OUTPUT_DIM,
89        })
90        .collect();
91
92    let sparsity = if total_weights > 0 {
93        total_zeros as f32 / total_weights as f32
94    } else {
95        0.0
96    };
97
98    let mean_alpha = alphas.iter().sum::<f32>() / 13.0;
99
100    let info = ModelFileInfo {
101        source_model,
102        num_layers,
103        sparsity,
104        mean_alpha,
105    };
106
107    Ok((ExpertBank13::from_layers(experts), info))
108}
109
110/// Ternary majority vote: same sign → keep, conflict → 0 (hold/uncertain).
111#[inline]
112fn majority_trit(a: i8, b: i8) -> i8 {
113    if a == b { a } else { 0 }
114}