Skip to main content

eml_core/
tree.rs

1//! Depth-configurable EML tree evaluation.
2//!
3//! An [`EmlTree`] is a fixed-depth tree of EML operators with trainable
4//! mixing weights. It maps N input features to a single scalar output.
5//!
6//! Supported depths: 2, 3, 4, 5.
7
8use crate::operator::{eml_safe, softmax3};
9
10/// Depth-configurable EML evaluation tree.
11///
12/// The tree maps `input_count` features through layers of affine mixing
13/// and EML operators to produce a single scalar output.
14///
15/// # Architecture
16///
17/// - **Level 0**: `2^(depth-1)` affine combinations of input features
18///   (3 params each via softmax3 mixing).
19/// - **Levels 1..depth-1**: EML nodes halving the width at each level,
20///   with mixing weights.
21/// - **Output**: final EML node producing a single scalar.
22#[derive(Debug, Clone)]
23pub struct EmlTree {
24    depth: usize,
25    input_count: usize,
26    param_count: usize,
27}
28
29impl EmlTree {
30    /// Create a new EML tree specification.
31    ///
32    /// # Arguments
33    /// - `depth`: Tree depth (2, 3, 4, or 5).
34    /// - `input_count`: Number of input features.
35    ///
36    /// # Panics
37    /// Panics if depth is not in {2, 3, 4, 5}.
38    pub fn new(depth: usize, input_count: usize) -> Self {
39        assert!(
40            (2..=5).contains(&depth),
41            "EmlTree depth must be 2, 3, 4, or 5, got {depth}"
42        );
43        let param_count = Self::compute_param_count(depth, input_count);
44        Self {
45            depth,
46            input_count,
47            param_count,
48        }
49    }
50
51    /// Number of trainable parameters for this tree configuration.
52    pub fn param_count(&self) -> usize {
53        self.param_count
54    }
55
56    /// Tree depth.
57    pub fn depth(&self) -> usize {
58        self.depth
59    }
60
61    /// Number of input features.
62    pub fn input_count(&self) -> usize {
63        self.input_count
64    }
65
66    /// Compute the total parameter count for a given depth and input count.
67    ///
68    /// Level 0: `width` nodes * 3 params each (softmax3 mixing of 2 inputs + bias).
69    /// Each subsequent level halves the width and adds mixing params.
70    fn compute_param_count(depth: usize, _input_count: usize) -> usize {
71        let width = 1usize << (depth - 1); // 2^(depth-1) nodes at level 0
72
73        // Level 0: each node has 3 softmax params
74        let mut total = width * 3;
75
76        // Levels 2..depth: each level halves, each node needs mixing weights.
77        // Level 1 is pure EML (no extra params — just pairs level-0 outputs).
78        let mut w = width / 2; // level 1 width (after first EML pairing)
79        for level in 2..depth {
80            // Each node at this level mixes two inputs: 2 weights
81            // plus for deeper trees we use 3-weight softmax mixing
82            let params_per_node = if level < depth - 1 { 3 } else { 2 };
83            total += w * params_per_node;
84            w /= 2;
85            if w == 0 {
86                w = 1;
87            }
88        }
89
90        // Output level: 2 mixing weights for the final EML
91        total += 2;
92
93        total
94    }
95
96    /// Evaluate the tree with given parameters and inputs.
97    ///
98    /// # Arguments
99    /// - `params`: Trainable parameters (length must equal `param_count()`).
100    /// - `inputs`: Input feature values (length must equal `input_count`).
101    ///
102    /// # Panics
103    /// Panics if `params.len() != param_count()` or `inputs.len() != input_count`.
104    pub fn evaluate(&self, params: &[f64], inputs: &[f64]) -> f64 {
105        assert_eq!(
106            params.len(),
107            self.param_count,
108            "expected {} params, got {}",
109            self.param_count,
110            params.len()
111        );
112        assert_eq!(
113            inputs.len(),
114            self.input_count,
115            "expected {} inputs, got {}",
116            self.input_count,
117            inputs.len()
118        );
119
120        let width = 1usize << (self.depth - 1);
121
122        // Level 0: affine combinations via softmax3
123        let mut a = vec![0.0f64; width];
124        for i in 0..width {
125            let base = i * 3;
126            let (alpha, beta, gamma) = softmax3(params[base], params[base + 1], params[base + 2]);
127            // Pick two input features (cycling through available inputs)
128            let j = (i * 2) % self.input_count;
129            let k = (i * 2 + 1) % self.input_count;
130            a[i] = (alpha + beta * inputs[j] + gamma * inputs[k]).clamp(-10.0, 10.0);
131        }
132
133        // Level 1: pair up with EML (no extra params)
134        let mut current: Vec<f64> = a
135            .chunks(2)
136            .map(|pair| eml_safe(pair[0], pair[1].max(0.01)))
137            .collect();
138
139        // Levels 2..depth-1: mix + EML
140        let mut param_offset = width * 3;
141        for level in 2..self.depth {
142            let is_last_mix = level == self.depth - 1;
143            let params_per_node = if is_last_mix { 2 } else { 3 };
144            let next_width = (current.len() + 1) / 2;
145            let mut next = Vec::with_capacity(next_width);
146
147            for i in 0..next_width {
148                let li = i * 2;
149                let ri = (i * 2 + 1).min(current.len() - 1);
150
151                if params_per_node == 3 {
152                    let (alpha, beta, gamma) = softmax3(
153                        params[param_offset],
154                        params[param_offset + 1],
155                        params[param_offset + 2],
156                    );
157                    let mixed = (alpha + beta * current[li] + gamma * current[ri])
158                        .clamp(-10.0, 10.0);
159                    // Use shifted softmax for the right side
160                    let (ar, br, gr) = softmax3(
161                        params[param_offset] + 0.5,
162                        params[param_offset + 1] - 0.5,
163                        params[param_offset + 2],
164                    );
165                    let mixed_r = (ar + br * current[ri] + gr * current[li]).clamp(0.01, 10.0);
166                    next.push(eml_safe(mixed, mixed_r));
167                } else {
168                    let w0 = params[param_offset];
169                    let w1 = params[param_offset + 1];
170                    let left = (w0 * current[li] + (1.0 - w0) * current[ri]).clamp(-10.0, 10.0);
171                    let right = (w1 * current[li] + (1.0 - w1) * current[ri]).clamp(0.01, 10.0);
172                    next.push(eml_safe(left, right));
173                }
174
175                param_offset += params_per_node;
176            }
177
178            current = next;
179        }
180
181        // Output: final mixing
182        let w0 = params[param_offset];
183        let w1 = params[param_offset + 1];
184        let (left, right) = if current.len() >= 2 {
185            (
186                (w0 * current[0] + (1.0 - w0) * current[1]).clamp(-10.0, 10.0),
187                (w1 * current[0] + (1.0 - w1) * current[1]).clamp(0.01, 10.0),
188            )
189        } else {
190            (
191                (w0 * current[0]).clamp(-10.0, 10.0),
192                (w1 * current[0]).clamp(0.01, 10.0),
193            )
194        };
195
196        eml_safe(left, right).max(0.0)
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn tree_depth_2() {
206        let tree = EmlTree::new(2, 3);
207        assert_eq!(tree.depth(), 2);
208        assert_eq!(tree.input_count(), 3);
209        let pc = tree.param_count();
210        assert!(pc > 0, "param count should be positive");
211
212        let params = vec![0.1; pc];
213        let inputs = vec![0.5, 0.3, 0.7];
214        let result = tree.evaluate(&params, &inputs);
215        assert!(result.is_finite(), "depth-2 result should be finite");
216    }
217
218    #[test]
219    fn tree_depth_3() {
220        let tree = EmlTree::new(3, 5);
221        let pc = tree.param_count();
222        let params = vec![0.0; pc];
223        let inputs = vec![0.1, 0.2, 0.3, 0.4, 0.5];
224        let result = tree.evaluate(&params, &inputs);
225        assert!(result.is_finite());
226    }
227
228    #[test]
229    fn tree_depth_4() {
230        let tree = EmlTree::new(4, 7);
231        let pc = tree.param_count();
232        let params = vec![0.1; pc];
233        let inputs = vec![0.1; 7];
234        let result = tree.evaluate(&params, &inputs);
235        assert!(result.is_finite());
236    }
237
238    #[test]
239    fn tree_depth_5() {
240        let tree = EmlTree::new(5, 4);
241        let pc = tree.param_count();
242        assert!(pc > 0);
243        let params = vec![0.0; pc];
244        let inputs = vec![0.5; 4];
245        let result = tree.evaluate(&params, &inputs);
246        assert!(result.is_finite());
247    }
248
249    #[test]
250    #[should_panic(expected = "EmlTree depth must be 2, 3, 4, or 5")]
251    fn tree_invalid_depth() {
252        EmlTree::new(1, 3);
253    }
254
255    #[test]
256    fn tree_output_non_negative() {
257        for depth in 2..=5 {
258            let tree = EmlTree::new(depth, 4);
259            let params = vec![0.5; tree.param_count()];
260            let inputs = vec![0.3; 4];
261            let result = tree.evaluate(&params, &inputs);
262            assert!(
263                result >= 0.0,
264                "depth-{depth} output should be non-negative, got {result}"
265            );
266        }
267    }
268
269    #[test]
270    fn param_count_increases_with_depth() {
271        let pc2 = EmlTree::new(2, 4).param_count();
272        let pc3 = EmlTree::new(3, 4).param_count();
273        let pc4 = EmlTree::new(4, 4).param_count();
274        let pc5 = EmlTree::new(5, 4).param_count();
275        assert!(pc3 > pc2, "depth 3 should have more params than depth 2");
276        assert!(pc4 > pc3, "depth 4 should have more params than depth 3");
277        assert!(pc5 > pc4, "depth 5 should have more params than depth 4");
278    }
279}