use super::cer::edit_distance;
use super::state::{FormatOptions, NumericMetricState};
use super::{MetricMetadata, SerializedEntry};
use crate::metric::{
Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry,
};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{Int, Tensor};
use core::marker::PhantomData;
use std::sync::Arc;
#[derive(Clone)]
pub struct WordErrorRate<B: Backend> {
name: MetricName,
state: NumericMetricState,
pad_token: Option<usize>,
_b: PhantomData<B>,
}
#[derive(new)]
pub struct WerInput<B: Backend> {
pub outputs: Tensor<B, 2, Int>,
pub targets: Tensor<B, 2, Int>,
}
impl<B: Backend> Default for WordErrorRate<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> WordErrorRate<B> {
pub fn new() -> Self {
Self {
name: Arc::new("WER".to_string()),
state: NumericMetricState::default(),
pad_token: None,
_b: PhantomData,
}
}
pub fn with_pad_token(mut self, index: usize) -> Self {
self.pad_token = Some(index);
self
}
}
impl<B: Backend> Metric for WordErrorRate<B> {
type Input = WerInput<B>;
fn update(&mut self, input: &WerInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
let outputs = input.outputs.clone();
let targets = input.targets.clone();
let [batch_size, seq_len] = targets.dims();
let outputs_data = outputs.to_data().iter::<i32>().collect::<Vec<_>>();
let targets_data = targets.to_data().iter::<i32>().collect::<Vec<_>>();
let pad_token = self.pad_token.map(|p| p as i32);
let mut total_edit_distance = 0.0;
let mut total_target_length = 0.0;
for i in 0..batch_size {
let start = i * seq_len;
let end = (i + 1) * seq_len;
let output_seq = &outputs_data[start..end];
let target_seq = &targets_data[start..end];
let target_seq_no_pad: &[i32] = match pad_token {
Some(pad) => {
let len = target_seq
.iter()
.position(|&x| x == pad)
.unwrap_or(target_seq.len());
&target_seq[..len]
}
None => target_seq,
};
let output_seq_no_pad: &[i32] = match pad_token {
Some(pad) => {
let len = output_seq
.iter()
.position(|&x| x == pad)
.unwrap_or(output_seq.len());
&output_seq[..len]
}
None => output_seq,
};
let ed = edit_distance(target_seq_no_pad, output_seq_no_pad);
total_edit_distance += ed as f64;
total_target_length += target_seq_no_pad.len() as f64;
}
let value = if total_target_length > 0.0 {
100.0 * total_edit_distance / total_target_length
} else {
0.0
};
self.state.update(
value,
batch_size,
FormatOptions::new(self.name()).unit("%").precision(2),
)
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn clear(&mut self) {
self.state.reset();
}
fn attributes(&self) -> MetricAttributes {
NumericAttributes {
unit: Some("%".to_string()),
higher_is_better: false,
}
.into()
}
}
impl<B: Backend> Numeric for WordErrorRate<B> {
fn value(&self) -> NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn test_wer_without_padding() {
let device = Default::default();
let mut metric = WordErrorRate::<TestBackend>::new();
let preds = Tensor::from_data([[1, 2], [3, 4]], &device);
let tgts = Tensor::from_data([[1, 2], [3, 4]], &device);
metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
assert_eq!(0.0, metric.value().current());
}
#[test]
fn test_wer_without_padding_two_errors() {
let device = Default::default();
let mut metric = WordErrorRate::<TestBackend>::new();
let preds = Tensor::from_data([[1, 2], [3, 5]], &device);
let tgts = Tensor::from_data([[1, 3], [3, 4]], &device);
metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
assert_eq!(50.0, metric.value().current());
}
#[test]
fn test_wer_with_padding() {
let device = Default::default();
let pad = 9_i64;
let mut metric = WordErrorRate::<TestBackend>::new().with_pad_token(pad as usize);
let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device);
let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device);
metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
assert_eq!(50.0, metric.value().current());
}
#[test]
fn test_clear_resets_state() {
let device = Default::default();
let mut metric = WordErrorRate::<TestBackend>::new();
let preds = Tensor::from_data([[1, 2]], &device);
let tgts = Tensor::from_data([[1, 3]], &device);
metric.update(
&WerInput::new(preds.clone(), tgts.clone()),
&MetricMetadata::fake(),
);
assert!(metric.value().current() > 0.0);
metric.clear();
assert!(metric.value().current().is_nan());
}
}