1use super::cer::edit_distance;
2use super::state::{FormatOptions, NumericMetricState};
3use super::{MetricMetadata, SerializedEntry};
4use crate::metric::{
5 Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry,
6};
7use burn_core::tensor::backend::Backend;
8use burn_core::tensor::{Int, Tensor};
9use core::marker::PhantomData;
10use std::sync::Arc;
11
12#[derive(Clone)]
18pub struct WordErrorRate<B: Backend> {
19 name: MetricName,
20 state: NumericMetricState,
21 pad_token: Option<usize>,
22 _b: PhantomData<B>,
23}
24
25#[derive(new)]
27pub struct WerInput<B: Backend> {
28 pub outputs: Tensor<B, 2, Int>,
30 pub targets: Tensor<B, 2, Int>,
32}
33impl<B: Backend> Default for WordErrorRate<B> {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl<B: Backend> WordErrorRate<B> {
40 pub fn new() -> Self {
42 Self {
43 name: Arc::new("WER".to_string()),
44 state: NumericMetricState::default(),
45 pad_token: None,
46 _b: PhantomData,
47 }
48 }
49
50 pub fn with_pad_token(mut self, index: usize) -> Self {
52 self.pad_token = Some(index);
53 self
54 }
55}
56
57impl<B: Backend> Metric for WordErrorRate<B> {
58 type Input = WerInput<B>;
59
60 fn update(&mut self, input: &WerInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
61 let outputs = input.outputs.clone();
62 let targets = input.targets.clone();
63 let [batch_size, seq_len] = targets.dims();
64
65 let outputs_data = outputs.to_data().iter::<i32>().collect::<Vec<_>>();
69 let targets_data = targets.to_data().iter::<i32>().collect::<Vec<_>>();
70
71 let pad_token = self.pad_token.map(|p| p as i32);
72
73 let mut total_edit_distance = 0.0;
74 let mut total_target_length = 0.0;
75
76 for i in 0..batch_size {
78 let start = i * seq_len;
79 let end = (i + 1) * seq_len;
80 let output_seq = &outputs_data[start..end];
81 let target_seq = &targets_data[start..end];
82
83 let target_seq_no_pad: &[i32] = match pad_token {
85 Some(pad) => {
86 let len = target_seq
87 .iter()
88 .position(|&x| x == pad)
89 .unwrap_or(target_seq.len());
90 &target_seq[..len]
91 }
92 None => target_seq,
93 };
94 let output_seq_no_pad: &[i32] = match pad_token {
95 Some(pad) => {
96 let len = output_seq
97 .iter()
98 .position(|&x| x == pad)
99 .unwrap_or(output_seq.len());
100 &output_seq[..len]
101 }
102 None => output_seq,
103 };
104
105 let ed = edit_distance(target_seq_no_pad, output_seq_no_pad);
106 total_edit_distance += ed as f64;
107 total_target_length += target_seq_no_pad.len() as f64;
108 }
109
110 let value = if total_target_length > 0.0 {
112 100.0 * total_edit_distance / total_target_length
113 } else {
114 0.0
115 };
116
117 self.state.update(
118 value,
119 batch_size,
120 FormatOptions::new(self.name()).unit("%").precision(2),
121 )
122 }
123
124 fn name(&self) -> MetricName {
125 self.name.clone()
126 }
127
128 fn clear(&mut self) {
129 self.state.reset();
130 }
131
132 fn attributes(&self) -> MetricAttributes {
133 NumericAttributes {
134 unit: Some("%".to_string()),
135 higher_is_better: false,
136 }
137 .into()
138 }
139}
140
141impl<B: Backend> Numeric for WordErrorRate<B> {
142 fn value(&self) -> NumericEntry {
143 self.state.current_value()
144 }
145
146 fn running_value(&self) -> NumericEntry {
147 self.state.running_value()
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::TestBackend;
155
156 #[test]
158 fn test_wer_without_padding() {
159 let device = Default::default();
160 let mut metric = WordErrorRate::<TestBackend>::new();
161
162 let preds = Tensor::from_data([[1, 2], [3, 4]], &device);
164 let tgts = Tensor::from_data([[1, 2], [3, 4]], &device);
165
166 metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
167
168 assert_eq!(0.0, metric.value().current());
169 }
170
171 #[test]
173 fn test_wer_without_padding_two_errors() {
174 let device = Default::default();
175 let mut metric = WordErrorRate::<TestBackend>::new();
176
177 let preds = Tensor::from_data([[1, 2], [3, 5]], &device);
181 let tgts = Tensor::from_data([[1, 3], [3, 4]], &device);
182
183 metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
184
185 assert_eq!(50.0, metric.value().current());
187 }
188
189 #[test]
191 fn test_wer_with_padding() {
192 let device = Default::default();
193 let pad = 9_i64;
194 let mut metric = WordErrorRate::<TestBackend>::new().with_pad_token(pad as usize);
195
196 let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device);
200 let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device);
201
202 metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
203 assert_eq!(50.0, metric.value().current());
204 }
205
206 #[test]
208 fn test_clear_resets_state() {
209 let device = Default::default();
210 let mut metric = WordErrorRate::<TestBackend>::new();
211
212 let preds = Tensor::from_data([[1, 2]], &device);
213 let tgts = Tensor::from_data([[1, 3]], &device); metric.update(
216 &WerInput::new(preds.clone(), tgts.clone()),
217 &MetricMetadata::fake(),
218 );
219 assert!(metric.value().current() > 0.0);
220
221 metric.clear();
222 assert!(metric.value().current().is_nan());
223 }
224}