1use alloc::boxed::Box;
45use alloc::vec;
46use alloc::vec::Vec;
47
48use crate::math;
49
50pub trait LeafModel: Send + Sync {
54 fn predict(&self, features: &[f64]) -> f64;
56
57 fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64);
59
60 fn clone_fresh(&self) -> Box<dyn LeafModel>;
62
63 fn clone_warm(&self) -> Box<dyn LeafModel> {
70 self.clone_fresh()
71 }
72}
73
74pub struct ClosedFormLeaf {
83 grad_sum: f64,
84 hess_sum: f64,
85 weight: f64,
86}
87
88impl Default for ClosedFormLeaf {
89 fn default() -> Self {
90 Self {
91 grad_sum: 0.0,
92 hess_sum: 0.0,
93 weight: 0.0,
94 }
95 }
96}
97
98impl ClosedFormLeaf {
99 pub fn new() -> Self {
101 Self::default()
102 }
103}
104
105impl LeafModel for ClosedFormLeaf {
106 fn predict(&self, _features: &[f64]) -> f64 {
107 self.weight
108 }
109
110 fn update(&mut self, _features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
111 self.grad_sum += gradient;
112 self.hess_sum += hessian;
113 self.weight = -self.grad_sum / (self.hess_sum + lambda);
114 }
115
116 fn clone_fresh(&self) -> Box<dyn LeafModel> {
117 Box::new(ClosedFormLeaf::new())
118 }
119}
120
121pub struct LinearLeafModel {
137 weights: Vec<f64>,
138 bias: f64,
139 learning_rate: f64,
140 decay: Option<f64>,
141 use_adagrad: bool,
142 sq_grad_accum: Vec<f64>,
144 sq_bias_accum: f64,
146 initialized: bool,
147}
148
149impl LinearLeafModel {
150 pub fn new(learning_rate: f64, decay: Option<f64>, use_adagrad: bool) -> Self {
161 Self {
162 weights: Vec::new(),
163 bias: 0.0,
164 learning_rate,
165 decay,
166 use_adagrad,
167 sq_grad_accum: Vec::new(),
168 sq_bias_accum: 0.0,
169 initialized: false,
170 }
171 }
172}
173
174const ADAGRAD_EPS: f64 = 1e-8;
176
177impl LeafModel for LinearLeafModel {
178 fn predict(&self, features: &[f64]) -> f64 {
179 if !self.initialized {
180 return 0.0;
181 }
182 let mut dot = self.bias;
183 for (w, x) in self.weights.iter().zip(features.iter()) {
184 dot += w * x;
185 }
186 dot
187 }
188
189 fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
190 if !self.initialized {
191 let d = features.len();
192 self.weights = vec![0.0; d];
193 self.sq_grad_accum = vec![0.0; d];
194 self.initialized = true;
195 }
196
197 if let Some(d) = self.decay {
199 for w in self.weights.iter_mut() {
200 *w *= d;
201 }
202 self.bias *= d;
203 }
204
205 let base_lr = self.learning_rate / (math::abs(hessian) + lambda);
207
208 if self.use_adagrad {
209 for (i, (w, x)) in self.weights.iter_mut().zip(features.iter()).enumerate() {
211 let g = gradient * x;
212 self.sq_grad_accum[i] += g * g;
213 let adaptive_lr = base_lr / (math::sqrt(self.sq_grad_accum[i]) + ADAGRAD_EPS);
214 *w -= adaptive_lr * g;
215 }
216 self.sq_bias_accum += gradient * gradient;
217 let bias_lr = base_lr / (math::sqrt(self.sq_bias_accum) + ADAGRAD_EPS);
218 self.bias -= bias_lr * gradient;
219 } else {
220 for (w, x) in self.weights.iter_mut().zip(features.iter()) {
222 *w -= base_lr * gradient * x;
223 }
224 self.bias -= base_lr * gradient;
225 }
226 }
227
228 fn clone_fresh(&self) -> Box<dyn LeafModel> {
229 Box::new(LinearLeafModel::new(
230 self.learning_rate,
231 self.decay,
232 self.use_adagrad,
233 ))
234 }
235
236 fn clone_warm(&self) -> Box<dyn LeafModel> {
237 Box::new(LinearLeafModel {
238 weights: self.weights.clone(),
239 bias: self.bias,
240 learning_rate: self.learning_rate,
241 decay: self.decay,
242 use_adagrad: self.use_adagrad,
243 sq_grad_accum: vec![0.0; self.weights.len()],
248 sq_bias_accum: 0.0,
249 initialized: self.initialized,
250 })
251 }
252}
253
254pub struct MLPLeafModel {
267 hidden_weights: Vec<Vec<f64>>, hidden_bias: Vec<f64>,
269 output_weights: Vec<f64>,
270 output_bias: f64,
271 hidden_size: usize,
272 learning_rate: f64,
273 decay: Option<f64>,
274 seed: u64,
275 initialized: bool,
276 hidden_activations: Vec<f64>,
277 hidden_pre_activations: Vec<f64>,
278}
279
280impl MLPLeafModel {
281 pub fn new(hidden_size: usize, learning_rate: f64, seed: u64, decay: Option<f64>) -> Self {
288 Self {
289 hidden_weights: Vec::new(),
290 hidden_bias: Vec::new(),
291 output_weights: Vec::new(),
292 output_bias: 0.0,
293 hidden_size,
294 learning_rate,
295 decay,
296 seed,
297 initialized: false,
298 hidden_activations: Vec::new(),
299 hidden_pre_activations: Vec::new(),
300 }
301 }
302
303 fn initialize(&mut self, input_size: usize) {
305 let mut state = self.seed ^ (self.hidden_size as u64);
306
307 self.hidden_weights = Vec::with_capacity(self.hidden_size);
308 for _ in 0..self.hidden_size {
309 let mut row = Vec::with_capacity(input_size);
310 for _ in 0..input_size {
311 let r = xorshift64(&mut state);
312 let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
314 row.push(val);
315 }
316 self.hidden_weights.push(row);
317 }
318
319 self.hidden_bias = Vec::with_capacity(self.hidden_size);
320 for _ in 0..self.hidden_size {
321 let r = xorshift64(&mut state);
322 let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
323 self.hidden_bias.push(val);
324 }
325
326 self.output_weights = Vec::with_capacity(self.hidden_size);
327 for _ in 0..self.hidden_size {
328 let r = xorshift64(&mut state);
329 let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
330 self.output_weights.push(val);
331 }
332
333 {
334 let r = xorshift64(&mut state);
335 self.output_bias = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
336 }
337
338 self.hidden_activations = vec![0.0; self.hidden_size];
339 self.hidden_pre_activations = vec![0.0; self.hidden_size];
340 self.initialized = true;
341 }
342
343 fn forward(&mut self, features: &[f64]) -> f64 {
345 for h in 0..self.hidden_size {
347 let mut z = self.hidden_bias[h];
348 for (j, x) in features.iter().enumerate() {
349 if j < self.hidden_weights[h].len() {
350 z += self.hidden_weights[h][j] * x;
351 }
352 }
353 self.hidden_pre_activations[h] = z;
354 self.hidden_activations[h] = if z > 0.0 { z } else { 0.0 };
356 }
357
358 let mut out = self.output_bias;
360 for (w, a) in self
361 .output_weights
362 .iter()
363 .zip(self.hidden_activations.iter())
364 {
365 out += w * a;
366 }
367 out
368 }
369}
370
371impl LeafModel for MLPLeafModel {
372 fn predict(&self, features: &[f64]) -> f64 {
373 if !self.initialized {
374 return 0.0;
375 }
376 let hidden_acts: Vec<f64> = self
378 .hidden_weights
379 .iter()
380 .zip(self.hidden_bias.iter())
381 .map(|(hw, &hb)| {
382 let mut z = hb;
383 for (j, x) in features.iter().enumerate() {
384 if j < hw.len() {
385 z += hw[j] * x;
386 }
387 }
388 if z > 0.0 {
389 z
390 } else {
391 0.0
392 }
393 })
394 .collect();
395 let mut out = self.output_bias;
396 for (w, a) in self.output_weights.iter().zip(hidden_acts.iter()) {
397 out += w * a;
398 }
399 out
400 }
401
402 fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
403 if !self.initialized {
404 self.initialize(features.len());
405 }
406
407 if let Some(d) = self.decay {
409 for row in self.hidden_weights.iter_mut() {
410 for w in row.iter_mut() {
411 *w *= d;
412 }
413 }
414 for b in self.hidden_bias.iter_mut() {
415 *b *= d;
416 }
417 for w in self.output_weights.iter_mut() {
418 *w *= d;
419 }
420 self.output_bias *= d;
421 }
422
423 let _output = self.forward(features);
425
426 let effective_lr = self.learning_rate / (math::abs(hessian) + lambda);
427
428 let d_output = gradient;
430
431 for h in 0..self.hidden_size {
435 self.output_weights[h] -= effective_lr * d_output * self.hidden_activations[h];
436 }
437 self.output_bias -= effective_lr * d_output;
438
439 for h in 0..self.hidden_size {
441 let d_hidden_act = d_output * self.output_weights[h];
443
444 let d_relu = if self.hidden_pre_activations[h] > 0.0 {
446 d_hidden_act
447 } else {
448 0.0
449 };
450
451 for (j, x) in features.iter().enumerate() {
453 if j < self.hidden_weights[h].len() {
454 self.hidden_weights[h][j] -= effective_lr * d_relu * x;
455 }
456 }
457 self.hidden_bias[h] -= effective_lr * d_relu;
458 }
459 }
460
461 fn clone_fresh(&self) -> Box<dyn LeafModel> {
462 let derived_seed = self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(1);
464 Box::new(MLPLeafModel::new(
465 self.hidden_size,
466 self.learning_rate,
467 derived_seed,
468 self.decay,
469 ))
470 }
471
472 fn clone_warm(&self) -> Box<dyn LeafModel> {
473 Box::new(MLPLeafModel {
474 hidden_weights: self.hidden_weights.clone(),
475 hidden_bias: self.hidden_bias.clone(),
476 output_weights: self.output_weights.clone(),
477 output_bias: self.output_bias,
478 hidden_size: self.hidden_size,
479 learning_rate: self.learning_rate,
480 decay: self.decay,
481 seed: self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(2),
483 initialized: self.initialized,
484 hidden_activations: vec![0.0; self.hidden_size],
485 hidden_pre_activations: vec![0.0; self.hidden_size],
486 })
487 }
488}
489
490pub struct AdaptiveLeafModel {
513 active: Box<dyn LeafModel>,
515 shadow: Box<dyn LeafModel>,
517 promote_to: LeafModelType,
519 cumulative_advantage: f64,
522 n: u64,
524 max_loss_diff: f64,
526 delta: f64,
528 promoted: bool,
530 seed: u64,
532}
533
534impl AdaptiveLeafModel {
535 pub fn new(
540 shadow: Box<dyn LeafModel>,
541 promote_to: LeafModelType,
542 delta: f64,
543 seed: u64,
544 ) -> Self {
545 Self {
546 active: Box::new(ClosedFormLeaf::new()),
547 shadow,
548 promote_to,
549 cumulative_advantage: 0.0,
550 n: 0,
551 max_loss_diff: 0.0,
552 delta,
553 promoted: false,
554 seed,
555 }
556 }
557}
558
559impl LeafModel for AdaptiveLeafModel {
560 fn predict(&self, features: &[f64]) -> f64 {
561 self.active.predict(features)
562 }
563
564 fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
565 if self.promoted {
566 self.active.update(features, gradient, hessian, lambda);
568 return;
569 }
570
571 let pred_active = self.active.predict(features);
573 let pred_shadow = self.shadow.predict(features);
574
575 let loss_active = gradient * pred_active + 0.5 * hessian * pred_active * pred_active;
579 let loss_shadow = gradient * pred_shadow + 0.5 * hessian * pred_shadow * pred_shadow;
580
581 let diff = loss_active - loss_shadow;
583 self.cumulative_advantage += diff;
584 self.n += 1;
585
586 let abs_diff = math::abs(diff);
588 if abs_diff > self.max_loss_diff {
589 self.max_loss_diff = abs_diff;
590 }
591
592 self.active.update(features, gradient, hessian, lambda);
594 self.shadow.update(features, gradient, hessian, lambda);
595
596 if self.n >= 10 && self.max_loss_diff > 0.0 {
600 let mean_advantage = self.cumulative_advantage / self.n as f64;
601 if mean_advantage > 0.0 {
602 let r_squared = self.max_loss_diff * self.max_loss_diff;
603 let ln_inv_delta = math::ln(1.0 / self.delta);
604 let epsilon = math::sqrt(r_squared * ln_inv_delta / (2.0 * self.n as f64));
605
606 if mean_advantage > epsilon {
607 self.promoted = true;
609 core::mem::swap(&mut self.active, &mut self.shadow);
610 }
611 }
612 }
613 }
614
615 fn clone_fresh(&self) -> Box<dyn LeafModel> {
616 let derived_seed = self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(1);
617 Box::new(AdaptiveLeafModel::new(
618 self.promote_to.create(derived_seed, self.delta),
619 self.promote_to.clone(),
620 self.delta,
621 derived_seed,
622 ))
623 }
624 }
627
628unsafe impl Send for AdaptiveLeafModel {}
633unsafe impl Sync for AdaptiveLeafModel {}
634
635#[derive(Debug, Clone, Default, PartialEq)]
655#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
656#[non_exhaustive]
657pub enum LeafModelType {
658 #[default]
660 ClosedForm,
661
662 Linear {
671 learning_rate: f64,
673 #[cfg_attr(feature = "serde", serde(default))]
675 decay: Option<f64>,
676 #[cfg_attr(feature = "serde", serde(default))]
678 use_adagrad: bool,
679 },
680
681 MLP {
685 hidden_size: usize,
687 learning_rate: f64,
689 #[cfg_attr(feature = "serde", serde(default))]
691 decay: Option<f64>,
692 },
693
694 Adaptive {
700 promote_to: Box<LeafModelType>,
702 },
703}
704
705impl LeafModelType {
706 pub fn create(&self, seed: u64, delta: f64) -> Box<dyn LeafModel> {
713 match self {
714 Self::ClosedForm => Box::new(ClosedFormLeaf::new()),
715 Self::Linear {
716 learning_rate,
717 decay,
718 use_adagrad,
719 } => Box::new(LinearLeafModel::new(*learning_rate, *decay, *use_adagrad)),
720 Self::MLP {
721 hidden_size,
722 learning_rate,
723 decay,
724 } => Box::new(MLPLeafModel::new(
725 *hidden_size,
726 *learning_rate,
727 seed,
728 *decay,
729 )),
730 Self::Adaptive { promote_to } => Box::new(AdaptiveLeafModel::new(
731 promote_to.create(seed, delta),
732 *promote_to.clone(),
733 delta,
734 seed,
735 )),
736 }
737 }
738}
739
740use crate::rng::xorshift64;
741
742#[cfg(test)]
747mod tests {
748 use super::*;
749
750 fn xorshift64(state: &mut u64) -> u64 {
752 let mut s = *state;
753 s ^= s << 13;
754 s ^= s >> 7;
755 s ^= s << 17;
756 *state = s;
757 s
758 }
759
760 fn rand_f64(state: &mut u64) -> f64 {
762 xorshift64(state) as f64 / u64::MAX as f64
763 }
764
765 #[test]
766 fn closed_form_matches_formula() {
767 let mut leaf = ClosedFormLeaf::new();
768 let lambda = 1.0;
769
770 let updates = [(0.5, 1.0), (-0.3, 0.8), (1.2, 2.0), (-0.1, 0.5)];
772 let mut grad_sum = 0.0;
773 let mut hess_sum = 0.0;
774
775 for &(g, h) in &updates {
776 leaf.update(&[], g, h, lambda);
777 grad_sum += g;
778 hess_sum += h;
779 }
780
781 let expected = -grad_sum / (hess_sum + lambda);
782 let predicted = leaf.predict(&[]);
783
784 assert!(
785 (predicted - expected).abs() < 1e-12,
786 "closed form mismatch: got {predicted}, expected {expected}"
787 );
788 }
789
790 #[test]
791 fn closed_form_clone_fresh_resets() {
792 let mut leaf = ClosedFormLeaf::new();
793 leaf.update(&[], 5.0, 2.0, 1.0);
794 assert!(
795 leaf.predict(&[]).abs() > 0.0,
796 "leaf should have non-zero weight after update"
797 );
798
799 let fresh = leaf.clone_fresh();
800 assert!(
801 fresh.predict(&[]).abs() < 1e-15,
802 "fresh clone should predict 0, got {}",
803 fresh.predict(&[])
804 );
805 }
806
807 #[test]
808 fn linear_converges_on_linear_target() {
809 let mut model = LinearLeafModel::new(0.01, None, false);
811 let lambda = 0.1;
812 let mut rng = 42u64;
813
814 for _ in 0..2000 {
815 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
816 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
817 let features = vec![x1, x2];
818 let target = 2.0 * x1 + 3.0 * x2;
819
820 let pred = model.predict(&features);
821 let gradient = 2.0 * (pred - target);
822 let hessian = 2.0;
823 model.update(&features, gradient, hessian, lambda);
824 }
825
826 let test_features = vec![0.5, -0.3];
827 let target = 2.0 * 0.5 + 3.0 * (-0.3);
828 let pred = model.predict(&test_features);
829
830 assert!(
831 (pred - target).abs() < 1.0,
832 "linear model should converge within 1.0 of target: pred={pred}, target={target}"
833 );
834 }
835
836 #[test]
837 fn linear_uninitialized_predicts_zero() {
838 let model = LinearLeafModel::new(0.01, None, false);
839 let pred = model.predict(&[1.0, 2.0, 3.0]);
840 assert!(
841 pred.abs() < 1e-15,
842 "uninitialized linear model should predict 0, got {pred}"
843 );
844 }
845
846 #[test]
847 fn linear_clone_warm_preserves_weights() {
848 let mut model = LinearLeafModel::new(0.01, None, false);
849 let features = vec![1.0, 2.0];
850
851 for i in 0..100 {
853 let target = 3.0 * features[0] + 2.0 * features[1];
854 let pred = model.predict(&features);
855 let gradient = 2.0 * (pred - target);
856 model.update(&features, gradient, 2.0, 0.1);
857 let _ = i;
859 }
860
861 let trained_pred = model.predict(&features);
862 assert!(
863 trained_pred.abs() > 0.01,
864 "model should have learned something"
865 );
866
867 let warm = model.clone_warm();
869 let warm_pred = warm.predict(&features);
870 assert!(
871 (warm_pred - trained_pred).abs() < 1e-12,
872 "warm clone should preserve weights: trained={trained_pred}, warm={warm_pred}"
873 );
874
875 let fresh = model.clone_fresh();
877 let fresh_pred = fresh.predict(&features);
878 assert!(
879 fresh_pred.abs() < 1e-15,
880 "fresh clone should predict 0, got {fresh_pred}"
881 );
882 }
883
884 #[test]
885 fn linear_decay_forgets_old_data() {
886 let mut model_decay = LinearLeafModel::new(0.05, Some(0.99), false);
889 let mut model_no_decay = LinearLeafModel::new(0.05, None, false);
890 let features = vec![1.0];
891 let lambda = 0.1;
892
893 for _ in 0..500 {
895 let pred_d = model_decay.predict(&features);
896 let pred_n = model_no_decay.predict(&features);
897 model_decay.update(&features, 2.0 * (pred_d - 5.0), 2.0, lambda);
898 model_no_decay.update(&features, 2.0 * (pred_n - 5.0), 2.0, lambda);
899 }
900
901 let pred_d_trained = model_decay.predict(&features);
903 let pred_n_trained = model_no_decay.predict(&features);
904 assert!(
905 (pred_d_trained - 5.0).abs() < 2.0,
906 "decay model should approximate target"
907 );
908 assert!(
909 (pred_n_trained - 5.0).abs() < 2.0,
910 "no-decay model should approximate target"
911 );
912
913 for _ in 0..200 {
917 model_decay.update(&features, 0.0, 1.0, lambda);
918 model_no_decay.update(&features, 0.0, 1.0, lambda);
919 }
920
921 let pred_d_after = model_decay.predict(&features);
922 let pred_n_after = model_no_decay.predict(&features);
923
924 assert!(
926 pred_d_after.abs() < pred_n_after.abs(),
927 "decay model should forget: decay pred={pred_d_after:.3}, no-decay pred={pred_n_after:.3}"
928 );
929 }
930
931 #[test]
932 fn mlp_produces_finite_predictions() {
933 let model_uninit = MLPLeafModel::new(4, 0.01, 42, None);
934 let features = vec![1.0, 2.0, 3.0];
935
936 let pred_before = model_uninit.predict(&features);
937 assert!(
938 pred_before.is_finite(),
939 "uninit prediction should be finite"
940 );
941 assert!(
942 pred_before.abs() < 1e-15,
943 "uninit prediction should be 0, got {pred_before}"
944 );
945
946 let mut model = MLPLeafModel::new(4, 0.01, 42, None);
947 for _ in 0..10 {
948 model.update(&features, 0.5, 1.0, 0.1);
949 }
950 let pred_after = model.predict(&features);
951 assert!(
952 pred_after.is_finite(),
953 "prediction after training should be finite, got {pred_after}"
954 );
955 }
956
957 #[test]
958 fn mlp_loss_decreases() {
959 let mut model = MLPLeafModel::new(8, 0.05, 123, None);
960 let features = vec![1.0, -0.5, 0.3];
961 let target = 2.5;
962 let lambda = 0.1;
963
964 model.update(&features, 0.0, 1.0, lambda); let initial_pred = model.predict(&features);
966 let initial_error = (initial_pred - target).abs();
967
968 for _ in 0..200 {
969 let pred = model.predict(&features);
970 let gradient = 2.0 * (pred - target);
971 let hessian = 2.0;
972 model.update(&features, gradient, hessian, lambda);
973 }
974
975 let final_pred = model.predict(&features);
976 let final_error = (final_pred - target).abs();
977
978 assert!(
979 final_error < initial_error,
980 "MLP error should decrease: initial={initial_error}, final={final_error}"
981 );
982 }
983
984 #[test]
985 fn mlp_clone_fresh_resets() {
986 let mut model = MLPLeafModel::new(4, 0.01, 42, None);
987 let features = vec![1.0, 2.0];
988
989 for _ in 0..20 {
990 model.update(&features, 0.5, 1.0, 0.1);
991 }
992
993 let trained_pred = model.predict(&features);
994 assert!(
995 trained_pred.abs() > 1e-10,
996 "trained model should have non-zero prediction"
997 );
998
999 let fresh = model.clone_fresh();
1000 let fresh_pred = fresh.predict(&features);
1001 assert!(
1002 fresh_pred.abs() < 1e-15,
1003 "fresh clone should predict 0, got {fresh_pred}"
1004 );
1005 }
1006
1007 #[test]
1008 fn mlp_clone_warm_preserves_weights() {
1009 let mut model = MLPLeafModel::new(4, 0.01, 42, None);
1010 let features = vec![1.0, 2.0];
1011
1012 for _ in 0..50 {
1013 model.update(&features, 0.5, 1.0, 0.1);
1014 }
1015
1016 let trained_pred = model.predict(&features);
1017 let warm = model.clone_warm();
1018 let warm_pred = warm.predict(&features);
1019
1020 assert!(
1021 (warm_pred - trained_pred).abs() < 1e-10,
1022 "warm clone should preserve predictions: trained={trained_pred}, warm={warm_pred}"
1023 );
1024 }
1025
1026 #[test]
1027 fn leaf_model_type_default_is_closed_form() {
1028 let default_type = LeafModelType::default();
1029 assert!(
1030 matches!(default_type, LeafModelType::ClosedForm),
1031 "default LeafModelType should be ClosedForm, got {default_type:?}"
1032 );
1033 }
1034
1035 #[test]
1036 fn leaf_model_type_create_all_variants() {
1037 let features = vec![1.0, 2.0, 3.0];
1038 let delta = 1e-7;
1039
1040 let mut closed = LeafModelType::ClosedForm.create(0, delta);
1042 closed.update(&features, 1.0, 1.0, 0.1);
1043 let p = closed.predict(&features);
1044 assert!(p.is_finite(), "ClosedForm prediction should be finite");
1045
1046 let mut linear = LeafModelType::Linear {
1048 learning_rate: 0.01,
1049 decay: None,
1050 use_adagrad: false,
1051 }
1052 .create(0, delta);
1053 linear.update(&features, 1.0, 1.0, 0.1);
1054 let p = linear.predict(&features);
1055 assert!(p.is_finite(), "Linear prediction should be finite");
1056
1057 let mut mlp = LeafModelType::MLP {
1059 hidden_size: 4,
1060 learning_rate: 0.01,
1061 decay: None,
1062 }
1063 .create(99, delta);
1064 mlp.update(&features, 1.0, 1.0, 0.1);
1065 let p = mlp.predict(&features);
1066 assert!(p.is_finite(), "MLP prediction should be finite");
1067
1068 let mut adaptive = LeafModelType::Adaptive {
1070 promote_to: Box::new(LeafModelType::Linear {
1071 learning_rate: 0.01,
1072 decay: None,
1073 use_adagrad: false,
1074 }),
1075 }
1076 .create(42, delta);
1077 adaptive.update(&features, 1.0, 1.0, 0.1);
1078 let p = adaptive.predict(&features);
1079 assert!(p.is_finite(), "Adaptive prediction should be finite");
1080 }
1081
1082 #[test]
1083 fn adaptive_promotes_on_linear_target() {
1084 let promote_to = LeafModelType::Linear {
1087 learning_rate: 0.01,
1088 decay: None,
1089 use_adagrad: false,
1090 };
1091 let shadow = promote_to.create(42, 1e-7);
1092 let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-3, 42);
1093
1094 let mut rng = 42u64;
1095 for _ in 0..5000 {
1096 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1097 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1098 let features = vec![x1, x2];
1099 let target = 3.0 * x1 + 2.0 * x2;
1100
1101 let pred = model.predict(&features);
1102 let gradient = 2.0 * (pred - target);
1103 let hessian = 2.0;
1104 model.update(&features, gradient, hessian, 0.1);
1105 }
1106
1107 assert!(
1110 model.promoted,
1111 "adaptive model should have promoted on linear target after 5000 samples"
1112 );
1113 }
1114
1115 #[test]
1116 fn adaptive_does_not_promote_on_constant_target() {
1117 let promote_to = LeafModelType::Linear {
1119 learning_rate: 0.01,
1120 decay: None,
1121 use_adagrad: false,
1122 };
1123 let shadow = promote_to.create(42, 1e-7);
1124 let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-7, 42);
1125
1126 for _ in 0..2000 {
1127 let features = vec![1.0, 2.0];
1128 let target = 5.0; let pred = model.predict(&features);
1130 let gradient = 2.0 * (pred - target);
1131 let hessian = 2.0;
1132 model.update(&features, gradient, hessian, 0.1);
1133 }
1134
1135 let pred = model.predict(&[1.0, 2.0]);
1140 assert!(pred.is_finite(), "prediction should be finite");
1141 }
1142
1143 #[test]
1144 fn adaptive_clone_fresh_resets_promotion() {
1145 let promote_to = LeafModelType::Linear {
1146 learning_rate: 0.01,
1147 decay: None,
1148 use_adagrad: false,
1149 };
1150 let shadow = promote_to.create(42, 1e-3);
1151 let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-3, 42);
1152
1153 let mut rng = 42u64;
1155 for _ in 0..5000 {
1156 let x = rand_f64(&mut rng) * 2.0 - 1.0;
1157 let features = vec![x];
1158 let pred = model.predict(&features);
1159 model.update(&features, 2.0 * (pred - 3.0 * x), 2.0, 0.1);
1160 }
1161
1162 let fresh = model.clone_fresh();
1163 let p = fresh.predict(&[0.5]);
1165 assert!(
1166 p.abs() < 1e-10,
1167 "fresh adaptive clone should predict ~0, got {p}"
1168 );
1169 }
1170}