1mod die;
2use cfg_if::cfg_if;
3use std::collections::HashMap;
4use std::convert::TryFrom;
5
6pub trait Element: Eq + PartialEq + Copy + Clone + std::hash::Hash {}
7impl<T> Element for T where T: Eq + PartialEq + Copy + Clone + std::hash::Hash {}
8
9pub struct MarkovChain<T: Element> {
11 order: usize,
16 probability_map: HashMap<Vec<Option<T>>, die::WeightedDie<T>>,
20 optional_elements: Vec<usize>,
21}
22
23impl<T: Element> MarkovChain<T> {
24 pub fn new(order: usize, optional_elements: &[usize]) -> Self {
43 let opts: Vec<usize> = optional_elements
45 .clone()
46 .into_iter()
47 .map(|i| *i)
48 .filter(|i| *i < order)
49 .collect();
50 MarkovChain {
51 order,
52 probability_map: HashMap::<Vec<Option<T>>, die::WeightedDie<T>>::new(),
53 optional_elements: opts,
54 }
55 }
56
57 fn to_partial_key(order: usize, view: &[Option<T>]) -> Vec<Option<T>> {
59 view.into_iter()
60 .skip(view.len() - order)
61 .take(order)
62 .cloned()
63 .collect()
64 }
65
66 fn to_full_key(order: usize, view: &[T]) -> Vec<Option<T>> {
68 view.into_iter()
69 .skip(view.len() - order)
70 .take(order)
71 .cloned()
72 .map(|e| Some(e))
73 .collect()
74 }
75
76 fn permute(
77 key: Vec<Option<T>>,
78 optionals: Vec<usize>,
79 mut perms: Vec<Vec<Option<T>>>,
80 ) -> Vec<Vec<Option<T>>> {
81 if optionals.len() == 0 {
82 perms.push(key);
83 perms
84 } else {
85 let mut off = key.clone();
86 off[optionals[0]] = None;
87 let on = key.clone();
88
89 perms = Self::permute(off, optionals[1..].to_vec(), perms);
90 perms = Self::permute(on, optionals[1..].to_vec(), perms);
91 perms
92 }
93 }
94
95 fn permute_key(&mut self, key: Vec<T>) -> Vec<Vec<Option<T>>> {
98 let optioned_key: Vec<Option<T>> = key.into_iter().map(|e| Some(e)).collect();
99 Self::permute(optioned_key, self.optional_elements.clone(), vec![])
100 }
101
102 pub fn train(&mut self, view: &[T], result: T, weight_delta: i32) {
114 for partial_key in self.permute_key(view.clone().to_vec()) {
115 self.probability_map
117 .entry(partial_key)
118 .and_modify(|d| {
119 d.modify(result, weight_delta);
120 })
121 .or_insert((|| {
122 let mut d = die::WeightedDie::new();
123 d.modify(result, weight_delta);
124 d
125 })());
126 }
127 }
128
129 pub fn generate_deterministic_from_partial(
136 &self,
137 view: &[Option<T>],
138 rand_val: u64,
139 ) -> Option<T> {
140 let key = MarkovChain::to_partial_key(self.order, view);
141
142 match self.probability_map.get(&key) {
143 Some(v) => v.roll(Some(rand_val)),
144 None => None,
145 }
146 }
147
148 pub fn generate_deterministic(&self, view: &[T], rand_val: u64) -> Option<T> {
155 let key = MarkovChain::to_full_key(self.order, view);
156
157 match self.probability_map.get(&key) {
158 Some(v) => v.roll(Some(rand_val)),
159 None => None,
160 }
161 }
162
163 cfg_if! {
164 if #[cfg(feature = "rand")] {
165 pub fn generate(&self, view: &[T]) -> Option<T> {
170 let key = MarkovChain::to_full_key(self.order, view);
171
172 match self.probability_map.get(&key) {
173 Some(v) => v.roll(None),
174 None => None,
175 }
176 }
177
178 pub fn generate_from_partial(&self, view: &[Option<T>]) -> Option<T> {
183 let key = MarkovChain::to_partial_key(self.order, view);
184
185 match self.probability_map.get(&key) {
186 Some(v) => v.roll(None),
187 None => None,
188 }
189 }
190 }
191 }
192
193 pub fn probability(&self, view: &[Option<T>], result: T) -> f32 {
196 let key = MarkovChain::to_partial_key(self.order, view);
197
198 let map = self.probability_map.get_key_value(&key);
199 match map {
200 Some(v) => v.1.get_probability(result),
201 None => 0.0,
202 }
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn empty() {
212 let m0 = MarkovChain::new(0, &[]);
213 assert_eq!(m0.generate(&[]), None);
214 assert_eq!(m0.generate(&[1]), None);
215 assert_eq!(m0.generate_deterministic(&[], 33), None);
216
217 let m1 = MarkovChain::new(1, &[]);
218 assert_eq!(m1.generate(&[1]), None);
219 assert_eq!(m1.generate(&[1, 1]), None);
220 assert_eq!(m1.generate_deterministic(&[1], 33), None);
221
222 let m2 = MarkovChain::new(2, &[]);
223 assert_eq!(m2.generate(&[1, 1]), None);
224 assert_eq!(m2.generate(&[1, 1, 1]), None);
225 assert_eq!(m2.generate_deterministic(&[1, 1], 33), None);
226 }
227
228 #[test]
229 fn alphabet_first_order() {
230 let mut m = MarkovChain::new(1, &[]);
231
232 let alpha: Vec<char> = "abcdefghijklmnopqrstuvwxyz".chars().collect();
236 let encoded: Vec<u64> = alpha
237 .clone()
238 .into_iter()
239 .enumerate()
240 .map(|(i, _x)| i as u64)
241 .collect();
242
243 for i in m.order..encoded.len() {
244 m.train(&[encoded[i - 1]], encoded[i], 1);
245 }
246
247 for i in 0..(encoded.len() - 1) {
248 let next = m.generate(&[encoded[i]].clone());
249 match next {
250 Some(v) => assert_eq!(v, encoded[i + 1]),
251 None => panic!(
252 "can't predict next letter after {} (encoded as {})",
253 alpha[i], encoded[i]
254 ),
255 };
256 }
257 }
258
259 #[test]
260 fn alphabet_second_order() {
261 let mut m = MarkovChain::new(2, &[]);
262
263 let alpha: Vec<char> = "abcdefghijklmnopqrstuvwxyz".chars().collect();
267 let encoded: Vec<u64> = (0..alpha.len()).map(|x| x as u64).collect();
268
269 for i in m.order..encoded.len() {
270 m.train(&[encoded[i - 2], encoded[i - 1]], encoded[i], 1);
271 }
272
273 for i in 1..(encoded.len() - 1) {
274 let next = m.generate(&[encoded[i - 1], encoded[i].clone()]);
275 match next {
276 Some(v) => assert_eq!(v, encoded[i + 1]),
277 None => panic!(
278 "can't predict next letter after {} (encoded as {})",
279 alpha[i], encoded[i]
280 ),
281 };
282 }
283 }
284}