iocaine 2.0.0

The deadliest poison known to AI
Documentation
// SPDX-FileCopyrightText: 2017-2024 Martin Geisler
// SPDX-FileCopyrightText: 2025 Gergely Nagy
// SPDX-FileContributor: Martin Geisler
// SPDX-FileContributor: Gergely Nagy
//
// SPDX-License-Identifier: MIT
//
// Originally based on code borrowed from https://github.com/mgeisler/lipsum

use rand::{seq::IndexedRandom, Rng};
use std::collections::HashMap;
use std::fs::File;
use std::io::Read as _;
use substrings::{Interner, Substr, WhitespaceSplitIterator};

mod substrings;

type Bigram = (Substr, Substr);
#[derive(Debug, Default)]
pub struct WurstsalatGeneratorPro {
    string: String,
    map: HashMap<Bigram, Vec<Substr>>,
    keys: Vec<Bigram>,
}

impl WurstsalatGeneratorPro {
    pub fn new() -> Self {
        Default::default()
    }

    fn learn(string: String, mut breaks: &[usize]) -> Self {
        let mut interner = Interner::new();
        let words = WhitespaceSplitIterator::new(&string);
        let mut map = HashMap::<Bigram, Vec<Substr>>::new();
        for window in words.collect::<Vec<_>>().windows(3) {
            let (a, b, c) = (window[0], window[1], window[2]);

            // This bit of weirdness is to preserve the behavior from
            // learning from multiple files independently; if our
            // current window spans a break, we don't add the triple.
            let mut skip_triple = false;
            while !breaks.is_empty() && breaks[0] <= c.start {
                if breaks[0] <= a.start {
                    // completely passed the first break, can remove it
                    breaks = &breaks[1..];
                } else {
                    skip_triple = true;
                    break;
                }
            }

            if !skip_triple {
                map.entry((interner.intern(&string, a), interner.intern(&string, b)))
                    .or_default()
                    .push(interner.intern(&string, c));
            }
        }

        let mut keys = map.keys().copied().collect::<Vec<_>>();
        keys.sort_unstable_by_key(|(s1, s2)| {
            (&string[s1.start..s1.end], &string[s2.start..s2.end])
        });

        Self { string, keys, map }
    }

    pub fn learn_from_files(files: &[String]) -> Result<Self, std::io::Error> {
        let mut s = String::new();
        let mut breaks = Vec::new();
        for source in files.iter() {
            let mut f = File::open(source)?;
            f.read_to_string(&mut s)?;
            breaks.push(s.len());
            s.push(' ');
        }

        Ok(Self::learn(s, &breaks))
    }

    pub fn generate<R: Rng>(&self, mut rng: R) -> Words<'_, R> {
        let initial_bigram = self.keys.choose(&mut rng).cloned().unwrap_or_default();
        self.iter_with_rng_from(rng, initial_bigram)
    }

    fn iter_with_rng_from<R: Rng>(&self, rng: R, from: Bigram) -> Words<'_, R> {
        Words {
            string: self.string.as_str(),
            map: &self.map,
            rng,
            keys: &self.keys,
            state: from,
        }
    }
}

#[derive(Clone)]
pub struct Words<'a, R: Rng> {
    string: &'a str,
    map: &'a HashMap<Bigram, Vec<Substr>>,
    rng: R,
    keys: &'a [Bigram],
    state: Bigram,
}

impl<'a, R: Rng> Iterator for Words<'a, R> {
    type Item = &'a str;

    fn next(&mut self) -> Option<&'a str> {
        if self.map.is_empty() {
            return None;
        }

        let result = self.state.0.extract_str(self.string);

        let next_words = self.map.get(&self.state).unwrap_or_else(|| {
            self.state = *self.keys.choose(&mut self.rng).unwrap();
            &self.map[&self.state]
        });
        let next = *next_words.choose(&mut self.rng).unwrap();
        self.state = (self.state.1, next);

        Some(result)
    }
}

/// Check if `c` is an ASCII punctuation character.
fn is_ascii_punctuation(c: char) -> bool {
    c.is_ascii_punctuation()
}

/// Capitalize the first character in a string.
fn capitalize(word: &str) -> String {
    let idx = match word.chars().next() {
        Some(c) => c.len_utf8(),
        None => 0,
    };

    let mut result = String::with_capacity(word.len());
    result.push_str(&word[..idx].to_uppercase());
    result.push_str(&word[idx..]);
    result
}

/// Join words from an iterator. The first word is always capitalized
/// and the generated sentence will end with `'.'` if it doesn't
/// already end with some other ASCII punctuation character.
pub fn join_words<'a, I: Iterator<Item = &'a str>>(mut words: I) -> String {
    match words.next() {
        None => String::new(),
        Some(word) => {
            // Punctuation characters which ends a sentence.
            let punctuation: &[char] = &['.', '!', '?'];

            let mut sentence = capitalize(word);
            let mut needs_cap = sentence.ends_with(punctuation);

            // Add remaining words.
            for word in words {
                sentence.push(' ');

                if needs_cap {
                    sentence.push_str(&capitalize(word));
                } else {
                    sentence.push_str(word);
                }

                needs_cap = word.ends_with(punctuation);
            }

            // Ensure the sentence ends with either one of ".!?".
            if !sentence.ends_with(punctuation) {
                // Trim all trailing punctuation characters to avoid
                // adding '.' after a ',' or similar.
                let idx = sentence.trim_end_matches(is_ascii_punctuation).len();
                sentence.truncate(idx);
                sentence.push('.');
            }

            sentence
        }
    }
}