Skip to main content

tsetlin_rs/
small.rs

1//! Small-sized Tsetlin Machine with const generics for compile-time
2//! optimization.
3//!
4//! Uses stack allocation and compile-time unrolling for small feature sets.
5//! Zero heap allocations, full loop unrolling at compile time.
6
7use core::array;
8
9use crate::Automaton;
10
11/// A clause with compile-time known feature count.
12///
13/// Stack-allocated, zero heap allocations. The compiler fully unrolls
14/// all loops for known N, enabling maximum optimization.
15///
16/// # Memory Layout
17///
18/// Fields ordered to minimize padding:
19/// - `include: [Automaton; N]` (N * 4 bytes)
20/// - `negated: [Automaton; N]` (N * 4 bytes)
21/// - `weight: f32` (4 bytes)
22/// - `activations: u32` (4 bytes)
23/// - `correct: u32` (4 bytes)
24/// - `incorrect: u32` (4 bytes)
25/// - `polarity: i8` (1 byte + padding to 64-byte alignment)
26///
27/// # Performance
28///
29/// Optimal for N <= 64. For larger feature sets, use `BitwiseClause`.
30// Note: serde not supported for const generic arrays without serde_big_array.
31// Use Vec-based Clause for serialization needs.
32#[derive(Debug, Clone)]
33#[repr(align(64))]
34pub struct SmallClause<const N: usize> {
35    include:     [Automaton; N],
36    negated:     [Automaton; N],
37    weight:      f32,
38    activations: u32,
39    correct:     u32,
40    incorrect:   u32,
41    polarity:    i8
42}
43
44impl<const N: usize> SmallClause<N> {
45    /// Creates clause with given states and polarity.
46    ///
47    /// # Arguments
48    ///
49    /// * `n_states` - States per automaton (threshold for action)
50    /// * `polarity` - Must be +1 or -1
51    ///
52    /// # Panics
53    ///
54    /// Debug-asserts that polarity is +1 or -1.
55    #[inline]
56    #[must_use]
57    pub fn new(n_states: i16, polarity: i8) -> Self {
58        debug_assert!(polarity == 1 || polarity == -1);
59        Self {
60            include: array::from_fn(|_| Automaton::new(n_states)),
61            negated: array::from_fn(|_| Automaton::new(n_states)),
62            weight: 1.0,
63            activations: 0,
64            correct: 0,
65            incorrect: 0,
66            polarity
67        }
68    }
69
70    /// Returns the clause polarity (+1 or -1).
71    #[inline(always)]
72    #[must_use]
73    pub const fn polarity(&self) -> i8 {
74        self.polarity
75    }
76
77    /// Returns the number of features (compile-time constant).
78    #[inline(always)]
79    #[must_use]
80    pub const fn n_features(&self) -> usize {
81        N
82    }
83
84    /// Returns the current clause weight.
85    #[inline(always)]
86    #[must_use]
87    pub const fn weight(&self) -> f32 {
88        self.weight
89    }
90
91    /// Returns the activation count since last reset.
92    #[inline(always)]
93    #[must_use]
94    pub const fn activations(&self) -> u32 {
95        self.activations
96    }
97
98    /// Returns read-only access to include automata.
99    #[inline(always)]
100    #[must_use]
101    pub const fn include_automata(&self) -> &[Automaton; N] {
102        &self.include
103    }
104
105    /// Returns mutable access to include automata.
106    #[inline(always)]
107    pub fn include_automata_mut(&mut self) -> &mut [Automaton; N] {
108        &mut self.include
109    }
110
111    /// Returns read-only access to negated automata.
112    #[inline(always)]
113    #[must_use]
114    pub const fn negated_automata(&self) -> &[Automaton; N] {
115        &self.negated
116    }
117
118    /// Returns mutable access to negated automata.
119    #[inline(always)]
120    pub fn negated_automata_mut(&mut self) -> &mut [Automaton; N] {
121        &mut self.negated
122    }
123
124    /// Evaluates clause on input with early exit on violation.
125    ///
126    /// Returns `true` if all included literals are satisfied.
127    /// Loop is fully unrolled at compile time for maximum performance.
128    ///
129    /// # Performance
130    ///
131    /// Uses unchecked indexing. Safety is guaranteed by const generic bounds.
132    #[inline]
133    #[must_use]
134    pub fn evaluate(&self, x: &[u8; N]) -> bool {
135        for k in 0..N {
136            // SAFETY: k is always in bounds [0, N) due to loop bounds
137            // and array size is exactly N.
138            let include_action = unsafe { self.include.get_unchecked(k).action() };
139            let negated_action = unsafe { self.negated.get_unchecked(k).action() };
140            let xk = unsafe { *x.get_unchecked(k) };
141
142            if include_action && xk == 0 {
143                return false;
144            }
145            if negated_action && xk == 1 {
146                return false;
147            }
148        }
149        true
150    }
151
152    /// Evaluates clause and tracks activation count.
153    ///
154    /// Use during training to track which clauses are active.
155    #[inline]
156    pub fn evaluate_tracked(&mut self, x: &[u8; N]) -> bool {
157        let fires = self.evaluate(x);
158        if fires {
159            self.activations = self.activations.saturating_add(1);
160        }
161        fires
162    }
163
164    /// Returns weighted vote: `polarity * weight` if clause fires, `0.0`
165    /// otherwise.
166    #[inline(always)]
167    #[must_use]
168    pub fn vote_weighted(&self, x: &[u8; N]) -> f32 {
169        if self.evaluate(x) {
170            self.polarity as f32 * self.weight
171        } else {
172            0.0
173        }
174    }
175
176    /// Returns unweighted vote: `polarity` if fires, `0` otherwise.
177    #[inline(always)]
178    #[must_use]
179    pub fn vote(&self, x: &[u8; N]) -> i32 {
180        if self.evaluate(x) {
181            self.polarity as i32
182        } else {
183            0
184        }
185    }
186
187    /// Records prediction outcome for weight learning.
188    ///
189    /// Call when clause fired and prediction was made.
190    #[inline]
191    pub fn record_outcome(&mut self, was_correct: bool) {
192        if was_correct {
193            self.correct = self.correct.saturating_add(1);
194        } else {
195            self.incorrect = self.incorrect.saturating_add(1);
196        }
197    }
198
199    /// Updates weight based on accumulated outcomes.
200    ///
201    /// Weight increases when clause predictions are accurate,
202    /// decreases when inaccurate. Call at end of each epoch.
203    ///
204    /// # Arguments
205    ///
206    /// * `learning_rate` - How fast weight changes (0.0 - 1.0)
207    /// * `min_weight` - Minimum allowed weight
208    /// * `max_weight` - Maximum allowed weight
209    pub fn update_weight(&mut self, learning_rate: f32, min_weight: f32, max_weight: f32) {
210        let total = self.correct + self.incorrect;
211        if total == 0 {
212            return;
213        }
214
215        let accuracy = self.correct as f32 / total as f32;
216        let adjustment = (accuracy - 0.5) * 2.0 * learning_rate;
217        self.weight = (self.weight + adjustment).clamp(min_weight, max_weight);
218
219        self.correct = 0;
220        self.incorrect = 0;
221    }
222
223    /// Returns `true` if clause is "dead" (rarely activates or very low
224    /// weight).
225    ///
226    /// Dead clauses can be pruned and reset during training.
227    #[inline]
228    #[must_use]
229    pub const fn is_dead(&self, min_activations: u32, min_weight: f32) -> bool {
230        self.activations < min_activations || self.weight < min_weight
231    }
232
233    /// Resets activation counter. Call at start of each epoch.
234    #[inline]
235    pub fn reset_activations(&mut self) {
236        self.activations = 0;
237    }
238
239    /// Resets all statistics (activations, correct, incorrect).
240    #[inline]
241    pub fn reset_stats(&mut self) {
242        self.activations = 0;
243        self.correct = 0;
244        self.incorrect = 0;
245    }
246}
247
248/// Type alias for 2-feature clause.
249pub type Clause2 = SmallClause<2>;
250/// Type alias for 4-feature clause (XOR problems).
251pub type Clause4 = SmallClause<4>;
252/// Type alias for 8-feature clause.
253pub type Clause8 = SmallClause<8>;
254/// Type alias for 16-feature clause.
255pub type Clause16 = SmallClause<16>;
256/// Type alias for 32-feature clause.
257pub type Clause32 = SmallClause<32>;
258/// Type alias for 64-feature clause.
259pub type Clause64 = SmallClause<64>;
260
261/// Bitwise clause with compile-time known feature count.
262///
263/// Uses packed u64 bitmasks for SIMD-friendly evaluation.
264/// Processes 64 features per AND operation.
265///
266/// # Type Parameters
267///
268/// * `N` - Number of features (compile-time constant)
269/// * `W` - Number of u64 words needed: `(N + 63) / 64`
270// Note: serde not supported for const generic arrays without serde_big_array.
271#[derive(Debug, Clone)]
272#[repr(align(64))]
273pub struct SmallBitwiseClause<const N: usize, const W: usize> {
274    include:  [Automaton; N],
275    negated:  [Automaton; N],
276    inc_mask: [u64; W],
277    neg_mask: [u64; W],
278    weight:   f32,
279    polarity: i8,
280    dirty:    bool
281}
282
283impl<const N: usize, const W: usize> SmallBitwiseClause<N, W> {
284    /// Creates bitwise clause with given states and polarity.
285    ///
286    /// # Panics
287    ///
288    /// Debug-asserts that W == (N + 63) / 64.
289    #[inline]
290    #[must_use]
291    pub fn new(n_states: i16, polarity: i8) -> Self {
292        debug_assert!(polarity == 1 || polarity == -1);
293        debug_assert_eq!(W, N.div_ceil(64), "W must equal ceil(N/64)");
294        Self {
295            include: array::from_fn(|_| Automaton::new(n_states)),
296            negated: array::from_fn(|_| Automaton::new(n_states)),
297            inc_mask: [0; W],
298            neg_mask: [0; W],
299            weight: 1.0,
300            polarity,
301            dirty: true
302        }
303    }
304
305    /// Returns the clause polarity (+1 or -1).
306    #[inline(always)]
307    #[must_use]
308    pub const fn polarity(&self) -> i8 {
309        self.polarity
310    }
311
312    /// Returns the number of features (compile-time constant).
313    #[inline(always)]
314    #[must_use]
315    pub const fn n_features(&self) -> usize {
316        N
317    }
318
319    /// Returns read-only access to include automata.
320    #[inline(always)]
321    #[must_use]
322    pub const fn include_automata(&self) -> &[Automaton; N] {
323        &self.include
324    }
325
326    /// Returns mutable access to include automata.
327    #[inline(always)]
328    pub fn include_automata_mut(&mut self) -> &mut [Automaton; N] {
329        self.dirty = true;
330        &mut self.include
331    }
332
333    /// Returns read-only access to negated automata.
334    #[inline(always)]
335    #[must_use]
336    pub const fn negated_automata(&self) -> &[Automaton; N] {
337        &self.negated
338    }
339
340    /// Returns mutable access to negated automata.
341    #[inline(always)]
342    pub fn negated_automata_mut(&mut self) -> &mut [Automaton; N] {
343        self.dirty = true;
344        &mut self.negated
345    }
346
347    /// Rebuilds bitmasks from automaton states.
348    ///
349    /// Call after training before evaluation.
350    pub fn rebuild_masks(&mut self) {
351        if !self.dirty {
352            return;
353        }
354
355        for word in &mut self.inc_mask {
356            *word = 0;
357        }
358        for word in &mut self.neg_mask {
359            *word = 0;
360        }
361
362        for k in 0..N {
363            let word_idx = k / 64;
364            let bit_idx = k % 64;
365
366            // SAFETY: k < N, so word_idx < W and k < N (automata bounds)
367            if unsafe { self.include.get_unchecked(k).action() } {
368                self.inc_mask[word_idx] |= 1u64 << bit_idx;
369            }
370            if unsafe { self.negated.get_unchecked(k).action() } {
371                self.neg_mask[word_idx] |= 1u64 << bit_idx;
372            }
373        }
374
375        self.dirty = false;
376    }
377
378    /// Evaluates clause using bitwise AND operations.
379    ///
380    /// Processes 64 features per CPU instruction for massive speedup.
381    ///
382    /// # Panics
383    ///
384    /// Debug-asserts that `rebuild_masks()` was called after training.
385    #[inline]
386    #[must_use]
387    pub fn evaluate_packed(&self, x_packed: &[u64; W]) -> bool {
388        debug_assert!(!self.dirty, "call rebuild_masks() first");
389
390        for i in 0..W {
391            // SAFETY: i < W, all arrays have size W
392            let x = unsafe { *x_packed.get_unchecked(i) };
393            let inc = unsafe { *self.inc_mask.get_unchecked(i) };
394            let neg = unsafe { *self.neg_mask.get_unchecked(i) };
395
396            // include violation: inc & !x != 0 (required bit is 0)
397            // negated violation: neg & x != 0 (forbidden bit is 1)
398            if (inc & !x) | (neg & x) != 0 {
399                return false;
400            }
401        }
402        true
403    }
404
405    /// Returns polarity if fires, 0 otherwise.
406    #[inline(always)]
407    #[must_use]
408    pub fn vote_packed(&self, x_packed: &[u64; W]) -> i32 {
409        if self.evaluate_packed(x_packed) {
410            self.polarity as i32
411        } else {
412            0
413        }
414    }
415
416    /// Returns weighted vote if fires.
417    #[inline(always)]
418    #[must_use]
419    pub fn vote_weighted_packed(&self, x_packed: &[u64; W]) -> f32 {
420        if self.evaluate_packed(x_packed) {
421            self.polarity as f32 * self.weight
422        } else {
423            0.0
424        }
425    }
426}
427
428/// Packs binary input into u64 array for bitwise evaluation.
429///
430/// Compile-time known size for zero allocation.
431#[inline]
432#[must_use]
433pub fn pack_input_small<const N: usize, const W: usize>(x: &[u8; N]) -> [u64; W] {
434    let mut packed = [0u64; W];
435
436    for (k, &xk) in x.iter().enumerate() {
437        if xk != 0 {
438            packed[k / 64] |= 1u64 << (k % 64);
439        }
440    }
441
442    packed
443}
444
445/// Bitwise clause for 64 features (1 u64 word).
446pub type BitwiseClause64 = SmallBitwiseClause<64, 1>;
447/// Bitwise clause for 128 features (2 u64 words).
448pub type BitwiseClause128 = SmallBitwiseClause<128, 2>;
449/// Bitwise clause for 256 features (4 u64 words).
450pub type BitwiseClause256 = SmallBitwiseClause<256, 4>;
451
452/// Binary classification Tsetlin Machine with compile-time known dimensions.
453///
454/// Stack-allocated with zero heap allocations. All loops unrolled at compile
455/// time.
456///
457/// # Type Parameters
458///
459/// * `N` - Number of features (compile-time constant)
460/// * `C` - Number of clauses (compile-time constant, must be even)
461///
462/// # Performance
463///
464/// Up to 3x faster than dynamic [`TsetlinMachine`](crate::TsetlinMachine)
465/// for small dimensions due to:
466/// - No heap allocations
467/// - Full loop unrolling
468/// - Better cache locality
469///
470/// # Example
471///
472/// ```
473/// use tsetlin_rs::SmallTsetlinMachine;
474///
475/// // XOR with 2 features and 20 clauses
476/// let mut tm: SmallTsetlinMachine<2, 20> = SmallTsetlinMachine::new(100, 15);
477///
478/// let x = [[0, 0], [0, 1], [1, 0], [1, 1]];
479/// let y = [0u8, 1, 1, 0];
480///
481/// tm.fit(&x, &y, 200, 42);
482/// assert!(tm.evaluate(&x, &y) >= 0.75);
483/// ```
484#[derive(Debug, Clone)]
485#[repr(align(64))]
486pub struct SmallTsetlinMachine<const N: usize, const C: usize> {
487    clauses: [SmallClause<N>; C],
488    s:       f32,
489    t:       f32
490}
491
492impl<const N: usize, const C: usize> SmallTsetlinMachine<N, C> {
493    /// Creates new machine with given states and threshold.
494    ///
495    /// Half clauses get +1 polarity, half get -1.
496    ///
497    /// # Panics
498    ///
499    /// Debug-asserts that C is even.
500    #[must_use]
501    pub fn new(n_states: i16, threshold: i32) -> Self {
502        debug_assert!(C.is_multiple_of(2), "C must be even");
503        Self {
504            clauses: array::from_fn(|i| {
505                let p = if i % 2 == 0 { 1 } else { -1 };
506                SmallClause::new(n_states, p)
507            }),
508            s:       3.9,
509            t:       threshold as f32
510        }
511    }
512
513    /// Creates machine with custom specificity parameter.
514    #[must_use]
515    pub fn with_s(n_states: i16, threshold: i32, s: f32) -> Self {
516        let mut tm = Self::new(n_states, threshold);
517        tm.s = s;
518        tm
519    }
520
521    /// Returns the number of features (compile-time constant).
522    #[inline(always)]
523    #[must_use]
524    pub const fn n_features(&self) -> usize {
525        N
526    }
527
528    /// Returns the number of clauses (compile-time constant).
529    #[inline(always)]
530    #[must_use]
531    pub const fn n_clauses(&self) -> usize {
532        C
533    }
534
535    /// Returns current threshold.
536    #[inline(always)]
537    #[must_use]
538    pub fn threshold(&self) -> f32 {
539        self.t
540    }
541
542    /// Returns read-only access to clauses.
543    #[inline(always)]
544    #[must_use]
545    pub const fn clauses(&self) -> &[SmallClause<N>; C] {
546        &self.clauses
547    }
548
549    /// Sum of clause votes for input x.
550    #[inline]
551    #[must_use]
552    pub fn sum_votes(&self, x: &[u8; N]) -> i32 {
553        let mut sum = 0i32;
554        for i in 0..C {
555            // SAFETY: i < C, array size is exactly C
556            sum += unsafe { self.clauses.get_unchecked(i).vote(x) };
557        }
558        sum
559    }
560
561    /// Predicts class (0 or 1).
562    #[inline(always)]
563    #[must_use]
564    pub fn predict(&self, x: &[u8; N]) -> u8 {
565        if self.sum_votes(x) >= 0 { 1 } else { 0 }
566    }
567
568    /// Trains on single example.
569    pub fn train_one(&mut self, x: &[u8; N], y: u8, rng: &mut impl rand::Rng) {
570        let sum = (self.sum_votes(x) as f32).clamp(-self.t, self.t);
571        let inv_2t = 1.0 / (2.0 * self.t);
572        let s = self.s;
573
574        let prob = if y == 1 {
575            (self.t - sum) * inv_2t
576        } else {
577            (self.t + sum) * inv_2t
578        };
579
580        for i in 0..C {
581            // SAFETY: i < C
582            let clause = unsafe { self.clauses.get_unchecked_mut(i) };
583            let fires = clause.evaluate(x);
584            let p = clause.polarity();
585
586            if y == 1 {
587                if p == 1 && rng.random::<f32>() <= prob {
588                    small_type_i(clause, x, fires, s, rng);
589                } else if p == -1 && fires && rng.random::<f32>() <= prob {
590                    small_type_ii(clause, x);
591                }
592            } else if p == -1 && rng.random::<f32>() <= prob {
593                small_type_i(clause, x, fires, s, rng);
594            } else if p == 1 && fires && rng.random::<f32>() <= prob {
595                small_type_ii(clause, x);
596            }
597        }
598    }
599
600    /// Simple training for given epochs.
601    pub fn fit(&mut self, x: &[[u8; N]], y: &[u8], epochs: usize, seed: u64) {
602        let mut rng = crate::utils::rng_from_seed(seed);
603
604        for _ in 0..epochs {
605            for (xi, &yi) in x.iter().zip(y.iter()) {
606                self.train_one(xi, yi, &mut rng);
607            }
608        }
609    }
610
611    /// Evaluates accuracy on test data.
612    #[must_use]
613    pub fn evaluate(&self, x: &[[u8; N]], y: &[u8]) -> f32 {
614        if x.is_empty() {
615            return 0.0;
616        }
617        let correct = x
618            .iter()
619            .zip(y.iter())
620            .filter(|(xi, yi)| self.predict(xi) == **yi)
621            .count();
622        correct as f32 / x.len() as f32
623    }
624}
625
626/// Type I feedback for SmallClause (reinforces patterns).
627fn small_type_i<const N: usize>(
628    clause: &mut SmallClause<N>,
629    x: &[u8; N],
630    fires: bool,
631    s: f32,
632    rng: &mut impl rand::Rng
633) {
634    let prob_strengthen = (s - 1.0) / s;
635    let prob_weaken = 1.0 / s;
636
637    if !fires {
638        // Clause doesn't fire: weaken all automata
639        for k in 0..N {
640            if rng.random::<f32>() <= prob_weaken {
641                // SAFETY: k < N
642                unsafe {
643                    clause
644                        .include_automata_mut()
645                        .get_unchecked_mut(k)
646                        .decrement()
647                };
648            }
649            if rng.random::<f32>() <= prob_weaken {
650                unsafe {
651                    clause
652                        .negated_automata_mut()
653                        .get_unchecked_mut(k)
654                        .decrement()
655                };
656            }
657        }
658    } else {
659        // Clause fires: reinforce matching pattern
660        for k in 0..N {
661            // SAFETY: k < N
662            let xk = unsafe { *x.get_unchecked(k) };
663
664            if xk == 1 {
665                if rng.random::<f32>() <= prob_strengthen {
666                    unsafe {
667                        clause
668                            .include_automata_mut()
669                            .get_unchecked_mut(k)
670                            .increment()
671                    };
672                }
673                if rng.random::<f32>() <= prob_weaken {
674                    unsafe {
675                        clause
676                            .negated_automata_mut()
677                            .get_unchecked_mut(k)
678                            .decrement()
679                    };
680                }
681            } else {
682                if rng.random::<f32>() <= prob_strengthen {
683                    unsafe {
684                        clause
685                            .negated_automata_mut()
686                            .get_unchecked_mut(k)
687                            .increment()
688                    };
689                }
690                if rng.random::<f32>() <= prob_weaken {
691                    unsafe {
692                        clause
693                            .include_automata_mut()
694                            .get_unchecked_mut(k)
695                            .decrement()
696                    };
697                }
698            }
699        }
700    }
701}
702
703/// Type II feedback for SmallClause (corrects false positives).
704fn small_type_ii<const N: usize>(clause: &mut SmallClause<N>, x: &[u8; N]) {
705    for k in 0..N {
706        // SAFETY: k < N
707        let xk = unsafe { *x.get_unchecked(k) };
708        let inc_action = unsafe { clause.include_automata().get_unchecked(k).action() };
709        let neg_action = unsafe { clause.negated_automata().get_unchecked(k).action() };
710
711        if xk == 0 && !inc_action {
712            unsafe {
713                clause
714                    .include_automata_mut()
715                    .get_unchecked_mut(k)
716                    .increment()
717            };
718        }
719        if xk == 1 && !neg_action {
720            unsafe {
721                clause
722                    .negated_automata_mut()
723                    .get_unchecked_mut(k)
724                    .increment()
725            };
726        }
727    }
728}
729
730/// 2-feature, 20-clause machine (XOR).
731pub type TM2x20 = SmallTsetlinMachine<2, 20>;
732/// 4-feature, 40-clause machine.
733pub type TM4x40 = SmallTsetlinMachine<4, 40>;
734/// 8-feature, 80-clause machine.
735pub type TM8x80 = SmallTsetlinMachine<8, 80>;
736/// 16-feature, 160-clause machine.
737pub type TM16x160 = SmallTsetlinMachine<16, 160>;
738
739#[cfg(test)]
740mod tests {
741    use super::*;
742
743    #[test]
744    fn small_clause_new() {
745        let c: SmallClause<4> = SmallClause::new(100, 1);
746        assert_eq!(c.n_features(), 4);
747        assert_eq!(c.polarity(), 1);
748        assert!((c.weight() - 1.0).abs() < 0.001);
749        assert_eq!(c.activations(), 0);
750    }
751
752    #[test]
753    fn small_clause_evaluate() {
754        let c: SmallClause<4> = SmallClause::new(100, 1);
755        let x = [0, 1, 0, 1];
756        assert!(c.evaluate(&x));
757    }
758
759    #[test]
760    fn small_clause_vote() {
761        let c: SmallClause<4> = SmallClause::new(100, -1);
762        let x = [1, 1, 1, 1];
763        assert_eq!(c.vote(&x), -1);
764    }
765
766    #[test]
767    fn small_clause_weighted_vote() {
768        let mut c: SmallClause<4> = SmallClause::new(100, 1);
769        c.weight = 0.5;
770        let x = [0, 0, 0, 0];
771        assert!((c.vote_weighted(&x) - 0.5).abs() < 0.001);
772    }
773
774    #[test]
775    fn small_clause_activation_tracking() {
776        let mut c: SmallClause<2> = SmallClause::new(100, 1);
777        c.evaluate_tracked(&[0, 0]);
778        c.evaluate_tracked(&[1, 1]);
779        assert_eq!(c.activations(), 2);
780    }
781
782    #[test]
783    fn small_clause_weight_update() {
784        let mut c: SmallClause<2> = SmallClause::new(100, 1);
785        c.correct = 8;
786        c.incorrect = 2;
787        c.update_weight(0.1, 0.1, 2.0);
788        assert!(c.weight() > 1.0);
789    }
790
791    #[test]
792    fn small_clause_is_dead() {
793        let mut c: SmallClause<2> = SmallClause::new(100, 1);
794        c.weight = 0.05;
795        assert!(c.is_dead(10, 0.1));
796    }
797
798    #[test]
799    fn small_bitwise_clause_new() {
800        let c: SmallBitwiseClause<64, 1> = SmallBitwiseClause::new(100, 1);
801        assert_eq!(c.n_features(), 64);
802        assert_eq!(c.polarity(), 1);
803    }
804
805    #[test]
806    fn small_bitwise_evaluate() {
807        let mut c: BitwiseClause64 = SmallBitwiseClause::new(100, 1);
808        c.rebuild_masks();
809
810        let x_packed = [0xFFFF_FFFF_FFFF_FFFFu64];
811        assert!(c.evaluate_packed(&x_packed));
812    }
813
814    #[test]
815    fn small_bitwise_violation() {
816        let mut c: BitwiseClause64 = SmallBitwiseClause::new(100, 1);
817
818        // Force include[0] to be active
819        for _ in 0..200 {
820            c.include_automata_mut()[0].increment();
821        }
822        c.rebuild_masks();
823
824        // x[0] = 0, should violate
825        assert!(!c.evaluate_packed(&[0u64]));
826
827        // x[0] = 1, should pass
828        assert!(c.evaluate_packed(&[1u64]));
829    }
830
831    #[test]
832    fn pack_input_small_test() {
833        let x: [u8; 8] = [1, 0, 1, 1, 0, 0, 0, 1];
834        let packed: [u64; 1] = pack_input_small(&x);
835        assert_eq!(packed[0], 0b10001101); // bits 0,2,3,7 set
836    }
837
838    #[test]
839    fn small_tm_new() {
840        let tm: SmallTsetlinMachine<2, 20> = SmallTsetlinMachine::new(100, 15);
841        assert_eq!(tm.n_features(), 2);
842        assert_eq!(tm.n_clauses(), 20);
843        assert!((tm.threshold() - 15.0).abs() < 0.001);
844    }
845
846    #[test]
847    fn small_tm_xor_convergence() {
848        let mut tm: SmallTsetlinMachine<2, 20> = SmallTsetlinMachine::new(100, 10);
849
850        let x = [[0, 0], [0, 1], [1, 0], [1, 1]];
851        let y = [0u8, 1, 1, 0];
852
853        tm.fit(&x, &y, 200, 42);
854        assert!(tm.evaluate(&x, &y) >= 0.75);
855    }
856
857    #[test]
858    fn small_tm_type_alias() {
859        let tm: TM2x20 = TM2x20::new(100, 10);
860        assert_eq!(tm.n_features(), 2);
861        assert_eq!(tm.n_clauses(), 20);
862    }
863}