wordmarkov/chain/
body.rs

1/*!
2 * Actual Markov chain container.
3 */
4
5use super::selectors::interface::MarkovSelector;
6use super::selectors::interface::SelectionType;
7use super::token::*;
8use crate::sentence::lex::{Lexer, Token as LexedToken};
9use rand::{distributions::Uniform, prelude::*};
10use std::collections::HashMap;
11use std::collections::LinkedList;
12use std::rc::Rc;
13
14/// The direction in which to traverse the Markov chain.
15#[derive(Clone, Copy, Eq, PartialEq, Debug)]
16pub enum MarkovTraverseDir {
17    Forward,
18    Reverse,
19}
20
21#[derive(Clone, Debug, Eq, PartialEq)]
22pub enum MarkovSeed<'a> {
23    Word(&'a str),
24    Id(usize),
25    Random,
26}
27
28/// An edge linking two words in the Markov chain.
29pub struct Edge {
30    /// The word this edge comes from.
31    pub src_idx: usize,
32
33    /// The word this edge leads into.
34    pub dst_idx: usize,
35
36    /// How many times this edge has been found.
37    pub hits: usize,
38
39    /// The punctuation in this edge.
40    pub pct_idx: usize,
41}
42
43impl Edge {
44    /// Get the word from which this Edge sprouts.
45    pub fn get_source<'a>(&self, chain: &'a MarkovChain) -> MarkovToken<'a> {
46        chain.get_textlet(self.src_idx).unwrap()
47    }
48
49    /// Get the word into which this Edge leads.
50    pub fn get_dest<'a>(&self, chain: &'a MarkovChain) -> MarkovToken<'a> {
51        chain.get_textlet(self.dst_idx).unwrap()
52    }
53
54    /// Get the punctuation between the words this Edge connects.
55    pub fn get_punct<'a>(&self, chain: &'a MarkovChain) -> MarkovToken<'a> {
56        chain.get_textlet(self.pct_idx).unwrap()
57    }
58}
59
60/**
61 * A graph that links tokens together.
62 */
63pub struct MarkovChain {
64    textlet_bag: Vec<MarkovTokenOwned>,
65    textlet_indices: HashMap<Rc<str>, usize>,
66    words: Vec<usize>,
67
68    edge_list: Vec<Edge>,
69    edges: HashMap<usize, Vec<usize>>,
70    reverse_edges: HashMap<usize, Vec<usize>>,
71}
72
73impl Default for MarkovChain {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl MarkovChain {
80    /**
81     * Makes a new empty [MarkovChain].
82     */
83    pub fn new() -> MarkovChain {
84        MarkovChain {
85            textlet_bag: vec![MarkovTokenOwned::Begin, MarkovTokenOwned::End],
86            textlet_indices: HashMap::new(),
87            words: Vec::new(),
88
89            edge_list: Vec::new(),
90            edges: HashMap::new(),
91            reverse_edges: HashMap::new(),
92        }
93    }
94
95    /**
96     * Gets the index of a textlet in this chain; if the textlet is not found,
97     * makes a new one and returns that instead.
98     */
99    pub fn ensure_textlet_index(&mut self, word: &str) -> usize {
100        match self.textlet_indices.get(word) {
101            Some(a) => *a,
102            None => {
103                let i = self.textlet_bag.len();
104                let rcword: Rc<str> = Rc::from(word);
105
106                self.textlet_bag
107                    .push(MarkovTokenOwned::Textlet(rcword.clone()));
108
109                self.textlet_indices.insert(rcword, i);
110
111                i
112            }
113        }
114    }
115
116    /**
117     * Get a textlet index from a [crate::sentence::lex::Token].
118     *
119     * If one does not exist, make one and return that instead.
120     */
121    pub fn ensure_textlet_from_token(&mut self, token: LexedToken) -> usize {
122        match token {
123            LexedToken::Begin => 0,
124            LexedToken::End => 1,
125            LexedToken::Punct(word) => self.ensure_textlet_index(word),
126            LexedToken::Word(word) => self.ensure_textlet_index(word),
127        }
128    }
129
130    /**
131     * Tries to get the index of a textlet in this chain.
132     *
133     * If the textlet is not registered, returns None.
134     */
135    pub fn try_get_textlet_index(&self, word: &str) -> Option<usize> {
136        self.textlet_indices.get(word).copied()
137    }
138
139    /**
140     * Gets the [MarkovToken] of a textlet by its index.
141     */
142    pub fn get_textlet(&self, index: usize) -> Option<MarkovToken<'_>> {
143        self.textlet_bag.get(index).map(MarkovToken::from)
144    }
145
146    fn push_new_edge(
147        &mut self,
148        from: usize,
149        to: usize,
150        punct: usize,
151        hits: Option<usize>,
152    ) -> usize {
153        let edge = Edge {
154            src_idx: from,
155            dst_idx: to,
156            hits: hits.unwrap_or(1),
157            pct_idx: punct,
158        };
159
160        let idx = self.edge_list.len();
161        self.edge_list.push(edge);
162
163        idx
164    }
165
166    fn add_reverse_edge(&mut self, edge_idx: usize) {
167        let edge = &self.edge_list[edge_idx];
168
169        match self.reverse_edges.get_mut(&edge.dst_idx) {
170            None => {
171                let rev_vec = vec![edge_idx];
172
173                self.reverse_edges.insert(edge.dst_idx, rev_vec);
174            }
175
176            Some(rev_vec) => {
177                for oedge in rev_vec.iter() {
178                    let oedge = self.edge_list.get(*oedge).unwrap();
179
180                    if edge.src_idx == oedge.src_idx && edge.pct_idx == oedge.pct_idx {
181                        return;
182                    }
183                }
184
185                rev_vec.push(edge_idx);
186            }
187        }
188    }
189
190    /**
191     * Register a new edge between two word tokens in this chain.
192     *
193     * `from` and `to` must be existing textlet indices. Same with
194     * `punct` – it must be an existing index, and not a space.
195     *
196     * For both `from` and `to`, if the index is not found in the
197     * `self.words` list, it will be added to it.
198     */
199    pub fn register_edge(&mut self, from: usize, to: usize, punct: usize) {
200        for item in [from, to] {
201            if !self.words.contains(&item) {
202                self.words.push(item);
203            }
204        }
205
206        if let Some(edgevec) = self.edges.get_mut(&from) {
207            for edge in edgevec.iter() {
208                let edge: &mut Edge = self.edge_list.get_mut(*edge).unwrap();
209
210                if edge.dst_idx == to && edge.pct_idx == punct {
211                    edge.hits += 1;
212                    return;
213                }
214            }
215        }
216
217        let idx = self.push_new_edge(from, to, punct, None);
218        self.edges.insert(from, vec![idx]);
219
220        if let Some(edgevec) = self.edges.get_mut(&from) {
221            edgevec.push(idx);
222        } else {
223            self.edges.insert(from, vec![idx]);
224        }
225
226        self.add_reverse_edge(idx);
227    }
228
229    fn get_seed<T: Rng>(&self, seed: MarkovSeed, rng: &mut T) -> Result<usize, String> {
230        use MarkovSeed::*;
231
232        match seed {
233            Word(seed) => {
234                let from = self.try_get_textlet_index(seed);
235
236                if from.is_none() {
237                    return Err(format!(
238                        "Seed word {:?} not found in this Markov chain!",
239                        seed
240                    ));
241                }
242
243                Ok(from.unwrap())
244            }
245
246            Id(seed) => Ok(seed),
247
248            Random => {
249                let from: usize = Uniform::new(0, self.words.len()).sample(rng);
250                Ok(self.words[from])
251            }
252        }
253    }
254
255    fn _weighted_select<R>(
256        &self,
257        sel_type: SelectionType,
258        edges: &[usize],
259        weights: &[f32],
260        rng: &mut R,
261    ) -> &Edge
262    where
263        R: Rng,
264    {
265        match sel_type {
266            SelectionType::Lowest => {
267                edges
268                    .iter()
269                    .map(|e| &self.edge_list[*e])
270                    .zip(weights.iter())
271                    .reduce(|ewc, ewn| if ewc.1 < ewn.1 { ewc } else { ewn })
272                    .unwrap()
273                    .0
274            }
275
276            SelectionType::Highest => {
277                edges
278                    .iter()
279                    .map(|e| &self.edge_list[*e])
280                    .zip(weights.iter())
281                    .reduce(|ewc, ewn| if ewc.1 > ewn.1 { ewc } else { ewn })
282                    .unwrap()
283                    .0
284            }
285
286            SelectionType::WeightedRandom => {
287                let total: f32 = weights.iter().sum();
288                let pick = Uniform::new(0.0_f32, total).sample(rng);
289
290                let mut curr = 0.0;
291                let mut res = None;
292
293                for (edge, weight) in edges
294                    .iter()
295                    .map(|e| &self.edge_list[*e])
296                    .zip(weights.iter())
297                {
298                    curr += weight;
299
300                    if curr >= pick {
301                        res = Some(edge);
302                        break;
303                    }
304                }
305
306                res.unwrap()
307            }
308        }
309    }
310
311    /**
312     * Selects the word following the current one (`from`) based om the
313     * criteria of a [MarkovSelector] (`selector`).
314     *
315     * Returns a tuple (`dest`, `inbetween`, `dest_idx`, `inbetween_idx`).
316     * The first two items can be converted into strings because MarkovToken
317     * has Into<&str>. The last two items are the corresponding internal
318     * indices, which can be reused in functions which take `usize`.
319     *
320     * `inbetween` is all of the whitespace and punctuation lying between
321     * `from` and `dest`. Simply concatenate `from` with `inbetween.into()`
322     * with `dest.into()`, in that order.
323     */
324    pub fn select_next_word(
325        &self,
326        seed: MarkovSeed,
327        selector: &mut dyn MarkovSelector,
328        direction: MarkovTraverseDir,
329    ) -> Result<(MarkovToken<'_>, MarkovToken<'_>, usize, usize), String> {
330        use MarkovTraverseDir::*;
331
332        let mut rng = thread_rng();
333
334        let from: usize = self.get_seed(seed, &mut rng)?;
335
336        let edges = match direction {
337            MarkovTraverseDir::Forward => self.edges.get(&from),
338            MarkovTraverseDir::Reverse => self.reverse_edges.get(&from),
339        };
340
341        if edges.is_none() {
342            return Err(format!(
343                "Seed textlet {:?} is not connected to anything in this Markov chain!",
344                self.get_textlet(from)
345            ));
346        }
347
348        let edges = edges.unwrap();
349
350        if edges.is_empty() {
351            return Err(format!("Seed textlet {:?} is not connected to anything in this Markov chain, but in a weird way!", self.get_textlet(from)));
352        }
353
354        let mut weights: Vec<f32> = vec![0.0; edges.len()];
355
356        selector.reset(direction);
357
358        for (edge, weight) in edges
359            .iter()
360            .map(|e| &self.edge_list[*e])
361            .zip(weights.iter_mut())
362        {
363            *weight = selector.weight(
364                &edge.get_source(self),
365                &edge.get_dest(self),
366                &edge.get_punct(self),
367                edge.hits,
368            );
369        }
370
371        let sel_type = selector.selection_type();
372
373        let best_edge: &Edge = self._weighted_select(sel_type, edges, &weights, &mut rng);
374
375        match direction {
376            Forward => Ok((
377                best_edge.get_dest(self),
378                best_edge.get_punct(self),
379                best_edge.dst_idx,
380                best_edge.pct_idx,
381            )),
382
383            Reverse => Ok((
384                best_edge.get_source(self),
385                best_edge.get_punct(self),
386                best_edge.src_idx,
387                best_edge.pct_idx,
388            )),
389        }
390    }
391
392    /**
393     * The number of words in this chain.
394     *
395     * Includes the internal tokens [MarkovTokenOwned::Begin] and
396     * [MarkovTokenOwned::End].
397     *
398     * Should be smaller than or, in extreme cases, equal to,
399     * [Self::num_textlets()].
400     */
401    pub fn num_words(&self) -> usize {
402        self.words.len()
403    }
404
405    /**
406     * The number of textlets in this chain.
407     *
408     * Includes unique instances of whitespace or punctuation, as well as the
409     * internal tokens [MarkovTokenOwned::Begin] and [MarkovTokenOwned::End].
410     */
411    pub fn num_textlets(&self) -> usize {
412        self.textlet_bag.len()
413    }
414
415    /**
416     * The number of [Edge]s registered in this chain.
417     *
418     * Includes edges connected to the internal tokens
419     * [MarkovTokenOwned::Begin] and [MarkovTokenOwned::End].
420     */
421    pub fn num_edges(&self) -> usize {
422        self.edge_list.len()
423    }
424
425    /**
426     * Parse a sentence, registering textlets and edges
427     * for it.
428     */
429    pub fn parse_sentence(&mut self, sentence: &str) {
430        let mut lexer = Lexer::new(sentence);
431        let mut curr_token = lexer.next();
432
433        let mut to_register: Vec<(LexedToken, LexedToken, LexedToken)> = vec![];
434
435        if sentence.is_empty() {
436            return;
437        }
438
439        loop {
440            if curr_token.is_none() {
441                panic!("Found a none token prematurely!");
442            }
443
444            let token = curr_token.unwrap();
445
446            let punct = lexer.next();
447            let next_token = lexer.next();
448
449            if punct.is_none() || next_token.is_none() {
450                return;
451            }
452
453            let punct = punct.unwrap();
454            let next_token = next_token.unwrap();
455
456            to_register.push((token, punct, next_token.clone()));
457
458            if next_token == LexedToken::End {
459                break;
460            }
461
462            curr_token = Some(next_token);
463        }
464
465        for (src, pct, dst) in to_register {
466            let src = self.ensure_textlet_from_token(src);
467            let pct = self.ensure_textlet_from_token(pct);
468            let dst = self.ensure_textlet_from_token(dst);
469
470            self.register_edge(src, dst, pct);
471        }
472    }
473
474    /// Get the textlet identifier for [MarkovTokenOwned::Begin].
475    pub fn begin(&self) -> usize {
476        self.textlet_bag
477            .iter()
478            .position(|a| a == &MarkovTokenOwned::Begin)
479            .unwrap()
480    }
481
482    /// Get the textlet identifier for [MarkovTokenOwned::End].
483    pub fn end(&self) -> usize {
484        self.textlet_bag
485            .iter()
486            .position(|a| a == &MarkovTokenOwned::End)
487            .unwrap()
488    }
489
490    /// Returns whether the chain is empty – has no words in it.
491    pub fn is_empty(&self) -> bool {
492        self.words.is_empty()
493    }
494
495    /**
496     * Composes a sentence by traversing this chain forward and backward from a
497     * given 'seed word'.
498     */
499    pub fn compose_sentence<'a>(
500        &'a self,
501        seed: MarkovSeed,
502        selector: &mut dyn MarkovSelector,
503        max_len: Option<usize>,
504    ) -> Result<TokenList<'a>, String> {
505        use MarkovSeed::Id;
506        use MarkovToken::*;
507        use MarkovTraverseDir::*;
508
509        let mut rng = thread_rng();
510
511        if self.is_empty() {
512            return Err("Cannot compose a sentence from an empty chain".into());
513        }
514
515        let seed = self.get_seed(seed, &mut rng)?;
516
517        let mut sentence: LinkedList<MarkovToken<'a>> =
518            LinkedList::from([self.get_textlet(seed).unwrap()]);
519
520        let mut len = self.get_textlet(seed).unwrap().len();
521
522        let mut curr_backward = seed;
523        let mut curr_forward = seed;
524
525        let capped = max_len.is_some();
526        let max_half_len: Option<usize> = max_len.map(|x| x / 2);
527
528        while curr_backward != self.begin() {
529            let (prev, punct, prvidx, _) =
530                self.select_next_word(Id(curr_backward), selector, Reverse)?;
531
532            let new_len = len + punct.len() + prev.len();
533
534            if capped && new_len > max_half_len.unwrap() {
535                break;
536            }
537
538            len = new_len;
539
540            sentence.push_front(punct);
541
542            if prev == Begin {
543                break;
544            }
545
546            sentence.push_front(prev);
547
548            curr_backward = prvidx;
549        }
550
551        while curr_forward != self.begin() {
552            let (next, punct, nxtidx, _) =
553                self.select_next_word(Id(curr_forward), selector, Forward)?;
554
555            let new_len = len + punct.len() + next.len();
556
557            if capped && new_len > max_len.unwrap() {
558                break;
559            }
560
561            len = new_len;
562
563            sentence.push_back(punct);
564
565            if next == End {
566                break;
567            }
568
569            sentence.push_back(next);
570
571            curr_forward = nxtidx;
572        }
573
574        Ok(TokenList(sentence))
575    }
576}