use crate::semiring::{Semiring, StarSemiring};
use crate::wfst::{StateId, Wfst, NO_STATE};
use super::queue::{
FifoQueue, QueueType, ShortestDistanceQueue, ShortestFirstQueue, TopologicalQueue,
};
#[derive(Clone, Debug)]
pub struct ShortestDistanceConfig {
pub queue_type: QueueType,
pub max_iterations: Option<usize>,
pub is_acyclic: Option<bool>,
pub epsilon: f64,
}
impl Default for ShortestDistanceConfig {
fn default() -> Self {
Self {
queue_type: QueueType::Auto,
max_iterations: None,
is_acyclic: None,
epsilon: 1e-10,
}
}
}
impl ShortestDistanceConfig {
pub fn acyclic() -> Self {
Self {
queue_type: QueueType::Topological,
is_acyclic: Some(true),
..Default::default()
}
}
pub fn tropical() -> Self {
Self {
queue_type: QueueType::ShortestFirst,
..Default::default()
}
}
pub fn general() -> Self {
Self {
queue_type: QueueType::Fifo,
..Default::default()
}
}
}
pub fn single_source_shortest_distance<L, W, F>(
fst: &F,
config: ShortestDistanceConfig,
) -> Option<Vec<W>>
where
L: Clone,
W: Semiring,
F: Wfst<L, W>,
{
let num_states = fst.num_states();
if num_states == 0 {
return Some(Vec::new());
}
let start = fst.start();
if start == NO_STATE || start as usize >= num_states {
return Some(vec![W::zero(); num_states]);
}
match config.queue_type {
QueueType::Fifo => {
let queue = FifoQueue::with_capacity(num_states);
single_source_shortest_distance_impl(fst, queue, &config)
}
QueueType::ShortestFirst => {
let mut queue = ShortestFirstQueue::with_capacity(num_states);
queue.init_distances(num_states);
single_source_shortest_distance_impl(fst, queue, &config)
}
QueueType::Topological => {
if let Some(order) = compute_topological_order(fst) {
let queue = TopologicalQueue::from_order(order);
single_source_shortest_distance_impl(fst, queue, &config)
} else {
let queue = FifoQueue::with_capacity(num_states);
single_source_shortest_distance_impl(fst, queue, &config)
}
}
QueueType::Auto => {
if let Some(order) = compute_topological_order(fst) {
let queue = TopologicalQueue::from_order(order);
single_source_shortest_distance_impl(fst, queue, &config)
} else {
let queue = FifoQueue::with_capacity(num_states);
single_source_shortest_distance_impl(fst, queue, &config)
}
}
}
}
pub fn single_source_shortest_distance_with_queue<L, W, F, Q>(fst: &F, queue: Q) -> Option<Vec<W>>
where
L: Clone,
W: Semiring,
F: Wfst<L, W>,
Q: ShortestDistanceQueue<W>,
{
single_source_shortest_distance_impl(fst, queue, &ShortestDistanceConfig::default())
}
fn single_source_shortest_distance_impl<L, W, F, Q>(
fst: &F,
mut queue: Q,
config: &ShortestDistanceConfig,
) -> Option<Vec<W>>
where
L: Clone,
W: Semiring,
F: Wfst<L, W>,
Q: ShortestDistanceQueue<W>,
{
let num_states = fst.num_states();
let start = fst.start();
let mut distance: Vec<W> = vec![W::zero(); num_states];
let mut remainder: Vec<W> = vec![W::zero(); num_states];
distance[start as usize] = W::one();
remainder[start as usize] = W::one();
queue.insert(start, &distance[start as usize]);
let max_iterations = config.max_iterations.unwrap_or(usize::MAX);
let mut iterations = 0;
while let Some(state) = queue.pop() {
iterations += 1;
if iterations > max_iterations {
return None;
}
let state_idx = state as usize;
let r = remainder[state_idx];
remainder[state_idx] = W::zero();
if r.is_zero() {
continue;
}
for transition in fst.transitions(state) {
let next_state = transition.to;
let next_idx = next_state as usize;
let contribution = r.times(&transition.weight);
let old_distance = distance[next_idx];
let new_distance = old_distance.plus(&contribution);
if !new_distance.approx_eq(&old_distance, config.epsilon) {
remainder[next_idx] = remainder[next_idx].plus(&contribution);
distance[next_idx] = new_distance;
queue.update(next_state, &distance[next_idx]);
}
}
}
Some(distance)
}
pub fn all_pairs_shortest_distance<L, W, F>(fst: &F) -> Option<Vec<Vec<W>>>
where
L: Clone,
W: StarSemiring,
F: Wfst<L, W>,
{
let n = fst.num_states();
if n == 0 {
return Some(Vec::new());
}
let mut d: Vec<Vec<W>> = vec![vec![W::zero(); n]; n];
for i in 0..n {
d[i][i] = W::one();
}
for state in 0..n as StateId {
for transition in fst.transitions(state) {
let from = state as usize;
let to = transition.to as usize;
d[from][to] = d[from][to].plus(&transition.weight);
}
}
for k in 0..n {
let star_kk = d[k][k].star()?;
for i in 0..n {
if d[i][k].is_zero() {
continue; }
for j in 0..n {
if d[k][j].is_zero() {
continue; }
let through_k = d[i][k].times(&star_kk).times(&d[k][j]);
d[i][j] = d[i][j].plus(&through_k);
}
}
}
Some(d)
}
fn compute_topological_order<L, W, F>(fst: &F) -> Option<Vec<StateId>>
where
L: Clone,
W: Semiring,
F: Wfst<L, W>,
{
let n = fst.num_states();
if n == 0 {
return Some(Vec::new());
}
let mut in_degree: Vec<usize> = vec![0; n];
for state in 0..n as StateId {
for transition in fst.transitions(state) {
let to = transition.to as usize;
if to < n {
in_degree[to] += 1;
}
}
}
let mut queue: Vec<StateId> = Vec::with_capacity(n);
let mut result: Vec<StateId> = Vec::with_capacity(n);
for (state, °) in in_degree.iter().enumerate() {
if deg == 0 {
queue.push(state as StateId);
}
}
while let Some(state) = queue.pop() {
result.push(state);
for transition in fst.transitions(state) {
let next = transition.to as usize;
if next < n {
in_degree[next] -= 1;
if in_degree[next] == 0 {
queue.push(next as StateId);
}
}
}
}
if result.len() == n {
Some(result)
} else {
None
}
}
pub fn shortest_distance_to_final<L, W, F>(fst: &F, config: ShortestDistanceConfig) -> Option<W>
where
L: Clone,
W: Semiring,
F: Wfst<L, W>,
{
let distances = single_source_shortest_distance(fst, config)?;
let mut total = W::zero();
for (state, dist) in distances.iter().enumerate() {
if fst.is_final(state as StateId) {
let final_weight = fst.final_weight(state as StateId);
total = total.plus(&dist.times(&final_weight));
}
}
Some(total)
}
pub fn reverse_shortest_distance<L, W, F>(fst: &F, config: ShortestDistanceConfig) -> Option<Vec<W>>
where
L: Clone,
W: Semiring,
F: Wfst<L, W>,
{
let n = fst.num_states();
if n == 0 {
return Some(Vec::new());
}
let mut reverse_adj: Vec<Vec<(StateId, W)>> = vec![Vec::new(); n];
for state in 0..n as StateId {
for transition in fst.transitions(state) {
let to = transition.to as usize;
if to < n {
reverse_adj[to].push((state, transition.weight));
}
}
}
let mut distance: Vec<W> = vec![W::zero(); n];
let mut queue = FifoQueue::with_capacity(n);
for state in 0..n as StateId {
if fst.is_final(state) {
distance[state as usize] = fst.final_weight(state);
queue.insert(state, &distance[state as usize]);
}
}
let mut remainder: Vec<W> = distance.clone();
while let Some(state) = queue.pop() {
let state_idx = state as usize;
let r = remainder[state_idx];
remainder[state_idx] = W::zero();
if r.is_zero() {
continue;
}
for &(prev_state, ref weight) in &reverse_adj[state_idx] {
let prev_idx = prev_state as usize;
let contribution = weight.times(&r);
let old_distance = distance[prev_idx];
let new_distance = old_distance.plus(&contribution);
if !new_distance.approx_eq(&old_distance, config.epsilon) {
remainder[prev_idx] = remainder[prev_idx].plus(&contribution);
distance[prev_idx] = new_distance;
queue.update(prev_state, &distance[prev_idx]);
}
}
}
Some(distance)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::{LogWeight, TropicalWeight};
use crate::wfst::{MutableWfst, VectorWfst, VectorWfstBuilder};
mod property_tests {
use super::*;
use crate::test_utils::arb_acyclic_wfst_tropical;
use proptest::prelude::*;
proptest! {
#[test]
fn queue_independence_acyclic(
fst in arb_acyclic_wfst_tropical(10, 3)
) {
if fst.num_states() == 0 {
return Ok(());
}
let fifo_result = single_source_shortest_distance(
&fst,
ShortestDistanceConfig { queue_type: QueueType::Fifo, ..Default::default() }
);
let topo_result = single_source_shortest_distance(
&fst,
ShortestDistanceConfig { queue_type: QueueType::Topological, ..Default::default() }
);
let auto_result = single_source_shortest_distance(
&fst,
ShortestDistanceConfig { queue_type: QueueType::Auto, ..Default::default() }
);
match (fifo_result, topo_result, auto_result) {
(Some(fifo), Some(topo), Some(auto)) => {
prop_assert_eq!(fifo.len(), topo.len());
prop_assert_eq!(fifo.len(), auto.len());
for i in 0..fifo.len() {
prop_assert!(
fifo[i].approx_eq(&topo[i], 1e-6),
"FIFO[{}]={:?} != Topo[{}]={:?}",
i, fifo[i], i, topo[i]
);
prop_assert!(
fifo[i].approx_eq(&auto[i], 1e-6),
"FIFO[{}]={:?} != Auto[{}]={:?}",
i, fifo[i], i, auto[i]
);
}
}
(None, None, None) => { }
_ => {
prop_assert!(false, "Inconsistent results across queue types");
}
}
}
#[test]
fn shortest_distance_non_negative(
fst in arb_acyclic_wfst_tropical(8, 3)
) {
if let Some(distances) = single_source_shortest_distance(&fst, ShortestDistanceConfig::default()) {
for d in distances {
if !d.is_zero() {
prop_assert!(d.value() >= 0.0, "Negative distance: {:?}", d);
}
}
}
}
#[test]
fn shortest_distance_start_is_one(
fst in arb_acyclic_wfst_tropical(8, 3)
) {
if fst.num_states() == 0 || fst.start() == NO_STATE {
return Ok(());
}
if let Some(distances) = single_source_shortest_distance(&fst, ShortestDistanceConfig::default()) {
let start = fst.start() as usize;
prop_assert!(
distances[start].approx_eq(&TropicalWeight::one(), 1e-10),
"Distance to start should be one, got {:?}",
distances[start]
);
}
}
#[test]
fn all_pairs_diagonal_is_one(
fst in arb_acyclic_wfst_tropical(5, 2)
) {
if fst.num_states() == 0 {
return Ok(());
}
if let Some(distances) = all_pairs_shortest_distance(&fst) {
for i in 0..distances.len() {
prop_assert!(
distances[i][i].approx_eq(&TropicalWeight::one(), 1e-10),
"Diagonal d[{}][{}] should be one, got {:?}",
i, i, distances[i][i]
);
}
}
}
#[test]
fn reverse_distance_final_states(
fst in arb_acyclic_wfst_tropical(6, 2)
) {
if fst.num_states() == 0 {
return Ok(());
}
if let Some(distances) = reverse_shortest_distance(&fst, ShortestDistanceConfig::default()) {
for state in 0..fst.num_states() {
let state_id = state as StateId;
if fst.is_final(state_id) {
let final_w = fst.final_weight(state_id);
let has_transitions = !fst.transitions(state_id).is_empty();
if has_transitions {
prop_assert!(
distances[state].value() <= final_w.value() + 1e-6,
"Final state {} distance {:?} > final_weight {:?}",
state, distances[state], final_w
);
} else {
prop_assert!(
distances[state].approx_eq(&final_w, 1e-6),
"Leaf final state {} distance {:?} != final_weight {:?}",
state, distances[state], final_w
);
}
}
}
}
}
}
}
fn build_linear_fst(n: usize) -> VectorWfst<char, TropicalWeight> {
let mut fst: VectorWfst<char, TropicalWeight> = VectorWfst::with_capacity(n + 1);
fst.reserve_states(n + 1);
for _ in 0..=n {
fst.add_state();
}
fst.set_start(0);
fst.set_final(n as StateId, TropicalWeight::one());
for i in 0..n {
fst.add_arc(
i as StateId,
Some('a'),
Some('a'),
(i + 1) as StateId,
TropicalWeight::new(1.0),
);
}
fst
}
fn build_diamond_fst() -> VectorWfst<char, TropicalWeight> {
VectorWfstBuilder::new()
.add_states(4)
.start(0)
.final_state(3, TropicalWeight::one())
.arc(0, Some('a'), Some('a'), 1, TropicalWeight::new(1.0))
.arc(0, Some('b'), Some('b'), 2, TropicalWeight::new(2.0))
.arc(1, Some('c'), Some('c'), 3, TropicalWeight::new(1.0))
.arc(2, Some('d'), Some('d'), 3, TropicalWeight::new(1.0))
.build()
}
#[test]
fn test_single_source_linear() {
let fst = build_linear_fst(3);
let distances = single_source_shortest_distance(&fst, ShortestDistanceConfig::default())
.expect("algorithms/shortest_distance.rs: required value was None/Err");
assert_eq!(distances.len(), 4);
assert!(distances[0].approx_eq(&TropicalWeight::one(), 1e-10)); assert!(distances[1].approx_eq(&TropicalWeight::new(1.0), 1e-10));
assert!(distances[2].approx_eq(&TropicalWeight::new(2.0), 1e-10));
assert!(distances[3].approx_eq(&TropicalWeight::new(3.0), 1e-10)); }
#[test]
fn test_single_source_diamond() {
let fst = build_diamond_fst();
let distances = single_source_shortest_distance(&fst, ShortestDistanceConfig::default())
.expect("algorithms/shortest_distance.rs: required value was None/Err");
assert_eq!(distances.len(), 4);
assert!(distances[0].approx_eq(&TropicalWeight::one(), 1e-10));
assert!(distances[1].approx_eq(&TropicalWeight::new(1.0), 1e-10)); assert!(distances[2].approx_eq(&TropicalWeight::new(2.0), 1e-10)); assert!(distances[3].approx_eq(&TropicalWeight::new(2.0), 1e-10));
}
#[test]
fn test_single_source_with_topological_queue() {
let fst = build_linear_fst(5);
let distances = single_source_shortest_distance(&fst, ShortestDistanceConfig::acyclic())
.expect("algorithms/shortest_distance.rs: required value was None/Err");
assert_eq!(distances.len(), 6);
for i in 0..6 {
assert!(distances[i].approx_eq(&TropicalWeight::new(i as f64), 1e-10));
}
}
#[test]
fn test_shortest_distance_to_final() {
let fst = build_diamond_fst();
let total = shortest_distance_to_final(&fst, ShortestDistanceConfig::default())
.expect("algorithms/shortest_distance.rs: required value was None/Err");
assert!(total.approx_eq(&TropicalWeight::new(2.0), 1e-10));
}
#[test]
fn test_all_pairs_simple() {
let fst: VectorWfst<char, TropicalWeight> = VectorWfstBuilder::new()
.add_states(3)
.start(0)
.final_state(2, TropicalWeight::one())
.arc(0, Some('a'), Some('a'), 1, TropicalWeight::new(1.0))
.arc(1, Some('b'), Some('b'), 2, TropicalWeight::new(2.0))
.build();
let distances = all_pairs_shortest_distance(&fst)
.expect("algorithms/shortest_distance.rs: required value was None/Err");
assert_eq!(distances.len(), 3);
assert!(distances[0][0].approx_eq(&TropicalWeight::one(), 1e-10));
assert!(distances[0][1].approx_eq(&TropicalWeight::new(1.0), 1e-10));
assert!(distances[0][2].approx_eq(&TropicalWeight::new(3.0), 1e-10));
assert!(distances[1][2].approx_eq(&TropicalWeight::new(2.0), 1e-10));
assert!(distances[2][0].is_zero());
}
#[test]
fn test_reverse_shortest_distance() {
let fst = build_linear_fst(3);
let distances = reverse_shortest_distance(&fst, ShortestDistanceConfig::default())
.expect("algorithms/shortest_distance.rs: required value was None/Err");
assert!(distances[0].approx_eq(&TropicalWeight::new(3.0), 1e-10));
assert!(distances[1].approx_eq(&TropicalWeight::new(2.0), 1e-10));
assert!(distances[2].approx_eq(&TropicalWeight::new(1.0), 1e-10));
assert!(distances[3].approx_eq(&TropicalWeight::one(), 1e-10));
}
#[test]
fn test_empty_fst() {
let builder: VectorWfstBuilder<char, TropicalWeight> = VectorWfstBuilder::new();
let fst = builder.build();
let distances = single_source_shortest_distance(&fst, ShortestDistanceConfig::default())
.expect("algorithms/shortest_distance.rs: required value was None/Err");
assert!(distances.is_empty());
let all_pairs = all_pairs_shortest_distance(&fst)
.expect("algorithms/shortest_distance.rs: required value was None/Err");
assert!(all_pairs.is_empty());
}
#[test]
fn test_log_semiring_shortest_distance() {
let fst: VectorWfst<char, LogWeight> = VectorWfstBuilder::new()
.add_states(4)
.start(0)
.final_state(3, LogWeight::one())
.arc(0, Some('a'), Some('a'), 1, LogWeight::new(1.0))
.arc(0, Some('b'), Some('b'), 2, LogWeight::new(2.0))
.arc(1, Some('c'), Some('c'), 3, LogWeight::new(1.0))
.arc(2, Some('d'), Some('d'), 3, LogWeight::new(1.0))
.build();
let distances = single_source_shortest_distance(&fst, ShortestDistanceConfig::default())
.expect("algorithms/shortest_distance.rs: required value was None/Err");
assert!(distances[3].value() < 2.0); }
}