use std::marker::PhantomData;
use anyhow::Result;
use unsafe_unwrap::UnsafeUnwrap;
use crate::algorithms::dfs_visit::{dfs_visit, Visitor};
use crate::algorithms::tr_filters::AnyTrFilter;
use crate::fst_properties::FstProperties;
use crate::fst_traits::Fst;
use crate::fst_traits::{ExpandedFst, MutableFst};
use crate::semirings::Semiring;
use crate::StateId;
use crate::Tr;
use crate::NO_STATE_ID;
pub fn connect<W: Semiring, F: ExpandedFst<W> + MutableFst<W>>(fst: &mut F) -> Result<()> {
let mut visitor = ConnectVisitor::new(fst);
dfs_visit(fst, &mut visitor, &AnyTrFilter {}, false);
let mut dstates = Vec::with_capacity(visitor.access.len());
for s in 0..visitor.access.len() {
if !visitor.access[s] || !visitor.coaccess[s] {
dstates.push(s as StateId);
}
}
fst.del_states(dstates)?;
fst.set_properties_with_mask(
FstProperties::ACCESSIBLE | FstProperties::COACCESSIBLE,
FstProperties::ACCESSIBLE | FstProperties::COACCESSIBLE,
);
Ok(())
}
struct ConnectVisitor<'a, W: Semiring, F: Fst<W>> {
access: Vec<bool>,
coaccess: Vec<bool>,
start: StateId,
fst: &'a F,
nstates: usize,
dfnumber: Vec<i32>,
lowlink: Vec<i32>,
onstack: Vec<bool>,
scc_stack: Vec<StateId>,
w: PhantomData<W>,
}
impl<'a, W: Semiring, F: 'a + ExpandedFst<W>> ConnectVisitor<'a, W, F> {
pub fn new(fst: &'a F) -> Self {
let n = fst.num_states();
Self {
access: vec![false; n],
coaccess: vec![false; n],
start: fst.start().unwrap_or(NO_STATE_ID),
fst,
nstates: 0,
dfnumber: vec![-1; n],
lowlink: vec![-1; n],
onstack: vec![false; n],
scc_stack: vec![],
w: PhantomData,
}
}
}
impl<'a, W: Semiring, F: 'a + ExpandedFst<W>> Visitor<'a, W, F> for ConnectVisitor<'a, W, F> {
fn init_visit(&mut self, _fst: &'a F) {}
fn init_state(&mut self, s: StateId, root: StateId) -> bool {
self.scc_stack.push(s);
let s = s as usize;
self.dfnumber[s] = self.nstates as i32;
self.lowlink[s] = self.nstates as i32;
self.onstack[s] = true;
self.access[s] = root == self.start;
self.nstates += 1;
true
}
fn tree_tr(&mut self, _s: StateId, _tr: &Tr<W>) -> bool {
true
}
fn back_tr(&mut self, s: StateId, tr: &Tr<W>) -> bool {
let t = tr.nextstate as usize;
let s = s as usize;
if self.dfnumber[t] < self.lowlink[s] {
self.lowlink[s] = self.dfnumber[t];
}
if self.coaccess[t] {
self.coaccess[s] = true;
}
true
}
fn forward_or_cross_tr(&mut self, s: StateId, tr: &Tr<W>) -> bool {
let t = tr.nextstate as usize;
let s = s as usize;
if self.dfnumber[t] < self.dfnumber[s]
&& self.onstack[t]
&& self.dfnumber[t] < self.lowlink[s]
{
self.lowlink[s] = self.dfnumber[t];
}
if self.coaccess[t] {
self.coaccess[s] = true;
}
true
}
#[inline]
fn finish_state(&mut self, s: StateId, parent: Option<StateId>, _tr: Option<&Tr<W>>) {
let s = s as usize;
if unsafe { self.fst.is_final_unchecked(s as StateId) } {
self.coaccess[s] = true;
}
if self.dfnumber[s] == self.lowlink[s] {
let mut scc_coaccess = false;
let mut i = self.scc_stack.len();
let mut t;
loop {
i -= 1;
t = self.scc_stack[i] as usize;
if self.coaccess[t] {
scc_coaccess = true;
}
if s == t {
break;
}
}
loop {
t = unsafe { *self.scc_stack.last().unsafe_unwrap() } as usize;
if scc_coaccess {
self.coaccess[t] = true;
}
self.onstack[t] = false;
self.scc_stack.pop();
if s == (t as usize) {
break;
}
}
}
if let Some(_p) = parent {
let _p = _p as usize;
let s = s as usize;
if self.coaccess[s] {
self.coaccess[_p] = true;
}
if self.lowlink[s] < self.lowlink[_p] {
self.lowlink[_p] = self.lowlink[s];
}
}
}
#[inline]
fn finish_visit(&mut self) {}
}
#[cfg(test)]
mod tests {
use ::proptest::prelude::*;
use crate::fst_properties::FstProperties;
use crate::prelude::*;
use super::*;
proptest! {
#[test]
fn test_connect_proptest(mut fst in any::<VectorFst::<TropicalWeight>>()) {
connect(&mut fst).unwrap();
prop_assert!(fst.properties().intersects(
FstProperties::ACCESSIBLE | FstProperties::COACCESSIBLE
));
}
}
}