use std::collections::HashMap;
use std::marker::PhantomData;
use crate::semiring::Semiring;
use crate::wfst::{MutableWfst, VectorWfst};
use super::cascade::LexiconEntry;
use super::context::PhoneId;
use super::ngram::WordId;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum MarkingStyle {
LeftMarked,
RightMarked,
BoundaryTag,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SubwordPosition {
WholeWord,
Initial,
Medial,
Final,
}
#[derive(Clone, Debug)]
pub struct SubwordEntry<W: Semiring> {
pub subword: String,
pub subword_id: u32,
pub phones: Vec<PhoneId>,
pub position: SubwordPosition,
pub weight: W,
pub raw_subword: String,
}
pub struct SubwordLexiconBuilder<W: Semiring> {
marking_style: MarkingStyle,
subword_vocab: HashMap<String, u32>,
reverse_vocab: Vec<String>,
entries: Vec<SubwordEntry<W>>,
word_decompositions: HashMap<WordId, Vec<u32>>,
phone_vocab: HashMap<String, PhoneId>,
reverse_phone_vocab: Vec<String>,
next_subword_id: u32,
_weight: PhantomData<W>,
}
impl<W: Semiring + Clone> SubwordLexiconBuilder<W> {
pub fn new(marking_style: MarkingStyle) -> Self {
Self {
marking_style,
subword_vocab: HashMap::new(),
reverse_vocab: Vec::new(),
entries: Vec::new(),
word_decompositions: HashMap::new(),
phone_vocab: HashMap::new(),
reverse_phone_vocab: Vec::new(),
next_subword_id: 0,
_weight: PhantomData,
}
}
pub fn marking_style(&self) -> MarkingStyle {
self.marking_style
}
pub fn vocab_size(&self) -> usize {
self.subword_vocab.len()
}
pub fn phone_vocab_size(&self) -> usize {
self.phone_vocab.len()
}
fn intern_phone(&mut self, phone: &str) -> PhoneId {
if let Some(&id) = self.phone_vocab.get(phone) {
id
} else {
let id = self.reverse_phone_vocab.len() as PhoneId;
self.phone_vocab.insert(phone.to_string(), id);
self.reverse_phone_vocab.push(phone.to_string());
id
}
}
fn apply_marking(&self, subword: &str, position: SubwordPosition) -> String {
match (self.marking_style, position) {
(_, SubwordPosition::WholeWord) => subword.to_string(),
(MarkingStyle::LeftMarked, SubwordPosition::Initial) => subword.to_string(),
(MarkingStyle::LeftMarked, SubwordPosition::Medial) => format!("+{}", subword),
(MarkingStyle::LeftMarked, SubwordPosition::Final) => format!("+{}", subword),
(MarkingStyle::RightMarked, SubwordPosition::Initial) => format!("{}+", subword),
(MarkingStyle::RightMarked, SubwordPosition::Medial) => format!("{}+", subword),
(MarkingStyle::RightMarked, SubwordPosition::Final) => subword.to_string(),
(MarkingStyle::BoundaryTag, SubwordPosition::Initial) => format!("<w>{}", subword),
(MarkingStyle::BoundaryTag, SubwordPosition::Medial) => subword.to_string(),
(MarkingStyle::BoundaryTag, SubwordPosition::Final) => subword.to_string(),
}
}
pub fn is_word_boundary(&self, marked_subword: &str) -> bool {
match self.marking_style {
MarkingStyle::LeftMarked => !marked_subword.starts_with('+'),
MarkingStyle::RightMarked => !marked_subword.ends_with('+'),
MarkingStyle::BoundaryTag => marked_subword.starts_with("<w>"),
}
}
pub fn add_word(&mut self, word: &str, phones: &[&str], weight: W) -> u32 {
self.add_subword(word, phones, SubwordPosition::WholeWord, weight)
}
pub fn add_subword(
&mut self,
subword: &str,
phones: &[&str],
position: SubwordPosition,
weight: W,
) -> u32 {
let marked = self.apply_marking(subword, position);
if let Some(&id) = self.subword_vocab.get(&marked) {
return id;
}
let id = self.next_subword_id;
self.next_subword_id += 1;
let phone_ids: Vec<PhoneId> = phones.iter().map(|p| self.intern_phone(p)).collect();
self.subword_vocab.insert(marked.clone(), id);
self.reverse_vocab.push(marked.clone());
let entry = SubwordEntry {
subword: marked,
subword_id: id,
phones: phone_ids,
position,
weight,
raw_subword: subword.to_string(),
};
self.entries.push(entry);
id
}
pub fn register_decomposition(&mut self, word_id: WordId, subword_ids: Vec<u32>) {
self.word_decompositions.insert(word_id, subword_ids);
}
pub fn get_subword_id(&self, marked_subword: &str) -> Option<u32> {
self.subword_vocab.get(marked_subword).copied()
}
pub fn get_subword_text(&self, id: u32) -> Option<&str> {
self.reverse_vocab.get(id as usize).map(|s| s.as_str())
}
pub fn get_phone_name(&self, id: PhoneId) -> Option<&str> {
self.reverse_phone_vocab
.get(id as usize)
.map(|s| s.as_str())
}
pub fn build_lexicon_fst(&self) -> VectorWfst<PhoneId, W> {
let mut fst: VectorWfst<PhoneId, W> = VectorWfst::new();
let start = fst.add_state();
fst.set_start(start);
fst.set_final(start, W::one());
for entry in &self.entries {
if entry.phones.is_empty() {
continue;
}
let mut current = start;
let next = fst.add_state();
fst.add_arc(
current,
Some(entry.phones[0]),
Some(entry.phones[0]),
next,
entry.weight.clone(),
);
current = next;
for &phone in entry
.phones
.iter()
.skip(1)
.take(entry.phones.len().saturating_sub(2))
{
let next = fst.add_state();
fst.add_arc(current, Some(phone), Some(phone), next, W::one());
current = next;
}
if entry.phones.len() > 1 {
let last_phone = entry.phones[entry.phones.len() - 1];
fst.add_arc(current, Some(last_phone), Some(last_phone), start, W::one());
} else {
fst.add_arc(current, None, None, start, W::one());
}
}
fst
}
pub fn to_lexicon_entries(&self) -> Vec<LexiconEntry<W>> {
self.entries
.iter()
.map(|e| LexiconEntry {
word: e.subword_id as WordId,
phones: e.phones.clone(),
weight: e.weight.clone(),
auxiliaries: Vec::new(),
})
.collect()
}
pub fn entries(&self) -> &[SubwordEntry<W>] {
&self.entries
}
pub fn get_decomposition(&self, word_id: WordId) -> Option<&[u32]> {
self.word_decompositions.get(&word_id).map(|v| v.as_slice())
}
pub fn reconstruct_words(&self, subword_ids: &[u32]) -> Vec<String> {
let mut words = Vec::new();
let mut current_word = String::new();
for &id in subword_ids {
let Some(marked) = self.get_subword_text(id) else {
continue;
};
let raw = match self.marking_style {
MarkingStyle::LeftMarked => {
if marked.starts_with('+') {
&marked[1..]
} else {
marked
}
}
MarkingStyle::RightMarked => {
if marked.ends_with('+') {
&marked[..marked.len() - 1]
} else {
marked
}
}
MarkingStyle::BoundaryTag => {
if marked.starts_with("<w>") {
&marked[3..]
} else {
marked
}
}
};
let is_boundary = self.is_word_boundary(marked);
match self.marking_style {
MarkingStyle::LeftMarked | MarkingStyle::BoundaryTag => {
if is_boundary && !current_word.is_empty() {
words.push(current_word);
current_word = String::new();
}
current_word.push_str(raw);
}
MarkingStyle::RightMarked => {
current_word.push_str(raw);
if is_boundary && !current_word.is_empty() {
words.push(current_word);
current_word = String::new();
}
}
}
}
if !current_word.is_empty() {
words.push(current_word);
}
words
}
}
impl<W: Semiring + Clone> Default for SubwordLexiconBuilder<W> {
fn default() -> Self {
Self::new(MarkingStyle::LeftMarked)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
use crate::wfst::Wfst;
#[test]
fn test_marking_style_left() {
let builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::LeftMarked);
assert_eq!(
builder.apply_marking("word", SubwordPosition::WholeWord),
"word"
);
assert_eq!(
builder.apply_marking("hel", SubwordPosition::Initial),
"hel"
);
assert_eq!(builder.apply_marking("lo", SubwordPosition::Medial), "+lo");
assert_eq!(builder.apply_marking("ing", SubwordPosition::Final), "+ing");
}
#[test]
fn test_marking_style_right() {
let builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::RightMarked);
assert_eq!(
builder.apply_marking("word", SubwordPosition::WholeWord),
"word"
);
assert_eq!(
builder.apply_marking("hel", SubwordPosition::Initial),
"hel+"
);
assert_eq!(builder.apply_marking("lo", SubwordPosition::Medial), "lo+");
assert_eq!(builder.apply_marking("ing", SubwordPosition::Final), "ing");
}
#[test]
fn test_marking_style_boundary() {
let builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::BoundaryTag);
assert_eq!(
builder.apply_marking("word", SubwordPosition::WholeWord),
"word"
);
assert_eq!(
builder.apply_marking("hel", SubwordPosition::Initial),
"<w>hel"
);
assert_eq!(builder.apply_marking("lo", SubwordPosition::Medial), "lo");
assert_eq!(builder.apply_marking("ing", SubwordPosition::Final), "ing");
}
#[test]
fn test_add_word() {
let mut builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::LeftMarked);
let id = builder.add_word("hello", &["HH", "AH", "L", "OW"], TropicalWeight::one());
assert_eq!(id, 0);
assert_eq!(builder.vocab_size(), 1);
assert_eq!(builder.phone_vocab_size(), 4);
let id2 = builder.add_word("hello", &["HH", "AH", "L", "OW"], TropicalWeight::one());
assert_eq!(id, id2);
}
#[test]
fn test_add_subwords() {
let mut builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::LeftMarked);
let id1 = builder.add_subword(
"hel",
&["HH", "AH", "L"],
SubwordPosition::Initial,
TropicalWeight::one(),
);
let id2 = builder.add_subword(
"lo",
&["L", "OW"],
SubwordPosition::Final,
TropicalWeight::one(),
);
assert_eq!(id1, 0);
assert_eq!(id2, 1);
assert_eq!(builder.vocab_size(), 2);
assert_eq!(builder.get_subword_text(id1), Some("hel"));
assert_eq!(builder.get_subword_text(id2), Some("+lo"));
}
#[test]
fn test_is_word_boundary() {
let left_builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::LeftMarked);
assert!(left_builder.is_word_boundary("hello"));
assert!(!left_builder.is_word_boundary("+ing"));
let right_builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::RightMarked);
assert!(right_builder.is_word_boundary("hello"));
assert!(!right_builder.is_word_boundary("hel+"));
let boundary_builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::BoundaryTag);
assert!(boundary_builder.is_word_boundary("<w>hello"));
assert!(!boundary_builder.is_word_boundary("ing"));
}
#[test]
fn test_reconstruct_words_left_marked() {
let mut builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::LeftMarked);
let id1 = builder.add_subword(
"hel",
&["HH", "AH", "L"],
SubwordPosition::Initial,
TropicalWeight::one(),
);
let id2 = builder.add_subword(
"lo",
&["L", "OW"],
SubwordPosition::Final,
TropicalWeight::one(),
);
let id3 = builder.add_word("world", &["W", "ER", "L", "D"], TropicalWeight::one());
let words = builder.reconstruct_words(&[id1, id2, id3]);
assert_eq!(words, vec!["hello", "world"]);
}
#[test]
fn test_reconstruct_words_right_marked() {
let mut builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::RightMarked);
let id1 = builder.add_subword(
"hel",
&["HH", "AH", "L"],
SubwordPosition::Initial,
TropicalWeight::one(),
);
let id2 = builder.add_subword(
"lo",
&["L", "OW"],
SubwordPosition::Final,
TropicalWeight::one(),
);
let id3 = builder.add_word("world", &["W", "ER", "L", "D"], TropicalWeight::one());
let words = builder.reconstruct_words(&[id1, id2, id3]);
assert_eq!(words, vec!["hello", "world"]);
}
#[test]
fn test_build_lexicon_fst() {
let mut builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::LeftMarked);
builder.add_word("hi", &["HH", "AY"], TropicalWeight::one());
builder.add_word("bye", &["B", "AY"], TropicalWeight::one());
let fst = builder.build_lexicon_fst();
assert!(fst.num_states() >= 3);
assert!(fst.is_valid_state(fst.start()));
}
#[test]
fn test_to_lexicon_entries() {
let mut builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::LeftMarked);
builder.add_word("hello", &["HH", "AH", "L", "OW"], TropicalWeight::new(1.5));
let entries = builder.to_lexicon_entries();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].word, 0);
assert_eq!(entries[0].phones.len(), 4);
assert!((entries[0].weight.value() - 1.5).abs() < 0.001);
}
#[test]
fn test_register_decomposition() {
let mut builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::LeftMarked);
let id1 = builder.add_subword(
"un",
&["AH", "N"],
SubwordPosition::Initial,
TropicalWeight::one(),
);
let id2 = builder.add_subword(
"break",
&["B", "R", "EY", "K"],
SubwordPosition::Medial,
TropicalWeight::one(),
);
let id3 = builder.add_subword(
"able",
&["AH", "B", "AH", "L"],
SubwordPosition::Final,
TropicalWeight::one(),
);
let word_id: WordId = 42;
builder.register_decomposition(word_id, vec![id1, id2, id3]);
let decomp = builder.get_decomposition(word_id);
assert_eq!(decomp, Some(&[id1, id2, id3][..]));
}
#[test]
fn test_phone_interning() {
let mut builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::LeftMarked);
builder.add_word("aaa", &["AH", "AH", "AH"], TropicalWeight::one());
builder.add_word("bbb", &["B", "B", "B"], TropicalWeight::one());
assert_eq!(builder.phone_vocab_size(), 2);
assert_eq!(builder.get_phone_name(0), Some("AH"));
assert_eq!(builder.get_phone_name(1), Some("B"));
}
#[test]
fn test_empty_builder() {
let builder: SubwordLexiconBuilder<TropicalWeight> =
SubwordLexiconBuilder::new(MarkingStyle::LeftMarked);
assert_eq!(builder.vocab_size(), 0);
assert_eq!(builder.phone_vocab_size(), 0);
assert!(builder.entries().is_empty());
let fst = builder.build_lexicon_fst();
assert_eq!(fst.num_states(), 1); }
}