harper_pos_utils/chunker/brill_chunker/
mod.rs

1mod patch;
2
3#[cfg(feature = "training")]
4use std::path::Path;
5
6#[cfg(feature = "training")]
7use crate::word_counter::WordCounter;
8use crate::{
9    UPOS,
10    chunker::{Chunker, upos_freq_dict::UPOSFreqDict},
11};
12
13use patch::Patch;
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct BrillChunker {
18    base: UPOSFreqDict,
19    patches: Vec<Patch>,
20}
21
22impl BrillChunker {
23    pub fn new(base: UPOSFreqDict) -> Self {
24        Self {
25            base,
26            patches: Vec::new(),
27        }
28    }
29
30    fn apply_patches(&self, sentence: &[String], tags: &[Option<UPOS>], np_states: &mut [bool]) {
31        for patch in &self.patches {
32            for i in 0..sentence.len() {
33                if patch.from == np_states[i]
34                    && patch.criteria.fulfils(sentence, tags, np_states, i)
35                {
36                    np_states[i] = !np_states[i];
37                }
38            }
39        }
40    }
41}
42
43impl Chunker for BrillChunker {
44    fn chunk_sentence(&self, sentence: &[String], tags: &[Option<UPOS>]) -> Vec<bool> {
45        let mut initial_pass = self.base.chunk_sentence(sentence, tags);
46
47        self.apply_patches(sentence, tags, &mut initial_pass);
48
49        initial_pass
50    }
51}
52
53#[cfg(feature = "training")]
54type CandidateArgs = (Vec<String>, Vec<Option<UPOS>>, Vec<bool>);
55
56#[cfg(feature = "training")]
57impl BrillChunker {
58    /// Tag a provided sentence with the tagger, providing the "correct" tags (from a dataset or
59    /// other source), returning the number of errors.
60    pub fn count_patch_errors(
61        &self,
62        sentence: &[String],
63        tags: &[Option<UPOS>],
64        base_flags: &[bool],
65        correct_np_flags: &[bool],
66    ) -> usize {
67        let mut flags = base_flags.to_vec();
68        self.apply_patches(sentence, tags, &mut flags);
69
70        let mut loss = 0;
71        for (a, b) in flags.into_iter().zip(correct_np_flags) {
72            if a != *b {
73                loss += 1;
74            }
75        }
76
77        loss
78    }
79
80    /// Tag a provided sentence with the tagger, providing the "correct" tags (from a dataset or
81    /// other source), returning the number of errors.
82    pub fn count_chunk_errors(
83        &self,
84        sentence: &[String],
85        tags: &[Option<UPOS>],
86        correct_np_flags: &[bool],
87        relevant_words: &mut WordCounter,
88    ) -> usize {
89        let flags = self.chunk_sentence(sentence, tags);
90
91        let mut loss = 0;
92        for ((a, b), word) in flags.into_iter().zip(correct_np_flags).zip(sentence) {
93            if a != *b {
94                loss += 1;
95                relevant_words.inc(word);
96            }
97        }
98
99        loss
100    }
101
102    /// To speed up training, only try a subset of all possible candidates.
103    /// How many to select is given by the `candidate_selection_chance`. A higher chance means a
104    /// longer training time.
105    fn epoch(&mut self, training_files: &[impl AsRef<Path>], candidate_selection_chance: f32) {
106        use crate::conllu_utils::iter_sentences_in_conllu;
107        use rs_conllu::Sentence;
108        use std::time::Instant;
109
110        assert!((0.0..=1.0).contains(&candidate_selection_chance));
111
112        let mut total_tokens = 0;
113        let mut error_counter = 0;
114
115        let sentences: Vec<Sentence> = training_files
116            .iter()
117            .flat_map(iter_sentences_in_conllu)
118            .collect();
119        let mut sentences_flagged: Vec<CandidateArgs> = Vec::new();
120
121        for sent in &sentences {
122            use hashbrown::HashSet;
123
124            use crate::chunker::np_extraction::locate_noun_phrases_in_sent;
125
126            let mut toks: Vec<String> = Vec::new();
127            let mut tags = Vec::new();
128
129            for token in &sent.tokens {
130                let form = token.form.clone();
131                if let Some(last) = toks.last_mut() {
132                    match form.as_str() {
133                        "sn't" | "n't" | "'ll" | "'ve" | "'re" | "'d" | "'m" | "'s" => {
134                            last.push_str(&form);
135                            continue;
136                        }
137                        _ => {}
138                    }
139                }
140                toks.push(form);
141                tags.push(token.upos.and_then(UPOS::from_conllu));
142            }
143
144            let actual = locate_noun_phrases_in_sent(sent);
145            let actual_flat = actual.into_iter().fold(HashSet::new(), |mut a, b| {
146                a.extend(b.into_iter());
147                a
148            });
149
150            let mut actual_seq = Vec::new();
151
152            for el in actual_flat {
153                if el >= actual_seq.len() {
154                    actual_seq.resize(el + 1, false);
155                }
156                actual_seq[el] = true;
157            }
158
159            sentences_flagged.push((toks, tags, actual_seq));
160        }
161
162        let mut relevant_words = WordCounter::default();
163
164        for (tok_buf, tag_buf, flag_buf) in &sentences_flagged {
165            total_tokens += tok_buf.len();
166            error_counter += self.count_chunk_errors(
167                tok_buf.as_slice(),
168                tag_buf,
169                flag_buf.as_slice(),
170                &mut relevant_words,
171            );
172        }
173
174        println!("=============");
175        println!("Total tokens in training set: {total_tokens}");
176        println!("Tokens incorrectly flagged: {error_counter}");
177        println!(
178            "Error rate: {}%",
179            error_counter as f32 / total_tokens as f32 * 100.
180        );
181
182        // Before adding any patches, let's get a good base.
183        let mut base_flags = Vec::new();
184        for (toks, tags, _) in &sentences_flagged {
185            base_flags.push(self.chunk_sentence(toks, tags));
186        }
187
188        let all_candidates = Patch::generate_candidate_patches(&relevant_words);
189        let mut pruned_candidates: Vec<Patch> = rand::seq::IndexedRandom::choose_multiple(
190            all_candidates.as_slice(),
191            &mut rand::rng(),
192            (all_candidates.len() as f32 * candidate_selection_chance) as usize,
193        )
194        .cloned()
195        .collect();
196
197        let start = Instant::now();
198
199        #[cfg(feature = "threaded")]
200        rayon::slice::ParallelSliceMut::par_sort_by_cached_key(
201            pruned_candidates.as_mut_slice(),
202            |candidate: &Patch| {
203                self.score_candidate(candidate.clone(), &sentences_flagged, &base_flags)
204            },
205        );
206
207        #[cfg(not(feature = "threaded"))]
208        pruned_candidates.sort_by_cached_key(|candidate| {
209            self.score_candidate(candidate.clone(), &sentences_flagged, &base_flags)
210        });
211
212        let duration = start.elapsed();
213        let seconds = duration.as_secs();
214        let millis = duration.subsec_millis();
215
216        println!(
217            "It took {} seconds and {} milliseconds to search through {} candidates at {} c/sec.",
218            seconds,
219            millis,
220            pruned_candidates.len(),
221            pruned_candidates.len() as f32 / seconds as f32
222        );
223
224        if let Some(best) = pruned_candidates.first() {
225            self.patches.push(best.clone());
226        }
227    }
228
229    /// Lower is better
230    fn score_candidate(
231        &self,
232        candidate: Patch,
233        sentences_flagged: &[CandidateArgs],
234        base_flags: &[Vec<bool>],
235    ) -> usize {
236        let mut tagger = BrillChunker::new(UPOSFreqDict::default());
237        tagger.patches.push(candidate);
238
239        let mut errors = 0;
240
241        for ((toks, tags, flags), base) in sentences_flagged.iter().zip(base_flags.iter()) {
242            errors += tagger.count_patch_errors(toks.as_slice(), tags.as_slice(), base, flags);
243        }
244
245        errors
246    }
247
248    /// Train a brand-new tagger on a `.conllu` dataset, provided via a path.
249    /// This does not do _any_ error handling, and should not run in production.
250    /// It should be used for training a model that _will_ be used in production.
251    pub fn train(
252        training_files: &[impl AsRef<Path>],
253        epochs: usize,
254        candidate_selection_chance: f32,
255    ) -> Self {
256        let mut freq_dict = UPOSFreqDict::default();
257
258        for file in training_files {
259            freq_dict.inc_from_conllu_file(file);
260        }
261
262        let mut chunker = Self::new(freq_dict);
263
264        for _ in 0..epochs {
265            chunker.epoch(training_files, candidate_selection_chance);
266        }
267
268        chunker
269    }
270}