1use crate::{UPOS, chunker::Chunker};
2#[cfg(feature = "training")]
3use burn::backend::Autodiff;
4
5#[cfg(feature = "training")]
6use burn::nn::loss::{MseLoss, Reduction};
7use burn::nn::{Dropout, DropoutConfig};
8#[cfg(feature = "training")]
9use burn::optim::{GradientsParams, Optimizer};
10use burn::record::{FullPrecisionSettings, NamedMpkBytesRecorder, NamedMpkFileRecorder, Recorder};
11use burn::tensor::TensorData;
12#[cfg(feature = "training")]
13use burn::tensor::backend::AutodiffBackend;
14
15use burn::{
16 module::Module,
17 nn::{BiLstmConfig, EmbeddingConfig, LinearConfig},
18 tensor::{Int, Tensor, backend::Backend},
19};
20use burn_ndarray::{NdArray, NdArrayDevice};
21use hashbrown::HashMap;
22use std::path::Path;
23
24const UNK_IDX: usize = 1;
25
26#[derive(Module, Debug)]
27struct NpModel<B: Backend> {
28 embedding_words: burn::nn::Embedding<B>,
29 embedding_upos: burn::nn::Embedding<B>,
30 lstm: burn::nn::BiLstm<B>,
31 linear_out: burn::nn::Linear<B>,
32 dropout: Dropout,
33}
34
35impl<B: Backend> NpModel<B> {
36 fn new(vocab: usize, word_embed_dim: usize, dropout: f32, device: &B::Device) -> Self {
37 let upos_embed = 8;
38 let total_embed = word_embed_dim + upos_embed;
39
40 Self {
41 embedding_words: EmbeddingConfig::new(vocab, word_embed_dim).init(device),
42 embedding_upos: EmbeddingConfig::new(20, upos_embed).init(device),
43 lstm: BiLstmConfig::new(total_embed, total_embed, false).init(device),
44 linear_out: LinearConfig::new(total_embed * 2, 1).init(device),
46 dropout: DropoutConfig::new(dropout as f64).init(),
47 }
48 }
49
50 fn forward(
51 &self,
52 word_tens: Tensor<B, 2, Int>,
53 tag_tens: Tensor<B, 2, Int>,
54 use_dropout: bool,
55 ) -> Tensor<B, 2> {
56 let word_embed = self.embedding_words.forward(word_tens);
57 let tag_embed = self.embedding_upos.forward(tag_tens);
58
59 let mut x = Tensor::cat(vec![word_embed, tag_embed], 2);
60
61 if use_dropout {
62 x = self.dropout.forward(x);
63 }
64
65 let (mut x, _) = self.lstm.forward(x, None);
66
67 if use_dropout {
68 x = self.dropout.forward(x);
69 }
70
71 let x = self.linear_out.forward(x);
72 x.squeeze::<2>(2)
73 }
74}
75
76pub struct BurnChunker<B: Backend> {
77 vocab: HashMap<String, usize>,
78 model: NpModel<B>,
79 device: B::Device,
80}
81
82impl<B: Backend> BurnChunker<B> {
83 fn idx(&self, tok: &str) -> usize {
84 *self.vocab.get(tok).unwrap_or(&UNK_IDX)
85 }
86
87 fn to_tensors(
88 &self,
89 sent: &[String],
90 tags: &[Option<UPOS>],
91 ) -> (Tensor<B, 2, Int>, Tensor<B, 2, Int>) {
92 let idxs: Vec<_> = sent.iter().map(|t| self.idx(t) as i32).collect();
94
95 let upos: Vec<_> = tags
96 .iter()
97 .map(|t| t.map(|o| o as i32 + 2).unwrap_or(1))
98 .collect();
99
100 let word_tensor =
101 Tensor::<B, 1, Int>::from_data(TensorData::from(idxs.as_slice()), &self.device)
102 .reshape([1, sent.len()]);
103
104 let tag_tensor =
105 Tensor::<B, 1, Int>::from_data(TensorData::from(upos.as_slice()), &self.device)
106 .reshape([1, sent.len()]);
107
108 (word_tensor, tag_tensor)
109 }
110
111 pub fn save_to(&self, dir: impl AsRef<Path>) {
112 let dir = dir.as_ref();
113 std::fs::create_dir_all(dir).unwrap();
114
115 let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
116 self.model
117 .clone()
118 .save_file(dir.join("model.mpk"), &recorder)
119 .unwrap();
120
121 let vocab_bytes = serde_json::to_vec(&self.vocab).unwrap();
122 std::fs::write(dir.join("vocab.json"), vocab_bytes).unwrap();
123 }
124
125 pub fn load_from_bytes(
126 model_bytes: impl AsRef<[u8]>,
127 vocab_bytes: impl AsRef<[u8]>,
128 embed_dim: usize,
129 dropout: f32,
130 device: B::Device,
131 ) -> Self {
132 let vocab: HashMap<String, usize> = serde_json::from_slice(vocab_bytes.as_ref()).unwrap();
133
134 let recorder = NamedMpkBytesRecorder::<FullPrecisionSettings>::new();
135
136 let owned_data = model_bytes.as_ref().to_vec();
137 let record = recorder.load(owned_data, &device).unwrap();
138
139 let model = NpModel::new(vocab.len(), embed_dim, dropout, &device);
140 let model = model.load_record(record);
141
142 Self {
143 vocab,
144 model,
145 device,
146 }
147 }
148}
149
150#[cfg(feature = "training")]
151struct ExtractedSentences(
152 Vec<Vec<String>>,
153 Vec<Vec<Option<UPOS>>>,
154 Vec<Vec<bool>>,
155 HashMap<String, usize>,
156);
157
158#[cfg(feature = "training")]
159impl<B: Backend + AutodiffBackend> BurnChunker<B> {
160 fn to_label(&self, labels: &[bool]) -> Tensor<B, 2> {
161 let ys: Vec<_> = labels.iter().map(|b| if *b { 1. } else { 0. }).collect();
162
163 Tensor::<B, 1, _>::from_data(TensorData::from(ys.as_slice()), &self.device)
164 .reshape([1, labels.len()])
165 }
166
167 pub fn train(
168 training_files: &[impl AsRef<Path>],
169 test_file: &impl AsRef<Path>,
170 word_embed_dim: usize,
171 dropout: f32,
172 epochs: usize,
173 lr: f64,
174 device: B::Device,
175 ) -> Self {
176 use burn::tensor::cast::ToElement;
177
178 println!("Preparing datasets...");
179 let ExtractedSentences(sents, tags, labs, vocab) =
180 Self::extract_sents_from_files(training_files);
181
182 println!("Preparing model and training config...");
183
184 let mut model = NpModel::<B>::new(vocab.len(), word_embed_dim, dropout, &device);
185 let opt_config = burn::optim::AdamConfig::new();
186 let mut opt = opt_config.init();
187
188 let util = BurnChunker {
189 vocab: vocab.clone(),
190 model: model.clone(),
191 device: device.clone(),
192 };
193
194 let loss_fn = MseLoss::new();
195 let mut last_score = 0.;
196
197 println!("Training...");
198
199 for _ in 0..epochs {
200 let mut total_loss = 0.;
201 let mut total_tokens = 0;
202 let mut total_correct: usize = 0;
203
204 for (i, ((x, w), y)) in sents.iter().zip(tags.iter()).zip(labs.iter()).enumerate() {
205 let (word_tens, tag_tens) = util.to_tensors(x, w);
206 let y_tensor = util.to_label(y);
207
208 let logits = model.forward(word_tens, tag_tens, true);
209 total_correct += logits
210 .to_data()
211 .iter()
212 .map(|p: f32| p > 0.5)
213 .zip(y)
214 .map(|(a, b)| if a == *b { 1 } else { 0 })
215 .sum::<usize>();
216
217 let loss = loss_fn.forward(logits, y_tensor, Reduction::Mean);
218
219 let grads = loss.backward();
220 let grads = GradientsParams::from_grads(grads, &model);
221
222 model = opt.step(lr, model, grads);
223
224 total_loss += loss.into_scalar().to_f64();
225 total_tokens += x.len();
226
227 if i % 1000 == 0 {
228 println!("{i}/{}", sents.len());
229 }
230 }
231
232 println!(
233 "Average loss for epoch: {}",
234 total_loss / sents.len() as f64 * 100.
235 );
236
237 println!(
238 "{}% correct in training dataset",
239 total_correct as f32 / total_tokens as f32 * 100.
240 );
241
242 let score = util.score_model(&model, test_file);
243 println!("{}% correct in test dataset", score * 100.);
244
245 if score < last_score {
246 println!("Overfitting detected. Stopping...");
247 break;
248 }
249
250 last_score = score;
251 }
252
253 Self {
254 vocab,
255 model,
256 device,
257 }
258 }
259
260 fn score_model(&self, model: &NpModel<B>, dataset: &impl AsRef<Path>) -> f32 {
261 let ExtractedSentences(sents, tags, labs, _) = Self::extract_sents_from_files(&[dataset]);
262
263 let mut total_tokens = 0;
264 let mut total_correct: usize = 0;
265
266 for ((x, w), y) in sents.iter().zip(tags.iter()).zip(labs.iter()) {
267 let (word_tens, tag_tens) = self.to_tensors(x, w);
268
269 let logits = model.forward(word_tens, tag_tens, false);
270 total_correct += logits
271 .to_data()
272 .iter()
273 .map(|p: f32| p > 0.5)
274 .zip(y)
275 .map(|(a, b)| if a == *b { 1 } else { 0 })
276 .sum::<usize>();
277
278 total_tokens += x.len();
279 }
280
281 total_correct as f32 / total_tokens as f32
282 }
283
284 fn extract_sents_from_files(files: &[impl AsRef<Path>]) -> ExtractedSentences {
285 use super::np_extraction::locate_noun_phrases_in_sent;
286 use crate::conllu_utils::iter_sentences_in_conllu;
287
288 let mut vocab: HashMap<String, usize> = HashMap::new();
289 vocab.insert("<UNK>".into(), UNK_IDX);
290
291 let mut sents: Vec<Vec<String>> = Vec::new();
292 let mut sent_tags: Vec<Vec<Option<UPOS>>> = Vec::new();
293 let mut labs: Vec<Vec<bool>> = Vec::new();
294
295 const CONTRACTIONS: &[&str] = &["sn't", "n't", "'ll", "'ve", "'re", "'d", "'m", "'s"];
296
297 for file in files {
298 for sent in iter_sentences_in_conllu(file) {
299 let spans = locate_noun_phrases_in_sent(&sent);
300
301 let mut original_mask = vec![false; sent.tokens.len()];
302 for span in spans {
303 for i in span {
304 original_mask[i] = true;
305 }
306 }
307
308 let mut toks: Vec<String> = Vec::new();
309 let mut tags: Vec<Option<UPOS>> = Vec::new();
310 let mut mask: Vec<bool> = Vec::new();
311
312 for (idx, tok) in sent.tokens.iter().enumerate() {
313 let is_contraction = CONTRACTIONS.contains(&&tok.form[..]);
314 if is_contraction && !toks.is_empty() {
315 let prev_tok = toks.pop().unwrap();
316 let prev_mask = mask.pop().unwrap();
317 toks.push(format!("{prev_tok}{}", tok.form));
318 mask.push(prev_mask || original_mask[idx]);
319 } else {
320 toks.push(tok.form.clone());
321 tags.push(tok.upos.and_then(UPOS::from_conllu));
322 mask.push(original_mask[idx]);
323 }
324 }
325
326 for t in &toks {
327 if !vocab.contains_key(t) {
328 let next = vocab.len();
329 vocab.insert(t.clone(), next);
330 }
331 }
332
333 sents.push(toks);
334 sent_tags.push(tags);
335 labs.push(mask);
336 }
337 }
338
339 ExtractedSentences(sents, sent_tags, labs, vocab)
340 }
341}
342
343#[cfg(feature = "training")]
344pub type BurnChunkerCpu = BurnChunker<burn::backend::Autodiff<NdArray>>;
345
346#[cfg(not(feature = "training"))]
347pub type BurnChunkerCpu = BurnChunker<NdArray>;
348
349impl BurnChunkerCpu {
350 pub fn load_from_bytes_cpu(
351 model_bytes: impl AsRef<[u8]>,
352 vocab_bytes: impl AsRef<[u8]>,
353 embed_dim: usize,
354 dropout: f32,
355 ) -> Self {
356 Self::load_from_bytes(
357 model_bytes,
358 vocab_bytes,
359 embed_dim,
360 dropout,
361 NdArrayDevice::Cpu,
362 )
363 }
364}
365
366#[cfg(feature = "training")]
367impl BurnChunkerCpu {
368 pub fn train_cpu(
369 training_files: &[impl AsRef<Path>],
370 test_file: &impl AsRef<Path>,
371 embed_dim: usize,
372 dropout: f32,
373 epochs: usize,
374 lr: f64,
375 ) -> Self {
376 BurnChunker::<Autodiff<NdArray>>::train(
377 training_files,
378 test_file,
379 embed_dim,
380 dropout,
381 epochs,
382 lr,
383 NdArrayDevice::Cpu,
384 )
385 }
386}
387
388impl<B: Backend> Chunker for BurnChunker<B> {
389 fn chunk_sentence(&self, sentence: &[String], tags: &[Option<UPOS>]) -> Vec<bool> {
390 if sentence.is_empty() {
392 return Vec::new();
393 }
394
395 let (word_tens, tag_tens) = self.to_tensors(sentence, tags);
396 let prob = self.model.forward(word_tens, tag_tens, false);
397 prob.to_data().iter().map(|p: f32| p > 0.5).collect()
398 }
399}