lmarkov/
lib.rs

1#[cfg(test)]
2mod tests;
3
4use std::collections::HashMap;
5
6#[cfg(feature = "serialization")]
7use std::fmt;
8
9use rand::seq::SliceRandom;
10
11#[cfg(feature = "serialization")]
12use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer};
13
14#[cfg(feature = "serialization")]
15const KEY_NO_WORD: &str = "\n";
16#[cfg(feature = "serialization")]
17const KEY_SEPARATOR: &str = " ";
18
19/// A sequence of words, used as the key in a `Chain`'s map.
20#[derive(Debug, Hash, Clone, PartialEq, Eq)]
21pub struct ChainKey(Vec<Option<String>>);
22
23impl ChainKey {
24    pub fn blank(order: usize) -> Self {
25        ChainKey(vec![None; order])
26    }
27
28    pub fn from_vec(vec: Vec<Option<String>>) -> Self {
29        ChainKey(vec)
30    }
31
32    pub fn to_vec(self) -> Vec<Option<String>> {
33        self.0
34    }
35
36    pub fn advance(&mut self, next_word: &Option<String>) {
37        self.0 = self.0[1..self.0.len()].to_vec();
38        self.0.push(next_word.clone());
39    }
40
41    #[cfg(feature = "serialization")]
42    fn to_string(&self) -> String {
43        let mut result = String::new();
44
45        let mut first = true;
46
47        for word in &self.0 {
48            if first {
49                first = false;
50            } else {
51                result.push_str(KEY_SEPARATOR);
52            }
53
54            if let Some(word) = word {
55                result.push_str(&word);
56            } else {
57                result.push_str(KEY_NO_WORD);
58            }
59        }
60
61        result
62    }
63
64    /// TODO: Check input for correctness.
65    #[cfg(feature = "serialization")]
66    fn from_str(string: &str) -> Self {
67        let mut result = Vec::new();
68
69        for word in string.split(KEY_SEPARATOR) {
70            if word == KEY_NO_WORD {
71                result.push(None);
72            } else {
73                result.push(Some(word.to_string()));
74            }
75        }
76
77        ChainKey(result)
78    }
79}
80
81#[cfg(feature = "serialization")]
82impl Serialize for ChainKey {
83    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
84    where
85        S: Serializer,
86    {
87        serializer.serialize_str(&self.to_string())
88    }
89}
90
91#[cfg(feature = "serialization")]
92struct ChainKeyVisitor;
93
94#[cfg(feature = "serialization")]
95impl<'de> Visitor<'de> for ChainKeyVisitor {
96    type Value = ChainKey;
97
98    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
99        formatter.write_str("a string")
100    }
101
102    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
103    where
104        E: serde::de::Error,
105    {
106        Ok(ChainKey::from_str(value))
107    }
108}
109
110#[cfg(feature = "serialization")]
111impl<'de> Deserialize<'de> for ChainKey {
112    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
113    where
114        D: Deserializer<'de>,
115    {
116        deserializer.deserialize_str(ChainKeyVisitor)
117    }
118}
119
120/// A Markov chain.
121#[derive(Clone, Debug)]
122#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
123pub struct Chain {
124    /// A map from `order` words to the possible following words.
125    map: HashMap<ChainKey, Vec<Option<String>>>,
126    order: usize,
127}
128
129impl Chain {
130    pub fn new(order: usize) -> Self {
131        Chain {
132            map: HashMap::new(),
133            order,
134        }
135    }
136
137    pub fn train(&mut self, string: &str) {
138        // Create a Vec that starts with `self.order` `None`s, then all the
139        // words in the string wrapped in `Some()`, then a single `None`.
140        let mut words = vec![None; self.order];
141
142        for word in string.split_whitespace() {
143            words.push(Some(word.to_string()));
144        }
145
146        words.push(None);
147
148        // Now slide a window over `words` to produce slices where the last
149        // element is the resulting word and the rest is the key to that word.
150        for window in words.windows(self.order + 1) {
151            let key = &window[..self.order];
152            let word = &window[self.order];
153
154            let map_entry = self
155                .map
156                .entry(ChainKey::from_vec(key.to_vec()))
157                .or_insert(Vec::new());
158            map_entry.push(word.clone());
159        }
160    }
161
162    // Generate a string.
163    pub fn generate(&self) -> Option<String> {
164        // Start with a key of all `None` to match starting from the start of
165        // one of the training inputs.
166        let seed = ChainKey::blank(self.order);
167
168        self.generate_from_seed(&seed)
169    }
170
171    /// Generate a string based on some seed words.
172    /// Returns `None` if there is no way to start a generated string with
173    /// that seed, eg. it is longer than `self.order`.
174    pub fn generate_from_seed(&self, seed: &ChainKey) -> Option<String> {
175        if !self.map.contains_key(seed) {
176            return None;
177        }
178
179        let mut rng = rand::thread_rng();
180        let mut result: Vec<String> = Vec::new();
181
182        let mut cursor = seed.clone();
183
184        loop {
185            let possible_words = &self.map[&cursor];
186
187            // Any entry in the map is guaranteed to have at least one word in
188            // it, so this unwrap is okay.
189            let next_word = possible_words.choose(&mut rng).unwrap();
190
191            if let Some(next_word) = next_word {
192                result.push(next_word.clone());
193            } else {
194                // Terminator word.
195                break;
196            }
197
198            // Advance the cursor along by popping the front and appending the
199            // new word on the end.
200            cursor.advance(next_word);
201        }
202
203        Some(result.join(" "))
204    }
205
206    /// Serialize this chain to JSON.
207    #[cfg(feature = "serialization")]
208    pub fn to_json(&self) -> serde_json::Result<String> {
209        serde_json::to_string(self)
210    }
211
212    /// Load a chain from JSON.
213    #[cfg(feature = "serialization")]
214    pub fn from_json(json: &str) -> serde_json::Result<Self> {
215        serde_json::from_str(json)
216    }
217}