use core::num::NonZeroUsize;
use std::{convert::TryFrom, str::FromStr};
#[macro_use]
pub mod error;
use error::*;
#[derive(Clone, Debug)]
pub struct SymbolTable(Vec<u8>);
impl SymbolTable {
pub fn new(source: &[u8]) -> Result<Self, CreationError> {
ensure! { !source.is_empty(), CreationError::EmptySlice }
ensure! { all_chars_ascii(&source), NonAsciiError::NonAsciiU8 }
let mut vec: Vec<_> = source.iter().copied().collect();
vec.sort();
vec.dedup();
Ok(Self(vec))
}
pub fn from_chars(source: &[char]) -> Result<Self, CreationError> {
let inner: Box<[u8]> = source
.iter()
.map(|c| try_ascii_u8_from_char(*c))
.collect::<Result<_, _>>()?;
Ok(Self::new(&inner)?)
}
#[allow(clippy::char_lit_as_u8)]
pub fn alphabet() -> Self {
Self::new(&('a' as u8..='z' as u8).collect::<Box<[_]>>()).unwrap()
}
pub fn mudder(
&self,
a: &str,
b: &str,
amount: NonZeroUsize,
) -> Result<Vec<String>, GenerationError> {
use error::InternalError::*;
use GenerationError::*;
ensure! { all_chars_ascii(a), NonAsciiError::NonAsciiU8 }
ensure! { all_chars_ascii(b), NonAsciiError::NonAsciiU8 }
ensure! { self.contains_all_chars(a), UnknownCharacters(a.to_string()) }
ensure! { self.contains_all_chars(b), UnknownCharacters(b.to_string()) }
let (a, b) = if a.is_empty() || b.is_empty() {
(a, b)
} else if b < a {
(b, a)
} else {
ensure! { a != b, MatchingStrings(a.to_string()) }
(a, b)
};
let matching_count: usize = {
let (mut start_chars, mut end_chars) = (a.chars(), b.chars());
let mut last_start_char = '\0';
let mut i: usize = 0;
loop {
match (start_chars.next(), end_chars.next()) {
(Some(sc), Some(ec)) if sc == ec => {
last_start_char = sc;
i += 1;
continue;
}
(None, Some(ec)) if ec == last_start_char => {
i += 1;
continue;
}
(None, None) | (Some(_), None) | (None, Some(_)) | (Some(_), Some(_)) => {
break i
}
}
}
};
let non_empty_input_count = [a, b].iter().filter(|s| !s.is_empty()).count();
let computed_amount = || amount.get() + non_empty_input_count;
let branching_factor = self.distance_between_first_chars(
&a[std::cmp::min(matching_count, a.len())..],
&b[matching_count..],
)?;
let mut depth =
depth_for(dbg!(branching_factor), dbg!(computed_amount())) + dbg!(matching_count);
let pool: Vec<String> = self.traverse("".into(), a, b, dbg!(depth)).collect();
let pool = if (pool.len() as isize).saturating_sub(non_empty_input_count as isize)
< amount.get() as isize
{
depth += depth_for(branching_factor, computed_amount() + pool.len());
dbg!(self.traverse("".into(), a, b, dbg!(depth)).collect())
} else {
pool
};
if (pool.len() as isize).saturating_sub(non_empty_input_count as isize)
< amount.get() as isize
{
panic!(
"Internal error: Failed to calculate the correct tree depth!
This is a bug. Please report it at: https://github.com/Follpvosten/mudders/issues
and make sure to include the following information:
Symbols in table: {symbols:?}
Given inputs: {a:?}, {b:?}, amount: {amount}
matching_count: {m_count}
non_empty_input_count: {ne_input_count}
required pool length (computed amount): {comp_amount}
branching_factor: {b_factor}
final depth: {depth}
pool: {pool:?} (length: {pool_len})",
symbols = self.0.iter().map(|i| *i as char).collect::<Box<[_]>>(),
a = a,
b = b,
amount = amount,
m_count = matching_count,
ne_input_count = non_empty_input_count,
comp_amount = computed_amount(),
b_factor = branching_factor,
depth = depth,
pool = pool,
pool_len = pool.len(),
)
}
Ok(if amount.get() == 1 {
pool.get(pool.len() / 2)
.map(|item| vec![item.clone()])
.ok_or_else(|| FailedToGetMiddle)?
} else {
let step = computed_amount() as f64 / pool.len() as f64;
let mut counter = 0f64;
let mut last_value = 0;
let result: Vec<_> = pool
.into_iter()
.filter(|_| {
counter += step;
let new_value = counter.floor() as usize;
if new_value > last_value {
last_value = new_value;
true
} else {
false
}
})
.take(amount.into())
.collect();
ensure! { result.len() == amount.get(), NotEnoughItemsInPool };
result
})
}
pub fn mudder_one(&self, a: &str, b: &str) -> Result<String, GenerationError> {
self.mudder(a, b, unsafe { NonZeroUsize::new_unchecked(1) })
.map(|mut vec| vec.remove(0))
}
pub fn generate(&self, amount: NonZeroUsize) -> Result<Vec<String>, GenerationError> {
self.mudder("", "", amount)
}
fn traverse<'a>(
&'a self,
curr_key: String,
start: &'a str,
end: &'a str,
depth: usize,
) -> Box<dyn Iterator<Item = String> + 'a> {
if depth == 0 {
Box::new(std::iter::empty())
} else {
Box::new(
self.0
.iter()
.filter_map(move |c| -> Option<Box<dyn Iterator<Item = String>>> {
let key = {
let the_char = *c as char;
let mut string =
String::with_capacity(curr_key.len() + the_char.len_utf8());
string.push_str(&curr_key);
string.push(the_char);
string
};
if key.as_str() > end && !end.is_empty() {
None
} else if key.as_str() < start {
if start.starts_with(&key) {
Some(Box::new(self.traverse(key, start, end, depth - 1)))
} else {
None
}
} else {
if key.len() < 2 {
let iter = std::iter::once(key.clone());
Some(if key == end {
Box::new(iter)
} else {
Box::new(iter.chain(self.traverse(key, start, end, depth - 1)))
})
} else {
let first = key.chars().next().unwrap();
Some(if key.chars().all(|c| c == first) {
Box::new(self.traverse(key, start, end, depth - 1))
} else {
Box::new(std::iter::once(key.clone()).chain(self.traverse(
key,
start,
end,
depth - 1,
)))
})
}
}
})
.flatten(),
)
}
}
fn distance_between_first_chars(
&self,
start: &str,
end: &str,
) -> Result<usize, GenerationError> {
use InternalError::WrongCharOrder;
Ok(match (start.chars().next(), end.chars().next()) {
(Some(start_char), Some(end_char)) => {
ensure! { start_char < end_char, WrongCharOrder(start_char, end_char) }
let distance =
try_ascii_u8_from_char(end_char)? - try_ascii_u8_from_char(start_char)?;
distance as usize + 1
}
(Some(start_char), None) => {
let end_u8 = self.0.last().unwrap();
ensure! { start_char <= *end_u8 as char, WrongCharOrder(start_char, *end_u8 as char) }
let distance = end_u8 - try_ascii_u8_from_char(start_char)?;
if distance == 0 {
2
} else {
distance as usize + 1
}
}
(None, Some(end_char)) => {
let start_u8 = self.0.first().unwrap();
ensure! { *start_u8 < end_char as u8, WrongCharOrder(*start_u8 as char, end_char) }
let distance = try_ascii_u8_from_char(end_char)? - start_u8;
if distance == 0 {
2
} else {
distance as usize + 1
}
}
_ => self.0.len(),
})
}
fn contains_all_chars(&self, chars: impl AsRef<[u8]>) -> bool {
chars.as_ref().iter().all(|c| self.0.contains(c))
}
}
fn depth_for(branching_factor: usize, n_elements: usize) -> usize {
f64::log(n_elements as f64, branching_factor as f64).ceil() as usize
}
fn try_ascii_u8_from_char(c: char) -> Result<u8, NonAsciiError> {
u8::try_from(c as u32).map_err(NonAsciiError::from)
}
fn all_chars_ascii(chars: impl AsRef<[u8]>) -> bool {
chars.as_ref().iter().all(|i| i.is_ascii())
}
impl FromStr for SymbolTable {
type Err = CreationError;
fn from_str(s: &str) -> Result<Self, CreationError> {
Self::from_chars(&s.chars().collect::<Box<[_]>>())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::num::NonZeroUsize;
fn n(n: usize) -> NonZeroUsize {
NonZeroUsize::new(n).unwrap()
}
#[test]
#[allow(clippy::char_lit_as_u8)]
fn valid_tables_work() {
assert!(SymbolTable::new(&[1, 2, 3, 4, 5]).is_ok());
assert!(SymbolTable::new(&[125, 126, 127]).is_ok());
assert!(SymbolTable::new(&['a' as u8, 'f' as u8]).is_ok());
assert!(SymbolTable::from_chars(&['a', 'b', 'c']).is_ok());
assert!(SymbolTable::from_str("0123").is_ok());
}
#[test]
fn invalid_tables_error() {
assert!(SymbolTable::from_str("🍅😂👶🏻").is_err());
assert!(SymbolTable::from_chars(&['🍌', '🍣', '⛈']).is_err());
assert!(SymbolTable::new(&[128, 129, 130]).is_err());
assert!(SymbolTable::new(&[]).is_err());
assert!(SymbolTable::from_chars(&[]).is_err());
assert!(SymbolTable::from_str("").is_err());
}
#[test]
fn unknown_chars_error() {
use error::GenerationError::UnknownCharacters;
let table = SymbolTable::alphabet();
assert_eq!(
table.mudder_one("123", "()/"),
Err(UnknownCharacters("123".into()))
);
assert_eq!(
table.mudder_one("a", "123"),
Err(UnknownCharacters("123".into()))
);
assert_eq!(
table.mudder_one("0)(", "b"),
Err(UnknownCharacters("0)(".into()))
);
let table = SymbolTable::from_str("123").unwrap();
assert_eq!(
table.mudder_one("a", "b"),
Err(UnknownCharacters("a".into()))
);
assert_eq!(
table.mudder_one("456", "1"),
Err(UnknownCharacters("456".into()))
);
assert_eq!(
table.mudder_one("2", "abc"),
Err(UnknownCharacters("abc".into()))
);
}
#[test]
fn equal_strings_error() {
use error::GenerationError::MatchingStrings;
let table = SymbolTable::alphabet();
assert_eq!(
table.mudder_one("abc", "abc"),
Err(MatchingStrings("abc".into()))
);
assert_eq!(
table.mudder_one("xyz", "xyz"),
Err(MatchingStrings("xyz".into()))
);
}
#[test]
fn reasonable_values() {
let table = SymbolTable::from_str("ab").unwrap();
let result = table.mudder_one("a", "b").unwrap();
assert_eq!(result, "ab");
let table = SymbolTable::from_str("0123456789").unwrap();
let result = table.mudder_one("1", "2").unwrap();
assert_eq!(result, "15");
}
#[test]
fn outputs_more_or_less_match_mudderjs() {
let table = SymbolTable::from_str("abc").unwrap();
let result = table.mudder_one("a", "b").unwrap();
assert_eq!(result, "ac");
let table = SymbolTable::alphabet();
let result = table.mudder("anhui", "azazel", n(3)).unwrap();
assert_eq!(result.len(), 3);
assert_eq!(vec!["aq", "as", "av"], result);
}
#[test]
fn empty_start() {
let table = SymbolTable::from_str("abc").unwrap();
let result = table.mudder("", "c", n(2)).unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn empty_end() {
let table = SymbolTable::from_str("abc").unwrap();
let result = table.mudder("b", "", n(2)).unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn generate_after_z() {
let table = SymbolTable::alphabet();
let result = table.mudder("z", "", n(10)).unwrap();
assert_eq!(result.len(), 10);
assert!(result.iter().all(|k| k.as_str() > "z"));
}
#[test]
fn only_amount() {
let table = SymbolTable::alphabet();
let result = table.generate(n(10)).unwrap();
assert_eq!(result.len(), 10);
}
#[test]
fn values_sorting_correct() {
let mut iter = SymbolTable::alphabet().generate(n(12)).into_iter();
while let (Some(one), Some(two)) = (iter.next(), iter.next()) {
assert!(one < two);
}
}
#[test]
fn differing_input_lengths() {
let table = SymbolTable::alphabet();
let result = table.mudder_one("a", "ab").unwrap();
assert!(result.starts_with('a'));
}
#[test]
fn values_consistently_between_start_and_end() {
let table = SymbolTable::alphabet();
{
let mut right = String::from("z");
for _ in 0..500 {
let new_val = dbg!(table.mudder_one("a", &right).unwrap());
assert!(new_val < right);
assert!(new_val.as_str() > "a");
right = new_val;
}
}
{
let mut left = String::from("a");
for _ in 0..17 {
let new_val = dbg!(table.mudder_one(&left, "z").unwrap());
assert!(new_val > left);
assert!(new_val.as_str() < "z");
left = new_val;
}
}
}
#[test]
fn traverse_alphabet() {
fn traverse_alphabet(a: &str, b: &str, depth: usize) -> Vec<String> {
SymbolTable::alphabet()
.traverse("".into(), a, b, depth)
.collect()
}
assert_eq!(traverse_alphabet("a", "d", 1), vec!["a", "b", "c", "d"]);
assert_eq!(
traverse_alphabet("a", "z", 1),
('a' as u32 as u8..='z' as u32 as u8)
.map(|c| (c as char).to_string())
.collect::<Vec<_>>()
);
assert_eq!(
traverse_alphabet("a", "b", 2),
vec![
"a", "ab", "ac", "ad", "ae", "af", "ag", "ah", "ai", "aj", "ak", "al", "am", "an",
"ao", "ap", "aq", "ar", "as", "at", "au", "av", "aw", "ax", "ay", "az", "b"
]
)
}
#[test]
fn traverse_custom() {
fn traverse(table: &str, a: &str, b: &str, depth: usize) -> Vec<String> {
let table = SymbolTable::from_str(table).unwrap();
table.traverse("".into(), a, b, depth).collect()
}
assert_eq!(traverse("abc", "a", "c", 1), vec!["a", "b", "c"]);
assert_eq!(
traverse("abc", "a", "c", 2),
vec!["a", "ab", "ac", "b", "ba", "bc", "c"]
);
assert_eq!(
traverse("0123456789", "1", "2", 2),
vec!["1", "10", "12", "13", "14", "15", "16", "17", "18", "19", "2"]
);
}
#[test]
fn distance_between_first_chars_correct() {
let table = SymbolTable::alphabet();
assert_eq!(table.distance_between_first_chars("a", "b").unwrap(), 2);
assert_eq!(table.distance_between_first_chars("a", "z").unwrap(), 26);
assert_eq!(table.distance_between_first_chars("", "").unwrap(), 26);
assert_eq!(table.distance_between_first_chars("n", "").unwrap(), 13);
assert_eq!(table.distance_between_first_chars("", "n").unwrap(), 14);
assert_eq!(table.distance_between_first_chars("y", "z").unwrap(), 2);
assert_eq!(table.distance_between_first_chars("a", "y").unwrap(), 25);
assert_eq!(
table.distance_between_first_chars("aaaa", "zzzz").unwrap(),
table.distance_between_first_chars("aa", "zz").unwrap()
);
let table = SymbolTable::from_str("12345").unwrap();
assert_eq!(table.distance_between_first_chars("1", "2").unwrap(), 2);
assert_eq!(table.distance_between_first_chars("1", "3").unwrap(), 3);
assert_eq!(table.distance_between_first_chars("2", "3").unwrap(), 2);
}
}