1use alloc::boxed::Box;
45use alloc::vec;
46use alloc::vec::Vec;
47
48use crate::math;
49
50pub trait LeafModel: Send + Sync {
54 fn predict(&self, features: &[f64]) -> f64;
56
57 fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64);
59
60 fn clone_fresh(&self) -> Box<dyn LeafModel>;
62
63 fn clone_warm(&self) -> Box<dyn LeafModel> {
70 self.clone_fresh()
71 }
72}
73
74pub struct ClosedFormLeaf {
83 grad_sum: f64,
84 hess_sum: f64,
85 weight: f64,
86}
87
88impl Default for ClosedFormLeaf {
89 fn default() -> Self {
90 Self {
91 grad_sum: 0.0,
92 hess_sum: 0.0,
93 weight: 0.0,
94 }
95 }
96}
97
98impl ClosedFormLeaf {
99 pub fn new() -> Self {
101 Self::default()
102 }
103}
104
105impl LeafModel for ClosedFormLeaf {
106 fn predict(&self, _features: &[f64]) -> f64 {
107 self.weight
108 }
109
110 fn update(&mut self, _features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
111 self.grad_sum += gradient;
112 self.hess_sum += hessian;
113 self.weight = -self.grad_sum / (self.hess_sum + lambda);
114 }
115
116 fn clone_fresh(&self) -> Box<dyn LeafModel> {
117 Box::new(ClosedFormLeaf::new())
118 }
119}
120
121pub struct LinearLeafModel {
137 weights: Vec<f64>,
138 bias: f64,
139 learning_rate: f64,
140 decay: Option<f64>,
141 use_adagrad: bool,
142 sq_grad_accum: Vec<f64>,
144 sq_bias_accum: f64,
146 initialized: bool,
147}
148
149impl LinearLeafModel {
150 pub fn new(learning_rate: f64, decay: Option<f64>, use_adagrad: bool) -> Self {
161 Self {
162 weights: Vec::new(),
163 bias: 0.0,
164 learning_rate,
165 decay,
166 use_adagrad,
167 sq_grad_accum: Vec::new(),
168 sq_bias_accum: 0.0,
169 initialized: false,
170 }
171 }
172}
173
174const ADAGRAD_EPS: f64 = 1e-8;
176
177impl LeafModel for LinearLeafModel {
178 fn predict(&self, features: &[f64]) -> f64 {
179 if !self.initialized {
180 return 0.0;
181 }
182 let mut dot = self.bias;
183 for (w, x) in self.weights.iter().zip(features.iter()) {
184 dot += w * x;
185 }
186 dot
187 }
188
189 fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
190 if !self.initialized {
191 let d = features.len();
192 self.weights = vec![0.0; d];
193 self.sq_grad_accum = vec![0.0; d];
194 self.initialized = true;
195 }
196
197 if let Some(d) = self.decay {
199 for w in self.weights.iter_mut() {
200 *w *= d;
201 }
202 self.bias *= d;
203 }
204
205 let base_lr = self.learning_rate / (math::abs(hessian) + lambda);
207
208 if self.use_adagrad {
209 for (i, (w, x)) in self.weights.iter_mut().zip(features.iter()).enumerate() {
211 let g = gradient * x;
212 self.sq_grad_accum[i] += g * g;
213 let adaptive_lr = base_lr / (math::sqrt(self.sq_grad_accum[i]) + ADAGRAD_EPS);
214 *w -= adaptive_lr * g;
215 }
216 self.sq_bias_accum += gradient * gradient;
217 let bias_lr = base_lr / (math::sqrt(self.sq_bias_accum) + ADAGRAD_EPS);
218 self.bias -= bias_lr * gradient;
219 } else {
220 for (w, x) in self.weights.iter_mut().zip(features.iter()) {
222 *w -= base_lr * gradient * x;
223 }
224 self.bias -= base_lr * gradient;
225 }
226 }
227
228 fn clone_fresh(&self) -> Box<dyn LeafModel> {
229 Box::new(LinearLeafModel::new(
230 self.learning_rate,
231 self.decay,
232 self.use_adagrad,
233 ))
234 }
235
236 fn clone_warm(&self) -> Box<dyn LeafModel> {
237 Box::new(LinearLeafModel {
238 weights: self.weights.clone(),
239 bias: self.bias,
240 learning_rate: self.learning_rate,
241 decay: self.decay,
242 use_adagrad: self.use_adagrad,
243 sq_grad_accum: vec![0.0; self.weights.len()],
248 sq_bias_accum: 0.0,
249 initialized: self.initialized,
250 })
251 }
252}
253
254pub struct MLPLeafModel {
267 hidden_weights: Vec<Vec<f64>>, hidden_bias: Vec<f64>,
269 output_weights: Vec<f64>,
270 output_bias: f64,
271 hidden_size: usize,
272 learning_rate: f64,
273 decay: Option<f64>,
274 seed: u64,
275 initialized: bool,
276 hidden_activations: Vec<f64>,
277 hidden_pre_activations: Vec<f64>,
278}
279
280impl MLPLeafModel {
281 pub fn new(hidden_size: usize, learning_rate: f64, seed: u64, decay: Option<f64>) -> Self {
288 Self {
289 hidden_weights: Vec::new(),
290 hidden_bias: Vec::new(),
291 output_weights: Vec::new(),
292 output_bias: 0.0,
293 hidden_size,
294 learning_rate,
295 decay,
296 seed,
297 initialized: false,
298 hidden_activations: Vec::new(),
299 hidden_pre_activations: Vec::new(),
300 }
301 }
302
303 fn initialize(&mut self, input_size: usize) {
305 let mut state = self.seed ^ (self.hidden_size as u64);
306
307 self.hidden_weights = Vec::with_capacity(self.hidden_size);
308 for _ in 0..self.hidden_size {
309 let mut row = Vec::with_capacity(input_size);
310 for _ in 0..input_size {
311 let r = xorshift64(&mut state);
312 let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
314 row.push(val);
315 }
316 self.hidden_weights.push(row);
317 }
318
319 self.hidden_bias = Vec::with_capacity(self.hidden_size);
320 for _ in 0..self.hidden_size {
321 let r = xorshift64(&mut state);
322 let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
323 self.hidden_bias.push(val);
324 }
325
326 self.output_weights = Vec::with_capacity(self.hidden_size);
327 for _ in 0..self.hidden_size {
328 let r = xorshift64(&mut state);
329 let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
330 self.output_weights.push(val);
331 }
332
333 {
334 let r = xorshift64(&mut state);
335 self.output_bias = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
336 }
337
338 self.hidden_activations = vec![0.0; self.hidden_size];
339 self.hidden_pre_activations = vec![0.0; self.hidden_size];
340 self.initialized = true;
341 }
342
343 fn forward(&mut self, features: &[f64]) -> f64 {
345 for h in 0..self.hidden_size {
347 let mut z = self.hidden_bias[h];
348 for (j, x) in features.iter().enumerate() {
349 if j < self.hidden_weights[h].len() {
350 z += self.hidden_weights[h][j] * x;
351 }
352 }
353 self.hidden_pre_activations[h] = z;
354 self.hidden_activations[h] = if z > 0.0 { z } else { 0.0 };
356 }
357
358 let mut out = self.output_bias;
360 for (w, a) in self
361 .output_weights
362 .iter()
363 .zip(self.hidden_activations.iter())
364 {
365 out += w * a;
366 }
367 out
368 }
369}
370
371impl LeafModel for MLPLeafModel {
372 fn predict(&self, features: &[f64]) -> f64 {
373 if !self.initialized {
374 return 0.0;
375 }
376 let hidden_acts: Vec<f64> = self
378 .hidden_weights
379 .iter()
380 .zip(self.hidden_bias.iter())
381 .map(|(hw, &hb)| {
382 let mut z = hb;
383 for (j, x) in features.iter().enumerate() {
384 if j < hw.len() {
385 z += hw[j] * x;
386 }
387 }
388 if z > 0.0 {
389 z
390 } else {
391 0.0
392 }
393 })
394 .collect();
395 let mut out = self.output_bias;
396 for (w, a) in self.output_weights.iter().zip(hidden_acts.iter()) {
397 out += w * a;
398 }
399 out
400 }
401
402 fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
403 if !self.initialized {
404 self.initialize(features.len());
405 }
406
407 if let Some(d) = self.decay {
409 for row in self.hidden_weights.iter_mut() {
410 for w in row.iter_mut() {
411 *w *= d;
412 }
413 }
414 for b in self.hidden_bias.iter_mut() {
415 *b *= d;
416 }
417 for w in self.output_weights.iter_mut() {
418 *w *= d;
419 }
420 self.output_bias *= d;
421 }
422
423 let _output = self.forward(features);
425
426 let effective_lr = self.learning_rate / (math::abs(hessian) + lambda);
427
428 let d_output = gradient;
430
431 for h in 0..self.hidden_size {
435 self.output_weights[h] -= effective_lr * d_output * self.hidden_activations[h];
436 }
437 self.output_bias -= effective_lr * d_output;
438
439 for h in 0..self.hidden_size {
441 let d_hidden_act = d_output * self.output_weights[h];
443
444 let d_relu = if self.hidden_pre_activations[h] > 0.0 {
446 d_hidden_act
447 } else {
448 0.0
449 };
450
451 for (j, x) in features.iter().enumerate() {
453 if j < self.hidden_weights[h].len() {
454 self.hidden_weights[h][j] -= effective_lr * d_relu * x;
455 }
456 }
457 self.hidden_bias[h] -= effective_lr * d_relu;
458 }
459 }
460
461 fn clone_fresh(&self) -> Box<dyn LeafModel> {
462 let derived_seed = self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(1);
464 Box::new(MLPLeafModel::new(
465 self.hidden_size,
466 self.learning_rate,
467 derived_seed,
468 self.decay,
469 ))
470 }
471
472 fn clone_warm(&self) -> Box<dyn LeafModel> {
473 Box::new(MLPLeafModel {
474 hidden_weights: self.hidden_weights.clone(),
475 hidden_bias: self.hidden_bias.clone(),
476 output_weights: self.output_weights.clone(),
477 output_bias: self.output_bias,
478 hidden_size: self.hidden_size,
479 learning_rate: self.learning_rate,
480 decay: self.decay,
481 seed: self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(2),
483 initialized: self.initialized,
484 hidden_activations: vec![0.0; self.hidden_size],
485 hidden_pre_activations: vec![0.0; self.hidden_size],
486 })
487 }
488}
489
490pub struct AdaptiveLeafModel {
513 active: Box<dyn LeafModel>,
515 shadow: Box<dyn LeafModel>,
517 promote_to: LeafModelType,
519 cumulative_advantage: f64,
522 n: u64,
524 max_loss_diff: f64,
526 delta: f64,
528 promoted: bool,
530 seed: u64,
532}
533
534impl AdaptiveLeafModel {
535 pub fn new(
540 shadow: Box<dyn LeafModel>,
541 promote_to: LeafModelType,
542 delta: f64,
543 seed: u64,
544 ) -> Self {
545 Self {
546 active: Box::new(ClosedFormLeaf::new()),
547 shadow,
548 promote_to,
549 cumulative_advantage: 0.0,
550 n: 0,
551 max_loss_diff: 0.0,
552 delta,
553 promoted: false,
554 seed,
555 }
556 }
557}
558
559impl LeafModel for AdaptiveLeafModel {
560 fn predict(&self, features: &[f64]) -> f64 {
561 self.active.predict(features)
562 }
563
564 fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
565 if self.promoted {
566 self.active.update(features, gradient, hessian, lambda);
568 return;
569 }
570
571 let pred_active = self.active.predict(features);
573 let pred_shadow = self.shadow.predict(features);
574
575 let loss_active = gradient * pred_active + 0.5 * hessian * pred_active * pred_active;
579 let loss_shadow = gradient * pred_shadow + 0.5 * hessian * pred_shadow * pred_shadow;
580
581 let diff = loss_active - loss_shadow;
583 self.cumulative_advantage += diff;
584 self.n += 1;
585
586 let abs_diff = math::abs(diff);
588 if abs_diff > self.max_loss_diff {
589 self.max_loss_diff = abs_diff;
590 }
591
592 self.active.update(features, gradient, hessian, lambda);
594 self.shadow.update(features, gradient, hessian, lambda);
595
596 if self.n >= 10 && self.max_loss_diff > 0.0 {
600 let mean_advantage = self.cumulative_advantage / self.n as f64;
601 if mean_advantage > 0.0 {
602 let r_squared = self.max_loss_diff * self.max_loss_diff;
603 let ln_inv_delta = math::ln(1.0 / self.delta);
604 let epsilon = math::sqrt(r_squared * ln_inv_delta / (2.0 * self.n as f64));
605
606 if mean_advantage > epsilon {
607 self.promoted = true;
609 core::mem::swap(&mut self.active, &mut self.shadow);
610 }
611 }
612 }
613 }
614
615 fn clone_fresh(&self) -> Box<dyn LeafModel> {
616 let derived_seed = self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(1);
617 Box::new(AdaptiveLeafModel::new(
618 self.promote_to.create(derived_seed, self.delta),
619 self.promote_to.clone(),
620 self.delta,
621 derived_seed,
622 ))
623 }
624 }
627
628unsafe impl Send for AdaptiveLeafModel {}
633unsafe impl Sync for AdaptiveLeafModel {}
634
635#[derive(Debug, Clone, Default, PartialEq)]
655#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
656pub enum LeafModelType {
657 #[default]
659 ClosedForm,
660
661 Linear {
670 learning_rate: f64,
671 #[cfg_attr(feature = "serde", serde(default))]
672 decay: Option<f64>,
673 #[cfg_attr(feature = "serde", serde(default))]
674 use_adagrad: bool,
675 },
676
677 MLP {
681 hidden_size: usize,
682 learning_rate: f64,
683 #[cfg_attr(feature = "serde", serde(default))]
684 decay: Option<f64>,
685 },
686
687 Adaptive { promote_to: Box<LeafModelType> },
693}
694
695impl LeafModelType {
696 pub fn create(&self, seed: u64, delta: f64) -> Box<dyn LeafModel> {
703 match self {
704 Self::ClosedForm => Box::new(ClosedFormLeaf::new()),
705 Self::Linear {
706 learning_rate,
707 decay,
708 use_adagrad,
709 } => Box::new(LinearLeafModel::new(*learning_rate, *decay, *use_adagrad)),
710 Self::MLP {
711 hidden_size,
712 learning_rate,
713 decay,
714 } => Box::new(MLPLeafModel::new(
715 *hidden_size,
716 *learning_rate,
717 seed,
718 *decay,
719 )),
720 Self::Adaptive { promote_to } => Box::new(AdaptiveLeafModel::new(
721 promote_to.create(seed, delta),
722 *promote_to.clone(),
723 delta,
724 seed,
725 )),
726 }
727 }
728}
729
730fn xorshift64(state: &mut u64) -> u64 {
736 let mut s = *state;
737 s ^= s << 13;
738 s ^= s >> 7;
739 s ^= s << 17;
740 *state = s;
741 s
742}
743
744#[cfg(test)]
749mod tests {
750 use super::*;
751
752 fn xorshift64(state: &mut u64) -> u64 {
754 let mut s = *state;
755 s ^= s << 13;
756 s ^= s >> 7;
757 s ^= s << 17;
758 *state = s;
759 s
760 }
761
762 fn rand_f64(state: &mut u64) -> f64 {
764 xorshift64(state) as f64 / u64::MAX as f64
765 }
766
767 #[test]
768 fn closed_form_matches_formula() {
769 let mut leaf = ClosedFormLeaf::new();
770 let lambda = 1.0;
771
772 let updates = [(0.5, 1.0), (-0.3, 0.8), (1.2, 2.0), (-0.1, 0.5)];
774 let mut grad_sum = 0.0;
775 let mut hess_sum = 0.0;
776
777 for &(g, h) in &updates {
778 leaf.update(&[], g, h, lambda);
779 grad_sum += g;
780 hess_sum += h;
781 }
782
783 let expected = -grad_sum / (hess_sum + lambda);
784 let predicted = leaf.predict(&[]);
785
786 assert!(
787 (predicted - expected).abs() < 1e-12,
788 "closed form mismatch: got {predicted}, expected {expected}"
789 );
790 }
791
792 #[test]
793 fn closed_form_clone_fresh_resets() {
794 let mut leaf = ClosedFormLeaf::new();
795 leaf.update(&[], 5.0, 2.0, 1.0);
796 assert!(
797 leaf.predict(&[]).abs() > 0.0,
798 "leaf should have non-zero weight after update"
799 );
800
801 let fresh = leaf.clone_fresh();
802 assert!(
803 fresh.predict(&[]).abs() < 1e-15,
804 "fresh clone should predict 0, got {}",
805 fresh.predict(&[])
806 );
807 }
808
809 #[test]
810 fn linear_converges_on_linear_target() {
811 let mut model = LinearLeafModel::new(0.01, None, false);
813 let lambda = 0.1;
814 let mut rng = 42u64;
815
816 for _ in 0..2000 {
817 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
818 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
819 let features = vec![x1, x2];
820 let target = 2.0 * x1 + 3.0 * x2;
821
822 let pred = model.predict(&features);
823 let gradient = 2.0 * (pred - target);
824 let hessian = 2.0;
825 model.update(&features, gradient, hessian, lambda);
826 }
827
828 let test_features = vec![0.5, -0.3];
829 let target = 2.0 * 0.5 + 3.0 * (-0.3);
830 let pred = model.predict(&test_features);
831
832 assert!(
833 (pred - target).abs() < 1.0,
834 "linear model should converge within 1.0 of target: pred={pred}, target={target}"
835 );
836 }
837
838 #[test]
839 fn linear_uninitialized_predicts_zero() {
840 let model = LinearLeafModel::new(0.01, None, false);
841 let pred = model.predict(&[1.0, 2.0, 3.0]);
842 assert!(
843 pred.abs() < 1e-15,
844 "uninitialized linear model should predict 0, got {pred}"
845 );
846 }
847
848 #[test]
849 fn linear_clone_warm_preserves_weights() {
850 let mut model = LinearLeafModel::new(0.01, None, false);
851 let features = vec![1.0, 2.0];
852
853 for i in 0..100 {
855 let target = 3.0 * features[0] + 2.0 * features[1];
856 let pred = model.predict(&features);
857 let gradient = 2.0 * (pred - target);
858 model.update(&features, gradient, 2.0, 0.1);
859 let _ = i;
861 }
862
863 let trained_pred = model.predict(&features);
864 assert!(
865 trained_pred.abs() > 0.01,
866 "model should have learned something"
867 );
868
869 let warm = model.clone_warm();
871 let warm_pred = warm.predict(&features);
872 assert!(
873 (warm_pred - trained_pred).abs() < 1e-12,
874 "warm clone should preserve weights: trained={trained_pred}, warm={warm_pred}"
875 );
876
877 let fresh = model.clone_fresh();
879 let fresh_pred = fresh.predict(&features);
880 assert!(
881 fresh_pred.abs() < 1e-15,
882 "fresh clone should predict 0, got {fresh_pred}"
883 );
884 }
885
886 #[test]
887 fn linear_decay_forgets_old_data() {
888 let mut model_decay = LinearLeafModel::new(0.05, Some(0.99), false);
891 let mut model_no_decay = LinearLeafModel::new(0.05, None, false);
892 let features = vec![1.0];
893 let lambda = 0.1;
894
895 for _ in 0..500 {
897 let pred_d = model_decay.predict(&features);
898 let pred_n = model_no_decay.predict(&features);
899 model_decay.update(&features, 2.0 * (pred_d - 5.0), 2.0, lambda);
900 model_no_decay.update(&features, 2.0 * (pred_n - 5.0), 2.0, lambda);
901 }
902
903 let pred_d_trained = model_decay.predict(&features);
905 let pred_n_trained = model_no_decay.predict(&features);
906 assert!(
907 (pred_d_trained - 5.0).abs() < 2.0,
908 "decay model should approximate target"
909 );
910 assert!(
911 (pred_n_trained - 5.0).abs() < 2.0,
912 "no-decay model should approximate target"
913 );
914
915 for _ in 0..200 {
919 model_decay.update(&features, 0.0, 1.0, lambda);
920 model_no_decay.update(&features, 0.0, 1.0, lambda);
921 }
922
923 let pred_d_after = model_decay.predict(&features);
924 let pred_n_after = model_no_decay.predict(&features);
925
926 assert!(
928 pred_d_after.abs() < pred_n_after.abs(),
929 "decay model should forget: decay pred={pred_d_after:.3}, no-decay pred={pred_n_after:.3}"
930 );
931 }
932
933 #[test]
934 fn mlp_produces_finite_predictions() {
935 let model_uninit = MLPLeafModel::new(4, 0.01, 42, None);
936 let features = vec![1.0, 2.0, 3.0];
937
938 let pred_before = model_uninit.predict(&features);
939 assert!(
940 pred_before.is_finite(),
941 "uninit prediction should be finite"
942 );
943 assert!(
944 pred_before.abs() < 1e-15,
945 "uninit prediction should be 0, got {pred_before}"
946 );
947
948 let mut model = MLPLeafModel::new(4, 0.01, 42, None);
949 for _ in 0..10 {
950 model.update(&features, 0.5, 1.0, 0.1);
951 }
952 let pred_after = model.predict(&features);
953 assert!(
954 pred_after.is_finite(),
955 "prediction after training should be finite, got {pred_after}"
956 );
957 }
958
959 #[test]
960 fn mlp_loss_decreases() {
961 let mut model = MLPLeafModel::new(8, 0.05, 123, None);
962 let features = vec![1.0, -0.5, 0.3];
963 let target = 2.5;
964 let lambda = 0.1;
965
966 model.update(&features, 0.0, 1.0, lambda); let initial_pred = model.predict(&features);
968 let initial_error = (initial_pred - target).abs();
969
970 for _ in 0..200 {
971 let pred = model.predict(&features);
972 let gradient = 2.0 * (pred - target);
973 let hessian = 2.0;
974 model.update(&features, gradient, hessian, lambda);
975 }
976
977 let final_pred = model.predict(&features);
978 let final_error = (final_pred - target).abs();
979
980 assert!(
981 final_error < initial_error,
982 "MLP error should decrease: initial={initial_error}, final={final_error}"
983 );
984 }
985
986 #[test]
987 fn mlp_clone_fresh_resets() {
988 let mut model = MLPLeafModel::new(4, 0.01, 42, None);
989 let features = vec![1.0, 2.0];
990
991 for _ in 0..20 {
992 model.update(&features, 0.5, 1.0, 0.1);
993 }
994
995 let trained_pred = model.predict(&features);
996 assert!(
997 trained_pred.abs() > 1e-10,
998 "trained model should have non-zero prediction"
999 );
1000
1001 let fresh = model.clone_fresh();
1002 let fresh_pred = fresh.predict(&features);
1003 assert!(
1004 fresh_pred.abs() < 1e-15,
1005 "fresh clone should predict 0, got {fresh_pred}"
1006 );
1007 }
1008
1009 #[test]
1010 fn mlp_clone_warm_preserves_weights() {
1011 let mut model = MLPLeafModel::new(4, 0.01, 42, None);
1012 let features = vec![1.0, 2.0];
1013
1014 for _ in 0..50 {
1015 model.update(&features, 0.5, 1.0, 0.1);
1016 }
1017
1018 let trained_pred = model.predict(&features);
1019 let warm = model.clone_warm();
1020 let warm_pred = warm.predict(&features);
1021
1022 assert!(
1023 (warm_pred - trained_pred).abs() < 1e-10,
1024 "warm clone should preserve predictions: trained={trained_pred}, warm={warm_pred}"
1025 );
1026 }
1027
1028 #[test]
1029 fn leaf_model_type_default_is_closed_form() {
1030 let default_type = LeafModelType::default();
1031 assert!(
1032 matches!(default_type, LeafModelType::ClosedForm),
1033 "default LeafModelType should be ClosedForm, got {default_type:?}"
1034 );
1035 }
1036
1037 #[test]
1038 fn leaf_model_type_create_all_variants() {
1039 let features = vec![1.0, 2.0, 3.0];
1040 let delta = 1e-7;
1041
1042 let mut closed = LeafModelType::ClosedForm.create(0, delta);
1044 closed.update(&features, 1.0, 1.0, 0.1);
1045 let p = closed.predict(&features);
1046 assert!(p.is_finite(), "ClosedForm prediction should be finite");
1047
1048 let mut linear = LeafModelType::Linear {
1050 learning_rate: 0.01,
1051 decay: None,
1052 use_adagrad: false,
1053 }
1054 .create(0, delta);
1055 linear.update(&features, 1.0, 1.0, 0.1);
1056 let p = linear.predict(&features);
1057 assert!(p.is_finite(), "Linear prediction should be finite");
1058
1059 let mut mlp = LeafModelType::MLP {
1061 hidden_size: 4,
1062 learning_rate: 0.01,
1063 decay: None,
1064 }
1065 .create(99, delta);
1066 mlp.update(&features, 1.0, 1.0, 0.1);
1067 let p = mlp.predict(&features);
1068 assert!(p.is_finite(), "MLP prediction should be finite");
1069
1070 let mut adaptive = LeafModelType::Adaptive {
1072 promote_to: Box::new(LeafModelType::Linear {
1073 learning_rate: 0.01,
1074 decay: None,
1075 use_adagrad: false,
1076 }),
1077 }
1078 .create(42, delta);
1079 adaptive.update(&features, 1.0, 1.0, 0.1);
1080 let p = adaptive.predict(&features);
1081 assert!(p.is_finite(), "Adaptive prediction should be finite");
1082 }
1083
1084 #[test]
1085 fn adaptive_promotes_on_linear_target() {
1086 let promote_to = LeafModelType::Linear {
1089 learning_rate: 0.01,
1090 decay: None,
1091 use_adagrad: false,
1092 };
1093 let shadow = promote_to.create(42, 1e-7);
1094 let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-3, 42);
1095
1096 let mut rng = 42u64;
1097 for _ in 0..5000 {
1098 let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1099 let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1100 let features = vec![x1, x2];
1101 let target = 3.0 * x1 + 2.0 * x2;
1102
1103 let pred = model.predict(&features);
1104 let gradient = 2.0 * (pred - target);
1105 let hessian = 2.0;
1106 model.update(&features, gradient, hessian, 0.1);
1107 }
1108
1109 assert!(
1112 model.promoted,
1113 "adaptive model should have promoted on linear target after 5000 samples"
1114 );
1115 }
1116
1117 #[test]
1118 fn adaptive_does_not_promote_on_constant_target() {
1119 let promote_to = LeafModelType::Linear {
1121 learning_rate: 0.01,
1122 decay: None,
1123 use_adagrad: false,
1124 };
1125 let shadow = promote_to.create(42, 1e-7);
1126 let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-7, 42);
1127
1128 for _ in 0..2000 {
1129 let features = vec![1.0, 2.0];
1130 let target = 5.0; let pred = model.predict(&features);
1132 let gradient = 2.0 * (pred - target);
1133 let hessian = 2.0;
1134 model.update(&features, gradient, hessian, 0.1);
1135 }
1136
1137 let pred = model.predict(&[1.0, 2.0]);
1142 assert!(pred.is_finite(), "prediction should be finite");
1143 }
1144
1145 #[test]
1146 fn adaptive_clone_fresh_resets_promotion() {
1147 let promote_to = LeafModelType::Linear {
1148 learning_rate: 0.01,
1149 decay: None,
1150 use_adagrad: false,
1151 };
1152 let shadow = promote_to.create(42, 1e-3);
1153 let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-3, 42);
1154
1155 let mut rng = 42u64;
1157 for _ in 0..5000 {
1158 let x = rand_f64(&mut rng) * 2.0 - 1.0;
1159 let features = vec![x];
1160 let pred = model.predict(&features);
1161 model.update(&features, 2.0 * (pred - 3.0 * x), 2.0, 0.1);
1162 }
1163
1164 let fresh = model.clone_fresh();
1165 let p = fresh.predict(&[0.5]);
1167 assert!(
1168 p.abs() < 1e-10,
1169 "fresh adaptive clone should predict ~0, got {p}"
1170 );
1171 }
1172}