Skip to main content

moe_platform/
lib.rs

1pub mod file_loader;
2
3use anyhow::Result;
4use moe_core::core::inference::InferenceEngine;
5use file_loader::{load_expert_bank, EXPERT_INPUT_DIM, EXPERT_OUTPUT_DIM};
6
7/// A loaded, ready-to-run model instance.
8pub struct LoadedModel {
9    pub(crate) engine: InferenceEngine,
10    pub model_id: String,
11    pub input_dim: usize,
12    pub output_dim: usize,
13}
14
15/// The ternary verdict and metadata from a single inference pass.
16pub struct InferenceResult {
17    /// Ternary decision: -1 (reject), 0 (hold), +1 (affirm)
18    pub trit_verdict: i8,
19    /// Strength of signal driving the verdict (0.0 – 1.0)
20    pub confidence: f32,
21    /// Raw output activation vector from the expert bank
22    pub output_vec: Vec<f32>,
23    /// Human-readable routing summary
24    pub routing_summary: String,
25}
26
27pub struct Platform;
28
29impl Platform {
30    pub fn new() -> Self {
31        Self
32    }
33
34    /// Instantiate a synthetic InferenceEngine with default dimensions.
35    /// Use this when no pre-trained weights are available.
36    pub fn load_model(&self, model_id: &str) -> Result<LoadedModel> {
37        Ok(LoadedModel {
38            engine: InferenceEngine::new(
39                format!("epis-v1.0/{}", model_id),
40                EXPERT_INPUT_DIM,
41                EXPERT_OUTPUT_DIM,
42            ),
43            model_id: model_id.to_string(),
44            input_dim: EXPERT_INPUT_DIM,
45            output_dim: EXPERT_OUTPUT_DIM,
46        })
47    }
48
49    /// Load a real ternarized model from a `.tern.bin` file.
50    ///
51    /// The file must be a `ModelCoherence` binary produced by `scripts/transmute_llama.py`.
52    /// Layers are mapped round-robin to the 13 EPIS experts.
53    ///
54    /// ```no_run
55    /// use moe_platform::Platform;
56    /// let platform = Platform::new();
57    /// let model = platform.load_model_from_file("/path/to/llama32-1b.tern.bin").unwrap();
58    /// let result = platform.run_inference(&model, "Should we proceed?").unwrap();
59    /// println!("Verdict: {}", result.trit_verdict);
60    /// ```
61    pub fn load_model_from_file(&self, path: &str) -> Result<LoadedModel> {
62        let (expert_bank, info) = load_expert_bank(path)?;
63
64        log::info!(
65            "Loaded '{}' — {} layers → 13 experts | sparsity {:.1}% | ᾱ={:.4}",
66            info.source_model,
67            info.num_layers,
68            info.sparsity * 100.0,
69            info.mean_alpha,
70        );
71
72        let mut engine = InferenceEngine::new(
73            format!("epis-v1.0/{}", info.source_model),
74            EXPERT_INPUT_DIM,
75            EXPERT_OUTPUT_DIM,
76        );
77        // Swap in real weights from file — overwrites the randomly-initialised bank
78        engine.expert_bank = expert_bank;
79
80        Ok(LoadedModel {
81            engine,
82            model_id: info.source_model,
83            input_dim: EXPERT_INPUT_DIM,
84            output_dim: EXPERT_OUTPUT_DIM,
85        })
86    }
87
88    /// Run a forward pass and return a structured ternary result.
89    pub fn run_inference(&self, model: &LoadedModel, prompt: &str) -> Result<InferenceResult> {
90        let mut input = encode_prompt(prompt, model.input_dim);
91        let output = model.engine.forward(&mut input)?;
92        Ok(decode_result(output, model))
93    }
94}
95
96impl Default for Platform {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102/// Encode a text prompt into a normalised float activation vector.
103fn encode_prompt(prompt: &str, dim: usize) -> Vec<f32> {
104    let mut vec = vec![0.0f32; dim];
105    for (i, b) in prompt.bytes().enumerate() {
106        vec[i % dim] += (b as f32 - 128.0) / 128.0;
107    }
108    let norm = vec.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-9);
109    vec.iter_mut().for_each(|x| *x /= norm);
110    vec
111}
112
113/// Map raw output activations to a ternary verdict + metadata.
114fn decode_result(output: Vec<f32>, model: &LoadedModel) -> InferenceResult {
115    let mean = output.iter().sum::<f32>() / output.len() as f32;
116    let trit_verdict: i8 = if mean > 0.05 { 1 } else if mean < -0.05 { -1 } else { 0 };
117    let confidence = mean.abs().min(1.0);
118
119    let verdict_label = match trit_verdict {
120        1  => "affirm (+1)",
121        -1 => "reject (-1)",
122        _  => "hold   ( 0)",
123    };
124
125    let routing_summary = format!(
126        "model={} | kernel={} | dims={}→{} | verdict={} | confidence={:.3}",
127        model.model_id,
128        model.engine.kernel_version,
129        model.input_dim,
130        model.output_dim,
131        verdict_label,
132        confidence,
133    );
134
135    InferenceResult { trit_verdict, confidence, output_vec: output, routing_summary }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn test_load_and_infer() {
144        let platform = Platform::new();
145        let model = platform.load_model("test-epis").unwrap();
146        let result = platform.run_inference(&model, "Should we proceed?").unwrap();
147        assert!([-1i8, 0, 1].contains(&result.trit_verdict));
148        assert!(result.confidence >= 0.0 && result.confidence <= 1.0);
149        assert_eq!(result.output_vec.len(), model.output_dim);
150    }
151
152    #[test]
153    fn test_epis_determinism() {
154        let platform = Platform::new();
155        let model = platform.load_model("test-epis").unwrap();
156        let prompt = "Is this action safe?";
157        let a = platform.run_inference(&model, prompt).unwrap();
158        let b = platform.run_inference(&model, prompt).unwrap();
159        assert_eq!(a.trit_verdict, b.trit_verdict,
160            "EPIS must produce identical verdicts for identical input");
161        assert_eq!(a.output_vec, b.output_vec,
162            "EPIS must produce identical activations for identical input");
163    }
164
165    #[test]
166    fn test_different_prompts_may_differ() {
167        let platform = Platform::new();
168        let model = platform.load_model("test-epis").unwrap();
169        let a = platform.run_inference(&model, "proceed").unwrap();
170        let b = platform.run_inference(&model, "abort").unwrap();
171        assert_ne!(a.output_vec, b.output_vec,
172            "Different prompts must produce different activations");
173    }
174
175    /// Smoke test for file loading — skipped if no .tern.bin is present.
176    #[test]
177    fn test_load_from_file_if_available() {
178        let candidates = [
179            "/home/eri-irfos/llama32-1b.tern.bin",
180            "/home/eri-irfos/Desktop/llama32-1b.tern.bin",
181        ];
182        let path = candidates.iter().find(|p| std::path::Path::new(p).exists());
183
184        if let Some(p) = path {
185            let platform = Platform::new();
186            let model = platform.load_model_from_file(p).unwrap();
187            let result = platform.run_inference(&model, "What is ternary logic?").unwrap();
188            assert!([-1i8, 0, 1].contains(&result.trit_verdict));
189            println!("✓ Real model loaded: {}", result.routing_summary);
190        } else {
191            println!("⚠  No .tern.bin found — skipping file-load smoke test");
192            println!("   Run: python3 scripts/transmute_llama.py to generate one");
193        }
194    }
195}