lindera_filter/token_filter/
mapping.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use yada::builder::DoubleArrayBuilder;
5use yada::DoubleArray;
6
7use lindera_core::error::LinderaErrorKind;
8use lindera_core::LinderaResult;
9
10use crate::token::Token;
11use crate::token_filter::TokenFilter;
12
13pub const MAPPING_TOKEN_FILTER_NAME: &str = "mapping";
14
15#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
16pub struct MappingTokenFilterConfig {
17    pub mapping: HashMap<String, String>,
18}
19
20impl MappingTokenFilterConfig {
21    pub fn new(map: HashMap<String, String>) -> Self {
22        Self { mapping: map }
23    }
24
25    pub fn from_slice(data: &[u8]) -> LinderaResult<Self> {
26        serde_json::from_slice::<MappingTokenFilterConfig>(data)
27            .map_err(|err| LinderaErrorKind::Deserialize.with_error(err))
28    }
29
30    pub fn from_value(value: &serde_json::Value) -> LinderaResult<Self> {
31        serde_json::from_value::<MappingTokenFilterConfig>(value.clone())
32            .map_err(|err| LinderaErrorKind::Deserialize.with_error(err))
33    }
34}
35
36/// Replace characters with the specified character mappings.
37///
38#[derive(Clone)]
39pub struct MappingTokenFilter {
40    config: MappingTokenFilterConfig,
41    trie: DoubleArray<Vec<u8>>,
42}
43
44impl MappingTokenFilter {
45    pub fn new(config: MappingTokenFilterConfig) -> LinderaResult<Self> {
46        let mut keyset: Vec<(&[u8], u32)> = Vec::new();
47        let mut keys = config.mapping.keys().collect::<Vec<_>>();
48        keys.sort();
49        for (value, key) in keys.into_iter().enumerate() {
50            keyset.push((key.as_bytes(), value as u32));
51        }
52
53        let data = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
54            LinderaErrorKind::Io.with_error(anyhow::anyhow!("DoubleArray build error."))
55        })?;
56
57        let trie = DoubleArray::new(data);
58
59        Ok(Self { config, trie })
60    }
61
62    pub fn from_slice(data: &[u8]) -> LinderaResult<Self> {
63        Self::new(MappingTokenFilterConfig::from_slice(data)?)
64    }
65}
66
67impl TokenFilter for MappingTokenFilter {
68    fn name(&self) -> &'static str {
69        MAPPING_TOKEN_FILTER_NAME
70    }
71
72    fn apply<'a>(&self, tokens: &mut Vec<Token>) -> LinderaResult<()> {
73        for token in tokens.iter_mut() {
74            let mut result = String::new();
75            let mut start = 0_usize;
76            let len = token.text.len();
77
78            while start < len {
79                let suffix = &token.text[start..];
80                match self
81                    .trie
82                    .common_prefix_search(suffix.as_bytes())
83                    .last()
84                    .map(|(_offset_len, prefix_len)| prefix_len)
85                {
86                    Some(prefix_len) => {
87                        let surface = &token.text[start..start + prefix_len];
88                        let replacement = &self.config.mapping[surface];
89
90                        result.push_str(replacement);
91
92                        // move start offset
93                        start += prefix_len;
94                    }
95                    None => {
96                        match suffix.chars().next() {
97                            Some(c) => {
98                                result.push(c);
99
100                                // move start offset
101                                start += c.len_utf8();
102                            }
103                            None => break,
104                        }
105                    }
106                }
107            }
108
109            token.text = result;
110        }
111
112        Ok(())
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    #[cfg(all(feature = "ipadic", feature = "filter",))]
119    use lindera_core::word_entry::WordId;
120
121    use crate::token_filter::mapping::{MappingTokenFilter, MappingTokenFilterConfig};
122    #[cfg(all(feature = "ipadic", feature = "filter",))]
123    use crate::{token::Token, token_filter::TokenFilter};
124
125    #[test]
126    fn test_mapping_token_filter_config_from_slice() {
127        let config_str = r#"
128        {
129            "mapping": {
130                "ア": "ア",
131                "イ": "イ",
132                "ウ": "ウ",
133                "エ": "エ",
134                "オ": "オ"
135            }
136        }
137        "#;
138        let config = MappingTokenFilterConfig::from_slice(config_str.as_bytes()).unwrap();
139        assert_eq!("ア", config.mapping.get("ア").unwrap());
140    }
141
142    #[test]
143    fn test_mapping_token_filter_from_slice() {
144        let config_str = r#"
145        {
146            "mapping": {
147                "ア": "ア",
148                "イ": "イ",
149                "ウ": "ウ",
150                "エ": "エ",
151                "オ": "オ"
152            }
153        }
154        "#;
155        let result = MappingTokenFilter::from_slice(config_str.as_bytes());
156        assert_eq!(true, result.is_ok());
157    }
158
159    #[test]
160    #[cfg(all(feature = "ipadic", feature = "filter",))]
161    fn test_mapping_token_filter_apply_ipadic() {
162        let config_str = r#"
163        {
164            "mapping": {
165                "籠": "篭"
166            }
167        }
168        "#;
169        let filter = MappingTokenFilter::from_slice(config_str.as_bytes()).unwrap();
170
171        let mut tokens: Vec<Token> = vec![
172            Token {
173                text: "籠原".to_string(),
174                byte_start: 0,
175                byte_end: 6,
176                position: 0,
177                position_length: 1,
178                word_id: WordId(312630, true),
179                details: vec![
180                    "名詞".to_string(),
181                    "固有名詞".to_string(),
182                    "一般".to_string(),
183                    "*".to_string(),
184                    "*".to_string(),
185                    "*".to_string(),
186                    "籠原".to_string(),
187                    "カゴハラ".to_string(),
188                    "カゴハラ".to_string(),
189                ],
190            },
191            Token {
192                text: "駅".to_string(),
193                byte_start: 6,
194                byte_end: 9,
195                position: 1,
196                position_length: 1,
197                word_id: WordId(383791, true),
198                details: vec![
199                    "名詞".to_string(),
200                    "接尾".to_string(),
201                    "地域".to_string(),
202                    "*".to_string(),
203                    "*".to_string(),
204                    "*".to_string(),
205                    "駅".to_string(),
206                    "エキ".to_string(),
207                    "エキ".to_string(),
208                ],
209            },
210        ];
211
212        filter.apply(&mut tokens).unwrap();
213
214        assert_eq!(tokens.len(), 2);
215        assert_eq!(&tokens[0].text, "篭原");
216        assert_eq!(&tokens[1].text, "駅");
217    }
218}