use crate::error::{ConfigError, IrithyllError, Result};
use crate::learner::{StreamingLearner, Tunable};
use crate::preprocessing::IncrementalNormalizer;
use crate::projection::SubspaceTracker;
use std::fmt;
#[derive(Clone, Debug)]
pub struct ProjectionConfig {
pub rank: usize,
pub lambda: f64,
pub delta: f64,
pub warmup: usize,
pub seed: u64,
pub supervised_lr: f64,
}
impl Default for ProjectionConfig {
fn default() -> Self {
Self {
rank: 8,
lambda: 0.9999,
delta: 100.0,
warmup: 200,
seed: 42,
supervised_lr: 0.001,
}
}
}
impl ProjectionConfig {
pub fn builder() -> ProjectionConfigBuilder {
ProjectionConfigBuilder::new()
}
}
#[derive(Debug, Clone)]
pub struct ProjectionConfigBuilder {
config: ProjectionConfig,
}
impl ProjectionConfigBuilder {
pub fn new() -> Self {
Self {
config: ProjectionConfig::default(),
}
}
pub fn rank(mut self, rank: usize) -> Self {
self.config.rank = rank;
self
}
pub fn lambda(mut self, lambda: f64) -> Self {
self.config.lambda = lambda;
self
}
pub fn delta(mut self, delta: f64) -> Self {
self.config.delta = delta;
self
}
pub fn warmup(mut self, warmup: usize) -> Self {
self.config.warmup = warmup;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.config.seed = seed;
self
}
pub fn supervised_lr(mut self, lr: f64) -> Self {
self.config.supervised_lr = lr;
self
}
pub fn build(self) -> Result<ProjectionConfig> {
let c = &self.config;
if c.rank < 1 {
return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
"rank",
"must be >= 1",
c.rank,
)));
}
if c.lambda <= 0.0 || c.lambda > 1.0 {
return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
"lambda",
"must be in (0, 1]",
c.lambda,
)));
}
if c.delta <= 0.0 {
return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
"delta",
"must be > 0",
c.delta,
)));
}
if c.seed == 0 {
return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
"seed",
"must be non-zero (xorshift64)",
c.seed,
)));
}
if c.supervised_lr <= 0.0 || c.supervised_lr > 1.0 {
return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
"supervised_lr",
"must be in (0, 1]",
c.supervised_lr,
)));
}
Ok(self.config)
}
}
impl Default for ProjectionConfigBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct ProjectedLearner {
inner: Box<dyn StreamingLearner>,
tracker: SubspaceTracker,
normalizer: IncrementalNormalizer,
config: ProjectionConfig,
n_samples: u64,
}
impl ProjectedLearner {
pub fn new(inner: Box<dyn StreamingLearner>, d_in: usize, config: ProjectionConfig) -> Self {
let tracker =
SubspaceTracker::new(d_in, config.rank, config.lambda, config.delta, config.seed);
let normalizer = IncrementalNormalizer::new();
Self {
inner,
tracker,
normalizer,
config,
n_samples: 0,
}
}
pub fn from_learner(
inner: impl StreamingLearner + 'static,
d_in: usize,
config: ProjectionConfig,
) -> Self {
Self::new(Box::new(inner), d_in, config)
}
#[inline]
pub fn tracker(&self) -> &SubspaceTracker {
&self.tracker
}
#[inline]
pub fn inner(&self) -> &dyn StreamingLearner {
&*self.inner
}
#[inline]
pub fn inner_mut(&mut self) -> &mut dyn StreamingLearner {
&mut *self.inner
}
#[inline]
pub fn config(&self) -> &ProjectionConfig {
&self.config
}
#[inline]
pub fn warmup_complete(&self) -> bool {
self.n_samples >= self.config.warmup as u64
}
}
impl StreamingLearner for ProjectedLearner {
#[allow(deprecated)]
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
self.normalizer.update(features);
let normed = self.normalizer.transform(features);
let projected = self.tracker.project(&normed);
self.inner.train_one(&projected, target, weight);
if self.n_samples >= self.config.warmup as u64 {
if let Some(beta) = self.inner.readout_weights() {
if beta.len() == self.tracker.rank() {
let pred = self.inner.predict(&projected);
let residual = target - pred;
let post_warmup = self.n_samples - self.config.warmup as u64;
let ramp = (post_warmup as f64 / 1000.0).min(1.0);
let lr = self.config.supervised_lr * ramp;
self.tracker.supervised_update(&normed, residual, beta, lr);
} else {
self.tracker.update(&normed, 0.0);
}
} else {
self.tracker.update(&normed, 0.0);
}
}
self.n_samples += 1;
}
fn predict(&self, features: &[f64]) -> f64 {
if self.normalizer.count() == 0 {
let projected = self.tracker.project(features);
return self.inner.predict(&projected);
}
let normed = self.normalizer.transform(features);
let projected = self.tracker.project(&normed);
self.inner.predict(&projected)
}
#[inline]
fn n_samples_seen(&self) -> u64 {
self.n_samples
}
fn reset(&mut self) {
self.inner.reset();
self.tracker.reset();
self.normalizer = IncrementalNormalizer::new();
self.n_samples = 0;
}
#[allow(deprecated)]
fn diagnostics_array(&self) -> [f64; 5] {
self.inner.diagnostics_array()
}
#[allow(deprecated)]
fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
self.inner.adjust_config(lr_multiplier, lambda_delta);
}
#[allow(deprecated)]
fn apply_structural_change(&mut self, depth_delta: i32, steps_delta: i32) {
self.inner.apply_structural_change(depth_delta, steps_delta);
}
#[allow(deprecated)]
fn replacement_count(&self) -> u64 {
self.inner.replacement_count()
}
#[allow(deprecated)]
fn check_proactive_prune(&mut self) -> bool {
self.inner.check_proactive_prune()
}
#[allow(deprecated)]
fn set_prune_half_life(&mut self, hl: usize) {
self.inner.set_prune_half_life(hl);
}
#[allow(deprecated)]
fn readout_weights(&self) -> Option<&[f64]> {
self.inner.readout_weights()
}
}
impl crate::automl::DiagnosticSource for ProjectedLearner {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
None
}
}
impl Tunable for ProjectedLearner {
#[allow(deprecated)]
fn diagnostics_array(&self) -> [f64; 5] {
self.inner.diagnostics_array()
}
#[allow(deprecated)]
fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
self.inner.adjust_config(lr_multiplier, lambda_delta);
}
}
impl fmt::Debug for ProjectedLearner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ProjectedLearner")
.field("d_in", &self.tracker.d_in())
.field("rank", &self.config.rank)
.field("warmup", &self.config.warmup)
.field("n_samples", &self.n_samples)
.field("warmup_complete", &self.warmup_complete())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MeanLearner {
sum: f64,
count: u64,
}
impl MeanLearner {
fn new() -> Self {
Self { sum: 0.0, count: 0 }
}
}
impl StreamingLearner for MeanLearner {
fn train_one(&mut self, _features: &[f64], target: f64, _weight: f64) {
self.sum += target;
self.count += 1;
}
fn predict(&self, _features: &[f64]) -> f64 {
if self.count == 0 {
return 0.0;
}
self.sum / self.count as f64
}
fn n_samples_seen(&self) -> u64 {
self.count
}
fn reset(&mut self) {
self.sum = 0.0;
self.count = 0;
}
}
#[test]
fn default_config_builds() {
let config = ProjectionConfig::builder().build().unwrap();
assert_eq!(config.rank, 8);
assert!((config.lambda - 0.9999).abs() < 1e-12);
assert!((config.delta - 100.0).abs() < 1e-12);
assert_eq!(config.warmup, 200);
assert_eq!(config.seed, 42);
}
#[test]
fn custom_config_builds() {
let config = ProjectionConfig::builder()
.rank(4)
.lambda(0.995)
.delta(50.0)
.warmup(20)
.seed(123)
.build()
.unwrap();
assert_eq!(config.rank, 4);
assert!((config.lambda - 0.995).abs() < 1e-12);
assert!((config.delta - 50.0).abs() < 1e-12);
assert_eq!(config.warmup, 20);
assert_eq!(config.seed, 123);
}
#[test]
fn zero_rank_fails() {
let result = ProjectionConfig::builder().rank(0).build();
assert!(result.is_err(), "rank=0 should fail validation");
}
#[test]
fn lambda_zero_fails() {
let result = ProjectionConfig::builder().lambda(0.0).build();
assert!(result.is_err(), "lambda=0 should fail validation");
}
#[test]
fn lambda_above_one_fails() {
let result = ProjectionConfig::builder().lambda(1.01).build();
assert!(result.is_err(), "lambda>1 should fail validation");
}
#[test]
fn delta_zero_fails() {
let result = ProjectionConfig::builder().delta(0.0).build();
assert!(result.is_err(), "delta=0 should fail validation");
}
#[test]
fn seed_zero_fails() {
let result = ProjectionConfig::builder().seed(0).build();
assert!(result.is_err(), "seed=0 should fail validation");
}
fn make_projected(d_in: usize, rank: usize, warmup: usize) -> ProjectedLearner {
let config = ProjectionConfig::builder()
.rank(rank)
.warmup(warmup)
.build()
.unwrap();
ProjectedLearner::from_learner(MeanLearner::new(), d_in, config)
}
#[test]
fn train_and_predict_basic() {
let mut model = make_projected(6, 3, 5);
for i in 0..20 {
let x = vec![i as f64 * 0.1; 6];
model.train(&x, i as f64);
}
assert_eq!(model.n_samples_seen(), 20);
let pred = model.predict(&[0.5; 6]);
assert!(
pred.is_finite(),
"prediction should be finite, got {}",
pred
);
}
#[test]
fn predict_before_training_returns_finite() {
let model = make_projected(4, 2, 5);
let pred = model.predict(&[1.0; 4]);
assert!(
pred.is_finite(),
"predict before training should return finite value, got {}",
pred
);
}
#[test]
fn warmup_transition_is_seamless() {
let warmup = 10;
let mut model = make_projected(8, 4, warmup);
for i in 0..warmup + 20 {
let x = vec![(i as f64 * 0.1).sin(); 8];
model.train(&x, i as f64);
}
assert!(model.warmup_complete(), "warmup should be complete");
assert_eq!(model.n_samples_seen(), (warmup + 20) as u64);
assert!(
model.tracker().n_samples() > 0,
"tracker should have samples after warmup"
);
}
#[test]
fn reset_clears_everything() {
let mut model = make_projected(6, 3, 5);
for i in 0..30 {
model.train(&[i as f64; 6], i as f64);
}
assert!(model.n_samples_seen() > 0);
assert!(model.warmup_complete());
model.reset();
assert_eq!(
model.n_samples_seen(),
0,
"n_samples should be 0 after reset"
);
assert!(
!model.warmup_complete(),
"warmup should not be complete after reset"
);
assert_eq!(model.tracker().n_samples(), 0, "tracker should be reset");
assert_eq!(
model.inner().n_samples_seen(),
0,
"inner model should be reset"
);
}
#[test]
fn accessors_work() {
let model = make_projected(10, 4, 20);
assert_eq!(model.tracker().d_in(), 10);
assert_eq!(model.tracker().rank(), 4);
assert_eq!(model.config().rank, 4);
assert_eq!(model.config().warmup, 20);
assert_eq!(model.inner().n_samples_seen(), 0);
}
#[test]
fn from_boxed_constructor_works() {
let inner: Box<dyn StreamingLearner> = Box::new(MeanLearner::new());
let config = ProjectionConfig::builder().rank(3).build().unwrap();
let mut model = ProjectedLearner::new(inner, 6, config);
model.train(&[1.0; 6], 5.0);
model.train(&[2.0; 6], 10.0);
assert_eq!(model.n_samples_seen(), 2);
let pred = model.predict(&[1.5; 6]);
assert!(pred.is_finite());
}
#[test]
fn as_trait_object() {
let model = make_projected(6, 3, 5);
let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
boxed.train(&[1.0; 6], 7.0);
assert_eq!(boxed.n_samples_seen(), 1);
let pred = boxed.predict(&[1.0; 6]);
assert!(
pred.is_finite(),
"trait object predict should work: got {}",
pred
);
}
#[test]
fn debug_format_is_informative() {
let model = make_projected(10, 4, 20);
let debug = format!("{:?}", model);
assert!(
debug.contains("ProjectedLearner"),
"debug should contain struct name"
);
assert!(debug.contains("rank"), "debug should contain rank field");
assert!(debug.contains("d_in"), "debug should contain d_in field");
}
#[test]
fn inner_model_sees_rank_dim_features() {
let mut model = make_projected(20, 5, 10);
for i in 0..50 {
let x = vec![i as f64 * 0.02; 20];
model.train(&x, i as f64);
}
assert_eq!(model.inner().n_samples_seen(), 50);
assert_eq!(model.tracker().n_samples(), 40); }
#[test]
fn predict_reads_current_input() {
use crate::learners::RecursiveLeastSquares;
let config = ProjectionConfig::builder()
.rank(2)
.warmup(0)
.seed(42)
.build()
.unwrap();
let inner = RecursiveLeastSquares::new(0.999);
let mut model = ProjectedLearner::from_learner(inner, 4, config);
for i in 0..20 {
let x = [i as f64 * 0.1, (i as f64).sin(), (i as f64).cos(), 1.0];
model.train(&x, x[0] + x[1]);
}
let out_a = model.predict(&[0.0, 0.0, 1.0, 1.0]);
let out_b = model.predict(&[10.0, 10.0, 1.0, 1.0]);
assert!(
out_a.is_finite() && out_b.is_finite(),
"both predictions must be finite: out_a={out_a}, out_b={out_b}"
);
assert!(
(out_a - out_b).abs() > 1e-9,
"predict must read current x_t, got identical {out_a} for different inputs"
);
}
#[test]
fn predict_is_side_effect_free() {
let mut model = make_projected(6, 3, 5);
for i in 0..10 {
model.train(&[i as f64; 6], i as f64);
}
let n_before = model.n_samples_seen();
let tracker_n_before = model.tracker().n_samples();
let _ = model.predict(&[0.5; 6]);
let _ = model.predict(&[0.5; 6]);
let _ = model.predict(&[0.5; 6]);
assert_eq!(
model.n_samples_seen(),
n_before,
"predict should not change n_samples"
);
assert_eq!(
model.tracker().n_samples(),
tracker_n_before,
"predict should not change tracker samples"
);
}
}