Skip to main content

rlx_clinicalbert/
classifier.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Multinomial logistic regression on top of frozen ClinicalBERT pooler
17//! features — for sentence-pair classification benchmarks (MedNLI).
18//!
19//! Pure Rust, dense FP32, mini-batch SGD with momentum + L2 weight decay.
20//! Uses `rlx_cpu::blas::sgemm_bias` for the forward matmul so a 14k-example
21//! training run finishes in seconds.
22
23use anyhow::{Result, bail};
24
25/// One row of feature `[hidden]` + integer label.
26pub struct LabeledFeature<'a> {
27    pub features: &'a [f32],
28    pub label: usize,
29}
30
31/// Trained `[num_classes, hidden]` weight matrix + per-class bias.
32pub struct LinearClassifier {
33    pub hidden: usize,
34    pub num_classes: usize,
35    /// `[num_classes, hidden]` row-major. Stored as `[hidden, num_classes]`
36    /// transposed (so `features @ weight_t = logits`).
37    weight_t: Vec<f32>,
38    bias: Vec<f32>,
39}
40
41impl LinearClassifier {
42    /// Initialize weights with a small Gaussian scale (Xavier-ish).
43    pub fn new(hidden: usize, num_classes: usize) -> Self {
44        let scale = (2.0_f32 / hidden as f32).sqrt();
45        let mut weight_t = vec![0f32; hidden * num_classes];
46        // Deterministic pseudo-random — Park-Miller LCG so runs are
47        // reproducible without pulling in a `rand` dependency.
48        let mut state: u32 = 0x9e37_79b9;
49        for w in weight_t.iter_mut() {
50            state = state.wrapping_mul(48_271).wrapping_add(0x9e37_79b9);
51            let u = (state >> 16) as f32 / 65536.0;
52            // Box-Muller-lite: shift so range is roughly [-1, 1).
53            *w = (u * 2.0 - 1.0) * scale * 0.5;
54        }
55        Self {
56            hidden,
57            num_classes,
58            weight_t,
59            bias: vec![0f32; num_classes],
60        }
61    }
62
63    /// Predict the argmax class for one example.
64    pub fn predict(&self, features: &[f32]) -> Result<usize> {
65        if features.len() != self.hidden {
66            bail!(
67                "LinearClassifier::predict: expected {} features, got {}",
68                self.hidden,
69                features.len()
70            );
71        }
72        let mut logits = vec![0f32; self.num_classes];
73        // logits = features @ weight_t + bias
74        rlx_cpu::blas::sgemm_bias(
75            features,
76            &self.weight_t,
77            &self.bias,
78            &mut logits,
79            1,
80            self.hidden,
81            self.num_classes,
82        );
83        let mut best = 0usize;
84        let mut best_val = logits[0];
85        for (j, &v) in logits.iter().enumerate().skip(1) {
86            if v > best_val {
87                best_val = v;
88                best = j;
89            }
90        }
91        Ok(best)
92    }
93
94    /// Batched argmax accuracy on a frozen dataset.
95    pub fn accuracy(&self, examples: &[LabeledFeature<'_>]) -> Result<f32> {
96        if examples.is_empty() {
97            return Ok(0.0);
98        }
99        // Pack all features into a contiguous `[N, hidden]` matrix and do a
100        // single GEMM. Cheap on CPU; saves N kernel-launch overhead.
101        let n = examples.len();
102        let mut feats = vec![0f32; n * self.hidden];
103        let mut labels = vec![0usize; n];
104        for (i, ex) in examples.iter().enumerate() {
105            if ex.features.len() != self.hidden {
106                bail!(
107                    "LinearClassifier::accuracy: row {i} has {} features, expected {}",
108                    ex.features.len(),
109                    self.hidden
110                );
111            }
112            feats[i * self.hidden..(i + 1) * self.hidden].copy_from_slice(ex.features);
113            labels[i] = ex.label;
114        }
115        let mut logits = vec![0f32; n * self.num_classes];
116        rlx_cpu::blas::sgemm_bias(
117            &feats,
118            &self.weight_t,
119            &self.bias,
120            &mut logits,
121            n,
122            self.hidden,
123            self.num_classes,
124        );
125        let mut correct = 0usize;
126        for i in 0..n {
127            let row = &logits[i * self.num_classes..(i + 1) * self.num_classes];
128            let pred = row
129                .iter()
130                .enumerate()
131                .fold(
132                    (0usize, row[0]),
133                    |(bi, bv), (j, &v)| {
134                        if v > bv { (j, v) } else { (bi, bv) }
135                    },
136                )
137                .0;
138            if pred == labels[i] {
139                correct += 1;
140            }
141        }
142        Ok(correct as f32 / n as f32)
143    }
144}
145
146/// Training hyperparameters with sensible defaults for short, dense probes.
147#[derive(Debug, Clone)]
148pub struct TrainConfig {
149    /// Number of full passes over the training set.
150    pub epochs: usize,
151    /// Mini-batch size.
152    pub batch: usize,
153    /// SGD step size.
154    pub lr: f32,
155    /// L2 weight decay coefficient.
156    pub l2: f32,
157    /// Momentum (0.0 to disable).
158    pub momentum: f32,
159}
160
161impl Default for TrainConfig {
162    fn default() -> Self {
163        Self {
164            epochs: 20,
165            batch: 32,
166            lr: 0.1,
167            l2: 1e-4,
168            momentum: 0.9,
169        }
170    }
171}
172
173/// Train a multinomial logistic regression on frozen features with mini-batch
174/// SGD. Each row of `train` is one `(features [H], label)` pair. Reports the
175/// final training accuracy and prints a per-epoch summary when `verbose`.
176pub fn train_logreg(
177    hidden: usize,
178    num_classes: usize,
179    train: &[LabeledFeature<'_>],
180    cfg: &TrainConfig,
181    verbose: bool,
182) -> Result<LinearClassifier> {
183    if train.is_empty() {
184        bail!("train_logreg: empty training set");
185    }
186    let mut clf = LinearClassifier::new(hidden, num_classes);
187
188    // Pre-pack training features into a contiguous `[N, hidden]` so each
189    // mini-batch is a simple slice + sgemm_bias.
190    let n = train.len();
191    let mut feats = vec![0f32; n * hidden];
192    let mut labels = vec![0u32; n];
193    for (i, ex) in train.iter().enumerate() {
194        if ex.features.len() != hidden {
195            bail!(
196                "train row {i} has {} features, expected {hidden}",
197                ex.features.len()
198            );
199        }
200        if ex.label >= num_classes {
201            bail!(
202                "train row {i} label {} ≥ num_classes {num_classes}",
203                ex.label
204            );
205        }
206        feats[i * hidden..(i + 1) * hidden].copy_from_slice(ex.features);
207        labels[i] = ex.label as u32;
208    }
209
210    // Mini-batch index permutation (Park-Miller LCG, reproducible).
211    let mut perm: Vec<usize> = (0..n).collect();
212    let mut rng_state: u32 = 0x1234_5678;
213    let lcg = |s: &mut u32| -> u32 {
214        *s = s.wrapping_mul(48_271).wrapping_add(0x9e37_79b9);
215        *s
216    };
217
218    // Momentum buffers, same shape as parameters.
219    let mut vel_w = vec![0f32; hidden * num_classes];
220    let mut vel_b = vec![0f32; num_classes];
221    // Scratch for logits and softmax.
222    let mut logits = vec![0f32; cfg.batch * num_classes];
223
224    for epoch in 0..cfg.epochs {
225        // Shuffle perm.
226        for i in (1..n).rev() {
227            let j = (lcg(&mut rng_state) as usize) % (i + 1);
228            perm.swap(i, j);
229        }
230
231        let mut epoch_loss = 0f32;
232        let mut epoch_correct = 0usize;
233        let mut seen = 0usize;
234
235        for chunk in perm.chunks(cfg.batch) {
236            let bs = chunk.len();
237            // Pack a mini-batch into contiguous buffers.
238            let mut xb = vec![0f32; bs * hidden];
239            let mut yb = vec![0u32; bs];
240            for (i, &idx) in chunk.iter().enumerate() {
241                xb[i * hidden..(i + 1) * hidden]
242                    .copy_from_slice(&feats[idx * hidden..(idx + 1) * hidden]);
243                yb[i] = labels[idx];
244            }
245
246            // Forward: logits = xb @ weight_t + bias  (bs × num_classes)
247            if logits.len() < bs * num_classes {
248                logits.resize(bs * num_classes, 0.0);
249            }
250            let logits_slice = &mut logits[..bs * num_classes];
251            rlx_cpu::blas::sgemm_bias(
252                &xb,
253                &clf.weight_t,
254                &clf.bias,
255                logits_slice,
256                bs,
257                hidden,
258                num_classes,
259            );
260
261            // Softmax + cross-entropy + count correct + build gradient
262            // (delta = softmax - one_hot).
263            let mut delta = vec![0f32; bs * num_classes];
264            for i in 0..bs {
265                let row = &mut logits_slice[i * num_classes..(i + 1) * num_classes];
266                let max_logit = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
267                let mut sum = 0f32;
268                for v in row.iter_mut() {
269                    *v = (*v - max_logit).exp();
270                    sum += *v;
271                }
272                let inv = 1.0 / sum;
273                let mut argmax = 0usize;
274                let mut argmax_val = -1f32;
275                for (j, v) in row.iter_mut().enumerate() {
276                    *v *= inv;
277                    if *v > argmax_val {
278                        argmax_val = *v;
279                        argmax = j;
280                    }
281                    delta[i * num_classes + j] = *v;
282                }
283                let y = yb[i] as usize;
284                delta[i * num_classes + y] -= 1.0;
285                epoch_loss += -row[y].max(1e-12).ln();
286                if argmax == y {
287                    epoch_correct += 1;
288                }
289            }
290
291            // Gradient: dW = xb^T @ delta  (hidden × num_classes),
292            // averaged over the batch and with L2 decay.
293            let inv_bs = 1.0 / bs as f32;
294            // grad_w is stored as `[hidden, num_classes]` to match weight_t.
295            let mut grad_w = vec![0f32; hidden * num_classes];
296            // Manual GEMM for the transpose-A pattern (no sgemm_at helper).
297            for h_idx in 0..hidden {
298                for c_idx in 0..num_classes {
299                    let mut s = 0f32;
300                    for i in 0..bs {
301                        s += xb[i * hidden + h_idx] * delta[i * num_classes + c_idx];
302                    }
303                    grad_w[h_idx * num_classes + c_idx] = s * inv_bs;
304                }
305            }
306            let mut grad_b = vec![0f32; num_classes];
307            for i in 0..bs {
308                for c_idx in 0..num_classes {
309                    grad_b[c_idx] += delta[i * num_classes + c_idx];
310                }
311            }
312            for v in grad_b.iter_mut() {
313                *v *= inv_bs;
314            }
315
316            // SGD with momentum: v ← μ·v + g + λ·w ; w ← w - lr·v
317            for j in 0..hidden * num_classes {
318                let g = grad_w[j] + cfg.l2 * clf.weight_t[j];
319                vel_w[j] = cfg.momentum * vel_w[j] + g;
320                clf.weight_t[j] -= cfg.lr * vel_w[j];
321            }
322            for j in 0..num_classes {
323                vel_b[j] = cfg.momentum * vel_b[j] + grad_b[j];
324                clf.bias[j] -= cfg.lr * vel_b[j];
325            }
326
327            seen += bs;
328        }
329
330        if verbose {
331            let acc = epoch_correct as f32 / seen as f32;
332            let loss = epoch_loss / seen as f32;
333            eprintln!(
334                "[clf] epoch {:>3}: train_loss={:.4} train_acc={:.4}",
335                epoch + 1,
336                loss,
337                acc
338            );
339        }
340    }
341
342    Ok(clf)
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    /// Tiny linearly-separable 2D dataset to confirm the optimizer converges
350    /// — three clusters at distinct centroids; logreg must reach 100%.
351    #[test]
352    fn logreg_separates_three_clusters() {
353        let hidden = 2;
354        let num_classes = 3;
355        let centroids = [[0.0_f32, 0.0], [3.0, 0.0], [0.0, 3.0]];
356        let mut features: Vec<Vec<f32>> = Vec::new();
357        let mut labels: Vec<usize> = Vec::new();
358        for (c, ctr) in centroids.iter().enumerate() {
359            for k in 0..40 {
360                let jitter = (k as f32 * 0.07) - 1.4;
361                features.push(vec![ctr[0] + jitter, ctr[1] + jitter * 0.5]);
362                labels.push(c);
363            }
364        }
365        let train: Vec<LabeledFeature> = features
366            .iter()
367            .zip(&labels)
368            .map(|(f, l)| LabeledFeature {
369                features: f.as_slice(),
370                label: *l,
371            })
372            .collect();
373        let cfg = TrainConfig {
374            epochs: 100,
375            batch: 16,
376            lr: 0.2,
377            l2: 0.0,
378            momentum: 0.9,
379        };
380        let clf = train_logreg(hidden, num_classes, &train, &cfg, false).unwrap();
381        let acc = clf.accuracy(&train).unwrap();
382        assert!(acc > 0.98, "got {acc}");
383    }
384}