use std::collections::HashMap;
use crate::semiring::Semiring;
use crate::wfst::{StateId, VectorWfst, Wfst};
#[derive(Debug, Clone, Default)]
pub struct SparseGradient {
gradients: HashMap<usize, f64>,
num_arcs: usize,
}
impl SparseGradient {
pub fn new(num_arcs: usize) -> Self {
Self {
gradients: HashMap::new(),
num_arcs,
}
}
#[inline]
pub fn set(&mut self, arc_id: usize, value: f64) {
if value.abs() > 1e-10 {
self.gradients.insert(arc_id, value);
}
}
#[inline]
pub fn get(&self, arc_id: usize) -> f64 {
self.gradients.get(&arc_id).copied().unwrap_or(0.0)
}
#[inline]
pub fn add(&mut self, arc_id: usize, value: f64) {
let entry = self.gradients.entry(arc_id).or_insert(0.0);
*entry += value;
}
pub fn nnz(&self) -> usize {
self.gradients.len()
}
pub fn num_arcs(&self) -> usize {
self.num_arcs
}
pub fn sparsity(&self) -> f64 {
if self.num_arcs == 0 {
0.0
} else {
1.0 - (self.gradients.len() as f64 / self.num_arcs as f64)
}
}
pub fn to_dense(&self) -> Vec<f64> {
let mut dense = vec![0.0; self.num_arcs];
for (&arc_id, &value) in &self.gradients {
if arc_id < self.num_arcs {
dense[arc_id] = value;
}
}
dense
}
pub fn iter(&self) -> impl Iterator<Item = (usize, f64)> + '_ {
self.gradients.iter().map(|(&k, &v)| (k, v))
}
pub fn scale(&mut self, factor: f64) {
for value in self.gradients.values_mut() {
*value *= factor;
}
}
pub fn add_sparse(&mut self, other: &SparseGradient) {
for (&arc_id, &value) in &other.gradients {
self.add(arc_id, value);
}
}
}
#[derive(Debug)]
pub struct ComposedBackwardResult {
pub grad1: SparseGradient,
pub grad2: SparseGradient,
pub stats: BackwardStats,
}
#[derive(Debug, Clone, Default)]
pub struct BackwardStats {
pub states_visited: usize,
pub nonzero_arcs: usize,
pub total_gradient_mass: f64,
}
#[derive(Debug, Clone)]
pub struct ForwardBackwardScores {
pub alpha: Vec<f64>,
pub beta: Vec<f64>,
pub total_log_prob: f64,
}
impl ForwardBackwardScores {
pub fn new(num_states: usize) -> Self {
Self {
alpha: vec![f64::NEG_INFINITY; num_states],
beta: vec![f64::NEG_INFINITY; num_states],
total_log_prob: f64::NEG_INFINITY,
}
}
#[inline]
pub fn arc_posterior(&self, from_alpha: f64, arc_weight: f64, to_beta: f64) -> f64 {
let log_posterior = from_alpha + arc_weight + to_beta - self.total_log_prob;
if log_posterior > f64::NEG_INFINITY {
log_posterior.exp()
} else {
0.0
}
}
}
pub fn forward_backward<L, W>(fst: &VectorWfst<L, W>) -> ForwardBackwardScores
where
L: Clone + Eq + std::hash::Hash + Send + Sync,
W: Semiring + Into<f64> + Clone,
{
use std::collections::VecDeque;
let num_states = fst.num_states();
let mut scores = ForwardBackwardScores::new(num_states);
if num_states == 0 {
return scores;
}
let start = fst.start();
scores.alpha[start as usize] = 0.0;
let mut in_degree = vec![0usize; num_states];
for state in 0..num_states as StateId {
for tr in fst.transitions(state) {
in_degree[tr.to as usize] += 1;
}
}
let mut queue: VecDeque<StateId> = VecDeque::new();
let mut processed = vec![false; num_states];
queue.push_back(start);
processed[start as usize] = true;
for state in 0..num_states as StateId {
if in_degree[state as usize] == 0 && state != start {
queue.push_back(state);
processed[state as usize] = true;
}
}
let mut remaining_in = in_degree.clone();
while let Some(state) = queue.pop_front() {
if scores.alpha[state as usize] <= f64::NEG_INFINITY {
for tr in fst.transitions(state) {
remaining_in[tr.to as usize] = remaining_in[tr.to as usize].saturating_sub(1);
if remaining_in[tr.to as usize] == 0 && !processed[tr.to as usize] {
queue.push_back(tr.to);
processed[tr.to as usize] = true;
}
}
continue;
}
for tr in fst.transitions(state) {
let arc_weight: f64 = tr.weight.clone().into();
let new_alpha = scores.alpha[state as usize] + arc_weight;
scores.alpha[tr.to as usize] = log_add(scores.alpha[tr.to as usize], new_alpha);
remaining_in[tr.to as usize] = remaining_in[tr.to as usize].saturating_sub(1);
if remaining_in[tr.to as usize] == 0 && !processed[tr.to as usize] {
queue.push_back(tr.to);
processed[tr.to as usize] = true;
}
}
}
for state in 0..num_states as StateId {
if fst.is_final(state) {
let final_weight: f64 = fst.final_weight(state).into();
scores.beta[state as usize] = final_weight;
}
}
let mut out_degree = vec![0usize; num_states];
for state in 0..num_states as StateId {
out_degree[state as usize] = fst.transitions(state).len();
}
let mut reverse_queue: VecDeque<StateId> = VecDeque::new();
let mut reverse_processed = vec![false; num_states];
let mut remaining_out = out_degree.clone();
for state in 0..num_states as StateId {
if out_degree[state as usize] == 0 {
reverse_queue.push_back(state);
reverse_processed[state as usize] = true;
}
}
let mut reverse_adj: Vec<Vec<(StateId, f64)>> = vec![Vec::new(); num_states];
for state in 0..num_states as StateId {
for tr in fst.transitions(state) {
let arc_weight: f64 = tr.weight.clone().into();
reverse_adj[tr.to as usize].push((state, arc_weight));
}
}
while let Some(state) = reverse_queue.pop_front() {
for &(from_state, arc_weight) in &reverse_adj[state as usize] {
if scores.beta[state as usize] > f64::NEG_INFINITY {
let new_beta = arc_weight + scores.beta[state as usize];
scores.beta[from_state as usize] =
log_add(scores.beta[from_state as usize], new_beta);
}
remaining_out[from_state as usize] =
remaining_out[from_state as usize].saturating_sub(1);
if remaining_out[from_state as usize] == 0 && !reverse_processed[from_state as usize] {
reverse_queue.push_back(from_state);
reverse_processed[from_state as usize] = true;
}
}
}
scores.total_log_prob = scores.alpha[start as usize] + scores.beta[start as usize];
scores
}
pub fn topdown_backward<L, W>(
fst: &VectorWfst<L, W>,
fb_scores: &ForwardBackwardScores,
) -> SparseGradient
where
L: Clone + Eq + std::hash::Hash + Send + Sync,
W: Semiring + Into<f64> + Clone,
{
let num_states = fst.num_states();
let mut total_arcs = 0;
for state in 0..num_states as StateId {
total_arcs += fst.transitions(state).len();
}
let mut gradients = SparseGradient::new(total_arcs);
let mut arc_id = 0;
for state in 0..num_states as StateId {
let alpha_s = fb_scores.alpha[state as usize];
if alpha_s <= f64::NEG_INFINITY {
arc_id += fst.transitions(state).len();
continue;
}
for tr in fst.transitions(state) {
let beta_t = fb_scores.beta[tr.to as usize];
if beta_t <= f64::NEG_INFINITY {
arc_id += 1;
continue;
}
let arc_weight: f64 = tr.weight.clone().into();
let posterior = fb_scores.arc_posterior(alpha_s, arc_weight, beta_t);
if posterior > 1e-10 {
gradients.set(arc_id, -posterior);
}
arc_id += 1;
}
}
gradients
}
#[derive(Debug, Clone)]
pub struct PrunedBackwardConfig {
pub beam: f64,
pub normalize: bool,
pub min_gradient: f64,
}
impl Default for PrunedBackwardConfig {
fn default() -> Self {
Self {
beam: 10.0,
normalize: true,
min_gradient: 1e-10,
}
}
}
#[derive(Debug)]
pub struct PrunedSearchResult<W: Semiring> {
pub surviving_states: HashMap<StateId, StateId>,
pub surviving_arcs: HashMap<usize, usize>,
pub forward_scores: Vec<f64>,
pub best_score: f64,
pub beam: f64,
_phantom: std::marker::PhantomData<W>,
}
impl<W: Semiring> PrunedSearchResult<W> {
pub fn new(beam: f64) -> Self {
Self {
surviving_states: HashMap::new(),
surviving_arcs: HashMap::new(),
forward_scores: Vec::new(),
best_score: f64::NEG_INFINITY,
beam,
_phantom: std::marker::PhantomData,
}
}
pub fn add_state(&mut self, original: StateId, forward_score: f64) {
let pruned_id = self.surviving_states.len() as StateId;
self.surviving_states.insert(original, pruned_id);
self.forward_scores.push(forward_score);
if forward_score > self.best_score {
self.best_score = forward_score;
}
}
pub fn add_arc(&mut self, original: usize) {
let pruned_id = self.surviving_arcs.len();
self.surviving_arcs.insert(original, pruned_id);
}
pub fn state_survived(&self, state: StateId) -> bool {
self.surviving_states.contains_key(&state)
}
pub fn arc_survived(&self, arc_id: usize) -> bool {
self.surviving_arcs.contains_key(&arc_id)
}
}
pub fn pruned_search_backward<L, W>(
fst: &VectorWfst<L, W>,
search_result: &PrunedSearchResult<W>,
output_grad: f64,
config: &PrunedBackwardConfig,
) -> SparseGradient
where
L: Clone + Eq + std::hash::Hash + Send + Sync,
W: Semiring + Into<f64> + Clone,
{
let num_surviving = search_result.surviving_arcs.len();
let mut gradients = SparseGradient::new(num_surviving);
let num_surviving_states = search_result.surviving_states.len();
let mut beta = vec![f64::NEG_INFINITY; num_surviving_states];
for (&orig_state, &pruned_id) in &search_result.surviving_states {
if fst.is_final(orig_state) {
let final_weight: f64 = fst.final_weight(orig_state).into();
beta[pruned_id as usize] = final_weight;
}
}
let num_states = fst.num_states();
for _ in 0..num_surviving_states + 1 {
let mut changed = false;
let mut arc_id = 0;
for state in 0..num_states as StateId {
let pruned_from = match search_result.surviving_states.get(&state) {
Some(&id) => id,
None => {
arc_id += fst.transitions(state).len();
continue;
}
};
for tr in fst.transitions(state) {
if search_result.arc_survived(arc_id) {
if let Some(&pruned_to) = search_result.surviving_states.get(&tr.to) {
if beta[pruned_to as usize] > f64::NEG_INFINITY {
let arc_weight: f64 = tr.weight.clone().into();
let new_beta = arc_weight + beta[pruned_to as usize];
let old_beta = beta[pruned_from as usize];
let updated = log_add(old_beta, new_beta);
if (updated - old_beta).abs() > 1e-10 {
beta[pruned_from as usize] = updated;
changed = true;
}
}
}
}
arc_id += 1;
}
}
if !changed {
break;
}
}
let start = fst.start();
let total_log_prob = if let Some(&pruned_start) = search_result.surviving_states.get(&start) {
let alpha = search_result
.forward_scores
.get(pruned_start as usize)
.copied()
.unwrap_or(f64::NEG_INFINITY);
alpha + beta[pruned_start as usize]
} else {
f64::NEG_INFINITY
};
let mut arc_id = 0;
for state in 0..num_states as StateId {
let pruned_from = match search_result.surviving_states.get(&state) {
Some(&id) => id,
None => {
arc_id += fst.transitions(state).len();
continue;
}
};
let alpha_s = search_result
.forward_scores
.get(pruned_from as usize)
.copied()
.unwrap_or(f64::NEG_INFINITY);
for tr in fst.transitions(state) {
if let Some(&pruned_arc_id) = search_result.surviving_arcs.get(&arc_id) {
if let Some(&pruned_to) = search_result.surviving_states.get(&tr.to) {
let beta_t = beta[pruned_to as usize];
let arc_weight: f64 = tr.weight.clone().into();
let log_posterior = alpha_s + arc_weight + beta_t - total_log_prob;
let posterior = if log_posterior > f64::NEG_INFINITY {
log_posterior.exp()
} else {
0.0
};
if posterior > config.min_gradient {
gradients.set(pruned_arc_id, -posterior * output_grad);
}
}
}
arc_id += 1;
}
}
if config.normalize && gradients.nnz() > 0 {
let sum: f64 = gradients.iter().map(|(_, g)| g.abs()).sum();
if sum > 1e-10 {
gradients.scale(1.0 / sum);
}
}
gradients
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ComposedState {
pub s1: StateId,
pub s2: StateId,
}
#[derive(Debug, Clone, Copy)]
pub struct ComposedArcInfo {
pub source: StateId,
pub dest: StateId,
pub log_weight: f64,
pub arc1: Option<usize>,
pub arc2: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct ComposedArcMap {
arc_origins: HashMap<usize, (Option<usize>, Option<usize>)>,
arc_info: Vec<ComposedArcInfo>,
}
impl ComposedArcMap {
pub fn new() -> Self {
Self {
arc_origins: HashMap::new(),
arc_info: Vec::new(),
}
}
pub fn add(&mut self, composed_arc: usize, arc1: Option<usize>, arc2: Option<usize>) {
self.arc_origins.insert(composed_arc, (arc1, arc2));
}
pub fn add_with_info(
&mut self,
source: StateId,
dest: StateId,
log_weight: f64,
arc1: Option<usize>,
arc2: Option<usize>,
) {
let idx = self.arc_info.len();
self.arc_origins.insert(idx, (arc1, arc2));
self.arc_info.push(ComposedArcInfo {
source,
dest,
log_weight,
arc1,
arc2,
});
}
pub fn get(&self, composed_arc: usize) -> Option<(Option<usize>, Option<usize>)> {
self.arc_origins.get(&composed_arc).copied()
}
pub fn arc_infos(&self) -> impl Iterator<Item = &ComposedArcInfo> {
self.arc_info.iter()
}
pub fn has_arc_info(&self) -> bool {
!self.arc_info.is_empty()
}
}
impl Default for ComposedArcMap {
fn default() -> Self {
Self::new()
}
}
pub fn composed_backward<L, W>(
fst1: &VectorWfst<L, W>,
fst2: &VectorWfst<L, W>,
composed_fb: &ForwardBackwardScores,
arc_map: &ComposedArcMap,
output_grad: f64,
) -> ComposedBackwardResult
where
L: Clone + Eq + std::hash::Hash + Send + Sync,
W: Semiring + Into<f64> + Clone,
{
let mut num_arcs1 = 0;
for state in 0..fst1.num_states() as StateId {
num_arcs1 += fst1.transitions(state).len();
}
let mut num_arcs2 = 0;
for state in 0..fst2.num_states() as StateId {
num_arcs2 += fst2.transitions(state).len();
}
let mut grad1 = SparseGradient::new(num_arcs1);
let mut grad2 = SparseGradient::new(num_arcs2);
let mut stats = BackwardStats::default();
if arc_map.has_arc_info() {
for arc_info in arc_map.arc_infos() {
let src = arc_info.source as usize;
let dst = arc_info.dest as usize;
if src >= composed_fb.alpha.len() || dst >= composed_fb.beta.len() {
continue;
}
let posterior = composed_fb.arc_posterior(
composed_fb.alpha[src],
arc_info.log_weight,
composed_fb.beta[dst],
);
if posterior <= 0.0 {
continue;
}
let grad_value = -posterior * output_grad;
if let Some(arc1) = arc_info.arc1 {
grad1.add(arc1, grad_value);
stats.nonzero_arcs += 1;
}
if let Some(arc2) = arc_info.arc2 {
grad2.add(arc2, grad_value);
stats.nonzero_arcs += 1;
}
stats.total_gradient_mass += grad_value.abs();
}
} else {
let num_arcs = arc_map.arc_origins.len();
let uniform_weight = if num_arcs > 0 {
1.0 / num_arcs as f64
} else {
0.0
};
for &(arc1_opt, arc2_opt) in arc_map.arc_origins.values() {
let grad_value = -output_grad * uniform_weight;
if let Some(arc1) = arc1_opt {
grad1.add(arc1, grad_value);
stats.nonzero_arcs += 1;
}
if let Some(arc2) = arc2_opt {
grad2.add(arc2, grad_value);
stats.nonzero_arcs += 1;
}
stats.total_gradient_mass += grad_value.abs();
}
}
stats.states_visited = composed_fb.alpha.len();
ComposedBackwardResult {
grad1,
grad2,
stats,
}
}
#[inline]
fn log_add(a: f64, b: f64) -> f64 {
if a == f64::NEG_INFINITY {
b
} else if b == f64::NEG_INFINITY {
a
} else if a > b {
a + (1.0 + (b - a).exp()).ln()
} else {
b + (1.0 + (a - b).exp()).ln()
}
}
pub fn count_arcs<L, W>(fst: &VectorWfst<L, W>) -> usize
where
L: Clone + Eq + std::hash::Hash + Send + Sync,
W: Semiring,
{
let mut total = 0;
for state in 0..fst.num_states() as StateId {
total += fst.transitions(state).len();
}
total
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::LogWeight;
use crate::wfst::MutableWfst;
#[test]
fn test_sparse_gradient_basic() {
let mut grad = SparseGradient::new(10);
grad.set(0, 0.5);
grad.set(5, -0.3);
assert_eq!(grad.nnz(), 2);
assert!((grad.get(0) - 0.5).abs() < 1e-10);
assert!((grad.get(5) - (-0.3)).abs() < 1e-10);
assert!((grad.get(3) - 0.0).abs() < 1e-10);
}
#[test]
fn test_sparse_gradient_add() {
let mut grad = SparseGradient::new(10);
grad.add(0, 0.5);
grad.add(0, 0.3);
assert!((grad.get(0) - 0.8).abs() < 1e-10);
}
#[test]
fn test_sparse_gradient_sparsity() {
let mut grad = SparseGradient::new(100);
grad.set(0, 0.5);
grad.set(50, 0.3);
assert!((grad.sparsity() - 0.98).abs() < 1e-10);
}
#[test]
fn test_sparse_gradient_to_dense() {
let mut grad = SparseGradient::new(5);
grad.set(1, 0.5);
grad.set(3, -0.3);
let dense = grad.to_dense();
assert_eq!(dense.len(), 5);
assert!((dense[0] - 0.0).abs() < 1e-10);
assert!((dense[1] - 0.5).abs() < 1e-10);
assert!((dense[3] - (-0.3)).abs() < 1e-10);
}
#[test]
fn test_forward_backward_single_path() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0));
let fb = forward_backward(&fst);
assert!((fb.alpha[0] - 0.0).abs() < 1e-6);
assert!((fb.alpha[1] - 1.0).abs() < 1e-6);
assert!((fb.beta[1] - 0.0).abs() < 1e-6);
assert!((fb.beta[0] - 1.0).abs() < 1e-6);
assert!((fb.total_log_prob - 1.0).abs() < 1e-6);
}
#[test]
fn test_forward_backward_two_paths() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0));
fst.add_arc(s0, Some('b'), Some('b'), s1, LogWeight::new(2.0));
let fb = forward_backward(&fst);
let expected_alpha1 = log_add(1.0, 2.0);
assert!((fb.alpha[1] - expected_alpha1).abs() < 1e-6);
}
#[test]
fn test_topdown_backward_single_path() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0));
let fb = forward_backward(&fst);
let grads = topdown_backward(&fst, &fb);
assert_eq!(grads.nnz(), 1);
assert!((grads.get(0) - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_pruned_search_result() {
let mut result = PrunedSearchResult::<LogWeight>::new(10.0);
result.add_state(0, -5.0);
result.add_state(1, -3.0);
result.add_arc(0);
assert!(result.state_survived(0));
assert!(result.state_survived(1));
assert!(!result.state_survived(2));
assert!(result.arc_survived(0));
assert!(!result.arc_survived(1));
assert!((result.best_score - (-3.0)).abs() < 1e-10);
}
#[test]
fn test_composed_arc_map() {
let mut map = ComposedArcMap::new();
map.add(0, Some(0), Some(1));
map.add(1, Some(2), None);
map.add(2, None, Some(3));
assert_eq!(map.get(0), Some((Some(0), Some(1))));
assert_eq!(map.get(1), Some((Some(2), None)));
assert_eq!(map.get(2), Some((None, Some(3))));
assert_eq!(map.get(3), None);
}
#[test]
fn test_log_add() {
assert!((log_add(0.0, 0.0) - 0.693).abs() < 0.01); assert!((log_add(f64::NEG_INFINITY, 0.0) - 0.0).abs() < 0.001);
assert!((log_add(0.0, f64::NEG_INFINITY) - 0.0).abs() < 0.001);
}
#[test]
fn test_count_arcs() {
let mut fst = VectorWfst::<char, LogWeight>::new();
let s0 = fst.add_state();
let s1 = fst.add_state();
fst.set_start(s0);
fst.set_final(s1, LogWeight::one());
fst.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0));
fst.add_arc(s0, Some('b'), Some('b'), s1, LogWeight::new(2.0));
assert_eq!(count_arcs(&fst), 2);
}
#[test]
fn test_composed_arc_map_with_info() {
let mut map = ComposedArcMap::new();
map.add_with_info(0, 1, 1.5, Some(0), Some(0));
map.add_with_info(1, 2, 2.0, Some(1), None);
assert!(map.has_arc_info());
let infos: Vec<_> = map.arc_infos().collect();
assert_eq!(infos.len(), 2);
assert_eq!(infos[0].source, 0);
assert_eq!(infos[0].dest, 1);
assert!((infos[0].log_weight - 1.5).abs() < 1e-10);
assert_eq!(infos[0].arc1, Some(0));
assert_eq!(infos[0].arc2, Some(0));
assert_eq!(infos[1].source, 1);
assert_eq!(infos[1].dest, 2);
assert!((infos[1].log_weight - 2.0).abs() < 1e-10);
assert_eq!(infos[1].arc1, Some(1));
assert_eq!(infos[1].arc2, None);
}
#[test]
fn test_composed_backward_with_posteriors() {
let mut fst1 = VectorWfst::<char, LogWeight>::new();
let s0 = fst1.add_state();
let s1 = fst1.add_state();
fst1.set_start(s0);
fst1.set_final(s1, LogWeight::one());
fst1.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0));
let mut fst2 = VectorWfst::<char, LogWeight>::new();
let t0 = fst2.add_state();
let t1 = fst2.add_state();
fst2.set_start(t0);
fst2.set_final(t1, LogWeight::one());
fst2.add_arc(t0, Some('a'), Some('a'), t1, LogWeight::new(0.5));
let mut fb = ForwardBackwardScores::new(2);
fb.alpha[0] = 0.0; fb.alpha[1] = 1.5; fb.beta[1] = 0.0; fb.beta[0] = 1.5; fb.total_log_prob = 1.5;
let mut arc_map = ComposedArcMap::new();
arc_map.add_with_info(0, 1, 1.5, Some(0), Some(0));
let result = composed_backward(&fst1, &fst2, &fb, &arc_map, 1.0);
assert!(
(result.grad1.get(0) - (-1.0)).abs() < 1e-6,
"grad1[0] = {}, expected -1.0",
result.grad1.get(0)
);
assert!(
(result.grad2.get(0) - (-1.0)).abs() < 1e-6,
"grad2[0] = {}, expected -1.0",
result.grad2.get(0)
);
}
#[test]
fn test_composed_backward_two_paths() {
let mut fst1 = VectorWfst::<char, LogWeight>::new();
let s0 = fst1.add_state();
let s1 = fst1.add_state();
fst1.set_start(s0);
fst1.set_final(s1, LogWeight::one());
fst1.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0)); fst1.add_arc(s0, Some('b'), Some('b'), s1, LogWeight::new(2.0));
let mut fst2 = VectorWfst::<char, LogWeight>::new();
let t0 = fst2.add_state();
let t1 = fst2.add_state();
fst2.set_start(t0);
fst2.set_final(t1, LogWeight::one());
fst2.add_arc(t0, Some('a'), Some('a'), t1, LogWeight::new(0.0)); fst2.add_arc(t0, Some('b'), Some('b'), t1, LogWeight::new(0.0));
let total = log_add(1.0, 2.0);
let mut fb = ForwardBackwardScores::new(2);
fb.alpha[0] = 0.0;
fb.alpha[1] = total;
fb.beta[1] = 0.0;
fb.beta[0] = total;
fb.total_log_prob = total;
let mut arc_map = ComposedArcMap::new();
arc_map.add_with_info(0, 1, 1.0, Some(0), Some(0)); arc_map.add_with_info(0, 1, 2.0, Some(1), Some(1));
let result = composed_backward(&fst1, &fst2, &fb, &arc_map, 1.0);
let posterior_a = fb.arc_posterior(0.0, 1.0, 0.0);
let posterior_b = fb.arc_posterior(0.0, 2.0, 0.0);
let sum = posterior_a + posterior_b;
assert!(
(sum - 1.0).abs() < 0.01,
"Posteriors sum to {}, expected ~1.0",
sum
);
assert!(
(result.grad1.get(0) - (-posterior_a)).abs() < 1e-6,
"grad1[0] = {}, expected {}",
result.grad1.get(0),
-posterior_a
);
assert!(
(result.grad1.get(1) - (-posterior_b)).abs() < 1e-6,
"grad1[1] = {}, expected {}",
result.grad1.get(1),
-posterior_b
);
}
#[test]
fn test_composed_backward_legacy_fallback() {
let mut fst1 = VectorWfst::<char, LogWeight>::new();
let s0 = fst1.add_state();
let s1 = fst1.add_state();
fst1.set_start(s0);
fst1.set_final(s1, LogWeight::one());
fst1.add_arc(s0, Some('a'), Some('a'), s1, LogWeight::new(1.0));
let fst2 = fst1.clone();
let fb = ForwardBackwardScores::new(2);
let mut arc_map = ComposedArcMap::new();
arc_map.add(0, Some(0), Some(0));
let result = composed_backward(&fst1, &fst2, &fb, &arc_map, 1.0);
assert!(result.grad1.nnz() > 0 || result.grad2.nnz() > 0);
}
mod property_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(64))]
#[test]
fn sparse_gradient_set_get_roundtrip(
arc_id in 0usize..128,
value in (-1.0e3f64..1.0e3).prop_filter("non-tiny", |v| v.abs() > 1e-6)
) {
let mut grad = SparseGradient::new(128);
grad.set(arc_id, value);
prop_assert!((grad.get(arc_id) - value).abs() < 1e-9);
}
#[test]
fn sparse_gradient_dense_len_equals_num_arcs(
num_arcs in 0usize..256,
writes in proptest::collection::vec(
(0usize..256, -10.0f64..10.0),
0..16,
),
) {
let mut grad = SparseGradient::new(num_arcs);
for (idx, val) in writes {
if idx < num_arcs {
grad.set(idx, val);
}
}
prop_assert_eq!(grad.to_dense().len(), num_arcs);
}
#[test]
fn sparse_gradient_scale_zero_zeros(
writes in proptest::collection::vec(
(0usize..64, -100.0f64..100.0),
1..16,
),
) {
let mut grad = SparseGradient::new(64);
for (idx, val) in &writes {
grad.set(*idx, *val);
}
grad.scale(0.0);
for (idx, _) in &writes {
prop_assert!(grad.get(*idx).abs() < 1e-12);
}
}
}
}
}