harper_pos_utils/chunker/brill_chunker/
mod.rs1mod patch;
2
3#[cfg(feature = "training")]
4use std::path::Path;
5
6#[cfg(feature = "training")]
7use crate::word_counter::WordCounter;
8use crate::{
9 UPOS,
10 chunker::{Chunker, upos_freq_dict::UPOSFreqDict},
11};
12
13use patch::Patch;
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct BrillChunker {
18 base: UPOSFreqDict,
19 patches: Vec<Patch>,
20}
21
22impl BrillChunker {
23 pub fn new(base: UPOSFreqDict) -> Self {
24 Self {
25 base,
26 patches: Vec::new(),
27 }
28 }
29
30 fn apply_patches(&self, sentence: &[String], tags: &[Option<UPOS>], np_states: &mut [bool]) {
31 for patch in &self.patches {
32 for i in 0..sentence.len() {
33 if patch.from == np_states[i]
34 && patch.criteria.fulfils(sentence, tags, np_states, i)
35 {
36 np_states[i] = !np_states[i];
37 }
38 }
39 }
40 }
41}
42
43impl Chunker for BrillChunker {
44 fn chunk_sentence(&self, sentence: &[String], tags: &[Option<UPOS>]) -> Vec<bool> {
45 let mut initial_pass = self.base.chunk_sentence(sentence, tags);
46
47 self.apply_patches(sentence, tags, &mut initial_pass);
48
49 initial_pass
50 }
51}
52
53#[cfg(feature = "training")]
54type CandidateArgs = (Vec<String>, Vec<Option<UPOS>>, Vec<bool>);
55
56#[cfg(feature = "training")]
57impl BrillChunker {
58 pub fn count_patch_errors(
61 &self,
62 sentence: &[String],
63 tags: &[Option<UPOS>],
64 base_flags: &[bool],
65 correct_np_flags: &[bool],
66 ) -> usize {
67 let mut flags = base_flags.to_vec();
68 self.apply_patches(sentence, tags, &mut flags);
69
70 let mut loss = 0;
71 for (a, b) in flags.into_iter().zip(correct_np_flags) {
72 if a != *b {
73 loss += 1;
74 }
75 }
76
77 loss
78 }
79
80 pub fn count_chunk_errors(
83 &self,
84 sentence: &[String],
85 tags: &[Option<UPOS>],
86 correct_np_flags: &[bool],
87 relevant_words: &mut WordCounter,
88 ) -> usize {
89 let flags = self.chunk_sentence(sentence, tags);
90
91 let mut loss = 0;
92 for ((a, b), word) in flags.into_iter().zip(correct_np_flags).zip(sentence) {
93 if a != *b {
94 loss += 1;
95 relevant_words.inc(word);
96 }
97 }
98
99 loss
100 }
101
102 fn epoch(&mut self, training_files: &[impl AsRef<Path>], candidate_selection_chance: f32) {
106 use crate::conllu_utils::iter_sentences_in_conllu;
107 use rs_conllu::Sentence;
108 use std::time::Instant;
109
110 assert!((0.0..=1.0).contains(&candidate_selection_chance));
111
112 let mut total_tokens = 0;
113 let mut error_counter = 0;
114
115 let sentences: Vec<Sentence> = training_files
116 .iter()
117 .flat_map(iter_sentences_in_conllu)
118 .collect();
119 let mut sentences_flagged: Vec<CandidateArgs> = Vec::new();
120
121 for sent in &sentences {
122 use hashbrown::HashSet;
123
124 use crate::chunker::np_extraction::locate_noun_phrases_in_sent;
125
126 let mut toks: Vec<String> = Vec::new();
127 let mut tags = Vec::new();
128
129 for token in &sent.tokens {
130 let form = token.form.clone();
131 if let Some(last) = toks.last_mut() {
132 match form.as_str() {
133 "sn't" | "n't" | "'ll" | "'ve" | "'re" | "'d" | "'m" | "'s" => {
134 last.push_str(&form);
135 continue;
136 }
137 _ => {}
138 }
139 }
140 toks.push(form);
141 tags.push(token.upos.and_then(UPOS::from_conllu));
142 }
143
144 let actual = locate_noun_phrases_in_sent(sent);
145 let actual_flat = actual.into_iter().fold(HashSet::new(), |mut a, b| {
146 a.extend(b.into_iter());
147 a
148 });
149
150 let mut actual_seq = Vec::new();
151
152 for el in actual_flat {
153 if el >= actual_seq.len() {
154 actual_seq.resize(el + 1, false);
155 }
156 actual_seq[el] = true;
157 }
158
159 sentences_flagged.push((toks, tags, actual_seq));
160 }
161
162 let mut relevant_words = WordCounter::default();
163
164 for (tok_buf, tag_buf, flag_buf) in &sentences_flagged {
165 total_tokens += tok_buf.len();
166 error_counter += self.count_chunk_errors(
167 tok_buf.as_slice(),
168 tag_buf,
169 flag_buf.as_slice(),
170 &mut relevant_words,
171 );
172 }
173
174 println!("=============");
175 println!("Total tokens in training set: {total_tokens}");
176 println!("Tokens incorrectly flagged: {error_counter}");
177 println!(
178 "Error rate: {}%",
179 error_counter as f32 / total_tokens as f32 * 100.
180 );
181
182 let mut base_flags = Vec::new();
184 for (toks, tags, _) in &sentences_flagged {
185 base_flags.push(self.chunk_sentence(toks, tags));
186 }
187
188 let all_candidates = Patch::generate_candidate_patches(&relevant_words);
189 let mut pruned_candidates: Vec<Patch> = rand::seq::IndexedRandom::choose_multiple(
190 all_candidates.as_slice(),
191 &mut rand::rng(),
192 (all_candidates.len() as f32 * candidate_selection_chance) as usize,
193 )
194 .cloned()
195 .collect();
196
197 let start = Instant::now();
198
199 #[cfg(feature = "threaded")]
200 rayon::slice::ParallelSliceMut::par_sort_by_cached_key(
201 pruned_candidates.as_mut_slice(),
202 |candidate: &Patch| {
203 self.score_candidate(candidate.clone(), &sentences_flagged, &base_flags)
204 },
205 );
206
207 #[cfg(not(feature = "threaded"))]
208 pruned_candidates.sort_by_cached_key(|candidate| {
209 self.score_candidate(candidate.clone(), &sentences_flagged, &base_flags)
210 });
211
212 let duration = start.elapsed();
213 let seconds = duration.as_secs();
214 let millis = duration.subsec_millis();
215
216 println!(
217 "It took {} seconds and {} milliseconds to search through {} candidates at {} c/sec.",
218 seconds,
219 millis,
220 pruned_candidates.len(),
221 pruned_candidates.len() as f32 / seconds as f32
222 );
223
224 if let Some(best) = pruned_candidates.first() {
225 self.patches.push(best.clone());
226 }
227 }
228
229 fn score_candidate(
231 &self,
232 candidate: Patch,
233 sentences_flagged: &[CandidateArgs],
234 base_flags: &[Vec<bool>],
235 ) -> usize {
236 let mut tagger = BrillChunker::new(UPOSFreqDict::default());
237 tagger.patches.push(candidate);
238
239 let mut errors = 0;
240
241 for ((toks, tags, flags), base) in sentences_flagged.iter().zip(base_flags.iter()) {
242 errors += tagger.count_patch_errors(toks.as_slice(), tags.as_slice(), base, flags);
243 }
244
245 errors
246 }
247
248 pub fn train(
252 training_files: &[impl AsRef<Path>],
253 epochs: usize,
254 candidate_selection_chance: f32,
255 ) -> Self {
256 let mut freq_dict = UPOSFreqDict::default();
257
258 for file in training_files {
259 freq_dict.inc_from_conllu_file(file);
260 }
261
262 let mut chunker = Self::new(freq_dict);
263
264 for _ in 0..epochs {
265 chunker.epoch(training_files, candidate_selection_chance);
266 }
267
268 chunker
269 }
270}