use hashbrown::{hash_map::RawEntryMut, HashMap};
use lasso::{Capacity, Rodeo, Spur};
use rand::{seq::SliceRandom, RngCore};
use regex::Regex;
use smallvec::SmallVec;
#[cfg(feature = "serialize")]
use {
serde::{Deserialize, Serialize},
serde_json_any_key::*,
};
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
#[derive(Clone)]
pub struct RawMarkovChain<const N: usize> {
#[cfg_attr(feature = "serialize", serde(with = "any_key_map"))]
items: HashMap<SmallVec<[Spur; N]>, ChainItem, foldhash::fast::FixedState>,
state_size: usize,
#[cfg_attr(feature = "serialize", serde(with = "serde_regex"))]
regex: Regex,
cache: Rodeo,
}
pub type MarkovChain = RawMarkovChain<4>;
impl<const N: usize> RawMarkovChain<N> {
#[inline]
pub fn new(state_size: usize, regex: Regex) -> RawMarkovChain<N> {
RawMarkovChain {
items: HashMap::with_hasher(foldhash::fast::FixedState::default()),
state_size,
regex,
cache: Rodeo::new(),
}
}
#[inline]
pub fn with_capacity(
state_size: usize,
capacity: usize,
regex: Regex,
) -> RawMarkovChain<N> {
RawMarkovChain {
items: HashMap::with_capacity_and_hasher(
capacity,
foldhash::fast::FixedState::default(),
),
state_size,
regex,
cache: Rodeo::with_capacity(Capacity::for_strings(capacity)),
}
}
pub fn add_text(&mut self, text: &str) {
let tokens: Vec<Spur> = self
.regex
.find_iter(text)
.map(|t| self.cache.get_or_intern(t.as_str()))
.collect();
if tokens.is_empty() {
return;
}
for win in tokens.windows(tokens.len().min(self.state_size + 1)) {
let wlen = win.len();
let rel = win.last().unwrap();
for i in 2..=wlen {
let slice = &win[(wlen - i)..(wlen - 1)];
match self.items.raw_entry_mut().from_key(slice) {
RawEntryMut::Occupied(mut view) => {
view.get_mut().add(*rel);
}
RawEntryMut::Vacant(view) => {
view.insert(
SmallVec::from_slice(slice),
ChainItem::new(*rel),
);
}
}
}
}
}
pub fn add_text_weighted(&mut self, text: &str, weight: usize) {
if weight == 0 {
return;
}
let tokens: Vec<Spur> = self
.regex
.find_iter(text)
.map(|t| self.cache.get_or_intern(t.as_str()))
.collect();
if tokens.is_empty() {
return;
}
for win in tokens.windows(tokens.len().min(self.state_size + 1)) {
let wlen = win.len();
let rel = win.last().unwrap();
for i in 2..=wlen {
let slice = &win[(wlen - i)..(wlen - 1)];
match self.items.raw_entry_mut().from_key(slice) {
RawEntryMut::Occupied(mut view) => {
view.get_mut().add_weighted(*rel, weight);
}
RawEntryMut::Vacant(view) => {
view.insert(
SmallVec::from_slice(slice),
ChainItem::new_weighted(*rel, weight),
);
}
}
}
}
}
pub fn generate(&self, length: usize, rng: &mut impl RngCore) -> Option<String> {
if self.is_empty() {
return None;
}
let mut res = String::new();
for next in self.iter(length, rng) {
res.push_str(next);
res.push(' ');
}
res.pop();
Some(res)
}
pub fn generate_start(
&self,
start: &str,
length: usize,
rng: &mut impl RngCore,
) -> Option<String> {
if self.is_empty() {
return None;
}
let mut res = String::new();
for next in self.iter_start(start, length, rng) {
res.push_str(next);
res.push(' ');
}
res.pop();
Some(res)
}
#[inline]
pub fn len(&self) -> usize {
self.items.len()
}
#[inline]
pub fn cache_len(&self) -> usize {
self.cache.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn state_size(&self) -> usize {
self.state_size
}
#[inline]
pub fn regex(&self) -> Regex {
self.regex.clone()
}
#[inline]
pub fn iter<'a>(
&'a self,
count: usize,
rng: &'a mut dyn RngCore,
) -> MarkovChainIter<'a, N> {
MarkovChainIter {
chain: self,
count,
rng,
prev: Vec::with_capacity(self.state_size),
}
}
#[inline]
pub fn iter_start<'a>(
&'a self,
start: &str,
count: usize,
rng: &'a mut dyn RngCore,
) -> MarkovChainIter<'a, N> {
let prev: Vec<Spur> = self
.regex
.find_iter(start)
.map(|m| m.as_str())
.collect::<Vec<&str>>()
.into_iter()
.rev()
.take(self.state_size)
.rev()
.filter_map(|t| self.cache.get(t))
.collect();
MarkovChainIter {
chain: self,
count,
rng,
prev,
}
}
fn next_step(&self, prev: &[Spur], rng: &mut impl RngCore) -> Option<Spur> {
for i in 0..prev.len() {
let pslice = &prev[i..];
if let Some(res) = self.items.get(pslice) {
return res.get_rand(rng);
} else {
continue;
}
}
self.items
.values()
.collect::<Vec<&ChainItem>>()
.choose(rng)?
.get_rand(rng)
}
}
pub struct MarkovChainIter<'a, const N: usize> {
chain: &'a RawMarkovChain<N>,
count: usize,
rng: &'a mut dyn RngCore,
prev: Vec<Spur>,
}
impl<'a, const N: usize> Iterator for MarkovChainIter<'a, N> {
type Item = &'a str;
fn next(&mut self) -> Option<Self::Item> {
if self.count == 0 {
return None;
}
self.count -= 1;
let next_spur = self.chain.next_step(&self.prev, &mut self.rng)?;
let next = self.chain.cache.resolve(&next_spur);
if self.prev.len() == self.chain.state_size {
self.prev.remove(0);
}
self.prev.push(next_spur);
Some(next)
}
}
#[cfg_attr(
feature = "serialize",
derive(Serialize, Deserialize),
serde(transparent)
)]
#[derive(Clone)]
struct ChainItem {
items: Vec<Spur>,
}
impl ChainItem {
#[inline]
fn new(s: Spur) -> ChainItem {
ChainItem { items: vec![s] }
}
#[inline]
fn new_weighted(s: Spur, weight: usize) -> ChainItem {
ChainItem {
items: vec![s; weight],
}
}
#[inline]
fn add(&mut self, s: Spur) {
self.items.push(s);
}
#[inline]
fn add_weighted(&mut self, s: Spur, weight: usize) {
self.items.extend(std::iter::repeat(s).take(weight));
}
#[inline]
fn get_rand(&self, rng: &mut impl RngCore) -> Option<Spur> {
let res = *self
.items
.choose(rng)?;
Some(res)
}
}