use irithyll_core::math::sigmoid;
use crate::learner::StreamingLearner;
use crate::learners::rls::RecursiveLeastSquares;
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub enum ClassificationMode {
Regression,
Binary,
Multiclass {
n_classes: usize,
},
}
fn stable_softmax(logits: &[f64]) -> Vec<f64> {
let max_logit = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = logits.iter().map(|&z| (z - max_logit).exp()).collect();
let sum: f64 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
pub struct ClassificationWrapper {
inner: Box<dyn StreamingLearner>,
mode: ClassificationMode,
extra_heads: Vec<RecursiveLeastSquares>,
samples_seen: u64,
}
impl ClassificationWrapper {
pub fn binary(model: Box<dyn StreamingLearner>) -> Self {
Self {
inner: model,
mode: ClassificationMode::Binary,
extra_heads: Vec::new(),
samples_seen: 0,
}
}
pub fn multiclass(model: Box<dyn StreamingLearner>, n_classes: usize) -> Self {
assert!(
n_classes >= 2,
"multiclass requires n_classes >= 2, got {n_classes}"
);
let extra_heads = (0..n_classes - 1)
.map(|_| RecursiveLeastSquares::new(0.99))
.collect();
Self {
inner: model,
mode: ClassificationMode::Multiclass { n_classes },
extra_heads,
samples_seen: 0,
}
}
pub fn mode(&self) -> ClassificationMode {
self.mode
}
pub fn predict_proba(&self, features: &[f64]) -> Vec<f64> {
match self.mode {
ClassificationMode::Regression => {
vec![self.inner.predict(features)]
}
ClassificationMode::Binary => {
let raw = self.inner.predict(features);
let p1 = sigmoid(raw);
vec![1.0 - p1, p1]
}
ClassificationMode::Multiclass { n_classes } => {
let mut logits = Vec::with_capacity(n_classes);
logits.push(self.inner.predict(features));
for head in &self.extra_heads {
logits.push(head.predict(features));
}
stable_softmax(&logits)
}
}
}
}
impl StreamingLearner for ClassificationWrapper {
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
self.samples_seen += 1;
match self.mode {
ClassificationMode::Regression => {
self.inner.train_one(features, target, weight);
}
ClassificationMode::Binary => {
let bipolar = if target > 0.5 { 1.0 } else { -1.0 };
self.inner.train_one(features, bipolar, weight);
}
ClassificationMode::Multiclass { n_classes } => {
let class_idx = target as usize;
let target_0 = if class_idx == 0 { 1.0 } else { -1.0 };
self.inner.train_one(features, target_0, weight);
for (k_minus_1, head) in self.extra_heads.iter_mut().enumerate() {
let class_k = k_minus_1 + 1;
let target_k = if class_idx == class_k { 1.0 } else { -1.0 };
head.train_one(features, target_k, weight);
}
debug_assert!(
class_idx < n_classes,
"class index {} out of range for {} classes",
class_idx,
n_classes,
);
}
}
}
fn predict(&self, features: &[f64]) -> f64 {
match self.mode {
ClassificationMode::Regression => self.inner.predict(features),
ClassificationMode::Binary => {
let raw = self.inner.predict(features);
if raw >= 0.0 {
1.0
} else {
0.0
}
}
ClassificationMode::Multiclass { n_classes } => {
let mut logits = Vec::with_capacity(n_classes);
logits.push(self.inner.predict(features));
for head in &self.extra_heads {
logits.push(head.predict(features));
}
let proba = stable_softmax(&logits);
proba
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(core::cmp::Ordering::Equal))
.map(|(idx, _)| idx as f64)
.unwrap_or(0.0)
}
}
}
fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
fn reset(&mut self) {
self.inner.reset();
for head in &mut self.extra_heads {
head.reset();
}
self.samples_seen = 0;
}
}
impl core::fmt::Debug for ClassificationWrapper {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ClassificationWrapper")
.field("mode", &self.mode)
.field("samples_seen", &self.samples_seen)
.field("n_extra_heads", &self.extra_heads.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::learners::rls::RecursiveLeastSquares;
#[test]
fn sigmoid_at_zero_is_half() {
let result = sigmoid(0.0);
assert!(
(result - 0.5).abs() < 1e-12,
"sigmoid(0) should be 0.5, got {result}"
);
}
#[test]
fn sigmoid_extreme_values_are_finite() {
let p_high = sigmoid(1000.0);
let p_low = sigmoid(-1000.0);
assert!(p_high.is_finite(), "sigmoid(1000) should be finite");
assert!(p_low.is_finite(), "sigmoid(-1000) should be finite");
assert!(
(p_high - 1.0).abs() < 1e-10,
"sigmoid(1000) should be ~1.0, got {p_high}"
);
assert!(
p_low.abs() < 1e-10,
"sigmoid(-1000) should be ~0.0, got {p_low}"
);
}
#[test]
fn softmax_uniform_logits_are_equal() {
let logits = vec![1.0, 1.0, 1.0];
let proba = stable_softmax(&logits);
assert_eq!(proba.len(), 3, "softmax output should have 3 elements");
for p in &proba {
assert!(
(p - 1.0 / 3.0).abs() < 1e-10,
"uniform logits should give equal probabilities, got {p}"
);
}
}
#[test]
fn softmax_sums_to_one() {
let logits = vec![2.0, 1.0, 0.1, -1.0];
let proba = stable_softmax(&logits);
let sum: f64 = proba.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"softmax probabilities should sum to 1.0, got {sum}"
);
}
#[test]
fn softmax_extreme_logits_are_stable() {
let logits = vec![1000.0, 0.0, -1000.0];
let proba = stable_softmax(&logits);
assert!(
proba.iter().all(|p| p.is_finite()),
"softmax should be finite for extreme logits"
);
assert!(
(proba[0] - 1.0).abs() < 1e-10,
"dominant logit should have probability ~1.0, got {}",
proba[0]
);
}
#[test]
fn classification_mode_equality() {
assert_eq!(ClassificationMode::Binary, ClassificationMode::Binary);
assert_eq!(
ClassificationMode::Multiclass { n_classes: 3 },
ClassificationMode::Multiclass { n_classes: 3 }
);
assert_ne!(ClassificationMode::Binary, ClassificationMode::Regression);
}
#[test]
fn binary_wrapper_returns_zero_or_one() {
let model = RecursiveLeastSquares::new(0.99);
let mut clf = ClassificationWrapper::binary(Box::new(model));
for i in 0..100 {
let x = i as f64 * 0.1;
let label = if x > 5.0 { 1.0 } else { 0.0 };
clf.train(&[x], label);
}
let pred = clf.predict(&[8.0]);
assert!(
pred == 0.0 || pred == 1.0,
"binary predict should return 0.0 or 1.0, got {pred}"
);
}
#[test]
fn binary_wrapper_predict_proba_returns_two_classes() {
let model = RecursiveLeastSquares::new(0.99);
let clf = ClassificationWrapper::binary(Box::new(model));
let proba = clf.predict_proba(&[1.0, 2.0]);
assert_eq!(
proba.len(),
2,
"binary predict_proba should return 2 probabilities"
);
let sum: f64 = proba.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"binary probabilities should sum to 1.0, got {sum}"
);
}
#[test]
fn binary_wrapper_learns_sine_classification() {
let model = RecursiveLeastSquares::new(0.998);
let mut clf = ClassificationWrapper::binary(Box::new(model));
for i in 0..500 {
let x = (i as f64) * 0.05;
let label = if x.sin() > 0.0 { 1.0 } else { 0.0 };
clf.train(&[x.sin(), x.cos()], label);
}
let mut correct = 0;
let test_points = 50;
for i in 0..test_points {
let x = (i as f64) * 0.1 + 0.05; let expected = if x.sin() > 0.0 { 1.0 } else { 0.0 };
let pred = clf.predict(&[x.sin(), x.cos()]);
if (pred - expected).abs() < 1e-10 {
correct += 1;
}
}
let accuracy = correct as f64 / test_points as f64;
assert!(
accuracy > 0.7,
"binary sine classification accuracy should be > 70%, got {:.1}%",
accuracy * 100.0
);
}
#[test]
#[should_panic(expected = "n_classes >= 2")]
fn multiclass_panics_on_fewer_than_two_classes() {
let model = RecursiveLeastSquares::new(0.99);
let _ = ClassificationWrapper::multiclass(Box::new(model), 1);
}
#[test]
fn multiclass_wrapper_returns_valid_class_index() {
let model = RecursiveLeastSquares::new(0.99);
let mut clf = ClassificationWrapper::multiclass(Box::new(model), 3);
for i in 0..60 {
let class = (i % 3) as f64;
let x0 = if i % 3 == 0 { 1.0 } else { 0.0 };
let x1 = if i % 3 == 1 { 1.0 } else { 0.0 };
let x2 = if i % 3 == 2 { 1.0 } else { 0.0 };
clf.train(&[x0, x1, x2], class);
}
let pred = clf.predict(&[1.0, 0.0, 0.0]);
assert!(
(0.0..3.0).contains(&pred),
"multiclass predict should return class index in [0, 3), got {pred}"
);
assert!(
(pred - pred.round()).abs() < 1e-10,
"multiclass predict should return an integer class index, got {pred}"
);
}
#[test]
fn multiclass_predict_proba_returns_k_probabilities() {
let model = RecursiveLeastSquares::new(0.99);
let clf = ClassificationWrapper::multiclass(Box::new(model), 4);
let proba = clf.predict_proba(&[1.0, 2.0, 3.0]);
assert_eq!(
proba.len(),
4,
"multiclass(4) predict_proba should return 4 probabilities"
);
let sum: f64 = proba.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"multiclass probabilities should sum to 1.0, got {sum}"
);
}
#[test]
fn multiclass_learns_three_class_data() {
let model = RecursiveLeastSquares::new(0.998);
let mut clf = ClassificationWrapper::multiclass(Box::new(model), 3);
for _ in 0..200 {
clf.train(&[1.0, 0.0, 0.0], 0.0);
clf.train(&[0.0, 1.0, 0.0], 1.0);
clf.train(&[0.0, 0.0, 1.0], 2.0);
}
let pred_0 = clf.predict(&[1.0, 0.0, 0.0]);
let pred_1 = clf.predict(&[0.0, 1.0, 0.0]);
let pred_2 = clf.predict(&[0.0, 0.0, 1.0]);
assert!(
(pred_0 - 0.0).abs() < 1e-10,
"pure class 0 features should predict class 0, got {pred_0}"
);
assert!(
(pred_1 - 1.0).abs() < 1e-10,
"pure class 1 features should predict class 1, got {pred_1}"
);
assert!(
(pred_2 - 2.0).abs() < 1e-10,
"pure class 2 features should predict class 2, got {pred_2}"
);
}
#[test]
fn wrapper_tracks_samples_seen() {
let model = RecursiveLeastSquares::new(0.99);
let mut clf = ClassificationWrapper::binary(Box::new(model));
assert_eq!(
clf.n_samples_seen(),
0,
"fresh wrapper should have 0 samples"
);
clf.train(&[1.0], 1.0);
clf.train(&[2.0], 0.0);
assert_eq!(clf.n_samples_seen(), 2, "wrapper should track samples seen");
}
#[test]
fn wrapper_reset_clears_state() {
let model = RecursiveLeastSquares::new(0.99);
let mut clf = ClassificationWrapper::binary(Box::new(model));
clf.train(&[1.0], 1.0);
clf.train(&[2.0], 0.0);
clf.reset();
assert_eq!(
clf.n_samples_seen(),
0,
"samples_seen should be 0 after reset"
);
}
#[test]
fn multiclass_reset_clears_all_heads() {
let model = RecursiveLeastSquares::new(0.99);
let mut clf = ClassificationWrapper::multiclass(Box::new(model), 3);
for i in 0..30 {
clf.train(&[1.0, 0.0], (i % 3) as f64);
}
assert_eq!(clf.n_samples_seen(), 30);
clf.reset();
assert_eq!(clf.n_samples_seen(), 0, "reset should clear all state");
let proba = clf.predict_proba(&[1.0, 0.0]);
assert_eq!(
proba.len(),
3,
"predict_proba should still return 3 classes after reset"
);
}
#[test]
fn wrapper_mode_accessor() {
let model = RecursiveLeastSquares::new(0.99);
let clf = ClassificationWrapper::binary(Box::new(model));
assert_eq!(clf.mode(), ClassificationMode::Binary);
let model2 = RecursiveLeastSquares::new(0.99);
let clf2 = ClassificationWrapper::multiclass(Box::new(model2), 5);
assert_eq!(clf2.mode(), ClassificationMode::Multiclass { n_classes: 5 });
}
#[test]
fn wrapper_debug_format() {
let model = RecursiveLeastSquares::new(0.99);
let clf = ClassificationWrapper::binary(Box::new(model));
let debug = format!("{:?}", clf);
assert!(
debug.contains("ClassificationWrapper"),
"debug output should contain struct name, got: {debug}"
);
assert!(
debug.contains("Binary"),
"debug output should contain mode, got: {debug}"
);
}
#[test]
fn regression_mode_is_passthrough() {
let model = RecursiveLeastSquares::new(0.99);
let mut clf = ClassificationWrapper {
inner: Box::new(model),
mode: ClassificationMode::Regression,
extra_heads: Vec::new(),
samples_seen: 0,
};
for i in 0..100 {
let x = i as f64 * 0.1;
clf.train(&[x], 2.0 * x);
}
let pred = clf.predict(&[5.0]);
assert!(
(pred - 10.0).abs() < 0.5,
"regression passthrough should approximate y=2x, got {pred}"
);
}
}