harper_pos_utils/chunker/
burn_chunker.rs

1use crate::{UPOS, chunker::Chunker};
2#[cfg(feature = "training")]
3use burn::backend::Autodiff;
4
5#[cfg(feature = "training")]
6use burn::nn::loss::{MseLoss, Reduction};
7use burn::nn::{Dropout, DropoutConfig};
8#[cfg(feature = "training")]
9use burn::optim::{GradientsParams, Optimizer};
10use burn::record::{FullPrecisionSettings, NamedMpkBytesRecorder, NamedMpkFileRecorder, Recorder};
11use burn::tensor::TensorData;
12#[cfg(feature = "training")]
13use burn::tensor::backend::AutodiffBackend;
14
15use burn::{
16    module::Module,
17    nn::{BiLstmConfig, EmbeddingConfig, LinearConfig},
18    tensor::{Int, Tensor, backend::Backend},
19};
20use burn_ndarray::{NdArray, NdArrayDevice};
21use hashbrown::HashMap;
22use std::path::Path;
23
24const UNK_IDX: usize = 1;
25
26#[derive(Module, Debug)]
27struct NpModel<B: Backend> {
28    embedding_words: burn::nn::Embedding<B>,
29    embedding_upos: burn::nn::Embedding<B>,
30    lstm: burn::nn::BiLstm<B>,
31    linear_out: burn::nn::Linear<B>,
32    dropout: Dropout,
33}
34
35impl<B: Backend> NpModel<B> {
36    fn new(vocab: usize, word_embed_dim: usize, dropout: f32, device: &B::Device) -> Self {
37        let upos_embed = 8;
38        let total_embed = word_embed_dim + upos_embed;
39
40        Self {
41            embedding_words: EmbeddingConfig::new(vocab, word_embed_dim).init(device),
42            embedding_upos: EmbeddingConfig::new(20, upos_embed).init(device),
43            lstm: BiLstmConfig::new(total_embed, total_embed, false).init(device),
44            // Multiply by two because the BiLSTM emits double the hidden parameters
45            linear_out: LinearConfig::new(total_embed * 2, 1).init(device),
46            dropout: DropoutConfig::new(dropout as f64).init(),
47        }
48    }
49
50    fn forward(
51        &self,
52        word_tens: Tensor<B, 2, Int>,
53        tag_tens: Tensor<B, 2, Int>,
54        use_dropout: bool,
55    ) -> Tensor<B, 2> {
56        let word_embed = self.embedding_words.forward(word_tens);
57        let tag_embed = self.embedding_upos.forward(tag_tens);
58
59        let mut x = Tensor::cat(vec![word_embed, tag_embed], 2);
60
61        if use_dropout {
62            x = self.dropout.forward(x);
63        }
64
65        let (mut x, _) = self.lstm.forward(x, None);
66
67        if use_dropout {
68            x = self.dropout.forward(x);
69        }
70
71        let x = self.linear_out.forward(x);
72        x.squeeze_dim::<2>(2)
73    }
74}
75
76/// A [`Chunker`] that uses a BiLSTM and the Burn machine learning framework.
77///
78/// Additional details in this [talk](https://elijahpotter.dev/articles/i-spoke-at-wordcamp-u.s.-in-2025)
79pub struct BurnChunker<B: Backend> {
80    vocab: HashMap<String, usize>,
81    model: NpModel<B>,
82    device: B::Device,
83}
84
85impl<B: Backend> BurnChunker<B> {
86    fn idx(&self, tok: &str) -> usize {
87        *self.vocab.get(tok).unwrap_or(&UNK_IDX)
88    }
89
90    fn to_tensors(
91        &self,
92        sent: &[String],
93        tags: &[Option<UPOS>],
94    ) -> (Tensor<B, 2, Int>, Tensor<B, 2, Int>) {
95        // Interleave with UPOS tags
96        let idxs: Vec<_> = sent.iter().map(|t| self.idx(t) as i32).collect();
97
98        let upos: Vec<_> = tags
99            .iter()
100            .map(|t| t.map(|o| o as i32 + 2).unwrap_or(1))
101            .collect();
102
103        let word_tensor =
104            Tensor::<B, 1, Int>::from_data(TensorData::from(idxs.as_slice()), &self.device)
105                .reshape([1, sent.len()]);
106
107        let tag_tensor =
108            Tensor::<B, 1, Int>::from_data(TensorData::from(upos.as_slice()), &self.device)
109                .reshape([1, sent.len()]);
110
111        (word_tensor, tag_tensor)
112    }
113
114    pub fn save_to(&self, dir: impl AsRef<Path>) {
115        let dir = dir.as_ref();
116        std::fs::create_dir_all(dir).unwrap();
117
118        let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
119        self.model
120            .clone()
121            .save_file(dir.join("model.mpk"), &recorder)
122            .unwrap();
123
124        let vocab_bytes = serde_json::to_vec(&self.vocab).unwrap();
125        std::fs::write(dir.join("vocab.json"), vocab_bytes).unwrap();
126    }
127
128    pub fn load_from_bytes(
129        model_bytes: impl AsRef<[u8]>,
130        vocab_bytes: impl AsRef<[u8]>,
131        embed_dim: usize,
132        dropout: f32,
133        device: B::Device,
134    ) -> Self {
135        let vocab: HashMap<String, usize> = serde_json::from_slice(vocab_bytes.as_ref()).unwrap();
136
137        let recorder = NamedMpkBytesRecorder::<FullPrecisionSettings>::new();
138
139        let owned_data = model_bytes.as_ref().to_vec();
140        let record = recorder.load(owned_data, &device).unwrap();
141
142        let model = NpModel::new(vocab.len(), embed_dim, dropout, &device);
143        let model = model.load_record(record);
144
145        Self {
146            vocab,
147            model,
148            device,
149        }
150    }
151}
152
153#[cfg(feature = "training")]
154struct ExtractedSentences(
155    Vec<Vec<String>>,
156    Vec<Vec<Option<UPOS>>>,
157    Vec<Vec<bool>>,
158    HashMap<String, usize>,
159);
160
161#[cfg(feature = "training")]
162impl<B: Backend + AutodiffBackend> BurnChunker<B> {
163    fn to_label(&self, labels: &[bool]) -> Tensor<B, 2> {
164        let ys: Vec<_> = labels.iter().map(|b| if *b { 1. } else { 0. }).collect();
165
166        Tensor::<B, 1, _>::from_data(TensorData::from(ys.as_slice()), &self.device)
167            .reshape([1, labels.len()])
168    }
169
170    pub fn train(
171        training_files: &[impl AsRef<Path>],
172        test_file: &impl AsRef<Path>,
173        word_embed_dim: usize,
174        dropout: f32,
175        epochs: usize,
176        lr: f64,
177        device: B::Device,
178    ) -> Self {
179        use burn::tensor::cast::ToElement;
180
181        println!("Preparing datasets...");
182        let ExtractedSentences(sents, tags, labs, vocab) =
183            Self::extract_sents_from_files(training_files);
184
185        println!("Preparing model and training config...");
186
187        let mut model = NpModel::<B>::new(vocab.len(), word_embed_dim, dropout, &device);
188        let opt_config = burn::optim::AdamConfig::new();
189        let mut opt = opt_config.init();
190
191        let util = BurnChunker {
192            vocab: vocab.clone(),
193            model: model.clone(),
194            device: device.clone(),
195        };
196
197        let loss_fn = MseLoss::new();
198        let mut last_score = 0.;
199
200        println!("Training...");
201
202        for _ in 0..epochs {
203            let mut total_loss = 0.;
204            let mut total_tokens = 0;
205            let mut total_correct: usize = 0;
206
207            for (i, ((x, w), y)) in sents.iter().zip(tags.iter()).zip(labs.iter()).enumerate() {
208                let (word_tens, tag_tens) = util.to_tensors(x, w);
209                let y_tensor = util.to_label(y);
210
211                let logits = model.forward(word_tens, tag_tens, true);
212                total_correct += logits
213                    .to_data()
214                    .iter()
215                    .map(|p: f32| p > 0.5)
216                    .zip(y)
217                    .map(|(a, b)| if a == *b { 1 } else { 0 })
218                    .sum::<usize>();
219
220                let loss = loss_fn.forward(logits, y_tensor, Reduction::Mean);
221
222                let grads = loss.backward();
223                let grads = GradientsParams::from_grads(grads, &model);
224
225                model = opt.step(lr, model, grads);
226
227                total_loss += loss.into_scalar().to_f64();
228                total_tokens += x.len();
229
230                if i % 1000 == 0 {
231                    println!("{i}/{}", sents.len());
232                }
233            }
234
235            println!(
236                "Average loss for epoch: {}",
237                total_loss / sents.len() as f64 * 100.
238            );
239
240            println!(
241                "{}% correct in training dataset",
242                total_correct as f32 / total_tokens as f32 * 100.
243            );
244
245            let score = util.score_model(&model, test_file);
246            println!("{}% correct in test dataset", score * 100.);
247
248            if score < last_score {
249                println!("Overfitting detected. Stopping...");
250                break;
251            }
252
253            last_score = score;
254        }
255
256        Self {
257            vocab,
258            model,
259            device,
260        }
261    }
262
263    fn score_model(&self, model: &NpModel<B>, dataset: &impl AsRef<Path>) -> f32 {
264        let ExtractedSentences(sents, tags, labs, _) = Self::extract_sents_from_files(&[dataset]);
265
266        let mut total_tokens = 0;
267        let mut total_correct: usize = 0;
268
269        for ((x, w), y) in sents.iter().zip(tags.iter()).zip(labs.iter()) {
270            let (word_tens, tag_tens) = self.to_tensors(x, w);
271
272            let logits = model.forward(word_tens, tag_tens, false);
273            total_correct += logits
274                .to_data()
275                .iter()
276                .map(|p: f32| p > 0.5)
277                .zip(y)
278                .map(|(a, b)| if a == *b { 1 } else { 0 })
279                .sum::<usize>();
280
281            total_tokens += x.len();
282        }
283
284        total_correct as f32 / total_tokens as f32
285    }
286
287    fn extract_sents_from_files(files: &[impl AsRef<Path>]) -> ExtractedSentences {
288        use super::np_extraction::locate_noun_phrases_in_sent;
289        use crate::conllu_utils::iter_sentences_in_conllu;
290
291        let mut vocab: HashMap<String, usize> = HashMap::new();
292        vocab.insert("<UNK>".into(), UNK_IDX);
293
294        let mut sents: Vec<Vec<String>> = Vec::new();
295        let mut sent_tags: Vec<Vec<Option<UPOS>>> = Vec::new();
296        let mut labs: Vec<Vec<bool>> = Vec::new();
297
298        const CONTRACTIONS: &[&str] = &["sn't", "n't", "'ll", "'ve", "'re", "'d", "'m", "'s"];
299
300        for file in files {
301            for sent in iter_sentences_in_conllu(file) {
302                let spans = locate_noun_phrases_in_sent(&sent);
303
304                let mut original_mask = vec![false; sent.tokens.len()];
305                for span in spans {
306                    for i in span {
307                        original_mask[i] = true;
308                    }
309                }
310
311                let mut toks: Vec<String> = Vec::new();
312                let mut tags: Vec<Option<UPOS>> = Vec::new();
313                let mut mask: Vec<bool> = Vec::new();
314
315                for (idx, tok) in sent.tokens.iter().enumerate() {
316                    let is_contraction = CONTRACTIONS.contains(&&tok.form[..]);
317                    if is_contraction && !toks.is_empty() {
318                        let prev_tok = toks.pop().unwrap();
319                        let prev_mask = mask.pop().unwrap();
320                        toks.push(format!("{prev_tok}{}", tok.form));
321                        mask.push(prev_mask || original_mask[idx]);
322                    } else {
323                        toks.push(tok.form.clone());
324                        tags.push(tok.upos.and_then(UPOS::from_conllu));
325                        mask.push(original_mask[idx]);
326                    }
327                }
328
329                for t in &toks {
330                    if !vocab.contains_key(t) {
331                        let next = vocab.len();
332                        vocab.insert(t.clone(), next);
333                    }
334                }
335
336                sents.push(toks);
337                sent_tags.push(tags);
338                labs.push(mask);
339            }
340        }
341
342        ExtractedSentences(sents, sent_tags, labs, vocab)
343    }
344}
345
346#[cfg(feature = "training")]
347pub type BurnChunkerCpu = BurnChunker<burn::backend::Autodiff<NdArray>>;
348
349#[cfg(not(feature = "training"))]
350pub type BurnChunkerCpu = BurnChunker<NdArray>;
351
352impl BurnChunkerCpu {
353    pub fn load_from_bytes_cpu(
354        model_bytes: impl AsRef<[u8]>,
355        vocab_bytes: impl AsRef<[u8]>,
356        embed_dim: usize,
357        dropout: f32,
358    ) -> Self {
359        Self::load_from_bytes(
360            model_bytes,
361            vocab_bytes,
362            embed_dim,
363            dropout,
364            NdArrayDevice::Cpu,
365        )
366    }
367}
368
369#[cfg(feature = "training")]
370impl BurnChunkerCpu {
371    pub fn train_cpu(
372        training_files: &[impl AsRef<Path>],
373        test_file: &impl AsRef<Path>,
374        embed_dim: usize,
375        dropout: f32,
376        epochs: usize,
377        lr: f64,
378    ) -> Self {
379        BurnChunker::<Autodiff<NdArray>>::train(
380            training_files,
381            test_file,
382            embed_dim,
383            dropout,
384            epochs,
385            lr,
386            NdArrayDevice::Cpu,
387        )
388    }
389}
390
391impl<B: Backend> Chunker for BurnChunker<B> {
392    fn chunk_sentence(&self, sentence: &[String], tags: &[Option<UPOS>]) -> Vec<bool> {
393        // Solves a divide-by-zero error in the linear layer.
394        if sentence.is_empty() {
395            return Vec::new();
396        }
397
398        let (word_tens, tag_tens) = self.to_tensors(sentence, tags);
399        let prob = self.model.forward(word_tens, tag_tens, false);
400        prob.to_data().iter().map(|p: f32| p > 0.5).collect()
401    }
402}