harper_pos_utils/chunker/
upos_freq_dict.rs

1#[cfg(feature = "training")]
2use std::path::Path;
3
4use hashbrown::HashMap;
5use serde::{Deserialize, Serialize};
6
7use crate::UPOS;
8
9use super::Chunker;
10
11/// Tracks the number of times any given UPOS is associated with a noun phrase.
12/// Used as the baseline for the chunker.
13#[derive(Debug, Clone, Serialize, Deserialize, Default)]
14pub struct UPOSFreqDict {
15    /// The # of times each [`UPOS`] was not part of an NP subtracted from the number of times it
16    /// was.
17    pub counts: HashMap<UPOS, isize>,
18}
19
20impl UPOSFreqDict {
21    pub fn is_likely_np_component(&self, upos: &UPOS) -> bool {
22        self.counts.get(upos).cloned().unwrap_or_default() > 0
23    }
24}
25
26impl Chunker for UPOSFreqDict {
27    fn chunk_sentence(&self, _sentence: &[String], tags: &[Option<UPOS>]) -> Vec<bool> {
28        tags.iter()
29            .map(|t| {
30                t.as_ref()
31                    .map(|t| self.is_likely_np_component(t))
32                    .unwrap_or(false)
33            })
34            .collect()
35    }
36}
37
38#[cfg(feature = "training")]
39impl UPOSFreqDict {
40    /// Increment the count for a particular lint kind.
41    pub fn inc_is_np(&mut self, upos: UPOS, is_np: bool) {
42        self.counts
43            .entry(upos)
44            .and_modify(|counter| *counter += if is_np { 1 } else { -1 })
45            .or_insert(1);
46    }
47
48    /// Parse a `.conllu` file and use it to train a frequency dictionary.
49    /// For error-handling purposes, this function should not be made accessible outside of training.
50    pub fn inc_from_conllu_file(&mut self, path: impl AsRef<Path>) {
51        use super::np_extraction::locate_noun_phrases_in_sent;
52        use crate::conllu_utils::iter_sentences_in_conllu;
53
54        for sent in iter_sentences_in_conllu(path) {
55            use hashbrown::HashSet;
56
57            let noun_phrases = locate_noun_phrases_in_sent(&sent);
58
59            let flat = noun_phrases.into_iter().fold(HashSet::new(), |mut a, b| {
60                a.extend(b);
61                a
62            });
63
64            for (i, token) in sent.tokens.iter().enumerate() {
65                if let Some(upos) = token.upos.and_then(UPOS::from_conllu) {
66                    self.inc_is_np(upos, flat.contains(&i))
67                }
68            }
69        }
70    }
71}