use serde::{Deserialize, Serialize};
use crate::events::{EmlEvent, EmlEventLog};
use crate::operator::{eml_safe, random_params, softmax3};
#[derive(Debug, Clone)]
struct TrainingPoint {
inputs: Vec<f64>,
targets: Vec<Option<f64>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmlModel {
depth: usize,
input_count: usize,
head_count: usize,
params: Vec<f64>,
trained: bool,
#[serde(skip)]
training_data: Vec<TrainingPoint>,
#[serde(skip)]
event_log: EmlEventLog,
#[serde(skip)]
model_name: String,
}
impl EmlModel {
pub fn new(depth: usize, input_count: usize, head_count: usize) -> Self {
assert!(
(2..=5).contains(&depth),
"EmlModel depth must be 2, 3, 4, or 5, got {depth}"
);
assert!(head_count > 0, "head_count must be >= 1");
let param_count = Self::compute_param_count(depth, head_count);
Self {
depth,
input_count,
head_count,
params: vec![0.0; param_count],
trained: false,
training_data: Vec::new(),
event_log: EmlEventLog::new(),
model_name: String::new(),
}
}
pub fn param_count(&self) -> usize {
self.params.len()
}
pub fn params_slice(&self) -> &[f64] {
&self.params
}
pub fn params_slice_mut(&mut self) -> &mut [f64] {
&mut self.params
}
pub fn mark_trained(&mut self, trained: bool) {
self.trained = trained;
}
pub fn is_trained(&self) -> bool {
self.trained
}
pub fn training_sample_count(&self) -> usize {
self.training_data.len()
}
pub fn depth(&self) -> usize {
self.depth
}
pub fn input_count(&self) -> usize {
self.input_count
}
pub fn head_count(&self) -> usize {
self.head_count
}
pub fn set_model_name(&mut self, name: impl Into<String>) {
self.model_name = name.into();
}
pub fn model_name(&self) -> &str {
&self.model_name
}
pub fn drain_events(&mut self) -> Vec<EmlEvent> {
self.event_log.drain()
}
pub fn push_event(&mut self, event: EmlEvent) {
self.event_log.push(event);
}
pub fn pending_event_count(&self) -> usize {
self.event_log.len()
}
fn compute_param_count(depth: usize, head_count: usize) -> usize {
let mut total = 24;
match depth {
2 => {
}
3 => {
total += 2 * 4;
}
4 => {
total += 4 * 3;
total += 2 * 4;
}
5 => {
total += 4 * 3;
total += 4 * 3;
total += 2 * 4;
}
_ => unreachable!(),
}
total += head_count * 2;
total
}
pub fn predict(&self, inputs: &[f64]) -> Vec<f64> {
assert_eq!(
inputs.len(),
self.input_count,
"expected {} inputs, got {}",
self.input_count,
inputs.len()
);
self.evaluate_with_params(&self.params, inputs)
}
pub fn predict_primary(&self, inputs: &[f64]) -> f64 {
self.predict(inputs)[0]
}
fn evaluate_with_params(&self, params: &[f64], inputs: &[f64]) -> Vec<f64> {
let feature_pairs = Self::feature_pairs(self.input_count);
let mut a = [0.0f64; 8];
for i in 0..8 {
let base = i * 3;
let (alpha, beta, gamma) = softmax3(params[base], params[base + 1], params[base + 2]);
let (j, k) = feature_pairs[i];
a[i] = (alpha + beta * inputs[j] + gamma * inputs[k]).clamp(-10.0, 10.0);
}
let b = [
eml_safe(a[0], a[1]),
eml_safe(a[2], a[3]),
eml_safe(a[4], a[5]),
eml_safe(a[6], a[7]),
];
let trunk = match self.depth {
2 => {
b.to_vec()
}
3 => {
let mut c = [0.0f64; 2];
for i in 0..2 {
let base = 24 + i * 4;
let mix_left = params[base]
+ params[base + 1] * b[0]
+ (1.0 - params[base] - params[base + 1]) * b[1];
let mix_right = params[base + 2]
+ params[base + 3] * b[2]
+ (1.0 - params[base + 2] - params[base + 3]) * b[3];
let ml = mix_left.clamp(-10.0, 10.0);
let mr = mix_right.clamp(0.01, 10.0);
c[i] = eml_safe(ml, mr);
}
c.to_vec()
}
4 => {
let level2_pairs: [(usize, usize, usize, usize); 4] = [
(0, 1, 2, 3),
(0, 1, 2, 3),
(0, 2, 1, 3),
(1, 3, 0, 2),
];
let mut c = [0.0f64; 4];
for i in 0..4 {
let base = 24 + i * 3;
let (li, lj, ri, rj) = level2_pairs[i];
let (alpha, beta, gamma) =
softmax3(params[base], params[base + 1], params[base + 2]);
let mix_left = (alpha + beta * b[li] + gamma * b[lj]).clamp(-10.0, 10.0);
let (ar, br, gr) = softmax3(
params[base] + 0.5,
params[base + 1] - 0.5,
params[base + 2],
);
let mix_right = (ar + br * b[ri] + gr * b[rj]).clamp(0.01, 10.0);
c[i] = eml_safe(mix_left, mix_right);
}
let level3_pairs: [(usize, usize, usize, usize); 2] =
[(0, 1, 2, 3), (0, 2, 1, 3)];
let mut d = [0.0f64; 2];
for i in 0..2 {
let base = 36 + i * 4;
let (li, lj, ri, rj) = level3_pairs[i];
let mix_left = (params[base]
+ params[base + 1] * c[li]
+ (1.0 - params[base] - params[base + 1]) * c[lj])
.clamp(-10.0, 10.0);
let mix_right = (params[base + 2]
+ params[base + 3] * c[ri]
+ (1.0 - params[base + 2] - params[base + 3]) * c[rj])
.clamp(0.01, 10.0);
d[i] = eml_safe(mix_left, mix_right);
}
d.to_vec()
}
5 => {
let level2_pairs: [(usize, usize, usize, usize); 4] = [
(0, 1, 2, 3),
(0, 1, 2, 3),
(0, 2, 1, 3),
(1, 3, 0, 2),
];
let mut c = [0.0f64; 4];
for i in 0..4 {
let base = 24 + i * 3;
let (li, lj, ri, rj) = level2_pairs[i];
let (alpha, beta, gamma) =
softmax3(params[base], params[base + 1], params[base + 2]);
let mix_left = (alpha + beta * b[li] + gamma * b[lj]).clamp(-10.0, 10.0);
let (ar, br, gr) = softmax3(
params[base] + 0.5,
params[base + 1] - 0.5,
params[base + 2],
);
let mix_right = (ar + br * b[ri] + gr * b[rj]).clamp(0.01, 10.0);
c[i] = eml_safe(mix_left, mix_right);
}
let level3_pairs: [(usize, usize, usize, usize); 4] = [
(0, 1, 2, 3),
(0, 2, 1, 3),
(1, 3, 0, 2),
(0, 3, 1, 2),
];
let mut e = [0.0f64; 4];
for i in 0..4 {
let base = 36 + i * 3;
let (li, lj, ri, rj) = level3_pairs[i];
let (alpha, beta, gamma) =
softmax3(params[base], params[base + 1], params[base + 2]);
let mix_left = (alpha + beta * c[li] + gamma * c[lj]).clamp(-10.0, 10.0);
let (ar, br, gr) = softmax3(
params[base] + 0.5,
params[base + 1] - 0.5,
params[base + 2],
);
let mix_right = (ar + br * c[ri] + gr * c[rj]).clamp(0.01, 10.0);
e[i] = eml_safe(mix_left, mix_right);
}
let mut f = [0.0f64; 2];
for i in 0..2 {
let base = 48 + i * 4;
let li = i * 2;
let lj = i * 2 + 1;
let ri = (i * 2 + 2) % 4;
let rj = (i * 2 + 3) % 4;
let mix_left = (params[base]
+ params[base + 1] * e[li]
+ (1.0 - params[base] - params[base + 1]) * e[lj])
.clamp(-10.0, 10.0);
let mix_right = (params[base + 2]
+ params[base + 3] * e[ri]
+ (1.0 - params[base + 2] - params[base + 3]) * e[rj])
.clamp(0.01, 10.0);
f[i] = eml_safe(mix_left, mix_right);
}
f.to_vec()
}
_ => unreachable!(),
};
let head_base = self.param_count() - self.head_count * 2;
let mut outputs = Vec::with_capacity(self.head_count);
for k in 0..self.head_count {
let base = head_base + k * 2;
let w0 = params[base];
let w1 = params[base + 1];
let (left, right) = if trunk.len() >= 2 {
(
(w0 * trunk[0] + (1.0 - w0) * trunk[1]).clamp(-10.0, 10.0),
(w1 * trunk[0] + (1.0 - w1) * trunk[1]).clamp(0.01, 10.0),
)
} else {
(
(w0 * trunk[0]).clamp(-10.0, 10.0),
(w1 * trunk[0]).clamp(0.01, 10.0),
)
};
outputs.push(eml_safe(left, right).max(0.0));
}
outputs
}
fn feature_pairs(input_count: usize) -> [(usize, usize); 8] {
let mut pairs = [(0usize, 0usize); 8];
for i in 0..8 {
pairs[i] = (
(i * 2) % input_count,
(i * 2 + 1) % input_count,
);
}
pairs
}
pub fn record(&mut self, inputs: &[f64], targets: &[Option<f64>]) {
assert_eq!(
inputs.len(),
self.input_count,
"expected {} inputs, got {}",
self.input_count,
inputs.len()
);
assert_eq!(
targets.len(),
self.head_count,
"expected {} targets, got {}",
self.head_count,
targets.len()
);
self.training_data.push(TrainingPoint {
inputs: inputs.to_vec(),
targets: targets.to_vec(),
});
}
pub fn train(&mut self) -> bool {
if self.training_data.len() < 50 {
return false;
}
let param_count = self.params.len();
let mut best_params = self.params.clone();
let mse_before = self.evaluate_mse(&self.params);
let mut best_mse = mse_before;
let restart_count = if param_count > 40 { 200 } else { 100 };
let mut rng_state: u64 = 0xDEAD_BEEF_CAFE_1234;
for _ in 0..restart_count {
let candidate = random_params(&mut rng_state, param_count);
let mse = self.evaluate_mse(&candidate);
if mse < best_mse {
best_mse = mse;
best_params = candidate;
}
}
let deltas = [-0.1, -0.01, -0.001, 0.001, 0.01, 0.1];
for _ in 0..1000 {
let mut improved = false;
for i in 0..param_count {
for &delta in &deltas {
let mut candidate = best_params.clone();
candidate[i] += delta;
let mse = self.evaluate_mse(&candidate);
if mse < best_mse {
best_mse = mse;
best_params = candidate;
improved = true;
}
}
}
if !improved {
break;
}
}
self.params = best_params;
self.trained = best_mse < 0.01;
let name = if self.model_name.is_empty() {
format!("eml_d{}x{}x{}", self.depth, self.input_count, self.head_count)
} else {
self.model_name.clone()
};
self.event_log.push(EmlEvent::Trained {
model_name: name,
samples_used: self.training_data.len(),
mse_before,
mse_after: best_mse,
converged: self.trained,
param_count: self.params.len(),
});
self.trained
}
fn evaluate_mse(&self, params: &[f64]) -> f64 {
if self.training_data.is_empty() {
return f64::MAX;
}
let mut total_loss = 0.0;
let mut total_weight = 0.0;
for tp in &self.training_data {
let predicted = self.evaluate_with_params(params, &tp.inputs);
for (k, target) in tp.targets.iter().enumerate() {
if let Some(t) = target {
let weight = if k == 0 { 1.0 } else { 0.3 };
total_loss += weight * (predicted[k] - t).powi(2);
total_weight += weight;
}
}
}
if total_weight > 0.0 {
total_loss / total_weight
} else {
f64::MAX
}
}
pub fn distill(&self, target_depth: usize, num_samples: usize) -> EmlModel {
assert!(
target_depth < self.depth,
"student depth ({target_depth}) must be less than teacher depth ({})",
self.depth
);
let mut student = EmlModel::new(target_depth, self.input_count, self.head_count);
let mut rng_state: u64 = 0xCAFE_BABE_1234_5678;
let lcg_next = |state: &mut u64| -> f64 {
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
(*state >> 33) as f64 / (1u64 << 31) as f64
};
for _ in 0..num_samples.max(50) {
let inputs: Vec<f64> = (0..self.input_count)
.map(|_| lcg_next(&mut rng_state))
.collect();
let teacher_out = self.predict(&inputs);
let targets: Vec<Option<f64>> = teacher_out.into_iter().map(Some).collect();
student.record(&inputs, &targets);
}
student.train();
student
}
pub fn to_json(&self) -> String {
serde_json::to_string(self).expect("EmlModel serialization should not fail")
}
pub fn from_json(json: &str) -> Option<Self> {
serde_json::from_str(json).ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_model_defaults() {
let m = EmlModel::new(4, 7, 3);
assert_eq!(m.depth(), 4);
assert_eq!(m.input_count(), 7);
assert_eq!(m.head_count(), 3);
assert!(!m.is_trained());
assert_eq!(m.training_sample_count(), 0);
}
#[test]
fn param_count_depth_2() {
let m = EmlModel::new(2, 5, 1);
assert_eq!(m.param_count(), 26);
}
#[test]
fn param_count_depth_3() {
let m = EmlModel::new(3, 7, 1);
assert_eq!(m.param_count(), 34);
}
#[test]
fn param_count_depth_4_single_head() {
let m = EmlModel::new(4, 7, 1);
assert_eq!(m.param_count(), 46);
}
#[test]
fn param_count_depth_4_three_heads() {
let m = EmlModel::new(4, 7, 3);
assert_eq!(m.param_count(), 50);
}
#[test]
fn param_count_depth_5() {
let m = EmlModel::new(5, 4, 2);
assert_eq!(m.param_count(), 60);
}
#[test]
fn predict_untrained_produces_values() {
let m = EmlModel::new(4, 7, 3);
let inputs = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
let result = m.predict(&inputs);
assert_eq!(result.len(), 3);
for &v in &result {
assert!(v.is_finite(), "prediction should be finite");
assert!(v >= 0.0, "prediction should be non-negative");
}
}
#[test]
fn predict_primary_matches_first_head() {
let m = EmlModel::new(3, 5, 3);
let inputs = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let all = m.predict(&inputs);
let primary = m.predict_primary(&inputs);
assert!(
(primary - all[0]).abs() < 1e-12,
"predict_primary should match predict()[0]"
);
}
#[test]
fn record_increments_count() {
let mut m = EmlModel::new(3, 3, 1);
assert_eq!(m.training_sample_count(), 0);
m.record(&[0.1, 0.2, 0.3], &[Some(1.0)]);
assert_eq!(m.training_sample_count(), 1);
}
#[test]
fn train_insufficient_data_returns_false() {
let mut m = EmlModel::new(3, 3, 1);
for i in 0..10 {
m.record(
&[i as f64 / 10.0, 0.5, 0.5],
&[Some(1.0)],
);
}
assert!(!m.train());
assert!(!m.is_trained());
}
#[test]
fn training_convergence_polynomial() {
let mut m = EmlModel::new(4, 1, 1);
for i in 0..100 {
let x = i as f64 / 100.0;
let y = x * x;
m.record(&[x], &[Some(y)]);
}
let _ = m.train();
let pred = m.predict_primary(&[0.5]);
assert!(pred.is_finite());
}
#[test]
fn multi_head_training() {
let mut m = EmlModel::new(4, 2, 3);
for i in 0..80 {
let x = i as f64 / 80.0;
let y = (i + 10) as f64 / 80.0;
m.record(
&[x, y],
&[Some(x + y), Some(x * y), None],
);
}
let _ = m.train();
let pred = m.predict(&[0.5, 0.5]);
assert_eq!(pred.len(), 3);
for &v in &pred {
assert!(v.is_finite());
}
}
#[test]
fn serialization_roundtrip() {
let mut m = EmlModel::new(4, 5, 2);
for (i, p) in m.params.iter_mut().enumerate() {
*p = (i as f64 * 0.1).sin();
}
m.trained = true;
let json = m.to_json();
let m2 = EmlModel::from_json(&json).expect("should deserialize");
assert_eq!(m.depth, m2.depth);
assert_eq!(m.input_count, m2.input_count);
assert_eq!(m.head_count, m2.head_count);
assert_eq!(m.params.len(), m2.params.len());
for (i, (a, b)) in m.params.iter().zip(m2.params.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-14,
"param[{i}] mismatch: {a} vs {b}"
);
}
assert_eq!(m.trained, m2.trained);
assert_eq!(m2.training_sample_count(), 0);
}
#[test]
fn from_json_invalid_returns_none() {
assert!(EmlModel::from_json("not valid json").is_none());
}
#[test]
fn various_depths_produce_finite_output() {
for depth in 2..=5 {
let m = EmlModel::new(depth, 4, 2);
let inputs = vec![0.3, 0.5, 0.7, 0.1];
let result = m.predict(&inputs);
assert_eq!(result.len(), 2);
for &v in &result {
assert!(
v.is_finite(),
"depth-{depth} should produce finite output"
);
}
}
}
#[test]
#[should_panic(expected = "EmlModel depth must be 2, 3, 4, or 5")]
fn invalid_depth_panics() {
EmlModel::new(6, 3, 1);
}
#[test]
#[should_panic(expected = "head_count must be >= 1")]
fn zero_heads_panics() {
EmlModel::new(3, 3, 0);
}
#[test]
fn distill_depth_4_to_depth_2() {
let mut teacher = EmlModel::new(4, 2, 1);
for (i, p) in teacher.params.iter_mut().enumerate() {
*p = ((i as f64) * 0.37).sin() * 0.5;
}
teacher.trained = true;
let student = teacher.distill(2, 500);
assert_eq!(student.depth(), 2);
assert_eq!(student.input_count(), 2);
assert_eq!(student.head_count(), 1);
let mut total_err = 0.0;
let mut count = 0;
for i in 0..10 {
for j in 0..10 {
let x = i as f64 / 10.0;
let y = j as f64 / 10.0;
let t = teacher.predict_primary(&[x, y]);
let s = student.predict_primary(&[x, y]);
assert!(t.is_finite());
assert!(s.is_finite());
total_err += (t - s).abs();
count += 1;
}
}
let mae = total_err / count as f64;
assert!(
mae < 50.0,
"distilled model MAE should be reasonable, got {mae}"
);
}
#[test]
fn distill_multi_head() {
let mut teacher = EmlModel::new(4, 2, 2);
for i in 0..100 {
let x = i as f64 / 100.0;
let y = (i + 20) as f64 / 100.0;
teacher.record(&[x, y], &[Some(x + y), Some(x * y)]);
}
teacher.train();
let student = teacher.distill(2, 200);
assert_eq!(student.depth(), 2);
assert_eq!(student.head_count(), 2);
let pred = student.predict(&[0.5, 0.7]);
assert_eq!(pred.len(), 2);
for &v in &pred {
assert!(v.is_finite());
}
}
#[test]
#[should_panic(expected = "student depth")]
fn distill_same_depth_panics() {
let teacher = EmlModel::new(4, 3, 1);
teacher.distill(4, 100);
}
}