burn_train/metric/
cer.rs

1use super::state::{FormatOptions, NumericMetricState};
2use super::{MetricEntry, MetricMetadata};
3use crate::metric::{Metric, MetricName, Numeric, NumericEntry};
4use burn_core::tensor::backend::Backend;
5use burn_core::tensor::{Int, Tensor};
6use core::marker::PhantomData;
7use std::sync::Arc;
8
9/// Computes the edit distance (Levenshtein distance) between two sequences of integers.
10///
11/// The edit distance is defined as the minimum number of single-element edits (insertions,
12/// deletions, or substitutions) required to change one sequence into the other. This
13/// implementation is optimized for space, using only two rows of the dynamic programming table.
14///
15pub(crate) fn edit_distance(reference: &[i32], prediction: &[i32]) -> usize {
16    let mut prev = (0..=prediction.len()).collect::<Vec<_>>();
17    let mut curr = vec![0; prediction.len() + 1];
18
19    for (i, &r) in reference.iter().enumerate() {
20        curr[0] = i + 1;
21        for (j, &p) in prediction.iter().enumerate() {
22            curr[j + 1] = if r == p {
23                prev[j] // no operation needed
24            } else {
25                1 + prev[j].min(prev[j + 1]).min(curr[j]) // substitution, insertion, deletion
26            };
27        }
28        core::mem::swap(&mut prev, &mut curr);
29    }
30    prev[prediction.len()]
31}
32
33/// Character error rate (CER) is defined as the edit distance (e.g. Levenshtein distance) between the predicted
34/// and reference character sequences, divided by the total number of characters in the reference.
35/// This metric is commonly used in tasks such as speech recognition, OCR, or text generation
36/// to quantify how closely the predicted output matches the ground truth at a character level.
37///
38#[derive(Clone)]
39pub struct CharErrorRate<B: Backend> {
40    name: MetricName,
41    state: NumericMetricState,
42    pad_token: Option<usize>,
43    _b: PhantomData<B>,
44}
45
46/// The [character error rate metric](CharErrorRate) input type.
47#[derive(new)]
48pub struct CerInput<B: Backend> {
49    /// The predicted token sequences (as a 2-D tensor of token indices).
50    pub outputs: Tensor<B, 2, Int>,
51    /// The target token sequences (as a 2-D tensor of token indices).
52    pub targets: Tensor<B, 2, Int>,
53}
54
55impl<B: Backend> Default for CharErrorRate<B> {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl<B: Backend> CharErrorRate<B> {
62    /// Creates the metric.
63    pub fn new() -> Self {
64        Self {
65            name: Arc::new("CER".to_string()),
66            state: NumericMetricState::default(),
67            pad_token: None,
68            _b: PhantomData,
69        }
70    }
71
72    /// Sets the pad token.
73    pub fn with_pad_token(mut self, index: usize) -> Self {
74        self.pad_token = Some(index);
75        self
76    }
77}
78
79/// The [character error rate metric](CharErrorRate) implementation.
80impl<B: Backend> Metric for CharErrorRate<B> {
81    type Input = CerInput<B>;
82
83    fn update(&mut self, input: &CerInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
84        let outputs = &input.outputs;
85        let targets = &input.targets;
86        let [batch_size, seq_len] = targets.dims();
87
88        let (output_lengths, target_lengths) = if let Some(pad) = self.pad_token {
89            // Create boolean masks for non-padding tokens.
90            let output_mask = outputs.clone().not_equal_elem(pad as i64);
91            let target_mask = targets.clone().not_equal_elem(pad as i64);
92
93            let output_lengths_tensor = output_mask.int().sum_dim(1);
94            let target_lengths_tensor = target_mask.int().sum_dim(1);
95
96            (
97                output_lengths_tensor.to_data().to_vec::<i64>().unwrap(),
98                target_lengths_tensor.to_data().to_vec::<i64>().unwrap(),
99            )
100        } else {
101            // If there's no padding, all sequences have the full length.
102            (
103                vec![seq_len as i64; batch_size],
104                vec![seq_len as i64; batch_size],
105            )
106        };
107
108        let outputs_data = outputs.to_data().to_vec::<i64>().unwrap();
109        let targets_data = targets.to_data().to_vec::<i64>().unwrap();
110
111        let total_edit_distance: usize = (0..batch_size)
112            .map(|i| {
113                let start = i * seq_len;
114
115                // Get pre-calculated lengths for the current sequence.
116                let output_len = output_lengths[i] as usize;
117                let target_len = target_lengths[i] as usize;
118
119                let output_seq_slice = &outputs_data[start..(start + output_len)];
120                let target_seq_slice = &targets_data[start..(start + target_len)];
121                let output_seq: Vec<i32> = output_seq_slice.iter().map(|&x| x as i32).collect();
122                let target_seq: Vec<i32> = target_seq_slice.iter().map(|&x| x as i32).collect();
123
124                edit_distance(&target_seq, &output_seq)
125            })
126            .sum();
127
128        let total_target_length = target_lengths.iter().map(|&x| x as f64).sum::<f64>();
129
130        let value = if total_target_length > 0.0 {
131            100.0 * total_edit_distance as f64 / total_target_length
132        } else {
133            0.0
134        };
135
136        self.state.update(
137            value,
138            batch_size,
139            FormatOptions::new(self.name()).unit("%").precision(2),
140        )
141    }
142
143    fn clear(&mut self) {
144        self.state.reset();
145    }
146
147    fn name(&self) -> MetricName {
148        self.name.clone()
149    }
150}
151
152/// The [character error rate metric](CharErrorRate) implementation.
153impl<B: Backend> Numeric for CharErrorRate<B> {
154    fn value(&self) -> NumericEntry {
155        self.state.value()
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use crate::TestBackend;
163
164    /// Perfect match ⇒ CER = 0 %.
165    #[test]
166    fn test_cer_without_padding() {
167        let device = Default::default();
168        let mut metric = CharErrorRate::<TestBackend>::new();
169
170        // Batch size = 2, sequence length = 2
171        let preds = Tensor::from_data([[1, 2], [3, 4]], &device);
172        let tgts = Tensor::from_data([[1, 2], [3, 4]], &device);
173
174        metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());
175
176        assert_eq!(0.0, metric.value().current());
177    }
178
179    /// Two edits in four target tokens ⇒ 50 %.
180    #[test]
181    fn test_cer_without_padding_two_errors() {
182        let device = Default::default();
183        let mut metric = CharErrorRate::<TestBackend>::new();
184
185        // One substitution in each sequence.
186        let preds = Tensor::from_data([[1, 2], [3, 5]], &device);
187        let tgts = Tensor::from_data([[1, 3], [3, 4]], &device);
188
189        metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());
190
191        // 2 edits / 4 tokens = 50 %
192        assert_eq!(50.0, metric.value().current());
193    }
194
195    /// Same scenario as above, but with right-padding (token 9) ignored.
196    #[test]
197    fn test_cer_with_padding() {
198        let device = Default::default();
199        let pad = 9_i64;
200        let mut metric = CharErrorRate::<TestBackend>::new().with_pad_token(pad as usize);
201
202        // Each row has three columns, last one is the pad token.
203        let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device);
204        let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device);
205
206        metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());
207        assert_eq!(50.0, metric.value().current());
208    }
209
210    /// `clear()` must reset the running statistics to zero.
211    #[test]
212    fn test_clear_resets_state() {
213        let device = Default::default();
214        let mut metric = CharErrorRate::<TestBackend>::new();
215
216        let preds = Tensor::from_data([[1, 2]], &device);
217        let tgts = Tensor::from_data([[1, 3]], &device); // one error
218
219        metric.update(
220            &CerInput::new(preds.clone(), tgts.clone()),
221            &MetricMetadata::fake(),
222        );
223        assert!(metric.value().current() > 0.0);
224
225        metric.clear();
226        assert!(metric.value().current().is_nan());
227    }
228}