Skip to main content

markovify_rs/
text.rs

1//! Text processing and sentence generation
2
3use crate::chain::{Chain, BEGIN};
4use crate::errors::{MarkovError, Result};
5use crate::splitters::split_into_sentences;
6use lazy_static::lazy_static;
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9
10/// Default maximum overlap ratio for sentence output testing
11const DEFAULT_MAX_OVERLAP_RATIO: f64 = 0.7;
12/// Default maximum overlap total for sentence output testing
13const DEFAULT_MAX_OVERLAP_TOTAL: usize = 15;
14/// Default number of tries for sentence generation
15const DEFAULT_TRIES: usize = 10;
16
17lazy_static! {
18    /// Pattern to reject sentences with problematic characters
19    static ref REJECT_PAT: Regex = Regex::new(r#"(^')|('$)|\s'|'\s|["(\(\)\[\])]"#).unwrap();
20    /// Pattern for splitting words
21    static ref WORD_SPLIT_PATTERN: Regex = Regex::new(r"\s+").unwrap();
22}
23
24/// Serialized representation of Text for JSON
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TextData {
27    pub state_size: usize,
28    pub chain: String, // JSON string of the chain
29    pub parsed_sentences: Option<Vec<Vec<String>>>,
30}
31
32/// A Markov chain text model for generating random sentences
33#[derive(Debug, Clone)]
34pub struct Text {
35    state_size: usize,
36    chain: Chain,
37    parsed_sentences: Option<Vec<Vec<String>>>,
38    rejoined_text: Option<String>,
39    retain_original: bool,
40    well_formed: bool,
41    reject_pat: Regex,
42}
43
44impl Text {
45    /// Create a new Text model from input text
46    ///
47    /// # Arguments
48    /// * `input_text` - The source text to build the model from
49    /// * `state_size` - Number of words in the model's state (default: 2)
50    /// * `retain_original` - Whether to keep the original corpus for overlap checking
51    /// * `well_formed` - Whether to reject sentences with unmatched quotes/parenthesis
52    /// * `reject_reg` - Optional custom regex pattern for rejecting sentences
53    pub fn new(
54        input_text: &str,
55        state_size: usize,
56        retain_original: bool,
57        well_formed: bool,
58        reject_reg: Option<&str>,
59    ) -> Result<Self> {
60        let reject_pat = if let Some(reg) = reject_reg {
61            Regex::new(reg).map_err(|e| MarkovError::ParamError(format!("Invalid regex: {}", e)))?
62        } else {
63            REJECT_PAT.clone()
64        };
65
66        let parsed_sentences: Vec<Vec<String>> =
67            Self::generate_corpus(input_text, &reject_pat, well_formed)
68                .into_iter()
69                .collect();
70
71        let rejoined_text = if retain_original && !parsed_sentences.is_empty() {
72            Some(Self::sentence_join_static(
73                &parsed_sentences
74                    .iter()
75                    .map(|s| Self::word_join_static(s))
76                    .collect::<Vec<_>>(),
77            ))
78        } else {
79            None
80        };
81
82        let chain = Chain::new(&parsed_sentences, state_size);
83
84        Ok(Text {
85            state_size,
86            chain,
87            parsed_sentences: if retain_original {
88                Some(parsed_sentences)
89            } else {
90                None
91            },
92            rejoined_text,
93            retain_original,
94            well_formed,
95            reject_pat,
96        })
97    }
98
99    /// Create a Text model from an existing chain
100    pub fn from_chain(
101        chain: Chain,
102        parsed_sentences: Option<Vec<Vec<String>>>,
103        retain_original: bool,
104    ) -> Self {
105        let state_size = chain.state_size();
106
107        let rejoined_text = if retain_original {
108            parsed_sentences.as_ref().map(|sentences| {
109                Self::sentence_join_static(
110                    &sentences
111                        .iter()
112                        .map(|s| Self::word_join_static(s))
113                        .collect::<Vec<_>>(),
114                )
115            })
116        } else {
117            None
118        };
119
120        Text {
121            state_size,
122            chain,
123            parsed_sentences,
124            rejoined_text,
125            retain_original,
126            well_formed: true,
127            reject_pat: REJECT_PAT.clone(),
128        }
129    }
130
131    /// Split text into sentences
132    pub fn sentence_split(&self, text: &str) -> Vec<String> {
133        split_into_sentences(text)
134    }
135
136    /// Join sentences into text
137    pub fn sentence_join(&self, sentences: &[String]) -> String {
138        sentences.join(" ")
139    }
140
141    /// Split a sentence into words
142    pub fn word_split(&self, sentence: &str) -> Vec<String> {
143        WORD_SPLIT_PATTERN
144            .split(sentence)
145            .filter(|s| !s.is_empty())
146            .map(|s| s.to_string())
147            .collect()
148    }
149
150    /// Join words into a sentence
151    pub fn word_join(&self, words: &[String]) -> String {
152        words.join(" ")
153    }
154
155    /// Test if a sentence input is valid
156    pub fn test_sentence_input(&self, sentence: &str) -> bool {
157        if sentence.trim().is_empty() {
158            return false;
159        }
160
161        if self.well_formed && self.reject_pat.is_match(sentence) {
162            return false;
163        }
164
165        true
166    }
167
168    /// Generate a corpus from text
169    fn generate_corpus(text: &str, reject_pat: &Regex, well_formed: bool) -> Vec<Vec<String>> {
170        let sentences = split_into_sentences(text);
171
172        sentences
173            .into_iter()
174            .filter(|s| {
175                if !well_formed {
176                    return true;
177                }
178                // Test sentence input
179                if s.trim().is_empty() {
180                    return false;
181                }
182                if reject_pat.is_match(s) {
183                    return false;
184                }
185                true
186            })
187            .map(|s| {
188                WORD_SPLIT_PATTERN
189                    .split(&s)
190                    .filter(|s| !s.is_empty())
191                    .map(|s| s.to_string())
192                    .collect()
193            })
194            .collect()
195    }
196
197    /// Test if a generated sentence output is acceptable
198    fn test_sentence_output(
199        &self,
200        words: &[String],
201        max_overlap_ratio: f64,
202        max_overlap_total: usize,
203    ) -> bool {
204        if let Some(ref rejoined) = self.rejoined_text {
205            let overlap_ratio = ((max_overlap_ratio * words.len() as f64).round() as usize).max(1);
206            let overlap_max = overlap_ratio.min(max_overlap_total);
207            let overlap_over = overlap_max + 1;
208            let gram_count = words.len().saturating_sub(overlap_max).max(1);
209
210            for i in 0..gram_count {
211                let gram = &words[i..(i + overlap_over).min(words.len())];
212                let gram_joined = self.word_join(gram);
213                if rejoined.contains(&gram_joined) {
214                    return false;
215                }
216            }
217        }
218        true
219    }
220
221    /// Generate a random sentence
222    ///
223    /// # Arguments
224    /// * `init_state` - Optional starting state (tuple of words)
225    /// * `tries` - Maximum number of attempts (default: 10)
226    /// * `max_overlap_ratio` - Maximum overlap ratio with original text (default: 0.7)
227    /// * `max_overlap_total` - Maximum overlap total with original text (default: 15)
228    /// * `test_output` - Whether to test output for overlap (default: true)
229    /// * `max_words` - Maximum number of words in the sentence
230    /// * `min_words` - Minimum number of words in the sentence
231    #[allow(clippy::too_many_arguments)]
232    pub fn make_sentence(
233        &self,
234        init_state: Option<&[String]>,
235        tries: Option<usize>,
236        max_overlap_ratio: Option<f64>,
237        max_overlap_total: Option<usize>,
238        test_output: Option<bool>,
239        max_words: Option<usize>,
240        min_words: Option<usize>,
241    ) -> Option<String> {
242        let tries = tries.unwrap_or(DEFAULT_TRIES);
243        let mor = max_overlap_ratio.unwrap_or(DEFAULT_MAX_OVERLAP_RATIO);
244        let mot = max_overlap_total.unwrap_or(DEFAULT_MAX_OVERLAP_TOTAL);
245        let test = test_output.unwrap_or(true);
246
247        let prefix: Vec<String> = if let Some(state) = init_state {
248            state.iter().filter(|w| *w != BEGIN).cloned().collect()
249        } else {
250            vec![]
251        };
252
253        for _ in 0..tries {
254            let mut words = prefix.clone();
255            words.extend(self.chain.walk(init_state));
256
257            // Check word count constraints
258            if let Some(max) = max_words {
259                if words.len() > max {
260                    continue;
261                }
262            }
263            if let Some(min) = min_words {
264                if words.len() < min {
265                    continue;
266                }
267            }
268
269            // Test output if required
270            if test && self.rejoined_text.is_some() {
271                if self.test_sentence_output(&words, mor, mot) {
272                    return Some(self.word_join(&words));
273                }
274            } else {
275                return Some(self.word_join(&words));
276            }
277        }
278
279        None
280    }
281
282    /// Generate a short sentence with a maximum character count
283    #[allow(clippy::too_many_arguments)]
284    pub fn make_short_sentence(
285        &self,
286        max_chars: usize,
287        min_chars: Option<usize>,
288        init_state: Option<&[String]>,
289        tries: Option<usize>,
290        max_overlap_ratio: Option<f64>,
291        max_overlap_total: Option<usize>,
292        test_output: Option<bool>,
293        max_words: Option<usize>,
294        min_words: Option<usize>,
295    ) -> Option<String> {
296        let tries = tries.unwrap_or(DEFAULT_TRIES);
297        let min_chars = min_chars.unwrap_or(0);
298
299        for _ in 0..tries {
300            if let Some(sentence) = self.make_sentence(
301                init_state,
302                Some(tries),
303                max_overlap_ratio,
304                max_overlap_total,
305                test_output,
306                max_words,
307                min_words,
308            ) {
309                let len = sentence.len();
310                if len >= min_chars && len <= max_chars {
311                    return Some(sentence);
312                }
313            }
314        }
315
316        None
317    }
318
319    /// Generate a sentence that starts with a specific string
320    #[allow(clippy::too_many_arguments)]
321    pub fn make_sentence_with_start(
322        &self,
323        beginning: &str,
324        strict: bool,
325        tries: Option<usize>,
326        max_overlap_ratio: Option<f64>,
327        max_overlap_total: Option<usize>,
328        test_output: Option<bool>,
329        max_words: Option<usize>,
330        min_words: Option<usize>,
331    ) -> Result<String> {
332        let split = self.word_split(beginning);
333        let word_count = split.len();
334
335        if word_count > self.state_size {
336            return Err(MarkovError::ParamError(format!(
337                "`make_sentence_with_start` for this model requires a string containing 1 to {} words. Yours has {}: {:?}",
338                self.state_size, word_count, split
339            )));
340        }
341
342        let init_states: Vec<Vec<String>> = if word_count == self.state_size {
343            vec![split.clone()]
344        } else if word_count < self.state_size {
345            if strict {
346                // Pad with BEGIN tokens
347                let mut state = vec![BEGIN.to_string(); self.state_size - word_count];
348                state.extend(split.clone());
349                vec![state]
350            } else {
351                // Find all chains containing this sequence
352                self.find_init_states_from_chain(&split)
353            }
354        } else {
355            return Err(MarkovError::ParamError(format!(
356                "Invalid word count: {}",
357                word_count
358            )));
359        };
360
361        if init_states.is_empty() {
362            return Err(MarkovError::ParamError(format!(
363                "Cannot find sentence beginning with: {}",
364                beginning
365            )));
366        }
367
368        // Try each init state
369        for init_state in init_states {
370            if let Some(output) = self.make_sentence(
371                Some(&init_state),
372                tries,
373                max_overlap_ratio,
374                max_overlap_total,
375                test_output,
376                max_words,
377                min_words,
378            ) {
379                return Ok(output);
380            }
381        }
382
383        Err(MarkovError::ParamError(format!(
384            "Cannot generate sentence beginning with: {}",
385            beginning
386        )))
387    }
388
389    /// Find all initial states from the chain that contain the given split
390    fn find_init_states_from_chain(&self, split: &[String]) -> Vec<Vec<String>> {
391        let word_count = split.len();
392        let mut states = Vec::new();
393
394        for key in self.chain.model().keys() {
395            // Filter out BEGIN tokens and check if it starts with split
396            let filtered: Vec<&String> = key.iter().filter(|w| *w != BEGIN).collect();
397            if filtered.len() >= word_count
398                && filtered[..word_count]
399                    .iter()
400                    .zip(split.iter())
401                    .all(|(a, b)| *a == b)
402            {
403                states.push(key.clone());
404            }
405        }
406
407        states
408    }
409
410    /// Compile the model for faster generation
411    pub fn compile(&self) -> Self {
412        let compiled_chain = self.chain.compile();
413
414        Text {
415            state_size: self.state_size,
416            chain: compiled_chain,
417            parsed_sentences: self.parsed_sentences.clone(),
418            rejoined_text: self.rejoined_text.clone(),
419            retain_original: self.retain_original,
420            well_formed: self.well_formed,
421            reject_pat: self.reject_pat.clone(),
422        }
423    }
424
425    /// Compile the model in place (returns self)
426    pub fn compile_inplace(&mut self) {
427        self.chain = self.chain.compile();
428    }
429
430    /// Get the state size
431    pub fn state_size(&self) -> usize {
432        self.state_size
433    }
434
435    /// Get the chain
436    pub fn chain(&self) -> &Chain {
437        &self.chain
438    }
439
440    /// Serialize to JSON
441    pub fn to_json(&self) -> Result<String> {
442        let data = TextData {
443            state_size: self.state_size,
444            chain: self.chain.to_json()?,
445            parsed_sentences: self.parsed_sentences.clone(),
446        };
447        Ok(serde_json::to_string(&data)?)
448    }
449
450    /// Deserialize from JSON
451    pub fn from_json(json_str: &str) -> Result<Self> {
452        let data: TextData = serde_json::from_str(json_str)?;
453        let chain = Chain::from_json(&data.chain)?;
454
455        Ok(Text {
456            state_size: data.state_size,
457            chain,
458            parsed_sentences: data.parsed_sentences.clone(),
459            rejoined_text: data.parsed_sentences.as_ref().map(|sentences| {
460                Self::sentence_join_static(
461                    &sentences
462                        .iter()
463                        .map(|s| Self::word_join_static(s))
464                        .collect::<Vec<_>>(),
465                )
466            }),
467            retain_original: data.parsed_sentences.is_some(),
468            well_formed: true,
469            reject_pat: REJECT_PAT.clone(),
470        })
471    }
472
473    /// Check if the model retains original sentences
474    pub fn retain_original(&self) -> bool {
475        self.retain_original
476    }
477
478    /// Get the parsed sentences if available
479    pub fn parsed_sentences(&self) -> Option<&Vec<Vec<String>>> {
480        self.parsed_sentences.as_ref()
481    }
482
483    fn sentence_join_static(sentences: &[String]) -> String {
484        sentences.join(" ")
485    }
486
487    fn word_join_static(words: &[String]) -> String {
488        words.join(" ")
489    }
490}
491
492/// A text model that splits on newlines instead of sentence punctuation
493#[derive(Debug, Clone)]
494pub struct NewlineText {
495    inner: Text,
496}
497
498impl NewlineText {
499    /// Create a new NewlineText model
500    pub fn new(
501        input_text: &str,
502        state_size: usize,
503        retain_original: bool,
504        well_formed: bool,
505        reject_reg: Option<&str>,
506    ) -> Result<Self> {
507        let text = Text::new(
508            input_text,
509            state_size,
510            retain_original,
511            well_formed,
512            reject_reg,
513        )?;
514        Ok(NewlineText { inner: text })
515    }
516
517    /// Split text on newlines
518    pub fn sentence_split(&self, text: &str) -> Vec<String> {
519        text.split('\n')
520            .map(|s| s.trim().to_string())
521            .filter(|s| !s.is_empty())
522            .collect()
523    }
524
525    /// Generate a sentence
526    #[allow(clippy::too_many_arguments)]
527    pub fn make_sentence(
528        &self,
529        init_state: Option<&[String]>,
530        tries: Option<usize>,
531        max_overlap_ratio: Option<f64>,
532        max_overlap_total: Option<usize>,
533        test_output: Option<bool>,
534        max_words: Option<usize>,
535        min_words: Option<usize>,
536    ) -> Option<String> {
537        self.inner.make_sentence(
538            init_state,
539            tries,
540            max_overlap_ratio,
541            max_overlap_total,
542            test_output,
543            max_words,
544            min_words,
545        )
546    }
547
548    /// Generate a short sentence
549    #[allow(clippy::too_many_arguments)]
550    pub fn make_short_sentence(
551        &self,
552        max_chars: usize,
553        min_chars: Option<usize>,
554        init_state: Option<&[String]>,
555        tries: Option<usize>,
556        max_overlap_ratio: Option<f64>,
557        max_overlap_total: Option<usize>,
558        test_output: Option<bool>,
559        max_words: Option<usize>,
560        min_words: Option<usize>,
561    ) -> Option<String> {
562        self.inner.make_short_sentence(
563            max_chars,
564            min_chars,
565            init_state,
566            tries,
567            max_overlap_ratio,
568            max_overlap_total,
569            test_output,
570            max_words,
571            min_words,
572        )
573    }
574
575    /// Serialize to JSON
576    pub fn to_json(&self) -> Result<String> {
577        self.inner.to_json()
578    }
579
580    /// Deserialize from JSON
581    pub fn from_json(json_str: &str) -> Result<Self> {
582        let text = Text::from_json(json_str)?;
583        Ok(NewlineText { inner: text })
584    }
585
586    /// Get the inner Text model
587    pub fn inner(&self) -> &Text {
588        &self.inner
589    }
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    #[test]
597    fn test_text_creation() {
598        let text = "Hello world. This is a test.";
599        let model = Text::new(text, 2, true, true, None).unwrap();
600        assert_eq!(model.state_size(), 2);
601    }
602
603    #[test]
604    fn test_make_sentence() {
605        // Need more text for state_size=2 to work properly
606        let text = "The cat sat on the mat. The dog ran in the park. The bird flew over the tree. The cat chased the mouse. The dog barked loudly.";
607        let model = Text::new(text, 1, true, true, None).unwrap();
608        let sentence = model.make_sentence(None, None, None, None, None, None, None);
609        assert!(sentence.is_some());
610    }
611
612    #[test]
613    fn test_json_serialization() {
614        let text = "Hello world. This is a test.";
615        let model = Text::new(text, 2, true, true, None).unwrap();
616        let json = model.to_json().unwrap();
617        let restored = Text::from_json(&json).unwrap();
618        assert_eq!(model.state_size(), restored.state_size());
619    }
620
621    #[test]
622    fn test_newline_text() {
623        let text = "Line one
624Line two
625Line three";
626        let model = NewlineText::new(text, 2, true, true, None).unwrap();
627        let sentences = model.sentence_split(text);
628        assert_eq!(sentences.len(), 3);
629    }
630}