iocaine 2.2.0

The deadliest poison known to AI
Documentation
// SPDX-FileCopyrightText: 2017-2024 Martin Geisler
// SPDX-FileCopyrightText: 2025 Gergely Nagy
// SPDX-FileContributor: Martin Geisler
// SPDX-FileContributor: Gergely Nagy
//
// SPDX-License-Identifier: MIT
//
// Originally based on code borrowed from https://github.com/mgeisler/lipsum

use rand::{Rng, seq::IndexedRandom};
use std::collections::HashMap;
use std::fs::File;
use std::io::Read as _;
use substrings::{Interner, Substr, WhitespaceSplitIterator};

mod substrings;

type Bigram = (Substr, Substr);
#[derive(Debug, Default)]
pub struct WurstsalatGeneratorPro {
    string: String,
    map: HashMap<Bigram, Vec<Substr>>,
    keys: Vec<Bigram>,
}

impl WurstsalatGeneratorPro {
    #[must_use]
    pub fn new() -> Self {
        Self::default()
    }

    fn learn(string: String, mut breaks: &[usize]) -> Self {
        let mut interner = Interner::new();
        let words = WhitespaceSplitIterator::new(&string);
        let mut map = HashMap::<Bigram, Vec<Substr>>::new();
        for window in words.collect::<Vec<_>>().windows(3) {
            let (a, b, c) = (window[0], window[1], window[2]);

            // This bit of weirdness is to preserve the behavior from
            // learning from multiple files independently; if our
            // current window spans a break, we don't add the triple.
            let mut skip_triple = false;
            while !breaks.is_empty() && breaks[0] <= c.start {
                if breaks[0] <= a.start {
                    // completely passed the first break, can remove it
                    breaks = &breaks[1..];
                } else {
                    skip_triple = true;
                    break;
                }
            }

            if !skip_triple {
                map.entry((interner.intern(&string, a), interner.intern(&string, b)))
                    .or_default()
                    .push(interner.intern(&string, c));
            }
        }

        let mut keys = map.keys().copied().collect::<Vec<_>>();
        keys.sort_unstable_by_key(|(s1, s2)| {
            (&string[s1.start..s1.end], &string[s2.start..s2.end])
        });

        Self { string, map, keys }
    }

    pub fn learn_from_files(files: &[String]) -> Result<Self, std::io::Error> {
        let mut s = String::new();
        let mut breaks = Vec::new();
        for source in files {
            let mut f = File::open(source)?;
            f.read_to_string(&mut s)?;
            breaks.push(s.len());
            s.push(' ');
        }

        Ok(Self::learn(s, &breaks))
    }

    pub fn generate<R: Rng>(&self, mut rng: R) -> Words<'_, R> {
        let initial_bigram = self.keys.choose(&mut rng).copied().unwrap_or_default();
        self.iter_with_rng_from(rng, initial_bigram)
    }

    fn iter_with_rng_from<R: Rng>(&self, rng: R, from: Bigram) -> Words<'_, R> {
        Words {
            string: self.string.as_str(),
            map: &self.map,
            rng,
            keys: &self.keys,
            state: from,
        }
    }
}

#[derive(Clone)]
pub struct Words<'a, R: Rng> {
    string: &'a str,
    map: &'a HashMap<Bigram, Vec<Substr>>,
    rng: R,
    keys: &'a [Bigram],
    state: Bigram,
}

impl<'a, R: Rng> Iterator for Words<'a, R> {
    type Item = &'a str;

    fn next(&mut self) -> Option<&'a str> {
        if self.map.is_empty() {
            return None;
        }

        let result = self.state.0.extract_str(self.string);

        let next_words = self.map.get(&self.state).unwrap_or_else(|| {
            self.state = *self.keys.choose(&mut self.rng).unwrap();
            &self.map[&self.state]
        });
        let next = *next_words.choose(&mut self.rng).unwrap();
        self.state = (self.state.1, next);

        Some(result)
    }
}

/// Check if `c` is an ASCII punctuation character.
fn is_ascii_punctuation(c: char) -> bool {
    c.is_ascii_punctuation()
}

/// Capitalize the first character in a string.
fn capitalize(word: &str) -> String {
    let idx = word.chars().next().map_or(0, char::len_utf8);

    let mut result = String::with_capacity(word.len());
    result.push_str(&word[..idx].to_uppercase());
    result.push_str(&word[idx..]);
    result
}

/// Join words from an iterator. The first word is always capitalized
/// and the generated sentence will end with `'.'` if it doesn't
/// already end with some other ASCII punctuation character.
pub fn join_words<'a, I: Iterator<Item = &'a str>>(mut words: I) -> String {
    words.next().map_or_else(String::new, |word| {
        // Punctuation characters which ends a sentence.
        let punctuation: &[char] = &['.', '!', '?'];

        let mut sentence = capitalize(word);
        let mut needs_cap = sentence.ends_with(punctuation);

        // Add remaining words.
        for word in words {
            sentence.push(' ');

            if needs_cap {
                sentence.push_str(&capitalize(word));
            } else {
                sentence.push_str(word);
            }

            needs_cap = word.ends_with(punctuation);
        }

        // Ensure the sentence ends with either one of ".!?".
        if !sentence.ends_with(punctuation) {
            // Trim all trailing punctuation characters to avoid
            // adding '.' after a ',' or similar.
            let idx = sentence.trim_end_matches(is_ascii_punctuation).len();
            sentence.truncate(idx);
            sentence.push('.');
        }

        sentence
    })
}