chrf/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3
4#[cfg(not(feature = "ahash"))]
5use std::collections::HashMap;
6
7#[cfg(feature = "ahash")]
8use ahash::HashMap;
9
10/// A trait representing a container for ngram counts.
11pub trait Ngrams<G = char>: Default
12where
13    G: Copy + Default,
14{
15    #[doc(hidden)]
16    fn _feed_impl<const N: usize>(&mut self, count: usize, buffer: [G; N]);
17
18    #[doc(hidden)]
19    fn _chrf_impl(beta: f64, tl: &Self, refs: &Self) -> (f64, usize);
20
21    /// Adds all of the items from `iter`.
22    fn feed_from(&mut self, iter: impl IntoIterator<Item = G>);
23
24    /// Clears all of the ngrams.
25    fn clear(&mut self);
26}
27
28#[derive(Default, Debug)]
29struct N0<G>(core::marker::PhantomData<G>);
30
31impl<G> Ngrams<G> for N0<G>
32where
33    G: Copy + Default,
34{
35    fn _feed_impl<const N: usize>(&mut self, _count: usize, _buffer: [G; N]) {}
36    fn _chrf_impl(_beta: f64, _tl: &Self, _refs: &Self) -> (f64, usize) {
37        (0.0, 0)
38    }
39    fn feed_from(&mut self, _iter: impl IntoIterator<Item = G>) {}
40    fn clear(&mut self) {}
41}
42
43macro_rules! impl_ngrams {
44    ($(($name:ident = $width:expr, $next:ident, [$($inner:ident),*]))*) => {
45        $(
46            #[derive(Default, Debug)]
47            pub struct $name<G = char> {
48                ngrams: HashMap<[G; $width], u32>,
49                next: $next<G>,
50            }
51
52            const _: () = {
53                assert!($width != 0);
54            };
55
56            impl From<&str> for $name<char> {
57                fn from(text: &str) -> Self {
58                    let mut out = Self::default();
59                    out.feed(text);
60                    out
61                }
62            }
63
64            impl $name<char> {
65                /// Adds all of the ngrams from `text` except spaces.
66                fn feed(&mut self, text: &str) {
67                    self.feed_from(text.chars().filter(|&ch| ch != ' '))
68                }
69            }
70
71            impl<G> AsRef<$name<G>> for $name<G> {
72                #[inline(always)]
73                fn as_ref(&self) -> &$name<G> {
74                    self
75                }
76            }
77
78            impl<G> AsMut<$name<G>> for $name<G> {
79                #[inline(always)]
80                fn as_mut(&mut self) -> &mut $name<G> {
81                    self
82                }
83            }
84
85            impl<G> AsRef<$next<G>> for $name<G> {
86                #[inline(always)]
87                fn as_ref(&self) -> &$next<G> {
88                    &self.next
89                }
90            }
91
92            impl<G> AsMut<$next<G>> for $name<G> {
93                #[inline(always)]
94                fn as_mut(&mut self) -> &mut $next<G> {
95                    &mut self.next
96                }
97            }
98
99            $(
100                impl<G> AsRef<$inner<G>> for $name<G> {
101                    #[inline(always)]
102                    fn as_ref(&self) -> &$inner<G> {
103                        self.next.as_ref()
104                    }
105                }
106
107                impl<G> AsMut<$inner<G>> for $name<G> {
108                    #[inline(always)]
109                    fn as_mut(&mut self) -> &mut $inner<G> {
110                        self.next.as_mut()
111                    }
112                }
113            )*
114
115            impl<G> Ngrams<G> for $name<G> where G: Copy + Default + PartialEq + Eq + core::hash::Hash {
116                #[inline(always)]
117                fn _feed_impl<const N: usize>(&mut self, count: usize, buffer: [G; N]) {
118                    assert!(N >= $width);
119                    if count >= $width {
120                        let mut ngram = [G::default(); $width];
121                        ngram.copy_from_slice(&buffer[buffer.len() - $width..]);
122                        *self.ngrams.entry(ngram).or_insert(0) += 1;
123                    }
124                    self.next._feed_impl(count, buffer);
125                }
126
127                #[inline(always)]
128                fn _chrf_impl(beta: f64, tl: &Self, refs: &Self) -> (f64, usize) {
129                    let mut total_tl = 0;
130                    for &count_tl in tl.ngrams.values() {
131                        total_tl += count_tl;
132                    }
133
134                    let mut matching = 0;
135                    let mut total_ref = 0;
136                    for (ngram, &count_ref) in &refs.ngrams {
137                        total_ref += count_ref;
138                        if let Some(&count_tl) = tl.ngrams.get(ngram) {
139                            matching += core::cmp::min(count_ref, count_tl);
140                        }
141                    }
142
143                    let chr_tl = if total_tl > 0 {
144                        matching as f64 / total_tl as f64
145                    } else {
146                        1e-16
147                    };
148
149                    let chr_ref = if total_ref > 0 {
150                        matching as f64 / total_ref as f64
151                    } else {
152                        1e-16
153                    };
154
155                    let beta2 = beta.powi(2);
156                    let numerator = (1.0 + beta2) * (chr_tl * chr_ref);
157                    let mut denominator = (beta2 * chr_tl + chr_ref);
158                    if denominator < 1e-16 {
159                        denominator = 1e-16;
160                    }
161
162                    let score = numerator / denominator;
163                    let (next_score, next_count) = Ngrams::_chrf_impl(beta, &tl.next, &refs.next);
164                    (score + next_score, next_count + 1)
165                }
166
167                fn clear(&mut self) {
168                    self.ngrams.clear();
169                    self.next.clear();
170                }
171
172                fn feed_from(&mut self, iter: impl IntoIterator<Item = G>) {
173                    let mut ngram = [G::default(); $width];
174                    let mut count = 0;
175                    for ch in iter {
176                        #[allow(clippy::reversed_empty_ranges)]
177                        for n in 0..$width - 1 {
178                            ngram[n] = ngram[n + 1];
179                        }
180                        ngram[$width - 1] = ch;
181                        count += 1;
182                        self._feed_impl(count, ngram);
183                    }
184                }
185            }
186        )*
187    }
188}
189
190impl_ngrams! {
191    (N1 = 1, N0, [])
192    (N2 = 2, N1, [N0])
193    (N3 = 3, N2, [N1, N0])
194    (N4 = 4, N3, [N2, N1, N0])
195    (N5 = 5, N4, [N3, N2, N1, N0])
196    (N6 = 6, N5, [N4, N3, N2, N1, N0])
197    (N7 = 7, N6, [N5, N4, N3, N2, N1, N0])
198    (N8 = 8, N7, [N6, N5, N4, N3, N2, N1, N0])
199    (N9 = 9, N8, [N7, N6, N5, N4, N3, N2, N1, N0])
200    (N10 = 10, N9, [N8, N7, N6, N5, N4, N3, N2, N1, N0])
201    (N11 = 11, N10, [N9, N8, N7, N6, N5, N4, N3, N2, N1, N0])
202    (N12 = 12, N11, [N10, N9, N8, N7, N6, N5, N4, N3, N2, N1, N0])
203}
204
205/// Calculates a custom chrF score.
206///
207/// NOTE: Unlike [chrf3] the score returned by this function is *not* multiplied by 100.
208pub fn chrf<T>(beta: f64, translation: &T, reference: &T) -> f64
209where
210    T: Ngrams,
211{
212    let (sum, count) = Ngrams::_chrf_impl(beta, translation, reference);
213    sum / count as f64
214}
215
216/// Calculates a chrF3 score.
217pub fn chrf3(translation: &N6, reference: &N6) -> f64 {
218    chrf(3.0, translation, reference) * 100.0
219}
220
221#[test]
222fn test_chrf3() {
223    {
224        let tl = "aoeu33";
225        let refs = "axeu33";
226        let score = chrf3(&tl.into(), &refs.into());
227        assert!(
228            (score - 37.7778).abs() < 0.0001,
229            "unexpected score: {score} (test 1)"
230        );
231    }
232
233    {
234        let tl = "Recent offers of evacuating residents from the Syrian regime and Russia sound like only thinly veiled threats, pediatricians, surgeons and other doctors have said.";
235        let refs = "Recent offers of evacuation form the regime and Russia had sounded like thinly-veiled threats, said the surgeons paediatricians and other doctors.";
236        let score = chrf3(&tl.into(), &refs.into());
237        assert!(
238            (score - 69.8328).abs() < 0.0001,
239            "unexpected score: {score} (test 1)"
240        );
241    }
242}