1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
//! Checks if the input text contains multi-token phrases from a finite list (might contain e. g. city names) and assigns lemmas and part-of-speech tags accordingly.

use crate::types::*;
use aho_corasick::AhoCorasick;
use serde::{Deserialize, Serialize};

use super::tag::Tagger;

#[derive(Serialize, Deserialize)]
pub(crate) struct MultiwordTaggerFields {
    pub(crate) multiwords: Vec<(String, owned::PosId)>,
}

impl From<MultiwordTaggerFields> for MultiwordTagger {
    fn from(data: MultiwordTaggerFields) -> Self {
        MultiwordTagger {
            matcher: AhoCorasick::new_auto_configured(
                &data
                    .multiwords
                    .iter()
                    .map(|(word, _)| word)
                    .collect::<Vec<_>>(),
            ),
            multiwords: data.multiwords,
        }
    }
}

#[derive(Deserialize, Serialize)]
#[serde(from = "MultiwordTaggerFields")]
pub struct MultiwordTagger {
    #[serde(skip)]
    matcher: AhoCorasick,
    multiwords: Vec<(String, owned::PosId)>,
}

impl MultiwordTagger {
    pub fn apply<'t>(&'t self, tokens: &mut Vec<IncompleteToken<'t>>, tagger: &'t Tagger) {
        let mut start_indices = DefaultHashMap::new();
        let mut end_indices = DefaultHashMap::new();
        let mut byte_index = 0;

        let joined = tokens
            .iter()
            .enumerate()
            .map(|(i, x)| {
                start_indices.insert(byte_index, i);
                byte_index += x.word.text.0.len();
                end_indices.insert(byte_index, i);
                byte_index += " ".len();

                x.word.text.0.as_ref()
            })
            .collect::<Vec<_>>()
            .join(" ");

        for m in self.matcher.find_iter(&joined) {
            if let (Some(start), Some(end)) =
                (start_indices.get(&m.start()), end_indices.get(&m.end()))
            {
                let (word, pos) = &self.multiwords[m.pattern()];
                // end index is inclusive
                for token in tokens[*start..(*end + 1)].iter_mut() {
                    token.multiword_data = Some(WordData::new(
                        tagger.id_word(word.as_str().into()),
                        pos.as_ref_id(),
                    ));
                }
            }
        }
    }
}