burn_train/metric/
wer.rs

1use super::cer::edit_distance;
2use super::state::{FormatOptions, NumericMetricState};
3use super::{MetricEntry, MetricMetadata};
4use crate::metric::{Metric, MetricName, Numeric, NumericEntry};
5use burn_core::tensor::backend::Backend;
6use burn_core::tensor::{Int, Tensor};
7use core::marker::PhantomData;
8use std::sync::Arc;
9
10// The edit_distance function remains the same as it calculates the Levenshtein distance
11// between two sequences. The "units" within the sequences will now be treated as words.
12/// The word error rate (WER) metric, similar to the CER, is defined as the edit distance (e.g. Levenshtein distance) between the predicted
13/// and reference word sequences, divided by the total number of words in the reference. Here, the "units" within the sequences are words.
14///
15#[derive(Clone)]
16pub struct WordErrorRate<B: Backend> {
17    name: MetricName,
18    state: NumericMetricState,
19    pad_token: Option<usize>,
20    _b: PhantomData<B>,
21}
22
23/// The [word error rate metric](WordErrorRate) input type.
24#[derive(new)]
25pub struct WerInput<B: Backend> {
26    /// The predicted token sequences (as a 2-D tensor of token indices).
27    pub outputs: Tensor<B, 2, Int>,
28    /// The target token sequences (as a 2-D tensor of token indices).
29    pub targets: Tensor<B, 2, Int>,
30}
31impl<B: Backend> Default for WordErrorRate<B> {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl<B: Backend> WordErrorRate<B> {
38    /// Creates the metric.
39    pub fn new() -> Self {
40        Self {
41            name: Arc::new("WER".to_string()),
42            state: NumericMetricState::default(),
43            pad_token: None,
44            _b: PhantomData,
45        }
46    }
47
48    /// Sets the pad token.
49    pub fn with_pad_token(mut self, index: usize) -> Self {
50        self.pad_token = Some(index);
51        self
52    }
53}
54
55impl<B: Backend> Metric for WordErrorRate<B> {
56    type Input = WerInput<B>;
57
58    fn update(&mut self, input: &WerInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
59        let outputs = input.outputs.clone();
60        let targets = input.targets.clone();
61        let [batch_size, seq_len] = targets.dims();
62
63        let outputs_data = outputs
64            .to_data()
65            .to_vec::<i64>()
66            .expect("Failed to convert outputs to Vec");
67        let targets_data = targets
68            .to_data()
69            .to_vec::<i64>()
70            .expect("Failed to convert targets to Vec");
71
72        let pad_token = self.pad_token;
73
74        let mut total_edit_distance = 0.0;
75        let mut total_target_length = 0.0;
76
77        // Process each sequence in the batch
78        for i in 0..batch_size {
79            let start = i * seq_len;
80            let end = (i + 1) * seq_len;
81            let output_seq = &outputs_data[start..end];
82            let target_seq = &targets_data[start..end];
83
84            // Handle padding and map elements to i32.
85            // These sequences now represent "words" (token IDs).
86            let output_seq_no_pad = match pad_token {
87                Some(pad) => output_seq
88                    .iter()
89                    .take_while(|&&x| x != pad as i64)
90                    .map(|&x| x as i32)
91                    .collect::<Vec<_>>(),
92                None => output_seq.iter().map(|&x| x as i32).collect(),
93            };
94
95            let target_seq_no_pad = match pad_token {
96                Some(pad) => target_seq
97                    .iter()
98                    .take_while(|&&x| x != pad as i64)
99                    .map(|&x| x as i32)
100                    .collect::<Vec<_>>(),
101                None => target_seq.iter().map(|&x| x as i32).collect(),
102            };
103
104            let ed = edit_distance(&target_seq_no_pad, &output_seq_no_pad);
105            total_edit_distance += ed as f64;
106            total_target_length += target_seq_no_pad.len() as f64;
107        }
108
109        // Compute current WER value as a percentage
110        let value = if total_target_length > 0.0 {
111            100.0 * total_edit_distance / total_target_length
112        } else {
113            0.0
114        };
115
116        self.state.update(
117            value,
118            batch_size,
119            FormatOptions::new(self.name()).unit("%").precision(2),
120        )
121    }
122
123    fn name(&self) -> MetricName {
124        self.name.clone()
125    }
126
127    fn clear(&mut self) {
128        self.state.reset();
129    }
130}
131
132/// The [word error rate metric](WordErrorRate) implementation.
133impl<B: Backend> Numeric for WordErrorRate<B> {
134    fn value(&self) -> NumericEntry {
135        self.state.value()
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::TestBackend;
143
144    /// Perfect match => WER = 0 %.
145    #[test]
146    fn test_wer_without_padding() {
147        let device = Default::default();
148        let mut metric = WordErrorRate::<TestBackend>::new();
149
150        // Batch size = 2, sequence length = 2
151        let preds = Tensor::from_data([[1, 2], [3, 4]], &device);
152        let tgts = Tensor::from_data([[1, 2], [3, 4]], &device);
153
154        metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
155
156        assert_eq!(0.0, metric.value().current());
157    }
158
159    /// Two word edits in four target words => 50 %.
160    #[test]
161    fn test_wer_without_padding_two_errors() {
162        let device = Default::default();
163        let mut metric = WordErrorRate::<TestBackend>::new();
164
165        // One substitution in each sequence.
166        // Sequence 1: target [1, 3], pred [1, 2] -> 1 error (3 vs 2)
167        // Sequence 2: target [3, 4], pred [3, 5] -> 1 error (4 vs 5)
168        let preds = Tensor::from_data([[1, 2], [3, 5]], &device);
169        let tgts = Tensor::from_data([[1, 3], [3, 4]], &device);
170
171        metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
172
173        // Total errors = 2, Total target words = 4. WER = (2/4) * 100 = 50 %
174        assert_eq!(50.0, metric.value().current());
175    }
176
177    /// Same scenario as above, but with right-padding (token 9) ignored.
178    #[test]
179    fn test_wer_with_padding() {
180        let device = Default::default();
181        let pad = 9_i64;
182        let mut metric = WordErrorRate::<TestBackend>::new().with_pad_token(pad as usize);
183
184        // Each row has three columns, last one is the pad token.
185        // Target sequences after removing pad: [1, 3] and [3, 4] (total length 4)
186        // Predicted sequences after removing pad: [1, 2] and [3, 5]
187        let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device);
188        let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device);
189
190        metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
191        assert_eq!(50.0, metric.value().current());
192    }
193
194    /// `clear()` must reset the running statistics to NaN.
195    #[test]
196    fn test_clear_resets_state() {
197        let device = Default::default();
198        let mut metric = WordErrorRate::<TestBackend>::new();
199
200        let preds = Tensor::from_data([[1, 2]], &device);
201        let tgts = Tensor::from_data([[1, 3]], &device); // one error
202
203        metric.update(
204            &WerInput::new(preds.clone(), tgts.clone()),
205            &MetricMetadata::fake(),
206        );
207        assert!(metric.value().current() > 0.0);
208
209        metric.clear();
210        assert!(metric.value().current().is_nan());
211    }
212}