1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
use crate::MultiMarkov;
use log::{debug, info};
use rand::rngs::SmallRng;
use rand::{thread_rng, Rng, RngCore, SeedableRng};
use std::cmp::max;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::hash::Hash;
pub struct MultiMarkovBuilder<T>
where
T: Eq + Hash + Clone + std::cmp::Ord,
{
pub markov_chain: HashMap<Vec<T>, BTreeMap<T, f64>>,
pub known_states: HashSet<T>,
order: i32,
prior: Option<f64>,
rng: Box<dyn RngCore + Send + Sync>,
}
impl<T> MultiMarkovBuilder<T>
where
T: Eq + Hash + Clone + std::cmp::Ord,
{
/// Instantiate a new builder.
pub fn new() -> Self {
Self {
markov_chain: HashMap::new(),
known_states: HashSet::new(),
order: MultiMarkov::<T>::DEFAULT_ORDER,
prior: Some(MultiMarkov::<T>::DEFAULT_PRIOR),
rng: Box::new(Box::new(SmallRng::seed_from_u64(thread_rng().gen()))),
}
}
/// Specify the "order" of the Markov model. Must be a positive integer.
/// We recommend small values from about 1 to 3. Higher values will make the procedurally
/// generated data more similar to the training data, less random, and will make the process
/// slower and require more memory.
///
/// The default is MultiMarkov::DEFAULT_ORDER
pub fn with_order(mut self, order: i32) -> Self {
assert!(order > 0, "Order must be an integer greater than zero.");
self.order = order;
self
}
/// Specifies the "prior probability" of transition from any known state to any other known state,
/// if that transition was not observed in the training data. Small fractions are recommended,
/// so that this "true randomness" will be less common than transitions based on the training data.
///
/// The default is MultiMarkov::DEFAULT_PRIOR
pub fn with_prior(mut self, prior: f64) -> Self {
if prior == 0.0 {
self.prior = None;
} else {
self.prior = Some(prior);
}
self
}
/// Specifies that there will be no use of "prior probability" in this model. The only state
/// transitions possible will be those seen in the training data.
pub fn without_prior(mut self) -> Self {
self.prior = None;
self
}
/// Sets a custom Random Number Generator (RNG) for the model.
pub fn with_rng(mut self, rng: Box<dyn RngCore + Send + Sync>) -> Self {
self.rng = rng;
self
}
/// Ingest an iterator of sequences, adding the observed state transitions to the internal
/// statistical model.
pub fn train(mut self, sequences: impl Iterator<Item = Vec<T>>) -> Self {
let mut success_count: usize = 0;
let mut error_count: usize = 0;
for sequence in sequences {
match self.train_sequence(sequence) {
Ok(()) => success_count += 1,
Err(_) => error_count += 1,
};
}
debug!(
"{} sequences successfully trained; {} errors.",
success_count, error_count
);
self
}
/// Learn all the transitions possible from one training sequence, adding observations to the Markov model.
fn train_sequence(&mut self, sequence: Vec<T>) -> Result<(), &str> {
if sequence.len() < 2 {
return Err("Sequence was too short, must contain at least two states.");
}
// loop backwards through the characters in the sequence
for i in (1..sequence.len()).rev() {
// Build a running set of all known characters while we're at it
self.known_states.insert(sequence[i].clone());
// For the sequences preceding character (i), record that character (i) was observed following them.
// IE if the char_vec is ['R','U','S','T'] and this is a 3rd-order model, then for the three models ['S'], ['U','S'], and ['R','U','S'] we record that ['T'] is a known follower.
for j in (max(0, i as i32 - self.order) as usize)..i {
if let Some(transitions_from) = self.markov_chain.get_mut(&sequence[j..i]) {
// "from" sequence has been seen before
if let Some(weight) = transitions_from.get_mut(&sequence[i]) {
// it has been seen before with this transition; add one observance
*weight += 1.0;
} else {
// it hasn't been seen before with this transition; insert transition with one observance
transitions_from.insert(sequence[i].clone(), 1.0);
}
} else {
// "from" sequence hasn't been seen before; add it and add the observed transition
let mut observed_transition = BTreeMap::new();
observed_transition.insert(sequence[i].clone(), 1.0);
self.markov_chain
.insert(Vec::from(&sequence[j..i]), observed_transition);
}
// The following one-liner might accomplish all of the above, but is pretty hard on the eyes:
// *self.markov_chain.entry(Vec::from(&sequence[j..i])).or_insert(HashMap::new()).entry(sequence[i].clone()).or_insert(0.0) += 1.0;
}
}
Ok(())
}
/// Adds prior probabilities (if any) and builds the MultiMarkov object.
pub fn build(mut self) -> MultiMarkov<T> {
self.add_priors();
MultiMarkov {
markov_chain: self.markov_chain,
known_states: self.known_states,
order: self.order,
rng: self.rng,
}
}
/// Fills in missing state transitions with a given value so that any known state (except
/// those only seen at the end of sequences) can transition to any other known state.
/// Should be called after training is complete, because only then do we know the full set of
/// known states, and which transitions are unobserved.
fn add_priors(&mut self) {
let mut num_priors_added: usize = 0;
match self.prior {
Some(p) => {
for v in self.markov_chain.values_mut() {
for a in self.known_states.iter() {
v.entry(a.clone()).or_insert_with(|| {
num_priors_added += 1;
p
});
}
}
info!(
"Model has {} known states and {} trained sequences. {} priors added.",
self.markov_chain.len(),
self.known_states.len(),
num_priors_added
);
}
None => (),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn char_data() -> Vec<Vec<char>> {
vec![
vec!['a'], // can't be used, but should be skipped over rather than causing error to propagate
vec!['a', 'c', 'e'],
vec!['f', 'o', 'o', 'b', 'a', 'r'],
vec!['b', 'a', 'z'],
]
}
fn string_data() -> Vec<Vec<String>> {
vec![
vec![String::from("a")], // can't be used, but should be skipped over rather than causing error to propagate
vec![String::from("a"), String::from("c"), String::from("e")],
vec![
String::from("f"),
String::from("o"),
String::from("o"),
String::from("b"),
String::from("a"),
String::from("r"),
],
vec![String::from("b"), String::from("a"), String::from("z")],
]
}
#[test]
fn test_can_train_char_sequences() {
let _mm = MultiMarkov::<char>::builder()
.with_order(2)
.train(char_data().into_iter());
}
#[test]
fn test_can_train_string_sequences() {
let _mm = MultiMarkov::<String>::builder()
.with_order(2)
.train(string_data().into_iter());
}
#[test]
fn sequences_in_training_show_up_in_model() {
let mm = MultiMarkov::<char>::builder()
.with_order(2)
.train(char_data().into_iter());
// 'e' comes after 'c' (end of 2nd sequence trained properly)
assert!(mm.markov_chain.get(&*vec!['c']).unwrap().contains_key(&'e'));
// 'a' -> 'c' (beginning of 2nd sequence trained properly)
assert!(mm.markov_chain.get(&*vec!['a']).unwrap().contains_key(&'c'));
// a second-order sequence: ['a','c'] -> 'e'
assert!(mm
.markov_chain
.get(&*vec!['a', 'c'])
.unwrap()
.contains_key(&'e'));
// 'b' -> 'a' observed twice
assert_eq!(
*mm.markov_chain.get(&*vec!['b']).unwrap().get(&'a').unwrap(),
2.0
);
// 'z' is in the alphabet of known states, but has no transitions because it was only seen at the end of a sequence
assert!(mm.known_states.contains(&'z'));
assert!(!mm.markov_chain.contains_key(&*vec!['z']));
// we haven't added priors yet, so there should be no transition from 'a' -> 'b' available
assert!(!mm.markov_chain.get(&*vec!['a']).unwrap().contains_key(&'b'));
}
#[test]
fn can_set_priors_and_they_work() {
let mm = MultiMarkov::<char>::builder()
.with_order(2)
.train(char_data().into_iter())
.with_prior(0.015)
.build();
// prior should be set for a non-observed transition such as 'a' -> 'b'
assert!(mm.markov_chain.get(&*vec!['a']).unwrap().contains_key(&'b'));
assert_eq!(
*mm.markov_chain.get(&*vec!['a']).unwrap().get(&'b').unwrap(),
0.015
);
}
#[test]
fn make_sure_it_works_with_strings_too() {
let mm = MultiMarkov::<String>::builder()
.with_order(2)
.train(string_data().into_iter())
.with_prior(0.011)
.build();
// prior should be set for a non-observed transition such as 'a' -> 'b'
assert!(mm
.markov_chain
.get(&*vec![String::from("a")])
.unwrap()
.contains_key(&String::from("b")));
assert_eq!(
*mm.markov_chain
.get(&*vec![String::from("a")])
.unwrap()
.get(&String::from("b"))
.unwrap(),
0.011
);
}
#[test]
fn can_specify_no_priors_and_build() {
let mm = MultiMarkov::<char>::builder()
.with_order(2)
.train(char_data().into_iter())
.without_prior()
.build();
// a non-observed transition such as 'a' -> 'b' should have no entry in the model
assert!(!mm.markov_chain.get(&*vec!['a']).unwrap().contains_key(&'b'));
}
#[test]
#[should_panic(expected = "Order must be an integer greater than zero.")]
fn order_cannot_be_zero_or_negative() {
let _mm = MultiMarkov::<char>::builder()
.with_order(0)
.train(char_data().into_iter());
}
#[test]
fn test_that_seeded_rngs_give_the_same_output_every_time() {
use rand::{rngs::SmallRng, SeedableRng};
let mut mm1 = MultiMarkov::<char>::builder()
.with_rng(Box::new(SmallRng::seed_from_u64(1234)))
.train(char_data().into_iter())
.without_prior()
.build();
let mut mm2 = MultiMarkov::<char>::builder()
.with_rng(Box::new(SmallRng::seed_from_u64(1234)))
.train(char_data().into_iter())
.without_prior()
.build();
assert_eq!(mm1.random_next(&vec!['a']), mm2.random_next(&vec!['a']));
assert_eq!(mm1.random_next(&vec!['a']), mm2.random_next(&vec!['a']));
assert_eq!(mm1.random_next(&vec!['a']), mm2.random_next(&vec!['a']));
assert_eq!(mm1.random_next(&vec!['a']), mm2.random_next(&vec!['a']));
assert_eq!(mm1.random_next(&vec!['a']), mm2.random_next(&vec!['a']));
assert_eq!(mm1.random_next(&vec!['a']), mm2.random_next(&vec!['a']));
assert_eq!(mm1.random_next(&vec!['a']), mm2.random_next(&vec!['a']));
assert_eq!(mm1.random_next(&vec!['a']), mm2.random_next(&vec!['a']));
assert_eq!(mm1.random_next(&vec!['a']), mm2.random_next(&vec!['a']));
assert_eq!(mm1.random_next(&vec!['a']), mm2.random_next(&vec!['a']));
}
}