harper_pos_utils/chunker/brill_chunker/
mod.rs1mod patch;
2
3#[cfg(feature = "training")]
4use std::path::Path;
5
6#[cfg(feature = "training")]
7use crate::word_counter::WordCounter;
8use crate::{
9 UPOS,
10 chunker::{Chunker, upos_freq_dict::UPOSFreqDict},
11};
12
13use patch::Patch;
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BrillChunker {
23 base: UPOSFreqDict,
24 patches: Vec<Patch>,
25}
26
27impl BrillChunker {
28 pub fn new(base: UPOSFreqDict) -> Self {
29 Self {
30 base,
31 patches: Vec::new(),
32 }
33 }
34
35 fn apply_patches(&self, sentence: &[String], tags: &[Option<UPOS>], np_states: &mut [bool]) {
36 for patch in &self.patches {
37 for i in 0..sentence.len() {
38 if patch.from == np_states[i]
39 && patch.criteria.fulfils(sentence, tags, np_states, i)
40 {
41 np_states[i] = !np_states[i];
42 }
43 }
44 }
45 }
46}
47
48impl Chunker for BrillChunker {
49 fn chunk_sentence(&self, sentence: &[String], tags: &[Option<UPOS>]) -> Vec<bool> {
50 let mut initial_pass = self.base.chunk_sentence(sentence, tags);
51
52 self.apply_patches(sentence, tags, &mut initial_pass);
53
54 initial_pass
55 }
56}
57
58#[cfg(feature = "training")]
59type CandidateArgs = (Vec<String>, Vec<Option<UPOS>>, Vec<bool>);
60
61#[cfg(feature = "training")]
62impl BrillChunker {
63 pub fn count_patch_errors(
66 &self,
67 sentence: &[String],
68 tags: &[Option<UPOS>],
69 base_flags: &[bool],
70 correct_np_flags: &[bool],
71 ) -> usize {
72 let mut flags = base_flags.to_vec();
73 self.apply_patches(sentence, tags, &mut flags);
74
75 let mut loss = 0;
76 for (a, b) in flags.into_iter().zip(correct_np_flags) {
77 if a != *b {
78 loss += 1;
79 }
80 }
81
82 loss
83 }
84
85 pub fn count_chunk_errors(
88 &self,
89 sentence: &[String],
90 tags: &[Option<UPOS>],
91 correct_np_flags: &[bool],
92 relevant_words: &mut WordCounter,
93 ) -> usize {
94 let flags = self.chunk_sentence(sentence, tags);
95
96 let mut loss = 0;
97 for ((a, b), word) in flags.into_iter().zip(correct_np_flags).zip(sentence) {
98 if a != *b {
99 loss += 1;
100 relevant_words.inc(word);
101 }
102 }
103
104 loss
105 }
106
107 fn epoch(&mut self, training_files: &[impl AsRef<Path>], candidate_selection_chance: f32) {
111 use crate::conllu_utils::iter_sentences_in_conllu;
112 use rs_conllu::Sentence;
113 use std::time::Instant;
114
115 assert!((0.0..=1.0).contains(&candidate_selection_chance));
116
117 let mut total_tokens = 0;
118 let mut error_counter = 0;
119
120 let sentences: Vec<Sentence> = training_files
121 .iter()
122 .flat_map(iter_sentences_in_conllu)
123 .collect();
124 let mut sentences_flagged: Vec<CandidateArgs> = Vec::new();
125
126 for sent in &sentences {
127 use hashbrown::HashSet;
128
129 use crate::chunker::np_extraction::locate_noun_phrases_in_sent;
130
131 let mut toks: Vec<String> = Vec::new();
132 let mut tags = Vec::new();
133
134 for token in &sent.tokens {
135 let form = token.form.clone();
136 if let Some(last) = toks.last_mut() {
137 match form.as_str() {
138 "sn't" | "n't" | "'ll" | "'ve" | "'re" | "'d" | "'m" | "'s" => {
139 last.push_str(&form);
140 continue;
141 }
142 _ => {}
143 }
144 }
145 toks.push(form);
146 tags.push(token.upos.and_then(UPOS::from_conllu));
147 }
148
149 let actual = locate_noun_phrases_in_sent(sent);
150 let actual_flat = actual.into_iter().fold(HashSet::new(), |mut a, b| {
151 a.extend(b.into_iter());
152 a
153 });
154
155 let mut actual_seq = Vec::new();
156
157 for el in actual_flat {
158 if el >= actual_seq.len() {
159 actual_seq.resize(el + 1, false);
160 }
161 actual_seq[el] = true;
162 }
163
164 sentences_flagged.push((toks, tags, actual_seq));
165 }
166
167 let mut relevant_words = WordCounter::default();
168
169 for (tok_buf, tag_buf, flag_buf) in &sentences_flagged {
170 total_tokens += tok_buf.len();
171 error_counter += self.count_chunk_errors(
172 tok_buf.as_slice(),
173 tag_buf,
174 flag_buf.as_slice(),
175 &mut relevant_words,
176 );
177 }
178
179 println!("=============");
180 println!("Total tokens in training set: {total_tokens}");
181 println!("Tokens incorrectly flagged: {error_counter}");
182 println!(
183 "Error rate: {}%",
184 error_counter as f32 / total_tokens as f32 * 100.
185 );
186
187 let mut base_flags = Vec::new();
189 for (toks, tags, _) in &sentences_flagged {
190 base_flags.push(self.chunk_sentence(toks, tags));
191 }
192
193 let all_candidates = Patch::generate_candidate_patches(&relevant_words);
194 let mut pruned_candidates: Vec<Patch> = rand::seq::IndexedRandom::choose_multiple(
195 all_candidates.as_slice(),
196 &mut rand::rng(),
197 (all_candidates.len() as f32 * candidate_selection_chance) as usize,
198 )
199 .cloned()
200 .collect();
201
202 let start = Instant::now();
203
204 #[cfg(feature = "threaded")]
205 rayon::slice::ParallelSliceMut::par_sort_by_cached_key(
206 pruned_candidates.as_mut_slice(),
207 |candidate: &Patch| {
208 self.score_candidate(candidate.clone(), &sentences_flagged, &base_flags)
209 },
210 );
211
212 #[cfg(not(feature = "threaded"))]
213 pruned_candidates.sort_by_cached_key(|candidate| {
214 self.score_candidate(candidate.clone(), &sentences_flagged, &base_flags)
215 });
216
217 let duration = start.elapsed();
218 let seconds = duration.as_secs();
219 let millis = duration.subsec_millis();
220
221 println!(
222 "It took {} seconds and {} milliseconds to search through {} candidates at {} c/sec.",
223 seconds,
224 millis,
225 pruned_candidates.len(),
226 pruned_candidates.len() as f32 / seconds as f32
227 );
228
229 if let Some(best) = pruned_candidates.first() {
230 self.patches.push(best.clone());
231 }
232 }
233
234 fn score_candidate(
236 &self,
237 candidate: Patch,
238 sentences_flagged: &[CandidateArgs],
239 base_flags: &[Vec<bool>],
240 ) -> usize {
241 let mut tagger = BrillChunker::new(UPOSFreqDict::default());
242 tagger.patches.push(candidate);
243
244 let mut errors = 0;
245
246 for ((toks, tags, flags), base) in sentences_flagged.iter().zip(base_flags.iter()) {
247 errors += tagger.count_patch_errors(toks.as_slice(), tags.as_slice(), base, flags);
248 }
249
250 errors
251 }
252
253 pub fn train(
257 training_files: &[impl AsRef<Path>],
258 epochs: usize,
259 candidate_selection_chance: f32,
260 ) -> Self {
261 let mut freq_dict = UPOSFreqDict::default();
262
263 for file in training_files {
264 freq_dict.inc_from_conllu_file(file);
265 }
266
267 let mut chunker = Self::new(freq_dict);
268
269 for _ in 0..epochs {
270 chunker.epoch(training_files, candidate_selection_chance);
271 }
272
273 chunker
274 }
275}