1use crate::{UPOS, chunker::Chunker};
2#[cfg(feature = "training")]
3use burn::backend::Autodiff;
4
5#[cfg(feature = "training")]
6use burn::nn::loss::{MseLoss, Reduction};
7use burn::nn::{Dropout, DropoutConfig};
8#[cfg(feature = "training")]
9use burn::optim::{GradientsParams, Optimizer};
10use burn::record::{FullPrecisionSettings, NamedMpkBytesRecorder, NamedMpkFileRecorder, Recorder};
11use burn::tensor::TensorData;
12#[cfg(feature = "training")]
13use burn::tensor::backend::AutodiffBackend;
14
15use burn::{
16 module::Module,
17 nn::{BiLstmConfig, EmbeddingConfig, LinearConfig},
18 tensor::{Int, Tensor, backend::Backend},
19};
20use burn_ndarray::{NdArray, NdArrayDevice};
21use hashbrown::HashMap;
22use std::path::Path;
23
24const UNK_IDX: usize = 1;
25
26#[derive(Module, Debug)]
27struct NpModel<B: Backend> {
28 embedding_words: burn::nn::Embedding<B>,
29 embedding_upos: burn::nn::Embedding<B>,
30 lstm: burn::nn::BiLstm<B>,
31 linear_out: burn::nn::Linear<B>,
32 dropout: Dropout,
33}
34
35impl<B: Backend> NpModel<B> {
36 fn new(vocab: usize, word_embed_dim: usize, dropout: f32, device: &B::Device) -> Self {
37 let upos_embed = 8;
38 let total_embed = word_embed_dim + upos_embed;
39
40 Self {
41 embedding_words: EmbeddingConfig::new(vocab, word_embed_dim).init(device),
42 embedding_upos: EmbeddingConfig::new(20, upos_embed).init(device),
43 lstm: BiLstmConfig::new(total_embed, total_embed, false).init(device),
44 linear_out: LinearConfig::new(total_embed * 2, 1).init(device),
46 dropout: DropoutConfig::new(dropout as f64).init(),
47 }
48 }
49
50 fn forward(
51 &self,
52 word_tens: Tensor<B, 2, Int>,
53 tag_tens: Tensor<B, 2, Int>,
54 use_dropout: bool,
55 ) -> Tensor<B, 2> {
56 let word_embed = self.embedding_words.forward(word_tens);
57 let tag_embed = self.embedding_upos.forward(tag_tens);
58
59 let mut x = Tensor::cat(vec![word_embed, tag_embed], 2);
60
61 if use_dropout {
62 x = self.dropout.forward(x);
63 }
64
65 let (mut x, _) = self.lstm.forward(x, None);
66
67 if use_dropout {
68 x = self.dropout.forward(x);
69 }
70
71 let x = self.linear_out.forward(x);
72 x.squeeze_dim::<2>(2)
73 }
74}
75
76pub struct BurnChunker<B: Backend> {
80 vocab: HashMap<String, usize>,
81 model: NpModel<B>,
82 device: B::Device,
83}
84
85impl<B: Backend> BurnChunker<B> {
86 fn idx(&self, tok: &str) -> usize {
87 *self.vocab.get(tok).unwrap_or(&UNK_IDX)
88 }
89
90 fn to_tensors(
91 &self,
92 sent: &[String],
93 tags: &[Option<UPOS>],
94 ) -> (Tensor<B, 2, Int>, Tensor<B, 2, Int>) {
95 let idxs: Vec<_> = sent.iter().map(|t| self.idx(t) as i32).collect();
97
98 let upos: Vec<_> = tags
99 .iter()
100 .map(|t| t.map(|o| o as i32 + 2).unwrap_or(1))
101 .collect();
102
103 let word_tensor =
104 Tensor::<B, 1, Int>::from_data(TensorData::from(idxs.as_slice()), &self.device)
105 .reshape([1, sent.len()]);
106
107 let tag_tensor =
108 Tensor::<B, 1, Int>::from_data(TensorData::from(upos.as_slice()), &self.device)
109 .reshape([1, sent.len()]);
110
111 (word_tensor, tag_tensor)
112 }
113
114 pub fn save_to(&self, dir: impl AsRef<Path>) {
115 let dir = dir.as_ref();
116 std::fs::create_dir_all(dir).unwrap();
117
118 let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
119 self.model
120 .clone()
121 .save_file(dir.join("model.mpk"), &recorder)
122 .unwrap();
123
124 let vocab_bytes = serde_json::to_vec(&self.vocab).unwrap();
125 std::fs::write(dir.join("vocab.json"), vocab_bytes).unwrap();
126 }
127
128 pub fn load_from_bytes(
129 model_bytes: impl AsRef<[u8]>,
130 vocab_bytes: impl AsRef<[u8]>,
131 embed_dim: usize,
132 dropout: f32,
133 device: B::Device,
134 ) -> Self {
135 let vocab: HashMap<String, usize> = serde_json::from_slice(vocab_bytes.as_ref()).unwrap();
136
137 let recorder = NamedMpkBytesRecorder::<FullPrecisionSettings>::new();
138
139 let owned_data = model_bytes.as_ref().to_vec();
140 let record = recorder.load(owned_data, &device).unwrap();
141
142 let model = NpModel::new(vocab.len(), embed_dim, dropout, &device);
143 let model = model.load_record(record);
144
145 Self {
146 vocab,
147 model,
148 device,
149 }
150 }
151}
152
153#[cfg(feature = "training")]
154struct ExtractedSentences(
155 Vec<Vec<String>>,
156 Vec<Vec<Option<UPOS>>>,
157 Vec<Vec<bool>>,
158 HashMap<String, usize>,
159);
160
161#[cfg(feature = "training")]
162impl<B: Backend + AutodiffBackend> BurnChunker<B> {
163 fn to_label(&self, labels: &[bool]) -> Tensor<B, 2> {
164 let ys: Vec<_> = labels.iter().map(|b| if *b { 1. } else { 0. }).collect();
165
166 Tensor::<B, 1, _>::from_data(TensorData::from(ys.as_slice()), &self.device)
167 .reshape([1, labels.len()])
168 }
169
170 pub fn train(
171 training_files: &[impl AsRef<Path>],
172 test_file: &impl AsRef<Path>,
173 word_embed_dim: usize,
174 dropout: f32,
175 epochs: usize,
176 lr: f64,
177 device: B::Device,
178 ) -> Self {
179 use burn::tensor::cast::ToElement;
180
181 println!("Preparing datasets...");
182 let ExtractedSentences(sents, tags, labs, vocab) =
183 Self::extract_sents_from_files(training_files);
184
185 println!("Preparing model and training config...");
186
187 let mut model = NpModel::<B>::new(vocab.len(), word_embed_dim, dropout, &device);
188 let opt_config = burn::optim::AdamConfig::new();
189 let mut opt = opt_config.init();
190
191 let util = BurnChunker {
192 vocab: vocab.clone(),
193 model: model.clone(),
194 device: device.clone(),
195 };
196
197 let loss_fn = MseLoss::new();
198 let mut last_score = 0.;
199
200 println!("Training...");
201
202 for _ in 0..epochs {
203 let mut total_loss = 0.;
204 let mut total_tokens = 0;
205 let mut total_correct: usize = 0;
206
207 for (i, ((x, w), y)) in sents.iter().zip(tags.iter()).zip(labs.iter()).enumerate() {
208 let (word_tens, tag_tens) = util.to_tensors(x, w);
209 let y_tensor = util.to_label(y);
210
211 let logits = model.forward(word_tens, tag_tens, true);
212 total_correct += logits
213 .to_data()
214 .iter()
215 .map(|p: f32| p > 0.5)
216 .zip(y)
217 .map(|(a, b)| if a == *b { 1 } else { 0 })
218 .sum::<usize>();
219
220 let loss = loss_fn.forward(logits, y_tensor, Reduction::Mean);
221
222 let grads = loss.backward();
223 let grads = GradientsParams::from_grads(grads, &model);
224
225 model = opt.step(lr, model, grads);
226
227 total_loss += loss.into_scalar().to_f64();
228 total_tokens += x.len();
229
230 if i % 1000 == 0 {
231 println!("{i}/{}", sents.len());
232 }
233 }
234
235 println!(
236 "Average loss for epoch: {}",
237 total_loss / sents.len() as f64 * 100.
238 );
239
240 println!(
241 "{}% correct in training dataset",
242 total_correct as f32 / total_tokens as f32 * 100.
243 );
244
245 let score = util.score_model(&model, test_file);
246 println!("{}% correct in test dataset", score * 100.);
247
248 if score < last_score {
249 println!("Overfitting detected. Stopping...");
250 break;
251 }
252
253 last_score = score;
254 }
255
256 Self {
257 vocab,
258 model,
259 device,
260 }
261 }
262
263 fn score_model(&self, model: &NpModel<B>, dataset: &impl AsRef<Path>) -> f32 {
264 let ExtractedSentences(sents, tags, labs, _) = Self::extract_sents_from_files(&[dataset]);
265
266 let mut total_tokens = 0;
267 let mut total_correct: usize = 0;
268
269 for ((x, w), y) in sents.iter().zip(tags.iter()).zip(labs.iter()) {
270 let (word_tens, tag_tens) = self.to_tensors(x, w);
271
272 let logits = model.forward(word_tens, tag_tens, false);
273 total_correct += logits
274 .to_data()
275 .iter()
276 .map(|p: f32| p > 0.5)
277 .zip(y)
278 .map(|(a, b)| if a == *b { 1 } else { 0 })
279 .sum::<usize>();
280
281 total_tokens += x.len();
282 }
283
284 total_correct as f32 / total_tokens as f32
285 }
286
287 fn extract_sents_from_files(files: &[impl AsRef<Path>]) -> ExtractedSentences {
288 use super::np_extraction::locate_noun_phrases_in_sent;
289 use crate::conllu_utils::iter_sentences_in_conllu;
290
291 let mut vocab: HashMap<String, usize> = HashMap::new();
292 vocab.insert("<UNK>".into(), UNK_IDX);
293
294 let mut sents: Vec<Vec<String>> = Vec::new();
295 let mut sent_tags: Vec<Vec<Option<UPOS>>> = Vec::new();
296 let mut labs: Vec<Vec<bool>> = Vec::new();
297
298 const CONTRACTIONS: &[&str] = &["sn't", "n't", "'ll", "'ve", "'re", "'d", "'m", "'s"];
299
300 for file in files {
301 for sent in iter_sentences_in_conllu(file) {
302 let spans = locate_noun_phrases_in_sent(&sent);
303
304 let mut original_mask = vec![false; sent.tokens.len()];
305 for span in spans {
306 for i in span {
307 original_mask[i] = true;
308 }
309 }
310
311 let mut toks: Vec<String> = Vec::new();
312 let mut tags: Vec<Option<UPOS>> = Vec::new();
313 let mut mask: Vec<bool> = Vec::new();
314
315 for (idx, tok) in sent.tokens.iter().enumerate() {
316 let is_contraction = CONTRACTIONS.contains(&&tok.form[..]);
317 if is_contraction && !toks.is_empty() {
318 let prev_tok = toks.pop().unwrap();
319 let prev_mask = mask.pop().unwrap();
320 toks.push(format!("{prev_tok}{}", tok.form));
321 mask.push(prev_mask || original_mask[idx]);
322 } else {
323 toks.push(tok.form.clone());
324 tags.push(tok.upos.and_then(UPOS::from_conllu));
325 mask.push(original_mask[idx]);
326 }
327 }
328
329 for t in &toks {
330 if !vocab.contains_key(t) {
331 let next = vocab.len();
332 vocab.insert(t.clone(), next);
333 }
334 }
335
336 sents.push(toks);
337 sent_tags.push(tags);
338 labs.push(mask);
339 }
340 }
341
342 ExtractedSentences(sents, sent_tags, labs, vocab)
343 }
344}
345
346#[cfg(feature = "training")]
347pub type BurnChunkerCpu = BurnChunker<burn::backend::Autodiff<NdArray>>;
348
349#[cfg(not(feature = "training"))]
350pub type BurnChunkerCpu = BurnChunker<NdArray>;
351
352impl BurnChunkerCpu {
353 pub fn load_from_bytes_cpu(
354 model_bytes: impl AsRef<[u8]>,
355 vocab_bytes: impl AsRef<[u8]>,
356 embed_dim: usize,
357 dropout: f32,
358 ) -> Self {
359 Self::load_from_bytes(
360 model_bytes,
361 vocab_bytes,
362 embed_dim,
363 dropout,
364 NdArrayDevice::Cpu,
365 )
366 }
367}
368
369#[cfg(feature = "training")]
370impl BurnChunkerCpu {
371 pub fn train_cpu(
372 training_files: &[impl AsRef<Path>],
373 test_file: &impl AsRef<Path>,
374 embed_dim: usize,
375 dropout: f32,
376 epochs: usize,
377 lr: f64,
378 ) -> Self {
379 BurnChunker::<Autodiff<NdArray>>::train(
380 training_files,
381 test_file,
382 embed_dim,
383 dropout,
384 epochs,
385 lr,
386 NdArrayDevice::Cpu,
387 )
388 }
389}
390
391impl<B: Backend> Chunker for BurnChunker<B> {
392 fn chunk_sentence(&self, sentence: &[String], tags: &[Option<UPOS>]) -> Vec<bool> {
393 if sentence.is_empty() {
395 return Vec::new();
396 }
397
398 let (word_tens, tag_tens) = self.to_tensors(sentence, tags);
399 let prob = self.model.forward(word_tens, tag_tens, false);
400 prob.to_data().iter().map(|p: f32| p > 0.5).collect()
401 }
402}