1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3
4#[cfg(not(feature = "ahash"))]
5use std::collections::HashMap;
6
7#[cfg(feature = "ahash")]
8use ahash::HashMap;
9
10pub trait Ngrams<G = char>: Default
12where
13 G: Copy + Default,
14{
15 #[doc(hidden)]
16 fn _feed_impl<const N: usize>(&mut self, count: usize, buffer: [G; N]);
17
18 #[doc(hidden)]
19 fn _chrf_impl(beta: f64, tl: &Self, refs: &Self) -> (f64, usize);
20
21 fn feed_from(&mut self, iter: impl IntoIterator<Item = G>);
23
24 fn clear(&mut self);
26}
27
28#[derive(Default, Debug)]
29struct N0<G>(core::marker::PhantomData<G>);
30
31impl<G> Ngrams<G> for N0<G>
32where
33 G: Copy + Default,
34{
35 fn _feed_impl<const N: usize>(&mut self, _count: usize, _buffer: [G; N]) {}
36 fn _chrf_impl(_beta: f64, _tl: &Self, _refs: &Self) -> (f64, usize) {
37 (0.0, 0)
38 }
39 fn feed_from(&mut self, _iter: impl IntoIterator<Item = G>) {}
40 fn clear(&mut self) {}
41}
42
43macro_rules! impl_ngrams {
44 ($(($name:ident = $width:expr, $next:ident, [$($inner:ident),*]))*) => {
45 $(
46 #[derive(Default, Debug)]
47 pub struct $name<G = char> {
48 ngrams: HashMap<[G; $width], u32>,
49 next: $next<G>,
50 }
51
52 const _: () = {
53 assert!($width != 0);
54 };
55
56 impl From<&str> for $name<char> {
57 fn from(text: &str) -> Self {
58 let mut out = Self::default();
59 out.feed(text);
60 out
61 }
62 }
63
64 impl $name<char> {
65 fn feed(&mut self, text: &str) {
67 self.feed_from(text.chars().filter(|&ch| ch != ' '))
68 }
69 }
70
71 impl<G> AsRef<$name<G>> for $name<G> {
72 #[inline(always)]
73 fn as_ref(&self) -> &$name<G> {
74 self
75 }
76 }
77
78 impl<G> AsMut<$name<G>> for $name<G> {
79 #[inline(always)]
80 fn as_mut(&mut self) -> &mut $name<G> {
81 self
82 }
83 }
84
85 impl<G> AsRef<$next<G>> for $name<G> {
86 #[inline(always)]
87 fn as_ref(&self) -> &$next<G> {
88 &self.next
89 }
90 }
91
92 impl<G> AsMut<$next<G>> for $name<G> {
93 #[inline(always)]
94 fn as_mut(&mut self) -> &mut $next<G> {
95 &mut self.next
96 }
97 }
98
99 $(
100 impl<G> AsRef<$inner<G>> for $name<G> {
101 #[inline(always)]
102 fn as_ref(&self) -> &$inner<G> {
103 self.next.as_ref()
104 }
105 }
106
107 impl<G> AsMut<$inner<G>> for $name<G> {
108 #[inline(always)]
109 fn as_mut(&mut self) -> &mut $inner<G> {
110 self.next.as_mut()
111 }
112 }
113 )*
114
115 impl<G> Ngrams<G> for $name<G> where G: Copy + Default + PartialEq + Eq + core::hash::Hash {
116 #[inline(always)]
117 fn _feed_impl<const N: usize>(&mut self, count: usize, buffer: [G; N]) {
118 assert!(N >= $width);
119 if count >= $width {
120 let mut ngram = [G::default(); $width];
121 ngram.copy_from_slice(&buffer[buffer.len() - $width..]);
122 *self.ngrams.entry(ngram).or_insert(0) += 1;
123 }
124 self.next._feed_impl(count, buffer);
125 }
126
127 #[inline(always)]
128 fn _chrf_impl(beta: f64, tl: &Self, refs: &Self) -> (f64, usize) {
129 let mut total_tl = 0;
130 for &count_tl in tl.ngrams.values() {
131 total_tl += count_tl;
132 }
133
134 let mut matching = 0;
135 let mut total_ref = 0;
136 for (ngram, &count_ref) in &refs.ngrams {
137 total_ref += count_ref;
138 if let Some(&count_tl) = tl.ngrams.get(ngram) {
139 matching += core::cmp::min(count_ref, count_tl);
140 }
141 }
142
143 let chr_tl = if total_tl > 0 {
144 matching as f64 / total_tl as f64
145 } else {
146 1e-16
147 };
148
149 let chr_ref = if total_ref > 0 {
150 matching as f64 / total_ref as f64
151 } else {
152 1e-16
153 };
154
155 let beta2 = beta.powi(2);
156 let numerator = (1.0 + beta2) * (chr_tl * chr_ref);
157 let mut denominator = (beta2 * chr_tl + chr_ref);
158 if denominator < 1e-16 {
159 denominator = 1e-16;
160 }
161
162 let score = numerator / denominator;
163 let (next_score, next_count) = Ngrams::_chrf_impl(beta, &tl.next, &refs.next);
164 (score + next_score, next_count + 1)
165 }
166
167 fn clear(&mut self) {
168 self.ngrams.clear();
169 self.next.clear();
170 }
171
172 fn feed_from(&mut self, iter: impl IntoIterator<Item = G>) {
173 let mut ngram = [G::default(); $width];
174 let mut count = 0;
175 for ch in iter {
176 #[allow(clippy::reversed_empty_ranges)]
177 for n in 0..$width - 1 {
178 ngram[n] = ngram[n + 1];
179 }
180 ngram[$width - 1] = ch;
181 count += 1;
182 self._feed_impl(count, ngram);
183 }
184 }
185 }
186 )*
187 }
188}
189
190impl_ngrams! {
191 (N1 = 1, N0, [])
192 (N2 = 2, N1, [N0])
193 (N3 = 3, N2, [N1, N0])
194 (N4 = 4, N3, [N2, N1, N0])
195 (N5 = 5, N4, [N3, N2, N1, N0])
196 (N6 = 6, N5, [N4, N3, N2, N1, N0])
197 (N7 = 7, N6, [N5, N4, N3, N2, N1, N0])
198 (N8 = 8, N7, [N6, N5, N4, N3, N2, N1, N0])
199 (N9 = 9, N8, [N7, N6, N5, N4, N3, N2, N1, N0])
200 (N10 = 10, N9, [N8, N7, N6, N5, N4, N3, N2, N1, N0])
201 (N11 = 11, N10, [N9, N8, N7, N6, N5, N4, N3, N2, N1, N0])
202 (N12 = 12, N11, [N10, N9, N8, N7, N6, N5, N4, N3, N2, N1, N0])
203}
204
205pub fn chrf<T>(beta: f64, translation: &T, reference: &T) -> f64
209where
210 T: Ngrams,
211{
212 let (sum, count) = Ngrams::_chrf_impl(beta, translation, reference);
213 sum / count as f64
214}
215
216pub fn chrf3(translation: &N6, reference: &N6) -> f64 {
218 chrf(3.0, translation, reference) * 100.0
219}
220
221#[test]
222fn test_chrf3() {
223 {
224 let tl = "aoeu33";
225 let refs = "axeu33";
226 let score = chrf3(&tl.into(), &refs.into());
227 assert!(
228 (score - 37.7778).abs() < 0.0001,
229 "unexpected score: {score} (test 1)"
230 );
231 }
232
233 {
234 let tl = "Recent offers of evacuating residents from the Syrian regime and Russia sound like only thinly veiled threats, pediatricians, surgeons and other doctors have said.";
235 let refs = "Recent offers of evacuation form the regime and Russia had sounded like thinly-veiled threats, said the surgeons paediatricians and other doctors.";
236 let score = chrf3(&tl.into(), &refs.into());
237 assert!(
238 (score - 69.8328).abs() < 0.0001,
239 "unexpected score: {score} (test 1)"
240 );
241 }
242}