use std::sync::Arc;
use anyhow::Result;
use crate::algorithms::compose::compose_filters::{ComposeFilter, ComposeFilterBuilder};
use crate::algorithms::compose::filter_states::FilterState;
use crate::algorithms::compose::lookahead_filters::lookahead_selector::Selector;
use crate::algorithms::compose::matchers::{IterItemMatcher, MatcherFlags};
use crate::algorithms::compose::matchers::{MatchType, Matcher, REQUIRE_PRIORITY};
use crate::algorithms::compose::{ComposeFstOpOptions, ComposeStateTuple};
use crate::algorithms::lazy_fst_revamp::{FstOp, StateTable};
use crate::fst_properties::mutable_properties::compose_properties;
use crate::fst_properties::FstProperties;
use crate::fst_traits::CoreFst;
use crate::semirings::Semiring;
use crate::{StateId, Tr, Trs, TrsVec, EPS_LABEL, NO_LABEL};
#[derive(Debug, Clone)]
pub struct ComposeFstOp<W: Semiring, CFB: ComposeFilterBuilder<W>> {
compose_filter_builder: CFB,
state_table: StateTable<ComposeStateTuple<<CFB::CF as ComposeFilter<W>>::FS>>,
match_type: MatchType,
properties: FstProperties,
fst1: Arc<<<CFB::CF as ComposeFilter<W>>::M1 as Matcher<W>>::F>,
fst2: Arc<<<CFB::CF as ComposeFilter<W>>::M2 as Matcher<W>>::F>,
}
impl<W: Semiring, CFB: ComposeFilterBuilder<W>> ComposeFstOp<W, CFB> {
pub fn new(
fst1: Arc<<<CFB::CF as ComposeFilter<W>>::M1 as Matcher<W>>::F>,
fst2: Arc<<<CFB::CF as ComposeFilter<W>>::M2 as Matcher<W>>::F>,
opts: ComposeFstOpOptions<
CFB::M1,
CFB::M2,
CFB,
StateTable<ComposeStateTuple<<CFB::CF as ComposeFilter<W>>::FS>>,
>,
) -> Result<Self> {
let matcher1 = opts.matcher1;
let matcher2 = opts.matcher2;
let compose_filter_builder = opts.filter_builder.unwrap_or_else(|| {
ComposeFilterBuilder::new(Arc::clone(&fst1), Arc::clone(&fst2), matcher1, matcher2)
.unwrap()
});
let compose_filter = compose_filter_builder.build()?;
let match_type = Self::match_type(compose_filter.matcher1(), compose_filter.matcher2())?;
let fprops1 = fst1.properties();
let fprops2 = fst2.properties();
let cprops = compose_properties(fprops1, fprops2);
let properties = compose_filter.properties(cprops);
Ok(Self {
compose_filter_builder,
state_table: opts.state_table.unwrap_or_else(StateTable::new),
match_type,
properties,
fst1,
fst2,
})
}
fn match_type(
matcher1: &<CFB::CF as ComposeFilter<W>>::M1,
matcher2: &<CFB::CF as ComposeFilter<W>>::M2,
) -> Result<MatchType> {
if matcher1.flags().contains(MatcherFlags::REQUIRE_MATCH)
&& matcher1.match_type(true)? != MatchType::MatchOutput
{
bail!("ComposeFst: 1st argument cannot perform required matching (sort?)")
}
if matcher2.flags().contains(MatcherFlags::REQUIRE_MATCH)
&& matcher2.match_type(true)? != MatchType::MatchInput
{
bail!("ComposeFst: 2nd argument cannot perform required matching (sort?)")
}
let type1 = matcher1.match_type(false)?;
let type2 = matcher2.match_type(false)?;
let mt = if type1 == MatchType::MatchOutput && type2 == MatchType::MatchInput {
MatchType::MatchBoth
} else if type1 == MatchType::MatchOutput {
MatchType::MatchOutput
} else if type2 == MatchType::MatchInput {
MatchType::MatchInput
} else if matcher1.match_type(true)? == MatchType::MatchOutput {
MatchType::MatchOutput
} else if matcher2.match_type(true)? == MatchType::MatchInput {
MatchType::MatchInput
} else {
bail!("ComposeFst: 1st argument cannot match on output labels and 2nd argument cannot match on input labels (sort?).")
};
Ok(mt)
}
fn match_input(&self, s1: StateId, s2: StateId, compose_filter: &CFB::CF) -> Result<bool> {
match self.match_type {
MatchType::MatchInput => Ok(true),
MatchType::MatchOutput => Ok(false),
_ => {
let priority1 = compose_filter.matcher1().priority(s1)?;
let priority2 = compose_filter.matcher2().priority(s2)?;
if priority1 == REQUIRE_PRIORITY && priority2 == REQUIRE_PRIORITY {
bail!("Both sides can't require match")
}
if priority1 == REQUIRE_PRIORITY {
return Ok(false);
}
if priority2 == REQUIRE_PRIORITY {
return Ok(true);
}
Ok(priority1 <= priority2)
}
}
}
fn ordered_expand(
&self,
sa: StateId,
sb: StateId,
match_input: bool,
mut compose_filter: CFB::CF,
selector: Selector,
) -> Result<TrsVec<W>> {
let tr_loop = if match_input {
Tr::new(EPS_LABEL, NO_LABEL, W::one(), sb)
} else {
Tr::new(NO_LABEL, EPS_LABEL, W::one(), sb)
};
let mut trs = vec![];
match selector {
Selector::Fst1Matcher2 => {
self.match_tr(
sa,
&tr_loop,
match_input,
&mut compose_filter,
selector,
&mut trs,
)?;
for tr in self.fst1.get_trs(sb)?.trs() {
self.match_tr(sa, tr, match_input, &mut compose_filter, selector, &mut trs)?;
}
}
Selector::Fst2Matcher1 => {
self.match_tr(
sa,
&tr_loop,
match_input,
&mut compose_filter,
selector,
&mut trs,
)?;
for tr in self.fst2.get_trs(sb)?.trs() {
self.match_tr(sa, tr, match_input, &mut compose_filter, selector, &mut trs)?;
}
}
}
Ok(TrsVec(Arc::new(trs)))
}
fn add_tr(
&self,
mut arc1: Tr<W>,
arc2: Tr<W>,
fs: <CFB::CF as ComposeFilter<W>>::FS,
) -> Result<Tr<W>> {
let tuple = ComposeStateTuple {
fs,
s1: arc1.nextstate,
s2: arc2.nextstate,
};
arc1.weight.times_assign(arc2.weight)?;
Ok(Tr::new(
arc1.ilabel,
arc2.olabel,
arc1.weight,
self.state_table.find_id(tuple),
))
}
fn match_tr_selected(
&self,
sa: StateId,
tr: &Tr<W>,
match_input: bool,
compose_filter: &mut CFB::CF,
it: impl Iterator<Item = IterItemMatcher<W>>,
trs: &mut Vec<Tr<W>>,
) -> Result<()> {
let match_type = if match_input {
MatchType::MatchInput
} else {
MatchType::MatchOutput
};
for arca in it {
let mut arca = arca.into_tr(sa, match_type)?;
let mut arcb = tr.clone();
if match_input {
let fs = compose_filter.filter_tr(&mut arcb, &mut arca)?;
if fs != <CFB::CF as ComposeFilter<W>>::FS::new_no_state() {
trs.push(self.add_tr(arcb, arca, fs)?);
}
} else {
let fs = compose_filter.filter_tr(&mut arca, &mut arcb)?;
if fs != <CFB::CF as ComposeFilter<W>>::FS::new_no_state() {
trs.push(self.add_tr(arca, arcb, fs)?);
}
}
}
Ok(())
}
fn match_tr(
&self,
sa: StateId,
tr: &Tr<W>,
match_input: bool,
compose_filter: &mut CFB::CF,
selector: Selector,
trs: &mut Vec<Tr<W>>,
) -> Result<()> {
let label = if match_input { tr.olabel } else { tr.ilabel };
match selector {
Selector::Fst2Matcher1 => self.match_tr_selected(
sa,
tr,
match_input,
compose_filter,
compose_filter.matcher1().iter(sa, label)?,
trs,
),
Selector::Fst1Matcher2 => self.match_tr_selected(
sa,
tr,
match_input,
compose_filter,
compose_filter.matcher2().iter(sa, label)?,
trs,
),
}
}
}
impl<W: Semiring, CFB: ComposeFilterBuilder<W>> FstOp<W> for ComposeFstOp<W, CFB> {
fn compute_start(&self) -> Result<Option<usize>> {
let compose_filter = self.compose_filter_builder.build()?;
let s1 = self.fst1.start();
if s1.is_none() {
return Ok(None);
}
let s1 = s1.unwrap();
let s2 = self.fst2.start();
if s2.is_none() {
return Ok(None);
}
let s2 = s2.unwrap();
let fs = compose_filter.start();
let tuple = ComposeStateTuple { s1, s2, fs };
Ok(Some(self.state_table.find_id(tuple)))
}
fn compute_trs(&self, state: usize) -> Result<TrsVec<W>> {
let tuple = self.state_table.find_tuple(state);
let s1 = tuple.s1;
let s2 = tuple.s2;
let mut compose_filter = self.compose_filter_builder.build()?;
compose_filter.set_state(s1, s2, &tuple.fs)?;
let res = if self.match_input(s1, s2, &compose_filter)? {
self.ordered_expand(s2, s1, true, compose_filter, Selector::Fst1Matcher2)
} else {
self.ordered_expand(s1, s2, false, compose_filter, Selector::Fst2Matcher1)
};
res
}
fn compute_final_weight(&self, state: usize) -> Result<Option<W>> {
let tuple = self.state_table.find_tuple(state);
let mut compose_filter = self.compose_filter_builder.build()?;
let s1 = tuple.s1;
let final1 = compose_filter.matcher1().final_weight(s1)?;
if final1.is_none() {
return Ok(None);
}
let mut final1 = final1.unwrap();
let s2 = tuple.s2;
let final2 = compose_filter.matcher2().final_weight(s2)?;
if final2.is_none() {
return Ok(None);
}
let mut final2 = final2.unwrap();
compose_filter.set_state(s1, s2, &tuple.fs)?;
compose_filter.filter_final(&mut final1, &mut final2)?;
final1.times_assign(&final2)?;
if final1.is_zero() {
Ok(None)
} else {
Ok(Some(final1))
}
}
fn properties(&self) -> FstProperties {
self.properties
}
}