mudders 0.0.4

Generating Lexicographically-Evenly-Spaced Strings, or: Mudder.js in Rust.
Documentation
/*!
Generate lexicographically-evenly-spaced strings between two strings
from pre-defined alphabets.

This is a rewrite of [mudderjs](https://github.com/fasiha/mudderjs); thanks
for the original work of the author and their contributors!

## Usage
Add a dependency in your Cargo.toml:

```toml
mudders = "0.0.4"
```

Now you can generate lexicographically-spaced strings in a few different ways:

```
use mudders::SymbolTable;
// The mudder method takes a NonZeroUsize as the amount,
// so you cannot pass in an invalid value.
use std::num::NonZeroUsize;

// You can use the included alphabet table
let table = SymbolTable::alphabet();
// SymbolTable::mudder() returns a Vec containing `amount` Strings.
let result = table.mudder_one("a", "z").unwrap();
// These strings are always lexicographically placed between `start` and `end`.
let one_str = result.as_str();
assert!(one_str > "a");
assert!(one_str < "z");

// You can also define your own symbol tables
let table = SymbolTable::from_chars(&['a', 'b']).unwrap();
let result = table.mudder("a", "b", NonZeroUsize::new(2).unwrap()).unwrap();
assert_eq!(result.len(), 2);
assert!(result[0].as_str() > "a" && result[1].as_str() > "a");
assert!(result[0].as_str() < "b" && result[1].as_str() < "b");

// The strings *should* be evenly-spaced and as short as they can be.
let table = SymbolTable::alphabet();
let result = table.mudder("anhui", "azazel", NonZeroUsize::new(3).unwrap()).unwrap();
assert_eq!(result.len(), 3);
assert_eq!(vec!["aq", "as", "av"], result);
```

## Notes
The most notable difference to Mudder.js is that currently, mudders only
supports ASCII characters (because 127 characters ought to be enough for
everyone™). Our default `::alphabet()` also only has lowercase letters.

*/

use core::num::NonZeroUsize;
use std::{convert::TryFrom, str::FromStr};

#[macro_use]
pub mod error;
use error::*;

/// The functionality of the crate lives here.
///
/// A symbol table is, internally, a vector of valid ASCII bytes that are used
/// to generate lexicographically evenly-spaced strings.
#[derive(Clone, Debug)]
pub struct SymbolTable(Vec<u8>);

impl SymbolTable {
    /// Creates a new symbol table from the given byte slice.
    /// The slice is internally sorted using `.sort()`.
    ///
    /// An error is returned if one of the given bytes is out of ASCII range.
    pub fn new(source: &[u8]) -> Result<Self, CreationError> {
        ensure! { !source.is_empty(), CreationError::EmptySlice }
        ensure! { all_chars_ascii(&source), NonAsciiError::NonAsciiU8 }
        // Copy the values, we need to own them anyways...
        let mut vec: Vec<_> = source.iter().copied().collect();
        // Sort them so they're actually in order.
        // (You can pass in ['b', 'a'], but that's not usable internally I think.)
        vec.sort();
        vec.dedup();
        Ok(Self(vec))
    }

    /// Creates a new symbol table from the given characters.
    /// The slice is internally sorted using `.sort()`.
    ///
    /// An error is returned if one of the given characters is not ASCII.
    pub fn from_chars(source: &[char]) -> Result<Self, CreationError> {
        let inner: Box<[u8]> = source
            .iter()
            .map(|c| try_ascii_u8_from_char(*c))
            .collect::<Result<_, _>>()?;
        Ok(Self::new(&inner)?)
    }

    /// Returns a SymbolTable which contains the lowercase latin alphabet (`[a-z]`).
    #[allow(clippy::char_lit_as_u8)]
    pub fn alphabet() -> Self {
        Self::new(&('a' as u8..='z' as u8).collect::<Box<[_]>>()).unwrap()
    }

    /// Generate `amount` strings that lexicographically sort between `start` and `end`.
    /// The algorithm will try to make them as evenly-spaced as possible.
    ///
    /// When both parameters are empty strings, `amount` new strings that are
    /// in lexicographical order are returned.
    ///
    /// If parameter `b` is lexicographically before `a`, they are swapped internally.
    ///
    /// ```
    /// # use mudders::SymbolTable;
    /// # use std::num::NonZeroUsize;
    /// // Using the included alphabet table
    /// let table = SymbolTable::alphabet();
    /// // Generate 10 strings from scratch
    /// let results = table.mudder("", "", NonZeroUsize::new(10).unwrap()).unwrap();
    /// assert!(results.len() == 10);
    /// // results should look something like ["b", "d", "f", ..., "r", "t"]
    /// ```
    pub fn mudder(
        &self,
        a: &str,
        b: &str,
        amount: NonZeroUsize,
    ) -> Result<Vec<String>, GenerationError> {
        use error::InternalError::*;
        use GenerationError::*;
        ensure! { all_chars_ascii(a), NonAsciiError::NonAsciiU8 }
        ensure! { all_chars_ascii(b), NonAsciiError::NonAsciiU8 }
        ensure! { self.contains_all_chars(a), UnknownCharacters(a.to_string()) }
        ensure! { self.contains_all_chars(b), UnknownCharacters(b.to_string()) }
        let (a, b) = if a.is_empty() || b.is_empty() {
            // If an argument is empty, keep the order
            (a, b)
        } else if b < a {
            // If they're not empty and b is lexicographically prior to a, swap them
            (b, a)
        } else {
            // You can't generate values between two matching strings.
            ensure! { a != b, MatchingStrings(a.to_string()) }
            // In any other case, keep the order
            (a, b)
        };

        // TODO: Check for lexicographical adjacency!
        //ensure! { !lex_adjacent(a, b), LexAdjacentStrings(a.to_string(), b.to_string()) }

        // Count the characters start and end have in common.
        let matching_count: usize = {
            // Iterate through the chars of both given inputs...
            let (mut start_chars, mut end_chars) = (a.chars(), b.chars());
            // We need to keep track of this, because:
            // In the case of `a` == `"a"` and `b` == `"aab"`,
            // we actually need to compare `""` to `"b"` later on, not `""` to `"a"`.
            let mut last_start_char = '\0';
            // Counting to get the index.
            let mut i: usize = 0;
            loop {
                // Advance the iterators...
                match (start_chars.next(), end_chars.next()) {
                    // As long as there's two characters that match, increment i.
                    (Some(sc), Some(ec)) if sc == ec => {
                        last_start_char = sc;
                        i += 1;
                        continue;
                    }
                    // If start_chars have run out, but end_chars haven't, check
                    // if the current end char matches the last start char.
                    // If it does, we still need to increment our counter.
                    (None, Some(ec)) if ec == last_start_char => {
                        i += 1;
                        continue;
                    }
                    // break with i as soon as any mismatch happens or both iterators run out.
                    // matching_count will either be 0, indicating that there's
                    // no leading common pattern, or something other than 0, in
                    // that case it's the count of common characters.
                    (None, None) | (Some(_), None) | (None, Some(_)) | (Some(_), Some(_)) => {
                        break i
                    }
                }
            }
        };

        // Count the number to add to the total requests amount.
        // If a or b is empty, we need one item less in the pool;
        // two items less if both are empty.
        let non_empty_input_count = [a, b].iter().filter(|s| !s.is_empty()).count();
        // For convenience
        let computed_amount = || amount.get() + non_empty_input_count;

        // Calculate the distance between the first non-matching characters.
        // If matching_count is greater than 0, we have leading common chars,
        // so we skip those, but add the amount to the depth base.
        let branching_factor = self.distance_between_first_chars(
            //            v--- matching_count might be higher than a.len()
            //           vvv   because we might count past a's end
            &a[std::cmp::min(matching_count, a.len())..],
            &b[matching_count..],
        )?;
        // We also add matching_count to the depth because if we're starting
        // with a common prefix, we have at least x leading characters that
        // will be the same for all substrings.
        let mut depth =
            depth_for(dbg!(branching_factor), dbg!(computed_amount())) + dbg!(matching_count);

        // if branching_factor == 1 {
        //     // This should only be the case when we have an input like `"z", ""`.
        //     // In this case, we can generate strings after the z, but we need
        //     // to go one level deeper in any case.
        //     depth += 1;
        // }

        // TODO: Maybe keeping this as an iterator would be more efficient,
        // but it would have to be cloned at least once to get the pool length.
        let pool: Vec<String> = self.traverse("".into(), a, b, dbg!(depth)).collect();
        let pool = if (pool.len() as isize).saturating_sub(non_empty_input_count as isize)
            < amount.get() as isize
        {
            depth += depth_for(branching_factor, computed_amount() + pool.len());
            dbg!(self.traverse("".into(), a, b, dbg!(depth)).collect())
        } else {
            pool
        };
        if (pool.len() as isize).saturating_sub(non_empty_input_count as isize)
            < amount.get() as isize
        {
            // We still don't have enough items, so bail
            panic!(
                "Internal error: Failed to calculate the correct tree depth!
This is a bug. Please report it at: https://github.com/Follpvosten/mudders/issues
and make sure to include the following information:

Symbols in table: {symbols:?}
Given inputs: {a:?}, {b:?}, amount: {amount}
matching_count: {m_count}
non_empty_input_count: {ne_input_count}
required pool length (computed amount): {comp_amount}
branching_factor: {b_factor}
final depth: {depth}
pool: {pool:?} (length: {pool_len})",
                symbols = self.0.iter().map(|i| *i as char).collect::<Box<[_]>>(),
                a = a,
                b = b,
                amount = amount,
                m_count = matching_count,
                ne_input_count = non_empty_input_count,
                comp_amount = computed_amount(),
                b_factor = branching_factor,
                depth = depth,
                pool = pool,
                pool_len = pool.len(),
            )
        }
        Ok(if amount.get() == 1 {
            pool.get(pool.len() / 2)
                .map(|item| vec![item.clone()])
                .ok_or_else(|| FailedToGetMiddle)?
        } else {
            let step = computed_amount() as f64 / pool.len() as f64;
            let mut counter = 0f64;
            let mut last_value = 0;
            let result: Vec<_> = pool
                .into_iter()
                .filter(|_| {
                    counter += step;
                    let new_value = counter.floor() as usize;
                    if new_value > last_value {
                        last_value = new_value;
                        true
                    } else {
                        false
                    }
                })
                .take(amount.into())
                .collect();
            ensure! { result.len() == amount.get(), NotEnoughItemsInPool };
            result
        })
    }

    /// Convenience wrapper around `mudder` to generate exactly one string.
    ///
    /// # Safety
    /// This function calls `NonZeroUsize::new_unchecked(1)`.
    pub fn mudder_one(&self, a: &str, b: &str) -> Result<String, GenerationError> {
        self.mudder(a, b, unsafe { NonZeroUsize::new_unchecked(1) })
            .map(|mut vec| vec.remove(0))
    }

    /// Convenience wrapper around `mudder` to generate an amount of fresh strings.
    ///
    /// `SymbolTable.generate(amount)` is equivalent to `SymbolTable.mudder("", "", amount)`.
    pub fn generate(&self, amount: NonZeroUsize) -> Result<Vec<String>, GenerationError> {
        self.mudder("", "", amount)
    }

    /// Traverses a virtual tree of strings to the given depth.
    fn traverse<'a>(
        &'a self,
        curr_key: String,
        start: &'a str,
        end: &'a str,
        depth: usize,
    ) -> Box<dyn Iterator<Item = String> + 'a> {
        if depth == 0 {
            // If we've reached depth 0, we don't go futher.
            Box::new(std::iter::empty())
        } else {
            // Generate all possible mutations on the current depth
            Box::new(
                self.0
                    .iter()
                    .filter_map(move |c| -> Option<Box<dyn Iterator<Item = String>>> {
                        // TODO: Performance - this probably still isn't the best option.
                        let key = {
                            let the_char = *c as char;
                            let mut string =
                                String::with_capacity(curr_key.len() + the_char.len_utf8());
                            string.push_str(&curr_key);
                            string.push(the_char);
                            string
                        };

                        // After the end key, we definitely do not continue.
                        if key.as_str() > end && !end.is_empty() {
                            None
                        } else if key.as_str() < start {
                            // If we're prior to the start key...
                            // ...and the start key is a subkey of the current key...
                            if start.starts_with(&key) {
                                // ...only traverse the subtree, ignoring the key itself.
                                Some(Box::new(self.traverse(key, start, end, depth - 1)))
                            } else {
                                None
                            }
                        } else {
                            // Traverse normally, returning both the parent and sub key,
                            // in all other cases.
                            if key.len() < 2 {
                                let iter = std::iter::once(key.clone());
                                Some(if key == end {
                                    Box::new(iter)
                                } else {
                                    Box::new(iter.chain(self.traverse(key, start, end, depth - 1)))
                                })
                            } else {
                                let first = key.chars().next().unwrap();
                                Some(if key.chars().all(|c| c == first) {
                                    // If our characters are all the same,
                                    // don't add key to the list, only the subtree.
                                    Box::new(self.traverse(key, start, end, depth - 1))
                                } else {
                                    Box::new(std::iter::once(key.clone()).chain(self.traverse(
                                        key,
                                        start,
                                        end,
                                        depth - 1,
                                    )))
                                })
                            }
                        }
                    })
                    .flatten(),
            )
        }
    }

    fn distance_between_first_chars(
        &self,
        start: &str,
        end: &str,
    ) -> Result<usize, GenerationError> {
        use InternalError::WrongCharOrder;
        // check the first character of both strings...
        Ok(match (start.chars().next(), end.chars().next()) {
            // if both have a first char, compare them.
            (Some(start_char), Some(end_char)) => {
                ensure! { start_char < end_char, WrongCharOrder(start_char, end_char) }
                let distance =
                    try_ascii_u8_from_char(end_char)? - try_ascii_u8_from_char(start_char)?;
                distance as usize + 1
            }
            // if only the start has a first char, compare it to our last possible symbol.
            (Some(start_char), None) => {
                let end_u8 = self.0.last().unwrap();
                // In this case, we allow the start and end char to be equal.
                // This is because you can generate something after the last char,
                // but not before the first char.
                //                   vv
                ensure! { start_char <= *end_u8 as char, WrongCharOrder(start_char, *end_u8 as char) }
                let distance = end_u8 - try_ascii_u8_from_char(start_char)?;
                if distance == 0 {
                    2
                } else {
                    distance as usize + 1
                }
            }
            // if only the end has a first char, compare it to our first possible symbol.
            (None, Some(end_char)) => {
                let start_u8 = self.0.first().unwrap();
                ensure! { *start_u8 < end_char as u8, WrongCharOrder(*start_u8 as char, end_char) }
                let distance = try_ascii_u8_from_char(end_char)? - start_u8;
                if distance == 0 {
                    2
                } else {
                    distance as usize + 1
                }
            }
            // if there's no characters given, the whole symboltable is our range.
            _ => self.0.len(),
        })
    }

    fn contains_all_chars(&self, chars: impl AsRef<[u8]>) -> bool {
        chars.as_ref().iter().all(|c| self.0.contains(c))
    }
}

/// Calculate the required depth for the given values.
///
/// `branching_factor` is used as the logarithm base, `n_elements` as the
/// value, and the result is rounded up and cast to usize.
fn depth_for(branching_factor: usize, n_elements: usize) -> usize {
    f64::log(n_elements as f64, branching_factor as f64).ceil() as usize
}

fn try_ascii_u8_from_char(c: char) -> Result<u8, NonAsciiError> {
    u8::try_from(c as u32).map_err(NonAsciiError::from)
}
fn all_chars_ascii(chars: impl AsRef<[u8]>) -> bool {
    chars.as_ref().iter().all(|i| i.is_ascii())
}

impl FromStr for SymbolTable {
    type Err = CreationError;
    fn from_str(s: &str) -> Result<Self, CreationError> {
        Self::from_chars(&s.chars().collect::<Box<[_]>>())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::num::NonZeroUsize;

    /// Create and unwrap a NonZeroUsize from the given usize.
    fn n(n: usize) -> NonZeroUsize {
        NonZeroUsize::new(n).unwrap()
    }

    // Public API tests:

    #[test]
    #[allow(clippy::char_lit_as_u8)]
    fn valid_tables_work() {
        assert!(SymbolTable::new(&[1, 2, 3, 4, 5]).is_ok());
        assert!(SymbolTable::new(&[125, 126, 127]).is_ok());
        // Possible, but to be discouraged
        assert!(SymbolTable::new(&['a' as u8, 'f' as u8]).is_ok());
        assert!(SymbolTable::from_chars(&['a', 'b', 'c']).is_ok());
        assert!(SymbolTable::from_str("0123").is_ok());
    }

    #[test]
    fn invalid_tables_error() {
        assert!(SymbolTable::from_str("🍅😂👶🏻").is_err());
        assert!(SymbolTable::from_chars(&['🍌', '🍣', '']).is_err());
        assert!(SymbolTable::new(&[128, 129, 130]).is_err());
        assert!(SymbolTable::new(&[]).is_err());
        assert!(SymbolTable::from_chars(&[]).is_err());
        assert!(SymbolTable::from_str("").is_err());
    }

    #[test]
    fn unknown_chars_error() {
        use error::GenerationError::UnknownCharacters;
        // You cannot pass in strings with characters not in the SymbolTable:
        let table = SymbolTable::alphabet();
        assert_eq!(
            table.mudder_one("123", "()/"),
            Err(UnknownCharacters("123".into()))
        );
        assert_eq!(
            table.mudder_one("a", "123"),
            Err(UnknownCharacters("123".into()))
        );
        assert_eq!(
            table.mudder_one("0)(", "b"),
            Err(UnknownCharacters("0)(".into()))
        );
        let table = SymbolTable::from_str("123").unwrap();
        assert_eq!(
            table.mudder_one("a", "b"),
            Err(UnknownCharacters("a".into()))
        );
        assert_eq!(
            table.mudder_one("456", "1"),
            Err(UnknownCharacters("456".into()))
        );
        assert_eq!(
            table.mudder_one("2", "abc"),
            Err(UnknownCharacters("abc".into()))
        );
    }

    #[test]
    fn equal_strings_error() {
        use error::GenerationError::MatchingStrings;
        let table = SymbolTable::alphabet();
        assert_eq!(
            table.mudder_one("abc", "abc"),
            Err(MatchingStrings("abc".into()))
        );
        assert_eq!(
            table.mudder_one("xyz", "xyz"),
            Err(MatchingStrings("xyz".into()))
        );
    }

    // TODO: Make this test work.
    // I need to find out how to tell if two strings are lexicographically inseparable.
    // #[test]
    // fn lexicographically_adjacent_strings_error() {
    //     assert!(SymbolTable::alphabet().mudder("ba", "baa", n(1)).is_err());
    // }

    #[test]
    fn reasonable_values() {
        let table = SymbolTable::from_str("ab").unwrap();
        let result = table.mudder_one("a", "b").unwrap();
        assert_eq!(result, "ab");
        let table = SymbolTable::from_str("0123456789").unwrap();
        let result = table.mudder_one("1", "2").unwrap();
        assert_eq!(result, "15");
    }

    #[test]
    fn outputs_more_or_less_match_mudderjs() {
        let table = SymbolTable::from_str("abc").unwrap();
        let result = table.mudder_one("a", "b").unwrap();
        assert_eq!(result, "ac");
        let table = SymbolTable::alphabet();
        let result = table.mudder("anhui", "azazel", n(3)).unwrap();
        assert_eq!(result.len(), 3);
        assert_eq!(vec!["aq", "as", "av"], result);
    }

    #[test]
    fn empty_start() {
        let table = SymbolTable::from_str("abc").unwrap();
        let result = table.mudder("", "c", n(2)).unwrap();
        assert_eq!(result.len(), 2);
    }

    #[test]
    fn empty_end() {
        let table = SymbolTable::from_str("abc").unwrap();
        let result = table.mudder("b", "", n(2)).unwrap();
        assert_eq!(result.len(), 2);
    }

    #[test]
    fn generate_after_z() {
        let table = SymbolTable::alphabet();
        let result = table.mudder("z", "", n(10)).unwrap();
        assert_eq!(result.len(), 10);
        assert!(result.iter().all(|k| k.as_str() > "z"));
    }

    #[test]
    fn only_amount() {
        let table = SymbolTable::alphabet();
        let result = table.generate(n(10)).unwrap();
        assert_eq!(result.len(), 10);
    }

    #[test]
    fn values_sorting_correct() {
        let mut iter = SymbolTable::alphabet().generate(n(12)).into_iter();
        while let (Some(one), Some(two)) = (iter.next(), iter.next()) {
            assert!(one < two);
        }
    }

    #[test]
    fn differing_input_lengths() {
        let table = SymbolTable::alphabet();
        let result = table.mudder_one("a", "ab").unwrap();
        assert!(result.starts_with('a'));
    }

    #[test]
    fn values_consistently_between_start_and_end() {
        let table = SymbolTable::alphabet();
        {
            // From z to a
            let mut right = String::from("z");
            for _ in 0..500 {
                let new_val = dbg!(table.mudder_one("a", &right).unwrap());
                assert!(new_val < right);
                assert!(new_val.as_str() > "a");
                right = new_val;
            }
        }
        {
            // And from a to z
            let mut left = String::from("a");
            // TODO:    vv this test fails for higher numbers. FIXME!
            for _ in 0..17 {
                let new_val = dbg!(table.mudder_one(&left, "z").unwrap());
                assert!(new_val > left);
                assert!(new_val.as_str() < "z");
                left = new_val;
            }
        }
    }

    // Internal/private method tests:

    #[test]
    fn traverse_alphabet() {
        fn traverse_alphabet(a: &str, b: &str, depth: usize) -> Vec<String> {
            SymbolTable::alphabet()
                .traverse("".into(), a, b, depth)
                .collect()
        }
        assert_eq!(traverse_alphabet("a", "d", 1), vec!["a", "b", "c", "d"]);
        assert_eq!(
            traverse_alphabet("a", "z", 1),
            ('a' as u32 as u8..='z' as u32 as u8)
                .map(|c| (c as char).to_string())
                .collect::<Vec<_>>()
        );
        assert_eq!(
            traverse_alphabet("a", "b", 2),
            vec![
                "a", "ab", "ac", "ad", "ae", "af", "ag", "ah", "ai", "aj", "ak", "al", "am", "an",
                "ao", "ap", "aq", "ar", "as", "at", "au", "av", "aw", "ax", "ay", "az", "b"
            ]
        )
    }

    #[test]
    fn traverse_custom() {
        fn traverse(table: &str, a: &str, b: &str, depth: usize) -> Vec<String> {
            let table = SymbolTable::from_str(table).unwrap();
            table.traverse("".into(), a, b, depth).collect()
        }
        assert_eq!(traverse("abc", "a", "c", 1), vec!["a", "b", "c"]);
        assert_eq!(
            traverse("abc", "a", "c", 2),
            vec!["a", "ab", "ac", "b", "ba", "bc", "c"]
        );
        assert_eq!(
            traverse("0123456789", "1", "2", 2),
            vec!["1", "10", "12", "13", "14", "15", "16", "17", "18", "19", "2"]
        );
    }

    #[test]
    fn distance_between_first_chars_correct() {
        let table = SymbolTable::alphabet();
        assert_eq!(table.distance_between_first_chars("a", "b").unwrap(), 2);
        assert_eq!(table.distance_between_first_chars("a", "z").unwrap(), 26);
        assert_eq!(table.distance_between_first_chars("", "").unwrap(), 26);
        assert_eq!(table.distance_between_first_chars("n", "").unwrap(), 13);
        assert_eq!(table.distance_between_first_chars("", "n").unwrap(), 14);
        assert_eq!(table.distance_between_first_chars("y", "z").unwrap(), 2);
        assert_eq!(table.distance_between_first_chars("a", "y").unwrap(), 25);
        assert_eq!(
            table.distance_between_first_chars("aaaa", "zzzz").unwrap(),
            table.distance_between_first_chars("aa", "zz").unwrap()
        );

        let table = SymbolTable::from_str("12345").unwrap();
        assert_eq!(table.distance_between_first_chars("1", "2").unwrap(), 2);
        assert_eq!(table.distance_between_first_chars("1", "3").unwrap(), 3);
        assert_eq!(table.distance_between_first_chars("2", "3").unwrap(), 2);
    }
}