1#![warn(clippy::all, clippy::nursery, clippy::pedantic, clippy::cargo)]
2
3use std::cmp::Ordering;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum Choice {
8 Left,
9 Right,
10}
11
12impl Choice {
13 const fn to_bit(self) -> u32 {
14 match self {
15 Self::Left => 0,
16 Self::Right => 1,
17 }
18 }
19
20 pub fn display<T>(self, choices: [T; 2]) -> T {
22 let [left, right] = choices;
23
24 match self {
25 Self::Left => left,
26 Self::Right => right,
27 }
28 }
29}
30
31impl From<bool> for Choice {
32 fn from(b: bool) -> Self {
33 if b { Self::Right } else { Self::Left }
34 }
35}
36
37pub struct Predictor {
38 n: usize,
40 state: u32,
42 count: usize,
44 grams: Vec<Prediction>,
46 pub total_predictions: usize,
48 pub correct_predictions: usize,
50}
51
52impl Predictor {
53 #[must_use]
58 pub fn new(n: usize) -> Self {
59 Self {
60 n,
61 state: 0,
62 count: 0,
63 total_predictions: 0,
64 correct_predictions: 0,
65 grams: vec![Prediction::default(); 1 << n],
66 }
67 }
68
69 pub fn predict(&mut self, choice: Choice) -> Option<Choice> {
71 if self.count < self.n {
72 self.state = (self.state << 1) | choice.to_bit();
73 self.count += 1;
74 return None;
75 }
76
77 let prediction = self.grams[self.state as usize].predict();
78 self.grams[self.state as usize].register(choice);
79
80 self.state = ((self.state << 1) | choice.to_bit()) & ((1 << self.n) - 1);
81
82 self.total_predictions += 1;
83 if prediction == choice {
84 self.correct_predictions += 1;
85 }
86
87 Some(prediction)
88 }
89}
90
91#[derive(Debug, Clone, Copy, Default)]
92struct Prediction {
94 left: usize,
95 right: usize,
96}
97
98impl Prediction {
99 fn predict(&self) -> Choice {
101 match self.left.cmp(&self.right) {
102 Ordering::Less => Choice::Right,
103 Ordering::Greater => Choice::Left,
104 Ordering::Equal => Choice::from(rand::random::<bool>()),
105 }
106 }
107
108 const fn register(&mut self, choice: Choice) {
110 match choice {
111 Choice::Left => self.left += 1,
112 Choice::Right => self.right += 1,
113 }
114 }
115}