rust-mando 0.1.2

Convert Chinese characters to pinyin with jieba word segmentation
Documentation
//! Runtime pinyin lookup — reads `$OUT_DIR/pinyin.dat` (built by `build.rs`).
//!
//! # Binary layout (after zstd decompression, little-endian)
//!
//! ```text
//! u32 N                                   ← entry count
//! [u64 key_hash | u32 heap_offset] × N    ← index sorted by hash
//! [u8 len, len×UTF-8 bytes] × N           ← heap: pinyin_numbers strings
//! ```
//!
//! `pinyin_marks` is derived from `pinyin_numbers` at call time via
//! [`numbers_to_marks`].

use std::sync::OnceLock;

// ── static table ──────────────────────────────────────────────────────────────

static TABLE: OnceLock<Vec<u8>> = OnceLock::new();
static PINYIN_DAT: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/pinyin.dat"));

fn get_table() -> &'static [u8] {
    TABLE.get_or_init(|| {
        use ruzstd::streaming_decoder::StreamingDecoder;
        use std::io::Read;
        let mut buf = Vec::new();
        StreamingDecoder::new(PINYIN_DAT)
            .expect("invalid zstd stream in pinyin.dat")
            .read_to_end(&mut buf)
            .expect("failed to decompress pinyin.dat");
        buf
    })
}

// ── FNV-1a-64 (must match build.rs exactly) ──────────────────────────────────

#[inline]
fn fnv1a_64(bytes: &[u8]) -> u64 {
    let mut h: u64 = 0xcbf29ce484222325;
    for &b in bytes { h ^= b as u64; h = h.wrapping_mul(0x100000001b3); }
    h
}

// ── index readers ─────────────────────────────────────────────────────────────

#[inline] fn u32_le(data: &[u8], off: usize) -> u32 {
    u32::from_le_bytes(data[off..off+4].try_into().unwrap())
}
#[inline] fn u64_le(data: &[u8], off: usize) -> u64 {
    u64::from_le_bytes(data[off..off+8].try_into().unwrap())
}

// ── public API ────────────────────────────────────────────────────────────────

/// Look up `word` (traditional Chinese) and return its `pinyin_numbers` string,
/// e.g. `"bei3 jing1"`.  Returns `None` if the word is not in the table.
pub fn lookup_numbers(word: &str) -> Option<String> {
    let table = get_table();
    let n     = u32_le(table, 0) as usize;
    if n == 0 { return None; }

    const HEADER: usize     = 4;
    const ENTRY:  usize     = 12; // 8 (hash) + 4 (offset)
    let index_end           = HEADER + n * ENTRY;
    let target              = fnv1a_64(word.as_bytes());

    // Binary search
    let mut lo = 0usize;
    let mut hi = n;
    let found_mid = loop {
        if lo >= hi { return None; }
        let mid = lo + (hi - lo) / 2;
        match u64_le(table, HEADER + mid * ENTRY).cmp(&target) {
            std::cmp::Ordering::Less    => lo = mid + 1,
            std::cmp::Ordering::Greater => hi = mid,
            std::cmp::Ordering::Equal   => break mid,
        }
    };

    // Scan back to first entry with this hash (handle collisions)
    let mut i = found_mid;
    while i > 0 && u64_le(table, HEADER + (i-1) * ENTRY) == target { i -= 1; }

    // Probe forward through all entries with matching hash
    while i < n && u64_le(table, HEADER + i * ENTRY) == target {
        let heap_off = u32_le(table, HEADER + i * ENTRY + 8) as usize;
        let len      = table[index_end + heap_off] as usize;
        let bytes    = &table[index_end + heap_off + 1 .. index_end + heap_off + 1 + len];
        let numbers  = std::str::from_utf8(bytes).unwrap_or("");
        // Verify key by re-hashing (FNV collision is astronomically rare but
        // let's be correct anyway — the re-hash costs nothing on a hit)
        // We store only the hash, not the key, so we accept the first match.
        // For a dict keyed on CJK strings, FNV-1a collisions are negligible.
        return Some(numbers.to_string());
    }
    None
}

/// Convert `pinyin_numbers` (e.g. `"bei3 jing1"`) to `pinyin_marks`
/// (e.g. `"běi jīng"`).
pub fn numbers_to_marks(numbers: &str) -> String {
    numbers
        .split_whitespace()
        .map(syllable_to_marks)
        .collect::<Vec<_>>()
        .join(" ")
}

// ── tone conversion ───────────────────────────────────────────────────────────

/// Convert a single tone-number syllable like `bei3` or `Zhong1` to its
/// tone-mark form `běi` / `Zhōng`.
/// Neutral tone (5) strips the digit. No digit → returned unchanged.
fn syllable_to_marks(syl: &str) -> String {
    let last = match syl.chars().last() {
        Some(c) => c,
        None    => return syl.to_string(),
    };

    let tone = match last.to_digit(10) {
        Some(d) => d as u8,
        None    => return syl.to_string(), // no tone digit
    };

    // Strip the digit
    let base = &syl[..syl.len() - last.len_utf8()];
    if base.is_empty() { return syl.to_string(); }

    // Neutral tone — just remove the digit
    if tone == 5 { return base.to_string(); }

    // Preserve capitalisation of first char
    let first_upper = base.chars().next().map(|c| c.is_uppercase()).unwrap_or(false);
    let lower = base.to_lowercase();

    let marked = apply_tone(&lower, tone);

    if first_upper {
        let mut chars = marked.chars();
        match chars.next() {
            None    => marked,
            Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
        }
    } else {
        marked
    }
}

/// Place the tone diacritic on the correct vowel of a *lowercase* syllable.
///
/// Rules (standard Mandarin pinyin):
/// 1. If syllable contains `a` or `e`, that vowel takes the mark.
/// 2. If syllable ends in `ou`, the `o` takes the mark.
/// 3. Otherwise the last vowel takes the mark.
fn apply_tone(syl: &str, tone: u8) -> String {
    const VOWELS: &[char] = &['a', 'e', 'i', 'o', 'u', 'ü'];

    let target = if let Some(p) = syl.find('a').or_else(|| syl.find('e')) {
        p
    } else if let Some(p) = syl.find("ou") {
        p
    } else {
        match syl.char_indices().filter(|(_, c)| VOWELS.contains(c)).last() {
            Some((i, _)) => i,
            None         => return syl.to_string(),
        }
    };

    let ch = syl[target..].chars().next().unwrap();
    let marked = match (ch, tone) {
        ('a',1)=>'ā', ('a',2)=>'á', ('a',3)=>'ǎ', ('a',4)=>'à',
        ('e',1)=>'ē', ('e',2)=>'é', ('e',3)=>'ě', ('e',4)=>'è',
        ('i',1)=>'ī', ('i',2)=>'í', ('i',3)=>'ǐ', ('i',4)=>'ì',
        ('o',1)=>'ō', ('o',2)=>'ó', ('o',3)=>'ǒ', ('o',4)=>'ò',
        ('u',1)=>'ū', ('u',2)=>'ú', ('u',3)=>'ǔ', ('u',4)=>'ù',
        ('ü',1)=>'ǖ', ('ü',2)=>'ǘ', ('ü',3)=>'ǚ', ('ü',4)=>'ǜ',
        _      => ch,
    };

    let mut out = String::with_capacity(syl.len() + 4);
    out.push_str(&syl[..target]);
    out.push(marked);
    out.push_str(&syl[target + ch.len_utf8()..]);
    out
}

// ── tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::{numbers_to_marks, syllable_to_marks};

    #[test]
    fn tone_marks_basic() {
        assert_eq!(syllable_to_marks("bei3"),   "běi");
        assert_eq!(syllable_to_marks("jing1"),  "jīng");
        assert_eq!(syllable_to_marks("ni3"),    "");
        assert_eq!(syllable_to_marks("hao3"),   "hǎo");
    }

    #[test]
    fn tone_marks_capitals() {
        assert_eq!(syllable_to_marks("Zhong1"), "Zhōng");
        assert_eq!(syllable_to_marks("Bei3"),   "Běi");
    }

    #[test]
    fn tone_marks_neutral() {
        assert_eq!(syllable_to_marks("ma5"),  "ma");
        assert_eq!(syllable_to_marks("men5"), "men");
    }

    #[test]
    fn tone_marks_u_umlaut() {
        assert_eq!(syllable_to_marks("lü4"),  "");
        assert_eq!(syllable_to_marks("nü3"),  "");
    }

    #[test]
    fn tone_marks_ou_rule() {
        // 'ou' → tone on o, not u
        assert_eq!(syllable_to_marks("gou3"), "gǒu");
        assert_eq!(syllable_to_marks("dou4"), "dòu");
    }

    #[test]
    fn numbers_to_marks_phrase() {
        assert_eq!(numbers_to_marks("bei3 jing1"), "běi jīng");
        assert_eq!(numbers_to_marks("ni3 hao3"),   "nǐ hǎo");
        assert_eq!(numbers_to_marks("Zhong1 guo2"), "Zhōng guó");
    }

    #[test]
    fn no_digit_passthrough() {
        assert_eq!(syllable_to_marks("r"),  "r");  // erhua suffix
        assert_eq!(syllable_to_marks(""),   "");
    }
}