use anyhow::{Result, bail};
pub struct LabeledFeature<'a> {
pub features: &'a [f32],
pub label: usize,
}
pub struct LinearClassifier {
pub hidden: usize,
pub num_classes: usize,
weight_t: Vec<f32>,
bias: Vec<f32>,
}
impl LinearClassifier {
pub fn new(hidden: usize, num_classes: usize) -> Self {
let scale = (2.0_f32 / hidden as f32).sqrt();
let mut weight_t = vec![0f32; hidden * num_classes];
let mut state: u32 = 0x9e37_79b9;
for w in weight_t.iter_mut() {
state = state.wrapping_mul(48_271).wrapping_add(0x9e37_79b9);
let u = (state >> 16) as f32 / 65536.0;
*w = (u * 2.0 - 1.0) * scale * 0.5;
}
Self {
hidden,
num_classes,
weight_t,
bias: vec![0f32; num_classes],
}
}
pub fn predict(&self, features: &[f32]) -> Result<usize> {
if features.len() != self.hidden {
bail!(
"LinearClassifier::predict: expected {} features, got {}",
self.hidden,
features.len()
);
}
let mut logits = vec![0f32; self.num_classes];
rlx_cpu::blas::sgemm_bias(
features,
&self.weight_t,
&self.bias,
&mut logits,
1,
self.hidden,
self.num_classes,
);
let mut best = 0usize;
let mut best_val = logits[0];
for (j, &v) in logits.iter().enumerate().skip(1) {
if v > best_val {
best_val = v;
best = j;
}
}
Ok(best)
}
pub fn accuracy(&self, examples: &[LabeledFeature<'_>]) -> Result<f32> {
if examples.is_empty() {
return Ok(0.0);
}
let n = examples.len();
let mut feats = vec![0f32; n * self.hidden];
let mut labels = vec![0usize; n];
for (i, ex) in examples.iter().enumerate() {
if ex.features.len() != self.hidden {
bail!(
"LinearClassifier::accuracy: row {i} has {} features, expected {}",
ex.features.len(),
self.hidden
);
}
feats[i * self.hidden..(i + 1) * self.hidden].copy_from_slice(ex.features);
labels[i] = ex.label;
}
let mut logits = vec![0f32; n * self.num_classes];
rlx_cpu::blas::sgemm_bias(
&feats,
&self.weight_t,
&self.bias,
&mut logits,
n,
self.hidden,
self.num_classes,
);
let mut correct = 0usize;
for i in 0..n {
let row = &logits[i * self.num_classes..(i + 1) * self.num_classes];
let pred = row
.iter()
.enumerate()
.fold(
(0usize, row[0]),
|(bi, bv), (j, &v)| {
if v > bv { (j, v) } else { (bi, bv) }
},
)
.0;
if pred == labels[i] {
correct += 1;
}
}
Ok(correct as f32 / n as f32)
}
}
#[derive(Debug, Clone)]
pub struct TrainConfig {
pub epochs: usize,
pub batch: usize,
pub lr: f32,
pub l2: f32,
pub momentum: f32,
}
impl Default for TrainConfig {
fn default() -> Self {
Self {
epochs: 20,
batch: 32,
lr: 0.1,
l2: 1e-4,
momentum: 0.9,
}
}
}
pub fn train_logreg(
hidden: usize,
num_classes: usize,
train: &[LabeledFeature<'_>],
cfg: &TrainConfig,
verbose: bool,
) -> Result<LinearClassifier> {
if train.is_empty() {
bail!("train_logreg: empty training set");
}
let mut clf = LinearClassifier::new(hidden, num_classes);
let n = train.len();
let mut feats = vec![0f32; n * hidden];
let mut labels = vec![0u32; n];
for (i, ex) in train.iter().enumerate() {
if ex.features.len() != hidden {
bail!(
"train row {i} has {} features, expected {hidden}",
ex.features.len()
);
}
if ex.label >= num_classes {
bail!(
"train row {i} label {} ≥ num_classes {num_classes}",
ex.label
);
}
feats[i * hidden..(i + 1) * hidden].copy_from_slice(ex.features);
labels[i] = ex.label as u32;
}
let mut perm: Vec<usize> = (0..n).collect();
let mut rng_state: u32 = 0x1234_5678;
let lcg = |s: &mut u32| -> u32 {
*s = s.wrapping_mul(48_271).wrapping_add(0x9e37_79b9);
*s
};
let mut vel_w = vec![0f32; hidden * num_classes];
let mut vel_b = vec![0f32; num_classes];
let mut logits = vec![0f32; cfg.batch * num_classes];
for epoch in 0..cfg.epochs {
for i in (1..n).rev() {
let j = (lcg(&mut rng_state) as usize) % (i + 1);
perm.swap(i, j);
}
let mut epoch_loss = 0f32;
let mut epoch_correct = 0usize;
let mut seen = 0usize;
for chunk in perm.chunks(cfg.batch) {
let bs = chunk.len();
let mut xb = vec![0f32; bs * hidden];
let mut yb = vec![0u32; bs];
for (i, &idx) in chunk.iter().enumerate() {
xb[i * hidden..(i + 1) * hidden]
.copy_from_slice(&feats[idx * hidden..(idx + 1) * hidden]);
yb[i] = labels[idx];
}
if logits.len() < bs * num_classes {
logits.resize(bs * num_classes, 0.0);
}
let logits_slice = &mut logits[..bs * num_classes];
rlx_cpu::blas::sgemm_bias(
&xb,
&clf.weight_t,
&clf.bias,
logits_slice,
bs,
hidden,
num_classes,
);
let mut delta = vec![0f32; bs * num_classes];
for i in 0..bs {
let row = &mut logits_slice[i * num_classes..(i + 1) * num_classes];
let max_logit = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0f32;
for v in row.iter_mut() {
*v = (*v - max_logit).exp();
sum += *v;
}
let inv = 1.0 / sum;
let mut argmax = 0usize;
let mut argmax_val = -1f32;
for (j, v) in row.iter_mut().enumerate() {
*v *= inv;
if *v > argmax_val {
argmax_val = *v;
argmax = j;
}
delta[i * num_classes + j] = *v;
}
let y = yb[i] as usize;
delta[i * num_classes + y] -= 1.0;
epoch_loss += -row[y].max(1e-12).ln();
if argmax == y {
epoch_correct += 1;
}
}
let inv_bs = 1.0 / bs as f32;
let mut grad_w = vec![0f32; hidden * num_classes];
for h_idx in 0..hidden {
for c_idx in 0..num_classes {
let mut s = 0f32;
for i in 0..bs {
s += xb[i * hidden + h_idx] * delta[i * num_classes + c_idx];
}
grad_w[h_idx * num_classes + c_idx] = s * inv_bs;
}
}
let mut grad_b = vec![0f32; num_classes];
for i in 0..bs {
for c_idx in 0..num_classes {
grad_b[c_idx] += delta[i * num_classes + c_idx];
}
}
for v in grad_b.iter_mut() {
*v *= inv_bs;
}
for j in 0..hidden * num_classes {
let g = grad_w[j] + cfg.l2 * clf.weight_t[j];
vel_w[j] = cfg.momentum * vel_w[j] + g;
clf.weight_t[j] -= cfg.lr * vel_w[j];
}
for j in 0..num_classes {
vel_b[j] = cfg.momentum * vel_b[j] + grad_b[j];
clf.bias[j] -= cfg.lr * vel_b[j];
}
seen += bs;
}
if verbose {
let acc = epoch_correct as f32 / seen as f32;
let loss = epoch_loss / seen as f32;
eprintln!(
"[clf] epoch {:>3}: train_loss={:.4} train_acc={:.4}",
epoch + 1,
loss,
acc
);
}
}
Ok(clf)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn logreg_separates_three_clusters() {
let hidden = 2;
let num_classes = 3;
let centroids = [[0.0_f32, 0.0], [3.0, 0.0], [0.0, 3.0]];
let mut features: Vec<Vec<f32>> = Vec::new();
let mut labels: Vec<usize> = Vec::new();
for (c, ctr) in centroids.iter().enumerate() {
for k in 0..40 {
let jitter = (k as f32 * 0.07) - 1.4;
features.push(vec![ctr[0] + jitter, ctr[1] + jitter * 0.5]);
labels.push(c);
}
}
let train: Vec<LabeledFeature> = features
.iter()
.zip(&labels)
.map(|(f, l)| LabeledFeature {
features: f.as_slice(),
label: *l,
})
.collect();
let cfg = TrainConfig {
epochs: 100,
batch: 16,
lr: 0.2,
l2: 0.0,
momentum: 0.9,
};
let clf = train_logreg(hidden, num_classes, &train, &cfg, false).unwrap();
let acc = clf.accuracy(&train).unwrap();
assert!(acc > 0.98, "got {acc}");
}
}