1use core::array;
8
9use crate::Automaton;
10
11#[derive(Debug, Clone)]
33#[repr(align(64))]
34pub struct SmallClause<const N: usize> {
35 include: [Automaton; N],
36 negated: [Automaton; N],
37 weight: f32,
38 activations: u32,
39 correct: u32,
40 incorrect: u32,
41 polarity: i8
42}
43
44impl<const N: usize> SmallClause<N> {
45 #[inline]
56 #[must_use]
57 pub fn new(n_states: i16, polarity: i8) -> Self {
58 debug_assert!(polarity == 1 || polarity == -1);
59 Self {
60 include: array::from_fn(|_| Automaton::new(n_states)),
61 negated: array::from_fn(|_| Automaton::new(n_states)),
62 weight: 1.0,
63 activations: 0,
64 correct: 0,
65 incorrect: 0,
66 polarity
67 }
68 }
69
70 #[inline(always)]
72 #[must_use]
73 pub const fn polarity(&self) -> i8 {
74 self.polarity
75 }
76
77 #[inline(always)]
79 #[must_use]
80 pub const fn n_features(&self) -> usize {
81 N
82 }
83
84 #[inline(always)]
86 #[must_use]
87 pub const fn weight(&self) -> f32 {
88 self.weight
89 }
90
91 #[inline(always)]
93 #[must_use]
94 pub const fn activations(&self) -> u32 {
95 self.activations
96 }
97
98 #[inline(always)]
100 #[must_use]
101 pub const fn include_automata(&self) -> &[Automaton; N] {
102 &self.include
103 }
104
105 #[inline(always)]
107 pub fn include_automata_mut(&mut self) -> &mut [Automaton; N] {
108 &mut self.include
109 }
110
111 #[inline(always)]
113 #[must_use]
114 pub const fn negated_automata(&self) -> &[Automaton; N] {
115 &self.negated
116 }
117
118 #[inline(always)]
120 pub fn negated_automata_mut(&mut self) -> &mut [Automaton; N] {
121 &mut self.negated
122 }
123
124 #[inline]
133 #[must_use]
134 pub fn evaluate(&self, x: &[u8; N]) -> bool {
135 for k in 0..N {
136 let include_action = unsafe { self.include.get_unchecked(k).action() };
139 let negated_action = unsafe { self.negated.get_unchecked(k).action() };
140 let xk = unsafe { *x.get_unchecked(k) };
141
142 if include_action && xk == 0 {
143 return false;
144 }
145 if negated_action && xk == 1 {
146 return false;
147 }
148 }
149 true
150 }
151
152 #[inline]
156 pub fn evaluate_tracked(&mut self, x: &[u8; N]) -> bool {
157 let fires = self.evaluate(x);
158 if fires {
159 self.activations = self.activations.saturating_add(1);
160 }
161 fires
162 }
163
164 #[inline(always)]
167 #[must_use]
168 pub fn vote_weighted(&self, x: &[u8; N]) -> f32 {
169 if self.evaluate(x) {
170 self.polarity as f32 * self.weight
171 } else {
172 0.0
173 }
174 }
175
176 #[inline(always)]
178 #[must_use]
179 pub fn vote(&self, x: &[u8; N]) -> i32 {
180 if self.evaluate(x) {
181 self.polarity as i32
182 } else {
183 0
184 }
185 }
186
187 #[inline]
191 pub fn record_outcome(&mut self, was_correct: bool) {
192 if was_correct {
193 self.correct = self.correct.saturating_add(1);
194 } else {
195 self.incorrect = self.incorrect.saturating_add(1);
196 }
197 }
198
199 pub fn update_weight(&mut self, learning_rate: f32, min_weight: f32, max_weight: f32) {
210 let total = self.correct + self.incorrect;
211 if total == 0 {
212 return;
213 }
214
215 let accuracy = self.correct as f32 / total as f32;
216 let adjustment = (accuracy - 0.5) * 2.0 * learning_rate;
217 self.weight = (self.weight + adjustment).clamp(min_weight, max_weight);
218
219 self.correct = 0;
220 self.incorrect = 0;
221 }
222
223 #[inline]
228 #[must_use]
229 pub const fn is_dead(&self, min_activations: u32, min_weight: f32) -> bool {
230 self.activations < min_activations || self.weight < min_weight
231 }
232
233 #[inline]
235 pub fn reset_activations(&mut self) {
236 self.activations = 0;
237 }
238
239 #[inline]
241 pub fn reset_stats(&mut self) {
242 self.activations = 0;
243 self.correct = 0;
244 self.incorrect = 0;
245 }
246}
247
248pub type Clause2 = SmallClause<2>;
250pub type Clause4 = SmallClause<4>;
252pub type Clause8 = SmallClause<8>;
254pub type Clause16 = SmallClause<16>;
256pub type Clause32 = SmallClause<32>;
258pub type Clause64 = SmallClause<64>;
260
261#[derive(Debug, Clone)]
272#[repr(align(64))]
273pub struct SmallBitwiseClause<const N: usize, const W: usize> {
274 include: [Automaton; N],
275 negated: [Automaton; N],
276 inc_mask: [u64; W],
277 neg_mask: [u64; W],
278 weight: f32,
279 polarity: i8,
280 dirty: bool
281}
282
283impl<const N: usize, const W: usize> SmallBitwiseClause<N, W> {
284 #[inline]
290 #[must_use]
291 pub fn new(n_states: i16, polarity: i8) -> Self {
292 debug_assert!(polarity == 1 || polarity == -1);
293 debug_assert_eq!(W, N.div_ceil(64), "W must equal ceil(N/64)");
294 Self {
295 include: array::from_fn(|_| Automaton::new(n_states)),
296 negated: array::from_fn(|_| Automaton::new(n_states)),
297 inc_mask: [0; W],
298 neg_mask: [0; W],
299 weight: 1.0,
300 polarity,
301 dirty: true
302 }
303 }
304
305 #[inline(always)]
307 #[must_use]
308 pub const fn polarity(&self) -> i8 {
309 self.polarity
310 }
311
312 #[inline(always)]
314 #[must_use]
315 pub const fn n_features(&self) -> usize {
316 N
317 }
318
319 #[inline(always)]
321 #[must_use]
322 pub const fn include_automata(&self) -> &[Automaton; N] {
323 &self.include
324 }
325
326 #[inline(always)]
328 pub fn include_automata_mut(&mut self) -> &mut [Automaton; N] {
329 self.dirty = true;
330 &mut self.include
331 }
332
333 #[inline(always)]
335 #[must_use]
336 pub const fn negated_automata(&self) -> &[Automaton; N] {
337 &self.negated
338 }
339
340 #[inline(always)]
342 pub fn negated_automata_mut(&mut self) -> &mut [Automaton; N] {
343 self.dirty = true;
344 &mut self.negated
345 }
346
347 pub fn rebuild_masks(&mut self) {
351 if !self.dirty {
352 return;
353 }
354
355 for word in &mut self.inc_mask {
356 *word = 0;
357 }
358 for word in &mut self.neg_mask {
359 *word = 0;
360 }
361
362 for k in 0..N {
363 let word_idx = k / 64;
364 let bit_idx = k % 64;
365
366 if unsafe { self.include.get_unchecked(k).action() } {
368 self.inc_mask[word_idx] |= 1u64 << bit_idx;
369 }
370 if unsafe { self.negated.get_unchecked(k).action() } {
371 self.neg_mask[word_idx] |= 1u64 << bit_idx;
372 }
373 }
374
375 self.dirty = false;
376 }
377
378 #[inline]
386 #[must_use]
387 pub fn evaluate_packed(&self, x_packed: &[u64; W]) -> bool {
388 debug_assert!(!self.dirty, "call rebuild_masks() first");
389
390 for i in 0..W {
391 let x = unsafe { *x_packed.get_unchecked(i) };
393 let inc = unsafe { *self.inc_mask.get_unchecked(i) };
394 let neg = unsafe { *self.neg_mask.get_unchecked(i) };
395
396 if (inc & !x) | (neg & x) != 0 {
399 return false;
400 }
401 }
402 true
403 }
404
405 #[inline(always)]
407 #[must_use]
408 pub fn vote_packed(&self, x_packed: &[u64; W]) -> i32 {
409 if self.evaluate_packed(x_packed) {
410 self.polarity as i32
411 } else {
412 0
413 }
414 }
415
416 #[inline(always)]
418 #[must_use]
419 pub fn vote_weighted_packed(&self, x_packed: &[u64; W]) -> f32 {
420 if self.evaluate_packed(x_packed) {
421 self.polarity as f32 * self.weight
422 } else {
423 0.0
424 }
425 }
426}
427
428#[inline]
432#[must_use]
433pub fn pack_input_small<const N: usize, const W: usize>(x: &[u8; N]) -> [u64; W] {
434 let mut packed = [0u64; W];
435
436 for (k, &xk) in x.iter().enumerate() {
437 if xk != 0 {
438 packed[k / 64] |= 1u64 << (k % 64);
439 }
440 }
441
442 packed
443}
444
445pub type BitwiseClause64 = SmallBitwiseClause<64, 1>;
447pub type BitwiseClause128 = SmallBitwiseClause<128, 2>;
449pub type BitwiseClause256 = SmallBitwiseClause<256, 4>;
451
452#[derive(Debug, Clone)]
485#[repr(align(64))]
486pub struct SmallTsetlinMachine<const N: usize, const C: usize> {
487 clauses: [SmallClause<N>; C],
488 s: f32,
489 t: f32
490}
491
492impl<const N: usize, const C: usize> SmallTsetlinMachine<N, C> {
493 #[must_use]
501 pub fn new(n_states: i16, threshold: i32) -> Self {
502 debug_assert!(C.is_multiple_of(2), "C must be even");
503 Self {
504 clauses: array::from_fn(|i| {
505 let p = if i % 2 == 0 { 1 } else { -1 };
506 SmallClause::new(n_states, p)
507 }),
508 s: 3.9,
509 t: threshold as f32
510 }
511 }
512
513 #[must_use]
515 pub fn with_s(n_states: i16, threshold: i32, s: f32) -> Self {
516 let mut tm = Self::new(n_states, threshold);
517 tm.s = s;
518 tm
519 }
520
521 #[inline(always)]
523 #[must_use]
524 pub const fn n_features(&self) -> usize {
525 N
526 }
527
528 #[inline(always)]
530 #[must_use]
531 pub const fn n_clauses(&self) -> usize {
532 C
533 }
534
535 #[inline(always)]
537 #[must_use]
538 pub fn threshold(&self) -> f32 {
539 self.t
540 }
541
542 #[inline(always)]
544 #[must_use]
545 pub const fn clauses(&self) -> &[SmallClause<N>; C] {
546 &self.clauses
547 }
548
549 #[inline]
551 #[must_use]
552 pub fn sum_votes(&self, x: &[u8; N]) -> i32 {
553 let mut sum = 0i32;
554 for i in 0..C {
555 sum += unsafe { self.clauses.get_unchecked(i).vote(x) };
557 }
558 sum
559 }
560
561 #[inline(always)]
563 #[must_use]
564 pub fn predict(&self, x: &[u8; N]) -> u8 {
565 if self.sum_votes(x) >= 0 { 1 } else { 0 }
566 }
567
568 pub fn train_one(&mut self, x: &[u8; N], y: u8, rng: &mut impl rand::Rng) {
570 let sum = (self.sum_votes(x) as f32).clamp(-self.t, self.t);
571 let inv_2t = 1.0 / (2.0 * self.t);
572 let s = self.s;
573
574 let prob = if y == 1 {
575 (self.t - sum) * inv_2t
576 } else {
577 (self.t + sum) * inv_2t
578 };
579
580 for i in 0..C {
581 let clause = unsafe { self.clauses.get_unchecked_mut(i) };
583 let fires = clause.evaluate(x);
584 let p = clause.polarity();
585
586 if y == 1 {
587 if p == 1 && rng.random::<f32>() <= prob {
588 small_type_i(clause, x, fires, s, rng);
589 } else if p == -1 && fires && rng.random::<f32>() <= prob {
590 small_type_ii(clause, x);
591 }
592 } else if p == -1 && rng.random::<f32>() <= prob {
593 small_type_i(clause, x, fires, s, rng);
594 } else if p == 1 && fires && rng.random::<f32>() <= prob {
595 small_type_ii(clause, x);
596 }
597 }
598 }
599
600 pub fn fit(&mut self, x: &[[u8; N]], y: &[u8], epochs: usize, seed: u64) {
602 let mut rng = crate::utils::rng_from_seed(seed);
603
604 for _ in 0..epochs {
605 for (xi, &yi) in x.iter().zip(y.iter()) {
606 self.train_one(xi, yi, &mut rng);
607 }
608 }
609 }
610
611 #[must_use]
613 pub fn evaluate(&self, x: &[[u8; N]], y: &[u8]) -> f32 {
614 if x.is_empty() {
615 return 0.0;
616 }
617 let correct = x
618 .iter()
619 .zip(y.iter())
620 .filter(|(xi, yi)| self.predict(xi) == **yi)
621 .count();
622 correct as f32 / x.len() as f32
623 }
624}
625
626fn small_type_i<const N: usize>(
628 clause: &mut SmallClause<N>,
629 x: &[u8; N],
630 fires: bool,
631 s: f32,
632 rng: &mut impl rand::Rng
633) {
634 let prob_strengthen = (s - 1.0) / s;
635 let prob_weaken = 1.0 / s;
636
637 if !fires {
638 for k in 0..N {
640 if rng.random::<f32>() <= prob_weaken {
641 unsafe {
643 clause
644 .include_automata_mut()
645 .get_unchecked_mut(k)
646 .decrement()
647 };
648 }
649 if rng.random::<f32>() <= prob_weaken {
650 unsafe {
651 clause
652 .negated_automata_mut()
653 .get_unchecked_mut(k)
654 .decrement()
655 };
656 }
657 }
658 } else {
659 for k in 0..N {
661 let xk = unsafe { *x.get_unchecked(k) };
663
664 if xk == 1 {
665 if rng.random::<f32>() <= prob_strengthen {
666 unsafe {
667 clause
668 .include_automata_mut()
669 .get_unchecked_mut(k)
670 .increment()
671 };
672 }
673 if rng.random::<f32>() <= prob_weaken {
674 unsafe {
675 clause
676 .negated_automata_mut()
677 .get_unchecked_mut(k)
678 .decrement()
679 };
680 }
681 } else {
682 if rng.random::<f32>() <= prob_strengthen {
683 unsafe {
684 clause
685 .negated_automata_mut()
686 .get_unchecked_mut(k)
687 .increment()
688 };
689 }
690 if rng.random::<f32>() <= prob_weaken {
691 unsafe {
692 clause
693 .include_automata_mut()
694 .get_unchecked_mut(k)
695 .decrement()
696 };
697 }
698 }
699 }
700 }
701}
702
703fn small_type_ii<const N: usize>(clause: &mut SmallClause<N>, x: &[u8; N]) {
705 for k in 0..N {
706 let xk = unsafe { *x.get_unchecked(k) };
708 let inc_action = unsafe { clause.include_automata().get_unchecked(k).action() };
709 let neg_action = unsafe { clause.negated_automata().get_unchecked(k).action() };
710
711 if xk == 0 && !inc_action {
712 unsafe {
713 clause
714 .include_automata_mut()
715 .get_unchecked_mut(k)
716 .increment()
717 };
718 }
719 if xk == 1 && !neg_action {
720 unsafe {
721 clause
722 .negated_automata_mut()
723 .get_unchecked_mut(k)
724 .increment()
725 };
726 }
727 }
728}
729
730pub type TM2x20 = SmallTsetlinMachine<2, 20>;
732pub type TM4x40 = SmallTsetlinMachine<4, 40>;
734pub type TM8x80 = SmallTsetlinMachine<8, 80>;
736pub type TM16x160 = SmallTsetlinMachine<16, 160>;
738
739#[cfg(test)]
740mod tests {
741 use super::*;
742
743 #[test]
744 fn small_clause_new() {
745 let c: SmallClause<4> = SmallClause::new(100, 1);
746 assert_eq!(c.n_features(), 4);
747 assert_eq!(c.polarity(), 1);
748 assert!((c.weight() - 1.0).abs() < 0.001);
749 assert_eq!(c.activations(), 0);
750 }
751
752 #[test]
753 fn small_clause_evaluate() {
754 let c: SmallClause<4> = SmallClause::new(100, 1);
755 let x = [0, 1, 0, 1];
756 assert!(c.evaluate(&x));
757 }
758
759 #[test]
760 fn small_clause_vote() {
761 let c: SmallClause<4> = SmallClause::new(100, -1);
762 let x = [1, 1, 1, 1];
763 assert_eq!(c.vote(&x), -1);
764 }
765
766 #[test]
767 fn small_clause_weighted_vote() {
768 let mut c: SmallClause<4> = SmallClause::new(100, 1);
769 c.weight = 0.5;
770 let x = [0, 0, 0, 0];
771 assert!((c.vote_weighted(&x) - 0.5).abs() < 0.001);
772 }
773
774 #[test]
775 fn small_clause_activation_tracking() {
776 let mut c: SmallClause<2> = SmallClause::new(100, 1);
777 c.evaluate_tracked(&[0, 0]);
778 c.evaluate_tracked(&[1, 1]);
779 assert_eq!(c.activations(), 2);
780 }
781
782 #[test]
783 fn small_clause_weight_update() {
784 let mut c: SmallClause<2> = SmallClause::new(100, 1);
785 c.correct = 8;
786 c.incorrect = 2;
787 c.update_weight(0.1, 0.1, 2.0);
788 assert!(c.weight() > 1.0);
789 }
790
791 #[test]
792 fn small_clause_is_dead() {
793 let mut c: SmallClause<2> = SmallClause::new(100, 1);
794 c.weight = 0.05;
795 assert!(c.is_dead(10, 0.1));
796 }
797
798 #[test]
799 fn small_bitwise_clause_new() {
800 let c: SmallBitwiseClause<64, 1> = SmallBitwiseClause::new(100, 1);
801 assert_eq!(c.n_features(), 64);
802 assert_eq!(c.polarity(), 1);
803 }
804
805 #[test]
806 fn small_bitwise_evaluate() {
807 let mut c: BitwiseClause64 = SmallBitwiseClause::new(100, 1);
808 c.rebuild_masks();
809
810 let x_packed = [0xFFFF_FFFF_FFFF_FFFFu64];
811 assert!(c.evaluate_packed(&x_packed));
812 }
813
814 #[test]
815 fn small_bitwise_violation() {
816 let mut c: BitwiseClause64 = SmallBitwiseClause::new(100, 1);
817
818 for _ in 0..200 {
820 c.include_automata_mut()[0].increment();
821 }
822 c.rebuild_masks();
823
824 assert!(!c.evaluate_packed(&[0u64]));
826
827 assert!(c.evaluate_packed(&[1u64]));
829 }
830
831 #[test]
832 fn pack_input_small_test() {
833 let x: [u8; 8] = [1, 0, 1, 1, 0, 0, 0, 1];
834 let packed: [u64; 1] = pack_input_small(&x);
835 assert_eq!(packed[0], 0b10001101); }
837
838 #[test]
839 fn small_tm_new() {
840 let tm: SmallTsetlinMachine<2, 20> = SmallTsetlinMachine::new(100, 15);
841 assert_eq!(tm.n_features(), 2);
842 assert_eq!(tm.n_clauses(), 20);
843 assert!((tm.threshold() - 15.0).abs() < 0.001);
844 }
845
846 #[test]
847 fn small_tm_xor_convergence() {
848 let mut tm: SmallTsetlinMachine<2, 20> = SmallTsetlinMachine::new(100, 10);
849
850 let x = [[0, 0], [0, 1], [1, 0], [1, 1]];
851 let y = [0u8, 1, 1, 0];
852
853 tm.fit(&x, &y, 200, 42);
854 assert!(tm.evaluate(&x, &y) >= 0.75);
855 }
856
857 #[test]
858 fn small_tm_type_alias() {
859 let tm: TM2x20 = TM2x20::new(100, 10);
860 assert_eq!(tm.n_features(), 2);
861 assert_eq!(tm.n_clauses(), 20);
862 }
863}