harper_pos_utils/chunker/
burn_chunker.rs

1use crate::{UPOS, chunker::Chunker};
2#[cfg(feature = "training")]
3use burn::backend::Autodiff;
4
5#[cfg(feature = "training")]
6use burn::nn::loss::{MseLoss, Reduction};
7use burn::nn::{Dropout, DropoutConfig};
8#[cfg(feature = "training")]
9use burn::optim::{GradientsParams, Optimizer};
10use burn::record::{FullPrecisionSettings, NamedMpkBytesRecorder, NamedMpkFileRecorder, Recorder};
11use burn::tensor::TensorData;
12#[cfg(feature = "training")]
13use burn::tensor::backend::AutodiffBackend;
14
15use burn::{
16    module::Module,
17    nn::{BiLstmConfig, EmbeddingConfig, LinearConfig},
18    tensor::{Int, Tensor, backend::Backend},
19};
20use burn_ndarray::{NdArray, NdArrayDevice};
21use hashbrown::HashMap;
22use std::path::Path;
23
24const UNK_IDX: usize = 1;
25
26#[derive(Module, Debug)]
27struct NpModel<B: Backend> {
28    embedding_words: burn::nn::Embedding<B>,
29    embedding_upos: burn::nn::Embedding<B>,
30    lstm: burn::nn::BiLstm<B>,
31    linear_out: burn::nn::Linear<B>,
32    dropout: Dropout,
33}
34
35impl<B: Backend> NpModel<B> {
36    fn new(vocab: usize, word_embed_dim: usize, dropout: f32, device: &B::Device) -> Self {
37        let upos_embed = 8;
38        let total_embed = word_embed_dim + upos_embed;
39
40        Self {
41            embedding_words: EmbeddingConfig::new(vocab, word_embed_dim).init(device),
42            embedding_upos: EmbeddingConfig::new(20, upos_embed).init(device),
43            lstm: BiLstmConfig::new(total_embed, total_embed, false).init(device),
44            // Multiply by two because the BiLSTM emits double the hidden parameters
45            linear_out: LinearConfig::new(total_embed * 2, 1).init(device),
46            dropout: DropoutConfig::new(dropout as f64).init(),
47        }
48    }
49
50    fn forward(
51        &self,
52        word_tens: Tensor<B, 2, Int>,
53        tag_tens: Tensor<B, 2, Int>,
54        use_dropout: bool,
55    ) -> Tensor<B, 2> {
56        let word_embed = self.embedding_words.forward(word_tens);
57        let tag_embed = self.embedding_upos.forward(tag_tens);
58
59        let mut x = Tensor::cat(vec![word_embed, tag_embed], 2);
60
61        if use_dropout {
62            x = self.dropout.forward(x);
63        }
64
65        let (mut x, _) = self.lstm.forward(x, None);
66
67        if use_dropout {
68            x = self.dropout.forward(x);
69        }
70
71        let x = self.linear_out.forward(x);
72        x.squeeze::<2>(2)
73    }
74}
75
76pub struct BurnChunker<B: Backend> {
77    vocab: HashMap<String, usize>,
78    model: NpModel<B>,
79    device: B::Device,
80}
81
82impl<B: Backend> BurnChunker<B> {
83    fn idx(&self, tok: &str) -> usize {
84        *self.vocab.get(tok).unwrap_or(&UNK_IDX)
85    }
86
87    fn to_tensors(
88        &self,
89        sent: &[String],
90        tags: &[Option<UPOS>],
91    ) -> (Tensor<B, 2, Int>, Tensor<B, 2, Int>) {
92        // Interleave with UPOS tags
93        let idxs: Vec<_> = sent.iter().map(|t| self.idx(t) as i32).collect();
94
95        let upos: Vec<_> = tags
96            .iter()
97            .map(|t| t.map(|o| o as i32 + 2).unwrap_or(1))
98            .collect();
99
100        let word_tensor =
101            Tensor::<B, 1, Int>::from_data(TensorData::from(idxs.as_slice()), &self.device)
102                .reshape([1, sent.len()]);
103
104        let tag_tensor =
105            Tensor::<B, 1, Int>::from_data(TensorData::from(upos.as_slice()), &self.device)
106                .reshape([1, sent.len()]);
107
108        (word_tensor, tag_tensor)
109    }
110
111    pub fn save_to(&self, dir: impl AsRef<Path>) {
112        let dir = dir.as_ref();
113        std::fs::create_dir_all(dir).unwrap();
114
115        let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
116        self.model
117            .clone()
118            .save_file(dir.join("model.mpk"), &recorder)
119            .unwrap();
120
121        let vocab_bytes = serde_json::to_vec(&self.vocab).unwrap();
122        std::fs::write(dir.join("vocab.json"), vocab_bytes).unwrap();
123    }
124
125    pub fn load_from_bytes(
126        model_bytes: impl AsRef<[u8]>,
127        vocab_bytes: impl AsRef<[u8]>,
128        embed_dim: usize,
129        dropout: f32,
130        device: B::Device,
131    ) -> Self {
132        let vocab: HashMap<String, usize> = serde_json::from_slice(vocab_bytes.as_ref()).unwrap();
133
134        let recorder = NamedMpkBytesRecorder::<FullPrecisionSettings>::new();
135
136        let owned_data = model_bytes.as_ref().to_vec();
137        let record = recorder.load(owned_data, &device).unwrap();
138
139        let model = NpModel::new(vocab.len(), embed_dim, dropout, &device);
140        let model = model.load_record(record);
141
142        Self {
143            vocab,
144            model,
145            device,
146        }
147    }
148}
149
150#[cfg(feature = "training")]
151struct ExtractedSentences(
152    Vec<Vec<String>>,
153    Vec<Vec<Option<UPOS>>>,
154    Vec<Vec<bool>>,
155    HashMap<String, usize>,
156);
157
158#[cfg(feature = "training")]
159impl<B: Backend + AutodiffBackend> BurnChunker<B> {
160    fn to_label(&self, labels: &[bool]) -> Tensor<B, 2> {
161        let ys: Vec<_> = labels.iter().map(|b| if *b { 1. } else { 0. }).collect();
162
163        Tensor::<B, 1, _>::from_data(TensorData::from(ys.as_slice()), &self.device)
164            .reshape([1, labels.len()])
165    }
166
167    pub fn train(
168        training_files: &[impl AsRef<Path>],
169        test_file: &impl AsRef<Path>,
170        word_embed_dim: usize,
171        dropout: f32,
172        epochs: usize,
173        lr: f64,
174        device: B::Device,
175    ) -> Self {
176        use burn::tensor::cast::ToElement;
177
178        println!("Preparing datasets...");
179        let ExtractedSentences(sents, tags, labs, vocab) =
180            Self::extract_sents_from_files(training_files);
181
182        println!("Preparing model and training config...");
183
184        let mut model = NpModel::<B>::new(vocab.len(), word_embed_dim, dropout, &device);
185        let opt_config = burn::optim::AdamConfig::new();
186        let mut opt = opt_config.init();
187
188        let util = BurnChunker {
189            vocab: vocab.clone(),
190            model: model.clone(),
191            device: device.clone(),
192        };
193
194        let loss_fn = MseLoss::new();
195        let mut last_score = 0.;
196
197        println!("Training...");
198
199        for _ in 0..epochs {
200            let mut total_loss = 0.;
201            let mut total_tokens = 0;
202            let mut total_correct: usize = 0;
203
204            for (i, ((x, w), y)) in sents.iter().zip(tags.iter()).zip(labs.iter()).enumerate() {
205                let (word_tens, tag_tens) = util.to_tensors(x, w);
206                let y_tensor = util.to_label(y);
207
208                let logits = model.forward(word_tens, tag_tens, true);
209                total_correct += logits
210                    .to_data()
211                    .iter()
212                    .map(|p: f32| p > 0.5)
213                    .zip(y)
214                    .map(|(a, b)| if a == *b { 1 } else { 0 })
215                    .sum::<usize>();
216
217                let loss = loss_fn.forward(logits, y_tensor, Reduction::Mean);
218
219                let grads = loss.backward();
220                let grads = GradientsParams::from_grads(grads, &model);
221
222                model = opt.step(lr, model, grads);
223
224                total_loss += loss.into_scalar().to_f64();
225                total_tokens += x.len();
226
227                if i % 1000 == 0 {
228                    println!("{i}/{}", sents.len());
229                }
230            }
231
232            println!(
233                "Average loss for epoch: {}",
234                total_loss / sents.len() as f64 * 100.
235            );
236
237            println!(
238                "{}% correct in training dataset",
239                total_correct as f32 / total_tokens as f32 * 100.
240            );
241
242            let score = util.score_model(&model, test_file);
243            println!("{}% correct in test dataset", score * 100.);
244
245            if score < last_score {
246                println!("Overfitting detected. Stopping...");
247                break;
248            }
249
250            last_score = score;
251        }
252
253        Self {
254            vocab,
255            model,
256            device,
257        }
258    }
259
260    fn score_model(&self, model: &NpModel<B>, dataset: &impl AsRef<Path>) -> f32 {
261        let ExtractedSentences(sents, tags, labs, _) = Self::extract_sents_from_files(&[dataset]);
262
263        let mut total_tokens = 0;
264        let mut total_correct: usize = 0;
265
266        for ((x, w), y) in sents.iter().zip(tags.iter()).zip(labs.iter()) {
267            let (word_tens, tag_tens) = self.to_tensors(x, w);
268
269            let logits = model.forward(word_tens, tag_tens, false);
270            total_correct += logits
271                .to_data()
272                .iter()
273                .map(|p: f32| p > 0.5)
274                .zip(y)
275                .map(|(a, b)| if a == *b { 1 } else { 0 })
276                .sum::<usize>();
277
278            total_tokens += x.len();
279        }
280
281        total_correct as f32 / total_tokens as f32
282    }
283
284    fn extract_sents_from_files(files: &[impl AsRef<Path>]) -> ExtractedSentences {
285        use super::np_extraction::locate_noun_phrases_in_sent;
286        use crate::conllu_utils::iter_sentences_in_conllu;
287
288        let mut vocab: HashMap<String, usize> = HashMap::new();
289        vocab.insert("<UNK>".into(), UNK_IDX);
290
291        let mut sents: Vec<Vec<String>> = Vec::new();
292        let mut sent_tags: Vec<Vec<Option<UPOS>>> = Vec::new();
293        let mut labs: Vec<Vec<bool>> = Vec::new();
294
295        const CONTRACTIONS: &[&str] = &["sn't", "n't", "'ll", "'ve", "'re", "'d", "'m", "'s"];
296
297        for file in files {
298            for sent in iter_sentences_in_conllu(file) {
299                let spans = locate_noun_phrases_in_sent(&sent);
300
301                let mut original_mask = vec![false; sent.tokens.len()];
302                for span in spans {
303                    for i in span {
304                        original_mask[i] = true;
305                    }
306                }
307
308                let mut toks: Vec<String> = Vec::new();
309                let mut tags: Vec<Option<UPOS>> = Vec::new();
310                let mut mask: Vec<bool> = Vec::new();
311
312                for (idx, tok) in sent.tokens.iter().enumerate() {
313                    let is_contraction = CONTRACTIONS.contains(&&tok.form[..]);
314                    if is_contraction && !toks.is_empty() {
315                        let prev_tok = toks.pop().unwrap();
316                        let prev_mask = mask.pop().unwrap();
317                        toks.push(format!("{prev_tok}{}", tok.form));
318                        mask.push(prev_mask || original_mask[idx]);
319                    } else {
320                        toks.push(tok.form.clone());
321                        tags.push(tok.upos.and_then(UPOS::from_conllu));
322                        mask.push(original_mask[idx]);
323                    }
324                }
325
326                for t in &toks {
327                    if !vocab.contains_key(t) {
328                        let next = vocab.len();
329                        vocab.insert(t.clone(), next);
330                    }
331                }
332
333                sents.push(toks);
334                sent_tags.push(tags);
335                labs.push(mask);
336            }
337        }
338
339        ExtractedSentences(sents, sent_tags, labs, vocab)
340    }
341}
342
343#[cfg(feature = "training")]
344pub type BurnChunkerCpu = BurnChunker<burn::backend::Autodiff<NdArray>>;
345
346#[cfg(not(feature = "training"))]
347pub type BurnChunkerCpu = BurnChunker<NdArray>;
348
349impl BurnChunkerCpu {
350    pub fn load_from_bytes_cpu(
351        model_bytes: impl AsRef<[u8]>,
352        vocab_bytes: impl AsRef<[u8]>,
353        embed_dim: usize,
354        dropout: f32,
355    ) -> Self {
356        Self::load_from_bytes(
357            model_bytes,
358            vocab_bytes,
359            embed_dim,
360            dropout,
361            NdArrayDevice::Cpu,
362        )
363    }
364}
365
366#[cfg(feature = "training")]
367impl BurnChunkerCpu {
368    pub fn train_cpu(
369        training_files: &[impl AsRef<Path>],
370        test_file: &impl AsRef<Path>,
371        embed_dim: usize,
372        dropout: f32,
373        epochs: usize,
374        lr: f64,
375    ) -> Self {
376        BurnChunker::<Autodiff<NdArray>>::train(
377            training_files,
378            test_file,
379            embed_dim,
380            dropout,
381            epochs,
382            lr,
383            NdArrayDevice::Cpu,
384        )
385    }
386}
387
388impl<B: Backend> Chunker for BurnChunker<B> {
389    fn chunk_sentence(&self, sentence: &[String], tags: &[Option<UPOS>]) -> Vec<bool> {
390        // Solves a divide-by-zero error in the linear layer.
391        if sentence.is_empty() {
392            return Vec::new();
393        }
394
395        let (word_tens, tag_tens) = self.to_tensors(sentence, tags);
396        let prob = self.model.forward(word_tens, tag_tens, false);
397        prob.to_data().iter().map(|p: f32| p > 0.5).collect()
398    }
399}