use crate::semiring::Semiring;
use crate::transducer::{Label, BLANK};
use crate::wfst::{MutableWfst, StateId, VectorWfst, WeightedTransition, Wfst};
#[derive(Debug, Clone)]
pub struct WstConfig {
pub token_bypass_weight: f64,
pub blank_bypass_weight: f64,
pub confidence_threshold: f64,
pub max_bypass_span: usize,
pub allow_universal_bypass: bool,
pub universal_bypass_weight: f64,
}
impl Default for WstConfig {
fn default() -> Self {
Self {
token_bypass_weight: 2.0,
blank_bypass_weight: 0.5,
confidence_threshold: 0.5,
max_bypass_span: 3,
allow_universal_bypass: false,
universal_bypass_weight: 5.0,
}
}
}
#[derive(Debug, Clone)]
pub struct ConfidentToken {
pub label: Label,
pub confidence: f64,
pub alternatives: Vec<(Label, f64)>,
}
impl ConfidentToken {
pub fn new(label: Label, confidence: f64) -> Self {
Self {
label,
confidence,
alternatives: Vec::new(),
}
}
pub fn with_alternatives(
label: Label,
confidence: f64,
alternatives: Vec<(Label, f64)>,
) -> Self {
Self {
label,
confidence,
alternatives,
}
}
}
pub fn build_wst_graph<W>(targets: &[ConfidentToken], config: &WstConfig) -> VectorWfst<Label, W>
where
W: Semiring + From<f64>,
{
let mut fst: VectorWfst<Label, W> = VectorWfst::new();
let num_states = targets.len() + 1;
fst.add_states(num_states);
fst.set_start(0);
fst.set_final(targets.len() as StateId, W::one());
for (i, token) in targets.iter().enumerate() {
let from_state = i as StateId;
let to_state = (i + 1) as StateId;
fst.add_transition(WeightedTransition {
from: from_state,
input: Some(token.label),
output: Some(token.label),
to: to_state,
weight: W::one(),
});
for (alt_label, alt_conf) in &token.alternatives {
let weight = (token.confidence - alt_conf).max(0.0);
fst.add_transition(WeightedTransition {
from: from_state,
input: Some(*alt_label),
output: Some(*alt_label),
to: to_state,
weight: W::from(weight),
});
}
if token.confidence < config.confidence_threshold || config.allow_universal_bypass {
let bypass_weight = if token.confidence < config.confidence_threshold {
config.token_bypass_weight * (1.0 - token.confidence)
} else {
config.universal_bypass_weight
};
fst.add_transition(WeightedTransition {
from: from_state,
input: None, output: None,
to: to_state,
weight: W::from(bypass_weight),
});
}
if config.max_bypass_span > 1 {
for span in 2..=config.max_bypass_span.min(targets.len() - i) {
let skip_to = (i + span) as StateId;
let min_confidence: f64 = targets[i..i + span]
.iter()
.map(|t| t.confidence)
.fold(f64::INFINITY, f64::min);
if min_confidence < config.confidence_threshold {
let bypass_weight =
config.token_bypass_weight * span as f64 * (1.0 - min_confidence);
fst.add_transition(WeightedTransition {
from: from_state,
input: None,
output: None,
to: skip_to,
weight: W::from(bypass_weight),
});
}
}
}
}
for i in 0..=targets.len() {
let state = i as StateId;
fst.add_transition(WeightedTransition {
from: state,
input: Some(BLANK),
output: Some(BLANK),
to: state,
weight: W::from(config.blank_bypass_weight),
});
}
fst
}
pub fn build_wst_graph_uniform<W>(
targets: &[Label],
default_confidence: f64,
config: &WstConfig,
) -> VectorWfst<Label, W>
where
W: Semiring + From<f64>,
{
let confident_targets: Vec<ConfidentToken> = targets
.iter()
.map(|&label| ConfidentToken::new(label, default_confidence))
.collect();
build_wst_graph(&confident_targets, config)
}
pub fn build_wst_graph_with_insertions<W>(
targets: &[ConfidentToken],
vocab_size: usize,
insertion_weight: f64,
config: &WstConfig,
) -> VectorWfst<Label, W>
where
W: Semiring + From<f64>,
{
let mut fst: VectorWfst<Label, W> = VectorWfst::new();
let num_base_states = targets.len() + 1;
fst.add_states(num_base_states * 2);
fst.set_start(0);
fst.set_final(targets.len() as StateId, W::one());
for (i, token) in targets.iter().enumerate() {
let from_state = i as StateId;
let to_state = (i + 1) as StateId;
let insert_state = (num_base_states + i) as StateId;
fst.add_transition(WeightedTransition {
from: from_state,
input: Some(token.label),
output: Some(token.label),
to: to_state,
weight: W::one(),
});
if token.confidence < config.confidence_threshold {
let bypass_weight = config.token_bypass_weight * (1.0 - token.confidence);
fst.add_transition(WeightedTransition {
from: from_state,
input: None,
output: None,
to: to_state,
weight: W::from(bypass_weight),
});
}
fst.add_transition(WeightedTransition {
from: from_state,
input: None,
output: None,
to: insert_state,
weight: W::from(insertion_weight),
});
for label in 1..vocab_size as Label {
fst.add_transition(WeightedTransition {
from: insert_state,
input: Some(label),
output: Some(label),
to: from_state,
weight: W::one(),
});
}
}
for i in 0..num_base_states {
let state = i as StateId;
fst.add_transition(WeightedTransition {
from: state,
input: Some(BLANK),
output: Some(BLANK),
to: state,
weight: W::from(config.blank_bypass_weight),
});
}
fst
}
#[derive(Debug, Clone)]
pub struct WstLossResult {
pub loss: f64,
pub alignment: Vec<WstAlignmentStep>,
pub bypass_ratio: f64,
}
#[derive(Debug, Clone)]
pub struct WstAlignmentStep {
pub target_pos: usize,
pub label: Label,
pub is_bypass: bool,
pub weight: f64,
}
pub fn wst_loss<W>(acoustic_scores: &[Vec<f64>], wst_graph: &VectorWfst<Label, W>) -> WstLossResult
where
W: Semiring + From<f64> + Into<f64> + Clone,
{
let num_frames = acoustic_scores.len();
let num_states = wst_graph.num_states();
let mut alpha = vec![vec![f64::NEG_INFINITY; num_states]; num_frames + 1];
alpha[0][wst_graph.start() as usize] = 0.0;
for t in 0..num_frames {
for s in 0..num_states {
if alpha[t][s] <= f64::NEG_INFINITY {
continue;
}
let state = s as StateId;
for tr in wst_graph.transitions(state) {
let label = tr.input.unwrap_or(0) as usize;
let acoustic = if label < acoustic_scores[t].len() {
acoustic_scores[t][label]
} else {
f64::NEG_INFINITY
};
if acoustic <= f64::NEG_INFINITY {
continue;
}
let graph_weight: f64 = tr.weight.clone().into();
let arc_score = acoustic - graph_weight;
let next_state = tr.to as usize;
alpha[t + 1][next_state] =
log_add(alpha[t + 1][next_state], alpha[t][s] + arc_score);
}
for tr in wst_graph.transitions(state) {
if tr.input.is_none() && tr.output.is_none() {
let graph_weight: f64 = tr.weight.clone().into();
let next_state = tr.to as usize;
alpha[t][next_state] =
log_add(alpha[t][next_state], alpha[t][s] - graph_weight);
}
}
}
}
let mut total_log_prob = f64::NEG_INFINITY;
for s in 0..num_states {
let state = s as StateId;
if wst_graph.is_final(state) {
let final_weight: f64 = wst_graph.final_weight(state).into();
total_log_prob = log_add(total_log_prob, alpha[num_frames][s] - final_weight);
}
}
let (alignment, _viterbi_score) = viterbi_alignment(acoustic_scores, wst_graph);
let num_bypasses = alignment.iter().filter(|s| s.is_bypass).count();
let bypass_ratio = if alignment.is_empty() {
0.0
} else {
num_bypasses as f64 / alignment.len() as f64
};
WstLossResult {
loss: -total_log_prob,
alignment,
bypass_ratio,
}
}
fn viterbi_alignment<W>(
acoustic_scores: &[Vec<f64>],
wst_graph: &VectorWfst<Label, W>,
) -> (Vec<WstAlignmentStep>, f64)
where
W: Semiring + Into<f64> + Clone,
{
let num_frames = acoustic_scores.len();
let num_states = wst_graph.num_states();
if num_frames == 0 || num_states == 0 {
return (Vec::new(), f64::INFINITY);
}
let mut delta = vec![vec![f64::INFINITY; num_states]; num_frames + 1];
let mut backpointer: Vec<Vec<Option<(usize, usize, Label, bool, f64)>>> =
vec![vec![None; num_states]; num_frames + 1];
delta[0][wst_graph.start() as usize] = 0.0;
for t in 0..num_frames {
let mut changed = true;
while changed {
changed = false;
for s in 0..num_states {
if delta[t][s] >= f64::INFINITY {
continue;
}
let state = s as StateId;
for tr in wst_graph.transitions(state) {
if tr.input.is_some() || tr.output.is_some() {
continue;
}
let graph_weight: f64 = tr.weight.clone().into();
let new_score = delta[t][s] + graph_weight;
let next_state = tr.to as usize;
if new_score < delta[t][next_state] {
delta[t][next_state] = new_score;
backpointer[t][next_state] = Some((t, s, 0, true, graph_weight));
changed = true;
}
}
}
}
for s in 0..num_states {
if delta[t][s] >= f64::INFINITY {
continue;
}
let state = s as StateId;
for tr in wst_graph.transitions(state) {
if tr.input.is_none() && tr.output.is_none() {
continue;
}
let label = tr.input.unwrap_or(0);
let label_idx = label as usize;
let acoustic = if label_idx < acoustic_scores[t].len() {
acoustic_scores[t][label_idx]
} else {
f64::NEG_INFINITY
};
if acoustic <= f64::NEG_INFINITY {
continue;
}
let graph_weight: f64 = tr.weight.clone().into();
let arc_cost = graph_weight - acoustic;
let new_score = delta[t][s] + arc_cost;
let next_state = tr.to as usize;
if new_score < delta[t + 1][next_state] {
delta[t + 1][next_state] = new_score;
backpointer[t + 1][next_state] = Some((t, s, label, false, graph_weight));
}
}
}
}
let mut changed = true;
while changed {
changed = false;
for s in 0..num_states {
if delta[num_frames][s] >= f64::INFINITY {
continue;
}
let state = s as StateId;
for tr in wst_graph.transitions(state) {
if tr.input.is_some() || tr.output.is_some() {
continue;
}
let graph_weight: f64 = tr.weight.clone().into();
let new_score = delta[num_frames][s] + graph_weight;
let next_state = tr.to as usize;
if new_score < delta[num_frames][next_state] {
delta[num_frames][next_state] = new_score;
backpointer[num_frames][next_state] =
Some((num_frames, s, 0, true, graph_weight));
changed = true;
}
}
}
}
let mut best_score = f64::INFINITY;
let mut best_final_state = None;
for s in 0..num_states {
let state = s as StateId;
if wst_graph.is_final(state) {
let final_weight: f64 = wst_graph.final_weight(state).into();
let total = delta[num_frames][s] + final_weight;
if total < best_score {
best_score = total;
best_final_state = Some(s);
}
}
}
let mut alignment = Vec::new();
if let Some(mut state) = best_final_state {
let mut time = num_frames;
while let Some((prev_time, prev_state, label, is_bypass, weight)) = backpointer[time][state]
{
let target_pos = prev_state;
alignment.push(WstAlignmentStep {
target_pos,
label,
is_bypass,
weight,
});
time = prev_time;
state = prev_state;
}
alignment.reverse();
}
(alignment, best_score)
}
#[inline]
fn log_add(a: f64, b: f64) -> f64 {
if a == f64::NEG_INFINITY {
b
} else if b == f64::NEG_INFINITY {
a
} else if a > b {
a + (1.0 + (b - a).exp()).ln()
} else {
b + (1.0 + (a - b).exp()).ln()
}
}
pub fn estimate_confidences_from_nbest(
nbest: &[(Vec<Label>, f64)],
reference: &[Label],
) -> Vec<f64> {
if nbest.is_empty() || reference.is_empty() {
return vec![0.5; reference.len()];
}
let mut confidences = vec![0.0; reference.len()];
let total_prob: f64 = nbest.iter().map(|(_, p)| p.exp()).sum();
for (hyp, log_prob) in nbest {
let prob = log_prob.exp() / total_prob;
let alignment = align_sequences(reference, hyp);
for (ref_pos, _hyp_pos, matched) in alignment {
if matched {
confidences[ref_pos] += prob;
}
}
}
confidences
}
fn align_sequences(ref_seq: &[Label], hyp_seq: &[Label]) -> Vec<(usize, usize, bool)> {
let mut alignment = Vec::new();
let mut j = 0;
for (i, &ref_label) in ref_seq.iter().enumerate() {
if j < hyp_seq.len() && hyp_seq[j] == ref_label {
alignment.push((i, j, true));
j += 1;
} else {
alignment.push((i, j, false));
}
}
alignment
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
#[test]
fn test_build_wst_graph() {
let targets = vec![
ConfidentToken::new(1, 0.9),
ConfidentToken::new(2, 0.3), ConfidentToken::new(3, 0.8),
];
let config = WstConfig::default();
let graph: VectorWfst<Label, TropicalWeight> = build_wst_graph(&targets, &config);
assert_eq!(graph.num_states(), 4);
assert!(graph.is_final(3));
let state1_transitions = graph.transitions(1);
assert!(state1_transitions.iter().any(|t| t.input.is_none())); }
#[test]
fn test_wst_config() {
let config = WstConfig {
token_bypass_weight: 1.0,
confidence_threshold: 0.7,
..Default::default()
};
let targets = vec![
ConfidentToken::new(1, 0.5), ConfidentToken::new(2, 0.9), ];
let graph: VectorWfst<Label, TropicalWeight> = build_wst_graph(&targets, &config);
let state0_eps = graph
.transitions(0)
.iter()
.filter(|t| t.input.is_none())
.count();
assert!(state0_eps > 0);
}
#[test]
fn test_confident_token() {
let token = ConfidentToken::with_alternatives(1, 0.8, vec![(2, 0.1), (3, 0.05)]);
assert_eq!(token.label, 1);
assert_eq!(token.confidence, 0.8);
assert_eq!(token.alternatives.len(), 2);
}
#[test]
fn test_wst_loss_with_alignment() {
use crate::semiring::LogWeight;
let targets = vec![
ConfidentToken::new(1, 0.95), ConfidentToken::new(2, 0.3), ConfidentToken::new(3, 0.9), ];
let config = WstConfig {
confidence_threshold: 0.5,
token_bypass_weight: 2.0,
..Default::default()
};
let graph: VectorWfst<Label, LogWeight> = build_wst_graph(&targets, &config);
let acoustic_scores = vec![
vec![-0.1, -1.0, -2.0, -3.0], vec![-3.0, -2.0, -0.1, -1.0], vec![-2.0, -3.0, -1.0, -0.1], ];
let result = wst_loss(&acoustic_scores, &graph);
assert!(
!result.alignment.is_empty(),
"Alignment should not be empty"
);
assert!(
result.bypass_ratio < 1.0,
"Bypass ratio should be less than 1.0"
);
assert!(result.loss.is_finite(), "Loss should be finite");
}
#[test]
fn test_wst_loss_all_bypass() {
use crate::semiring::LogWeight;
let targets = vec![ConfidentToken::new(1, 0.1), ConfidentToken::new(2, 0.1)];
let config = WstConfig {
confidence_threshold: 0.5,
token_bypass_weight: 0.1, ..Default::default()
};
let graph: VectorWfst<Label, LogWeight> = build_wst_graph(&targets, &config);
let acoustic_scores = vec![
vec![-10.0, -10.0, -10.0], vec![-10.0, -10.0, -10.0],
];
let result = wst_loss(&acoustic_scores, &graph);
assert!(result.loss.is_finite(), "Loss should be finite");
}
}