1use anyhow::{Result, bail};
24
25pub struct LabeledFeature<'a> {
27 pub features: &'a [f32],
28 pub label: usize,
29}
30
31pub struct LinearClassifier {
33 pub hidden: usize,
34 pub num_classes: usize,
35 weight_t: Vec<f32>,
38 bias: Vec<f32>,
39}
40
41impl LinearClassifier {
42 pub fn new(hidden: usize, num_classes: usize) -> Self {
44 let scale = (2.0_f32 / hidden as f32).sqrt();
45 let mut weight_t = vec![0f32; hidden * num_classes];
46 let mut state: u32 = 0x9e37_79b9;
49 for w in weight_t.iter_mut() {
50 state = state.wrapping_mul(48_271).wrapping_add(0x9e37_79b9);
51 let u = (state >> 16) as f32 / 65536.0;
52 *w = (u * 2.0 - 1.0) * scale * 0.5;
54 }
55 Self {
56 hidden,
57 num_classes,
58 weight_t,
59 bias: vec![0f32; num_classes],
60 }
61 }
62
63 pub fn predict(&self, features: &[f32]) -> Result<usize> {
65 if features.len() != self.hidden {
66 bail!(
67 "LinearClassifier::predict: expected {} features, got {}",
68 self.hidden,
69 features.len()
70 );
71 }
72 let mut logits = vec![0f32; self.num_classes];
73 rlx_cpu::blas::sgemm_bias(
75 features,
76 &self.weight_t,
77 &self.bias,
78 &mut logits,
79 1,
80 self.hidden,
81 self.num_classes,
82 );
83 let mut best = 0usize;
84 let mut best_val = logits[0];
85 for (j, &v) in logits.iter().enumerate().skip(1) {
86 if v > best_val {
87 best_val = v;
88 best = j;
89 }
90 }
91 Ok(best)
92 }
93
94 pub fn accuracy(&self, examples: &[LabeledFeature<'_>]) -> Result<f32> {
96 if examples.is_empty() {
97 return Ok(0.0);
98 }
99 let n = examples.len();
102 let mut feats = vec![0f32; n * self.hidden];
103 let mut labels = vec![0usize; n];
104 for (i, ex) in examples.iter().enumerate() {
105 if ex.features.len() != self.hidden {
106 bail!(
107 "LinearClassifier::accuracy: row {i} has {} features, expected {}",
108 ex.features.len(),
109 self.hidden
110 );
111 }
112 feats[i * self.hidden..(i + 1) * self.hidden].copy_from_slice(ex.features);
113 labels[i] = ex.label;
114 }
115 let mut logits = vec![0f32; n * self.num_classes];
116 rlx_cpu::blas::sgemm_bias(
117 &feats,
118 &self.weight_t,
119 &self.bias,
120 &mut logits,
121 n,
122 self.hidden,
123 self.num_classes,
124 );
125 let mut correct = 0usize;
126 for i in 0..n {
127 let row = &logits[i * self.num_classes..(i + 1) * self.num_classes];
128 let pred = row
129 .iter()
130 .enumerate()
131 .fold(
132 (0usize, row[0]),
133 |(bi, bv), (j, &v)| {
134 if v > bv { (j, v) } else { (bi, bv) }
135 },
136 )
137 .0;
138 if pred == labels[i] {
139 correct += 1;
140 }
141 }
142 Ok(correct as f32 / n as f32)
143 }
144}
145
146#[derive(Debug, Clone)]
148pub struct TrainConfig {
149 pub epochs: usize,
151 pub batch: usize,
153 pub lr: f32,
155 pub l2: f32,
157 pub momentum: f32,
159}
160
161impl Default for TrainConfig {
162 fn default() -> Self {
163 Self {
164 epochs: 20,
165 batch: 32,
166 lr: 0.1,
167 l2: 1e-4,
168 momentum: 0.9,
169 }
170 }
171}
172
173pub fn train_logreg(
177 hidden: usize,
178 num_classes: usize,
179 train: &[LabeledFeature<'_>],
180 cfg: &TrainConfig,
181 verbose: bool,
182) -> Result<LinearClassifier> {
183 if train.is_empty() {
184 bail!("train_logreg: empty training set");
185 }
186 let mut clf = LinearClassifier::new(hidden, num_classes);
187
188 let n = train.len();
191 let mut feats = vec![0f32; n * hidden];
192 let mut labels = vec![0u32; n];
193 for (i, ex) in train.iter().enumerate() {
194 if ex.features.len() != hidden {
195 bail!(
196 "train row {i} has {} features, expected {hidden}",
197 ex.features.len()
198 );
199 }
200 if ex.label >= num_classes {
201 bail!(
202 "train row {i} label {} ≥ num_classes {num_classes}",
203 ex.label
204 );
205 }
206 feats[i * hidden..(i + 1) * hidden].copy_from_slice(ex.features);
207 labels[i] = ex.label as u32;
208 }
209
210 let mut perm: Vec<usize> = (0..n).collect();
212 let mut rng_state: u32 = 0x1234_5678;
213 let lcg = |s: &mut u32| -> u32 {
214 *s = s.wrapping_mul(48_271).wrapping_add(0x9e37_79b9);
215 *s
216 };
217
218 let mut vel_w = vec![0f32; hidden * num_classes];
220 let mut vel_b = vec![0f32; num_classes];
221 let mut logits = vec![0f32; cfg.batch * num_classes];
223
224 for epoch in 0..cfg.epochs {
225 for i in (1..n).rev() {
227 let j = (lcg(&mut rng_state) as usize) % (i + 1);
228 perm.swap(i, j);
229 }
230
231 let mut epoch_loss = 0f32;
232 let mut epoch_correct = 0usize;
233 let mut seen = 0usize;
234
235 for chunk in perm.chunks(cfg.batch) {
236 let bs = chunk.len();
237 let mut xb = vec![0f32; bs * hidden];
239 let mut yb = vec![0u32; bs];
240 for (i, &idx) in chunk.iter().enumerate() {
241 xb[i * hidden..(i + 1) * hidden]
242 .copy_from_slice(&feats[idx * hidden..(idx + 1) * hidden]);
243 yb[i] = labels[idx];
244 }
245
246 if logits.len() < bs * num_classes {
248 logits.resize(bs * num_classes, 0.0);
249 }
250 let logits_slice = &mut logits[..bs * num_classes];
251 rlx_cpu::blas::sgemm_bias(
252 &xb,
253 &clf.weight_t,
254 &clf.bias,
255 logits_slice,
256 bs,
257 hidden,
258 num_classes,
259 );
260
261 let mut delta = vec![0f32; bs * num_classes];
264 for i in 0..bs {
265 let row = &mut logits_slice[i * num_classes..(i + 1) * num_classes];
266 let max_logit = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
267 let mut sum = 0f32;
268 for v in row.iter_mut() {
269 *v = (*v - max_logit).exp();
270 sum += *v;
271 }
272 let inv = 1.0 / sum;
273 let mut argmax = 0usize;
274 let mut argmax_val = -1f32;
275 for (j, v) in row.iter_mut().enumerate() {
276 *v *= inv;
277 if *v > argmax_val {
278 argmax_val = *v;
279 argmax = j;
280 }
281 delta[i * num_classes + j] = *v;
282 }
283 let y = yb[i] as usize;
284 delta[i * num_classes + y] -= 1.0;
285 epoch_loss += -row[y].max(1e-12).ln();
286 if argmax == y {
287 epoch_correct += 1;
288 }
289 }
290
291 let inv_bs = 1.0 / bs as f32;
294 let mut grad_w = vec![0f32; hidden * num_classes];
296 for h_idx in 0..hidden {
298 for c_idx in 0..num_classes {
299 let mut s = 0f32;
300 for i in 0..bs {
301 s += xb[i * hidden + h_idx] * delta[i * num_classes + c_idx];
302 }
303 grad_w[h_idx * num_classes + c_idx] = s * inv_bs;
304 }
305 }
306 let mut grad_b = vec![0f32; num_classes];
307 for i in 0..bs {
308 for c_idx in 0..num_classes {
309 grad_b[c_idx] += delta[i * num_classes + c_idx];
310 }
311 }
312 for v in grad_b.iter_mut() {
313 *v *= inv_bs;
314 }
315
316 for j in 0..hidden * num_classes {
318 let g = grad_w[j] + cfg.l2 * clf.weight_t[j];
319 vel_w[j] = cfg.momentum * vel_w[j] + g;
320 clf.weight_t[j] -= cfg.lr * vel_w[j];
321 }
322 for j in 0..num_classes {
323 vel_b[j] = cfg.momentum * vel_b[j] + grad_b[j];
324 clf.bias[j] -= cfg.lr * vel_b[j];
325 }
326
327 seen += bs;
328 }
329
330 if verbose {
331 let acc = epoch_correct as f32 / seen as f32;
332 let loss = epoch_loss / seen as f32;
333 eprintln!(
334 "[clf] epoch {:>3}: train_loss={:.4} train_acc={:.4}",
335 epoch + 1,
336 loss,
337 acc
338 );
339 }
340 }
341
342 Ok(clf)
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
352 fn logreg_separates_three_clusters() {
353 let hidden = 2;
354 let num_classes = 3;
355 let centroids = [[0.0_f32, 0.0], [3.0, 0.0], [0.0, 3.0]];
356 let mut features: Vec<Vec<f32>> = Vec::new();
357 let mut labels: Vec<usize> = Vec::new();
358 for (c, ctr) in centroids.iter().enumerate() {
359 for k in 0..40 {
360 let jitter = (k as f32 * 0.07) - 1.4;
361 features.push(vec![ctr[0] + jitter, ctr[1] + jitter * 0.5]);
362 labels.push(c);
363 }
364 }
365 let train: Vec<LabeledFeature> = features
366 .iter()
367 .zip(&labels)
368 .map(|(f, l)| LabeledFeature {
369 features: f.as_slice(),
370 label: *l,
371 })
372 .collect();
373 let cfg = TrainConfig {
374 epochs: 100,
375 batch: 16,
376 lr: 0.2,
377 l2: 0.0,
378 momentum: 0.9,
379 };
380 let clf = train_logreg(hidden, num_classes, &train, &cfg, false).unwrap();
381 let acc = clf.accuracy(&train).unwrap();
382 assert!(acc > 0.98, "got {acc}");
383 }
384}