Skip to main content

burn_train/metric/
wer.rs

1use super::cer::edit_distance;
2use super::state::{FormatOptions, NumericMetricState};
3use super::{MetricMetadata, SerializedEntry};
4use crate::metric::{
5    Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry,
6};
7use burn_core::tensor::backend::Backend;
8use burn_core::tensor::{Int, Tensor};
9use core::marker::PhantomData;
10use std::sync::Arc;
11
12// The edit_distance function remains the same as it calculates the Levenshtein distance
13// between two sequences. The "units" within the sequences will now be treated as words.
14/// The word error rate (WER) metric, similar to the CER, is defined as the edit distance (e.g. Levenshtein distance) between the predicted
15/// and reference word sequences, divided by the total number of words in the reference. Here, the "units" within the sequences are words.
16///
17#[derive(Clone)]
18pub struct WordErrorRate<B: Backend> {
19    name: MetricName,
20    state: NumericMetricState,
21    pad_token: Option<usize>,
22    _b: PhantomData<B>,
23}
24
25/// The [word error rate metric](WordErrorRate) input type.
26#[derive(new)]
27pub struct WerInput<B: Backend> {
28    /// The predicted token sequences (as a 2-D tensor of token indices).
29    pub outputs: Tensor<B, 2, Int>,
30    /// The target token sequences (as a 2-D tensor of token indices).
31    pub targets: Tensor<B, 2, Int>,
32}
33impl<B: Backend> Default for WordErrorRate<B> {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl<B: Backend> WordErrorRate<B> {
40    /// Creates the metric.
41    pub fn new() -> Self {
42        Self {
43            name: Arc::new("WER".to_string()),
44            state: NumericMetricState::default(),
45            pad_token: None,
46            _b: PhantomData,
47        }
48    }
49
50    /// Sets the pad token.
51    pub fn with_pad_token(mut self, index: usize) -> Self {
52        self.pad_token = Some(index);
53        self
54    }
55}
56
57impl<B: Backend> Metric for WordErrorRate<B> {
58    type Input = WerInput<B>;
59
60    fn update(&mut self, input: &WerInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
61        let outputs = input.outputs.clone();
62        let targets = input.targets.clone();
63        let [batch_size, seq_len] = targets.dims();
64
65        let outputs_data = outputs
66            .to_data()
67            .to_vec::<i64>()
68            .expect("Failed to convert outputs to Vec");
69        let targets_data = targets
70            .to_data()
71            .to_vec::<i64>()
72            .expect("Failed to convert targets to Vec");
73
74        let pad_token = self.pad_token;
75
76        let mut total_edit_distance = 0.0;
77        let mut total_target_length = 0.0;
78
79        // Process each sequence in the batch
80        for i in 0..batch_size {
81            let start = i * seq_len;
82            let end = (i + 1) * seq_len;
83            let output_seq = &outputs_data[start..end];
84            let target_seq = &targets_data[start..end];
85
86            // Handle padding and map elements to i32.
87            // These sequences now represent "words" (token IDs).
88            let output_seq_no_pad = match pad_token {
89                Some(pad) => output_seq
90                    .iter()
91                    .take_while(|&&x| x != pad as i64)
92                    .map(|&x| x as i32)
93                    .collect::<Vec<_>>(),
94                None => output_seq.iter().map(|&x| x as i32).collect(),
95            };
96
97            let target_seq_no_pad = match pad_token {
98                Some(pad) => target_seq
99                    .iter()
100                    .take_while(|&&x| x != pad as i64)
101                    .map(|&x| x as i32)
102                    .collect::<Vec<_>>(),
103                None => target_seq.iter().map(|&x| x as i32).collect(),
104            };
105
106            let ed = edit_distance(&target_seq_no_pad, &output_seq_no_pad);
107            total_edit_distance += ed as f64;
108            total_target_length += target_seq_no_pad.len() as f64;
109        }
110
111        // Compute current WER value as a percentage
112        let value = if total_target_length > 0.0 {
113            100.0 * total_edit_distance / total_target_length
114        } else {
115            0.0
116        };
117
118        self.state.update(
119            value,
120            batch_size,
121            FormatOptions::new(self.name()).unit("%").precision(2),
122        )
123    }
124
125    fn name(&self) -> MetricName {
126        self.name.clone()
127    }
128
129    fn clear(&mut self) {
130        self.state.reset();
131    }
132
133    fn attributes(&self) -> MetricAttributes {
134        NumericAttributes {
135            unit: Some("%".to_string()),
136            higher_is_better: false,
137        }
138        .into()
139    }
140}
141
142impl<B: Backend> Numeric for WordErrorRate<B> {
143    fn value(&self) -> NumericEntry {
144        self.state.current_value()
145    }
146
147    fn running_value(&self) -> NumericEntry {
148        self.state.running_value()
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::TestBackend;
156
157    /// Perfect match => WER = 0 %.
158    #[test]
159    fn test_wer_without_padding() {
160        let device = Default::default();
161        let mut metric = WordErrorRate::<TestBackend>::new();
162
163        // Batch size = 2, sequence length = 2
164        let preds = Tensor::from_data([[1, 2], [3, 4]], &device);
165        let tgts = Tensor::from_data([[1, 2], [3, 4]], &device);
166
167        metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
168
169        assert_eq!(0.0, metric.value().current());
170    }
171
172    /// Two word edits in four target words => 50 %.
173    #[test]
174    fn test_wer_without_padding_two_errors() {
175        let device = Default::default();
176        let mut metric = WordErrorRate::<TestBackend>::new();
177
178        // One substitution in each sequence.
179        // Sequence 1: target [1, 3], pred [1, 2] -> 1 error (3 vs 2)
180        // Sequence 2: target [3, 4], pred [3, 5] -> 1 error (4 vs 5)
181        let preds = Tensor::from_data([[1, 2], [3, 5]], &device);
182        let tgts = Tensor::from_data([[1, 3], [3, 4]], &device);
183
184        metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
185
186        // Total errors = 2, Total target words = 4. WER = (2/4) * 100 = 50 %
187        assert_eq!(50.0, metric.value().current());
188    }
189
190    /// Same scenario as above, but with right-padding (token 9) ignored.
191    #[test]
192    fn test_wer_with_padding() {
193        let device = Default::default();
194        let pad = 9_i64;
195        let mut metric = WordErrorRate::<TestBackend>::new().with_pad_token(pad as usize);
196
197        // Each row has three columns, last one is the pad token.
198        // Target sequences after removing pad: [1, 3] and [3, 4] (total length 4)
199        // Predicted sequences after removing pad: [1, 2] and [3, 5]
200        let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device);
201        let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device);
202
203        metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());
204        assert_eq!(50.0, metric.value().current());
205    }
206
207    /// `clear()` must reset the running statistics to NaN.
208    #[test]
209    fn test_clear_resets_state() {
210        let device = Default::default();
211        let mut metric = WordErrorRate::<TestBackend>::new();
212
213        let preds = Tensor::from_data([[1, 2]], &device);
214        let tgts = Tensor::from_data([[1, 3]], &device); // one error
215
216        metric.update(
217            &WerInput::new(preds.clone(), tgts.clone()),
218            &MetricMetadata::fake(),
219        );
220        assert!(metric.value().current() > 0.0);
221
222        metric.clear();
223        assert!(metric.value().current().is_nan());
224    }
225}