#[cfg(test)]
mod tests {
use rand::{rngs::StdRng, Rng, SeedableRng};
use anyhow::Result;
use crate::fst_impls::VectorFst;
use crate::fst_traits::{
CoreFst, ExpandedFst, Fst, MutableFst, SerializableFst, StateIterator,
};
use crate::semirings::{ProbabilityWeight, Semiring, TropicalWeight};
use crate::tr::Tr;
use crate::{SymbolTable, Trs};
use std::sync::Arc;
#[test]
fn test_small_fst() -> Result<()> {
let mut fst = VectorFst::<ProbabilityWeight>::new();
let s1 = fst.add_state();
let s2 = fst.add_state();
fst.set_start(s1)?;
let tr_1 = Tr::new(3, 5, 10.0, s2);
fst.add_tr(s1, tr_1.clone())?;
assert_eq!(fst.num_trs(s1).unwrap(), 1);
let tr_2 = Tr::new(5, 7, 18.0, s2);
fst.add_tr(s1, tr_2.clone())?;
assert_eq!(fst.num_trs(s1).unwrap(), 2);
assert_eq!(fst.get_trs(s1)?.trs().iter().count(), 2);
let it_s1 = fst.get_trs(s1)?;
assert_eq!(it_s1.len(), 2);
assert_eq!(tr_1, it_s1.trs()[0]);
assert_eq!(tr_2, it_s1.trs()[1]);
let it_s2 = fst.get_trs(s2)?;
assert_eq!(it_s2.len(), 0);
Ok(())
}
#[test]
fn test_mutable_iter_trs_small() -> Result<()> {
let mut fst = VectorFst::<ProbabilityWeight>::new();
let s1 = fst.add_state();
let s2 = fst.add_state();
fst.set_start(s1)?;
let tr_1 = Tr::new(3, 5, 10.0, s2);
fst.add_tr(s1, tr_1.clone())?;
let tr_2 = Tr::new(5, 7, 18.0, s2);
fst.add_tr(s1, tr_2.clone())?;
let new_tr_1 = Tr::new(15, 29, 33.0, s2 + 55);
let mut tr_it = fst.tr_iter_mut(s1)?;
tr_it.set_tr(0, new_tr_1.clone())?;
let it_s1 = fst.get_trs(s1)?;
assert_eq!(new_tr_1, it_s1[0]);
assert_eq!(tr_2, it_s1[1]);
assert_eq!(it_s1.len(), 2);
Ok(())
}
#[test]
fn test_start_states() -> Result<()> {
let mut fst = VectorFst::<ProbabilityWeight>::new();
let n_states = 1000;
let states: Vec<_> = (0..n_states).map(|_| fst.add_state()).collect();
assert_eq!(fst.start(), None);
fst.set_start(states[18])?;
assert_eq!(fst.start(), Some(states[18]));
fst.set_start(states[32])?;
assert_eq!(fst.start(), Some(states[32]));
Ok(())
}
#[test]
fn test_only_final_states() -> Result<()> {
let mut fst = VectorFst::<ProbabilityWeight>::new();
let n_states = 1000;
let states: Vec<_> = (0..n_states).map(|_| fst.add_state()).collect();
assert_eq!(fst.final_states_iter().count(), 0);
states
.iter()
.for_each(|v| fst.set_final(*v, ProbabilityWeight::one()).unwrap());
assert_eq!(fst.final_states_iter().count(), n_states);
Ok(())
}
#[test]
fn test_final_weight() -> Result<()> {
let mut fst = VectorFst::<ProbabilityWeight>::new();
let n_states = 1000;
let n_final_states = 300;
let mut states: Vec<_> = (0..n_states).map(|_| fst.add_state()).collect();
assert!(fst
.states_iter()
.map(|state_id| fst.final_weight(state_id).unwrap())
.all(|v| v.is_none()));
let mut rg = StdRng::from_seed([53; 32]);
rg.shuffle(&mut states);
let final_states: Vec<_> = states.into_iter().take(n_final_states).collect();
final_states.iter().enumerate().for_each(|(idx, state_id)| {
fst.set_final(*state_id, ProbabilityWeight::new(idx as f32))
.unwrap()
});
assert!(final_states
.iter()
.all(|state_id| fst.is_final(*state_id).unwrap()));
assert!(final_states
.iter()
.enumerate()
.all(|(idx, state_id)| fst.final_weight(*state_id).unwrap()
== Some(ProbabilityWeight::new(idx as f32))));
Ok(())
}
#[test]
fn test_del_state_trs() -> Result<()> {
let mut fst = VectorFst::<ProbabilityWeight>::new();
let s1 = fst.add_state();
let s2 = fst.add_state();
fst.add_tr(s1, Tr::new(0, 0, ProbabilityWeight::one(), s2))?;
fst.add_tr(s2, Tr::new(0, 0, ProbabilityWeight::one(), s1))?;
fst.add_tr(s2, Tr::new(0, 0, ProbabilityWeight::one(), s2))?;
assert_eq!(fst.num_trs(s1)?, 1);
assert_eq!(fst.num_trs(s2)?, 2);
assert_eq!(fst.get_trs(s1)?.len(), 1);
assert_eq!(fst.get_trs(s2)?.len(), 2);
fst.del_state(s1)?;
assert_eq!(fst.num_trs(0)?, 1);
let only_state = fst.states_iter().next().unwrap();
assert_eq!(fst.get_trs(only_state)?.len(), 1);
Ok(())
}
#[test]
fn test_deleting_twice_same_state() -> Result<()> {
let mut fst1 = VectorFst::<ProbabilityWeight>::new();
let s = fst1.add_state();
assert!(fst1.del_state(s).is_ok());
assert!(fst1.del_state(s).is_err());
Ok(())
}
#[test]
fn test_del_multiple_states() {
let mut fst1 = VectorFst::<ProbabilityWeight>::new();
let s1 = fst1.add_state();
let s2 = fst1.add_state();
let mut fst2 = fst1.clone();
assert!(fst1.del_state(s1).is_ok());
assert!(fst1.del_state(s2).is_err());
let states_to_remove = vec![s1, s2];
assert!(fst2.del_states(states_to_remove.into_iter()).is_ok());
}
#[test]
fn test_del_states_big() -> Result<()> {
let n_states = 1000;
let n_states_to_delete = 300;
let mut fst = VectorFst::<ProbabilityWeight>::new();
let mut states: Vec<_> = (0..n_states).map(|_| fst.add_state()).collect();
assert_eq!(fst.num_states(), n_states);
let mut rg = StdRng::from_seed([53; 32]);
rg.shuffle(&mut states);
let states_to_delete: Vec<_> = states.into_iter().take(n_states_to_delete).collect();
fst.del_states(states_to_delete)?;
assert_eq!(fst.num_states(), n_states - n_states_to_delete);
Ok(())
}
#[test]
fn test_parse_single_final_state() -> Result<()> {
let parsed_fst = VectorFst::<TropicalWeight>::from_text_string("0\tInfinity\n")?;
let mut fst_ref: VectorFst<TropicalWeight> = VectorFst::new();
fst_ref.add_state();
fst_ref.set_start(0)?;
assert_eq!(fst_ref, parsed_fst);
Ok(())
}
#[test]
fn test_del_all_states() -> Result<()> {
let mut fst = VectorFst::<ProbabilityWeight>::new();
let s1 = fst.add_state();
let s2 = fst.add_state();
fst.add_tr(s1, Tr::new(0, 0, ProbabilityWeight::one(), s2))?;
fst.add_tr(s2, Tr::new(0, 0, ProbabilityWeight::one(), s1))?;
fst.add_tr(s2, Tr::new(0, 0, ProbabilityWeight::one(), s2))?;
fst.set_start(s1)?;
fst.set_final(s2, ProbabilityWeight::one())?;
assert_eq!(fst.num_states(), 2);
fst.del_all_states();
assert_eq!(fst.num_states(), 0);
Ok(())
}
#[test]
fn test_attach_symt() -> Result<()> {
let mut fst = VectorFst::<ProbabilityWeight>::new();
let s1 = fst.add_state();
let s2 = fst.add_state();
fst.add_tr(s1, Tr::new(1, 0, ProbabilityWeight::one(), s2))?;
fst.add_tr(s2, Tr::new(2, 0, ProbabilityWeight::one(), s1))?;
fst.add_tr(s2, Tr::new(3, 0, ProbabilityWeight::one(), s2))?;
fst.set_start(s1)?;
fst.set_final(s2, ProbabilityWeight::one())?;
{
let mut symt = SymbolTable::new();
symt.add_symbol("a");
symt.add_symbol("b");
symt.add_symbol("c");
fst.set_input_symbols(Arc::new(symt));
}
{
let symt = fst.input_symbols();
assert!(symt.is_some());
let symt = symt.unwrap();
assert_eq!(symt.len(), 4);
}
{
let symt = SymbolTable::new();
fst.set_output_symbols(Arc::new(symt));
}
{
let symt = fst.output_symbols();
assert!(symt.is_some());
let symt = symt.unwrap();
assert_eq!(symt.len(), 1);
}
Ok(())
}
}