harper_pos_utils/chunker/brill_chunker/
mod.rs

1mod patch;
2
3#[cfg(feature = "training")]
4use std::path::Path;
5
6#[cfg(feature = "training")]
7use crate::word_counter::WordCounter;
8use crate::{
9    UPOS,
10    chunker::{Chunker, upos_freq_dict::UPOSFreqDict},
11};
12
13use patch::Patch;
14use serde::{Deserialize, Serialize};
15
16/// A [`Chunker`] implementation based on the work by Eric Brill.
17///
18/// Additional reading:
19///
20/// - [Continuations on Transformation-based Learning](https://elijahpotter.dev/articles/more-transformation-based-learning)
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BrillChunker {
23    base: UPOSFreqDict,
24    patches: Vec<Patch>,
25}
26
27impl BrillChunker {
28    pub fn new(base: UPOSFreqDict) -> Self {
29        Self {
30            base,
31            patches: Vec::new(),
32        }
33    }
34
35    fn apply_patches(&self, sentence: &[String], tags: &[Option<UPOS>], np_states: &mut [bool]) {
36        for patch in &self.patches {
37            for i in 0..sentence.len() {
38                if patch.from == np_states[i]
39                    && patch.criteria.fulfils(sentence, tags, np_states, i)
40                {
41                    np_states[i] = !np_states[i];
42                }
43            }
44        }
45    }
46}
47
48impl Chunker for BrillChunker {
49    fn chunk_sentence(&self, sentence: &[String], tags: &[Option<UPOS>]) -> Vec<bool> {
50        let mut initial_pass = self.base.chunk_sentence(sentence, tags);
51
52        self.apply_patches(sentence, tags, &mut initial_pass);
53
54        initial_pass
55    }
56}
57
58#[cfg(feature = "training")]
59type CandidateArgs = (Vec<String>, Vec<Option<UPOS>>, Vec<bool>);
60
61#[cfg(feature = "training")]
62impl BrillChunker {
63    /// Tag a provided sentence with the tagger, providing the "correct" tags (from a dataset or
64    /// other source), returning the number of errors.
65    pub fn count_patch_errors(
66        &self,
67        sentence: &[String],
68        tags: &[Option<UPOS>],
69        base_flags: &[bool],
70        correct_np_flags: &[bool],
71    ) -> usize {
72        let mut flags = base_flags.to_vec();
73        self.apply_patches(sentence, tags, &mut flags);
74
75        let mut loss = 0;
76        for (a, b) in flags.into_iter().zip(correct_np_flags) {
77            if a != *b {
78                loss += 1;
79            }
80        }
81
82        loss
83    }
84
85    /// Tag a provided sentence with the tagger, providing the "correct" tags (from a dataset or
86    /// other source), returning the number of errors.
87    pub fn count_chunk_errors(
88        &self,
89        sentence: &[String],
90        tags: &[Option<UPOS>],
91        correct_np_flags: &[bool],
92        relevant_words: &mut WordCounter,
93    ) -> usize {
94        let flags = self.chunk_sentence(sentence, tags);
95
96        let mut loss = 0;
97        for ((a, b), word) in flags.into_iter().zip(correct_np_flags).zip(sentence) {
98            if a != *b {
99                loss += 1;
100                relevant_words.inc(word);
101            }
102        }
103
104        loss
105    }
106
107    /// To speed up training, only try a subset of all possible candidates.
108    /// How many to select is given by the `candidate_selection_chance`. A higher chance means a
109    /// longer training time.
110    fn epoch(&mut self, training_files: &[impl AsRef<Path>], candidate_selection_chance: f32) {
111        use crate::conllu_utils::iter_sentences_in_conllu;
112        use rs_conllu::Sentence;
113        use std::time::Instant;
114
115        assert!((0.0..=1.0).contains(&candidate_selection_chance));
116
117        let mut total_tokens = 0;
118        let mut error_counter = 0;
119
120        let sentences: Vec<Sentence> = training_files
121            .iter()
122            .flat_map(iter_sentences_in_conllu)
123            .collect();
124        let mut sentences_flagged: Vec<CandidateArgs> = Vec::new();
125
126        for sent in &sentences {
127            use hashbrown::HashSet;
128
129            use crate::chunker::np_extraction::locate_noun_phrases_in_sent;
130
131            let mut toks: Vec<String> = Vec::new();
132            let mut tags = Vec::new();
133
134            for token in &sent.tokens {
135                let form = token.form.clone();
136                if let Some(last) = toks.last_mut() {
137                    match form.as_str() {
138                        "sn't" | "n't" | "'ll" | "'ve" | "'re" | "'d" | "'m" | "'s" => {
139                            last.push_str(&form);
140                            continue;
141                        }
142                        _ => {}
143                    }
144                }
145                toks.push(form);
146                tags.push(token.upos.and_then(UPOS::from_conllu));
147            }
148
149            let actual = locate_noun_phrases_in_sent(sent);
150            let actual_flat = actual.into_iter().fold(HashSet::new(), |mut a, b| {
151                a.extend(b.into_iter());
152                a
153            });
154
155            let mut actual_seq = Vec::new();
156
157            for el in actual_flat {
158                if el >= actual_seq.len() {
159                    actual_seq.resize(el + 1, false);
160                }
161                actual_seq[el] = true;
162            }
163
164            sentences_flagged.push((toks, tags, actual_seq));
165        }
166
167        let mut relevant_words = WordCounter::default();
168
169        for (tok_buf, tag_buf, flag_buf) in &sentences_flagged {
170            total_tokens += tok_buf.len();
171            error_counter += self.count_chunk_errors(
172                tok_buf.as_slice(),
173                tag_buf,
174                flag_buf.as_slice(),
175                &mut relevant_words,
176            );
177        }
178
179        println!("=============");
180        println!("Total tokens in training set: {total_tokens}");
181        println!("Tokens incorrectly flagged: {error_counter}");
182        println!(
183            "Error rate: {}%",
184            error_counter as f32 / total_tokens as f32 * 100.
185        );
186
187        // Before adding any patches, let's get a good base.
188        let mut base_flags = Vec::new();
189        for (toks, tags, _) in &sentences_flagged {
190            base_flags.push(self.chunk_sentence(toks, tags));
191        }
192
193        let all_candidates = Patch::generate_candidate_patches(&relevant_words);
194        let mut pruned_candidates: Vec<Patch> = rand::seq::IndexedRandom::choose_multiple(
195            all_candidates.as_slice(),
196            &mut rand::rng(),
197            (all_candidates.len() as f32 * candidate_selection_chance) as usize,
198        )
199        .cloned()
200        .collect();
201
202        let start = Instant::now();
203
204        #[cfg(feature = "threaded")]
205        rayon::slice::ParallelSliceMut::par_sort_by_cached_key(
206            pruned_candidates.as_mut_slice(),
207            |candidate: &Patch| {
208                self.score_candidate(candidate.clone(), &sentences_flagged, &base_flags)
209            },
210        );
211
212        #[cfg(not(feature = "threaded"))]
213        pruned_candidates.sort_by_cached_key(|candidate| {
214            self.score_candidate(candidate.clone(), &sentences_flagged, &base_flags)
215        });
216
217        let duration = start.elapsed();
218        let seconds = duration.as_secs();
219        let millis = duration.subsec_millis();
220
221        println!(
222            "It took {} seconds and {} milliseconds to search through {} candidates at {} c/sec.",
223            seconds,
224            millis,
225            pruned_candidates.len(),
226            pruned_candidates.len() as f32 / seconds as f32
227        );
228
229        if let Some(best) = pruned_candidates.first() {
230            self.patches.push(best.clone());
231        }
232    }
233
234    /// Lower is better
235    fn score_candidate(
236        &self,
237        candidate: Patch,
238        sentences_flagged: &[CandidateArgs],
239        base_flags: &[Vec<bool>],
240    ) -> usize {
241        let mut tagger = BrillChunker::new(UPOSFreqDict::default());
242        tagger.patches.push(candidate);
243
244        let mut errors = 0;
245
246        for ((toks, tags, flags), base) in sentences_flagged.iter().zip(base_flags.iter()) {
247            errors += tagger.count_patch_errors(toks.as_slice(), tags.as_slice(), base, flags);
248        }
249
250        errors
251    }
252
253    /// Train a brand-new tagger on a `.conllu` dataset, provided via a path.
254    /// This does not do _any_ error handling, and should not run in production.
255    /// It should be used for training a model that _will_ be used in production.
256    pub fn train(
257        training_files: &[impl AsRef<Path>],
258        epochs: usize,
259        candidate_selection_chance: f32,
260    ) -> Self {
261        let mut freq_dict = UPOSFreqDict::default();
262
263        for file in training_files {
264            freq_dict.inc_from_conllu_file(file);
265        }
266
267        let mut chunker = Self::new(freq_dict);
268
269        for _ in 0..epochs {
270            chunker.epoch(training_files, candidate_selection_chance);
271        }
272
273        chunker
274    }
275}