use std::collections::VecDeque;
use std::hash::{Hash, Hasher};
use anyhow::Result;
use crate::fst_traits::Fst;
use crate::semirings::Semiring;
use crate::{Label, StateId, Trs, EPS_LABEL};
#[derive(PartialEq, Debug, Clone, PartialOrd)]
pub struct FstPath<W: Semiring> {
pub ilabels: Vec<Label>,
pub olabels: Vec<Label>,
pub weight: W,
}
impl<W: Semiring> FstPath<W> {
pub fn new(ilabels: Vec<Label>, olabels: Vec<Label>, weight: W) -> Self {
FstPath {
ilabels,
olabels,
weight,
}
}
pub fn add_to_path(&mut self, ilabel: Label, olabel: Label, weight: &W) -> Result<()> {
if ilabel != EPS_LABEL {
self.ilabels.push(ilabel);
}
if olabel != EPS_LABEL {
self.olabels.push(olabel);
}
self.weight.times_assign(weight)
}
pub fn add_weight(&mut self, weight: &W) -> Result<()> {
self.weight.times_assign(weight)
}
pub fn concat(&mut self, other: FstPath<W>) -> Result<()> {
self.ilabels.extend(other.ilabels);
self.olabels.extend(other.olabels);
self.weight.times_assign(other.weight)
}
pub fn is_empty(&self) -> bool {
self.ilabels.is_empty() && self.olabels.is_empty() && self.weight.is_one()
}
}
impl<W: Semiring> Default for FstPath<W> {
fn default() -> Self {
FstPath {
ilabels: vec![],
olabels: vec![],
weight: W::one(),
}
}
}
#[allow(clippy::derive_hash_xor_eq)]
impl<W: Semiring + Hash + Eq> Hash for FstPath<W> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.ilabels.hash(state);
self.olabels.hash(state);
self.weight.hash(state);
}
}
impl<W: Semiring + Hash + Eq> Eq for FstPath<W> {}
struct BfsState<W: Semiring> {
state: StateId,
weight_curr: W,
next_ilabel_idx: StateId,
next_olabel_idx: StateId,
}
pub fn check_path_in_fst<W: Semiring, F: Fst<W>>(fst: &F, fst_path: &FstPath<W>) -> bool {
if let Some(start) = fst.start() {
let mut queue = VecDeque::new();
queue.push_back(BfsState {
state: start,
next_ilabel_idx: 0,
next_olabel_idx: 0,
weight_curr: W::one(),
});
while !queue.is_empty() {
let lol = queue.pop_front().unwrap();
let state = lol.state;
let next_ilabel_idx = lol.next_ilabel_idx as usize;
let next_olabel_idx = lol.next_olabel_idx as usize;
let weight_curr = lol.weight_curr;
if next_ilabel_idx >= fst_path.ilabels.len()
&& next_olabel_idx >= fst_path.olabels.len()
{
if let Some(final_weight) = unsafe { fst.final_weight_unchecked(state) } {
if weight_curr.times(final_weight).unwrap() == fst_path.weight {
return true;
}
}
}
for tr in unsafe { fst.get_trs_unchecked(state) }.trs() {
let match_ilabel = next_ilabel_idx < fst_path.ilabels.len()
&& tr.ilabel == fst_path.ilabels[next_ilabel_idx];
let match_olabel = next_olabel_idx < fst_path.olabels.len()
&& tr.olabel == fst_path.olabels[next_ilabel_idx];
let (new_next_ilabel_idx, new_next_olabel_idx) =
if tr.ilabel == EPS_LABEL && tr.olabel == EPS_LABEL {
(next_ilabel_idx, next_olabel_idx)
} else if tr.ilabel != EPS_LABEL && tr.olabel == EPS_LABEL {
if match_ilabel {
(next_ilabel_idx + 1, next_olabel_idx)
} else {
continue;
}
} else if tr.ilabel == EPS_LABEL && tr.olabel != EPS_LABEL {
if match_olabel {
(next_ilabel_idx, next_olabel_idx + 1)
} else {
continue;
}
} else if match_ilabel && match_olabel {
(next_ilabel_idx + 1, next_olabel_idx + 1)
} else {
continue;
};
queue.push_back(BfsState {
state: tr.nextstate,
next_ilabel_idx: new_next_ilabel_idx as Label,
next_olabel_idx: new_next_olabel_idx as Label,
weight_curr: weight_curr.times(&tr.weight).unwrap(),
})
}
}
false
} else {
fst_path.is_empty()
}
}
#[macro_export]
macro_rules! fst_path {
( $( $x:expr ),*) => {
{
fn semiring_one<W: Semiring>() -> W {
W::one()
}
FstPath::new(
vec![$($x),*],
vec![$($x),*],
semiring_one()
)
}
};
( $( $x:expr ),* => $( $y:expr ),* ) => {
{
fn semiring_one<W: Semiring>() -> W {
W::one()
}
FstPath::new(
vec![$($x),*],
vec![$($y),*],
semiring_one()
)
}
};
( $( $x:expr ),* ; $weight:expr) => {
{
fn semiring_new<W: Semiring>(v: W::Type) -> W {
W::new(v)
}
FstPath::new(
vec![$($x),*],
vec![$($x),*],
semiring_new($weight)
)
}
};
( $( $x:expr ),* => $( $y:expr ),* ; $weight:expr) => {
{
fn semiring_new<W: Semiring>(v: W::Type) -> W {
W::new(v)
}
FstPath::new(
vec![$($x),*],
vec![$($y),*],
semiring_new($weight)
)
}
};
}
#[cfg(test)]
mod test {
use super::*;
use crate::fst_impls::VectorFst;
use crate::fst_traits::MutableFst;
use crate::semirings::TropicalWeight;
#[test]
fn test_check_path_in_fst() -> Result<()> {
let mut fst = VectorFst::<TropicalWeight>::new();
fst.add_states(3);
fst.set_start(0)?;
fst.emplace_tr(0, 1, 2, 1.2, 1)?;
fst.emplace_tr(0, 4, 6, 1.1, 1)?;
fst.emplace_tr(1, 2, 3, 0.3, 2)?;
fst.emplace_tr(1, 6, 7, 0.5, 2)?;
fst.emplace_tr(0, 10, 12, 3.0, 2)?;
fst.set_final(2, 3.2)?;
assert!(!check_path_in_fst(
&fst,
&FstPath::new(vec![], vec![], TropicalWeight::one())
));
assert!(!check_path_in_fst(
&fst,
&FstPath::new(vec![1], vec![2], TropicalWeight::new(1.2))
));
assert!(!check_path_in_fst(
&fst,
&FstPath::new(vec![1, 2], vec![2, 3], TropicalWeight::new(1.5))
));
assert!(check_path_in_fst(
&fst,
&FstPath::new(vec![1, 2], vec![2, 3], TropicalWeight::new(4.7))
));
assert!(!check_path_in_fst(
&fst,
&FstPath::new(vec![10], vec![10], TropicalWeight::new(3.0))
));
assert!(!check_path_in_fst(
&fst,
&FstPath::new(vec![12], vec![12], TropicalWeight::new(6.2))
));
assert!(!check_path_in_fst(
&fst,
&FstPath::new(vec![10], vec![10], TropicalWeight::new(6.2))
));
assert!(check_path_in_fst(
&fst,
&FstPath::new(vec![10], vec![12], TropicalWeight::new(6.2))
));
Ok(())
}
}