markov_generator/
lib.rs

1/*
2 * Copyright (C) 2022, 2024 taylor.fish <contact@taylor.fish>
3 *
4 * This file is part of markov-generator.
5 *
6 * markov-generator is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * markov-generator is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with markov-generator. If not, see <https://www.gnu.org/licenses/>.
18 */
19
20#![warn(missing_docs)]
21#![deny(unsafe_code)]
22#![cfg_attr(not(any(feature = "std", doc)), no_std)]
23#![cfg_attr(feature = "doc_cfg", feature(doc_cfg))]
24
25//! A highly customizable Rust library for building [Markov chains] and
26//! generating random sequences from them.
27//!
28//! [Markov chains]: https://en.wikipedia.org/wiki/Markov_chain
29//!
30//! [`Chain`] implements [Serde]’s [`Serialize`] and [`Deserialize`] traits, so
31//! you can reuse chains without having to regenerate them every time (which
32//! can be a lengthy process).
33//!
34//! Example
35//! -------
36//!
37//! ```rust
38//! use markov_generator::{AddEdges, HashChain};
39//! # use std::fs::File;
40//! # use std::io::{self, BufRead, BufReader, BufWriter, Write};
41//!
42//! # (|| -> std::io::Result<()> {
43//! const DEPTH: usize = 6;
44//! // Maps each sequence of 6 items to a list of possible next items.
45//! // `HashChain` uses hash maps internally; b-trees can also be used.
46//! let mut chain = HashChain::new(DEPTH);
47//!
48//! // In this case, corpus.txt contains one paragraph per line.
49//! let file = File::open("examples/corpus.txt")?;
50//! let mut reader = BufReader::new(file);
51//! let mut line = String::new();
52//! // The previous `DEPTH` characters before `line`.
53//! let mut prev = Vec::<char>::new();
54//!
55//! while let Ok(1..) = reader.read_line(&mut line) {
56//!     // `Both` means that the generated random output could start with the
57//!     // beginning of `line` or end after the end of `line`.
58//!     chain.add_all(line.chars(), AddEdges::Both);
59//!
60//!     // This makes sure there's a chance that the end of the previous line
61//!     // could be followed by the start of the current line when generating
62//!     // random output.
63//!     chain.add_all(
64//!         prev.iter().copied().chain(line.chars().take(DEPTH)),
65//!         AddEdges::Neither,
66//!     );
67//!
68//!     if let Some((n, (i, _c))) =
69//!         line.char_indices().rev().enumerate().take(DEPTH).last()
70//!     {
71//!         // Keep only the most recent `DEPTH` characters.
72//!         prev.drain(0..prev.len().saturating_sub(DEPTH - n - 1));
73//!         prev.extend(line[i..].chars());
74//!     }
75//!     line.clear();
76//! }
77//!
78//! // Generate and print random Markov data.
79//! let mut stdout = BufWriter::new(io::stdout().lock());
80//! let mut prev_newline = false;
81//! for &c in chain.generate() {
82//!     if prev_newline {
83//!         writeln!(stdout)?;
84//!     }
85//!     prev_newline = c == '\n';
86//!     write!(stdout, "{c}")?;
87//! }
88//! stdout.flush()?;
89//! # Ok(())
90//! # })().unwrap();
91//! ```
92//!
93//! Crate features
94//! --------------
95//!
96//! * `std` (default: enabled): Use [`std`]. If disabled, this crate is marked
97//!   `no_std`.
98//! * `serde` (default: enabled): Implement [Serde]’s [`Serialize`] and
99//!   [`Deserialize`] traits for [`Chain`].
100//!
101//! [Serde]: serde
102
103extern crate alloc;
104
105use alloc::boxed::Box;
106use alloc::collections::VecDeque;
107use core::borrow::Borrow;
108use core::fmt::{self, Debug};
109use core::iter::{DoubleEndedIterator, ExactSizeIterator};
110use core::mem;
111use rand::Rng;
112#[cfg(feature = "serde")]
113use serde::{Deserialize, Serialize};
114
115pub mod map;
116use map::{Map, MapFrom, MapFromSlice, MapOps, MapOpsSlice};
117use map::{OwnedSliceKey, SliceKey};
118
119#[cfg_attr(
120    feature = "serde",
121    derive(Serialize, Deserialize),
122    serde(bound(serialize = "T: Serialize, M: Map<T>")),
123    serde(bound(deserialize = "T: Deserialize<'de>, M: Map<T>"))
124)]
125struct FrequencyMap<T, M: Map<T>> {
126    #[cfg_attr(
127        feature = "serde",
128        serde(serialize_with = "MapOps::serialize"),
129        serde(deserialize_with = "MapOps::deserialize")
130    )]
131    map: <M as MapFrom<T>>::To<usize>,
132    total: usize,
133}
134
135impl<T: Debug, M: Map<T>> Debug for FrequencyMap<T, M> {
136    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137        f.debug_struct("FrequencyMap")
138            .field("map", &map::MapDebug::new(&self.map))
139            .field("total", &self.total)
140            .finish()
141    }
142}
143
144impl<T: Clone, M: Map<T>> Clone for FrequencyMap<T, M> {
145    fn clone(&self) -> Self {
146        Self {
147            map: self.map.clone(),
148            total: self.total,
149        }
150    }
151}
152
153impl<T, M: Map<T>> Default for FrequencyMap<T, M> {
154    fn default() -> Self {
155        Self {
156            map: Default::default(),
157            total: 0,
158        }
159    }
160}
161
162#[cfg_attr(
163    feature = "serde",
164    derive(Serialize, Deserialize),
165    serde(bound(serialize = "T: Serialize, M: Map<T>")),
166    serde(bound(deserialize = "T: Deserialize<'de>, M: Map<T>"))
167)]
168/// A Markov chain.
169///
170/// This type implements [`Serialize`] and [`Deserialize`] when the
171/// `serde` feature is enabled (which it is by default).
172pub struct Chain<T, M: Map<T> = map::BTree> {
173    #[cfg_attr(
174        feature = "serde",
175        serde(serialize_with = "MapOps::serialize"),
176        serde(deserialize_with = "MapOps::deserialize")
177    )]
178    map: <M as MapFromSlice<T>>::To<Box<FrequencyMap<T, M>>>,
179    depth: usize,
180    #[cfg_attr(feature = "serde", serde(skip))]
181    buf: VecDeque<T>,
182}
183
184/// Alias of <code>[Chain]<T, [map::BTree]></code>.
185pub type BTreeChain<T> = Chain<T, map::BTree>;
186
187#[cfg(feature = "std")]
188#[cfg_attr(feature = "doc_cfg", doc(cfg(feature = "std")))]
189/// Alias of <code>[Chain]<T, [map::Hash]></code>.
190pub type HashChain<T> = Chain<T, map::Hash>;
191
192impl<T, M: Map<T>> Chain<T, M> {
193    /// Creates a new chain.
194    ///
195    /// See [`Self::depth`] for an explanation of the depth.
196    pub fn new(depth: usize) -> Self {
197        Self {
198            map: Default::default(),
199            depth,
200            buf: Default::default(),
201        }
202    }
203
204    /// Gets the chain's depth.
205    ///
206    /// A depth of `n` means the chain maps sequences of `n` items of type `T`
207    /// to a list of possible next items.
208    pub fn depth(&self) -> usize {
209        self.depth
210    }
211
212    fn take_buf(&mut self) -> VecDeque<T> {
213        let mut buf = mem::take(&mut self.buf);
214        buf.clear();
215        if buf.capacity() == 0 {
216            buf.reserve_exact(self.depth);
217        }
218        buf
219    }
220
221    /// Adds all items in an iterator to the chain.
222    ///
223    /// This essentially calls [`Self::add`] on every overlapping window
224    /// of [`self.depth()`] elements plus the item following each window.
225    /// (Smaller windows at the start of the sequence may also be added
226    /// depending on the value of `edges`.)
227    ///
228    /// `edges` controls whether the start or end of `items` may be used as the
229    /// start or end of a generated sequence from the chain. See the
230    /// documentation for [`AddEdges`] for more information.
231    ///
232    /// [`self.depth()`]: Self::depth
233    pub fn add_all<I>(&mut self, items: I, edges: AddEdges)
234    where
235        I: IntoIterator<Item = T>,
236        T: Clone,
237    {
238        let mut buf = self.take_buf();
239        let mut iter = items.into_iter();
240        let mut item_opt = iter.next();
241        while let Some(item) = item_opt {
242            debug_assert!(buf.len() <= self.depth);
243            let next = iter.next();
244            if buf.len() == self.depth || edges.has_start() {
245                if next.is_some() || edges.has_end() {
246                    self.add_with_key(&buf, Some(item.clone()));
247                } else {
248                    self.add_with_key(&mut buf, Some(item));
249                    break;
250                }
251            }
252            if buf.len() == self.depth {
253                buf.pop_front();
254            }
255            buf.push_back(item);
256            item_opt = next;
257        }
258        if !buf.is_empty() && edges.has_end() {
259            self.add_with_key(&mut buf, None);
260        }
261        self.buf = buf;
262    }
263
264    /// Adds items to the chain.
265    ///
266    /// If `items` yields at least [`self.depth()`] items, this increases the
267    /// chance that the first [`self.depth()`] items will be followed by `next`
268    /// in a generated sequence (additional items in `items` are ignored).
269    ///
270    /// If `items` yields fewer than [`self.depth()`] items, this method
271    /// increases the chance that those items, when they are the only items
272    /// that have been generated so far in a sequence (i.e., the start of a
273    /// sequence), will be followed by `next`. In the case where `items` yields
274    /// no elements, this increases the chance that `item` will be produced as
275    /// the first element of a generated sequence.
276    ///
277    /// If `next` is `None`, this method increases the chance that `items` will
278    /// be considered the end of a sequence.
279    ///
280    /// [`self.depth()`]: Self::depth
281    pub fn add<I>(&mut self, items: I, next: Option<T>)
282    where
283        I: IntoIterator<Item = T>,
284    {
285        let mut buf = self.take_buf();
286        buf.extend(items.into_iter().take(self.depth));
287        self.add_with_key(&mut buf, next);
288        self.buf = buf;
289    }
290
291    fn add_with_key<S>(&mut self, key: S, next: Option<T>)
292    where
293        S: SliceKey<T> + Into<OwnedSliceKey<T>>,
294    {
295        debug_assert!(key.get(self.depth()).is_none());
296        let freq = self.map.slice_get_or_insert_with(key, Default::default);
297        if let Some(v) = next {
298            *freq.map.get_or_insert_with(v, Default::default) += 1;
299        }
300        freq.total += 1;
301    }
302
303    #[cfg(feature = "std")]
304    #[cfg_attr(feature = "doc_cfg", doc(cfg(feature = "std")))]
305    /// Generates random Markov chain data.
306    ///
307    /// Returns an iterator that yields the elements by reference. If you want
308    /// them by value, simply use [`Iterator::cloned`] (as long as `T` is
309    /// [`Clone`]).
310    pub fn generate(&self) -> Generator<'_, T, rand::rngs::ThreadRng, M> {
311        self.generate_with_rng(rand::thread_rng())
312    }
313
314    /// Like [`Self::generate`], but takes a custom random number generator.
315    pub fn generate_with_rng<R: Rng>(&self, rng: R) -> Generator<'_, T, R, M> {
316        Generator::new(self, rng)
317    }
318
319    #[cfg(feature = "std")]
320    #[cfg_attr(feature = "doc_cfg", doc(cfg(feature = "std")))]
321    /// Gets a random item that has followed `items` in the added data.
322    ///
323    /// If `items` yields more than [`self.depth()`] items, only the first
324    /// [`self.depth()`] items are considered. If `items` yields fewer than
325    /// [`self.depth()`] items (potentially zero), those items are interpreted
326    /// as the beginning of a generated sequence.
327    ///
328    /// [`self.depth()`]: Self::depth
329    pub fn get<'a, B>(&'a self, items: &[B]) -> Option<&'a T>
330    where
331        B: Borrow<T>,
332    {
333        self.get_with_rng(items, rand::thread_rng())
334    }
335
336    /// Like [`Self::get`], but takes a custom random number generator.
337    pub fn get_with_rng<'a, B, R>(
338        &'a self,
339        items: &[B],
340        rng: R,
341    ) -> Option<&'a T>
342    where
343        B: Borrow<T>,
344        R: Rng,
345    {
346        self.get_with_key(items, rng)
347    }
348
349    fn get_with_key<S, R>(&self, items: S, mut rng: R) -> Option<&T>
350    where
351        S: SliceKey<T>,
352        R: Rng,
353    {
354        let freq = self.map.slice_get(&items)?;
355        let mut n = rng.gen_range(0..freq.total);
356        for (item, count) in freq.map.iter() {
357            n = if let Some(n) = n.checked_sub(*count) {
358                n
359            } else {
360                return Some(item);
361            }
362        }
363        None
364    }
365}
366
367impl<T: Debug, M: Map<T>> Debug for Chain<T, M> {
368    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
369        f.debug_struct("Chain")
370            .field("map", &map::MapDebug::new(&self.map))
371            .field("depth", &self.depth)
372            .field("buf", &self.buf)
373            .finish()
374    }
375}
376
377impl<T: Clone, M: Map<T>> Clone for Chain<T, M> {
378    fn clone(&self) -> Self {
379        Self {
380            map: self.map.clone(),
381            depth: self.depth,
382            buf: Default::default(),
383        }
384    }
385}
386
387#[derive(Clone, Copy, PartialEq, Eq)]
388/// Controls the behavior of [`Chain::add_all`].
389///
390/// This enum determines whether the start or end of the provided items can be
391/// used as start or end data for the chain (see individual variants'
392/// descriptions).
393pub enum AddEdges {
394    /// Allows any of the first [`Chain::depth`] items of the provided iterator
395    /// to be returned by calling [`Chain::get`] with the items preceding it
396    /// (of which there are necessarily fewer than [`Chain::depth`], and
397    /// potentially none). This also means that [`Chain::generate`] may yield
398    /// these items as its initial elements.
399    Start,
400
401    /// Allows the last [`Chain::depth`] items of the provided iterator (or
402    /// fewer, if it doesn't yield that many) to be considered the end of the
403    /// sequence, represented by a return value of [`None`] from
404    /// [`Chain::get`].
405    End,
406
407    /// Performs the behavior of both [`Self::Start`] and [`Self::End`].
408    Both,
409
410    /// Performs the behavior of neither [`Self::Start`] nor [`Self::End`].
411    Neither,
412}
413
414impl AddEdges {
415    fn has_start(&self) -> bool {
416        matches!(self, Self::Start | Self::Both)
417    }
418
419    fn has_end(&self) -> bool {
420        matches!(self, Self::End | Self::Both)
421    }
422}
423
424/// Iterator returned by [`Chain::generate`].
425pub struct Generator<'a, T, R, M: Map<T>> {
426    chain: &'a Chain<T, M>,
427    rng: R,
428    buf: VecDeque<&'a T>,
429}
430
431impl<'a, T, R, M: Map<T>> Generator<'a, T, R, M> {
432    /// Creates a new random Markov data generator. See [`Chain::generate`].
433    pub fn new(chain: &'a Chain<T, M>, rng: R) -> Self {
434        let mut buf = VecDeque::new();
435        buf.reserve_exact(chain.depth);
436        Self {
437            chain,
438            rng,
439            buf,
440        }
441    }
442
443    /// Gets the generator's state.
444    ///
445    /// This is a sequence of exactly [`self.chain.depth()`](Chain::depth)
446    /// items. [`Self::next`] will pass the state to [`Chain::get`] and then
447    /// update the state accordingly by popping the front item and pushing the
448    /// result of [`Chain::get`] to the back.
449    ///
450    /// The initial state is empty.
451    #[rustfmt::skip]
452    pub fn state(
453        &self,
454    ) -> impl '_
455        + Clone
456        + DoubleEndedIterator<Item = &T>
457        + ExactSizeIterator
458    {
459        self.buf.iter().copied()
460    }
461
462    /// Sets the generator's state.
463    ///
464    /// This method sets the generator's state to (up to) the first
465    /// [`self.chain.depth()`](Chain::depth) items in `state`.
466    ///
467    /// See [`Self::state`] for an explanation of how the state is used.
468    pub fn set_state<I>(&mut self, state: I)
469    where
470        I: IntoIterator<Item = &'a T>,
471    {
472        self.buf.clear();
473        let iter = state.into_iter().take(self.chain.depth());
474        self.buf.extend(iter);
475    }
476}
477
478impl<'a, T, R, M> Iterator for Generator<'a, T, R, M>
479where
480    T: Clone,
481    R: Rng,
482    M: Map<T>,
483{
484    type Item = &'a T;
485
486    fn next(&mut self) -> Option<&'a T> {
487        let next = self.chain.get_with_key(&self.buf, &mut self.rng)?;
488        debug_assert!(self.buf.len() <= self.chain.depth());
489        if self.buf.len() == self.chain.depth() {
490            self.buf.pop_front();
491        }
492        self.buf.push_back(next);
493        Some(next)
494    }
495}