markov_generator/lib.rs
1/*
2 * Copyright (C) 2022, 2024 taylor.fish <contact@taylor.fish>
3 *
4 * This file is part of markov-generator.
5 *
6 * markov-generator is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * markov-generator is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with markov-generator. If not, see <https://www.gnu.org/licenses/>.
18 */
19
20#![warn(missing_docs)]
21#![deny(unsafe_code)]
22#![cfg_attr(not(any(feature = "std", doc)), no_std)]
23#![cfg_attr(feature = "doc_cfg", feature(doc_cfg))]
24
25//! A highly customizable Rust library for building [Markov chains] and
26//! generating random sequences from them.
27//!
28//! [Markov chains]: https://en.wikipedia.org/wiki/Markov_chain
29//!
30//! [`Chain`] implements [Serde]’s [`Serialize`] and [`Deserialize`] traits, so
31//! you can reuse chains without having to regenerate them every time (which
32//! can be a lengthy process).
33//!
34//! Example
35//! -------
36//!
37//! ```rust
38//! use markov_generator::{AddEdges, HashChain};
39//! # use std::fs::File;
40//! # use std::io::{self, BufRead, BufReader, BufWriter, Write};
41//!
42//! # (|| -> std::io::Result<()> {
43//! const DEPTH: usize = 6;
44//! // Maps each sequence of 6 items to a list of possible next items.
45//! // `HashChain` uses hash maps internally; b-trees can also be used.
46//! let mut chain = HashChain::new(DEPTH);
47//!
48//! // In this case, corpus.txt contains one paragraph per line.
49//! let file = File::open("examples/corpus.txt")?;
50//! let mut reader = BufReader::new(file);
51//! let mut line = String::new();
52//! // The previous `DEPTH` characters before `line`.
53//! let mut prev = Vec::<char>::new();
54//!
55//! while let Ok(1..) = reader.read_line(&mut line) {
56//! // `Both` means that the generated random output could start with the
57//! // beginning of `line` or end after the end of `line`.
58//! chain.add_all(line.chars(), AddEdges::Both);
59//!
60//! // This makes sure there's a chance that the end of the previous line
61//! // could be followed by the start of the current line when generating
62//! // random output.
63//! chain.add_all(
64//! prev.iter().copied().chain(line.chars().take(DEPTH)),
65//! AddEdges::Neither,
66//! );
67//!
68//! if let Some((n, (i, _c))) =
69//! line.char_indices().rev().enumerate().take(DEPTH).last()
70//! {
71//! // Keep only the most recent `DEPTH` characters.
72//! prev.drain(0..prev.len().saturating_sub(DEPTH - n - 1));
73//! prev.extend(line[i..].chars());
74//! }
75//! line.clear();
76//! }
77//!
78//! // Generate and print random Markov data.
79//! let mut stdout = BufWriter::new(io::stdout().lock());
80//! let mut prev_newline = false;
81//! for &c in chain.generate() {
82//! if prev_newline {
83//! writeln!(stdout)?;
84//! }
85//! prev_newline = c == '\n';
86//! write!(stdout, "{c}")?;
87//! }
88//! stdout.flush()?;
89//! # Ok(())
90//! # })().unwrap();
91//! ```
92//!
93//! Crate features
94//! --------------
95//!
96//! * `std` (default: enabled): Use [`std`]. If disabled, this crate is marked
97//! `no_std`.
98//! * `serde` (default: enabled): Implement [Serde]’s [`Serialize`] and
99//! [`Deserialize`] traits for [`Chain`].
100//!
101//! [Serde]: serde
102
103extern crate alloc;
104
105use alloc::boxed::Box;
106use alloc::collections::VecDeque;
107use core::borrow::Borrow;
108use core::fmt::{self, Debug};
109use core::iter::{DoubleEndedIterator, ExactSizeIterator};
110use core::mem;
111use rand::Rng;
112#[cfg(feature = "serde")]
113use serde::{Deserialize, Serialize};
114
115pub mod map;
116use map::{Map, MapFrom, MapFromSlice, MapOps, MapOpsSlice};
117use map::{OwnedSliceKey, SliceKey};
118
119#[cfg_attr(
120 feature = "serde",
121 derive(Serialize, Deserialize),
122 serde(bound(serialize = "T: Serialize, M: Map<T>")),
123 serde(bound(deserialize = "T: Deserialize<'de>, M: Map<T>"))
124)]
125struct FrequencyMap<T, M: Map<T>> {
126 #[cfg_attr(
127 feature = "serde",
128 serde(serialize_with = "MapOps::serialize"),
129 serde(deserialize_with = "MapOps::deserialize")
130 )]
131 map: <M as MapFrom<T>>::To<usize>,
132 total: usize,
133}
134
135impl<T: Debug, M: Map<T>> Debug for FrequencyMap<T, M> {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 f.debug_struct("FrequencyMap")
138 .field("map", &map::MapDebug::new(&self.map))
139 .field("total", &self.total)
140 .finish()
141 }
142}
143
144impl<T: Clone, M: Map<T>> Clone for FrequencyMap<T, M> {
145 fn clone(&self) -> Self {
146 Self {
147 map: self.map.clone(),
148 total: self.total,
149 }
150 }
151}
152
153impl<T, M: Map<T>> Default for FrequencyMap<T, M> {
154 fn default() -> Self {
155 Self {
156 map: Default::default(),
157 total: 0,
158 }
159 }
160}
161
162#[cfg_attr(
163 feature = "serde",
164 derive(Serialize, Deserialize),
165 serde(bound(serialize = "T: Serialize, M: Map<T>")),
166 serde(bound(deserialize = "T: Deserialize<'de>, M: Map<T>"))
167)]
168/// A Markov chain.
169///
170/// This type implements [`Serialize`] and [`Deserialize`] when the
171/// `serde` feature is enabled (which it is by default).
172pub struct Chain<T, M: Map<T> = map::BTree> {
173 #[cfg_attr(
174 feature = "serde",
175 serde(serialize_with = "MapOps::serialize"),
176 serde(deserialize_with = "MapOps::deserialize")
177 )]
178 map: <M as MapFromSlice<T>>::To<Box<FrequencyMap<T, M>>>,
179 depth: usize,
180 #[cfg_attr(feature = "serde", serde(skip))]
181 buf: VecDeque<T>,
182}
183
184/// Alias of <code>[Chain]<T, [map::BTree]></code>.
185pub type BTreeChain<T> = Chain<T, map::BTree>;
186
187#[cfg(feature = "std")]
188#[cfg_attr(feature = "doc_cfg", doc(cfg(feature = "std")))]
189/// Alias of <code>[Chain]<T, [map::Hash]></code>.
190pub type HashChain<T> = Chain<T, map::Hash>;
191
192impl<T, M: Map<T>> Chain<T, M> {
193 /// Creates a new chain.
194 ///
195 /// See [`Self::depth`] for an explanation of the depth.
196 pub fn new(depth: usize) -> Self {
197 Self {
198 map: Default::default(),
199 depth,
200 buf: Default::default(),
201 }
202 }
203
204 /// Gets the chain's depth.
205 ///
206 /// A depth of `n` means the chain maps sequences of `n` items of type `T`
207 /// to a list of possible next items.
208 pub fn depth(&self) -> usize {
209 self.depth
210 }
211
212 fn take_buf(&mut self) -> VecDeque<T> {
213 let mut buf = mem::take(&mut self.buf);
214 buf.clear();
215 if buf.capacity() == 0 {
216 buf.reserve_exact(self.depth);
217 }
218 buf
219 }
220
221 /// Adds all items in an iterator to the chain.
222 ///
223 /// This essentially calls [`Self::add`] on every overlapping window
224 /// of [`self.depth()`] elements plus the item following each window.
225 /// (Smaller windows at the start of the sequence may also be added
226 /// depending on the value of `edges`.)
227 ///
228 /// `edges` controls whether the start or end of `items` may be used as the
229 /// start or end of a generated sequence from the chain. See the
230 /// documentation for [`AddEdges`] for more information.
231 ///
232 /// [`self.depth()`]: Self::depth
233 pub fn add_all<I>(&mut self, items: I, edges: AddEdges)
234 where
235 I: IntoIterator<Item = T>,
236 T: Clone,
237 {
238 let mut buf = self.take_buf();
239 let mut iter = items.into_iter();
240 let mut item_opt = iter.next();
241 while let Some(item) = item_opt {
242 debug_assert!(buf.len() <= self.depth);
243 let next = iter.next();
244 if buf.len() == self.depth || edges.has_start() {
245 if next.is_some() || edges.has_end() {
246 self.add_with_key(&buf, Some(item.clone()));
247 } else {
248 self.add_with_key(&mut buf, Some(item));
249 break;
250 }
251 }
252 if buf.len() == self.depth {
253 buf.pop_front();
254 }
255 buf.push_back(item);
256 item_opt = next;
257 }
258 if !buf.is_empty() && edges.has_end() {
259 self.add_with_key(&mut buf, None);
260 }
261 self.buf = buf;
262 }
263
264 /// Adds items to the chain.
265 ///
266 /// If `items` yields at least [`self.depth()`] items, this increases the
267 /// chance that the first [`self.depth()`] items will be followed by `next`
268 /// in a generated sequence (additional items in `items` are ignored).
269 ///
270 /// If `items` yields fewer than [`self.depth()`] items, this method
271 /// increases the chance that those items, when they are the only items
272 /// that have been generated so far in a sequence (i.e., the start of a
273 /// sequence), will be followed by `next`. In the case where `items` yields
274 /// no elements, this increases the chance that `item` will be produced as
275 /// the first element of a generated sequence.
276 ///
277 /// If `next` is `None`, this method increases the chance that `items` will
278 /// be considered the end of a sequence.
279 ///
280 /// [`self.depth()`]: Self::depth
281 pub fn add<I>(&mut self, items: I, next: Option<T>)
282 where
283 I: IntoIterator<Item = T>,
284 {
285 let mut buf = self.take_buf();
286 buf.extend(items.into_iter().take(self.depth));
287 self.add_with_key(&mut buf, next);
288 self.buf = buf;
289 }
290
291 fn add_with_key<S>(&mut self, key: S, next: Option<T>)
292 where
293 S: SliceKey<T> + Into<OwnedSliceKey<T>>,
294 {
295 debug_assert!(key.get(self.depth()).is_none());
296 let freq = self.map.slice_get_or_insert_with(key, Default::default);
297 if let Some(v) = next {
298 *freq.map.get_or_insert_with(v, Default::default) += 1;
299 }
300 freq.total += 1;
301 }
302
303 #[cfg(feature = "std")]
304 #[cfg_attr(feature = "doc_cfg", doc(cfg(feature = "std")))]
305 /// Generates random Markov chain data.
306 ///
307 /// Returns an iterator that yields the elements by reference. If you want
308 /// them by value, simply use [`Iterator::cloned`] (as long as `T` is
309 /// [`Clone`]).
310 pub fn generate(&self) -> Generator<'_, T, rand::rngs::ThreadRng, M> {
311 self.generate_with_rng(rand::thread_rng())
312 }
313
314 /// Like [`Self::generate`], but takes a custom random number generator.
315 pub fn generate_with_rng<R: Rng>(&self, rng: R) -> Generator<'_, T, R, M> {
316 Generator::new(self, rng)
317 }
318
319 #[cfg(feature = "std")]
320 #[cfg_attr(feature = "doc_cfg", doc(cfg(feature = "std")))]
321 /// Gets a random item that has followed `items` in the added data.
322 ///
323 /// If `items` yields more than [`self.depth()`] items, only the first
324 /// [`self.depth()`] items are considered. If `items` yields fewer than
325 /// [`self.depth()`] items (potentially zero), those items are interpreted
326 /// as the beginning of a generated sequence.
327 ///
328 /// [`self.depth()`]: Self::depth
329 pub fn get<'a, B>(&'a self, items: &[B]) -> Option<&'a T>
330 where
331 B: Borrow<T>,
332 {
333 self.get_with_rng(items, rand::thread_rng())
334 }
335
336 /// Like [`Self::get`], but takes a custom random number generator.
337 pub fn get_with_rng<'a, B, R>(
338 &'a self,
339 items: &[B],
340 rng: R,
341 ) -> Option<&'a T>
342 where
343 B: Borrow<T>,
344 R: Rng,
345 {
346 self.get_with_key(items, rng)
347 }
348
349 fn get_with_key<S, R>(&self, items: S, mut rng: R) -> Option<&T>
350 where
351 S: SliceKey<T>,
352 R: Rng,
353 {
354 let freq = self.map.slice_get(&items)?;
355 let mut n = rng.gen_range(0..freq.total);
356 for (item, count) in freq.map.iter() {
357 n = if let Some(n) = n.checked_sub(*count) {
358 n
359 } else {
360 return Some(item);
361 }
362 }
363 None
364 }
365}
366
367impl<T: Debug, M: Map<T>> Debug for Chain<T, M> {
368 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
369 f.debug_struct("Chain")
370 .field("map", &map::MapDebug::new(&self.map))
371 .field("depth", &self.depth)
372 .field("buf", &self.buf)
373 .finish()
374 }
375}
376
377impl<T: Clone, M: Map<T>> Clone for Chain<T, M> {
378 fn clone(&self) -> Self {
379 Self {
380 map: self.map.clone(),
381 depth: self.depth,
382 buf: Default::default(),
383 }
384 }
385}
386
387#[derive(Clone, Copy, PartialEq, Eq)]
388/// Controls the behavior of [`Chain::add_all`].
389///
390/// This enum determines whether the start or end of the provided items can be
391/// used as start or end data for the chain (see individual variants'
392/// descriptions).
393pub enum AddEdges {
394 /// Allows any of the first [`Chain::depth`] items of the provided iterator
395 /// to be returned by calling [`Chain::get`] with the items preceding it
396 /// (of which there are necessarily fewer than [`Chain::depth`], and
397 /// potentially none). This also means that [`Chain::generate`] may yield
398 /// these items as its initial elements.
399 Start,
400
401 /// Allows the last [`Chain::depth`] items of the provided iterator (or
402 /// fewer, if it doesn't yield that many) to be considered the end of the
403 /// sequence, represented by a return value of [`None`] from
404 /// [`Chain::get`].
405 End,
406
407 /// Performs the behavior of both [`Self::Start`] and [`Self::End`].
408 Both,
409
410 /// Performs the behavior of neither [`Self::Start`] nor [`Self::End`].
411 Neither,
412}
413
414impl AddEdges {
415 fn has_start(&self) -> bool {
416 matches!(self, Self::Start | Self::Both)
417 }
418
419 fn has_end(&self) -> bool {
420 matches!(self, Self::End | Self::Both)
421 }
422}
423
424/// Iterator returned by [`Chain::generate`].
425pub struct Generator<'a, T, R, M: Map<T>> {
426 chain: &'a Chain<T, M>,
427 rng: R,
428 buf: VecDeque<&'a T>,
429}
430
431impl<'a, T, R, M: Map<T>> Generator<'a, T, R, M> {
432 /// Creates a new random Markov data generator. See [`Chain::generate`].
433 pub fn new(chain: &'a Chain<T, M>, rng: R) -> Self {
434 let mut buf = VecDeque::new();
435 buf.reserve_exact(chain.depth);
436 Self {
437 chain,
438 rng,
439 buf,
440 }
441 }
442
443 /// Gets the generator's state.
444 ///
445 /// This is a sequence of exactly [`self.chain.depth()`](Chain::depth)
446 /// items. [`Self::next`] will pass the state to [`Chain::get`] and then
447 /// update the state accordingly by popping the front item and pushing the
448 /// result of [`Chain::get`] to the back.
449 ///
450 /// The initial state is empty.
451 #[rustfmt::skip]
452 pub fn state(
453 &self,
454 ) -> impl '_
455 + Clone
456 + DoubleEndedIterator<Item = &T>
457 + ExactSizeIterator
458 {
459 self.buf.iter().copied()
460 }
461
462 /// Sets the generator's state.
463 ///
464 /// This method sets the generator's state to (up to) the first
465 /// [`self.chain.depth()`](Chain::depth) items in `state`.
466 ///
467 /// See [`Self::state`] for an explanation of how the state is used.
468 pub fn set_state<I>(&mut self, state: I)
469 where
470 I: IntoIterator<Item = &'a T>,
471 {
472 self.buf.clear();
473 let iter = state.into_iter().take(self.chain.depth());
474 self.buf.extend(iter);
475 }
476}
477
478impl<'a, T, R, M> Iterator for Generator<'a, T, R, M>
479where
480 T: Clone,
481 R: Rng,
482 M: Map<T>,
483{
484 type Item = &'a T;
485
486 fn next(&mut self) -> Option<&'a T> {
487 let next = self.chain.get_with_key(&self.buf, &mut self.rng)?;
488 debug_assert!(self.buf.len() <= self.chain.depth());
489 if self.buf.len() == self.chain.depth() {
490 self.buf.pop_front();
491 }
492 self.buf.push_back(next);
493 Some(next)
494 }
495}