use crate::chain::{Chain, BEGIN};
use crate::errors::{MarkovError, Result};
use crate::splitters::split_into_sentences;
use lazy_static::lazy_static;
use regex::Regex;
use serde::{Deserialize, Serialize};
const DEFAULT_MAX_OVERLAP_RATIO: f64 = 0.7;
const DEFAULT_MAX_OVERLAP_TOTAL: usize = 15;
const DEFAULT_TRIES: usize = 10;
lazy_static! {
static ref REJECT_PAT: Regex = Regex::new(r#"(^')|('$)|\s'|'\s|["(\(\)\[\])]"#).unwrap();
static ref WORD_SPLIT_PATTERN: Regex = Regex::new(r"\s+").unwrap();
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextData {
pub state_size: usize,
pub chain: String, pub parsed_sentences: Option<Vec<Vec<String>>>,
}
#[derive(Debug, Clone)]
pub struct Text {
state_size: usize,
chain: Chain,
parsed_sentences: Option<Vec<Vec<String>>>,
rejoined_text: Option<String>,
retain_original: bool,
well_formed: bool,
reject_pat: Regex,
}
impl Text {
pub fn new(
input_text: &str,
state_size: usize,
retain_original: bool,
well_formed: bool,
reject_reg: Option<&str>,
) -> Result<Self> {
let reject_pat = if let Some(reg) = reject_reg {
Regex::new(reg).map_err(|e| MarkovError::ParamError(format!("Invalid regex: {}", e)))?
} else {
REJECT_PAT.clone()
};
let parsed_sentences: Vec<Vec<String>> =
Self::generate_corpus(input_text, &reject_pat, well_formed)
.into_iter()
.collect();
let rejoined_text = if retain_original && !parsed_sentences.is_empty() {
Some(Self::sentence_join_static(
&parsed_sentences
.iter()
.map(|s| Self::word_join_static(s))
.collect::<Vec<_>>(),
))
} else {
None
};
let chain = Chain::new(&parsed_sentences, state_size);
Ok(Text {
state_size,
chain,
parsed_sentences: if retain_original {
Some(parsed_sentences)
} else {
None
},
rejoined_text,
retain_original,
well_formed,
reject_pat,
})
}
pub fn from_chain(
chain: Chain,
parsed_sentences: Option<Vec<Vec<String>>>,
retain_original: bool,
) -> Self {
let state_size = chain.state_size();
let rejoined_text = if retain_original {
parsed_sentences.as_ref().map(|sentences| {
Self::sentence_join_static(
&sentences
.iter()
.map(|s| Self::word_join_static(s))
.collect::<Vec<_>>(),
)
})
} else {
None
};
Text {
state_size,
chain,
parsed_sentences,
rejoined_text,
retain_original,
well_formed: true,
reject_pat: REJECT_PAT.clone(),
}
}
pub fn sentence_split(&self, text: &str) -> Vec<String> {
split_into_sentences(text)
}
pub fn sentence_join(&self, sentences: &[String]) -> String {
sentences.join(" ")
}
pub fn word_split(&self, sentence: &str) -> Vec<String> {
WORD_SPLIT_PATTERN
.split(sentence)
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
.collect()
}
pub fn word_join(&self, words: &[String]) -> String {
words.join(" ")
}
pub fn test_sentence_input(&self, sentence: &str) -> bool {
if sentence.trim().is_empty() {
return false;
}
if self.well_formed && self.reject_pat.is_match(sentence) {
return false;
}
true
}
fn generate_corpus(text: &str, reject_pat: &Regex, well_formed: bool) -> Vec<Vec<String>> {
let sentences = split_into_sentences(text);
sentences
.into_iter()
.filter(|s| {
if !well_formed {
return true;
}
if s.trim().is_empty() {
return false;
}
if reject_pat.is_match(s) {
return false;
}
true
})
.map(|s| {
WORD_SPLIT_PATTERN
.split(&s)
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
.collect()
})
.collect()
}
fn test_sentence_output(
&self,
words: &[String],
max_overlap_ratio: f64,
max_overlap_total: usize,
) -> bool {
if let Some(ref rejoined) = self.rejoined_text {
let overlap_ratio = ((max_overlap_ratio * words.len() as f64).round() as usize).max(1);
let overlap_max = overlap_ratio.min(max_overlap_total);
let overlap_over = overlap_max + 1;
let gram_count = words.len().saturating_sub(overlap_max).max(1);
for i in 0..gram_count {
let gram = &words[i..(i + overlap_over).min(words.len())];
let gram_joined = self.word_join(gram);
if rejoined.contains(&gram_joined) {
return false;
}
}
}
true
}
#[allow(clippy::too_many_arguments)]
pub fn make_sentence(
&self,
init_state: Option<&[String]>,
tries: Option<usize>,
max_overlap_ratio: Option<f64>,
max_overlap_total: Option<usize>,
test_output: Option<bool>,
max_words: Option<usize>,
min_words: Option<usize>,
) -> Option<String> {
let tries = tries.unwrap_or(DEFAULT_TRIES);
let mor = max_overlap_ratio.unwrap_or(DEFAULT_MAX_OVERLAP_RATIO);
let mot = max_overlap_total.unwrap_or(DEFAULT_MAX_OVERLAP_TOTAL);
let test = test_output.unwrap_or(true);
let prefix: Vec<String> = if let Some(state) = init_state {
state.iter().filter(|w| *w != BEGIN).cloned().collect()
} else {
vec![]
};
for _ in 0..tries {
let mut words = prefix.clone();
words.extend(self.chain.walk(init_state));
if let Some(max) = max_words {
if words.len() > max {
continue;
}
}
if let Some(min) = min_words {
if words.len() < min {
continue;
}
}
if test && self.rejoined_text.is_some() {
if self.test_sentence_output(&words, mor, mot) {
return Some(self.word_join(&words));
}
} else {
return Some(self.word_join(&words));
}
}
None
}
#[allow(clippy::too_many_arguments)]
pub fn make_short_sentence(
&self,
max_chars: usize,
min_chars: Option<usize>,
init_state: Option<&[String]>,
tries: Option<usize>,
max_overlap_ratio: Option<f64>,
max_overlap_total: Option<usize>,
test_output: Option<bool>,
max_words: Option<usize>,
min_words: Option<usize>,
) -> Option<String> {
let tries = tries.unwrap_or(DEFAULT_TRIES);
let min_chars = min_chars.unwrap_or(0);
for _ in 0..tries {
if let Some(sentence) = self.make_sentence(
init_state,
Some(tries),
max_overlap_ratio,
max_overlap_total,
test_output,
max_words,
min_words,
) {
let len = sentence.len();
if len >= min_chars && len <= max_chars {
return Some(sentence);
}
}
}
None
}
#[allow(clippy::too_many_arguments)]
pub fn make_sentence_with_start(
&self,
beginning: &str,
strict: bool,
tries: Option<usize>,
max_overlap_ratio: Option<f64>,
max_overlap_total: Option<usize>,
test_output: Option<bool>,
max_words: Option<usize>,
min_words: Option<usize>,
) -> Result<String> {
let split = self.word_split(beginning);
let word_count = split.len();
if word_count > self.state_size {
return Err(MarkovError::ParamError(format!(
"`make_sentence_with_start` for this model requires a string containing 1 to {} words. Yours has {}: {:?}",
self.state_size, word_count, split
)));
}
let init_states: Vec<Vec<String>> = if word_count == self.state_size {
vec![split.clone()]
} else if word_count < self.state_size {
if strict {
let mut state = vec![BEGIN.to_string(); self.state_size - word_count];
state.extend(split.clone());
vec![state]
} else {
self.find_init_states_from_chain(&split)
}
} else {
return Err(MarkovError::ParamError(format!(
"Invalid word count: {}",
word_count
)));
};
if init_states.is_empty() {
return Err(MarkovError::ParamError(format!(
"Cannot find sentence beginning with: {}",
beginning
)));
}
for init_state in init_states {
if let Some(output) = self.make_sentence(
Some(&init_state),
tries,
max_overlap_ratio,
max_overlap_total,
test_output,
max_words,
min_words,
) {
return Ok(output);
}
}
Err(MarkovError::ParamError(format!(
"Cannot generate sentence beginning with: {}",
beginning
)))
}
fn find_init_states_from_chain(&self, split: &[String]) -> Vec<Vec<String>> {
let word_count = split.len();
let mut states = Vec::new();
for key in self.chain.model().keys() {
let filtered: Vec<&String> = key.iter().filter(|w| *w != BEGIN).collect();
if filtered.len() >= word_count
&& filtered[..word_count]
.iter()
.zip(split.iter())
.all(|(a, b)| *a == b)
{
states.push(key.clone());
}
}
states
}
pub fn compile(&self) -> Self {
let compiled_chain = self.chain.compile();
Text {
state_size: self.state_size,
chain: compiled_chain,
parsed_sentences: self.parsed_sentences.clone(),
rejoined_text: self.rejoined_text.clone(),
retain_original: self.retain_original,
well_formed: self.well_formed,
reject_pat: self.reject_pat.clone(),
}
}
pub fn compile_inplace(&mut self) {
self.chain = self.chain.compile();
}
pub fn state_size(&self) -> usize {
self.state_size
}
pub fn chain(&self) -> &Chain {
&self.chain
}
pub fn to_json(&self) -> Result<String> {
let data = TextData {
state_size: self.state_size,
chain: self.chain.to_json()?,
parsed_sentences: self.parsed_sentences.clone(),
};
Ok(serde_json::to_string(&data)?)
}
pub fn from_json(json_str: &str) -> Result<Self> {
let data: TextData = serde_json::from_str(json_str)?;
let chain = Chain::from_json(&data.chain)?;
Ok(Text {
state_size: data.state_size,
chain,
parsed_sentences: data.parsed_sentences.clone(),
rejoined_text: data.parsed_sentences.as_ref().map(|sentences| {
Self::sentence_join_static(
&sentences
.iter()
.map(|s| Self::word_join_static(s))
.collect::<Vec<_>>(),
)
}),
retain_original: data.parsed_sentences.is_some(),
well_formed: true,
reject_pat: REJECT_PAT.clone(),
})
}
pub fn retain_original(&self) -> bool {
self.retain_original
}
pub fn parsed_sentences(&self) -> Option<&Vec<Vec<String>>> {
self.parsed_sentences.as_ref()
}
fn sentence_join_static(sentences: &[String]) -> String {
sentences.join(" ")
}
fn word_join_static(words: &[String]) -> String {
words.join(" ")
}
}
#[derive(Debug, Clone)]
pub struct NewlineText {
inner: Text,
}
impl NewlineText {
pub fn new(
input_text: &str,
state_size: usize,
retain_original: bool,
well_formed: bool,
reject_reg: Option<&str>,
) -> Result<Self> {
let text = Text::new(
input_text,
state_size,
retain_original,
well_formed,
reject_reg,
)?;
Ok(NewlineText { inner: text })
}
pub fn sentence_split(&self, text: &str) -> Vec<String> {
text.split('\n')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect()
}
#[allow(clippy::too_many_arguments)]
pub fn make_sentence(
&self,
init_state: Option<&[String]>,
tries: Option<usize>,
max_overlap_ratio: Option<f64>,
max_overlap_total: Option<usize>,
test_output: Option<bool>,
max_words: Option<usize>,
min_words: Option<usize>,
) -> Option<String> {
self.inner.make_sentence(
init_state,
tries,
max_overlap_ratio,
max_overlap_total,
test_output,
max_words,
min_words,
)
}
#[allow(clippy::too_many_arguments)]
pub fn make_short_sentence(
&self,
max_chars: usize,
min_chars: Option<usize>,
init_state: Option<&[String]>,
tries: Option<usize>,
max_overlap_ratio: Option<f64>,
max_overlap_total: Option<usize>,
test_output: Option<bool>,
max_words: Option<usize>,
min_words: Option<usize>,
) -> Option<String> {
self.inner.make_short_sentence(
max_chars,
min_chars,
init_state,
tries,
max_overlap_ratio,
max_overlap_total,
test_output,
max_words,
min_words,
)
}
pub fn to_json(&self) -> Result<String> {
self.inner.to_json()
}
pub fn from_json(json_str: &str) -> Result<Self> {
let text = Text::from_json(json_str)?;
Ok(NewlineText { inner: text })
}
pub fn inner(&self) -> &Text {
&self.inner
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_text_creation() {
let text = "Hello world. This is a test.";
let model = Text::new(text, 2, true, true, None).unwrap();
assert_eq!(model.state_size(), 2);
}
#[test]
fn test_make_sentence() {
let text = "The cat sat on the mat. The dog ran in the park. The bird flew over the tree. The cat chased the mouse. The dog barked loudly.";
let model = Text::new(text, 1, true, true, None).unwrap();
let sentence = model.make_sentence(None, None, None, None, None, None, None);
assert!(sentence.is_some());
}
#[test]
fn test_json_serialization() {
let text = "Hello world. This is a test.";
let model = Text::new(text, 2, true, true, None).unwrap();
let json = model.to_json().unwrap();
let restored = Text::from_json(&json).unwrap();
assert_eq!(model.state_size(), restored.state_size());
}
#[test]
fn test_newline_text() {
let text = "Line one
Line two
Line three";
let model = NewlineText::new(text, 2, true, true, None).unwrap();
let sentences = model.sentence_split(text);
assert_eq!(sentences.len(), 3);
}
}