use rand::{Rng, SeedableRng};
use smallvec::SmallVec;
use crate::semiring::{Semiring, StochasticSemiring};
use crate::wfst::{StateId, WeightedTransition, Wfst};
#[derive(Clone, Debug)]
pub struct SampleConfig {
pub max_length: usize,
pub strategy: SampleStrategy,
pub include_epsilon: bool,
pub seed: Option<u64>,
}
impl Default for SampleConfig {
fn default() -> Self {
Self {
max_length: 10_000,
strategy: SampleStrategy::Proportional,
include_epsilon: false,
seed: None,
}
}
}
impl SampleConfig {
pub fn new() -> Self {
Self::default()
}
pub fn max_length(mut self, length: usize) -> Self {
self.max_length = length;
self
}
pub fn strategy(mut self, strategy: SampleStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn include_epsilon(mut self, include: bool) -> Self {
self.include_epsilon = include;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum SampleStrategy {
#[default]
Proportional,
Uniform,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum SampleError {
EmptyWfst,
MaxLengthExceeded,
NoAcceptingPaths,
DeadState(StateId),
ZeroWeights(StateId),
}
impl std::fmt::Display for SampleError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::EmptyWfst => write!(f, "WFST is empty"),
Self::MaxLengthExceeded => write!(f, "Maximum path length exceeded"),
Self::NoAcceptingPaths => write!(f, "WFST has no accepting paths"),
Self::DeadState(s) => write!(f, "Dead state encountered: {}", s),
Self::ZeroWeights(s) => write!(f, "All weights are zero at state {}", s),
}
}
}
impl std::error::Error for SampleError {}
#[derive(Clone, Debug)]
pub struct SampledPath<L, W> {
pub input_labels: Vec<Option<L>>,
pub output_labels: Vec<Option<L>>,
pub weight: W,
pub states: Vec<StateId>,
pub length: usize,
}
impl<L, W: Semiring> SampledPath<L, W> {
fn new(start: StateId) -> Self {
Self {
input_labels: Vec::new(),
output_labels: Vec::new(),
weight: W::one(),
states: vec![start],
length: 0,
}
}
fn extend(&mut self, trans: &WeightedTransition<L, W>)
where
L: Clone,
{
self.input_labels.push(trans.input.clone());
self.output_labels.push(trans.output.clone());
self.weight = self.weight.times(&trans.weight);
self.states.push(trans.to);
self.length += 1;
}
fn finalize(&mut self, final_weight: &W) {
self.weight = self.weight.times(final_weight);
}
pub fn input_string(&self) -> Vec<&L> {
self.input_labels
.iter()
.filter_map(|l| l.as_ref())
.collect()
}
pub fn output_string(&self) -> Vec<&L> {
self.output_labels
.iter()
.filter_map(|l| l.as_ref())
.collect()
}
}
pub fn sample_path<L, W, F>(
wfst: &F,
config: SampleConfig,
) -> Result<SampledPath<L, W>, SampleError>
where
L: Clone,
W: Semiring + StochasticSemiring,
F: Wfst<L, W>,
{
if wfst.is_empty() {
return Err(SampleError::EmptyWfst);
}
let mut rng: Box<dyn rand::RngCore> = match config.seed {
Some(seed) => Box::new(rand::rngs::StdRng::seed_from_u64(seed)),
None => Box::new(rand::rng()),
};
sample_path_with_rng(wfst, &config, &mut *rng)
}
fn sample_path_with_rng<L, W, F, R>(
wfst: &F,
config: &SampleConfig,
rng: &mut R,
) -> Result<SampledPath<L, W>, SampleError>
where
L: Clone,
W: Semiring + StochasticSemiring,
F: Wfst<L, W>,
R: Rng + ?Sized,
{
let start = wfst.start();
let mut path = SampledPath::new(start);
let mut current = start;
for _ in 0..config.max_length {
let transitions = wfst.transitions(current);
let is_final = wfst.is_final(current);
let final_weight = wfst.final_weight(current);
if transitions.is_empty() {
if is_final {
path.finalize(&final_weight);
return Ok(path);
} else {
return Err(SampleError::DeadState(current));
}
}
let should_stop = if is_final {
sample_stop_decision(transitions, &final_weight, config.strategy, rng)?
} else {
false
};
if should_stop {
path.finalize(&final_weight);
return Ok(path);
}
let trans = sample_transition(transitions, config.strategy, rng)?;
path.extend(trans);
current = trans.to;
}
Err(SampleError::MaxLengthExceeded)
}
fn sample_stop_decision<L, W, R>(
transitions: &[WeightedTransition<L, W>],
final_weight: &W,
strategy: SampleStrategy,
rng: &mut R,
) -> Result<bool, SampleError>
where
W: Semiring + StochasticSemiring,
R: Rng + ?Sized,
{
match strategy {
SampleStrategy::Uniform => {
let total_options = transitions.len() + 1; let stop_index: usize = rng.random_range(0..total_options);
Ok(stop_index == 0) }
SampleStrategy::Proportional => {
let final_prob = final_weight.to_probability();
let trans_sum: f64 = transitions.iter().map(|t| t.weight.to_probability()).sum();
let total = final_prob + trans_sum;
if total <= 0.0 {
return Ok(true); }
let threshold = final_prob / total;
let r: f64 = rng.random();
Ok(r < threshold)
}
}
}
fn sample_transition<'a, L, W, R>(
transitions: &'a [WeightedTransition<L, W>],
strategy: SampleStrategy,
rng: &mut R,
) -> Result<&'a WeightedTransition<L, W>, SampleError>
where
W: Semiring + StochasticSemiring,
R: Rng + ?Sized,
{
debug_assert!(!transitions.is_empty());
match strategy {
SampleStrategy::Uniform => {
let idx: usize = rng.random_range(0..transitions.len());
Ok(&transitions[idx])
}
SampleStrategy::Proportional => {
let weights: SmallVec<[f64; 8]> = transitions
.iter()
.map(|t| t.weight.to_probability())
.collect();
let total: f64 = weights.iter().sum();
if total <= 0.0 {
let idx: usize = rng.random_range(0..transitions.len());
return Ok(&transitions[idx]);
}
let r: f64 = rng.random::<f64>() * total;
let mut cumulative = 0.0;
for (i, &w) in weights.iter().enumerate() {
cumulative += w;
if r < cumulative {
return Ok(&transitions[i]);
}
}
Ok(transitions.last().expect("transitions not empty"))
}
}
}
pub fn sample_paths<L, W, F>(
wfst: &F,
count: usize,
config: SampleConfig,
) -> Vec<Result<SampledPath<L, W>, SampleError>>
where
L: Clone,
W: Semiring + StochasticSemiring,
F: Wfst<L, W>,
{
if wfst.is_empty() {
return vec![Err(SampleError::EmptyWfst); count];
}
let mut rng: Box<dyn rand::RngCore> = match config.seed {
Some(seed) => Box::new(rand::rngs::StdRng::seed_from_u64(seed)),
None => Box::new(rand::rng()),
};
(0..count)
.map(|_| sample_path_with_rng(wfst, &config, &mut *rng))
.collect()
}
pub fn sample_paths_until<L, W, F>(
wfst: &F,
target: usize,
max_attempts: usize,
config: SampleConfig,
) -> Vec<SampledPath<L, W>>
where
L: Clone,
W: Semiring + StochasticSemiring,
F: Wfst<L, W>,
{
if wfst.is_empty() {
return Vec::new();
}
let mut rng: Box<dyn rand::RngCore> = match config.seed {
Some(seed) => Box::new(rand::rngs::StdRng::seed_from_u64(seed)),
None => Box::new(rand::rng()),
};
let mut paths = Vec::with_capacity(target);
let mut attempts = 0;
while paths.len() < target && attempts < max_attempts {
if let Ok(path) = sample_path_with_rng(wfst, &config, &mut *rng) {
paths.push(path);
}
attempts += 1;
}
paths
}
pub fn estimate_expected_weight<L, W, F>(
wfst: &F,
num_samples: usize,
config: SampleConfig,
) -> Option<f64>
where
L: Clone,
W: Semiring + StochasticSemiring,
F: Wfst<L, W>,
{
if wfst.is_empty() || num_samples == 0 {
return None;
}
let paths = sample_paths_until(wfst, num_samples, num_samples * 10, config);
if paths.is_empty() {
return None;
}
let total: f64 = paths.iter().map(|p| p.weight.to_probability()).sum();
Some(total / paths.len() as f64)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
use crate::wfst::{MutableWfst, VectorWfst};
fn make_simple_wfst() -> VectorWfst<char, TropicalWeight> {
let mut wfst = VectorWfst::new();
let s0 = wfst.add_state();
let s1 = wfst.add_state();
let s2 = wfst.add_state();
wfst.set_start(s0);
wfst.set_final(s2, TropicalWeight::new(0.0));
wfst.add_arc(s0, Some('a'), Some('a'), s1, TropicalWeight::new(1.0));
wfst.add_arc(s1, Some('b'), Some('b'), s2, TropicalWeight::new(1.0));
wfst
}
fn make_branching_wfst() -> VectorWfst<char, TropicalWeight> {
let mut wfst = VectorWfst::new();
let s0 = wfst.add_state();
let s1 = wfst.add_state();
let s2 = wfst.add_state();
wfst.set_start(s0);
wfst.set_final(s1, TropicalWeight::new(0.0));
wfst.set_final(s2, TropicalWeight::new(0.0));
wfst.add_arc(s0, Some('a'), Some('x'), s1, TropicalWeight::new(1.0));
wfst.add_arc(s0, Some('b'), Some('y'), s2, TropicalWeight::new(2.0));
wfst
}
#[test]
fn test_sample_simple_path() {
let wfst = make_simple_wfst();
let config = SampleConfig::default().seed(42);
let path = sample_path(&wfst, config).expect("Should sample a path");
assert_eq!(path.input_string(), vec![&'a', &'b']);
assert_eq!(path.output_string(), vec![&'a', &'b']);
assert_eq!(path.length, 2);
assert_eq!(path.states.len(), 3);
}
#[test]
fn test_sample_uniform() {
let wfst = make_branching_wfst();
let config = SampleConfig::default()
.strategy(SampleStrategy::Uniform)
.seed(42);
let paths = sample_paths_until(&wfst, 100, 1000, config);
let a_count = paths
.iter()
.filter(|p| p.input_string() == vec![&'a'])
.count();
let b_count = paths
.iter()
.filter(|p| p.input_string() == vec![&'b'])
.count();
assert!(a_count > 0, "Should sample 'a' path");
assert!(b_count > 0, "Should sample 'b' path");
}
#[test]
fn test_sample_reproducible() {
let wfst = make_branching_wfst();
let config1 = SampleConfig::default().seed(12345);
let config2 = SampleConfig::default().seed(12345);
let path1 = sample_path(&wfst, config1).expect("Should sample");
let path2 = sample_path(&wfst, config2).expect("Should sample");
assert_eq!(path1.input_string(), path2.input_string());
}
#[test]
fn test_sample_empty_wfst() {
let wfst: VectorWfst<char, TropicalWeight> = VectorWfst::new();
let config = SampleConfig::default();
let result = sample_path(&wfst, config);
assert!(matches!(result, Err(SampleError::EmptyWfst)));
}
#[test]
fn test_sample_dead_state() {
let mut wfst = VectorWfst::<char, TropicalWeight>::new();
let s0 = wfst.add_state();
let s1 = wfst.add_state();
wfst.set_start(s0);
wfst.add_arc(s0, Some('a'), Some('a'), s1, TropicalWeight::new(1.0));
let config = SampleConfig::default().seed(42);
let result = sample_path(&wfst, config);
assert!(matches!(result, Err(SampleError::DeadState(_))));
}
#[test]
fn test_sample_multiple_paths() {
let wfst = make_branching_wfst();
let config = SampleConfig::default().seed(42);
let results = sample_paths(&wfst, 10, config);
assert_eq!(results.len(), 10);
let success_count = results.iter().filter(|r| r.is_ok()).count();
assert_eq!(success_count, 10);
}
#[test]
fn test_estimate_expected_weight() {
let wfst = make_simple_wfst();
let config = SampleConfig::default().seed(42);
let expected = estimate_expected_weight(&wfst, 100, config);
assert!(expected.is_some());
let e = expected.expect("algorithms/sample.rs: required value was None/Err");
assert!(e > 0.0, "Expected weight should be positive");
}
#[test]
fn test_sampled_path_methods() {
let wfst = make_simple_wfst();
let config = SampleConfig::default().seed(42);
let path = sample_path(&wfst, config).expect("Should sample");
let input = path.input_string();
let output = path.output_string();
assert_eq!(input.len(), 2);
assert_eq!(output.len(), 2);
assert_eq!(*input[0], 'a');
assert_eq!(*output[1], 'b');
}
}