1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use rand::prelude::*;
use std::collections::HashMap;
#[derive(Clone, PartialEq, Debug)]
pub struct MarkovChain {
pub transition_prob: HashMap<String, Vec<(String, f32)>>,
}
impl MarkovChain {
pub fn add_state_choice(&mut self, key: &str, probability: (String, f32)) {
if self.transition_prob.contains_key(key) {
match self
.transition_prob
.get_mut(key)
.unwrap()
.iter()
.position(|x| x.0 == probability.0)
{
Some(x) => {
let vec_to_swap = self.transition_prob.get_mut(key).unwrap();
vec_to_swap.push(probability);
vec_to_swap.swap_remove(x);
}
None => self.transition_prob.get_mut(key).unwrap().push(probability),
}
} else {
let mut prob_vec = vec![];
prob_vec.push(probability);
self.transition_prob.insert(key.to_string(), prob_vec);
}
}
pub fn generate_states(&self, mut current_state: String, num_of_states: u16) -> Vec<String> {
let mut future_states: Vec<String> = vec![];
for _ in 0..num_of_states {
let next_state = self.next_state(current_state.to_string());
future_states.push(next_state.clone());
current_state = next_state;
}
return future_states;
}
pub fn new() -> MarkovChain {
MarkovChain {
transition_prob: HashMap::new(),
}
}
pub fn next_state(&self, current_state: String) -> String {
let probabilities = if self.transition_prob.contains_key(¤t_state) {
self.transition_prob.get(¤t_state)
} else {
None
};
match probabilities {
Some(x) => {
return x
.choose_weighted(&mut thread_rng(), |state_prob| state_prob.1)
.unwrap()
.0
.clone();
},
None => "".to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let m = MarkovChain::new();
let expected = MarkovChain {
transition_prob: HashMap::new(),
};
assert_eq!(m.transition_prob, expected.transition_prob);
}
#[test]
fn test_add_state_choice() {
let mut m = MarkovChain::new();
m.add_state_choice("a", ("c".to_string(), 0.8));
m.add_state_choice("a", ("b".to_string(), 0.19));
m.add_state_choice("a", ("a".to_string(), 0.01));
let mut expected_prob_vec = vec![];
expected_prob_vec.push(("c".to_string(), 0.8));
expected_prob_vec.push(("b".to_string(), 0.19));
expected_prob_vec.push(("a".to_string(), 0.01));
let mut expected_hash_map = HashMap::new();
expected_hash_map.insert("a", expected_prob_vec);
assert_eq!(m.transition_prob.get("a"), expected_hash_map.get("a"));
}
#[test]
fn test_next_state() {
let mut m = MarkovChain::new();
m.add_state_choice("a", ("b".to_string(), 1.0));
m.add_state_choice("a", ("c".to_string(), 0.0));
assert_eq!(m.next_state("a".to_string()), "b");
assert!(m.next_state("a".to_string()) != "c");
assert_eq!(m.next_state("b".to_string()), "");
}
#[test]
fn test_generate_states() {
let mut m = MarkovChain::new();
m.add_state_choice("a", ("b".to_string(), 1.0));
m.add_state_choice("b", ("c".to_string(), 1.0));
m.add_state_choice("c", ("a".to_string(), 1.0));
assert_eq!(m.generate_states("a".to_string(), 6), vec!["b", "c", "a", "b", "c", "a"]);
assert!(m.next_state("a".to_string()) != "c");
}
}