use hashbrown::HashMap;
use itertools::Itertools;
use rand::Rng;
use rand::seq::IteratorRandom;
use unicode_segmentation::UnicodeSegmentation;
use crate::distribution::{TokenDistribution, TokenDistributionBuilder};
use crate::token::{TokenPair, TokenPairRef, TokenRef};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Chain {
map: HashMap<TokenPair, TokenDistribution>,
}
impl Chain {
pub fn from_text(content: &str) -> Result<Self, ChainBuilder> {
let mut cb = Self::builder();
cb = cb.feed_str(content)?.into();
cb.build()
}
pub fn builder() -> ChainBuilder {
ChainBuilder::new()
}
pub fn pairs(&self) -> impl Iterator<Item = &TokenPair> {
self.map.keys()
}
pub fn start_tokens(&self, rng: &mut impl Rng) -> Option<&TokenPair> {
self.pairs().choose(rng)
}
pub fn generate_str(&self, rng: &mut impl Rng, n: usize) -> Option<Vec<&str>> {
let start = self.start_tokens(rng)?;
self.generate_n_tokens(rng, &start.as_ref(), n)
}
pub fn generate_next_token(
&self,
rng: &mut impl Rng,
prev: &TokenPairRef<'_>,
) -> Option<TokenRef<'_>> {
let dist = self.map.get(prev)?;
Some(dist.get_random_token(rng))
}
pub fn generate_n_tokens(
&self,
rng: &mut impl Rng,
prev: &TokenPairRef<'_>,
n: usize,
) -> Option<Vec<TokenRef<'_>>> {
if n < 1 {
return Some(Vec::new());
}
let first = self.generate_next_token(rng, prev)?;
let mut res = Vec::with_capacity(n);
res.push(first);
let (mut left, mut right) = (prev.1, first);
while res.len() < n {
if let Some(next) = self.generate_next_token(rng, &(left, right)) {
res.push(next);
left = right;
right = next;
} else {
let tp = self.start_tokens(rng).unwrap();
let r = n - res.len();
if r >= 2 {
left = &tp.0;
right = &tp.1;
res.push(&tp.0);
res.push(&tp.1);
} else if r == 1 {
res.push(&tp.0);
break;
} else {
break;
}
}
}
Some(res)
}
pub fn generate_max_n_tokens(
&self,
rng: &mut impl Rng,
prev: &TokenPairRef<'_>,
n: usize,
) -> Option<Vec<TokenRef<'_>>> {
if n < 1 {
return Some(Vec::new());
}
let first = self.generate_next_token(rng, prev)?;
let mut res = Vec::with_capacity(n);
res.push(first);
let remaining = n - 1;
let (mut left, mut right) = (prev.1, first);
for _ in 0..remaining {
if let Some(next) = self.generate_next_token(rng, &(left, right)) {
res.push(next);
left = right;
right = next;
} else {
break;
}
}
Some(res)
}
}
pub type FeedResult = Result<UpdatedChainBuilder, ChainBuilder>;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ChainBuilder {
map: HashMap<TokenPair, TokenDistributionBuilder>,
}
impl ChainBuilder {
pub fn new() -> Self {
Self {
map: HashMap::new(),
}
}
pub fn build(self) -> Result<Chain, ChainBuilder> {
if self.map.is_empty() {
return Err(self);
}
let mut chain_map = HashMap::with_capacity(self.map.len());
for (pair, dist_builder) in self.map {
chain_map.insert(pair, dist_builder.build());
}
Ok(Chain { map: chain_map })
}
pub fn add_occurance(&mut self, prev: &TokenPairRef<'_>, next: &str) -> AddedPair {
match self.map.get_mut(&prev) {
Some(b) => {
b.add_token(next);
AddedPair::Updated
}
None => {
let mut b = TokenDistributionBuilder::new();
b.add_token(next);
let tp = TokenPair::from(prev);
self.map.insert(tp, b);
AddedPair::New
}
}
}
pub fn feed_str(self, content: &str) -> FeedResult {
let tokens = content.split_word_bounds();
self.feed_tokens(tokens)
}
pub fn feed_tokens<'a, T: Iterator<Item = TokenRef<'a>>>(mut self, tokens: T) -> FeedResult {
let mut windows = tokens.tuple_windows();
let mut new_pairs = 0_usize;
let mut updated_pairs = 0_usize;
if let Some((left, right, next)) = windows.next() {
match self.add_occurance(&(left, right), next) {
AddedPair::New => new_pairs += 1,
AddedPair::Updated => updated_pairs += 1,
}
} else {
return Err(self);
}
for (left, right, next) in windows {
match self.add_occurance(&(left, right), next) {
AddedPair::New => new_pairs += 1,
AddedPair::Updated => updated_pairs += 1,
}
}
Ok(UpdatedChainBuilder {
chain_builder: self,
new_pairs,
updated_pairs,
})
}
}
impl Default for ChainBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct UpdatedChainBuilder {
pub chain_builder: ChainBuilder,
pub new_pairs: usize,
pub updated_pairs: usize,
}
impl From<UpdatedChainBuilder> for ChainBuilder {
fn from(value: UpdatedChainBuilder) -> Self {
value.chain_builder
}
}
impl From<FeedResult> for ChainBuilder {
fn from(value: FeedResult) -> Self {
match value {
Ok(ucb) => ucb.chain_builder,
Err(cb) => cb,
}
}
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive())]
pub enum AddedPair {
New,
Updated,
}
trait SealedIntoChainBuilder {}
impl SealedIntoChainBuilder for FeedResult {}
impl SealedIntoChainBuilder for UpdatedChainBuilder {}
#[allow(private_bounds)]
pub trait IntoChainBuilder: SealedIntoChainBuilder {
fn into_cb(self) -> ChainBuilder;
}
impl IntoChainBuilder for FeedResult {
fn into_cb(self) -> ChainBuilder {
match self {
Ok(ucb) => ucb.chain_builder,
Err(cb) => cb,
}
}
}
impl IntoChainBuilder for UpdatedChainBuilder {
fn into_cb(self) -> ChainBuilder {
self.chain_builder
}
}
#[cfg(test)]
mod tests {
use rand::rng;
use crate::{Chain, ChainBuilder, chain::IntoChainBuilder, distribution::TokenDistribution};
#[test]
#[should_panic]
fn empty_chain_builder_panics() {
let _ = Chain::builder().build().unwrap();
}
#[test]
#[should_panic]
fn empty_token_dist_builder_panics() {
let _ = TokenDistribution::builder().build();
}
#[test]
fn feed_too_few_tokens() {
let s = "I ";
assert!(Chain::builder().feed_str(s).is_err());
}
#[test]
fn simple_single_possible_token() {
let s = "I am";
let cb = Chain::builder().feed_str(s).into_cb();
let chain = cb.build().unwrap();
assert_eq!(
chain.generate_next_token(&mut rng(), &("I", " ")).unwrap(),
"am"
);
}
#[test]
fn simple_single_impossible_token() {
let s = "I am";
let cb = Chain::builder().feed_str(s).into_cb();
let chain = cb.build().unwrap();
assert!(
chain
.generate_next_token(&mut rng(), &("You", " "))
.is_none()
);
}
#[test]
fn simple_generate_max_n_tokens() {
let s = "I am-full!of?cats";
let cb = Chain::builder().feed_str(s).into_cb();
let chain = cb.build().unwrap();
assert_eq!(
chain
.generate_max_n_tokens(&mut rng(), &("I", " "), 7)
.unwrap(),
vec!["am", "-", "full", "!", "of", "?", "cats"],
);
assert_eq!(
chain
.generate_max_n_tokens(&mut rng(), &("I", " "), 2)
.unwrap(),
vec!["am", "-"],
);
assert_eq!(
chain
.generate_max_n_tokens(&mut rng(), &("I", " "), 13)
.unwrap()
.len(),
7
);
}
#[test]
fn simple_generate_n_tokens() {
let s = "I am-full!of?cats";
let cb = Chain::builder().feed_str(s).into_cb();
let chain = cb.build().unwrap();
assert_eq!(
chain.generate_n_tokens(&mut rng(), &("I", " "), 7).unwrap(),
vec!["am", "-", "full", "!", "of", "?", "cats"],
);
assert_eq!(
chain.generate_n_tokens(&mut rng(), &("I", " "), 2).unwrap(),
vec!["am", "-"],
);
assert_eq!(
chain
.generate_n_tokens(&mut rng(), &("I", " "), 13)
.unwrap()
.len(),
13
);
assert_eq!(
chain
.generate_n_tokens(&mut rng(), &("I", " "), 8)
.unwrap()
.len(),
8
);
}
#[test]
fn simple_generate_max_n_tokens_zero() {
let s = "I am-full!of?cats";
let cb = Chain::builder().feed_str(s).into_cb();
let chain = cb.build().unwrap();
assert!(
chain
.generate_max_n_tokens(&mut rng(), &("I", " "), 0)
.unwrap()
.is_empty()
)
}
#[test]
fn simple_generate_max_n_tokens_impossible_first() {
let s = "I am-full!of?cats";
let cb = Chain::builder().feed_str(s).into_cb();
let chain = cb.build().unwrap();
assert!(
chain
.generate_max_n_tokens(&mut rng(), &("You", " "), 13)
.is_none()
)
}
#[test]
fn simple_generate_n_tokens_zero() {
let s = "I am-full!of?cats";
let cb = Chain::builder().feed_str(s).into_cb();
let chain = cb.build().unwrap();
assert!(
chain
.generate_n_tokens(&mut rng(), &("I", " "), 0)
.unwrap()
.is_empty()
)
}
#[test]
fn simple_generate_n_tokens_impossible_first() {
let s = "I am-full!of?cats";
let cb = Chain::builder().feed_str(s).into_cb();
let chain = cb.build().unwrap();
assert!(
chain
.generate_n_tokens(&mut rng(), &("You", " "), 13)
.is_none()
)
}
#[test]
fn generate_long_from_start_tokens() {
let s = r#"
Coach: How's it going, Norm?
Norm: Daddy's rich and Momma's good lookin'.
-- Cheers, Truce or Consequences
Sam: What's up, Norm?
Norm: My nipples. It's freezing out there.
-- Cheers, Coach Returns to Action
Coach: What's the story, Norm?
Norm: Thirsty guy walks into a bar. You finish it.
-- Cheers, Endless Slumper
"#;
let cb = Chain::builder().feed_str(s).into_cb();
let chain = cb.build().unwrap();
let mut rng = rng();
for _ in 0..100 {
let start = chain.start_tokens(&mut rng).unwrap();
let _ = chain.generate_n_tokens(&mut rng, &start.as_ref(), 100);
}
}
#[test]
fn generate_long_using_generate_str() {
let s = r#"
The difference between a program and a script isn't as subtle as most people
think. A script is interpreted, and a program is compiled.
Of course, there's no reason you can't write a compiler that immediately
executes the compiled form of a program without writing compilation artifacts
to disk, but that's an implementation detail, and precision in technical
matters is important.
Though Perl 5, for example, doesn't write out the artifacts of compilation to
disk and Java and .Net do, Perl 5 is clearly an interpreter even though it
evaluates the compiled form of code in the same way that the JVM and the CLR
do. Why? Because it's a scripting language.
Okay, that's a facetious explanation.
The difference between a program and a script is if there's native compilation
available in at least one widely-used implementation. Thus Java before the
prevalence of even the HotSpot JVM and its JIT was a scripting language and
now it's a programming language, except that you can write a C interpreter
that doesn't have a JIT and C programs become scripts.
-- chromatic
-- "Program vs. Script" ( http://use.perl.org/~chromatic/journal/35804 )
"#;
let chain = Chain::from_text(s).unwrap();
for _ in 0..100 {
chain.generate_str(&mut rng(), 100).unwrap();
}
}
#[test]
fn get_pairs() {
let s = r#"
This is a text.
There are many like it, but this one is mine.
-- Unknown
"#;
let chain = Chain::from_text(s).unwrap();
let good_starting_points: Vec<_> =
chain.pairs().filter(|tp| tp.0.as_str() == "\n").collect();
assert_eq!(good_starting_points.len(), 3);
}
#[test]
fn feed_stats() {
let cb = ChainBuilder::new();
let ucb = cb
.feed_tokens("hi hi what hi hi end".split_whitespace())
.unwrap();
assert_eq!(ucb.new_pairs, 3);
assert_eq!(ucb.updated_pairs, 1, "hi hi should be updated once");
}
}