Skip to main content

tsetlin_rs/
multiclass.rs

1//! Multi-class classification Tsetlin Machine.
2
3#[cfg(not(feature = "std"))]
4use alloc::vec::Vec;
5use core::cmp::Ordering;
6
7use rand::Rng;
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    Clause, Config,
13    feedback::{type_i, type_ii},
14    utils::rng_from_seed
15};
16
17/// Multi-class Tsetlin Machine using one-vs-all strategy.
18///
19/// Each class has its own set of clauses. Prediction is the class
20/// with the highest vote sum.
21///
22/// # Example
23///
24/// ```
25/// use tsetlin_rs::{Config, MultiClass};
26///
27/// let config = Config::builder().clauses(100).features(4).build().unwrap();
28/// let mut tm = MultiClass::new(config, 3, 50);
29///
30/// // Train on data where label is class index (0, 1, or 2)
31/// let x = vec![vec![1, 1, 0, 0], vec![0, 0, 1, 1]];
32/// let y = vec![0, 1];
33/// tm.fit(&x, &y, 100, 42);
34/// ```
35#[derive(Debug, Clone)]
36#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
37pub struct MultiClass {
38    clauses:   Vec<Vec<Clause>>,
39    config:    Config,
40    threshold: i32
41}
42
43impl MultiClass {
44    /// Creates multi-class machine with given number of classes.
45    ///
46    /// # Arguments
47    ///
48    /// * `config` - Machine configuration (clauses, features, etc.)
49    /// * `n_classes` - Number of output classes
50    /// * `threshold` - Vote threshold for training
51    #[must_use]
52    pub fn new(config: Config, n_classes: usize, threshold: i32) -> Self {
53        let clauses = (0..n_classes)
54            .map(|_| {
55                (0..config.n_clauses)
56                    .map(|i| {
57                        let p = if i % 2 == 0 { 1 } else { -1 };
58                        Clause::new(config.n_features, config.n_states, p)
59                    })
60                    .collect()
61            })
62            .collect();
63
64        Self {
65            clauses,
66            config,
67            threshold
68        }
69    }
70
71    /// Returns the number of classes.
72    #[inline]
73    #[must_use]
74    pub fn n_classes(&self) -> usize {
75        self.clauses.len()
76    }
77
78    /// Computes vote sums for each class.
79    #[must_use]
80    pub fn class_votes(&self, x: &[u8]) -> Vec<f32> {
81        self.clauses
82            .iter()
83            .map(|cls| cls.iter().map(|c| c.vote(x)).sum())
84            .collect()
85    }
86
87    /// Predicts class with highest vote sum.
88    ///
89    /// Returns 0 if votes are empty or contain NaN.
90    #[must_use]
91    pub fn predict(&self, x: &[u8]) -> usize {
92        self.class_votes(x)
93            .iter()
94            .enumerate()
95            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
96            .map_or(0, |(i, _)| i)
97    }
98
99    /// Trains on a single example.
100    pub fn train_one<R: Rng>(&mut self, x: &[u8], y: usize, rng: &mut R) {
101        let votes = self.class_votes(x);
102        let t = self.threshold as f32;
103
104        for (class_idx, class_clauses) in self.clauses.iter_mut().enumerate() {
105            let is_target = class_idx == y;
106            let sum = votes[class_idx].clamp(-t, t);
107
108            for clause in class_clauses {
109                let fires = clause.evaluate(x);
110                let p = clause.polarity();
111
112                if is_target {
113                    let prob = (t - sum) / (2.0 * t);
114                    if p == 1 && rng.random::<f32>() <= prob {
115                        type_i(clause, x, fires, self.config.s, rng);
116                    } else if p == -1 && fires && rng.random::<f32>() <= prob {
117                        type_ii(clause, x);
118                    }
119                } else {
120                    let prob = (t + sum) / (2.0 * t);
121                    if p == -1 && rng.random::<f32>() <= prob {
122                        type_i(clause, x, fires, self.config.s, rng);
123                    } else if p == 1 && fires && rng.random::<f32>() <= prob {
124                        type_ii(clause, x);
125                    }
126                }
127            }
128        }
129    }
130
131    /// Trains for given number of epochs.
132    ///
133    /// # Arguments
134    ///
135    /// * `x` - Training inputs (binary features)
136    /// * `y` - Class labels (0 to n_classes-1)
137    /// * `epochs` - Number of training epochs
138    /// * `seed` - Random seed for reproducibility
139    pub fn fit(&mut self, x: &[Vec<u8>], y: &[usize], epochs: usize, seed: u64) {
140        let mut rng = rng_from_seed(seed);
141        let mut indices: Vec<usize> = (0..x.len()).collect();
142
143        for _ in 0..epochs {
144            crate::utils::shuffle(&mut indices, &mut rng);
145            for &i in &indices {
146                self.train_one(&x[i], y[i], &mut rng);
147            }
148        }
149    }
150
151    /// Evaluates classification accuracy on test data.
152    ///
153    /// Returns fraction of correct predictions (0.0 to 1.0).
154    #[must_use]
155    pub fn evaluate(&self, x: &[Vec<u8>], y: &[usize]) -> f32 {
156        if x.is_empty() {
157            return 0.0;
158        }
159        let correct = x
160            .iter()
161            .zip(y)
162            .filter(|(xi, yi)| self.predict(xi) == **yi)
163            .count();
164        correct as f32 / x.len() as f32
165    }
166
167    /// Quick constructor with sensible defaults.
168    ///
169    /// # Panics
170    ///
171    /// Panics if n_clauses is odd or zero, or n_features is zero.
172    #[must_use]
173    pub fn quick(n_clauses: usize, n_features: usize, n_classes: usize, threshold: i32) -> Self {
174        let config = Config::builder()
175            .clauses(n_clauses)
176            .features(n_features)
177            .build()
178            .expect("invalid quick config");
179        Self::new(config, n_classes, threshold)
180    }
181
182    /// Trains on a single sample (online/incremental learning).
183    ///
184    /// # Arguments
185    ///
186    /// * `x` - Single input sample (binary features)
187    /// * `y` - Class label (0 to n_classes-1)
188    /// * `seed` - Random seed for this update
189    #[inline]
190    pub fn partial_fit(&mut self, x: &[u8], y: usize, seed: u64) {
191        let mut rng = rng_from_seed(seed);
192        self.train_one(x, y, &mut rng);
193    }
194
195    /// Trains on a mini-batch of samples (online/incremental learning).
196    ///
197    /// # Arguments
198    ///
199    /// * `xs` - Batch of input samples
200    /// * `ys` - Batch of class labels
201    /// * `seed` - Random seed for this batch
202    pub fn partial_fit_batch(&mut self, xs: &[Vec<u8>], ys: &[usize], seed: u64) {
203        if xs.is_empty() || xs.len() != ys.len() {
204            return;
205        }
206
207        let mut rng = rng_from_seed(seed);
208        for (x, &y) in xs.iter().zip(ys) {
209            self.train_one(x, y, &mut rng);
210        }
211    }
212}
213
214impl crate::model::TsetlinModel<Vec<u8>, usize> for MultiClass {
215    fn fit(&mut self, x: &[Vec<u8>], y: &[usize], epochs: usize, seed: u64) {
216        MultiClass::fit(self, x, y, epochs, seed);
217    }
218
219    fn predict(&self, x: &Vec<u8>) -> usize {
220        MultiClass::predict(self, x)
221    }
222
223    fn evaluate(&self, x: &[Vec<u8>], y: &[usize]) -> f32 {
224        MultiClass::evaluate(self, x, y)
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn predict_valid_class() {
234        let config = Config::builder().clauses(10).features(4).build().unwrap();
235        let tm = MultiClass::new(config, 3, 15);
236
237        assert!(tm.predict(&[1, 0, 1, 0]) < 3);
238    }
239
240    #[test]
241    fn n_classes_correct() {
242        let config = Config::builder().clauses(10).features(4).build().unwrap();
243        let tm = MultiClass::new(config, 5, 15);
244        assert_eq!(tm.n_classes(), 5);
245    }
246
247    #[test]
248    fn class_votes_returns_all_classes() {
249        let config = Config::builder().clauses(10).features(4).build().unwrap();
250        let tm = MultiClass::new(config, 3, 15);
251        let votes = tm.class_votes(&[1, 0, 1, 0]);
252        assert_eq!(votes.len(), 3);
253    }
254
255    #[test]
256    fn quick_constructor() {
257        let tm = MultiClass::quick(20, 4, 3, 15);
258        assert_eq!(tm.n_classes(), 3);
259    }
260
261    #[test]
262    fn evaluate_empty_returns_zero() {
263        let config = Config::builder().clauses(10).features(4).build().unwrap();
264        let tm = MultiClass::new(config, 3, 15);
265        assert!((tm.evaluate(&[], &[]) - 0.0).abs() < 0.001);
266    }
267
268    #[test]
269    fn train_one_modifies_state() {
270        let config = Config::builder().clauses(10).features(4).build().unwrap();
271        let mut tm = MultiClass::new(config, 2, 15);
272        let mut rng = rng_from_seed(42);
273
274        // Train on one example
275        tm.train_one(&[1, 0, 1, 0], 0, &mut rng);
276        tm.train_one(&[0, 1, 0, 1], 1, &mut rng);
277
278        // Should still return valid predictions
279        assert!(tm.predict(&[1, 0, 1, 0]) < 2);
280    }
281
282    #[test]
283    fn fit_and_evaluate() {
284        let mut tm = MultiClass::quick(50, 4, 2, 25);
285
286        // Simple pattern: class 0 has first bits set, class 1 has last bits set
287        let x = vec![
288            vec![1, 1, 0, 0],
289            vec![1, 0, 0, 0],
290            vec![0, 0, 1, 1],
291            vec![0, 0, 0, 1],
292        ];
293        let y = vec![0, 0, 1, 1];
294
295        tm.fit(&x, &y, 100, 42);
296
297        // Should achieve some accuracy
298        let acc = tm.evaluate(&x, &y);
299        assert!((0.0..=1.0).contains(&acc));
300    }
301
302    #[test]
303    fn trait_impl_works() {
304        use crate::model::TsetlinModel;
305
306        let config = Config::builder().clauses(20).features(4).build().unwrap();
307        let mut tm = MultiClass::new(config, 2, 15);
308
309        let x = vec![vec![1, 1, 0, 0], vec![0, 0, 1, 1]];
310        let y = vec![0, 1];
311
312        TsetlinModel::fit(&mut tm, &x, &y, 50, 42);
313        let pred = TsetlinModel::predict(&tm, &x[0]);
314        assert!(pred < 2);
315
316        let acc = TsetlinModel::evaluate(&tm, &x, &y);
317        assert!((0.0..=1.0).contains(&acc));
318    }
319
320    #[test]
321    fn partial_fit_single_sample() {
322        let mut tm = MultiClass::quick(20, 4, 3, 15);
323
324        tm.partial_fit(&[1, 1, 0, 0], 0, 42);
325        tm.partial_fit(&[0, 1, 1, 0], 1, 43);
326        tm.partial_fit(&[0, 0, 1, 1], 2, 44);
327
328        assert!(tm.predict(&[1, 1, 0, 0]) < 3);
329    }
330
331    #[test]
332    fn partial_fit_batch() {
333        let mut tm = MultiClass::quick(50, 4, 2, 25);
334
335        let x = vec![
336            vec![1, 1, 0, 0],
337            vec![1, 0, 0, 0],
338            vec![0, 0, 1, 1],
339            vec![0, 0, 0, 1],
340        ];
341        let y = vec![0, 0, 1, 1];
342
343        for epoch in 0..100 {
344            tm.partial_fit_batch(&x, &y, 42 + epoch);
345        }
346
347        let acc = tm.evaluate(&x, &y);
348        assert!((0.0..=1.0).contains(&acc));
349    }
350
351    #[test]
352    fn partial_fit_empty_batch() {
353        let mut tm = MultiClass::quick(10, 4, 2, 15);
354        tm.partial_fit_batch(&[], &[], 42);
355        assert_eq!(tm.n_classes(), 2);
356    }
357}