1use super::cer::edit_distance;
2use super::state::{FormatOptions, NumericMetricState};
3use super::{MetricEntry, MetricMetadata};
4use crate::metric::{Metric, MetricName, Numeric, NumericEntry};
5use burn_core::tensor::backend::Backend;
6use burn_core::tensor::{Int, Tensor};
7use core::marker::PhantomData;
8use std::sync::Arc;
9
10#[derive(Clone)]
16pub struct WordErrorRate<B: Backend> {
17 name: MetricName,
18 state: NumericMetricState,
19 pad_token: Option<usize>,
20 _b: PhantomData<B>,
21}
22
23#[derive(new)]
25pub struct WerInput<B: Backend> {
26 pub outputs: Tensor<B, 2, Int>,
28 pub targets: Tensor<B, 2, Int>,
30}
31impl<B: Backend> Default for WordErrorRate<B> {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl<B: Backend> WordErrorRate<B> {
38 pub fn new() -> Self {
40 Self {
41 name: Arc::new("WER".to_string()),
42 state: NumericMetricState::default(),
43 pad_token: None,
44 _b: PhantomData,
45 }
46 }
47
48 pub fn with_pad_token(mut self, index: usize) -> Self {
50 self.pad_token = Some(index);
51 self
52 }
53}
54
55impl<B: Backend> Metric for WordErrorRate<B> {
56 type Input = WerInput<B>;
57
58 fn update(&mut self, input: &WerInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
59 let outputs = input.outputs.clone();
60 let targets = input.targets.clone();
61 let [batch_size, seq_len] = targets.dims();
62
63 let outputs_data = outputs
64 .to_data()
65 .to_vec::<i64>()
66 .expect("Failed to convert outputs to Vec");
67 let targets_data = targets
68 .to_data()
69 .to_vec::<i64>()
70 .expect("Failed to convert targets to Vec");
71
72 let pad_token = self.pad_token;
73
74 let mut total_edit_distance = 0.0;
75 let mut total_target_length = 0.0;
76
77 for i in 0..batch_size {
79 let start = i * seq_len;
80 let end = (i + 1) * seq_len;
81 let output_seq = &outputs_data[start..end];
82 let target_seq = &targets_data[start..end];
83
84 let output_seq_no_pad = match pad_token {
87 Some(pad) => output_seq
88 .iter()
89 .take_while(|&&x| x != pad as i64)
90 .map(|&x| x as i32)
91 .collect::<Vec<_>>(),
92 None => output_seq.iter().map(|&x| x as i32).collect(),
93 };
94
95 let target_seq_no_pad = match pad_token {
96 Some(pad) => target_seq
97 .iter()
98 .take_while(|&&x| x != pad as i64)
99 .map(|&x| x as i32)
100 .collect::<Vec<_>>(),
101 None => target_seq.iter().map(|&x| x as i32).collect(),
102 };
103
104 let ed = edit_distance(&target_seq_no_pad, &output_seq_no_pad);
105 total_edit_distance += ed as f64;
106 total_target_length += target_seq_no_pad.len() as f64;
107 }
108
109 let value = if total_target_length > 0.0 {
111 100.0 * total_edit_distance / total_target_length
112 } else {
113 0.0
114 };
115
116 self.state.update(
117 value,
118 batch_size,
119 FormatOptions::new(self.name()).unit("%").precision(2),
120 )
121 }
122
123 fn name(&self) -> MetricName {
124 self.name.clone()
125 }
126
127 fn clear(&mut self) {
128 self.state.reset();
129 }
130}
131
132impl<B: Backend> Numeric for WordErrorRate<B> {
134 fn value(&self) -> NumericEntry {
135 self.state.value()
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use crate::TestBackend;
143
144 #[test]
146 fn test_wer_without_padding() {
147 let device = Default::default();
148 let mut metric = WordErrorRate::<TestBackend>::new();
149
150 let preds = Tensor::from_data([[1, 2], [3, 4]], &device);
152 let tgts = Tensor::from_data([[1, 2], [3, 4]], &device);
153
154 metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
155
156 assert_eq!(0.0, metric.value().current());
157 }
158
159 #[test]
161 fn test_wer_without_padding_two_errors() {
162 let device = Default::default();
163 let mut metric = WordErrorRate::<TestBackend>::new();
164
165 let preds = Tensor::from_data([[1, 2], [3, 5]], &device);
169 let tgts = Tensor::from_data([[1, 3], [3, 4]], &device);
170
171 metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
172
173 assert_eq!(50.0, metric.value().current());
175 }
176
177 #[test]
179 fn test_wer_with_padding() {
180 let device = Default::default();
181 let pad = 9_i64;
182 let mut metric = WordErrorRate::<TestBackend>::new().with_pad_token(pad as usize);
183
184 let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device);
188 let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device);
189
190 metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
191 assert_eq!(50.0, metric.value().current());
192 }
193
194 #[test]
196 fn test_clear_resets_state() {
197 let device = Default::default();
198 let mut metric = WordErrorRate::<TestBackend>::new();
199
200 let preds = Tensor::from_data([[1, 2]], &device);
201 let tgts = Tensor::from_data([[1, 3]], &device); metric.update(
204 &WerInput::new(preds.clone(), tgts.clone()),
205 &MetricMetadata::fake(),
206 );
207 assert!(metric.value().current() > 0.0);
208
209 metric.clear();
210 assert!(metric.value().current().is_nan());
211 }
212}