1use crate::chain::{Chain, BEGIN};
4use crate::errors::{MarkovError, Result};
5use crate::splitters::split_into_sentences;
6use lazy_static::lazy_static;
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9
10const DEFAULT_MAX_OVERLAP_RATIO: f64 = 0.7;
12const DEFAULT_MAX_OVERLAP_TOTAL: usize = 15;
14const DEFAULT_TRIES: usize = 10;
16
17lazy_static! {
18 static ref REJECT_PAT: Regex = Regex::new(r#"(^')|('$)|\s'|'\s|["(\(\)\[\])]"#).unwrap();
20 static ref WORD_SPLIT_PATTERN: Regex = Regex::new(r"\s+").unwrap();
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TextData {
27 pub state_size: usize,
28 pub chain: String, pub parsed_sentences: Option<Vec<Vec<String>>>,
30}
31
32#[derive(Debug, Clone)]
34pub struct Text {
35 state_size: usize,
36 chain: Chain,
37 parsed_sentences: Option<Vec<Vec<String>>>,
38 rejoined_text: Option<String>,
39 retain_original: bool,
40 well_formed: bool,
41 reject_pat: Regex,
42}
43
44impl Text {
45 pub fn new(
54 input_text: &str,
55 state_size: usize,
56 retain_original: bool,
57 well_formed: bool,
58 reject_reg: Option<&str>,
59 ) -> Result<Self> {
60 let reject_pat = if let Some(reg) = reject_reg {
61 Regex::new(reg).map_err(|e| MarkovError::ParamError(format!("Invalid regex: {}", e)))?
62 } else {
63 REJECT_PAT.clone()
64 };
65
66 let parsed_sentences: Vec<Vec<String>> =
67 Self::generate_corpus(input_text, &reject_pat, well_formed)
68 .into_iter()
69 .collect();
70
71 let rejoined_text = if retain_original && !parsed_sentences.is_empty() {
72 Some(Self::sentence_join_static(
73 &parsed_sentences
74 .iter()
75 .map(|s| Self::word_join_static(s))
76 .collect::<Vec<_>>(),
77 ))
78 } else {
79 None
80 };
81
82 let chain = Chain::new(&parsed_sentences, state_size);
83
84 Ok(Text {
85 state_size,
86 chain,
87 parsed_sentences: if retain_original {
88 Some(parsed_sentences)
89 } else {
90 None
91 },
92 rejoined_text,
93 retain_original,
94 well_formed,
95 reject_pat,
96 })
97 }
98
99 pub fn from_chain(
101 chain: Chain,
102 parsed_sentences: Option<Vec<Vec<String>>>,
103 retain_original: bool,
104 ) -> Self {
105 let state_size = chain.state_size();
106
107 let rejoined_text = if retain_original {
108 parsed_sentences.as_ref().map(|sentences| {
109 Self::sentence_join_static(
110 &sentences
111 .iter()
112 .map(|s| Self::word_join_static(s))
113 .collect::<Vec<_>>(),
114 )
115 })
116 } else {
117 None
118 };
119
120 Text {
121 state_size,
122 chain,
123 parsed_sentences,
124 rejoined_text,
125 retain_original,
126 well_formed: true,
127 reject_pat: REJECT_PAT.clone(),
128 }
129 }
130
131 pub fn sentence_split(&self, text: &str) -> Vec<String> {
133 split_into_sentences(text)
134 }
135
136 pub fn sentence_join(&self, sentences: &[String]) -> String {
138 sentences.join(" ")
139 }
140
141 pub fn word_split(&self, sentence: &str) -> Vec<String> {
143 WORD_SPLIT_PATTERN
144 .split(sentence)
145 .filter(|s| !s.is_empty())
146 .map(|s| s.to_string())
147 .collect()
148 }
149
150 pub fn word_join(&self, words: &[String]) -> String {
152 words.join(" ")
153 }
154
155 pub fn test_sentence_input(&self, sentence: &str) -> bool {
157 if sentence.trim().is_empty() {
158 return false;
159 }
160
161 if self.well_formed && self.reject_pat.is_match(sentence) {
162 return false;
163 }
164
165 true
166 }
167
168 fn generate_corpus(text: &str, reject_pat: &Regex, well_formed: bool) -> Vec<Vec<String>> {
170 let sentences = split_into_sentences(text);
171
172 sentences
173 .into_iter()
174 .filter(|s| {
175 if !well_formed {
176 return true;
177 }
178 if s.trim().is_empty() {
180 return false;
181 }
182 if reject_pat.is_match(s) {
183 return false;
184 }
185 true
186 })
187 .map(|s| {
188 WORD_SPLIT_PATTERN
189 .split(&s)
190 .filter(|s| !s.is_empty())
191 .map(|s| s.to_string())
192 .collect()
193 })
194 .collect()
195 }
196
197 fn test_sentence_output(
199 &self,
200 words: &[String],
201 max_overlap_ratio: f64,
202 max_overlap_total: usize,
203 ) -> bool {
204 if let Some(ref rejoined) = self.rejoined_text {
205 let overlap_ratio = ((max_overlap_ratio * words.len() as f64).round() as usize).max(1);
206 let overlap_max = overlap_ratio.min(max_overlap_total);
207 let overlap_over = overlap_max + 1;
208 let gram_count = words.len().saturating_sub(overlap_max).max(1);
209
210 for i in 0..gram_count {
211 let gram = &words[i..(i + overlap_over).min(words.len())];
212 let gram_joined = self.word_join(gram);
213 if rejoined.contains(&gram_joined) {
214 return false;
215 }
216 }
217 }
218 true
219 }
220
221 #[allow(clippy::too_many_arguments)]
232 pub fn make_sentence(
233 &self,
234 init_state: Option<&[String]>,
235 tries: Option<usize>,
236 max_overlap_ratio: Option<f64>,
237 max_overlap_total: Option<usize>,
238 test_output: Option<bool>,
239 max_words: Option<usize>,
240 min_words: Option<usize>,
241 ) -> Option<String> {
242 let tries = tries.unwrap_or(DEFAULT_TRIES);
243 let mor = max_overlap_ratio.unwrap_or(DEFAULT_MAX_OVERLAP_RATIO);
244 let mot = max_overlap_total.unwrap_or(DEFAULT_MAX_OVERLAP_TOTAL);
245 let test = test_output.unwrap_or(true);
246
247 let prefix: Vec<String> = if let Some(state) = init_state {
248 state.iter().filter(|w| *w != BEGIN).cloned().collect()
249 } else {
250 vec![]
251 };
252
253 for _ in 0..tries {
254 let mut words = prefix.clone();
255 words.extend(self.chain.walk(init_state));
256
257 if let Some(max) = max_words {
259 if words.len() > max {
260 continue;
261 }
262 }
263 if let Some(min) = min_words {
264 if words.len() < min {
265 continue;
266 }
267 }
268
269 if test && self.rejoined_text.is_some() {
271 if self.test_sentence_output(&words, mor, mot) {
272 return Some(self.word_join(&words));
273 }
274 } else {
275 return Some(self.word_join(&words));
276 }
277 }
278
279 None
280 }
281
282 #[allow(clippy::too_many_arguments)]
284 pub fn make_short_sentence(
285 &self,
286 max_chars: usize,
287 min_chars: Option<usize>,
288 init_state: Option<&[String]>,
289 tries: Option<usize>,
290 max_overlap_ratio: Option<f64>,
291 max_overlap_total: Option<usize>,
292 test_output: Option<bool>,
293 max_words: Option<usize>,
294 min_words: Option<usize>,
295 ) -> Option<String> {
296 let tries = tries.unwrap_or(DEFAULT_TRIES);
297 let min_chars = min_chars.unwrap_or(0);
298
299 for _ in 0..tries {
300 if let Some(sentence) = self.make_sentence(
301 init_state,
302 Some(tries),
303 max_overlap_ratio,
304 max_overlap_total,
305 test_output,
306 max_words,
307 min_words,
308 ) {
309 let len = sentence.len();
310 if len >= min_chars && len <= max_chars {
311 return Some(sentence);
312 }
313 }
314 }
315
316 None
317 }
318
319 #[allow(clippy::too_many_arguments)]
321 pub fn make_sentence_with_start(
322 &self,
323 beginning: &str,
324 strict: bool,
325 tries: Option<usize>,
326 max_overlap_ratio: Option<f64>,
327 max_overlap_total: Option<usize>,
328 test_output: Option<bool>,
329 max_words: Option<usize>,
330 min_words: Option<usize>,
331 ) -> Result<String> {
332 let split = self.word_split(beginning);
333 let word_count = split.len();
334
335 if word_count > self.state_size {
336 return Err(MarkovError::ParamError(format!(
337 "`make_sentence_with_start` for this model requires a string containing 1 to {} words. Yours has {}: {:?}",
338 self.state_size, word_count, split
339 )));
340 }
341
342 let init_states: Vec<Vec<String>> = if word_count == self.state_size {
343 vec![split.clone()]
344 } else if word_count < self.state_size {
345 if strict {
346 let mut state = vec![BEGIN.to_string(); self.state_size - word_count];
348 state.extend(split.clone());
349 vec![state]
350 } else {
351 self.find_init_states_from_chain(&split)
353 }
354 } else {
355 return Err(MarkovError::ParamError(format!(
356 "Invalid word count: {}",
357 word_count
358 )));
359 };
360
361 if init_states.is_empty() {
362 return Err(MarkovError::ParamError(format!(
363 "Cannot find sentence beginning with: {}",
364 beginning
365 )));
366 }
367
368 for init_state in init_states {
370 if let Some(output) = self.make_sentence(
371 Some(&init_state),
372 tries,
373 max_overlap_ratio,
374 max_overlap_total,
375 test_output,
376 max_words,
377 min_words,
378 ) {
379 return Ok(output);
380 }
381 }
382
383 Err(MarkovError::ParamError(format!(
384 "Cannot generate sentence beginning with: {}",
385 beginning
386 )))
387 }
388
389 fn find_init_states_from_chain(&self, split: &[String]) -> Vec<Vec<String>> {
391 let word_count = split.len();
392 let mut states = Vec::new();
393
394 for key in self.chain.model().keys() {
395 let filtered: Vec<&String> = key.iter().filter(|w| *w != BEGIN).collect();
397 if filtered.len() >= word_count
398 && filtered[..word_count]
399 .iter()
400 .zip(split.iter())
401 .all(|(a, b)| *a == b)
402 {
403 states.push(key.clone());
404 }
405 }
406
407 states
408 }
409
410 pub fn compile(&self) -> Self {
412 let compiled_chain = self.chain.compile();
413
414 Text {
415 state_size: self.state_size,
416 chain: compiled_chain,
417 parsed_sentences: self.parsed_sentences.clone(),
418 rejoined_text: self.rejoined_text.clone(),
419 retain_original: self.retain_original,
420 well_formed: self.well_formed,
421 reject_pat: self.reject_pat.clone(),
422 }
423 }
424
425 pub fn compile_inplace(&mut self) {
427 self.chain = self.chain.compile();
428 }
429
430 pub fn state_size(&self) -> usize {
432 self.state_size
433 }
434
435 pub fn chain(&self) -> &Chain {
437 &self.chain
438 }
439
440 pub fn to_json(&self) -> Result<String> {
442 let data = TextData {
443 state_size: self.state_size,
444 chain: self.chain.to_json()?,
445 parsed_sentences: self.parsed_sentences.clone(),
446 };
447 Ok(serde_json::to_string(&data)?)
448 }
449
450 pub fn from_json(json_str: &str) -> Result<Self> {
452 let data: TextData = serde_json::from_str(json_str)?;
453 let chain = Chain::from_json(&data.chain)?;
454
455 Ok(Text {
456 state_size: data.state_size,
457 chain,
458 parsed_sentences: data.parsed_sentences.clone(),
459 rejoined_text: data.parsed_sentences.as_ref().map(|sentences| {
460 Self::sentence_join_static(
461 &sentences
462 .iter()
463 .map(|s| Self::word_join_static(s))
464 .collect::<Vec<_>>(),
465 )
466 }),
467 retain_original: data.parsed_sentences.is_some(),
468 well_formed: true,
469 reject_pat: REJECT_PAT.clone(),
470 })
471 }
472
473 pub fn retain_original(&self) -> bool {
475 self.retain_original
476 }
477
478 pub fn parsed_sentences(&self) -> Option<&Vec<Vec<String>>> {
480 self.parsed_sentences.as_ref()
481 }
482
483 fn sentence_join_static(sentences: &[String]) -> String {
484 sentences.join(" ")
485 }
486
487 fn word_join_static(words: &[String]) -> String {
488 words.join(" ")
489 }
490}
491
492#[derive(Debug, Clone)]
494pub struct NewlineText {
495 inner: Text,
496}
497
498impl NewlineText {
499 pub fn new(
501 input_text: &str,
502 state_size: usize,
503 retain_original: bool,
504 well_formed: bool,
505 reject_reg: Option<&str>,
506 ) -> Result<Self> {
507 let text = Text::new(
508 input_text,
509 state_size,
510 retain_original,
511 well_formed,
512 reject_reg,
513 )?;
514 Ok(NewlineText { inner: text })
515 }
516
517 pub fn sentence_split(&self, text: &str) -> Vec<String> {
519 text.split('\n')
520 .map(|s| s.trim().to_string())
521 .filter(|s| !s.is_empty())
522 .collect()
523 }
524
525 #[allow(clippy::too_many_arguments)]
527 pub fn make_sentence(
528 &self,
529 init_state: Option<&[String]>,
530 tries: Option<usize>,
531 max_overlap_ratio: Option<f64>,
532 max_overlap_total: Option<usize>,
533 test_output: Option<bool>,
534 max_words: Option<usize>,
535 min_words: Option<usize>,
536 ) -> Option<String> {
537 self.inner.make_sentence(
538 init_state,
539 tries,
540 max_overlap_ratio,
541 max_overlap_total,
542 test_output,
543 max_words,
544 min_words,
545 )
546 }
547
548 #[allow(clippy::too_many_arguments)]
550 pub fn make_short_sentence(
551 &self,
552 max_chars: usize,
553 min_chars: Option<usize>,
554 init_state: Option<&[String]>,
555 tries: Option<usize>,
556 max_overlap_ratio: Option<f64>,
557 max_overlap_total: Option<usize>,
558 test_output: Option<bool>,
559 max_words: Option<usize>,
560 min_words: Option<usize>,
561 ) -> Option<String> {
562 self.inner.make_short_sentence(
563 max_chars,
564 min_chars,
565 init_state,
566 tries,
567 max_overlap_ratio,
568 max_overlap_total,
569 test_output,
570 max_words,
571 min_words,
572 )
573 }
574
575 pub fn to_json(&self) -> Result<String> {
577 self.inner.to_json()
578 }
579
580 pub fn from_json(json_str: &str) -> Result<Self> {
582 let text = Text::from_json(json_str)?;
583 Ok(NewlineText { inner: text })
584 }
585
586 pub fn inner(&self) -> &Text {
588 &self.inner
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
597 fn test_text_creation() {
598 let text = "Hello world. This is a test.";
599 let model = Text::new(text, 2, true, true, None).unwrap();
600 assert_eq!(model.state_size(), 2);
601 }
602
603 #[test]
604 fn test_make_sentence() {
605 let text = "The cat sat on the mat. The dog ran in the park. The bird flew over the tree. The cat chased the mouse. The dog barked loudly.";
607 let model = Text::new(text, 1, true, true, None).unwrap();
608 let sentence = model.make_sentence(None, None, None, None, None, None, None);
609 assert!(sentence.is_some());
610 }
611
612 #[test]
613 fn test_json_serialization() {
614 let text = "Hello world. This is a test.";
615 let model = Text::new(text, 2, true, true, None).unwrap();
616 let json = model.to_json().unwrap();
617 let restored = Text::from_json(&json).unwrap();
618 assert_eq!(model.state_size(), restored.state_size());
619 }
620
621 #[test]
622 fn test_newline_text() {
623 let text = "Line one
624Line two
625Line three";
626 let model = NewlineText::new(text, 2, true, true, None).unwrap();
627 let sentences = model.sentence_split(text);
628 assert_eq!(sentences.len(), 3);
629 }
630}