use crate::semiring::{LogWeight, Semiring};
use crate::wfst::{StateId, Wfst, NO_STATE};
use super::log_push::{compute_log_potentials, LogPushError};
#[derive(Clone, Debug)]
pub struct LookaheadConfig {
pub cache: bool,
pub allow_unreachable: bool,
}
impl Default for LookaheadConfig {
fn default() -> Self {
Self {
cache: true,
allow_unreachable: true,
}
}
}
#[derive(Clone, Debug)]
pub struct LookaheadTable {
potentials: Vec<LogWeight>,
total_weight: LogWeight,
num_reachable: usize,
}
impl LookaheadTable {
pub fn get(&self, state: StateId) -> LogWeight {
let idx = state as usize;
if idx < self.potentials.len() {
self.potentials[idx].clone()
} else {
LogWeight::zero()
}
}
pub fn get_value(&self, state: StateId) -> f64 {
let idx = state as usize;
if idx < self.potentials.len() {
self.potentials[idx].value()
} else {
f64::INFINITY
}
}
pub fn is_reachable(&self, state: StateId) -> bool {
let idx = state as usize;
if idx < self.potentials.len() {
!self.potentials[idx].is_zero()
} else {
false
}
}
pub fn total_weight(&self) -> &LogWeight {
&self.total_weight
}
pub fn num_reachable(&self) -> usize {
self.num_reachable
}
pub fn num_states(&self) -> usize {
self.potentials.len()
}
pub fn normalize_score(&self, state: StateId, accumulated: &LogWeight) -> LogWeight {
accumulated.times(&self.get(state))
}
}
pub fn build_lookahead_table<L, F>(
fst: &F,
config: LookaheadConfig,
) -> Result<LookaheadTable, LogPushError>
where
L: Clone,
F: Wfst<L, LogWeight>,
{
let n = fst.num_states();
if n == 0 {
return Ok(LookaheadTable {
potentials: Vec::new(),
total_weight: LogWeight::zero(),
num_reachable: 0,
});
}
if fst.start() == NO_STATE {
return Err(LogPushError::NoStartState);
}
let potentials = match compute_log_potentials(fst) {
Ok(p) => p,
Err(e) => {
if config.allow_unreachable {
return Ok(LookaheadTable {
potentials: vec![LogWeight::zero(); n],
total_weight: LogWeight::zero(),
num_reachable: 0,
});
} else {
return Err(e);
}
}
};
let start = fst.start() as usize;
let total_weight = if start < potentials.len() {
potentials[start].clone()
} else {
LogWeight::zero()
};
let num_reachable = potentials.iter().filter(|p| !p.is_zero()).count();
Ok(LookaheadTable {
potentials,
total_weight,
num_reachable,
})
}
pub fn compute_lookahead_single<L, F>(fst: &F, state: StateId) -> LogWeight
where
L: Clone,
F: Wfst<L, LogWeight>,
{
match compute_log_potentials(fst) {
Ok(potentials) => {
let idx = state as usize;
if idx < potentials.len() {
potentials[idx].clone()
} else {
LogWeight::zero()
}
}
Err(_) => LogWeight::zero(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wfst::{MutableWfst as MutableWfstTrait, VectorWfst};
fn build_simple_chain() -> VectorWfst<char, LogWeight> {
let mut fst = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
let s2 = fst.add_state();
fst.set_start(s0);
fst.set_final(s2, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0));
fst.add_arc(s1, Some('b'), Some('b'), s2, LogWeight::new(2.0));
fst
}
fn build_parallel_paths() -> VectorWfst<char, LogWeight> {
let mut fst = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0));
fst.add_arc(s0, Some('b'), Some('b'), s1, LogWeight::new(2.0));
fst
}
#[test]
fn test_build_lookahead_chain() {
let fst = build_simple_chain();
let table =
build_lookahead_table(&fst, LookaheadConfig::default()).expect("Should build table");
assert_eq!(table.num_states(), 3);
assert_eq!(table.num_reachable(), 3);
assert!(table.get(2).approx_eq(&LogWeight::one(), 0.001));
assert!(table.get(1).approx_eq(&LogWeight::new(2.0), 0.001));
assert!(table.get(0).approx_eq(&LogWeight::new(3.0), 0.001));
}
#[test]
fn test_lookahead_normalize_score() {
let fst = build_simple_chain();
let table =
build_lookahead_table(&fst, LookaheadConfig::default()).expect("Should build table");
let accumulated = LogWeight::new(1.0);
let normalized = table.normalize_score(1, &accumulated);
assert!(
normalized.approx_eq(&LogWeight::new(3.0), 0.001),
"Normalized score should be 3.0, got {:?}",
normalized
);
}
#[test]
fn test_lookahead_parallel() {
let fst = build_parallel_paths();
let table =
build_lookahead_table(&fst, LookaheadConfig::default()).expect("Should build table");
assert!(table.get(1).approx_eq(&LogWeight::one(), 0.001));
let expected = -((-1.0_f64).exp() + (-2.0_f64).exp()).ln();
assert!(
table.get(0).approx_eq(&LogWeight::new(expected), 0.001),
"State 0 lookahead should be {:?}, got {:?}",
expected,
table.get(0)
);
}
#[test]
fn test_lookahead_empty() {
let fst: VectorWfst<char, LogWeight> = VectorWfst::new();
let table =
build_lookahead_table(&fst, LookaheadConfig::default()).expect("Should handle empty");
assert_eq!(table.num_states(), 0);
assert_eq!(table.num_reachable(), 0);
assert!(table.total_weight().is_zero());
}
#[test]
fn test_lookahead_out_of_bounds() {
let fst = build_simple_chain();
let table =
build_lookahead_table(&fst, LookaheadConfig::default()).expect("Should build table");
assert!(table.get(100).is_zero());
assert_eq!(table.get_value(100), f64::INFINITY);
assert!(!table.is_reachable(100));
}
#[test]
fn test_compute_lookahead_single() {
let fst = build_simple_chain();
let lookahead_0 = compute_lookahead_single(&fst, 0);
let lookahead_1 = compute_lookahead_single(&fst, 1);
let lookahead_2 = compute_lookahead_single(&fst, 2);
assert!(lookahead_0.approx_eq(&LogWeight::new(3.0), 0.001));
assert!(lookahead_1.approx_eq(&LogWeight::new(2.0), 0.001));
assert!(lookahead_2.approx_eq(&LogWeight::one(), 0.001));
}
#[test]
fn test_lookahead_total_weight() {
let fst = build_simple_chain();
let table =
build_lookahead_table(&fst, LookaheadConfig::default()).expect("Should build table");
assert!(table.total_weight().approx_eq(&LogWeight::new(3.0), 0.001));
}
#[test]
fn test_lookahead_unreachable_state() {
let mut fst: VectorWfst<char, LogWeight> = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
let s2 = fst.add_state();
let s3 = fst.add_state(); fst.set_start(s0);
fst.set_final(s2, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0));
fst.add_arc(s1, Some('b'), Some('b'), s2, LogWeight::new(2.0));
let table =
build_lookahead_table(&fst, LookaheadConfig::default()).expect("Should build table");
assert!(table.is_reachable(s0));
assert!(table.is_reachable(s1));
assert!(table.is_reachable(s2));
assert!(!table.is_reachable(s3));
assert_eq!(table.num_reachable(), 3);
}
}