Skip to main content

eml_core/
model.rs

1//! Multi-head EML model with training via coordinate descent.
2//!
3//! [`EmlModel`] is a generic, domain-agnostic learned function that maps
4//! N input features to M output heads. It uses the EML operator tree
5//! internally and trains via random restart + coordinate descent.
6
7use serde::{Deserialize, Serialize};
8
9use crate::events::{EmlEvent, EmlEventLog};
10use crate::operator::{eml_safe, random_params, softmax3};
11
12// ---------------------------------------------------------------------------
13// Training point (internal)
14// ---------------------------------------------------------------------------
15
16/// A recorded (inputs, targets) pair for model training.
17#[derive(Debug, Clone)]
18struct TrainingPoint {
19    inputs: Vec<f64>,
20    targets: Vec<Option<f64>>,
21}
22
23// ---------------------------------------------------------------------------
24// EmlModel
25// ---------------------------------------------------------------------------
26
27/// Multi-head EML model for O(1) function approximation.
28///
29/// # Architecture
30///
31/// The model uses a shared trunk of EML operators that feeds into
32/// multiple output heads. Each head produces one scalar prediction.
33///
34/// ```text
35/// Level 0: 8 affine combinations of input features (24 params)
36/// Level 1: 4 EML nodes (no params — pure EML pairing)
37/// Level 2: mixing + EML (depth-dependent params)
38/// ...
39/// Level D: multi-head output (2 params per head)
40/// ```
41///
42/// Supported depths: 2, 3, 4, 5.
43///
44/// # Training
45///
46/// Training uses gradient-free random restart + coordinate descent,
47/// suitable for the modest parameter counts (typically 30-80 params).
48/// Call [`record`] to accumulate training data, then [`train`] to
49/// optimize parameters.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct EmlModel {
52    depth: usize,
53    input_count: usize,
54    head_count: usize,
55    /// Trainable parameters.
56    params: Vec<f64>,
57    /// Whether the model has been trained to convergence.
58    trained: bool,
59    /// Training data buffer.
60    #[serde(skip)]
61    training_data: Vec<TrainingPoint>,
62    /// Accumulated lifecycle events for ExoChain logging.
63    #[serde(skip)]
64    event_log: EmlEventLog,
65    /// Model name used in event logging (set by the wrapper).
66    #[serde(skip)]
67    model_name: String,
68}
69
70impl EmlModel {
71    /// Create a new untrained EML model.
72    ///
73    /// # Arguments
74    /// - `depth`: Tree depth (2, 3, 4, or 5).
75    /// - `input_count`: Number of input features.
76    /// - `head_count`: Number of output heads (>= 1).
77    ///
78    /// # Panics
79    /// Panics if depth is not in {2, 3, 4, 5} or head_count is 0.
80    pub fn new(depth: usize, input_count: usize, head_count: usize) -> Self {
81        assert!(
82            (2..=5).contains(&depth),
83            "EmlModel depth must be 2, 3, 4, or 5, got {depth}"
84        );
85        assert!(head_count > 0, "head_count must be >= 1");
86
87        let param_count = Self::compute_param_count(depth, head_count);
88        Self {
89            depth,
90            input_count,
91            head_count,
92            params: vec![0.0; param_count],
93            trained: false,
94            training_data: Vec::new(),
95            event_log: EmlEventLog::new(),
96            model_name: String::new(),
97        }
98    }
99
100    /// Total number of trainable parameters.
101    pub fn param_count(&self) -> usize {
102        self.params.len()
103    }
104
105    /// Read-only view of the trainable parameters.
106    ///
107    /// Intended for composed models (e.g., [`crate::ToyEmlAttention`]) that
108    /// need to run coordinate descent over the union of several `EmlModel`s'
109    /// parameters. Prefer [`Self::train`] for single-model training.
110    pub fn params_slice(&self) -> &[f64] {
111        &self.params
112    }
113
114    /// Mutable view of the trainable parameters.
115    ///
116    /// Intended for composed models running joint coordinate descent.
117    /// Callers are responsible for restoring parameters they perturb if a
118    /// candidate is rejected.
119    pub fn params_slice_mut(&mut self) -> &mut [f64] {
120        &mut self.params
121    }
122
123    /// Mark the model as trained (or not). Used by composed models after
124    /// joint coordinate descent converges.
125    pub fn mark_trained(&mut self, trained: bool) {
126        self.trained = trained;
127    }
128
129    /// Whether the model has been trained to convergence.
130    pub fn is_trained(&self) -> bool {
131        self.trained
132    }
133
134    /// Number of training samples collected so far.
135    pub fn training_sample_count(&self) -> usize {
136        self.training_data.len()
137    }
138
139    /// Tree depth.
140    pub fn depth(&self) -> usize {
141        self.depth
142    }
143
144    /// Number of input features.
145    pub fn input_count(&self) -> usize {
146        self.input_count
147    }
148
149    /// Number of output heads.
150    pub fn head_count(&self) -> usize {
151        self.head_count
152    }
153
154    // -------------------------------------------------------------------
155    // Event logging
156    // -------------------------------------------------------------------
157
158    /// Set the model name used in emitted events.
159    ///
160    /// Should be called once by the domain-specific wrapper after creation.
161    pub fn set_model_name(&mut self, name: impl Into<String>) {
162        self.model_name = name.into();
163    }
164
165    /// Get the model name.
166    pub fn model_name(&self) -> &str {
167        &self.model_name
168    }
169
170    /// Drain all accumulated lifecycle events, returning them.
171    ///
172    /// The caller is responsible for forwarding these to the ExoChain
173    /// or other audit sinks.
174    pub fn drain_events(&mut self) -> Vec<EmlEvent> {
175        self.event_log.drain()
176    }
177
178    /// Push a custom event into the event log.
179    pub fn push_event(&mut self, event: EmlEvent) {
180        self.event_log.push(event);
181    }
182
183    /// Number of pending (undrained) events.
184    pub fn pending_event_count(&self) -> usize {
185        self.event_log.len()
186    }
187
188    // -------------------------------------------------------------------
189    // Parameter count
190    // -------------------------------------------------------------------
191
192    /// Compute total parameter count for trunk + heads.
193    ///
194    /// Trunk param layout (same as the depth-4 coherence model):
195    ///   Level 0: 8 * 3 = 24 (affine combos via softmax3)
196    ///   Level 1: 0 (pure EML pairing)
197    ///   Level 2: 4 * 3 = 12 (mixing via softmax3)
198    ///   Level 3: 2 * 4 = 8 (mixing with 4 weights each)
199    ///   Head layer: head_count * 2
200    ///
201    /// For shallower trees, fewer mixing levels.
202    fn compute_param_count(depth: usize, head_count: usize) -> usize {
203        // Level 0: always 8 affine nodes * 3 params
204        let mut total = 24;
205
206        // Level 1: no params (pure EML)
207
208        // Levels 2..depth-1: mixing
209        match depth {
210            2 => {
211                // Only level 0 + heads
212            }
213            3 => {
214                // Level 2: 2 mixing nodes * 4 params
215                total += 2 * 4;
216            }
217            4 => {
218                // Level 2: 4 mixing nodes * 3 params
219                total += 4 * 3;
220                // Level 3: 2 mixing nodes * 4 params
221                total += 2 * 4;
222            }
223            5 => {
224                // Level 2: 4 mixing nodes * 3 params
225                total += 4 * 3;
226                // Level 3: 4 mixing nodes * 3 params
227                total += 4 * 3;
228                // Level 4: 2 mixing nodes * 4 params
229                total += 2 * 4;
230            }
231            _ => unreachable!(),
232        }
233
234        // Head layer: 2 params per head
235        total += head_count * 2;
236
237        total
238    }
239
240    // -------------------------------------------------------------------
241    // Prediction
242    // -------------------------------------------------------------------
243
244    /// Predict all heads from input features.
245    ///
246    /// Returns a Vec with one f64 per head. Values are clamped to be
247    /// non-negative.
248    pub fn predict(&self, inputs: &[f64]) -> Vec<f64> {
249        assert_eq!(
250            inputs.len(),
251            self.input_count,
252            "expected {} inputs, got {}",
253            self.input_count,
254            inputs.len()
255        );
256        self.evaluate_with_params(&self.params, inputs)
257    }
258
259    /// Predict only the primary (first) head.
260    pub fn predict_primary(&self, inputs: &[f64]) -> f64 {
261        self.predict(inputs)[0]
262    }
263
264    /// Evaluate with arbitrary params (used during training).
265    fn evaluate_with_params(&self, params: &[f64], inputs: &[f64]) -> Vec<f64> {
266        // Level 0: 8 affine combinations
267        let feature_pairs = Self::feature_pairs(self.input_count);
268        let mut a = [0.0f64; 8];
269        for i in 0..8 {
270            let base = i * 3;
271            let (alpha, beta, gamma) = softmax3(params[base], params[base + 1], params[base + 2]);
272            let (j, k) = feature_pairs[i];
273            a[i] = (alpha + beta * inputs[j] + gamma * inputs[k]).clamp(-10.0, 10.0);
274        }
275
276        // Level 1: 4 EML nodes (pure pairing)
277        let b = [
278            eml_safe(a[0], a[1]),
279            eml_safe(a[2], a[3]),
280            eml_safe(a[4], a[5]),
281            eml_safe(a[6], a[7]),
282        ];
283
284        // Trunk values before heads
285        let trunk = match self.depth {
286            2 => {
287                // Trunk is just b[0..4], heads mix from these
288                b.to_vec()
289            }
290            3 => {
291                // Level 2: 2 mixing nodes
292                let mut c = [0.0f64; 2];
293                for i in 0..2 {
294                    let base = 24 + i * 4;
295                    let mix_left = params[base]
296                        + params[base + 1] * b[0]
297                        + (1.0 - params[base] - params[base + 1]) * b[1];
298                    let mix_right = params[base + 2]
299                        + params[base + 3] * b[2]
300                        + (1.0 - params[base + 2] - params[base + 3]) * b[3];
301                    let ml = mix_left.clamp(-10.0, 10.0);
302                    let mr = mix_right.clamp(0.01, 10.0);
303                    c[i] = eml_safe(ml, mr);
304                }
305                c.to_vec()
306            }
307            4 => {
308                // Level 2: 4 mixing nodes
309                let level2_pairs: [(usize, usize, usize, usize); 4] = [
310                    (0, 1, 2, 3),
311                    (0, 1, 2, 3),
312                    (0, 2, 1, 3),
313                    (1, 3, 0, 2),
314                ];
315                let mut c = [0.0f64; 4];
316                for i in 0..4 {
317                    let base = 24 + i * 3;
318                    let (li, lj, ri, rj) = level2_pairs[i];
319                    let (alpha, beta, gamma) =
320                        softmax3(params[base], params[base + 1], params[base + 2]);
321                    let mix_left = (alpha + beta * b[li] + gamma * b[lj]).clamp(-10.0, 10.0);
322                    let (ar, br, gr) = softmax3(
323                        params[base] + 0.5,
324                        params[base + 1] - 0.5,
325                        params[base + 2],
326                    );
327                    let mix_right = (ar + br * b[ri] + gr * b[rj]).clamp(0.01, 10.0);
328                    c[i] = eml_safe(mix_left, mix_right);
329                }
330
331                // Level 3: 2 mixing nodes
332                let level3_pairs: [(usize, usize, usize, usize); 2] =
333                    [(0, 1, 2, 3), (0, 2, 1, 3)];
334                let mut d = [0.0f64; 2];
335                for i in 0..2 {
336                    let base = 36 + i * 4;
337                    let (li, lj, ri, rj) = level3_pairs[i];
338                    let mix_left = (params[base]
339                        + params[base + 1] * c[li]
340                        + (1.0 - params[base] - params[base + 1]) * c[lj])
341                        .clamp(-10.0, 10.0);
342                    let mix_right = (params[base + 2]
343                        + params[base + 3] * c[ri]
344                        + (1.0 - params[base + 2] - params[base + 3]) * c[rj])
345                        .clamp(0.01, 10.0);
346                    d[i] = eml_safe(mix_left, mix_right);
347                }
348                d.to_vec()
349            }
350            5 => {
351                // Level 2: 4 mixing nodes (same as depth 4)
352                let level2_pairs: [(usize, usize, usize, usize); 4] = [
353                    (0, 1, 2, 3),
354                    (0, 1, 2, 3),
355                    (0, 2, 1, 3),
356                    (1, 3, 0, 2),
357                ];
358                let mut c = [0.0f64; 4];
359                for i in 0..4 {
360                    let base = 24 + i * 3;
361                    let (li, lj, ri, rj) = level2_pairs[i];
362                    let (alpha, beta, gamma) =
363                        softmax3(params[base], params[base + 1], params[base + 2]);
364                    let mix_left = (alpha + beta * b[li] + gamma * b[lj]).clamp(-10.0, 10.0);
365                    let (ar, br, gr) = softmax3(
366                        params[base] + 0.5,
367                        params[base + 1] - 0.5,
368                        params[base + 2],
369                    );
370                    let mix_right = (ar + br * b[ri] + gr * b[rj]).clamp(0.01, 10.0);
371                    c[i] = eml_safe(mix_left, mix_right);
372                }
373
374                // Level 3: 4 mixing nodes
375                let level3_pairs: [(usize, usize, usize, usize); 4] = [
376                    (0, 1, 2, 3),
377                    (0, 2, 1, 3),
378                    (1, 3, 0, 2),
379                    (0, 3, 1, 2),
380                ];
381                let mut e = [0.0f64; 4];
382                for i in 0..4 {
383                    let base = 36 + i * 3;
384                    let (li, lj, ri, rj) = level3_pairs[i];
385                    let (alpha, beta, gamma) =
386                        softmax3(params[base], params[base + 1], params[base + 2]);
387                    let mix_left = (alpha + beta * c[li] + gamma * c[lj]).clamp(-10.0, 10.0);
388                    let (ar, br, gr) = softmax3(
389                        params[base] + 0.5,
390                        params[base + 1] - 0.5,
391                        params[base + 2],
392                    );
393                    let mix_right = (ar + br * c[ri] + gr * c[rj]).clamp(0.01, 10.0);
394                    e[i] = eml_safe(mix_left, mix_right);
395                }
396
397                // Level 4: 2 mixing nodes
398                let mut f = [0.0f64; 2];
399                for i in 0..2 {
400                    let base = 48 + i * 4;
401                    let li = i * 2;
402                    let lj = i * 2 + 1;
403                    let ri = (i * 2 + 2) % 4;
404                    let rj = (i * 2 + 3) % 4;
405                    let mix_left = (params[base]
406                        + params[base + 1] * e[li]
407                        + (1.0 - params[base] - params[base + 1]) * e[lj])
408                        .clamp(-10.0, 10.0);
409                    let mix_right = (params[base + 2]
410                        + params[base + 3] * e[ri]
411                        + (1.0 - params[base + 2] - params[base + 3]) * e[rj])
412                        .clamp(0.01, 10.0);
413                    f[i] = eml_safe(mix_left, mix_right);
414                }
415                f.to_vec()
416            }
417            _ => unreachable!(),
418        };
419
420        // Head layer: each head mixes the trunk values
421        let head_base = self.param_count() - self.head_count * 2;
422        let mut outputs = Vec::with_capacity(self.head_count);
423        for k in 0..self.head_count {
424            let base = head_base + k * 2;
425            let w0 = params[base];
426            let w1 = params[base + 1];
427            let (left, right) = if trunk.len() >= 2 {
428                (
429                    (w0 * trunk[0] + (1.0 - w0) * trunk[1]).clamp(-10.0, 10.0),
430                    (w1 * trunk[0] + (1.0 - w1) * trunk[1]).clamp(0.01, 10.0),
431                )
432            } else {
433                (
434                    (w0 * trunk[0]).clamp(-10.0, 10.0),
435                    (w1 * trunk[0]).clamp(0.01, 10.0),
436                )
437            };
438            outputs.push(eml_safe(left, right).max(0.0));
439        }
440
441        outputs
442    }
443
444    /// Generate feature pair indices for level 0 (cycling through inputs).
445    fn feature_pairs(input_count: usize) -> [(usize, usize); 8] {
446        let mut pairs = [(0usize, 0usize); 8];
447        for i in 0..8 {
448            pairs[i] = (
449                (i * 2) % input_count,
450                (i * 2 + 1) % input_count,
451            );
452        }
453        pairs
454    }
455
456    // -------------------------------------------------------------------
457    // Training
458    // -------------------------------------------------------------------
459
460    /// Record a training sample.
461    ///
462    /// # Arguments
463    /// - `inputs`: Input feature values.
464    /// - `targets`: Target values for each head. Use `None` for heads
465    ///   without ground truth in this sample (they are skipped in the
466    ///   loss function).
467    pub fn record(&mut self, inputs: &[f64], targets: &[Option<f64>]) {
468        assert_eq!(
469            inputs.len(),
470            self.input_count,
471            "expected {} inputs, got {}",
472            self.input_count,
473            inputs.len()
474        );
475        assert_eq!(
476            targets.len(),
477            self.head_count,
478            "expected {} targets, got {}",
479            self.head_count,
480            targets.len()
481        );
482        self.training_data.push(TrainingPoint {
483            inputs: inputs.to_vec(),
484            targets: targets.to_vec(),
485        });
486    }
487
488    /// Train the model using random restart + coordinate descent.
489    ///
490    /// Requires at least 50 training samples. Returns `true` if the
491    /// model converged (MSE < 0.01).
492    pub fn train(&mut self) -> bool {
493        if self.training_data.len() < 50 {
494            return false;
495        }
496
497        let param_count = self.params.len();
498        let mut best_params = self.params.clone();
499        let mse_before = self.evaluate_mse(&self.params);
500        let mut best_mse = mse_before;
501
502        // Phase 1: random restarts
503        let restart_count = if param_count > 40 { 200 } else { 100 };
504        let mut rng_state: u64 = 0xDEAD_BEEF_CAFE_1234;
505        for _ in 0..restart_count {
506            let candidate = random_params(&mut rng_state, param_count);
507            let mse = self.evaluate_mse(&candidate);
508            if mse < best_mse {
509                best_mse = mse;
510                best_params = candidate;
511            }
512        }
513
514        // Phase 2: coordinate descent
515        let deltas = [-0.1, -0.01, -0.001, 0.001, 0.01, 0.1];
516        for _ in 0..1000 {
517            let mut improved = false;
518            for i in 0..param_count {
519                for &delta in &deltas {
520                    let mut candidate = best_params.clone();
521                    candidate[i] += delta;
522                    let mse = self.evaluate_mse(&candidate);
523                    if mse < best_mse {
524                        best_mse = mse;
525                        best_params = candidate;
526                        improved = true;
527                    }
528                }
529            }
530            if !improved {
531                break;
532            }
533        }
534
535        self.params = best_params;
536        self.trained = best_mse < 0.01;
537
538        // Emit a Trained event for ExoChain logging.
539        let name = if self.model_name.is_empty() {
540            format!("eml_d{}x{}x{}", self.depth, self.input_count, self.head_count)
541        } else {
542            self.model_name.clone()
543        };
544        self.event_log.push(EmlEvent::Trained {
545            model_name: name,
546            samples_used: self.training_data.len(),
547            mse_before,
548            mse_after: best_mse,
549            converged: self.trained,
550            param_count: self.params.len(),
551        });
552
553        self.trained
554    }
555
556    /// Compute weighted MSE over the training set.
557    fn evaluate_mse(&self, params: &[f64]) -> f64 {
558        if self.training_data.is_empty() {
559            return f64::MAX;
560        }
561
562        let mut total_loss = 0.0;
563        let mut total_weight = 0.0;
564
565        for tp in &self.training_data {
566            let predicted = self.evaluate_with_params(params, &tp.inputs);
567            for (k, target) in tp.targets.iter().enumerate() {
568                if let Some(t) = target {
569                    // Primary head (k==0) gets weight 1.0, others 0.3
570                    let weight = if k == 0 { 1.0 } else { 0.3 };
571                    total_loss += weight * (predicted[k] - t).powi(2);
572                    total_weight += weight;
573                }
574            }
575        }
576
577        if total_weight > 0.0 {
578            total_loss / total_weight
579        } else {
580            f64::MAX
581        }
582    }
583
584    // -------------------------------------------------------------------
585    // Distillation
586    // -------------------------------------------------------------------
587
588    /// Distill this (teacher) model to a shallower student model.
589    ///
590    /// Creates a new `EmlModel` with `target_depth` (must be less than
591    /// the teacher's depth) and trains it to mimic the teacher's outputs
592    /// on `num_samples` synthetic inputs drawn uniformly from \[0, 1\].
593    ///
594    /// The student learns from the teacher's predictions, not from the
595    /// original training data. This preserves accuracy while reducing
596    /// computation for constrained devices (WASM, ESP32).
597    ///
598    /// # Panics
599    /// Panics if `target_depth >= self.depth` or `target_depth` is not
600    /// in {2, 3, 4, 5}.
601    pub fn distill(&self, target_depth: usize, num_samples: usize) -> EmlModel {
602        assert!(
603            target_depth < self.depth,
604            "student depth ({target_depth}) must be less than teacher depth ({})",
605            self.depth
606        );
607
608        let mut student = EmlModel::new(target_depth, self.input_count, self.head_count);
609
610        // Generate synthetic inputs in [0, 1] and get teacher predictions.
611        // Use a simple LCG for reproducibility without needing `rand`.
612        let mut rng_state: u64 = 0xCAFE_BABE_1234_5678;
613        let lcg_next = |state: &mut u64| -> f64 {
614            *state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
615            // Map to [0, 1]
616            (*state >> 33) as f64 / (1u64 << 31) as f64
617        };
618
619        for _ in 0..num_samples.max(50) {
620            let inputs: Vec<f64> = (0..self.input_count)
621                .map(|_| lcg_next(&mut rng_state))
622                .collect();
623            let teacher_out = self.predict(&inputs);
624            let targets: Vec<Option<f64>> = teacher_out.into_iter().map(Some).collect();
625            student.record(&inputs, &targets);
626        }
627
628        student.train();
629        student
630    }
631
632    // -------------------------------------------------------------------
633    // Serialization
634    // -------------------------------------------------------------------
635
636    /// Serialize the model to a JSON string.
637    pub fn to_json(&self) -> String {
638        serde_json::to_string(self).expect("EmlModel serialization should not fail")
639    }
640
641    /// Deserialize a model from a JSON string.
642    ///
643    /// Returns `None` if the JSON is invalid.
644    pub fn from_json(json: &str) -> Option<Self> {
645        serde_json::from_str(json).ok()
646    }
647}
648
649#[cfg(test)]
650mod tests {
651    use super::*;
652
653    #[test]
654    fn new_model_defaults() {
655        let m = EmlModel::new(4, 7, 3);
656        assert_eq!(m.depth(), 4);
657        assert_eq!(m.input_count(), 7);
658        assert_eq!(m.head_count(), 3);
659        assert!(!m.is_trained());
660        assert_eq!(m.training_sample_count(), 0);
661    }
662
663    #[test]
664    fn param_count_depth_2() {
665        let m = EmlModel::new(2, 5, 1);
666        // Level 0: 24, heads: 2 = 26
667        assert_eq!(m.param_count(), 26);
668    }
669
670    #[test]
671    fn param_count_depth_3() {
672        let m = EmlModel::new(3, 7, 1);
673        // Level 0: 24, level 2: 8, heads: 2 = 34
674        assert_eq!(m.param_count(), 34);
675    }
676
677    #[test]
678    fn param_count_depth_4_single_head() {
679        let m = EmlModel::new(4, 7, 1);
680        // Level 0: 24, level 2: 12, level 3: 8, heads: 2 = 46
681        assert_eq!(m.param_count(), 46);
682    }
683
684    #[test]
685    fn param_count_depth_4_three_heads() {
686        let m = EmlModel::new(4, 7, 3);
687        // Level 0: 24, level 2: 12, level 3: 8, heads: 6 = 50
688        assert_eq!(m.param_count(), 50);
689    }
690
691    #[test]
692    fn param_count_depth_5() {
693        let m = EmlModel::new(5, 4, 2);
694        // Level 0: 24, level 2: 12, level 3: 12, level 4: 8, heads: 4 = 60
695        assert_eq!(m.param_count(), 60);
696    }
697
698    #[test]
699    fn predict_untrained_produces_values() {
700        let m = EmlModel::new(4, 7, 3);
701        let inputs = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7];
702        let result = m.predict(&inputs);
703        assert_eq!(result.len(), 3);
704        for &v in &result {
705            assert!(v.is_finite(), "prediction should be finite");
706            assert!(v >= 0.0, "prediction should be non-negative");
707        }
708    }
709
710    #[test]
711    fn predict_primary_matches_first_head() {
712        let m = EmlModel::new(3, 5, 3);
713        let inputs = vec![0.1, 0.2, 0.3, 0.4, 0.5];
714        let all = m.predict(&inputs);
715        let primary = m.predict_primary(&inputs);
716        assert!(
717            (primary - all[0]).abs() < 1e-12,
718            "predict_primary should match predict()[0]"
719        );
720    }
721
722    #[test]
723    fn record_increments_count() {
724        let mut m = EmlModel::new(3, 3, 1);
725        assert_eq!(m.training_sample_count(), 0);
726        m.record(&[0.1, 0.2, 0.3], &[Some(1.0)]);
727        assert_eq!(m.training_sample_count(), 1);
728    }
729
730    #[test]
731    fn train_insufficient_data_returns_false() {
732        let mut m = EmlModel::new(3, 3, 1);
733        for i in 0..10 {
734            m.record(
735                &[i as f64 / 10.0, 0.5, 0.5],
736                &[Some(1.0)],
737            );
738        }
739        assert!(!m.train());
740        assert!(!m.is_trained());
741    }
742
743    #[test]
744    fn training_convergence_polynomial() {
745        // Train on y = x^2 for x in [0, 1]
746        let mut m = EmlModel::new(4, 1, 1);
747        for i in 0..100 {
748            let x = i as f64 / 100.0;
749            let y = x * x;
750            m.record(&[x], &[Some(y)]);
751        }
752        let _ = m.train();
753        // Even if not fully converged, should produce finite predictions
754        let pred = m.predict_primary(&[0.5]);
755        assert!(pred.is_finite());
756    }
757
758    #[test]
759    fn multi_head_training() {
760        let mut m = EmlModel::new(4, 2, 3);
761        for i in 0..80 {
762            let x = i as f64 / 80.0;
763            let y = (i + 10) as f64 / 80.0;
764            m.record(
765                &[x, y],
766                &[Some(x + y), Some(x * y), None],
767            );
768        }
769        let _ = m.train();
770        let pred = m.predict(&[0.5, 0.5]);
771        assert_eq!(pred.len(), 3);
772        for &v in &pred {
773            assert!(v.is_finite());
774        }
775    }
776
777    #[test]
778    fn serialization_roundtrip() {
779        let mut m = EmlModel::new(4, 5, 2);
780        // Set some params to non-zero
781        for (i, p) in m.params.iter_mut().enumerate() {
782            *p = (i as f64 * 0.1).sin();
783        }
784        m.trained = true;
785
786        let json = m.to_json();
787        let m2 = EmlModel::from_json(&json).expect("should deserialize");
788
789        assert_eq!(m.depth, m2.depth);
790        assert_eq!(m.input_count, m2.input_count);
791        assert_eq!(m.head_count, m2.head_count);
792        assert_eq!(m.params.len(), m2.params.len());
793        for (i, (a, b)) in m.params.iter().zip(m2.params.iter()).enumerate() {
794            assert!(
795                (a - b).abs() < 1e-14,
796                "param[{i}] mismatch: {a} vs {b}"
797            );
798        }
799        assert_eq!(m.trained, m2.trained);
800        // training_data is skipped in serde
801        assert_eq!(m2.training_sample_count(), 0);
802    }
803
804    #[test]
805    fn from_json_invalid_returns_none() {
806        assert!(EmlModel::from_json("not valid json").is_none());
807    }
808
809    #[test]
810    fn various_depths_produce_finite_output() {
811        for depth in 2..=5 {
812            let m = EmlModel::new(depth, 4, 2);
813            let inputs = vec![0.3, 0.5, 0.7, 0.1];
814            let result = m.predict(&inputs);
815            assert_eq!(result.len(), 2);
816            for &v in &result {
817                assert!(
818                    v.is_finite(),
819                    "depth-{depth} should produce finite output"
820                );
821            }
822        }
823    }
824
825    #[test]
826    #[should_panic(expected = "EmlModel depth must be 2, 3, 4, or 5")]
827    fn invalid_depth_panics() {
828        EmlModel::new(6, 3, 1);
829    }
830
831    #[test]
832    #[should_panic(expected = "head_count must be >= 1")]
833    fn zero_heads_panics() {
834        EmlModel::new(3, 3, 0);
835    }
836
837    #[test]
838    fn distill_depth_4_to_depth_2() {
839        // Distill a depth-4 model to depth-2.
840        // The student should learn to mimic the teacher's output function,
841        // regardless of whether the teacher was "well trained" on real data.
842        // We verify structural correctness and output agreement.
843        let mut teacher = EmlModel::new(4, 2, 1);
844        // Give teacher non-trivial params so it has a non-constant function.
845        for (i, p) in teacher.params.iter_mut().enumerate() {
846            *p = ((i as f64) * 0.37).sin() * 0.5;
847        }
848        teacher.trained = true;
849
850        let student = teacher.distill(2, 500);
851        assert_eq!(student.depth(), 2);
852        assert_eq!(student.input_count(), 2);
853        assert_eq!(student.head_count(), 1);
854
855        // Evaluate on a grid and compute mean absolute error.
856        let mut total_err = 0.0;
857        let mut count = 0;
858        for i in 0..10 {
859            for j in 0..10 {
860                let x = i as f64 / 10.0;
861                let y = j as f64 / 10.0;
862                let t = teacher.predict_primary(&[x, y]);
863                let s = student.predict_primary(&[x, y]);
864                assert!(t.is_finite());
865                assert!(s.is_finite());
866                total_err += (t - s).abs();
867                count += 1;
868            }
869        }
870        let mae = total_err / count as f64;
871
872        // The student should have reasonable fidelity. With 500 samples
873        // and coordinate descent, MAE should be moderate.
874        // We primarily verify the distillation mechanism works without panics
875        // and produces finite, non-degenerate outputs.
876        assert!(
877            mae < 50.0,
878            "distilled model MAE should be reasonable, got {mae}"
879        );
880    }
881
882    #[test]
883    fn distill_multi_head() {
884        let mut teacher = EmlModel::new(4, 2, 2);
885        for i in 0..100 {
886            let x = i as f64 / 100.0;
887            let y = (i + 20) as f64 / 100.0;
888            teacher.record(&[x, y], &[Some(x + y), Some(x * y)]);
889        }
890        teacher.train();
891
892        let student = teacher.distill(2, 200);
893        assert_eq!(student.depth(), 2);
894        assert_eq!(student.head_count(), 2);
895
896        // Both heads should produce finite outputs.
897        let pred = student.predict(&[0.5, 0.7]);
898        assert_eq!(pred.len(), 2);
899        for &v in &pred {
900            assert!(v.is_finite());
901        }
902    }
903
904    #[test]
905    #[should_panic(expected = "student depth")]
906    fn distill_same_depth_panics() {
907        let teacher = EmlModel::new(4, 3, 1);
908        teacher.distill(4, 100);
909    }
910}