1use crate::operator::{eml_safe, softmax3};
9
10#[derive(Debug, Clone)]
23pub struct EmlTree {
24 depth: usize,
25 input_count: usize,
26 param_count: usize,
27}
28
29impl EmlTree {
30 pub fn new(depth: usize, input_count: usize) -> Self {
39 assert!(
40 (2..=5).contains(&depth),
41 "EmlTree depth must be 2, 3, 4, or 5, got {depth}"
42 );
43 let param_count = Self::compute_param_count(depth, input_count);
44 Self {
45 depth,
46 input_count,
47 param_count,
48 }
49 }
50
51 pub fn param_count(&self) -> usize {
53 self.param_count
54 }
55
56 pub fn depth(&self) -> usize {
58 self.depth
59 }
60
61 pub fn input_count(&self) -> usize {
63 self.input_count
64 }
65
66 fn compute_param_count(depth: usize, _input_count: usize) -> usize {
71 let width = 1usize << (depth - 1); let mut total = width * 3;
75
76 let mut w = width / 2; for level in 2..depth {
80 let params_per_node = if level < depth - 1 { 3 } else { 2 };
83 total += w * params_per_node;
84 w /= 2;
85 if w == 0 {
86 w = 1;
87 }
88 }
89
90 total += 2;
92
93 total
94 }
95
96 pub fn evaluate(&self, params: &[f64], inputs: &[f64]) -> f64 {
105 assert_eq!(
106 params.len(),
107 self.param_count,
108 "expected {} params, got {}",
109 self.param_count,
110 params.len()
111 );
112 assert_eq!(
113 inputs.len(),
114 self.input_count,
115 "expected {} inputs, got {}",
116 self.input_count,
117 inputs.len()
118 );
119
120 let width = 1usize << (self.depth - 1);
121
122 let mut a = vec![0.0f64; width];
124 for i in 0..width {
125 let base = i * 3;
126 let (alpha, beta, gamma) = softmax3(params[base], params[base + 1], params[base + 2]);
127 let j = (i * 2) % self.input_count;
129 let k = (i * 2 + 1) % self.input_count;
130 a[i] = (alpha + beta * inputs[j] + gamma * inputs[k]).clamp(-10.0, 10.0);
131 }
132
133 let mut current: Vec<f64> = a
135 .chunks(2)
136 .map(|pair| eml_safe(pair[0], pair[1].max(0.01)))
137 .collect();
138
139 let mut param_offset = width * 3;
141 for level in 2..self.depth {
142 let is_last_mix = level == self.depth - 1;
143 let params_per_node = if is_last_mix { 2 } else { 3 };
144 let next_width = (current.len() + 1) / 2;
145 let mut next = Vec::with_capacity(next_width);
146
147 for i in 0..next_width {
148 let li = i * 2;
149 let ri = (i * 2 + 1).min(current.len() - 1);
150
151 if params_per_node == 3 {
152 let (alpha, beta, gamma) = softmax3(
153 params[param_offset],
154 params[param_offset + 1],
155 params[param_offset + 2],
156 );
157 let mixed = (alpha + beta * current[li] + gamma * current[ri])
158 .clamp(-10.0, 10.0);
159 let (ar, br, gr) = softmax3(
161 params[param_offset] + 0.5,
162 params[param_offset + 1] - 0.5,
163 params[param_offset + 2],
164 );
165 let mixed_r = (ar + br * current[ri] + gr * current[li]).clamp(0.01, 10.0);
166 next.push(eml_safe(mixed, mixed_r));
167 } else {
168 let w0 = params[param_offset];
169 let w1 = params[param_offset + 1];
170 let left = (w0 * current[li] + (1.0 - w0) * current[ri]).clamp(-10.0, 10.0);
171 let right = (w1 * current[li] + (1.0 - w1) * current[ri]).clamp(0.01, 10.0);
172 next.push(eml_safe(left, right));
173 }
174
175 param_offset += params_per_node;
176 }
177
178 current = next;
179 }
180
181 let w0 = params[param_offset];
183 let w1 = params[param_offset + 1];
184 let (left, right) = if current.len() >= 2 {
185 (
186 (w0 * current[0] + (1.0 - w0) * current[1]).clamp(-10.0, 10.0),
187 (w1 * current[0] + (1.0 - w1) * current[1]).clamp(0.01, 10.0),
188 )
189 } else {
190 (
191 (w0 * current[0]).clamp(-10.0, 10.0),
192 (w1 * current[0]).clamp(0.01, 10.0),
193 )
194 };
195
196 eml_safe(left, right).max(0.0)
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn tree_depth_2() {
206 let tree = EmlTree::new(2, 3);
207 assert_eq!(tree.depth(), 2);
208 assert_eq!(tree.input_count(), 3);
209 let pc = tree.param_count();
210 assert!(pc > 0, "param count should be positive");
211
212 let params = vec![0.1; pc];
213 let inputs = vec![0.5, 0.3, 0.7];
214 let result = tree.evaluate(¶ms, &inputs);
215 assert!(result.is_finite(), "depth-2 result should be finite");
216 }
217
218 #[test]
219 fn tree_depth_3() {
220 let tree = EmlTree::new(3, 5);
221 let pc = tree.param_count();
222 let params = vec![0.0; pc];
223 let inputs = vec![0.1, 0.2, 0.3, 0.4, 0.5];
224 let result = tree.evaluate(¶ms, &inputs);
225 assert!(result.is_finite());
226 }
227
228 #[test]
229 fn tree_depth_4() {
230 let tree = EmlTree::new(4, 7);
231 let pc = tree.param_count();
232 let params = vec![0.1; pc];
233 let inputs = vec![0.1; 7];
234 let result = tree.evaluate(¶ms, &inputs);
235 assert!(result.is_finite());
236 }
237
238 #[test]
239 fn tree_depth_5() {
240 let tree = EmlTree::new(5, 4);
241 let pc = tree.param_count();
242 assert!(pc > 0);
243 let params = vec![0.0; pc];
244 let inputs = vec![0.5; 4];
245 let result = tree.evaluate(¶ms, &inputs);
246 assert!(result.is_finite());
247 }
248
249 #[test]
250 #[should_panic(expected = "EmlTree depth must be 2, 3, 4, or 5")]
251 fn tree_invalid_depth() {
252 EmlTree::new(1, 3);
253 }
254
255 #[test]
256 fn tree_output_non_negative() {
257 for depth in 2..=5 {
258 let tree = EmlTree::new(depth, 4);
259 let params = vec![0.5; tree.param_count()];
260 let inputs = vec![0.3; 4];
261 let result = tree.evaluate(¶ms, &inputs);
262 assert!(
263 result >= 0.0,
264 "depth-{depth} output should be non-negative, got {result}"
265 );
266 }
267 }
268
269 #[test]
270 fn param_count_increases_with_depth() {
271 let pc2 = EmlTree::new(2, 4).param_count();
272 let pc3 = EmlTree::new(3, 4).param_count();
273 let pc4 = EmlTree::new(4, 4).param_count();
274 let pc5 = EmlTree::new(5, 4).param_count();
275 assert!(pc3 > pc2, "depth 3 should have more params than depth 2");
276 assert!(pc4 > pc3, "depth 4 should have more params than depth 3");
277 assert!(pc5 > pc4, "depth 5 should have more params than depth 4");
278 }
279}