use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Result;
use crate::algorithms::compose::{IntervalSet, StateReachable};
use crate::algorithms::tr_compares::{ILabelCompare, OLabelCompare};
use crate::algorithms::{fst_convert_from_ref, tr_sort};
use crate::fst_impls::VectorFst;
use crate::fst_properties::FstProperties;
use crate::fst_traits::{CoreFst, ExpandedFst, Fst, MutableFst};
use crate::semirings::Semiring;
use crate::{Label, StateId, Tr, Trs, EPS_LABEL, NO_LABEL, UNASSIGNED};
#[derive(Debug, Clone, PartialEq)]
pub struct LabelReachableData {
reach_input: bool,
final_label: Label,
label2index: HashMap<Label, Label>,
interval_sets: Vec<IntervalSet>,
}
impl LabelReachableData {
pub fn new(reach_input: bool) -> Self {
Self {
reach_input,
final_label: NO_LABEL,
label2index: HashMap::new(),
interval_sets: Vec::new(),
}
}
pub fn interval_set(&self, s: StateId) -> Result<&IntervalSet> {
self.interval_sets
.get(s)
.ok_or_else(|| format_err!("Missing state {}", s))
}
pub fn final_label(&self) -> Label {
self.final_label
}
pub fn label2index(&self) -> &HashMap<Label, Label> {
&self.label2index
}
pub fn reach_input(&self) -> bool {
self.reach_input
}
pub fn relabel(&mut self, label: Label) -> Label {
if label == EPS_LABEL {
return EPS_LABEL;
}
let n = self.label2index.len();
*self.label2index.entry(label).or_insert_with(|| n + 1)
}
pub fn relabel_fst<W: Semiring, F: MutableFst<W>>(
&mut self,
fst: &mut F,
relabel_input: bool,
) -> Result<()> {
for s in 0..fst.num_states() {
unsafe {
let mut it_tr = fst.tr_iter_unchecked_mut(s);
for idx_tr in 0..it_tr.len() {
let tr = it_tr.get_unchecked(idx_tr);
if relabel_input {
let new_ilabel = self.relabel(tr.ilabel);
it_tr.set_ilabel_unchecked(idx_tr, new_ilabel);
} else {
let new_olabel = self.relabel(tr.olabel);
it_tr.set_olabel_unchecked(idx_tr, new_olabel);
}
}
}
}
if relabel_input {
tr_sort(fst, ILabelCompare {});
fst.take_input_symbols();
} else {
tr_sort(fst, OLabelCompare {});
fst.take_output_symbols();
}
Ok(())
}
pub fn relabel_pairs(&self, avoid_collisions: bool) -> Vec<(Label, Label)> {
let mut pairs = vec![];
for (key, val) in self.label2index.iter() {
if *val != self.final_label {
pairs.push((*key, *val));
}
}
if avoid_collisions {
for i in 1..=self.label2index.len() {
let it = self.label2index.get(&i);
if it.is_none() || it.unwrap() == &self.final_label {
pairs.push((i, self.label2index.len() + 1));
}
}
}
pairs
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct LabelReachable {
data: Arc<LabelReachableData>,
reach_fst_input: bool,
}
impl LabelReachable {
pub fn new<W: Semiring, F: Fst<W>>(fst: &F, reach_input: bool) -> Result<Self> {
let data = Self::compute_data(fst, reach_input)?;
Ok(Self {
data: Arc::new(data),
reach_fst_input: false,
})
}
pub fn compute_data<W: Semiring, F: Fst<W>>(
fst: &F,
reach_input: bool,
) -> Result<LabelReachableData> {
let mut fst: VectorFst<_> = fst_convert_from_ref(fst);
let mut data = LabelReachableData::new(reach_input);
let mut label2state = HashMap::new();
let nstates = fst.num_states();
Self::transform_fst(&mut fst, &mut data, &mut label2state);
fst.compute_and_update_properties(FstProperties::ACYCLIC)?;
Self::find_intervals(&fst, nstates, &mut data, &mut label2state)?;
Ok(data)
}
pub fn new_from_data(data: Arc<LabelReachableData>) -> Self {
Self {
data,
reach_fst_input: false,
}
}
pub fn data(&self) -> &Arc<LabelReachableData> {
&self.data
}
pub fn reach_input(&self) -> bool {
self.data.reach_input
}
fn transform_fst<W: Semiring>(
fst: &mut VectorFst<W>,
data: &mut LabelReachableData,
label2state: &mut HashMap<Label, StateId>,
) {
let ins = fst.num_states();
let mut ons = ins;
let mut indeg = vec![0; ins];
for s in 0..ins {
let mut it_tr = unsafe { fst.tr_iter_unchecked_mut(s) };
for idx_tr in 0..it_tr.len() {
let tr = unsafe { it_tr.get_unchecked(idx_tr) };
let label = if data.reach_input {
tr.ilabel
} else {
tr.olabel
};
let nextstate = if label != EPS_LABEL {
match label2state.entry(label) {
Entry::Vacant(e) => {
let v = *e.insert(ons);
indeg.push(0);
ons += 1;
v
}
Entry::Occupied(e) => *e.get(),
}
} else {
tr.nextstate
};
indeg[nextstate] += 1;
unsafe { it_tr.set_nextstate_unchecked(idx_tr, nextstate) };
}
if let Some(final_weight) = unsafe { fst.final_weight_unchecked(s) } {
if !final_weight.is_zero() {
let nextstate = match label2state.entry(NO_LABEL) {
Entry::Vacant(e) => {
let v = *e.insert(ons);
indeg.push(0);
ons += 1;
v
}
Entry::Occupied(e) => *e.get(),
};
unsafe {
fst.add_tr_unchecked(
s,
Tr::new(NO_LABEL, NO_LABEL, final_weight, nextstate),
)
};
indeg[nextstate] += 1;
unsafe { fst.delete_final_weight_unchecked(s) }
}
}
}
while fst.num_states() < ons {
let s = fst.add_state();
unsafe { fst.set_final_unchecked(s, W::one()) };
}
let start = fst.add_state();
unsafe { fst.set_start_unchecked(start) };
for s in 0..start {
if indeg[s] == 0 {
unsafe { fst.add_tr_unchecked(start, Tr::new(0, 0, W::one(), s)) };
}
}
}
fn find_intervals<W: Semiring>(
fst: &VectorFst<W>,
ins: StateId,
data: &mut LabelReachableData,
label2state: &mut HashMap<Label, StateId>,
) -> Result<()> {
let state_reachable = StateReachable::new(fst)?;
let state2index = &state_reachable.state2index;
let interval_sets = &mut data.interval_sets;
*interval_sets = state_reachable.isets;
interval_sets.resize_with(ins, IntervalSet::default);
let label2index = &mut data.label2index;
for (label, state) in label2state.iter() {
let i = state2index[*state];
label2index.insert(*label, i);
if *label == NO_LABEL {
data.final_label = i;
}
}
label2state.clear();
Ok(())
}
pub fn reach_init<W: Semiring, F: ExpandedFst<W>>(
&mut self,
fst: &Arc<F>,
reach_input: bool,
) -> Result<()> {
self.reach_fst_input = reach_input;
let true_prop = if self.reach_fst_input {
FstProperties::I_LABEL_SORTED
} else {
FstProperties::O_LABEL_SORTED
};
let props = fst.properties_check(true_prop)?;
if !props.contains(true_prop) {
bail!("LabelReachable::ReachInit: Fst is not sorted")
}
Ok(())
}
pub fn reach_label(&self, current_state: StateId, label: Label) -> Result<bool> {
if label == EPS_LABEL {
return Ok(false);
}
Ok(self.data.interval_set(current_state)?.member(label))
}
pub fn reach_final(&self, current_state: StateId) -> Result<bool> {
Ok(self
.data
.interval_set(current_state)?
.member(self.data.final_label()))
}
pub fn reach<'a, W: Semiring + 'a, T: Trs<W>>(
&self,
current_state: StateId,
trs: T,
aiter_begin: usize,
aiter_end: usize,
compute_weight: bool,
) -> Result<Option<(usize, usize, W)>> {
let mut reach_begin = UNASSIGNED;
let mut reach_end = UNASSIGNED;
let mut reach_weight = W::zero();
let interval_set = self.data.interval_set(current_state)?;
let trs_slice = trs.trs();
if 2 * (aiter_end - aiter_begin) < interval_set.len() {
let mut reach_label = NO_LABEL;
for pos in aiter_begin..aiter_end {
let tr = unsafe { trs_slice.get_unchecked(pos) };
let label = if self.reach_fst_input {
tr.ilabel
} else {
tr.olabel
};
if label == reach_label || self.reach_label(current_state, label)? {
reach_label = label;
if reach_begin == UNASSIGNED {
reach_begin = pos;
}
reach_end = pos + 1;
if compute_weight {
reach_weight.plus_assign(&tr.weight)?;
}
}
}
} else {
let mut begin_low;
let mut end_low = aiter_begin;
for interval in interval_set.iter() {
begin_low = self.lower_bound(trs_slice, end_low, aiter_end, interval.begin);
end_low = self.lower_bound(trs_slice, begin_low, aiter_end, interval.end);
if end_low - begin_low > 0 {
if reach_begin == UNASSIGNED {
reach_begin = begin_low;
}
reach_end = end_low;
if compute_weight {
for i in begin_low..end_low {
reach_weight
.plus_assign(unsafe { &trs_slice.get_unchecked(i).weight })?;
}
}
}
}
}
if reach_begin != UNASSIGNED {
Ok(Some((reach_begin, reach_end, reach_weight)))
} else {
Ok(None)
}
}
fn lower_bound<W: Semiring>(
&self,
trs: &[Tr<W>],
aiter_begin: usize,
aiter_end: usize,
match_label: Label,
) -> usize {
debug_assert!(match_label != NO_LABEL);
let mut low = aiter_begin;
let mut high = aiter_end;
while low < high {
let mid = low + (high - low) / 2;
let tr = unsafe { trs.get_unchecked(mid) };
let label = if self.reach_fst_input {
tr.ilabel
} else {
tr.olabel
};
debug_assert!(label != NO_LABEL);
if label < match_label {
low = mid + 1;
} else {
high = mid;
}
}
low
}
}