use std::collections::HashMap;
use smallvec::SmallVec;
use crate::error::{HanziSortError, Result};
use crate::r#override::PinyinOverride;
const INLINE_OVERRIDE_LEN: usize = 8;
const MAX_ENCODED_PINYIN_LEN: usize = 16;
#[derive(Debug, Clone, Default)]
pub(crate) struct EncodedOverride {
char_override: HashMap<char, u128>,
phrase_override: HashMap<String, SmallVec<[u128; INLINE_OVERRIDE_LEN]>>,
}
impl EncodedOverride {
pub(crate) fn phrase_override(&self, phrase: &str) -> Option<&[u128]> {
self.phrase_override.get(phrase).map(SmallVec::as_slice)
}
pub(crate) fn char_override(&self, character: char) -> Option<u128> {
self.char_override.get(&character).copied()
}
}
impl TryFrom<&PinyinOverride> for EncodedOverride {
type Error = HanziSortError;
fn try_from(value: &PinyinOverride) -> Result<Self> {
let mut char_override = HashMap::with_capacity(value.char_override.len());
for (character, pinyin) in &value.char_override {
let encoded = encode_primary_pinyin(pinyin).map_err(|reason| {
HanziSortError::InvalidOverride(format!(
"char_override entry '{character}' has unencodable pinyin '{pinyin}': {reason}"
))
})?;
char_override.insert(*character, encoded);
}
let mut phrase_override = HashMap::with_capacity(value.phrase_override.len());
for (phrase, pinyins) in &value.phrase_override {
let mut encoded: SmallVec<[u128; INLINE_OVERRIDE_LEN]> =
SmallVec::with_capacity(pinyins.len());
for pinyin in pinyins {
let value = encode_primary_pinyin(pinyin).map_err(|reason| {
HanziSortError::InvalidOverride(format!(
"phrase_override entry '{phrase}' has unencodable pinyin '{pinyin}': \
{reason}"
))
})?;
encoded.push(value);
}
phrase_override.insert(phrase.clone(), encoded);
}
Ok(Self {
char_override,
phrase_override,
})
}
}
pub(crate) fn encode_primary_pinyin(pinyin: &str) -> std::result::Result<u128, &'static str> {
if pinyin.is_empty() {
return Err("pinyin syllable is empty");
}
if pinyin.len() > MAX_ENCODED_PINYIN_LEN {
return Err("pinyin syllable exceeds 16 bytes");
}
if !pinyin.is_ascii() {
return Err("pinyin syllable must be ASCII");
}
let mut encoded = 0_u128;
for byte in pinyin.bytes() {
encoded = (encoded << 8) | byte as u128;
}
Ok(encoded << ((MAX_ENCODED_PINYIN_LEN - pinyin.len()) * 8))
}
pub(crate) fn encode_primary_pinyin_unchecked(pinyin: &str) -> u128 {
debug_assert!(
!pinyin.is_empty()
&& pinyin.is_ascii()
&& pinyin.len() <= MAX_ENCODED_PINYIN_LEN,
"encode_primary_pinyin_unchecked invariants violated for {pinyin:?}"
);
let mut encoded = 0_u128;
for byte in pinyin.bytes() {
encoded = (encoded << 8) | byte as u128;
}
encoded << ((MAX_ENCODED_PINYIN_LEN - pinyin.len()) * 8)
}
#[cfg(test)]
mod tests {
use super::{encode_primary_pinyin, encode_primary_pinyin_unchecked};
#[test]
fn encode_primary_pinyin_preserves_lexicographic_order() {
assert!(encode_primary_pinyin("a").unwrap() < encode_primary_pinyin("aa").unwrap());
assert!(
encode_primary_pinyin("chong2").unwrap() < encode_primary_pinyin("qing4").unwrap()
);
assert!(
encode_primary_pinyin("zhong4").unwrap() < encode_primary_pinyin("zhong5").unwrap()
);
}
#[test]
fn encode_primary_pinyin_rejects_non_ascii() {
let err = encode_primary_pinyin("nü").expect_err("non-ASCII should fail");
assert!(err.contains("ASCII"));
}
#[test]
fn encode_primary_pinyin_rejects_oversized_input() {
let too_long = "a".repeat(17);
let err = encode_primary_pinyin(&too_long).expect_err("oversize should fail");
assert!(err.contains("16 bytes"));
}
#[test]
fn encode_primary_pinyin_unchecked_matches_checked_for_valid_input() {
for syllable in ["a", "chong2", "zhong5", "shuang1"] {
let checked = encode_primary_pinyin(syllable).expect("valid input");
let unchecked = encode_primary_pinyin_unchecked(syllable);
assert_eq!(checked, unchecked, "mismatch for {syllable}");
}
}
#[test]
fn encode_primary_pinyin_rejects_empty() {
let err = encode_primary_pinyin("").expect_err("empty should fail");
assert!(err.contains("empty"));
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
fn tone3_syllable() -> impl Strategy<Value = String> {
("[a-z]{1,15}", "[1-5]").prop_map(|(letters, tone)| format!("{letters}{tone}"))
}
fn ascii_string_up_to_16() -> impl Strategy<Value = String> {
prop::string::string_regex("[\\x21-\\x7e]{1,16}").unwrap()
}
proptest! {
#[test]
fn encoding_preserves_lex_order_for_arbitrary_ascii(
a in ascii_string_up_to_16(),
b in ascii_string_up_to_16(),
) {
let ea = encode_primary_pinyin(&a).expect("ASCII ≤16 bytes is valid");
let eb = encode_primary_pinyin(&b).expect("ASCII ≤16 bytes is valid");
prop_assert_eq!(ea.cmp(&eb), a.cmp(&b));
}
#[test]
fn encoding_preserves_lex_order_for_tone3(
a in tone3_syllable(),
b in tone3_syllable(),
) {
let ea = encode_primary_pinyin(&a).expect("valid tone3");
let eb = encode_primary_pinyin(&b).expect("valid tone3");
prop_assert_eq!(ea.cmp(&eb), a.cmp(&b));
}
#[test]
fn unchecked_agrees_with_checked_on_valid_input(s in tone3_syllable()) {
let checked = encode_primary_pinyin(&s).expect("valid tone3");
let unchecked = encode_primary_pinyin_unchecked(&s);
prop_assert_eq!(checked, unchecked);
}
}
}