use crate::learner::StreamingLearner;
use crate::learners::rls::RecursiveLeastSquares;
pub struct BoundedRls {
inner: RecursiveLeastSquares,
}
impl BoundedRls {
pub fn new(rls: RecursiveLeastSquares) -> Self {
Self { inner: rls }
}
pub fn into_inner(self) -> RecursiveLeastSquares {
self.inner
}
pub fn inner(&self) -> &RecursiveLeastSquares {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut RecursiveLeastSquares {
&mut self.inner
}
#[inline]
pub fn predict_clipped(&self, features: &[f64], lo: f64, hi: f64) -> f64 {
debug_assert!(lo <= hi, "BoundedRls: lo ({lo}) must be <= hi ({hi})");
let raw = self.inner.predict(features);
raw.clamp(lo, hi)
}
}
impl StreamingLearner for BoundedRls {
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
self.inner.train_one(features, target, weight);
}
fn predict(&self, features: &[f64]) -> f64 {
self.inner.predict(features)
}
fn n_samples_seen(&self) -> u64 {
self.inner.n_samples_seen()
}
fn reset(&mut self) {
self.inner.reset();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn in_range_prediction_passes_through() {
let mut bounded = BoundedRls::new(RecursiveLeastSquares::new(1.0));
for i in 0..100 {
let x = i as f64 * 0.1;
bounded.train(&[x], 2.0 * x);
}
let pred = bounded.predict_clipped(&[1.0], -10.0, 10.0);
assert!(
(pred - 2.0).abs() < 0.5,
"in-range prediction should be near 2.0, got {pred}"
);
assert!(
(-10.0..=10.0).contains(&pred),
"prediction {pred} must be within [-10, 10]"
);
}
#[test]
fn exceeds_hi_clipped_to_hi() {
let mut bounded = BoundedRls::new(RecursiveLeastSquares::new(1.0));
for i in 0..50 {
let x = i as f64;
bounded.train(&[x], 1000.0 * x);
}
let pred = bounded.predict_clipped(&[100.0], -1.0, 1.0);
assert!(
(pred - 1.0).abs() < 1e-9,
"expected hi clip = 1.0, got {pred}"
);
}
#[test]
fn below_lo_clipped_to_lo() {
let mut bounded = BoundedRls::new(RecursiveLeastSquares::new(1.0));
for i in 0..50 {
let x = i as f64;
bounded.train(&[x], -1000.0 * x);
}
let pred = bounded.predict_clipped(&[100.0], -1.0, 1.0);
assert!(
(pred - (-1.0)).abs() < 1e-9,
"expected lo clip = -1.0, got {pred}"
);
}
#[test]
fn cold_start_predict_clipped_returns_zero_within_bounds() {
let bounded = BoundedRls::new(RecursiveLeastSquares::new(1.0));
let pred = bounded.predict_clipped(&[1.0, 2.0], -5.0, 5.0);
assert_eq!(pred, 0.0, "cold-start RLS predicts 0.0");
}
#[test]
fn n_samples_seen_delegates_to_inner() {
let mut bounded = BoundedRls::new(RecursiveLeastSquares::new(1.0));
assert_eq!(bounded.n_samples_seen(), 0);
for i in 0..7 {
bounded.train(&[i as f64], i as f64);
}
assert_eq!(bounded.n_samples_seen(), 7);
}
}