1use super::selectors::interface::MarkovSelector;
6use super::selectors::interface::SelectionType;
7use super::token::*;
8use crate::sentence::lex::{Lexer, Token as LexedToken};
9use rand::{distributions::Uniform, prelude::*};
10use std::collections::HashMap;
11use std::collections::LinkedList;
12use std::rc::Rc;
13
14#[derive(Clone, Copy, Eq, PartialEq, Debug)]
16pub enum MarkovTraverseDir {
17 Forward,
18 Reverse,
19}
20
21#[derive(Clone, Debug, Eq, PartialEq)]
22pub enum MarkovSeed<'a> {
23 Word(&'a str),
24 Id(usize),
25 Random,
26}
27
28pub struct Edge {
30 pub src_idx: usize,
32
33 pub dst_idx: usize,
35
36 pub hits: usize,
38
39 pub pct_idx: usize,
41}
42
43impl Edge {
44 pub fn get_source<'a>(&self, chain: &'a MarkovChain) -> MarkovToken<'a> {
46 chain.get_textlet(self.src_idx).unwrap()
47 }
48
49 pub fn get_dest<'a>(&self, chain: &'a MarkovChain) -> MarkovToken<'a> {
51 chain.get_textlet(self.dst_idx).unwrap()
52 }
53
54 pub fn get_punct<'a>(&self, chain: &'a MarkovChain) -> MarkovToken<'a> {
56 chain.get_textlet(self.pct_idx).unwrap()
57 }
58}
59
60pub struct MarkovChain {
64 textlet_bag: Vec<MarkovTokenOwned>,
65 textlet_indices: HashMap<Rc<str>, usize>,
66 words: Vec<usize>,
67
68 edge_list: Vec<Edge>,
69 edges: HashMap<usize, Vec<usize>>,
70 reverse_edges: HashMap<usize, Vec<usize>>,
71}
72
73impl Default for MarkovChain {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl MarkovChain {
80 pub fn new() -> MarkovChain {
84 MarkovChain {
85 textlet_bag: vec![MarkovTokenOwned::Begin, MarkovTokenOwned::End],
86 textlet_indices: HashMap::new(),
87 words: Vec::new(),
88
89 edge_list: Vec::new(),
90 edges: HashMap::new(),
91 reverse_edges: HashMap::new(),
92 }
93 }
94
95 pub fn ensure_textlet_index(&mut self, word: &str) -> usize {
100 match self.textlet_indices.get(word) {
101 Some(a) => *a,
102 None => {
103 let i = self.textlet_bag.len();
104 let rcword: Rc<str> = Rc::from(word);
105
106 self.textlet_bag
107 .push(MarkovTokenOwned::Textlet(rcword.clone()));
108
109 self.textlet_indices.insert(rcword, i);
110
111 i
112 }
113 }
114 }
115
116 pub fn ensure_textlet_from_token(&mut self, token: LexedToken) -> usize {
122 match token {
123 LexedToken::Begin => 0,
124 LexedToken::End => 1,
125 LexedToken::Punct(word) => self.ensure_textlet_index(word),
126 LexedToken::Word(word) => self.ensure_textlet_index(word),
127 }
128 }
129
130 pub fn try_get_textlet_index(&self, word: &str) -> Option<usize> {
136 self.textlet_indices.get(word).copied()
137 }
138
139 pub fn get_textlet(&self, index: usize) -> Option<MarkovToken<'_>> {
143 self.textlet_bag.get(index).map(MarkovToken::from)
144 }
145
146 fn push_new_edge(
147 &mut self,
148 from: usize,
149 to: usize,
150 punct: usize,
151 hits: Option<usize>,
152 ) -> usize {
153 let edge = Edge {
154 src_idx: from,
155 dst_idx: to,
156 hits: hits.unwrap_or(1),
157 pct_idx: punct,
158 };
159
160 let idx = self.edge_list.len();
161 self.edge_list.push(edge);
162
163 idx
164 }
165
166 fn add_reverse_edge(&mut self, edge_idx: usize) {
167 let edge = &self.edge_list[edge_idx];
168
169 match self.reverse_edges.get_mut(&edge.dst_idx) {
170 None => {
171 let rev_vec = vec![edge_idx];
172
173 self.reverse_edges.insert(edge.dst_idx, rev_vec);
174 }
175
176 Some(rev_vec) => {
177 for oedge in rev_vec.iter() {
178 let oedge = self.edge_list.get(*oedge).unwrap();
179
180 if edge.src_idx == oedge.src_idx && edge.pct_idx == oedge.pct_idx {
181 return;
182 }
183 }
184
185 rev_vec.push(edge_idx);
186 }
187 }
188 }
189
190 pub fn register_edge(&mut self, from: usize, to: usize, punct: usize) {
200 for item in [from, to] {
201 if !self.words.contains(&item) {
202 self.words.push(item);
203 }
204 }
205
206 if let Some(edgevec) = self.edges.get_mut(&from) {
207 for edge in edgevec.iter() {
208 let edge: &mut Edge = self.edge_list.get_mut(*edge).unwrap();
209
210 if edge.dst_idx == to && edge.pct_idx == punct {
211 edge.hits += 1;
212 return;
213 }
214 }
215 }
216
217 let idx = self.push_new_edge(from, to, punct, None);
218 self.edges.insert(from, vec![idx]);
219
220 if let Some(edgevec) = self.edges.get_mut(&from) {
221 edgevec.push(idx);
222 } else {
223 self.edges.insert(from, vec![idx]);
224 }
225
226 self.add_reverse_edge(idx);
227 }
228
229 fn get_seed<T: Rng>(&self, seed: MarkovSeed, rng: &mut T) -> Result<usize, String> {
230 use MarkovSeed::*;
231
232 match seed {
233 Word(seed) => {
234 let from = self.try_get_textlet_index(seed);
235
236 if from.is_none() {
237 return Err(format!(
238 "Seed word {:?} not found in this Markov chain!",
239 seed
240 ));
241 }
242
243 Ok(from.unwrap())
244 }
245
246 Id(seed) => Ok(seed),
247
248 Random => {
249 let from: usize = Uniform::new(0, self.words.len()).sample(rng);
250 Ok(self.words[from])
251 }
252 }
253 }
254
255 fn _weighted_select<R>(
256 &self,
257 sel_type: SelectionType,
258 edges: &[usize],
259 weights: &[f32],
260 rng: &mut R,
261 ) -> &Edge
262 where
263 R: Rng,
264 {
265 match sel_type {
266 SelectionType::Lowest => {
267 edges
268 .iter()
269 .map(|e| &self.edge_list[*e])
270 .zip(weights.iter())
271 .reduce(|ewc, ewn| if ewc.1 < ewn.1 { ewc } else { ewn })
272 .unwrap()
273 .0
274 }
275
276 SelectionType::Highest => {
277 edges
278 .iter()
279 .map(|e| &self.edge_list[*e])
280 .zip(weights.iter())
281 .reduce(|ewc, ewn| if ewc.1 > ewn.1 { ewc } else { ewn })
282 .unwrap()
283 .0
284 }
285
286 SelectionType::WeightedRandom => {
287 let total: f32 = weights.iter().sum();
288 let pick = Uniform::new(0.0_f32, total).sample(rng);
289
290 let mut curr = 0.0;
291 let mut res = None;
292
293 for (edge, weight) in edges
294 .iter()
295 .map(|e| &self.edge_list[*e])
296 .zip(weights.iter())
297 {
298 curr += weight;
299
300 if curr >= pick {
301 res = Some(edge);
302 break;
303 }
304 }
305
306 res.unwrap()
307 }
308 }
309 }
310
311 pub fn select_next_word(
325 &self,
326 seed: MarkovSeed,
327 selector: &mut dyn MarkovSelector,
328 direction: MarkovTraverseDir,
329 ) -> Result<(MarkovToken<'_>, MarkovToken<'_>, usize, usize), String> {
330 use MarkovTraverseDir::*;
331
332 let mut rng = thread_rng();
333
334 let from: usize = self.get_seed(seed, &mut rng)?;
335
336 let edges = match direction {
337 MarkovTraverseDir::Forward => self.edges.get(&from),
338 MarkovTraverseDir::Reverse => self.reverse_edges.get(&from),
339 };
340
341 if edges.is_none() {
342 return Err(format!(
343 "Seed textlet {:?} is not connected to anything in this Markov chain!",
344 self.get_textlet(from)
345 ));
346 }
347
348 let edges = edges.unwrap();
349
350 if edges.is_empty() {
351 return Err(format!("Seed textlet {:?} is not connected to anything in this Markov chain, but in a weird way!", self.get_textlet(from)));
352 }
353
354 let mut weights: Vec<f32> = vec![0.0; edges.len()];
355
356 selector.reset(direction);
357
358 for (edge, weight) in edges
359 .iter()
360 .map(|e| &self.edge_list[*e])
361 .zip(weights.iter_mut())
362 {
363 *weight = selector.weight(
364 &edge.get_source(self),
365 &edge.get_dest(self),
366 &edge.get_punct(self),
367 edge.hits,
368 );
369 }
370
371 let sel_type = selector.selection_type();
372
373 let best_edge: &Edge = self._weighted_select(sel_type, edges, &weights, &mut rng);
374
375 match direction {
376 Forward => Ok((
377 best_edge.get_dest(self),
378 best_edge.get_punct(self),
379 best_edge.dst_idx,
380 best_edge.pct_idx,
381 )),
382
383 Reverse => Ok((
384 best_edge.get_source(self),
385 best_edge.get_punct(self),
386 best_edge.src_idx,
387 best_edge.pct_idx,
388 )),
389 }
390 }
391
392 pub fn num_words(&self) -> usize {
402 self.words.len()
403 }
404
405 pub fn num_textlets(&self) -> usize {
412 self.textlet_bag.len()
413 }
414
415 pub fn num_edges(&self) -> usize {
422 self.edge_list.len()
423 }
424
425 pub fn parse_sentence(&mut self, sentence: &str) {
430 let mut lexer = Lexer::new(sentence);
431 let mut curr_token = lexer.next();
432
433 let mut to_register: Vec<(LexedToken, LexedToken, LexedToken)> = vec![];
434
435 if sentence.is_empty() {
436 return;
437 }
438
439 loop {
440 if curr_token.is_none() {
441 panic!("Found a none token prematurely!");
442 }
443
444 let token = curr_token.unwrap();
445
446 let punct = lexer.next();
447 let next_token = lexer.next();
448
449 if punct.is_none() || next_token.is_none() {
450 return;
451 }
452
453 let punct = punct.unwrap();
454 let next_token = next_token.unwrap();
455
456 to_register.push((token, punct, next_token.clone()));
457
458 if next_token == LexedToken::End {
459 break;
460 }
461
462 curr_token = Some(next_token);
463 }
464
465 for (src, pct, dst) in to_register {
466 let src = self.ensure_textlet_from_token(src);
467 let pct = self.ensure_textlet_from_token(pct);
468 let dst = self.ensure_textlet_from_token(dst);
469
470 self.register_edge(src, dst, pct);
471 }
472 }
473
474 pub fn begin(&self) -> usize {
476 self.textlet_bag
477 .iter()
478 .position(|a| a == &MarkovTokenOwned::Begin)
479 .unwrap()
480 }
481
482 pub fn end(&self) -> usize {
484 self.textlet_bag
485 .iter()
486 .position(|a| a == &MarkovTokenOwned::End)
487 .unwrap()
488 }
489
490 pub fn is_empty(&self) -> bool {
492 self.words.is_empty()
493 }
494
495 pub fn compose_sentence<'a>(
500 &'a self,
501 seed: MarkovSeed,
502 selector: &mut dyn MarkovSelector,
503 max_len: Option<usize>,
504 ) -> Result<TokenList<'a>, String> {
505 use MarkovSeed::Id;
506 use MarkovToken::*;
507 use MarkovTraverseDir::*;
508
509 let mut rng = thread_rng();
510
511 if self.is_empty() {
512 return Err("Cannot compose a sentence from an empty chain".into());
513 }
514
515 let seed = self.get_seed(seed, &mut rng)?;
516
517 let mut sentence: LinkedList<MarkovToken<'a>> =
518 LinkedList::from([self.get_textlet(seed).unwrap()]);
519
520 let mut len = self.get_textlet(seed).unwrap().len();
521
522 let mut curr_backward = seed;
523 let mut curr_forward = seed;
524
525 let capped = max_len.is_some();
526 let max_half_len: Option<usize> = max_len.map(|x| x / 2);
527
528 while curr_backward != self.begin() {
529 let (prev, punct, prvidx, _) =
530 self.select_next_word(Id(curr_backward), selector, Reverse)?;
531
532 let new_len = len + punct.len() + prev.len();
533
534 if capped && new_len > max_half_len.unwrap() {
535 break;
536 }
537
538 len = new_len;
539
540 sentence.push_front(punct);
541
542 if prev == Begin {
543 break;
544 }
545
546 sentence.push_front(prev);
547
548 curr_backward = prvidx;
549 }
550
551 while curr_forward != self.begin() {
552 let (next, punct, nxtidx, _) =
553 self.select_next_word(Id(curr_forward), selector, Forward)?;
554
555 let new_len = len + punct.len() + next.len();
556
557 if capped && new_len > max_len.unwrap() {
558 break;
559 }
560
561 len = new_len;
562
563 sentence.push_back(punct);
564
565 if next == End {
566 break;
567 }
568
569 sentence.push_back(next);
570
571 curr_forward = nxtidx;
572 }
573
574 Ok(TokenList(sentence))
575 }
576}