mod forward_score;
mod gradient;
mod layers;
mod marginalization;
mod ngram_pruning;
mod second_order;
mod token_graphs;
mod topdown;
mod viterbi;
pub use forward_score::{forward_score, log_sum_exp_paths};
pub use gradient::{backward, ArcGradient, GradientAccumulator, GradientWfst};
pub use viterbi::{viterbi_path_with_grad, viterbi_score, ViterbiGradResult};
pub use layers::{
wfst_conv_backward, wfst_conv_forward, wfst_conv_forward_with_gradients, PaddingMode,
ReceptiveField, WfstConvConfig, WfstConvLayer, WfstConvOutput, WfstKernel,
};
pub use token_graphs::{
build_blank_graph, build_token_graph, build_vocabulary_graph, TokenGraphConfig,
TokenGraphStats, TokenGraphType, TokenId, BLANK_TOKEN,
};
pub use marginalization::{
build_character_lexicon, build_identity_lexicon, build_lexicon_transducer, build_target_graph,
marginalized_loss, GraphemeId, LexiconConfig, LexiconEntry, MarginalizationContext,
MarginalizationResult, MarginalizationStats, WordPieceId,
};
pub use ngram_pruning::{
build_pruned_bigram_graph, build_pruned_trigram_graph, NgramCounts, PrunedNgramConfig,
PrunedNgramStats,
};
pub use second_order::{
compute_diagonal_hessian, compute_fisher_information, gradient_and_hessian,
hessian_vector_product, natural_gradient, HessianMatrix, SecondOrderConfig, SecondOrderResult,
SecondOrderWfst,
};
pub use topdown::{
composed_backward, count_arcs, forward_backward, pruned_search_backward, topdown_backward,
BackwardStats, ComposedArcMap, ComposedBackwardResult, ComposedState, ForwardBackwardScores,
PrunedBackwardConfig, PrunedSearchResult, SparseGradient,
};
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::{LogWeight, Semiring};
use crate::wfst::{MutableWfst, VectorWfst};
#[test]
fn test_forward_score_single_path() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(-1.0));
let grad_fst = GradientWfst::from_wfst(&fst);
let score = forward_score(&grad_fst);
assert!((score.value() - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_forward_score_two_paths() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0)); fst.add_arc(s0, Some('b'), Some('b'), s1, LogWeight::new(2.0));
let grad_fst = GradientWfst::from_wfst(&fst);
let score = forward_score(&grad_fst);
let expected = -((-1.0_f64).exp() + (-2.0_f64).exp()).ln();
assert!((score.value() - expected).abs() < 1e-6);
}
#[test]
fn test_viterbi_score() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(-1.0));
fst.add_arc(s0, Some('b'), Some('b'), s1, LogWeight::new(-2.0));
let grad_fst = GradientWfst::from_wfst(&fst);
let score = viterbi_score(&grad_fst);
assert!((score.value() - (-2.0)).abs() < 1e-6);
}
#[test]
fn test_backward_gradients() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(-1.0));
let grad_fst = GradientWfst::from_wfst(&fst);
let _ = forward_score(&grad_fst);
let gradients = backward(&grad_fst);
assert_eq!(gradients.arc_gradients.len(), 1);
let grad = gradients.arc_gradients[0].gradient;
assert!((grad - 1.0).abs() < 1e-6);
}
}