harper_pos_utils/tagger/brill_tagger/
mod.rs1mod patch;
2
3#[cfg(feature = "training")]
4use std::path::Path;
5
6use patch::Patch;
7use serde::{Deserialize, Serialize};
8
9#[cfg(feature = "training")]
10use super::FreqDict;
11#[cfg(feature = "training")]
12use super::error_counter::{ErrorCounter, ErrorKind};
13
14use crate::{Tagger, UPOS};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct BrillTagger<B>
18where
19    B: Tagger,
20{
21    base: B,
22    patches: Vec<Patch>,
23}
24
25impl<B> BrillTagger<B>
26where
27    B: Tagger,
28{
29    pub fn new(base: B) -> Self {
30        Self {
31            base,
32            patches: Vec::new(),
33        }
34    }
35
36    fn apply_patches(&self, sentence: &[String], tags: &mut [Option<UPOS>]) {
37        for patch in &self.patches {
38            for i in 0..sentence.len() {
39                let Some(i_tag) = tags.get(i).copied().flatten() else {
40                    continue;
41                };
42
43                if patch.from == i_tag && patch.criteria.fulfils(sentence, tags, &[], i) {
44                    tags[i] = Some(patch.to);
45                }
46            }
47        }
48    }
49}
50
51impl<B> Tagger for BrillTagger<B>
52where
53    B: Tagger,
54{
55    fn tag_sentence(&self, sentence: &[String]) -> Vec<Option<UPOS>> {
58        let mut tags = self.base.tag_sentence(sentence);
59        self.apply_patches(sentence, &mut tags);
60
61        tags
62    }
63}
64
65#[cfg(feature = "training")]
66impl BrillTagger<FreqDict> {
67    pub fn locate_patch_errors(
70        &self,
71        sentence: &[String],
72        correct_tags: &[Option<UPOS>],
73        base_tags: &[Option<UPOS>],
74        errors: &mut ErrorCounter,
75    ) {
76        let mut base_tags = base_tags.to_vec();
77        self.apply_patches(sentence, &mut base_tags);
78
79        for ((tag, correct_tag), word) in base_tags.iter().zip(correct_tags.iter()).zip(sentence) {
80            if let Some(tag) = tag {
81                if let Some(correct_tag) = correct_tag {
82                    if tag != correct_tag {
83                        errors.inc(
84                            ErrorKind {
85                                was_tagged: *tag,
86                                correct_tag: *correct_tag,
87                            },
88                            word.as_str(),
89                        )
90                    }
91                }
92            }
93        }
94    }
95
96    pub fn locate_tag_errors(
99        &self,
100        sentence: &[String],
101        correct_tags: &[Option<UPOS>],
102    ) -> ErrorCounter {
103        let tags = self.tag_sentence(sentence);
104
105        let mut errors = ErrorCounter::new();
106
107        for ((tag, correct_tag), word) in tags.iter().zip(correct_tags.iter()).zip(sentence) {
108            if let Some(tag) = tag {
109                if let Some(correct_tag) = correct_tag {
110                    if tag != correct_tag {
111                        errors.inc(
112                            ErrorKind {
113                                was_tagged: *tag,
114                                correct_tag: *correct_tag,
115                            },
116                            word.as_str(),
117                        )
118                    }
119                }
120            }
121        }
122
123        errors
124    }
125
126    fn epoch(&mut self, training_files: &[impl AsRef<Path>], candidate_selection_chance: f32) {
130        use crate::conllu_utils::iter_sentences_in_conllu;
131        use rs_conllu::Sentence;
132        use std::time::Instant;
133
134        assert!((0.0..=1.0).contains(&candidate_selection_chance));
135
136        let mut total_tokens = 0;
137        let mut error_counter = ErrorCounter::new();
138
139        let sentences: Vec<Sentence> = training_files
140            .iter()
141            .flat_map(iter_sentences_in_conllu)
142            .collect();
143        let mut sentences_tagged: Vec<(Vec<String>, Vec<Option<UPOS>>)> = Vec::new();
144
145        for sent in &sentences {
146            let mut toks: Vec<String> = Vec::new();
147            let mut tags = Vec::new();
148
149            for token in &sent.tokens {
150                let form = token.form.clone();
151                if let Some(last) = toks.last_mut() {
152                    match form.as_str() {
153                        "sn't" | "n't" | "'ll" | "'ve" | "'re" | "'d" | "'m" | "'s" => {
154                            last.push_str(&form);
155                            continue;
156                        }
157                        _ => {}
158                    }
159                }
160                toks.push(form);
161                tags.push(token.upos.and_then(UPOS::from_conllu));
162            }
163
164            sentences_tagged.push((toks, tags));
165        }
166
167        for (tok_buf, tag_buf) in &sentences_tagged {
168            total_tokens += tok_buf.len();
169            error_counter
170                .merge_from(self.locate_tag_errors(tok_buf.as_slice(), tag_buf.as_slice()));
171        }
172
173        println!("=============");
174        println!("Total tokens in training set: {total_tokens}");
175        println!(
176            "Tokens incorrectly tagged: {}",
177            error_counter.total_errors()
178        );
179        println!(
180            "Error rate: {}%",
181            error_counter.total_errors() as f32 / total_tokens as f32 * 100.
182        );
183
184        let mut base_tags = Vec::new();
186        for (toks, _) in &sentences_tagged {
187            base_tags.push(self.tag_sentence(toks));
188        }
189
190        let all_candidates = Patch::generate_candidate_patches(&error_counter);
191        let mut pruned_candidates: Vec<Patch> = rand::seq::IndexedRandom::choose_multiple(
192            all_candidates.as_slice(),
193            &mut rand::rng(),
194            (all_candidates.len() as f32 * candidate_selection_chance) as usize,
195        )
196        .cloned()
197        .collect();
198
199        let start = Instant::now();
200
201        #[cfg(feature = "threaded")]
202        rayon::slice::ParallelSliceMut::par_sort_by_cached_key(
203            pruned_candidates.as_mut_slice(),
204            |candidate: &Patch| {
205                self.score_candidate(candidate.clone(), &sentences_tagged, &base_tags)
206            },
207        );
208
209        #[cfg(not(feature = "threaded"))]
210        pruned_candidates.sort_by_cached_key(|candidate| {
211            self.score_candidate(candidate.clone(), &sentences_tagged, &base_tags)
212        });
213
214        let duration = start.elapsed();
215        let seconds = duration.as_secs();
216        let millis = duration.subsec_millis();
217
218        println!(
219            "It took {} seconds and {} milliseconds to search through {} candidates at {} c/sec.",
220            seconds,
221            millis,
222            pruned_candidates.len(),
223            pruned_candidates.len() as f32 / seconds as f32
224        );
225
226        if let Some(best) = pruned_candidates.first() {
227            self.patches.push(best.clone());
228        }
229    }
230
231    fn score_candidate(
233        &self,
234        candidate: Patch,
235        sentences_tagged: &[(Vec<String>, Vec<Option<UPOS>>)],
236        base_tags: &[Vec<Option<UPOS>>],
237    ) -> usize {
238        let mut tagger = BrillTagger::new(FreqDict::default());
239        tagger.patches.push(candidate);
240
241        let mut candidate_errors = ErrorCounter::new();
242
243        for ((toks, tags), base) in sentences_tagged.iter().zip(base_tags.iter()) {
244            tagger.locate_patch_errors(
245                toks.as_slice(),
246                tags.as_slice(),
247                base,
248                &mut candidate_errors,
249            );
250        }
251
252        candidate_errors.total_errors()
253    }
254
255    pub fn train(
259        training_files: &[impl AsRef<Path>],
260        epochs: usize,
261        candidate_selection_chance: f32,
262    ) -> Self {
263        use crate::FreqDictBuilder;
264
265        let mut freq_dict_builder = FreqDictBuilder::new();
266
267        for file in training_files {
268            freq_dict_builder.inc_from_conllu_file(file);
269        }
270
271        let freq_dict = freq_dict_builder.build();
272
273        let mut tagger = Self::new(freq_dict);
274
275        for _ in 0..epochs {
276            tagger.epoch(training_files, candidate_selection_chance);
277        }
278
279        tagger
280    }
281}