use crate::semiring::Semiring;
use crate::wfst::{MutableWfst, StateId, VectorWfst, Wfst, NO_STATE};
#[derive(Clone, Debug)]
pub struct RescoreConfig {
pub determinize: bool,
pub minimize: bool,
pub pruning_threshold: Option<f64>,
pub max_states: Option<usize>,
pub interpolation_alpha: f64,
}
impl Default for RescoreConfig {
fn default() -> Self {
Self {
determinize: true,
minimize: true,
pruning_threshold: None,
max_states: None,
interpolation_alpha: 1.0, }
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RescorePass {
FirstPass,
SecondPass,
AdditionalPass(u32),
}
#[derive(Clone, Debug)]
pub struct LatticeGrammar<L: Clone, W: Semiring> {
pub fst: VectorWfst<L, W>,
pub source_pass: RescorePass,
pub stats: LatticeStats,
}
#[derive(Clone, Debug, Default)]
pub struct LatticeStats {
pub num_states: usize,
pub num_arcs: usize,
pub avg_arcs_per_state: f64,
pub density: Option<f64>,
}
impl<L: Clone + Send + Sync, W: Semiring + Clone> LatticeGrammar<L, W> {
pub fn new(fst: VectorWfst<L, W>, source_pass: RescorePass) -> Self {
let num_states = fst.num_states();
let num_arcs: usize = (0..num_states as StateId)
.map(|s| fst.transitions(s).len())
.sum();
let avg_arcs_per_state = if num_states > 0 {
num_arcs as f64 / num_states as f64
} else {
0.0
};
let stats = LatticeStats {
num_states,
num_arcs,
avg_arcs_per_state,
density: None,
};
Self {
fst,
source_pass,
stats,
}
}
pub fn with_density(mut self, density: f64) -> Self {
self.stats.density = Some(density);
self
}
}
#[derive(Clone, Debug)]
pub struct RescoreResult<L: Clone, W: Semiring> {
pub lattice: VectorWfst<L, W>,
pub stats: RescoreStats,
}
#[derive(Clone, Debug, Default)]
pub struct RescoreStats {
pub input_states: usize,
pub output_states: usize,
pub input_arcs: usize,
pub output_arcs: usize,
pub state_reduction: f64,
pub arc_reduction: f64,
}
impl RescoreStats {
pub fn compute_reductions(&mut self) {
if self.input_states > 0 {
self.state_reduction = 1.0 - (self.output_states as f64 / self.input_states as f64);
}
if self.input_arcs > 0 {
self.arc_reduction = 1.0 - (self.output_arcs as f64 / self.input_arcs as f64);
}
}
}
pub fn rescore_lattice<L, W>(
lattice: &LatticeGrammar<L, W>,
_new_lm: &VectorWfst<L, W>,
_config: &RescoreConfig,
) -> RescoreResult<L, W>
where
L: Clone + Eq + std::hash::Hash + Default + Send + Sync,
W: Semiring + Clone,
{
let mut stats = RescoreStats {
input_states: lattice.fst.num_states(),
input_arcs: count_arcs(&lattice.fst),
..Default::default()
};
let result_lattice = clone_lattice(&lattice.fst);
stats.output_states = result_lattice.num_states();
stats.output_arcs = count_arcs(&result_lattice);
stats.compute_reductions();
RescoreResult {
lattice: result_lattice,
stats,
}
}
fn count_arcs<L, W>(fst: &VectorWfst<L, W>) -> usize
where
L: Clone + Send + Sync,
W: Semiring,
{
(0..fst.num_states() as StateId)
.map(|s| fst.transitions(s).len())
.sum()
}
fn clone_lattice<L, W>(fst: &VectorWfst<L, W>) -> VectorWfst<L, W>
where
L: Clone + Send + Sync,
W: Semiring + Clone,
{
let mut result: VectorWfst<L, W> = VectorWfst::new();
for _ in 0..fst.num_states() {
result.add_state();
}
let start = fst.start();
if start != NO_STATE {
result.set_start(start);
}
for state in 0..fst.num_states() as StateId {
for arc in fst.transitions(state) {
result.add_arc(
state,
arc.input.clone(),
arc.output.clone(),
arc.to,
arc.weight.clone(),
);
}
if fst.is_final(state) {
let weight = fst.final_weight(state);
result.set_final(state, weight.clone());
}
}
result
}
pub fn multi_pass_rescore<L, W>(
initial_lattice: &LatticeGrammar<L, W>,
lm_sequence: &[VectorWfst<L, W>],
config: &RescoreConfig,
) -> Vec<RescoreResult<L, W>>
where
L: Clone + Eq + std::hash::Hash + Default + Send + Sync,
W: Semiring + Clone,
{
let mut results = Vec::with_capacity(lm_sequence.len());
let mut current_lattice = initial_lattice.clone();
for (i, lm) in lm_sequence.iter().enumerate() {
let result = rescore_lattice(¤t_lattice, lm, config);
current_lattice = LatticeGrammar::new(
result.lattice.clone(),
RescorePass::AdditionalPass(i as u32 + 1),
);
results.push(result);
}
results
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::LogWeight;
#[test]
fn test_rescore_config_default() {
let config = RescoreConfig::default();
assert!(config.determinize);
assert!(config.minimize);
assert!(config.pruning_threshold.is_none());
assert_eq!(config.interpolation_alpha, 1.0);
}
#[test]
fn test_rescore_pass() {
assert_eq!(RescorePass::FirstPass, RescorePass::FirstPass);
assert_ne!(RescorePass::FirstPass, RescorePass::SecondPass);
assert_eq!(
RescorePass::AdditionalPass(1),
RescorePass::AdditionalPass(1)
);
}
#[test]
fn test_lattice_grammar_new() {
let mut fst: VectorWfst<u32, LogWeight> = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some(1), Some(1), s1, LogWeight::one());
let grammar = LatticeGrammar::new(fst, RescorePass::FirstPass);
assert_eq!(grammar.stats.num_states, 2);
assert_eq!(grammar.stats.num_arcs, 1);
assert_eq!(grammar.stats.avg_arcs_per_state, 0.5);
}
#[test]
fn test_lattice_grammar_with_density() {
let fst = VectorWfst::<u32, LogWeight>::new();
let grammar = LatticeGrammar::new(fst, RescorePass::FirstPass).with_density(5.0);
assert_eq!(grammar.stats.density, Some(5.0));
}
#[test]
fn test_rescore_stats_compute_reductions() {
let mut stats = RescoreStats {
input_states: 100,
output_states: 50,
input_arcs: 200,
output_arcs: 100,
..Default::default()
};
stats.compute_reductions();
assert_eq!(stats.state_reduction, 0.5);
assert_eq!(stats.arc_reduction, 0.5);
}
#[test]
fn test_rescore_lattice_empty() {
let fst = VectorWfst::<u32, LogWeight>::new();
let lattice = LatticeGrammar::new(fst, RescorePass::FirstPass);
let lm = VectorWfst::<u32, LogWeight>::new();
let config = RescoreConfig::default();
let result = rescore_lattice(&lattice, &lm, &config);
assert_eq!(result.stats.input_states, 0);
assert_eq!(result.stats.output_states, 0);
}
#[test]
fn test_rescore_lattice_simple() {
let mut fst: VectorWfst<u32, LogWeight> = VectorWfst::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some(1), Some(1), s1, LogWeight::new(0.5));
let lattice = LatticeGrammar::new(fst, RescorePass::FirstPass);
let lm = VectorWfst::<u32, LogWeight>::new();
let config = RescoreConfig::default();
let result = rescore_lattice(&lattice, &lm, &config);
assert_eq!(result.stats.input_states, 2);
assert_eq!(result.stats.output_states, 2);
assert_eq!(result.stats.input_arcs, 1);
assert_eq!(result.stats.output_arcs, 1);
}
#[test]
fn test_multi_pass_rescore() {
let mut fst: VectorWfst<u32, LogWeight> = VectorWfst::new();
let s0 = fst.add_state();
fst.set_start(s0);
fst.set_final(s0, LogWeight::one());
let lattice = LatticeGrammar::new(fst, RescorePass::FirstPass);
let lm1 = VectorWfst::<u32, LogWeight>::new();
let lm2 = VectorWfst::<u32, LogWeight>::new();
let lm_sequence = vec![lm1, lm2];
let config = RescoreConfig::default();
let results = multi_pass_rescore(&lattice, &lm_sequence, &config);
assert_eq!(results.len(), 2);
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use crate::semiring::LogWeight;
use crate::wfst::Wfst;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn default_config_determinize(_seed in any::<u64>()) {
let config = RescoreConfig::default();
prop_assert!(config.determinize);
}
#[test]
fn default_config_minimize(_seed in any::<u64>()) {
let config = RescoreConfig::default();
prop_assert!(config.minimize);
}
#[test]
fn default_config_no_pruning(_seed in any::<u64>()) {
let config = RescoreConfig::default();
prop_assert!(config.pruning_threshold.is_none());
}
#[test]
fn default_config_no_max_states(_seed in any::<u64>()) {
let config = RescoreConfig::default();
prop_assert!(config.max_states.is_none());
}
#[test]
fn default_config_alpha(_seed in any::<u64>()) {
let config = RescoreConfig::default();
prop_assert!((config.interpolation_alpha - 1.0).abs() < 1e-10);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn pass_equality_reflexive(pass_num in 0u32..100) {
let pass = RescorePass::AdditionalPass(pass_num);
prop_assert_eq!(pass, pass);
}
#[test]
fn first_ne_second(_seed in any::<u64>()) {
prop_assert_ne!(RescorePass::FirstPass, RescorePass::SecondPass);
}
#[test]
fn different_additional_passes(a in 0u32..100, b in 100u32..200) {
prop_assert_ne!(
RescorePass::AdditionalPass(a),
RescorePass::AdditionalPass(b)
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn default_lattice_stats(_seed in any::<u64>()) {
let stats = LatticeStats::default();
prop_assert_eq!(stats.num_states, 0);
prop_assert_eq!(stats.num_arcs, 0);
prop_assert!((stats.avg_arcs_per_state - 0.0).abs() < 1e-10);
prop_assert!(stats.density.is_none());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(30))]
#[test]
fn lattice_grammar_state_count(num_states in 0usize..10) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
for _ in 0..num_states {
fst.add_state();
}
let grammar = LatticeGrammar::new(fst, RescorePass::FirstPass);
prop_assert_eq!(grammar.stats.num_states, num_states);
}
#[test]
fn lattice_grammar_arc_count(num_states in 2usize..6) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
let states: Vec<_> = (0..num_states).map(|_| fst.add_state()).collect();
for i in 0..states.len() - 1 {
fst.add_arc(states[i], Some(i as u32), Some(i as u32), states[i + 1], LogWeight::one());
}
let grammar = LatticeGrammar::new(fst, RescorePass::FirstPass);
prop_assert_eq!(grammar.stats.num_arcs, num_states - 1);
}
#[test]
fn lattice_grammar_avg_arcs(num_states in 2usize..6) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
let states: Vec<_> = (0..num_states).map(|_| fst.add_state()).collect();
for i in 0..states.len() - 1 {
fst.add_arc(states[i], Some(i as u32), Some(i as u32), states[i + 1], LogWeight::one());
}
let grammar = LatticeGrammar::new(fst, RescorePass::FirstPass);
let expected_avg = (num_states - 1) as f64 / num_states as f64;
prop_assert!((grammar.stats.avg_arcs_per_state - expected_avg).abs() < 1e-10);
}
#[test]
fn lattice_grammar_source_pass(pass_num in 0u32..100) {
let fst = VectorWfst::<u32, LogWeight>::new();
let pass = RescorePass::AdditionalPass(pass_num);
let grammar = LatticeGrammar::new(fst, pass);
prop_assert_eq!(grammar.source_pass, RescorePass::AdditionalPass(pass_num));
}
#[test]
fn lattice_grammar_density(density in 0.0f64..100.0) {
let fst = VectorWfst::<u32, LogWeight>::new();
let grammar = LatticeGrammar::new(fst, RescorePass::FirstPass)
.with_density(density);
prop_assert_eq!(grammar.stats.density, Some(density));
}
#[test]
fn empty_fst_zero_avg(_seed in any::<u64>()) {
let fst = VectorWfst::<u32, LogWeight>::new();
let grammar = LatticeGrammar::new(fst, RescorePass::FirstPass);
prop_assert!((grammar.stats.avg_arcs_per_state - 0.0).abs() < 1e-10);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn default_rescore_stats(_seed in any::<u64>()) {
let stats = RescoreStats::default();
prop_assert_eq!(stats.input_states, 0);
prop_assert_eq!(stats.output_states, 0);
prop_assert_eq!(stats.input_arcs, 0);
prop_assert_eq!(stats.output_arcs, 0);
prop_assert!((stats.state_reduction - 0.0).abs() < 1e-10);
prop_assert!((stats.arc_reduction - 0.0).abs() < 1e-10);
}
#[test]
fn state_reduction_correct(
input_states in 1usize..1000,
output_states in 0usize..1000
) {
let output_states = output_states.min(input_states);
let mut stats = RescoreStats {
input_states,
output_states,
..Default::default()
};
stats.compute_reductions();
let expected = 1.0 - (output_states as f64 / input_states as f64);
prop_assert!((stats.state_reduction - expected).abs() < 1e-10);
}
#[test]
fn arc_reduction_correct(
input_arcs in 1usize..1000,
output_arcs in 0usize..1000
) {
let output_arcs = output_arcs.min(input_arcs);
let mut stats = RescoreStats {
input_arcs,
output_arcs,
..Default::default()
};
stats.compute_reductions();
let expected = 1.0 - (output_arcs as f64 / input_arcs as f64);
prop_assert!((stats.arc_reduction - expected).abs() < 1e-10);
}
#[test]
fn zero_input_no_reduction(_seed in any::<u64>()) {
let mut stats = RescoreStats {
input_states: 0,
output_states: 0,
..Default::default()
};
stats.compute_reductions();
prop_assert!((stats.state_reduction - 0.0).abs() < 1e-10);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn rescore_empty_lattice(_seed in any::<u64>()) {
let fst = VectorWfst::<u32, LogWeight>::new();
let lattice = LatticeGrammar::new(fst, RescorePass::FirstPass);
let lm = VectorWfst::<u32, LogWeight>::new();
let config = RescoreConfig::default();
let result = rescore_lattice(&lattice, &lm, &config);
prop_assert_eq!(result.stats.input_states, 0);
prop_assert_eq!(result.stats.output_states, 0);
}
#[test]
fn rescore_preserves_states(num_states in 1usize..10) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
for _ in 0..num_states {
fst.add_state();
}
if num_states > 0 {
fst.set_start(0);
}
let lattice = LatticeGrammar::new(fst, RescorePass::FirstPass);
let lm = VectorWfst::<u32, LogWeight>::new();
let config = RescoreConfig::default();
let result = rescore_lattice(&lattice, &lm, &config);
prop_assert_eq!(result.stats.input_states, num_states);
prop_assert_eq!(result.stats.output_states, num_states);
}
#[test]
fn rescore_preserves_arcs(num_states in 2usize..6) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
let states: Vec<_> = (0..num_states).map(|_| fst.add_state()).collect();
fst.set_start(states[0]);
for i in 0..states.len() - 1 {
fst.add_arc(states[i], Some(i as u32), Some(i as u32), states[i + 1], LogWeight::one());
}
let lattice = LatticeGrammar::new(fst, RescorePass::FirstPass);
let lm = VectorWfst::<u32, LogWeight>::new();
let config = RescoreConfig::default();
let result = rescore_lattice(&lattice, &lm, &config);
let expected_arcs = num_states - 1;
prop_assert_eq!(result.stats.input_arcs, expected_arcs);
prop_assert_eq!(result.stats.output_arcs, expected_arcs);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(15))]
#[test]
fn multi_pass_result_count(num_passes in 0usize..5) {
let fst = VectorWfst::<u32, LogWeight>::new();
let lattice = LatticeGrammar::new(fst, RescorePass::FirstPass);
let lms: Vec<_> = (0..num_passes)
.map(|_| VectorWfst::<u32, LogWeight>::new())
.collect();
let config = RescoreConfig::default();
let results = multi_pass_rescore(&lattice, &lms, &config);
prop_assert_eq!(results.len(), num_passes);
}
#[test]
fn multi_pass_empty_lms(_seed in any::<u64>()) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
let s = fst.add_state();
fst.set_start(s);
fst.set_final(s, LogWeight::one());
let lattice = LatticeGrammar::new(fst, RescorePass::FirstPass);
let lms: Vec<VectorWfst<u32, LogWeight>> = vec![];
let config = RescoreConfig::default();
let results = multi_pass_rescore(&lattice, &lms, &config);
prop_assert!(results.is_empty());
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn clone_lattice_states(num_states in 0usize..10) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
for _ in 0..num_states {
fst.add_state();
}
let cloned = clone_lattice(&fst);
prop_assert_eq!(cloned.num_states(), num_states);
}
#[test]
fn clone_lattice_start(num_states in 1usize..10) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
for _ in 0..num_states {
fst.add_state();
}
fst.set_start(0);
let cloned = clone_lattice(&fst);
prop_assert_eq!(cloned.start(), 0);
}
#[test]
fn clone_lattice_finals(num_states in 1usize..5) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
for i in 0..num_states {
let s = fst.add_state();
if i % 2 == 0 {
fst.set_final(s, LogWeight::new(1.0));
}
}
let cloned = clone_lattice(&fst);
for i in 0..num_states as StateId {
prop_assert_eq!(cloned.is_final(i), fst.is_final(i));
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(30))]
#[test]
fn count_arcs_linear(num_states in 2usize..10) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
let states: Vec<_> = (0..num_states).map(|_| fst.add_state()).collect();
for i in 0..states.len() - 1 {
fst.add_arc(states[i], Some(i as u32), Some(i as u32), states[i + 1], LogWeight::one());
}
prop_assert_eq!(count_arcs(&fst), num_states - 1);
}
#[test]
fn count_arcs_empty(_seed in any::<u64>()) {
let fst = VectorWfst::<u32, LogWeight>::new();
prop_assert_eq!(count_arcs(&fst), 0);
}
#[test]
fn count_arcs_no_arcs(num_states in 1usize..10) {
let mut fst = VectorWfst::<u32, LogWeight>::new();
for _ in 0..num_states {
fst.add_state();
}
prop_assert_eq!(count_arcs(&fst), 0);
}
}
}