Skip to main content

markovr/
lib.rs

1mod die;
2use cfg_if::cfg_if;
3use std::collections::HashMap;
4use std::convert::TryFrom;
5
6pub trait Element: Eq + PartialEq + Copy + Clone + std::hash::Hash {}
7impl<T> Element for T where T: Eq + PartialEq + Copy + Clone + std::hash::Hash {}
8
9/// Variable-order Markov chain.
10pub struct MarkovChain<T: Element> {
11    // the 'memory' for the MarkovChain chain.
12    // 1 is the typical MarkovChain chain that only looks
13    // back 1 element.
14    // 0 is functionally equivalent to a weighteddie.
15    order: usize,
16    // the number of elements in the key should
17    // exactly equal the order of the MarkovChain chain.
18    // missing elements should be represented as None.
19    probability_map: HashMap<Vec<Option<T>>, die::WeightedDie<T>>,
20    optional_elements: Vec<usize>,
21}
22
23impl<T: Element> MarkovChain<T> {
24    /// Creates a new MarkovChain.
25    ///
26    /// 'order' is the order of the Markov Chain.
27    ///
28    /// A value of 1 is your typical Markov Chain,
29    /// that only looks back one place.
30    ///
31    /// A value of 0 is just a weighted die, since
32    /// there is no memory.
33    ///
34    /// You could think of order as the shape of the
35    /// input tensor, where the input tensor is the
36    /// sliding view / window.
37    ///
38    /// 'optional_keys' allows you to specify None for
39    /// the elements in the given indices during
40    /// generation. This drastically increases the memory
41    /// usage by 2^optional_elements.len().
42    pub fn new(order: usize, optional_elements: &[usize]) -> Self {
43        // filter out optional elements that are too big.
44        let opts: Vec<usize> = optional_elements
45            .clone()
46            .into_iter()
47            .map(|i| *i)
48            .filter(|i| *i < order)
49            .collect();
50        MarkovChain {
51            order,
52            probability_map: HashMap::<Vec<Option<T>>, die::WeightedDie<T>>::new(),
53            optional_elements: opts,
54        }
55    }
56
57    /// Truncates elements as needed
58    fn to_partial_key(order: usize, view: &[Option<T>]) -> Vec<Option<T>> {
59        view.into_iter()
60            .skip(view.len() - order)
61            .take(order)
62            .cloned()
63            .collect()
64    }
65
66    /// Truncates elements as needed
67    fn to_full_key(order: usize, view: &[T]) -> Vec<Option<T>> {
68        view.into_iter()
69            .skip(view.len() - order)
70            .take(order)
71            .cloned()
72            .map(|e| Some(e))
73            .collect()
74    }
75
76    fn permute(
77        key: Vec<Option<T>>,
78        optionals: Vec<usize>,
79        mut perms: Vec<Vec<Option<T>>>,
80    ) -> Vec<Vec<Option<T>>> {
81        if optionals.len() == 0 {
82            perms.push(key);
83            perms
84        } else {
85            let mut off = key.clone();
86            off[optionals[0]] = None;
87            let on = key.clone();
88
89            perms = Self::permute(off, optionals[1..].to_vec(), perms);
90            perms = Self::permute(on, optionals[1..].to_vec(), perms);
91            perms
92        }
93    }
94
95    // this generates 2^(number of optional keys) keys.
96    // this is used during training.
97    fn permute_key(&mut self, key: Vec<T>) -> Vec<Vec<Option<T>>> {
98        let optioned_key: Vec<Option<T>> = key.into_iter().map(|e| Some(e)).collect();
99        Self::permute(optioned_key, self.optional_elements.clone(), vec![])
100    }
101
102    /// Feeds training data into the model.
103    ///
104    /// 'view' is the sliding window of elements to load
105    /// into the MarkovChain chain. the number of elements in
106    /// view should be self.order + 1 (excess will be ignored).
107    /// the last element
108    /// in view is the element to increase the weight of.
109    ///
110    /// 'weight_delta' should be the number of times we're
111    /// loading this view into the model (typically 1 at
112    /// a time).
113    pub fn train(&mut self, view: &[T], result: T, weight_delta: i32) {
114        for partial_key in self.permute_key(view.clone().to_vec()) {
115            // Train not just on the full key, but all partial ones as well.
116            self.probability_map
117                .entry(partial_key)
118                .and_modify(|d| {
119                    d.modify(result, weight_delta);
120                })
121                .or_insert((|| {
122                    let mut d = die::WeightedDie::new();
123                    d.modify(result, weight_delta);
124                    d
125                })());
126        }
127    }
128
129    /// Generates the next value, given the previous item(s).
130    ///
131    /// view is the sliding window of the latest elements.
132    /// only the last self.order elements are looked at.
133    ///
134    /// rand_val allows for a deterministic result, if supplied.
135    pub fn generate_deterministic_from_partial(
136        &self,
137        view: &[Option<T>],
138        rand_val: u64,
139    ) -> Option<T> {
140        let key = MarkovChain::to_partial_key(self.order, view);
141
142        match self.probability_map.get(&key) {
143            Some(v) => v.roll(Some(rand_val)),
144            None => None,
145        }
146    }
147
148    /// Generates the next value, given the previous item(s).
149    ///
150    /// view is the sliding window of the latest elements.
151    /// only the last self.order elements are looked at.
152    ///
153    /// rand_val allows for a deterministic result, if supplied.
154    pub fn generate_deterministic(&self, view: &[T], rand_val: u64) -> Option<T> {
155        let key = MarkovChain::to_full_key(self.order, view);
156
157        match self.probability_map.get(&key) {
158            Some(v) => v.roll(Some(rand_val)),
159            None => None,
160        }
161    }
162
163    cfg_if! {
164        if #[cfg(feature = "rand")] {
165            /// Generates the next value, given the previous item(s).
166            ///
167            /// view is the sliding window of the latest elements.
168            /// only the last self.order elements are looked at.
169            pub fn generate(&self, view: &[T]) -> Option<T> {
170                let key = MarkovChain::to_full_key(self.order, view);
171
172                match self.probability_map.get(&key) {
173                    Some(v) => v.roll(None),
174                    None => None,
175                }
176            }
177
178            /// Generates the next value, given the previous item(s).
179            ///
180            /// view is the sliding window of the latest elements.
181            /// only the last self.order elements are looked at.
182            pub fn generate_from_partial(&self, view: &[Option<T>]) -> Option<T> {
183                let key = MarkovChain::to_partial_key(self.order, view);
184
185                match self.probability_map.get(&key) {
186                    Some(v) => v.roll(None),
187                    None => None,
188                }
189            }
190        }
191    }
192
193    /// Returns the probability of getting 'result', given
194    /// 'view'.
195    pub fn probability(&self, view: &[Option<T>], result: T) -> f32 {
196        let key = MarkovChain::to_partial_key(self.order, view);
197
198        let map = self.probability_map.get_key_value(&key);
199        match map {
200            Some(v) => v.1.get_probability(result),
201            None => 0.0,
202        }
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn empty() {
212        let m0 = MarkovChain::new(0, &[]);
213        assert_eq!(m0.generate(&[]), None);
214        assert_eq!(m0.generate(&[1]), None);
215        assert_eq!(m0.generate_deterministic(&[], 33), None);
216
217        let m1 = MarkovChain::new(1, &[]);
218        assert_eq!(m1.generate(&[1]), None);
219        assert_eq!(m1.generate(&[1, 1]), None);
220        assert_eq!(m1.generate_deterministic(&[1], 33), None);
221
222        let m2 = MarkovChain::new(2, &[]);
223        assert_eq!(m2.generate(&[1, 1]), None);
224        assert_eq!(m2.generate(&[1, 1, 1]), None);
225        assert_eq!(m2.generate_deterministic(&[1, 1], 33), None);
226    }
227
228    #[test]
229    fn alphabet_first_order() {
230        let mut m = MarkovChain::new(1, &[]);
231
232        // this could have just been a number range,
233        // but it serves as an example of how to encode
234        // an alphabet
235        let alpha: Vec<char> = "abcdefghijklmnopqrstuvwxyz".chars().collect();
236        let encoded: Vec<u64> = alpha
237            .clone()
238            .into_iter()
239            .enumerate()
240            .map(|(i, _x)| i as u64)
241            .collect();
242
243        for i in m.order..encoded.len() {
244            m.train(&[encoded[i - 1]], encoded[i], 1);
245        }
246
247        for i in 0..(encoded.len() - 1) {
248            let next = m.generate(&[encoded[i]].clone());
249            match next {
250                Some(v) => assert_eq!(v, encoded[i + 1]),
251                None => panic!(
252                    "can't predict next letter after {} (encoded as {})",
253                    alpha[i], encoded[i]
254                ),
255            };
256        }
257    }
258
259    #[test]
260    fn alphabet_second_order() {
261        let mut m = MarkovChain::new(2, &[]);
262
263        // this could have just been a number range,
264        // but it serves as an example of how to encode
265        // an alphabet
266        let alpha: Vec<char> = "abcdefghijklmnopqrstuvwxyz".chars().collect();
267        let encoded: Vec<u64> = (0..alpha.len()).map(|x| x as u64).collect();
268
269        for i in m.order..encoded.len() {
270            m.train(&[encoded[i - 2], encoded[i - 1]], encoded[i], 1);
271        }
272
273        for i in 1..(encoded.len() - 1) {
274            let next = m.generate(&[encoded[i - 1], encoded[i].clone()]);
275            match next {
276                Some(v) => assert_eq!(v, encoded[i + 1]),
277                None => panic!(
278                    "can't predict next letter after {} (encoded as {})",
279                    alpha[i], encoded[i]
280                ),
281            };
282        }
283    }
284}