use failure::Fallible;
use crate::algorithms::ReplaceFst;
use crate::arc::Arc;
use crate::fst_traits::{
AllocableFst, ArcIterator, CoreFst, ExpandedFst, Fst, MutableFst, StateIterator,
};
use crate::semirings::Semiring;
use crate::{SymbolTable, EPS_LABEL};
use std::rc::Rc;
pub fn concat<W, F1, F2>(fst_1: &mut F1, fst_2: &F2) -> Fallible<()>
where
W: Semiring,
F1: ExpandedFst<W = W> + MutableFst<W = W> + AllocableFst<W = W>,
F2: ExpandedFst<W = W>,
{
let start1 = fst_1.start();
if start1.is_none() {
return Ok(());
}
let numstates1 = fst_1.num_states();
fst_1.reserve_states(fst_2.num_states());
for s2 in 0..fst_2.num_states() {
let s1 = fst_1.add_state();
if let Some(final_weight) = unsafe { fst_2.final_weight_unchecked(s2) } {
unsafe { fst_1.set_final_unchecked(s1, final_weight.clone()) };
}
unsafe { fst_1.reserve_arcs_unchecked(s1, fst_2.num_arcs_unchecked(s2)) };
for arc in unsafe { fst_2.arcs_iter_unchecked(s2) } {
let mut new_arc = arc.clone();
new_arc.nextstate += numstates1;
unsafe { fst_1.add_arc_unchecked(s1, new_arc) };
}
}
let start2 = fst_2.start();
for s1 in 0..numstates1 {
if let Some(weight) = unsafe { fst_1.final_weight_unchecked(s1) } {
if let Some(_start2) = start2 {
let weight = weight.clone();
unsafe {
fst_1.add_arc_unchecked(
s1,
Arc::new(EPS_LABEL, EPS_LABEL, weight, _start2 + numstates1),
)
};
}
unsafe { fst_1.delete_final_weight_unchecked(s1) };
}
}
Ok(())
}
#[derive(Debug, Clone, PartialEq)]
pub struct ConcatFst<F: Fst + 'static>(ReplaceFst<F, F>)
where
F::W: 'static;
impl<F: Fst + MutableFst + AllocableFst> ConcatFst<F>
where
F::W: 'static,
{
pub fn new(fst1: F, fst2: F) -> Fallible<Self> {
let mut rfst = F::new();
rfst.add_states(3);
unsafe { rfst.set_start_unchecked(0) };
unsafe { rfst.set_final_unchecked(2, F::W::one()) };
if let Some(isymt) = fst1.input_symbols() {
rfst.set_input_symbols(isymt);
}
if let Some(osymt) = fst1.output_symbols() {
rfst.set_output_symbols(osymt);
}
unsafe { rfst.add_arc_unchecked(0, Arc::new(EPS_LABEL, std::usize::MAX, F::W::one(), 1)) };
unsafe {
rfst.add_arc_unchecked(1, Arc::new(EPS_LABEL, std::usize::MAX - 1, F::W::one(), 2))
};
let mut fst_tuples = Vec::with_capacity(3);
fst_tuples.push((0, rfst));
fst_tuples.push((std::usize::MAX, fst1));
fst_tuples.push((std::usize::MAX - 1, fst2));
Ok(ConcatFst(ReplaceFst::new(fst_tuples, 0, false)?))
}
}
impl<F: Fst> CoreFst for ConcatFst<F>
where
F::W: 'static,
{
type W = F::W;
fn start(&self) -> Option<usize> {
self.0.start()
}
fn final_weight(&self, state_id: usize) -> Fallible<Option<&Self::W>> {
self.0.final_weight(state_id)
}
unsafe fn final_weight_unchecked(&self, state_id: usize) -> Option<&Self::W> {
self.0.final_weight_unchecked(state_id)
}
fn num_arcs(&self, s: usize) -> Fallible<usize> {
self.0.num_arcs(s)
}
unsafe fn num_arcs_unchecked(&self, s: usize) -> usize {
self.0.num_arcs_unchecked(s)
}
}
impl<'a, F: Fst + 'static> StateIterator<'a> for ConcatFst<F>
where
F::W: 'static,
{
type Iter = <ReplaceFst<F, F> as StateIterator<'a>>::Iter;
fn states_iter(&'a self) -> Self::Iter {
self.0.states_iter()
}
}
impl<'a, F: Fst + 'static> ArcIterator<'a> for ConcatFst<F>
where
F::W: 'static,
{
type Iter = <ReplaceFst<F, F> as ArcIterator<'a>>::Iter;
fn arcs_iter(&'a self, state_id: usize) -> Fallible<Self::Iter> {
self.0.arcs_iter(state_id)
}
unsafe fn arcs_iter_unchecked(&'a self, state_id: usize) -> Self::Iter {
self.0.arcs_iter_unchecked(state_id)
}
}
impl<F: Fst + 'static> Fst for ConcatFst<F>
where
F::W: 'static,
{
fn input_symbols(&self) -> Option<Rc<SymbolTable>> {
self.0.input_symbols()
}
fn output_symbols(&self) -> Option<Rc<SymbolTable>> {
self.0.output_symbols()
}
}