aaronson_oracle/
lib.rs

1#![warn(clippy::all, clippy::nursery, clippy::pedantic, clippy::cargo)]
2
3use std::cmp::Ordering;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6/// An abstract representation of two possible actions.
7pub enum Choice {
8    Left,
9    Right,
10}
11
12impl Choice {
13    const fn to_bit(self) -> u32 {
14        match self {
15            Self::Left => 0,
16            Self::Right => 1,
17        }
18    }
19
20    /// Represent the choice as one of two possible values.
21    pub fn display<T>(self, choices: [T; 2]) -> T {
22        let [left, right] = choices;
23
24        match self {
25            Self::Left => left,
26            Self::Right => right,
27        }
28    }
29}
30
31impl From<bool> for Choice {
32    fn from(b: bool) -> Self {
33        if b { Self::Right } else { Self::Left }
34    }
35}
36
37pub struct Predictor {
38    /// The size of the n-grams used by the predictor.
39    n: usize,
40    /// The current state of the predictor.
41    state: u32,
42    /// The number of bits in the state. It will fall out of sync once it reaches `n`.
43    count: usize,
44    /// A vector of predictions, indexed by the state.
45    grams: Vec<Prediction>,
46    /// The total number of predictions made.
47    pub total_predictions: usize,
48    /// The number of correct predictions made.
49    pub correct_predictions: usize,
50}
51
52impl Predictor {
53    /// Create a new predictor with the given n-gram size.
54    ///
55    /// # Arguments
56    /// * `n` - The size of the n-grams used by the predictor.
57    #[must_use]
58    pub fn new(n: usize) -> Self {
59        Self {
60            n,
61            state: 0,
62            count: 0,
63            total_predictions: 0,
64            correct_predictions: 0,
65            grams: vec![Prediction::default(); 1 << n],
66        }
67    }
68
69    /// Predict the next choice, and register the correct choice.
70    pub fn predict(&mut self, choice: Choice) -> Option<Choice> {
71        if self.count < self.n {
72            self.state = (self.state << 1) | choice.to_bit();
73            self.count += 1;
74            return None;
75        }
76
77        let prediction = self.grams[self.state as usize].predict();
78        self.grams[self.state as usize].register(choice);
79
80        self.state = ((self.state << 1) | choice.to_bit()) & ((1 << self.n) - 1);
81
82        self.total_predictions += 1;
83        if prediction == choice {
84            self.correct_predictions += 1;
85        }
86
87        Some(prediction)
88    }
89}
90
91#[derive(Debug, Clone, Copy, Default)]
92/// A prediction for a given n-gram.
93struct Prediction {
94    left: usize,
95    right: usize,
96}
97
98impl Prediction {
99    /// Predict the next choice based on the current n-gram.
100    fn predict(&self) -> Choice {
101        match self.left.cmp(&self.right) {
102            Ordering::Less => Choice::Right,
103            Ordering::Greater => Choice::Left,
104            Ordering::Equal => Choice::from(rand::random::<bool>()),
105        }
106    }
107
108    /// Register the given choice.
109    const fn register(&mut self, choice: Choice) {
110        match choice {
111            Choice::Left => self.left += 1,
112            Choice::Right => self.right += 1,
113        }
114    }
115}