use std::collections::VecDeque;
use failure::Fallible;
use crate::fst_traits::{CoreFst, ExpandedFst};
use crate::semirings::Semiring;
use crate::StateId;
pub fn single_source_shortest_distance<F: ExpandedFst>(
fst: &F,
state_id: StateId,
) -> Fallible<Vec<<F as CoreFst>::W>> {
let num_states = fst.num_states();
let mut d = vec![];
let mut r = vec![];
d.resize(num_states, <F as CoreFst>::W::zero());
r.resize(num_states, <F as CoreFst>::W::zero());
if state_id < fst.num_states() {
d[state_id] = <F as CoreFst>::W::one();
r[state_id] = <F as CoreFst>::W::one();
let mut queue = VecDeque::new();
queue.push_back(state_id);
while !queue.is_empty() {
let state_cour = queue.pop_front().unwrap();
let r2 = &r[state_cour].clone();
r[state_cour] = <F as CoreFst>::W::zero();
for arc in fst.arcs_iter(state_cour)? {
let nextstate = arc.nextstate;
if d[nextstate] != d[nextstate].plus(&r2.times(&arc.weight)?)? {
d[nextstate] = d[nextstate].plus(&r2.times(&arc.weight)?)?;
r[nextstate] = r[nextstate].plus(&r2.times(&arc.weight)?)?;
if !queue.contains(&nextstate) {
queue.push_back(nextstate);
}
}
}
}
}
Ok(d)
}
pub fn shortest_distance<F: ExpandedFst>(fst: &F) -> Fallible<Vec<<F as CoreFst>::W>> {
if let Some(start_state) = fst.start() {
return single_source_shortest_distance(fst, start_state);
}
Ok(vec![<F as CoreFst>::W::zero(); fst.num_states()])
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fst_traits::StateIterator;
use crate::semirings::{IntegerWeight, Semiring};
use crate::test_data::vector_fst::get_vector_fsts_for_tests;
#[test]
fn test_single_source_shortest_distance_generic() -> Fallible<()> {
for data in get_vector_fsts_for_tests() {
let fst = data.fst;
let d_ref = data.all_distances;
for state in fst.states_iter() {
let d = single_source_shortest_distance(&fst, state)?;
assert_eq!(
d, d_ref[state],
"Test failing for single source shortest distance on wFST {:?} at state {:?}",
data.name, state
);
}
let d = single_source_shortest_distance(&fst, fst.num_states())?;
assert_eq!(
d,
vec![IntegerWeight::zero(); fst.num_states()],
"Test failing for single source shortest distance on wFST {:?} at state {:?}",
data.name,
fst.num_states()
);
}
Ok(())
}
#[test]
fn test_shortest_distance_generic() -> Fallible<()> {
for data in get_vector_fsts_for_tests() {
let fst = data.fst;
let d_ref = data.all_distances;
let d = shortest_distance(&fst)?;
if let Some(start_state) = fst.start() {
assert_eq!(
d, d_ref[start_state],
"Test failing for all shortest distance on wFST : {:?}",
data.name
);
} else {
assert_eq!(
d,
vec![IntegerWeight::zero(); fst.num_states()],
"Test failing for all shortest distance on wFST : {:?}",
data.name
);
}
}
Ok(())
}
}