1use serde::{Deserialize, Serialize};
8
9use crate::events::{EmlEvent, EmlEventLog};
10use crate::operator::{eml_safe, random_params, softmax3};
11
12#[derive(Debug, Clone)]
18struct TrainingPoint {
19 inputs: Vec<f64>,
20 targets: Vec<Option<f64>>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct EmlModel {
52 depth: usize,
53 input_count: usize,
54 head_count: usize,
55 params: Vec<f64>,
57 trained: bool,
59 #[serde(skip)]
61 training_data: Vec<TrainingPoint>,
62 #[serde(skip)]
64 event_log: EmlEventLog,
65 #[serde(skip)]
67 model_name: String,
68}
69
70impl EmlModel {
71 pub fn new(depth: usize, input_count: usize, head_count: usize) -> Self {
81 assert!(
82 (2..=5).contains(&depth),
83 "EmlModel depth must be 2, 3, 4, or 5, got {depth}"
84 );
85 assert!(head_count > 0, "head_count must be >= 1");
86
87 let param_count = Self::compute_param_count(depth, head_count);
88 Self {
89 depth,
90 input_count,
91 head_count,
92 params: vec![0.0; param_count],
93 trained: false,
94 training_data: Vec::new(),
95 event_log: EmlEventLog::new(),
96 model_name: String::new(),
97 }
98 }
99
100 pub fn param_count(&self) -> usize {
102 self.params.len()
103 }
104
105 pub fn params_slice(&self) -> &[f64] {
111 &self.params
112 }
113
114 pub fn params_slice_mut(&mut self) -> &mut [f64] {
120 &mut self.params
121 }
122
123 pub fn mark_trained(&mut self, trained: bool) {
126 self.trained = trained;
127 }
128
129 pub fn is_trained(&self) -> bool {
131 self.trained
132 }
133
134 pub fn training_sample_count(&self) -> usize {
136 self.training_data.len()
137 }
138
139 pub fn depth(&self) -> usize {
141 self.depth
142 }
143
144 pub fn input_count(&self) -> usize {
146 self.input_count
147 }
148
149 pub fn head_count(&self) -> usize {
151 self.head_count
152 }
153
154 pub fn set_model_name(&mut self, name: impl Into<String>) {
162 self.model_name = name.into();
163 }
164
165 pub fn model_name(&self) -> &str {
167 &self.model_name
168 }
169
170 pub fn drain_events(&mut self) -> Vec<EmlEvent> {
175 self.event_log.drain()
176 }
177
178 pub fn push_event(&mut self, event: EmlEvent) {
180 self.event_log.push(event);
181 }
182
183 pub fn pending_event_count(&self) -> usize {
185 self.event_log.len()
186 }
187
188 fn compute_param_count(depth: usize, head_count: usize) -> usize {
203 let mut total = 24;
205
206 match depth {
210 2 => {
211 }
213 3 => {
214 total += 2 * 4;
216 }
217 4 => {
218 total += 4 * 3;
220 total += 2 * 4;
222 }
223 5 => {
224 total += 4 * 3;
226 total += 4 * 3;
228 total += 2 * 4;
230 }
231 _ => unreachable!(),
232 }
233
234 total += head_count * 2;
236
237 total
238 }
239
240 pub fn predict(&self, inputs: &[f64]) -> Vec<f64> {
249 assert_eq!(
250 inputs.len(),
251 self.input_count,
252 "expected {} inputs, got {}",
253 self.input_count,
254 inputs.len()
255 );
256 self.evaluate_with_params(&self.params, inputs)
257 }
258
259 pub fn predict_primary(&self, inputs: &[f64]) -> f64 {
261 self.predict(inputs)[0]
262 }
263
264 fn evaluate_with_params(&self, params: &[f64], inputs: &[f64]) -> Vec<f64> {
266 let feature_pairs = Self::feature_pairs(self.input_count);
268 let mut a = [0.0f64; 8];
269 for i in 0..8 {
270 let base = i * 3;
271 let (alpha, beta, gamma) = softmax3(params[base], params[base + 1], params[base + 2]);
272 let (j, k) = feature_pairs[i];
273 a[i] = (alpha + beta * inputs[j] + gamma * inputs[k]).clamp(-10.0, 10.0);
274 }
275
276 let b = [
278 eml_safe(a[0], a[1]),
279 eml_safe(a[2], a[3]),
280 eml_safe(a[4], a[5]),
281 eml_safe(a[6], a[7]),
282 ];
283
284 let trunk = match self.depth {
286 2 => {
287 b.to_vec()
289 }
290 3 => {
291 let mut c = [0.0f64; 2];
293 for i in 0..2 {
294 let base = 24 + i * 4;
295 let mix_left = params[base]
296 + params[base + 1] * b[0]
297 + (1.0 - params[base] - params[base + 1]) * b[1];
298 let mix_right = params[base + 2]
299 + params[base + 3] * b[2]
300 + (1.0 - params[base + 2] - params[base + 3]) * b[3];
301 let ml = mix_left.clamp(-10.0, 10.0);
302 let mr = mix_right.clamp(0.01, 10.0);
303 c[i] = eml_safe(ml, mr);
304 }
305 c.to_vec()
306 }
307 4 => {
308 let level2_pairs: [(usize, usize, usize, usize); 4] = [
310 (0, 1, 2, 3),
311 (0, 1, 2, 3),
312 (0, 2, 1, 3),
313 (1, 3, 0, 2),
314 ];
315 let mut c = [0.0f64; 4];
316 for i in 0..4 {
317 let base = 24 + i * 3;
318 let (li, lj, ri, rj) = level2_pairs[i];
319 let (alpha, beta, gamma) =
320 softmax3(params[base], params[base + 1], params[base + 2]);
321 let mix_left = (alpha + beta * b[li] + gamma * b[lj]).clamp(-10.0, 10.0);
322 let (ar, br, gr) = softmax3(
323 params[base] + 0.5,
324 params[base + 1] - 0.5,
325 params[base + 2],
326 );
327 let mix_right = (ar + br * b[ri] + gr * b[rj]).clamp(0.01, 10.0);
328 c[i] = eml_safe(mix_left, mix_right);
329 }
330
331 let level3_pairs: [(usize, usize, usize, usize); 2] =
333 [(0, 1, 2, 3), (0, 2, 1, 3)];
334 let mut d = [0.0f64; 2];
335 for i in 0..2 {
336 let base = 36 + i * 4;
337 let (li, lj, ri, rj) = level3_pairs[i];
338 let mix_left = (params[base]
339 + params[base + 1] * c[li]
340 + (1.0 - params[base] - params[base + 1]) * c[lj])
341 .clamp(-10.0, 10.0);
342 let mix_right = (params[base + 2]
343 + params[base + 3] * c[ri]
344 + (1.0 - params[base + 2] - params[base + 3]) * c[rj])
345 .clamp(0.01, 10.0);
346 d[i] = eml_safe(mix_left, mix_right);
347 }
348 d.to_vec()
349 }
350 5 => {
351 let level2_pairs: [(usize, usize, usize, usize); 4] = [
353 (0, 1, 2, 3),
354 (0, 1, 2, 3),
355 (0, 2, 1, 3),
356 (1, 3, 0, 2),
357 ];
358 let mut c = [0.0f64; 4];
359 for i in 0..4 {
360 let base = 24 + i * 3;
361 let (li, lj, ri, rj) = level2_pairs[i];
362 let (alpha, beta, gamma) =
363 softmax3(params[base], params[base + 1], params[base + 2]);
364 let mix_left = (alpha + beta * b[li] + gamma * b[lj]).clamp(-10.0, 10.0);
365 let (ar, br, gr) = softmax3(
366 params[base] + 0.5,
367 params[base + 1] - 0.5,
368 params[base + 2],
369 );
370 let mix_right = (ar + br * b[ri] + gr * b[rj]).clamp(0.01, 10.0);
371 c[i] = eml_safe(mix_left, mix_right);
372 }
373
374 let level3_pairs: [(usize, usize, usize, usize); 4] = [
376 (0, 1, 2, 3),
377 (0, 2, 1, 3),
378 (1, 3, 0, 2),
379 (0, 3, 1, 2),
380 ];
381 let mut e = [0.0f64; 4];
382 for i in 0..4 {
383 let base = 36 + i * 3;
384 let (li, lj, ri, rj) = level3_pairs[i];
385 let (alpha, beta, gamma) =
386 softmax3(params[base], params[base + 1], params[base + 2]);
387 let mix_left = (alpha + beta * c[li] + gamma * c[lj]).clamp(-10.0, 10.0);
388 let (ar, br, gr) = softmax3(
389 params[base] + 0.5,
390 params[base + 1] - 0.5,
391 params[base + 2],
392 );
393 let mix_right = (ar + br * c[ri] + gr * c[rj]).clamp(0.01, 10.0);
394 e[i] = eml_safe(mix_left, mix_right);
395 }
396
397 let mut f = [0.0f64; 2];
399 for i in 0..2 {
400 let base = 48 + i * 4;
401 let li = i * 2;
402 let lj = i * 2 + 1;
403 let ri = (i * 2 + 2) % 4;
404 let rj = (i * 2 + 3) % 4;
405 let mix_left = (params[base]
406 + params[base + 1] * e[li]
407 + (1.0 - params[base] - params[base + 1]) * e[lj])
408 .clamp(-10.0, 10.0);
409 let mix_right = (params[base + 2]
410 + params[base + 3] * e[ri]
411 + (1.0 - params[base + 2] - params[base + 3]) * e[rj])
412 .clamp(0.01, 10.0);
413 f[i] = eml_safe(mix_left, mix_right);
414 }
415 f.to_vec()
416 }
417 _ => unreachable!(),
418 };
419
420 let head_base = self.param_count() - self.head_count * 2;
422 let mut outputs = Vec::with_capacity(self.head_count);
423 for k in 0..self.head_count {
424 let base = head_base + k * 2;
425 let w0 = params[base];
426 let w1 = params[base + 1];
427 let (left, right) = if trunk.len() >= 2 {
428 (
429 (w0 * trunk[0] + (1.0 - w0) * trunk[1]).clamp(-10.0, 10.0),
430 (w1 * trunk[0] + (1.0 - w1) * trunk[1]).clamp(0.01, 10.0),
431 )
432 } else {
433 (
434 (w0 * trunk[0]).clamp(-10.0, 10.0),
435 (w1 * trunk[0]).clamp(0.01, 10.0),
436 )
437 };
438 outputs.push(eml_safe(left, right).max(0.0));
439 }
440
441 outputs
442 }
443
444 fn feature_pairs(input_count: usize) -> [(usize, usize); 8] {
446 let mut pairs = [(0usize, 0usize); 8];
447 for i in 0..8 {
448 pairs[i] = (
449 (i * 2) % input_count,
450 (i * 2 + 1) % input_count,
451 );
452 }
453 pairs
454 }
455
456 pub fn record(&mut self, inputs: &[f64], targets: &[Option<f64>]) {
468 assert_eq!(
469 inputs.len(),
470 self.input_count,
471 "expected {} inputs, got {}",
472 self.input_count,
473 inputs.len()
474 );
475 assert_eq!(
476 targets.len(),
477 self.head_count,
478 "expected {} targets, got {}",
479 self.head_count,
480 targets.len()
481 );
482 self.training_data.push(TrainingPoint {
483 inputs: inputs.to_vec(),
484 targets: targets.to_vec(),
485 });
486 }
487
488 pub fn train(&mut self) -> bool {
493 if self.training_data.len() < 50 {
494 return false;
495 }
496
497 let param_count = self.params.len();
498 let mut best_params = self.params.clone();
499 let mse_before = self.evaluate_mse(&self.params);
500 let mut best_mse = mse_before;
501
502 let restart_count = if param_count > 40 { 200 } else { 100 };
504 let mut rng_state: u64 = 0xDEAD_BEEF_CAFE_1234;
505 for _ in 0..restart_count {
506 let candidate = random_params(&mut rng_state, param_count);
507 let mse = self.evaluate_mse(&candidate);
508 if mse < best_mse {
509 best_mse = mse;
510 best_params = candidate;
511 }
512 }
513
514 let deltas = [-0.1, -0.01, -0.001, 0.001, 0.01, 0.1];
516 for _ in 0..1000 {
517 let mut improved = false;
518 for i in 0..param_count {
519 for &delta in &deltas {
520 let mut candidate = best_params.clone();
521 candidate[i] += delta;
522 let mse = self.evaluate_mse(&candidate);
523 if mse < best_mse {
524 best_mse = mse;
525 best_params = candidate;
526 improved = true;
527 }
528 }
529 }
530 if !improved {
531 break;
532 }
533 }
534
535 self.params = best_params;
536 self.trained = best_mse < 0.01;
537
538 let name = if self.model_name.is_empty() {
540 format!("eml_d{}x{}x{}", self.depth, self.input_count, self.head_count)
541 } else {
542 self.model_name.clone()
543 };
544 self.event_log.push(EmlEvent::Trained {
545 model_name: name,
546 samples_used: self.training_data.len(),
547 mse_before,
548 mse_after: best_mse,
549 converged: self.trained,
550 param_count: self.params.len(),
551 });
552
553 self.trained
554 }
555
556 fn evaluate_mse(&self, params: &[f64]) -> f64 {
558 if self.training_data.is_empty() {
559 return f64::MAX;
560 }
561
562 let mut total_loss = 0.0;
563 let mut total_weight = 0.0;
564
565 for tp in &self.training_data {
566 let predicted = self.evaluate_with_params(params, &tp.inputs);
567 for (k, target) in tp.targets.iter().enumerate() {
568 if let Some(t) = target {
569 let weight = if k == 0 { 1.0 } else { 0.3 };
571 total_loss += weight * (predicted[k] - t).powi(2);
572 total_weight += weight;
573 }
574 }
575 }
576
577 if total_weight > 0.0 {
578 total_loss / total_weight
579 } else {
580 f64::MAX
581 }
582 }
583
584 pub fn distill(&self, target_depth: usize, num_samples: usize) -> EmlModel {
602 assert!(
603 target_depth < self.depth,
604 "student depth ({target_depth}) must be less than teacher depth ({})",
605 self.depth
606 );
607
608 let mut student = EmlModel::new(target_depth, self.input_count, self.head_count);
609
610 let mut rng_state: u64 = 0xCAFE_BABE_1234_5678;
613 let lcg_next = |state: &mut u64| -> f64 {
614 *state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
615 (*state >> 33) as f64 / (1u64 << 31) as f64
617 };
618
619 for _ in 0..num_samples.max(50) {
620 let inputs: Vec<f64> = (0..self.input_count)
621 .map(|_| lcg_next(&mut rng_state))
622 .collect();
623 let teacher_out = self.predict(&inputs);
624 let targets: Vec<Option<f64>> = teacher_out.into_iter().map(Some).collect();
625 student.record(&inputs, &targets);
626 }
627
628 student.train();
629 student
630 }
631
632 pub fn to_json(&self) -> String {
638 serde_json::to_string(self).expect("EmlModel serialization should not fail")
639 }
640
641 pub fn from_json(json: &str) -> Option<Self> {
645 serde_json::from_str(json).ok()
646 }
647}
648
649#[cfg(test)]
650mod tests {
651 use super::*;
652
653 #[test]
654 fn new_model_defaults() {
655 let m = EmlModel::new(4, 7, 3);
656 assert_eq!(m.depth(), 4);
657 assert_eq!(m.input_count(), 7);
658 assert_eq!(m.head_count(), 3);
659 assert!(!m.is_trained());
660 assert_eq!(m.training_sample_count(), 0);
661 }
662
663 #[test]
664 fn param_count_depth_2() {
665 let m = EmlModel::new(2, 5, 1);
666 assert_eq!(m.param_count(), 26);
668 }
669
670 #[test]
671 fn param_count_depth_3() {
672 let m = EmlModel::new(3, 7, 1);
673 assert_eq!(m.param_count(), 34);
675 }
676
677 #[test]
678 fn param_count_depth_4_single_head() {
679 let m = EmlModel::new(4, 7, 1);
680 assert_eq!(m.param_count(), 46);
682 }
683
684 #[test]
685 fn param_count_depth_4_three_heads() {
686 let m = EmlModel::new(4, 7, 3);
687 assert_eq!(m.param_count(), 50);
689 }
690
691 #[test]
692 fn param_count_depth_5() {
693 let m = EmlModel::new(5, 4, 2);
694 assert_eq!(m.param_count(), 60);
696 }
697
698 #[test]
699 fn predict_untrained_produces_values() {
700 let m = EmlModel::new(4, 7, 3);
701 let inputs = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
702 let result = m.predict(&inputs);
703 assert_eq!(result.len(), 3);
704 for &v in &result {
705 assert!(v.is_finite(), "prediction should be finite");
706 assert!(v >= 0.0, "prediction should be non-negative");
707 }
708 }
709
710 #[test]
711 fn predict_primary_matches_first_head() {
712 let m = EmlModel::new(3, 5, 3);
713 let inputs = vec![0.1, 0.2, 0.3, 0.4, 0.5];
714 let all = m.predict(&inputs);
715 let primary = m.predict_primary(&inputs);
716 assert!(
717 (primary - all[0]).abs() < 1e-12,
718 "predict_primary should match predict()[0]"
719 );
720 }
721
722 #[test]
723 fn record_increments_count() {
724 let mut m = EmlModel::new(3, 3, 1);
725 assert_eq!(m.training_sample_count(), 0);
726 m.record(&[0.1, 0.2, 0.3], &[Some(1.0)]);
727 assert_eq!(m.training_sample_count(), 1);
728 }
729
730 #[test]
731 fn train_insufficient_data_returns_false() {
732 let mut m = EmlModel::new(3, 3, 1);
733 for i in 0..10 {
734 m.record(
735 &[i as f64 / 10.0, 0.5, 0.5],
736 &[Some(1.0)],
737 );
738 }
739 assert!(!m.train());
740 assert!(!m.is_trained());
741 }
742
743 #[test]
744 fn training_convergence_polynomial() {
745 let mut m = EmlModel::new(4, 1, 1);
747 for i in 0..100 {
748 let x = i as f64 / 100.0;
749 let y = x * x;
750 m.record(&[x], &[Some(y)]);
751 }
752 let _ = m.train();
753 let pred = m.predict_primary(&[0.5]);
755 assert!(pred.is_finite());
756 }
757
758 #[test]
759 fn multi_head_training() {
760 let mut m = EmlModel::new(4, 2, 3);
761 for i in 0..80 {
762 let x = i as f64 / 80.0;
763 let y = (i + 10) as f64 / 80.0;
764 m.record(
765 &[x, y],
766 &[Some(x + y), Some(x * y), None],
767 );
768 }
769 let _ = m.train();
770 let pred = m.predict(&[0.5, 0.5]);
771 assert_eq!(pred.len(), 3);
772 for &v in &pred {
773 assert!(v.is_finite());
774 }
775 }
776
777 #[test]
778 fn serialization_roundtrip() {
779 let mut m = EmlModel::new(4, 5, 2);
780 for (i, p) in m.params.iter_mut().enumerate() {
782 *p = (i as f64 * 0.1).sin();
783 }
784 m.trained = true;
785
786 let json = m.to_json();
787 let m2 = EmlModel::from_json(&json).expect("should deserialize");
788
789 assert_eq!(m.depth, m2.depth);
790 assert_eq!(m.input_count, m2.input_count);
791 assert_eq!(m.head_count, m2.head_count);
792 assert_eq!(m.params.len(), m2.params.len());
793 for (i, (a, b)) in m.params.iter().zip(m2.params.iter()).enumerate() {
794 assert!(
795 (a - b).abs() < 1e-14,
796 "param[{i}] mismatch: {a} vs {b}"
797 );
798 }
799 assert_eq!(m.trained, m2.trained);
800 assert_eq!(m2.training_sample_count(), 0);
802 }
803
804 #[test]
805 fn from_json_invalid_returns_none() {
806 assert!(EmlModel::from_json("not valid json").is_none());
807 }
808
809 #[test]
810 fn various_depths_produce_finite_output() {
811 for depth in 2..=5 {
812 let m = EmlModel::new(depth, 4, 2);
813 let inputs = vec![0.3, 0.5, 0.7, 0.1];
814 let result = m.predict(&inputs);
815 assert_eq!(result.len(), 2);
816 for &v in &result {
817 assert!(
818 v.is_finite(),
819 "depth-{depth} should produce finite output"
820 );
821 }
822 }
823 }
824
825 #[test]
826 #[should_panic(expected = "EmlModel depth must be 2, 3, 4, or 5")]
827 fn invalid_depth_panics() {
828 EmlModel::new(6, 3, 1);
829 }
830
831 #[test]
832 #[should_panic(expected = "head_count must be >= 1")]
833 fn zero_heads_panics() {
834 EmlModel::new(3, 3, 0);
835 }
836
837 #[test]
838 fn distill_depth_4_to_depth_2() {
839 let mut teacher = EmlModel::new(4, 2, 1);
844 for (i, p) in teacher.params.iter_mut().enumerate() {
846 *p = ((i as f64) * 0.37).sin() * 0.5;
847 }
848 teacher.trained = true;
849
850 let student = teacher.distill(2, 500);
851 assert_eq!(student.depth(), 2);
852 assert_eq!(student.input_count(), 2);
853 assert_eq!(student.head_count(), 1);
854
855 let mut total_err = 0.0;
857 let mut count = 0;
858 for i in 0..10 {
859 for j in 0..10 {
860 let x = i as f64 / 10.0;
861 let y = j as f64 / 10.0;
862 let t = teacher.predict_primary(&[x, y]);
863 let s = student.predict_primary(&[x, y]);
864 assert!(t.is_finite());
865 assert!(s.is_finite());
866 total_err += (t - s).abs();
867 count += 1;
868 }
869 }
870 let mae = total_err / count as f64;
871
872 assert!(
877 mae < 50.0,
878 "distilled model MAE should be reasonable, got {mae}"
879 );
880 }
881
882 #[test]
883 fn distill_multi_head() {
884 let mut teacher = EmlModel::new(4, 2, 2);
885 for i in 0..100 {
886 let x = i as f64 / 100.0;
887 let y = (i + 20) as f64 / 100.0;
888 teacher.record(&[x, y], &[Some(x + y), Some(x * y)]);
889 }
890 teacher.train();
891
892 let student = teacher.distill(2, 200);
893 assert_eq!(student.depth(), 2);
894 assert_eq!(student.head_count(), 2);
895
896 let pred = student.predict(&[0.5, 0.7]);
898 assert_eq!(pred.len(), 2);
899 for &v in &pred {
900 assert!(v.is_finite());
901 }
902 }
903
904 #[test]
905 #[should_panic(expected = "student depth")]
906 fn distill_same_depth_panics() {
907 let teacher = EmlModel::new(4, 3, 1);
908 teacher.distill(4, 100);
909 }
910}