Skip to main content

burn_train/metric/
wer.rs

1use super::cer::edit_distance;
2use super::state::{FormatOptions, NumericMetricState};
3use super::{MetricMetadata, SerializedEntry};
4use crate::metric::{
5    Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry,
6};
7use burn_core::tensor::backend::Backend;
8use burn_core::tensor::{Int, Tensor};
9use core::marker::PhantomData;
10use std::sync::Arc;
11
12// The edit_distance function remains the same as it calculates the Levenshtein distance
13// between two sequences. The "units" within the sequences will now be treated as words.
14/// The word error rate (WER) metric, similar to the CER, is defined as the edit distance (e.g. Levenshtein distance) between the predicted
15/// and reference word sequences, divided by the total number of words in the reference. Here, the "units" within the sequences are words.
16///
17#[derive(Clone)]
18pub struct WordErrorRate<B: Backend> {
19    name: MetricName,
20    state: NumericMetricState,
21    pad_token: Option<usize>,
22    _b: PhantomData<B>,
23}
24
25/// The [word error rate metric](WordErrorRate) input type.
26#[derive(new)]
27pub struct WerInput<B: Backend> {
28    /// The predicted token sequences (as a 2-D tensor of token indices).
29    pub outputs: Tensor<B, 2, Int>,
30    /// The target token sequences (as a 2-D tensor of token indices).
31    pub targets: Tensor<B, 2, Int>,
32}
33impl<B: Backend> Default for WordErrorRate<B> {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl<B: Backend> WordErrorRate<B> {
40    /// Creates the metric.
41    pub fn new() -> Self {
42        Self {
43            name: Arc::new("WER".to_string()),
44            state: NumericMetricState::default(),
45            pad_token: None,
46            _b: PhantomData,
47        }
48    }
49
50    /// Sets the pad token.
51    pub fn with_pad_token(mut self, index: usize) -> Self {
52        self.pad_token = Some(index);
53        self
54    }
55}
56
57impl<B: Backend> Metric for WordErrorRate<B> {
58    type Input = WerInput<B>;
59
60    fn update(&mut self, input: &WerInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
61        let outputs = input.outputs.clone();
62        let targets = input.targets.clone();
63        let [batch_size, seq_len] = targets.dims();
64
65        // `TensorData::iter::<i32>()` dispatches on the stored DType and
66        // narrows to i32 per element; token IDs in any reasonable vocabulary
67        // fit in i32 regardless of the backend's native IntElem.
68        let outputs_data = outputs.to_data().iter::<i32>().collect::<Vec<_>>();
69        let targets_data = targets.to_data().iter::<i32>().collect::<Vec<_>>();
70
71        let pad_token = self.pad_token.map(|p| p as i32);
72
73        let mut total_edit_distance = 0.0;
74        let mut total_target_length = 0.0;
75
76        // Process each sequence in the batch
77        for i in 0..batch_size {
78            let start = i * seq_len;
79            let end = (i + 1) * seq_len;
80            let output_seq = &outputs_data[start..end];
81            let target_seq = &targets_data[start..end];
82
83            // Strip right-padding if a pad token is configured.
84            let target_seq_no_pad: &[i32] = match pad_token {
85                Some(pad) => {
86                    let len = target_seq
87                        .iter()
88                        .position(|&x| x == pad)
89                        .unwrap_or(target_seq.len());
90                    &target_seq[..len]
91                }
92                None => target_seq,
93            };
94            let output_seq_no_pad: &[i32] = match pad_token {
95                Some(pad) => {
96                    let len = output_seq
97                        .iter()
98                        .position(|&x| x == pad)
99                        .unwrap_or(output_seq.len());
100                    &output_seq[..len]
101                }
102                None => output_seq,
103            };
104
105            let ed = edit_distance(target_seq_no_pad, output_seq_no_pad);
106            total_edit_distance += ed as f64;
107            total_target_length += target_seq_no_pad.len() as f64;
108        }
109
110        // Compute current WER value as a percentage
111        let value = if total_target_length > 0.0 {
112            100.0 * total_edit_distance / total_target_length
113        } else {
114            0.0
115        };
116
117        self.state.update(
118            value,
119            batch_size,
120            FormatOptions::new(self.name()).unit("%").precision(2),
121        )
122    }
123
124    fn name(&self) -> MetricName {
125        self.name.clone()
126    }
127
128    fn clear(&mut self) {
129        self.state.reset();
130    }
131
132    fn attributes(&self) -> MetricAttributes {
133        NumericAttributes {
134            unit: Some("%".to_string()),
135            higher_is_better: false,
136        }
137        .into()
138    }
139}
140
141impl<B: Backend> Numeric for WordErrorRate<B> {
142    fn value(&self) -> NumericEntry {
143        self.state.current_value()
144    }
145
146    fn running_value(&self) -> NumericEntry {
147        self.state.running_value()
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::TestBackend;
155
156    /// Perfect match => WER = 0 %.
157    #[test]
158    fn test_wer_without_padding() {
159        let device = Default::default();
160        let mut metric = WordErrorRate::<TestBackend>::new();
161
162        // Batch size = 2, sequence length = 2
163        let preds = Tensor::from_data([[1, 2], [3, 4]], &device);
164        let tgts = Tensor::from_data([[1, 2], [3, 4]], &device);
165
166        metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
167
168        assert_eq!(0.0, metric.value().current());
169    }
170
171    /// Two word edits in four target words => 50 %.
172    #[test]
173    fn test_wer_without_padding_two_errors() {
174        let device = Default::default();
175        let mut metric = WordErrorRate::<TestBackend>::new();
176
177        // One substitution in each sequence.
178        // Sequence 1: target [1, 3], pred [1, 2] -> 1 error (3 vs 2)
179        // Sequence 2: target [3, 4], pred [3, 5] -> 1 error (4 vs 5)
180        let preds = Tensor::from_data([[1, 2], [3, 5]], &device);
181        let tgts = Tensor::from_data([[1, 3], [3, 4]], &device);
182
183        metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
184
185        // Total errors = 2, Total target words = 4. WER = (2/4) * 100 = 50 %
186        assert_eq!(50.0, metric.value().current());
187    }
188
189    /// Same scenario as above, but with right-padding (token 9) ignored.
190    #[test]
191    fn test_wer_with_padding() {
192        let device = Default::default();
193        let pad = 9_i64;
194        let mut metric = WordErrorRate::<TestBackend>::new().with_pad_token(pad as usize);
195
196        // Each row has three columns, last one is the pad token.
197        // Target sequences after removing pad: [1, 3] and [3, 4] (total length 4)
198        // Predicted sequences after removing pad: [1, 2] and [3, 5]
199        let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device);
200        let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device);
201
202        metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
203        assert_eq!(50.0, metric.value().current());
204    }
205
206    /// `clear()` must reset the running statistics to NaN.
207    #[test]
208    fn test_clear_resets_state() {
209        let device = Default::default();
210        let mut metric = WordErrorRate::<TestBackend>::new();
211
212        let preds = Tensor::from_data([[1, 2]], &device);
213        let tgts = Tensor::from_data([[1, 3]], &device); // one error
214
215        metric.update(
216            &WerInput::new(preds.clone(), tgts.clone()),
217            &MetricMetadata::fake(),
218        );
219        assert!(metric.value().current() > 0.0);
220
221        metric.clear();
222        assert!(metric.value().current().is_nan());
223    }
224}