use std::{collections::HashMap, fs, path::Path};
use smol_str::format_smolstr;
use crate::{
audio::tts::g2p::{
arpabet,
types::{Lexicon, LexiconEntry},
},
error::{Error, FileIoPayload, FileOp, MissingKeyPayload, OutOfRangePayload, Result},
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RawEntry {
word: String,
arpabet: Vec<String>,
variant: Option<u32>,
line_number: usize,
}
impl RawEntry {
pub fn new(
word: impl Into<String>,
arpabet: Vec<String>,
variant: Option<u32>,
line_number: usize,
) -> Self {
Self {
word: word.into(),
arpabet,
variant,
line_number,
}
}
#[inline(always)]
pub fn word(&self) -> &str {
&self.word
}
#[inline(always)]
pub fn arpabet(&self) -> &[String] {
&self.arpabet
}
#[inline(always)]
pub fn variant(&self) -> Option<u32> {
self.variant
}
#[inline(always)]
pub fn line_number(&self) -> usize {
self.line_number
}
}
fn malformed_word(word_token: &str, line_number: usize, reason: &'static str) -> Error {
Error::OutOfRange(OutOfRangePayload::new(
"CMUDict parse: malformed word token (expected WORD or WORD(N))",
reason,
format_smolstr!("line {line_number}: '{word_token}'"),
))
}
fn parse_word_and_variant(word_token: &str, line_number: usize) -> Result<(&str, Option<u32>)> {
let Some(open_idx) = word_token.find('(') else {
return Ok((word_token, None));
};
if !word_token.ends_with(')') {
return Err(malformed_word(
word_token,
line_number,
"trailing characters after closing paren (or missing closing paren)",
));
}
let base = &word_token[..open_idx];
if base.is_empty() {
return Err(malformed_word(
word_token,
line_number,
"empty base word before opening paren",
));
}
let variant_str = &word_token[open_idx + 1..word_token.len() - 1];
if variant_str.is_empty() {
return Err(malformed_word(
word_token,
line_number,
"empty variant index between parens",
));
}
if !variant_str.bytes().all(|b| b.is_ascii_digit()) {
return Err(malformed_word(
word_token,
line_number,
"variant index must be 1+ ASCII digits",
));
}
let variant = variant_str.parse::<u32>().map_err(|_| {
malformed_word(
word_token,
line_number,
"variant index overflows u32 (>4_294_967_295)",
)
})?;
Ok((base, Some(variant)))
}
pub fn parse_line(line: &str, line_number: usize) -> Result<Option<RawEntry>> {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with(";;;") {
return Ok(None);
}
let Some(first_space) = trimmed.find(' ') else {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"CMUDict line",
"must contain whitespace between word and pronunciation",
format_smolstr!("{line_number}"),
)));
};
let word_part = &trimmed[..first_space];
let pron_part = trimmed[first_space + 1..].trim();
if word_part.is_empty() || pron_part.is_empty() {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"CMUDict line",
"word and pronunciation must both be non-empty",
format_smolstr!("{line_number}"),
)));
}
let (word_str, variant) = parse_word_and_variant(word_part, line_number)?;
let word = word_str.to_lowercase();
let arpabet: Vec<String> = pron_part
.split(' ')
.filter(|s| !s.is_empty())
.map(String::from)
.collect();
if arpabet.is_empty() {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"CMUDict line",
"pronunciation must be non-empty",
format_smolstr!("{line_number}"),
)));
}
Ok(Some(RawEntry::new(word, arpabet, variant, line_number)))
}
pub fn parse(text: &str, primary_only: bool) -> Result<Vec<RawEntry>> {
let mut out = Vec::new();
for (idx, line) in text.lines().enumerate() {
let line_number = idx + 1;
if let Some(entry) = parse_line(line, line_number)?
&& (!primary_only || entry.variant().is_none())
{
out.push(entry);
}
}
Ok(out)
}
#[derive(Debug, Clone)]
pub struct CMUDict {
entries: HashMap<String, LexiconEntry>,
}
impl CMUDict {
#[must_use]
pub fn from_entries(entries: impl IntoIterator<Item = LexiconEntry>) -> Self {
let mut map = HashMap::new();
for entry in entries {
let key = entry.grapheme().to_lowercase();
map.insert(key, entry);
}
Self { entries: map }
}
pub fn from_raw_entries(raw: impl IntoIterator<Item = RawEntry>) -> Result<Self> {
let mut entries = Vec::new();
for r in raw {
let phonemes = arpabet::try_convert_sequence_strict(r.arpabet()).map_err(|bad| {
Error::OutOfRange(OutOfRangePayload::new(
"CMUDict ARPAbet token",
"must be a known ARPAbet symbol",
format_smolstr!(
"line {}: word '{}' token '{}'",
r.line_number(),
r.word(),
bad.token(),
),
))
})?;
if phonemes.is_empty() {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"CMUDict line: pronunciation after ARPAbet → IPA conversion",
"must be non-empty",
format_smolstr!("line {}: word '{}'", r.line_number(), r.word()),
)));
}
entries.push(LexiconEntry::new(r.word().to_owned(), phonemes));
}
Ok(Self::from_entries(entries))
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
impl Lexicon for CMUDict {
fn lookup(&self, grapheme: &str) -> Option<&LexiconEntry> {
self.entries.get(&grapheme.to_lowercase())
}
}
pub struct CMUDictLoader;
impl CMUDictLoader {
pub fn load(directory: &Path) -> Result<CMUDict> {
let path = directory.join("cmudict.dict");
if !path.exists() {
return Err(Error::MissingKey(MissingKeyPayload::new(
"CMUDictLoader::load: required file not found",
format_smolstr!("{}", path.display()),
)));
}
let bytes = fs::read(&path).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"read",
FileOp::Read,
path.to_path_buf(),
e,
))
})?;
let text = match std::str::from_utf8(&bytes) {
Ok(s) => s.to_owned(),
Err(_) => bytes.iter().map(|&b| b as char).collect(),
};
let raw = parse(&text, true)?;
CMUDict::from_raw_entries(raw)
}
}
#[cfg(test)]
mod tests;