1use super::cer::edit_distance;
2use super::state::{FormatOptions, NumericMetricState};
3use super::{MetricMetadata, SerializedEntry};
4use crate::metric::{
5 Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry,
6};
7use burn_core::tensor::backend::Backend;
8use burn_core::tensor::{Int, Tensor};
9use core::marker::PhantomData;
10use std::sync::Arc;
11
12#[derive(Clone)]
18pub struct WordErrorRate<B: Backend> {
19 name: MetricName,
20 state: NumericMetricState,
21 pad_token: Option<usize>,
22 _b: PhantomData<B>,
23}
24
25#[derive(new)]
27pub struct WerInput<B: Backend> {
28 pub outputs: Tensor<B, 2, Int>,
30 pub targets: Tensor<B, 2, Int>,
32}
33impl<B: Backend> Default for WordErrorRate<B> {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl<B: Backend> WordErrorRate<B> {
40 pub fn new() -> Self {
42 Self {
43 name: Arc::new("WER".to_string()),
44 state: NumericMetricState::default(),
45 pad_token: None,
46 _b: PhantomData,
47 }
48 }
49
50 pub fn with_pad_token(mut self, index: usize) -> Self {
52 self.pad_token = Some(index);
53 self
54 }
55}
56
57impl<B: Backend> Metric for WordErrorRate<B> {
58 type Input = WerInput<B>;
59
60 fn update(&mut self, input: &WerInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
61 let outputs = input.outputs.clone();
62 let targets = input.targets.clone();
63 let [batch_size, seq_len] = targets.dims();
64
65 let outputs_data = outputs
66 .to_data()
67 .to_vec::<i64>()
68 .expect("Failed to convert outputs to Vec");
69 let targets_data = targets
70 .to_data()
71 .to_vec::<i64>()
72 .expect("Failed to convert targets to Vec");
73
74 let pad_token = self.pad_token;
75
76 let mut total_edit_distance = 0.0;
77 let mut total_target_length = 0.0;
78
79 for i in 0..batch_size {
81 let start = i * seq_len;
82 let end = (i + 1) * seq_len;
83 let output_seq = &outputs_data[start..end];
84 let target_seq = &targets_data[start..end];
85
86 let output_seq_no_pad = match pad_token {
89 Some(pad) => output_seq
90 .iter()
91 .take_while(|&&x| x != pad as i64)
92 .map(|&x| x as i32)
93 .collect::<Vec<_>>(),
94 None => output_seq.iter().map(|&x| x as i32).collect(),
95 };
96
97 let target_seq_no_pad = match pad_token {
98 Some(pad) => target_seq
99 .iter()
100 .take_while(|&&x| x != pad as i64)
101 .map(|&x| x as i32)
102 .collect::<Vec<_>>(),
103 None => target_seq.iter().map(|&x| x as i32).collect(),
104 };
105
106 let ed = edit_distance(&target_seq_no_pad, &output_seq_no_pad);
107 total_edit_distance += ed as f64;
108 total_target_length += target_seq_no_pad.len() as f64;
109 }
110
111 let value = if total_target_length > 0.0 {
113 100.0 * total_edit_distance / total_target_length
114 } else {
115 0.0
116 };
117
118 self.state.update(
119 value,
120 batch_size,
121 FormatOptions::new(self.name()).unit("%").precision(2),
122 )
123 }
124
125 fn name(&self) -> MetricName {
126 self.name.clone()
127 }
128
129 fn clear(&mut self) {
130 self.state.reset();
131 }
132
133 fn attributes(&self) -> MetricAttributes {
134 NumericAttributes {
135 unit: Some("%".to_string()),
136 higher_is_better: false,
137 }
138 .into()
139 }
140}
141
142impl<B: Backend> Numeric for WordErrorRate<B> {
143 fn value(&self) -> NumericEntry {
144 self.state.current_value()
145 }
146
147 fn running_value(&self) -> NumericEntry {
148 self.state.running_value()
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::TestBackend;
156
157 #[test]
159 fn test_wer_without_padding() {
160 let device = Default::default();
161 let mut metric = WordErrorRate::<TestBackend>::new();
162
163 let preds = Tensor::from_data([[1, 2], [3, 4]], &device);
165 let tgts = Tensor::from_data([[1, 2], [3, 4]], &device);
166
167 metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
168
169 assert_eq!(0.0, metric.value().current());
170 }
171
172 #[test]
174 fn test_wer_without_padding_two_errors() {
175 let device = Default::default();
176 let mut metric = WordErrorRate::<TestBackend>::new();
177
178 let preds = Tensor::from_data([[1, 2], [3, 5]], &device);
182 let tgts = Tensor::from_data([[1, 3], [3, 4]], &device);
183
184 metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
185
186 assert_eq!(50.0, metric.value().current());
188 }
189
190 #[test]
192 fn test_wer_with_padding() {
193 let device = Default::default();
194 let pad = 9_i64;
195 let mut metric = WordErrorRate::<TestBackend>::new().with_pad_token(pad as usize);
196
197 let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device);
201 let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device);
202
203 metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
204 assert_eq!(50.0, metric.value().current());
205 }
206
207 #[test]
209 fn test_clear_resets_state() {
210 let device = Default::default();
211 let mut metric = WordErrorRate::<TestBackend>::new();
212
213 let preds = Tensor::from_data([[1, 2]], &device);
214 let tgts = Tensor::from_data([[1, 3]], &device); metric.update(
217 &WerInput::new(preds.clone(), tgts.clone()),
218 &MetricMetadata::fake(),
219 );
220 assert!(metric.value().current() > 0.0);
221
222 metric.clear();
223 assert!(metric.value().current().is_nan());
224 }
225}