use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use crate::semiring::Semiring;
use crate::wfst::{MutableWfst, StateId, VectorWfst, Wfst};
pub type PhoneId = u32;
pub const EPSILON: Option<PhoneId> = None;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct ContextState {
pub left_context: Vec<PhoneId>,
}
impl ContextState {
pub fn initial() -> Self {
Self {
left_context: Vec::new(),
}
}
pub fn with_context(context: Vec<PhoneId>) -> Self {
Self {
left_context: context,
}
}
pub fn extend(&self, phone: PhoneId, max_context: usize) -> Self {
let mut new_context = self.left_context.clone();
new_context.push(phone);
if new_context.len() > max_context {
new_context.remove(0);
}
Self {
left_context: new_context,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct ContextDependencyConfig {
pub deterministic: bool,
pub boundary_symbol: Option<PhoneId>,
pub auxiliary_self_loops: bool,
pub auxiliary_symbols: Option<std::ops::Range<PhoneId>>,
}
pub struct ContextDependencyBuilder<W: Semiring> {
num_phones: usize,
left_context_size: usize,
right_context_size: usize,
config: ContextDependencyConfig,
_weight: std::marker::PhantomData<W>,
}
impl<W: Semiring> ContextDependencyBuilder<W> {
pub fn new(num_phones: usize, left_context_size: usize, right_context_size: usize) -> Self {
Self {
num_phones,
left_context_size,
right_context_size,
config: ContextDependencyConfig::default(),
_weight: std::marker::PhantomData,
}
}
pub fn config(mut self, config: ContextDependencyConfig) -> Self {
self.config = config;
self
}
pub fn left_context_size(&self) -> usize {
self.left_context_size
}
pub fn right_context_size(&self) -> usize {
self.right_context_size
}
pub fn deterministic(mut self, boundary_symbol: PhoneId) -> Self {
self.config.deterministic = true;
self.config.boundary_symbol = Some(boundary_symbol);
self
}
pub fn with_auxiliary_symbols(mut self, range: std::ops::Range<PhoneId>) -> Self {
self.config.auxiliary_self_loops = true;
self.config.auxiliary_symbols = Some(range);
self
}
pub fn build(&self) -> VectorWfst<PhoneId, W> {
let mut fst: VectorWfst<PhoneId, W> = VectorWfst::new();
let mut state_map: HashMap<ContextState, StateId> = HashMap::new();
let initial = ContextState::initial();
let start_id = fst.add_state();
fst.set_start(start_id);
state_map.insert(initial.clone(), start_id);
let mut queue = vec![initial];
while let Some(current_state) = queue.pop() {
let current_id = *state_map
.get(¤t_state)
.expect("state should exist in map");
for phone in 0..self.num_phones as PhoneId {
let next_state = current_state.extend(phone, self.left_context_size);
let next_id = if let Some(&id) = state_map.get(&next_state) {
id
} else {
let id = fst.add_state();
state_map.insert(next_state.clone(), id);
queue.push(next_state.clone());
id
};
let output_label = self.compute_cd_label(¤t_state, phone);
fst.add_arc(
current_id,
Some(phone),
Some(output_label),
next_id,
W::one(),
);
}
if self.config.auxiliary_self_loops {
if let Some(ref range) = self.config.auxiliary_symbols {
for aux in range.clone() {
fst.add_arc(current_id, Some(aux), Some(aux), current_id, W::one());
}
}
}
if current_state.left_context.len() >= self.left_context_size {
fst.set_final(current_id, W::one());
}
}
if self.config.deterministic {
if let Some(boundary) = self.config.boundary_symbol {
self.add_boundary_handling(&mut fst, &state_map, boundary);
}
}
for id in 0..fst.num_states() as StateId {
if !fst.is_final(id) {
fst.set_final(id, W::one());
}
}
fst
}
fn compute_cd_label(&self, state: &ContextState, center_phone: PhoneId) -> PhoneId {
let base = self.num_phones as PhoneId + 1;
let mut label = center_phone;
for (i, &ctx_phone) in state.left_context.iter().rev().enumerate() {
let multiplier = base.pow((i + 1) as u32);
label += (ctx_phone + 1) * multiplier;
}
label
}
fn add_boundary_handling(
&self,
fst: &mut VectorWfst<PhoneId, W>,
state_map: &HashMap<ContextState, StateId>,
boundary: PhoneId,
) {
for (context_state, &state_id) in state_map {
if context_state.left_context.is_empty() {
fst.add_arc(state_id, Some(boundary), Some(boundary), state_id, W::one());
continue;
}
let boundary_label = self.compute_cd_label(context_state, boundary);
let exit_state = fst.add_state();
fst.set_final(exit_state, W::one());
fst.add_arc(
state_id,
Some(boundary),
Some(boundary_label),
exit_state,
W::one(),
);
}
}
}
pub struct TriphoneBuilder<W: Semiring> {
inner: ContextDependencyBuilder<W>,
}
impl<W: Semiring> TriphoneBuilder<W> {
pub fn new(num_phones: usize) -> Self {
Self {
inner: ContextDependencyBuilder::new(num_phones, 1, 1),
}
}
pub fn config(mut self, config: ContextDependencyConfig) -> Self {
self.inner.config = config;
self
}
pub fn deterministic(mut self, boundary_symbol: PhoneId) -> Self {
self.inner = self.inner.deterministic(boundary_symbol);
self
}
pub fn build(&self) -> VectorWfst<PhoneId, W> {
self.inner.build()
}
pub fn expected_states(&self) -> usize {
let n = self.inner.num_phones;
1 + n
}
pub fn expected_arcs(&self) -> usize {
let n = self.inner.num_phones;
(1 + n) * n
}
}
pub struct TetraploneBuilder<W: Semiring> {
inner: ContextDependencyBuilder<W>,
}
impl<W: Semiring> TetraploneBuilder<W> {
pub fn new(num_phones: usize) -> Self {
Self {
inner: ContextDependencyBuilder::new(num_phones, 2, 2),
}
}
pub fn config(mut self, config: ContextDependencyConfig) -> Self {
self.inner.config = config;
self
}
pub fn deterministic(mut self, boundary_symbol: PhoneId) -> Self {
self.inner = self.inner.deterministic(boundary_symbol);
self
}
pub fn build(&self) -> VectorWfst<PhoneId, W> {
self.inner.build()
}
pub fn expected_states(&self) -> usize {
let n = self.inner.num_phones;
1 + n + n * n
}
pub fn expected_arcs(&self) -> usize {
let n = self.inner.num_phones;
(1 + n + n * n) * n
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::LogWeight;
use crate::wfst::{Wfst, NO_STATE};
#[test]
fn test_context_state_initial() {
let state = ContextState::initial();
assert!(state.left_context.is_empty());
}
#[test]
fn test_context_state_extend() {
let state = ContextState::initial();
let state1 = state.extend(1, 2);
assert_eq!(state1.left_context, vec![1]);
let state2 = state1.extend(2, 2);
assert_eq!(state2.left_context, vec![1, 2]);
let state3 = state2.extend(3, 2);
assert_eq!(state3.left_context, vec![2, 3]);
}
#[test]
fn test_triphone_builder() {
let builder = TriphoneBuilder::<LogWeight>::new(5);
let fst = builder.build();
assert!(fst.num_states() >= 1);
assert!(fst.start() != NO_STATE);
}
#[test]
fn test_triphone_state_count() {
let builder = TriphoneBuilder::<LogWeight>::new(3);
let fst = builder.build();
assert_eq!(fst.num_states(), 4);
}
#[test]
fn test_triphone_arc_count() {
let builder = TriphoneBuilder::<LogWeight>::new(3);
let fst = builder.build();
let total_arcs: usize = (0..fst.num_states() as StateId)
.map(|s| fst.transitions(s).len())
.sum();
assert_eq!(total_arcs, 12);
}
#[test]
fn test_tetraphone_state_count() {
let builder = TetraploneBuilder::<LogWeight>::new(3);
let fst = builder.build();
assert_eq!(fst.num_states(), 13);
}
#[test]
fn test_cd_label_encoding() {
let builder = ContextDependencyBuilder::<LogWeight>::new(10, 1, 1);
let empty = ContextState::initial();
let with_ctx = ContextState::with_context(vec![5]);
let label1 = builder.compute_cd_label(&empty, 3);
let label2 = builder.compute_cd_label(&with_ctx, 3);
assert_ne!(label1, label2);
}
#[test]
fn test_all_states_final() {
let builder = TriphoneBuilder::<LogWeight>::new(3);
let fst = builder.build();
for id in 0..fst.num_states() as StateId {
assert!(fst.is_final(id));
}
}
#[test]
fn test_phone_0_contributes_to_label() {
let builder = ContextDependencyBuilder::<LogWeight>::new(10, 2, 1);
let empty = ContextState::initial();
let with_zero = ContextState::with_context(vec![0]);
let label_empty = builder.compute_cd_label(&empty, 5);
let label_with_zero = builder.compute_cd_label(&with_zero, 5);
assert_ne!(
label_empty, label_with_zero,
"Phone 0 in context must produce different label than empty context. \
Empty: {}, With [0]: {}",
label_empty, label_with_zero
);
}
#[test]
fn test_different_phone_0_positions() {
let builder = ContextDependencyBuilder::<LogWeight>::new(10, 2, 1);
let ctx_01 = ContextState::with_context(vec![0, 1]);
let ctx_10 = ContextState::with_context(vec![1, 0]);
let label_01 = builder.compute_cd_label(&ctx_01, 5);
let label_10 = builder.compute_cd_label(&ctx_10, 5);
assert_ne!(
label_01, label_10,
"Different phone 0 positions must produce different labels. \
[0,1]: {}, [1,0]: {}",
label_01, label_10
);
}
#[test]
fn test_cd_label_injectivity_with_phone_0() {
let builder = ContextDependencyBuilder::<LogWeight>::new(5, 2, 1);
let mut seen_labels: std::collections::HashMap<PhoneId, (Vec<PhoneId>, PhoneId)> =
std::collections::HashMap::new();
let phones: Vec<PhoneId> = vec![0, 1, 2];
for center in phones.iter().copied() {
let empty = ContextState::initial();
let label = builder.compute_cd_label(&empty, center);
if let Some((prev_ctx, prev_center)) = seen_labels.insert(label, (vec![], center)) {
panic!(
"Label collision: {} produced by both ({:?}, {}) and ({:?}, {})",
label,
prev_ctx,
prev_center,
vec![] as Vec<PhoneId>,
center
);
}
for &ctx0 in &phones {
let ctx = ContextState::with_context(vec![ctx0]);
let label = builder.compute_cd_label(&ctx, center);
if let Some((prev_ctx, prev_center)) =
seen_labels.insert(label, (vec![ctx0], center))
{
panic!(
"Label collision: {} produced by both ({:?}, {}) and ({:?}, {})",
label,
prev_ctx,
prev_center,
vec![ctx0],
center
);
}
}
for &ctx0 in &phones {
for &ctx1 in &phones {
let ctx = ContextState::with_context(vec![ctx0, ctx1]);
let label = builder.compute_cd_label(&ctx, center);
if let Some((prev_ctx, prev_center)) =
seen_labels.insert(label, (vec![ctx0, ctx1], center))
{
panic!(
"Label collision: {} produced by both ({:?}, {}) and ({:?}, {})",
label,
prev_ctx,
prev_center,
vec![ctx0, ctx1],
center
);
}
}
}
}
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use crate::semiring::LogWeight;
use crate::wfst::{Wfst, NO_STATE};
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn initial_state_empty(_seed in any::<u64>()) {
let state = ContextState::initial();
prop_assert!(state.left_context.is_empty());
}
#[test]
fn extend_adds_phone(phone in 0u32..100, max_ctx in 1usize..5) {
let state = ContextState::initial();
let extended = state.extend(phone, max_ctx);
prop_assert!(extended.left_context.contains(&phone));
}
#[test]
fn extend_respects_max_context(
phones in prop::collection::vec(0u32..100, 1..20),
max_ctx in 1usize..5
) {
let mut state = ContextState::initial();
for &phone in &phones {
state = state.extend(phone, max_ctx);
prop_assert!(state.left_context.len() <= max_ctx);
}
}
#[test]
fn extend_removes_oldest_when_full(max_ctx in 1usize..5) {
let mut state = ContextState::initial();
for i in 0..max_ctx as u32 {
state = state.extend(i, max_ctx);
}
prop_assert_eq!(state.left_context.len(), max_ctx);
let new_phone = max_ctx as u32 + 100;
state = state.extend(new_phone, max_ctx);
prop_assert_eq!(state.left_context.len(), max_ctx);
prop_assert!(!state.left_context.contains(&0));
prop_assert!(state.left_context.contains(&new_phone));
}
#[test]
fn with_context_preserves(context in prop::collection::vec(0u32..100, 0..5)) {
let state = ContextState::with_context(context.clone());
prop_assert_eq!(state.left_context, context);
}
#[test]
fn context_state_equality(context in prop::collection::vec(0u32..50, 0..4)) {
let state1 = ContextState::with_context(context.clone());
let state2 = ContextState::with_context(context);
prop_assert_eq!(state1, state2);
}
#[test]
fn different_contexts_different_states(
ctx1 in prop::collection::vec(0u32..50, 1..3),
ctx2 in prop::collection::vec(50u32..100, 1..3)
) {
let state1 = ContextState::with_context(ctx1);
let state2 = ContextState::with_context(ctx2);
prop_assert_ne!(state1, state2);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn default_config_non_deterministic(_seed in any::<u64>()) {
let config = ContextDependencyConfig::default();
prop_assert!(!config.deterministic);
prop_assert!(config.boundary_symbol.is_none());
}
#[test]
fn default_config_no_aux(_seed in any::<u64>()) {
let config = ContextDependencyConfig::default();
prop_assert!(!config.auxiliary_self_loops);
prop_assert!(config.auxiliary_symbols.is_none());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(30))]
#[test]
fn cd_label_deterministic(
num_phones in 2usize..10,
context in prop::collection::vec(0u32..10, 0..2),
center in 0u32..10
) {
let builder = ContextDependencyBuilder::<LogWeight>::new(num_phones, 2, 1);
let state = ContextState::with_context(context);
let label1 = builder.compute_cd_label(&state, center);
let label2 = builder.compute_cd_label(&state, center);
prop_assert_eq!(label1, label2);
}
#[test]
fn cd_label_context_sensitivity(
num_phones in 5usize..15,
center in 0u32..5
) {
let builder = ContextDependencyBuilder::<LogWeight>::new(num_phones, 1, 1);
let empty = ContextState::initial();
let with_ctx = ContextState::with_context(vec![1]);
let label1 = builder.compute_cd_label(&empty, center);
let label2 = builder.compute_cd_label(&with_ctx, center);
prop_assert_ne!(label1, label2);
}
#[test]
fn cd_label_center_sensitivity(
num_phones in 5usize..15,
center1 in 0u32..5,
center2 in 5u32..10
) {
let builder = ContextDependencyBuilder::<LogWeight>::new(num_phones, 1, 1);
let state = ContextState::initial();
let label1 = builder.compute_cd_label(&state, center1);
let label2 = builder.compute_cd_label(&state, center2);
prop_assert_ne!(label1, label2);
}
#[test]
fn builder_config_updates(
num_phones in 2usize..10,
deterministic in any::<bool>()
) {
let config = ContextDependencyConfig {
deterministic,
..Default::default()
};
let builder = ContextDependencyBuilder::<LogWeight>::new(num_phones, 1, 1)
.config(config);
prop_assert_eq!(builder.config.deterministic, deterministic);
}
#[test]
fn builder_deterministic_sets_fields(
num_phones in 2usize..10,
boundary in 0u32..100
) {
let builder = ContextDependencyBuilder::<LogWeight>::new(num_phones, 1, 1)
.deterministic(boundary);
prop_assert!(builder.config.deterministic);
prop_assert_eq!(builder.config.boundary_symbol, Some(boundary));
}
#[test]
fn builder_aux_symbols_sets_fields(
num_phones in 2usize..10,
start in 100u32..200,
end in 200u32..300
) {
let builder = ContextDependencyBuilder::<LogWeight>::new(num_phones, 1, 1)
.with_auxiliary_symbols(start..end);
prop_assert!(builder.config.auxiliary_self_loops);
prop_assert_eq!(builder.config.auxiliary_symbols, Some(start..end));
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn triphone_state_count(num_phones in 2usize..8) {
let builder = TriphoneBuilder::<LogWeight>::new(num_phones);
let fst = builder.build();
prop_assert_eq!(fst.num_states(), 1 + num_phones);
}
#[test]
fn triphone_arc_count(num_phones in 2usize..8) {
let builder = TriphoneBuilder::<LogWeight>::new(num_phones);
let fst = builder.build();
let total_arcs: usize = (0..fst.num_states() as StateId)
.map(|s| fst.transitions(s).len())
.sum();
prop_assert_eq!(total_arcs, (1 + num_phones) * num_phones);
}
#[test]
fn triphone_expected_states_accurate(num_phones in 2usize..8) {
let builder = TriphoneBuilder::<LogWeight>::new(num_phones);
let fst = builder.build();
prop_assert_eq!(fst.num_states(), builder.expected_states());
}
#[test]
fn triphone_expected_arcs_accurate(num_phones in 2usize..8) {
let builder = TriphoneBuilder::<LogWeight>::new(num_phones);
let fst = builder.build();
let total_arcs: usize = (0..fst.num_states() as StateId)
.map(|s| fst.transitions(s).len())
.sum();
prop_assert_eq!(total_arcs, builder.expected_arcs());
}
#[test]
fn triphone_all_states_final(num_phones in 2usize..8) {
let builder = TriphoneBuilder::<LogWeight>::new(num_phones);
let fst = builder.build();
for id in 0..fst.num_states() as StateId {
prop_assert!(fst.is_final(id));
}
}
#[test]
fn triphone_has_start(num_phones in 2usize..8) {
let builder = TriphoneBuilder::<LogWeight>::new(num_phones);
let fst = builder.build();
prop_assert!(fst.start() != NO_STATE);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(15))]
#[test]
fn tetraphone_state_count(num_phones in 2usize..5) {
let builder = TetraploneBuilder::<LogWeight>::new(num_phones);
let fst = builder.build();
let expected = 1 + num_phones + num_phones * num_phones;
prop_assert_eq!(fst.num_states(), expected);
}
#[test]
fn tetraphone_expected_states_accurate(num_phones in 2usize..5) {
let builder = TetraploneBuilder::<LogWeight>::new(num_phones);
let fst = builder.build();
prop_assert_eq!(fst.num_states(), builder.expected_states());
}
#[test]
fn tetraphone_expected_arcs_accurate(num_phones in 2usize..5) {
let builder = TetraploneBuilder::<LogWeight>::new(num_phones);
let fst = builder.build();
let total_arcs: usize = (0..fst.num_states() as StateId)
.map(|s| fst.transitions(s).len())
.sum();
prop_assert_eq!(total_arcs, builder.expected_arcs());
}
#[test]
fn tetraphone_all_states_final(num_phones in 2usize..5) {
let builder = TetraploneBuilder::<LogWeight>::new(num_phones);
let fst = builder.build();
for id in 0..fst.num_states() as StateId {
prop_assert!(fst.is_final(id));
}
}
#[test]
fn tetraphone_more_states_than_triphone(num_phones in 3usize..6) {
let tri = TriphoneBuilder::<LogWeight>::new(num_phones);
let tetra = TetraploneBuilder::<LogWeight>::new(num_phones);
prop_assert!(tetra.expected_states() > tri.expected_states());
}
#[test]
fn tetraphone_more_arcs_than_triphone(num_phones in 3usize..6) {
let tri = TriphoneBuilder::<LogWeight>::new(num_phones);
let tetra = TetraploneBuilder::<LogWeight>::new(num_phones);
prop_assert!(tetra.expected_arcs() > tri.expected_arcs());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn cd_label_injective(
num_phones in 3usize..8,
ctx1 in prop::collection::vec(0u32..3, 0..2),
ctx2 in prop::collection::vec(0u32..3, 0..2),
center1 in 0u32..3,
center2 in 0u32..3
) {
let ctx1: Vec<u32> = ctx1.into_iter().map(|p| p % num_phones as u32).collect();
let ctx2: Vec<u32> = ctx2.into_iter().map(|p| p % num_phones as u32).collect();
let center1 = center1 % num_phones as u32;
let center2 = center2 % num_phones as u32;
let builder = ContextDependencyBuilder::<LogWeight>::new(num_phones, 2, 1);
let state1 = ContextState::with_context(ctx1.clone());
let state2 = ContextState::with_context(ctx2.clone());
let label1 = builder.compute_cd_label(&state1, center1);
let label2 = builder.compute_cd_label(&state2, center2);
if label1 == label2 {
prop_assert_eq!(ctx1, ctx2);
prop_assert_eq!(center1, center2);
}
}
#[test]
fn phone_0_context_differs_from_empty(
num_phones in 3usize..10,
center in 0u32..5
) {
let center = center % num_phones as u32;
let builder = ContextDependencyBuilder::<LogWeight>::new(num_phones, 2, 1);
let empty = ContextState::initial();
let with_zero = ContextState::with_context(vec![0]);
let label_empty = builder.compute_cd_label(&empty, center);
let label_with_zero = builder.compute_cd_label(&with_zero, center);
prop_assert_ne!(
label_empty,
label_with_zero,
"Phone 0 in context must produce different label than empty context"
);
}
#[test]
fn phone_0_position_matters(
num_phones in 3usize..10,
center in 0u32..5
) {
let center = center % num_phones as u32;
let builder = ContextDependencyBuilder::<LogWeight>::new(num_phones, 2, 1);
let ctx_01 = ContextState::with_context(vec![0, 1]);
let ctx_10 = ContextState::with_context(vec![1, 0]);
let label_01 = builder.compute_cd_label(&ctx_01, center);
let label_10 = builder.compute_cd_label(&ctx_10, center);
prop_assert_ne!(
label_01,
label_10,
"Different phone 0 positions must produce different labels"
);
}
}
}