1use super::state::{FormatOptions, NumericMetricState};
2use super::{MetricEntry, MetricMetadata};
3use crate::metric::{Metric, MetricName, Numeric, NumericEntry};
4use burn_core::tensor::backend::Backend;
5use burn_core::tensor::{Int, Tensor};
6use core::marker::PhantomData;
7use std::sync::Arc;
8
9pub(crate) fn edit_distance(reference: &[i32], prediction: &[i32]) -> usize {
16 let mut prev = (0..=prediction.len()).collect::<Vec<_>>();
17 let mut curr = vec![0; prediction.len() + 1];
18
19 for (i, &r) in reference.iter().enumerate() {
20 curr[0] = i + 1;
21 for (j, &p) in prediction.iter().enumerate() {
22 curr[j + 1] = if r == p {
23 prev[j] } else {
25 1 + prev[j].min(prev[j + 1]).min(curr[j]) };
27 }
28 core::mem::swap(&mut prev, &mut curr);
29 }
30 prev[prediction.len()]
31}
32
33#[derive(Clone)]
39pub struct CharErrorRate<B: Backend> {
40 name: MetricName,
41 state: NumericMetricState,
42 pad_token: Option<usize>,
43 _b: PhantomData<B>,
44}
45
46#[derive(new)]
48pub struct CerInput<B: Backend> {
49 pub outputs: Tensor<B, 2, Int>,
51 pub targets: Tensor<B, 2, Int>,
53}
54
55impl<B: Backend> Default for CharErrorRate<B> {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl<B: Backend> CharErrorRate<B> {
62 pub fn new() -> Self {
64 Self {
65 name: Arc::new("CER".to_string()),
66 state: NumericMetricState::default(),
67 pad_token: None,
68 _b: PhantomData,
69 }
70 }
71
72 pub fn with_pad_token(mut self, index: usize) -> Self {
74 self.pad_token = Some(index);
75 self
76 }
77}
78
79impl<B: Backend> Metric for CharErrorRate<B> {
81 type Input = CerInput<B>;
82
83 fn update(&mut self, input: &CerInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
84 let outputs = &input.outputs;
85 let targets = &input.targets;
86 let [batch_size, seq_len] = targets.dims();
87
88 let (output_lengths, target_lengths) = if let Some(pad) = self.pad_token {
89 let output_mask = outputs.clone().not_equal_elem(pad as i64);
91 let target_mask = targets.clone().not_equal_elem(pad as i64);
92
93 let output_lengths_tensor = output_mask.int().sum_dim(1);
94 let target_lengths_tensor = target_mask.int().sum_dim(1);
95
96 (
97 output_lengths_tensor.to_data().to_vec::<i64>().unwrap(),
98 target_lengths_tensor.to_data().to_vec::<i64>().unwrap(),
99 )
100 } else {
101 (
103 vec![seq_len as i64; batch_size],
104 vec![seq_len as i64; batch_size],
105 )
106 };
107
108 let outputs_data = outputs.to_data().to_vec::<i64>().unwrap();
109 let targets_data = targets.to_data().to_vec::<i64>().unwrap();
110
111 let total_edit_distance: usize = (0..batch_size)
112 .map(|i| {
113 let start = i * seq_len;
114
115 let output_len = output_lengths[i] as usize;
117 let target_len = target_lengths[i] as usize;
118
119 let output_seq_slice = &outputs_data[start..(start + output_len)];
120 let target_seq_slice = &targets_data[start..(start + target_len)];
121 let output_seq: Vec<i32> = output_seq_slice.iter().map(|&x| x as i32).collect();
122 let target_seq: Vec<i32> = target_seq_slice.iter().map(|&x| x as i32).collect();
123
124 edit_distance(&target_seq, &output_seq)
125 })
126 .sum();
127
128 let total_target_length = target_lengths.iter().map(|&x| x as f64).sum::<f64>();
129
130 let value = if total_target_length > 0.0 {
131 100.0 * total_edit_distance as f64 / total_target_length
132 } else {
133 0.0
134 };
135
136 self.state.update(
137 value,
138 batch_size,
139 FormatOptions::new(self.name()).unit("%").precision(2),
140 )
141 }
142
143 fn clear(&mut self) {
144 self.state.reset();
145 }
146
147 fn name(&self) -> MetricName {
148 self.name.clone()
149 }
150}
151
152impl<B: Backend> Numeric for CharErrorRate<B> {
154 fn value(&self) -> NumericEntry {
155 self.state.value()
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::TestBackend;
163
164 #[test]
166 fn test_cer_without_padding() {
167 let device = Default::default();
168 let mut metric = CharErrorRate::<TestBackend>::new();
169
170 let preds = Tensor::from_data([[1, 2], [3, 4]], &device);
172 let tgts = Tensor::from_data([[1, 2], [3, 4]], &device);
173
174 metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());
175
176 assert_eq!(0.0, metric.value().current());
177 }
178
179 #[test]
181 fn test_cer_without_padding_two_errors() {
182 let device = Default::default();
183 let mut metric = CharErrorRate::<TestBackend>::new();
184
185 let preds = Tensor::from_data([[1, 2], [3, 5]], &device);
187 let tgts = Tensor::from_data([[1, 3], [3, 4]], &device);
188
189 metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());
190
191 assert_eq!(50.0, metric.value().current());
193 }
194
195 #[test]
197 fn test_cer_with_padding() {
198 let device = Default::default();
199 let pad = 9_i64;
200 let mut metric = CharErrorRate::<TestBackend>::new().with_pad_token(pad as usize);
201
202 let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device);
204 let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device);
205
206 metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());
207 assert_eq!(50.0, metric.value().current());
208 }
209
210 #[test]
212 fn test_clear_resets_state() {
213 let device = Default::default();
214 let mut metric = CharErrorRate::<TestBackend>::new();
215
216 let preds = Tensor::from_data([[1, 2]], &device);
217 let tgts = Tensor::from_data([[1, 3]], &device); metric.update(
220 &CerInput::new(preds.clone(), tgts.clone()),
221 &MetricMetadata::fake(),
222 );
223 assert!(metric.value().current() > 0.0);
224
225 metric.clear();
226 assert!(metric.value().current().is_nan());
227 }
228}