1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
use crate::ensemble::config::SGBTConfig;
use crate::ensemble::step::BoostingStep;
use crate::ensemble::SGBT;
use crate::loss::Loss;
impl<L: Loss> SGBT<L> {
/// Number of boosting steps (trees) in the ensemble.
pub fn n_steps(&self) -> usize {
self.steps.len()
}
/// Total trees (active + alternates).
pub fn n_trees(&self) -> usize {
self.steps.len() + self.steps.iter().filter(|s| s.has_alternate()).count()
}
/// Total leaves across all active trees.
pub fn total_leaves(&self) -> usize {
self.steps.iter().map(|s| s.n_leaves()).sum()
}
/// Total samples trained.
pub fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
/// Current tree contribution standard deviation (honest uncertainty).
///
/// This is the EWMA of per-sample contribution sigma across trees,
/// computed from the `adaptive_mts` machinery. Reflects how much
/// individual trees disagree — higher values indicate more model
/// uncertainty / regime change.
///
/// Returns 0.0 if `adaptive_mts` is not enabled or the model has
/// not yet been trained.
#[inline]
pub fn contribution_sigma(&self) -> f64 {
self.rolling_contribution_sigma
}
/// The current base prediction.
pub fn base_prediction(&self) -> f64 {
self.base_prediction
}
/// Whether the base prediction has been initialized.
pub fn is_initialized(&self) -> bool {
self.base_initialized
}
/// Access the configuration.
pub fn config(&self) -> &SGBTConfig {
&self.config
}
/// Set the learning rate for future boosting rounds.
///
/// This allows external schedulers to adapt the rate over time without
/// rebuilding the model.
///
/// # Panics
///
/// Panics if `lr` is not in `(0.0, 1.0]` or is not finite.
#[inline]
pub fn set_learning_rate(&mut self, lr: f64) {
assert!(
lr > 0.0 && lr <= 1.0 && lr.is_finite(),
"learning_rate must be in (0.0, 1.0], got {}",
lr
);
self.config.learning_rate = lr;
}
/// Set the L2 regularization parameter (lambda) for future boosting rounds.
///
/// Higher lambda increases regularization, shrinking leaf weights toward
/// zero. Takes effect immediately for subsequent leaf weight computations.
///
/// # Arguments
///
/// * `lambda` -- new L2 regularization value (must be >= 0)
#[inline]
pub fn set_lambda(&mut self, lambda: f64) {
self.config.lambda = lambda.max(0.0);
}
/// Set the maximum tree depth for future replacement trees.
///
/// Existing trees are not affected -- only new trees created during
/// drift-triggered or proactive replacement will use the updated depth.
///
/// # Arguments
///
/// * `depth` -- new maximum depth (clamped to 1..=20)
#[inline]
pub fn set_max_depth(&mut self, depth: usize) {
self.config.max_depth = depth.clamp(1, 20);
}
/// Adjust the number of boosting steps (trees in the ensemble).
///
/// - **Growing** (`n > current`): appends fresh trees using the current config.
/// - **Shrinking** (`n < current`): removes trailing steps (newest trees).
/// - Clamped to `3..=1000` to prevent degenerate ensembles.
pub fn set_n_steps(&mut self, n: usize) {
let n = n.clamp(3, 1000);
let current = self.steps.len();
if n > current {
let leaf_decay_alpha = self
.config
.leaf_half_life
.map(|hl| (-(2.0_f64.ln()) / hl as f64).exp());
let tree_config = crate::ensemble::config::build_tree_config(&self.config)
.leaf_decay_alpha_opt(leaf_decay_alpha);
let mts = self.config.max_tree_samples;
let shadow_warmup = self.config.shadow_warmup.unwrap_or(0);
for i in current..n {
let mut tc = tree_config.clone();
tc.seed = self.config.seed ^ (i as u64);
let detector = self.config.drift_detector.create();
let step = if shadow_warmup > 0 {
BoostingStep::new_with_graduated(tc, detector, mts, shadow_warmup)
} else {
BoostingStep::new_with_max_samples(tc, detector, mts)
};
self.steps.push(step);
}
} else if n < current {
self.steps.truncate(n);
}
self.diag.contribution_accuracy.resize(n, 0.0);
self.config.n_steps = n;
}
/// Total tree replacements across all boosting steps.
pub fn total_replacements(&self) -> u64 {
self.steps.iter().map(|s| s.slot().replacements()).sum()
}
/// Manually trigger a proactive prune check.
///
/// Finds the worst mature tree (past grace period) and replaces it if its
/// contribution accuracy is negative (accuracy-based) or its prediction
/// variance is minimal (variance-based). The contribution accuracy EWMAs
/// are updated every sample inside `train_one()`; this method only performs
/// the replacement decision.
///
/// Returns `true` if a tree was replaced, `false` otherwise.
pub fn check_proactive_prune(&mut self) -> bool {
if self.steps.len() <= 1 {
return false;
}
if self.config.accuracy_based_pruning {
let grace_period = self.config.grace_period as u64;
let worst = self
.steps
.iter()
.enumerate()
.zip(self.diag.contribution_accuracy.iter())
.filter(|((_, step), _)| step.slot().n_samples_seen() >= grace_period)
.min_by(|((_, _), a), ((_, _), b)| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
});
if let Some(((worst_idx, _), &worst_acc)) = worst {
if worst_acc < 0.0 {
self.steps[worst_idx].slot_mut().replace_active();
self.diag.contribution_accuracy[worst_idx] = 0.0;
return true;
}
}
false
} else {
let worst_idx = self
.steps
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let a_std = a.slot().prediction_std();
let b_std = b.slot().prediction_std();
a_std
.partial_cmp(&b_std)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0);
self.steps[worst_idx].slot_mut().replace_active();
true
}
}
/// Dynamically set the contribution accuracy EWMA half-life.
///
/// Recomputes `prune_alpha` from the given half-life so each correction
/// batch contributes equally regardless of size.
pub fn set_prune_half_life(&mut self, hl: usize) {
self.diag.prune_alpha = 1.0 - (-2.0 / hl.max(1) as f64).exp();
}
}