use std::borrow::Borrow;
use std::collections::BTreeMap;
use std::fmt::{Debug, Formatter};
use std::marker::PhantomData;
use anyhow::Result;
use crate::algorithms::randgen::rand_state::RandState;
use crate::algorithms::randgen::TrSelector;
use crate::prelude::Fst;
use crate::Semiring;
pub struct TrSampler<W: Semiring, F: Fst<W>, B: Borrow<F>, S: TrSelector> {
max_length: usize,
selector: S,
fst: B,
sample_map: BTreeMap<usize, usize>,
ghost: PhantomData<(W, F)>,
}
impl<W, F, B, S> Debug for TrSampler<W, F, B, S>
where
W: Semiring,
F: Fst<W>,
B: Borrow<F>,
S: TrSelector,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"TrSampler {{ max_length : {:?}, selector : {:?}, fst : {:?}, sample_map : {:?} }}",
self.max_length,
self.selector,
self.fst.borrow(),
self.sample_map
)
}
}
impl<W, F, B, S> TrSampler<W, F, B, S>
where
W: Semiring,
F: Fst<W>,
B: Borrow<F>,
S: TrSelector,
{
pub fn new(fst: B, selector: S, max_length: usize) -> Self {
Self {
fst,
selector,
max_length,
sample_map: BTreeMap::new(),
ghost: PhantomData,
}
}
pub fn sample(&mut self, rstate: &RandState) -> Result<bool> {
self.sample_map.clear();
if (self.fst.borrow().num_trs(rstate.state_id)? == 0
&& !self.fst.borrow().is_final(rstate.state_id)?)
|| rstate.length == self.max_length
{
return Ok(false);
}
for _ in 0..rstate.nsamples {
let selected = self
.selector
.select_tr(self.fst.borrow(), rstate.state_id)?;
*self.sample_map.entry(selected).or_insert(0) += 1;
}
Ok(true)
}
pub fn iter(&self) -> std::collections::btree_map::Iter<'_, usize, usize> {
self.sample_map.iter()
}
}