use crate::error::{Result, ZantetsuError};
use crate::parser::bio_tags::BioTag;
#[derive(Debug, Clone)]
pub struct ViterbiDecoder {
num_tags: usize,
}
#[derive(Debug, Clone, Copy)]
struct PathState {
score: f32,
prev_tag: Option<usize>,
}
impl ViterbiDecoder {
pub fn new(num_tags: usize) -> Self {
Self { num_tags }
}
pub fn decode(
&self,
emission_scores: &[Vec<f32>],
transition_matrix: &[Vec<f32>],
) -> Result<Vec<usize>> {
let seq_len = emission_scores.len();
if seq_len == 0 {
return Ok(Vec::new());
}
if emission_scores[0].len() != self.num_tags {
return Err(ZantetsuError::NeuralParser(format!(
"Emission score dimension mismatch: expected {}, got {}",
self.num_tags,
emission_scores[0].len()
)));
}
let mut dp: Vec<Vec<PathState>> = vec![
vec![
PathState {
score: f32::NEG_INFINITY,
prev_tag: None
};
self.num_tags
];
seq_len
];
for tag in 0..self.num_tags {
dp[0][tag].score = emission_scores[0][tag];
}
for pos in 1..seq_len {
for curr_tag in 0..self.num_tags {
let curr_bio_tag = BioTag::from_index(curr_tag).ok_or_else(|| {
ZantetsuError::NeuralParser(format!("Invalid tag index: {}", curr_tag))
})?;
let mut best_score = f32::NEG_INFINITY;
let mut best_prev = None;
for prev_tag in 0..self.num_tags {
let prev_bio_tag = BioTag::from_index(prev_tag).ok_or_else(|| {
ZantetsuError::NeuralParser(format!("Invalid tag index: {}", prev_tag))
})?;
if !BioTag::is_valid_transition(prev_bio_tag, curr_bio_tag) {
continue;
}
let score = dp[pos - 1][prev_tag].score
+ transition_matrix[prev_tag][curr_tag]
+ emission_scores[pos][curr_tag];
if score > best_score {
best_score = score;
best_prev = Some(prev_tag);
}
}
dp[pos][curr_tag].score = best_score;
dp[pos][curr_tag].prev_tag = best_prev;
}
}
let mut path = Vec::with_capacity(seq_len);
let mut best_final_tag = 0;
let mut best_final_score = f32::NEG_INFINITY;
for (tag, cell) in dp[seq_len - 1].iter().enumerate().take(self.num_tags) {
if cell.score > best_final_score {
best_final_score = cell.score;
best_final_tag = tag;
}
}
path.push(best_final_tag);
let mut curr_tag = best_final_tag;
for pos in (1..seq_len).rev() {
curr_tag = dp[pos][curr_tag].prev_tag.unwrap_or(0);
path.push(curr_tag);
}
path.reverse();
Ok(path)
}
pub fn decode_constrained(
&self,
emission_scores: &[Vec<f32>],
transition_matrix: &[Vec<f32>],
) -> Result<Vec<usize>> {
let mut valid_transitions: Vec<Vec<bool>> = vec![vec![false; self.num_tags]; self.num_tags];
for (prev_idx, row) in valid_transitions.iter_mut().enumerate().take(self.num_tags) {
if let Some(prev_tag) = BioTag::from_index(prev_idx) {
for (curr_idx, cell) in row.iter_mut().enumerate().take(self.num_tags) {
if let Some(curr_tag) = BioTag::from_index(curr_idx) {
*cell = BioTag::is_valid_transition(prev_tag, curr_tag);
}
}
}
}
let seq_len = emission_scores.len();
if seq_len == 0 {
return Ok(Vec::new());
}
let mut dp: Vec<Vec<f32>> = vec![vec![f32::NEG_INFINITY; self.num_tags]; seq_len];
let mut backptr: Vec<Vec<Option<usize>>> = vec![vec![None; self.num_tags]; seq_len];
for tag in 0..self.num_tags {
dp[0][tag] = emission_scores[0][tag];
}
for pos in 1..seq_len {
for curr_tag in 0..self.num_tags {
let mut best_score = f32::NEG_INFINITY;
let mut best_prev = None;
for prev_tag in 0..self.num_tags {
if !valid_transitions[prev_tag][curr_tag] {
continue;
}
let score = dp[pos - 1][prev_tag]
+ transition_matrix[prev_tag][curr_tag]
+ emission_scores[pos][curr_tag];
if score > best_score {
best_score = score;
best_prev = Some(prev_tag);
}
}
dp[pos][curr_tag] = best_score;
backptr[pos][curr_tag] = best_prev;
}
}
let mut best_final_tag = 0;
let mut best_final_score = f32::NEG_INFINITY;
for (tag, &score) in dp[seq_len - 1].iter().enumerate().take(self.num_tags) {
if score > best_final_score {
best_final_score = score;
best_final_tag = tag;
}
}
let mut path = vec![best_final_tag];
let mut curr_tag = best_final_tag;
for pos in (1..seq_len).rev() {
curr_tag = backptr[pos][curr_tag].unwrap_or(0);
path.push(curr_tag);
}
path.reverse();
Ok(path)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_simple_transition_matrix(num_tags: usize) -> Vec<Vec<f32>> {
let mut matrix = vec![vec![0.0f32; num_tags]; num_tags];
for i in 0..num_tags {
for j in 0..num_tags {
if BioTag::is_valid_transition(
BioTag::from_index(i).unwrap(),
BioTag::from_index(j).unwrap(),
) {
matrix[i][j] = 0.1;
} else {
matrix[i][j] = -1000.0; }
}
}
matrix
}
#[test]
fn test_viterbi_simple() {
let decoder = ViterbiDecoder::new(BioTag::NUM_TAGS);
let transition = create_simple_transition_matrix(BioTag::NUM_TAGS);
let emissions = vec![
vec![
1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
vec![
0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
];
let result = decoder.decode(&emissions, &transition).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0], 0);
}
#[test]
fn test_viterbi_empty() {
let decoder = ViterbiDecoder::new(BioTag::NUM_TAGS);
let transition = create_simple_transition_matrix(BioTag::NUM_TAGS);
let emissions: Vec<Vec<f32>> = vec![];
let result = decoder.decode(&emissions, &transition).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_decode_constrained() {
let decoder = ViterbiDecoder::new(BioTag::NUM_TAGS);
let transition = create_simple_transition_matrix(BioTag::NUM_TAGS);
let emissions = vec![vec![1.0; BioTag::NUM_TAGS], vec![1.0; BioTag::NUM_TAGS]];
let result = decoder.decode_constrained(&emissions, &transition).unwrap();
assert_eq!(result.len(), 2);
}
}