lindera_filter/character_filter/
mapping.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use yada::builder::DoubleArrayBuilder;
6use yada::DoubleArray;
7
8use lindera_core::error::LinderaErrorKind;
9use lindera_core::LinderaResult;
10
11use crate::character_filter::{add_offset_diff, CharacterFilter};
12
13pub const MAPPING_CHARACTER_FILTER_NAME: &str = "mapping";
14
15#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
16pub struct MappingCharacterFilterConfig {
17    pub mapping: HashMap<String, String>,
18}
19
20impl MappingCharacterFilterConfig {
21    pub fn new(map: HashMap<String, String>) -> Self {
22        Self { mapping: map }
23    }
24
25    pub fn from_slice(data: &[u8]) -> LinderaResult<Self> {
26        serde_json::from_slice::<MappingCharacterFilterConfig>(data)
27            .map_err(|err| LinderaErrorKind::Deserialize.with_error(err))
28    }
29
30    pub fn from_value(value: &Value) -> LinderaResult<Self> {
31        serde_json::from_value::<MappingCharacterFilterConfig>(value.clone())
32            .map_err(|err| LinderaErrorKind::Deserialize.with_error(err))
33    }
34}
35
36/// Replace characters with the specified character mappings,
37/// and correcting the resulting changes to the offsets.
38/// Matching is greedy (longest pattern matching at a given point wins).
39/// Replacement is allowed to be the empty string.
40///
41#[derive(Clone)]
42pub struct MappingCharacterFilter {
43    config: MappingCharacterFilterConfig,
44    trie: DoubleArray<Vec<u8>>,
45}
46
47impl MappingCharacterFilter {
48    pub fn new(config: MappingCharacterFilterConfig) -> LinderaResult<Self> {
49        let mut keyset: Vec<(&[u8], u32)> = Vec::new();
50        let mut keys = config.mapping.keys().collect::<Vec<_>>();
51        keys.sort();
52        for (value, key) in keys.into_iter().enumerate() {
53            keyset.push((key.as_bytes(), value as u32));
54        }
55
56        let data = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
57            LinderaErrorKind::Io.with_error(anyhow::anyhow!("DoubleArray build error."))
58        })?;
59
60        let trie = DoubleArray::new(data);
61
62        Ok(Self { config, trie })
63    }
64
65    pub fn from_slice(data: &[u8]) -> LinderaResult<Self> {
66        Self::new(MappingCharacterFilterConfig::from_slice(data)?)
67    }
68}
69
70impl CharacterFilter for MappingCharacterFilter {
71    fn name(&self) -> &'static str {
72        MAPPING_CHARACTER_FILTER_NAME
73    }
74
75    fn apply(&self, text: &str) -> LinderaResult<(String, Vec<usize>, Vec<i64>)> {
76        let mut offsets: Vec<usize> = Vec::new();
77        let mut diffs: Vec<i64> = Vec::new();
78
79        let mut result = String::new();
80        let mut input_start = 0_usize;
81        let len = text.len();
82
83        while input_start < len {
84            let suffix = &text[input_start..];
85            match self
86                .trie
87                .common_prefix_search(suffix.as_bytes())
88                .last()
89                .map(|(_offset_len, prefix_len)| prefix_len)
90            {
91                Some(input_len) => {
92                    let input_text = &text[input_start..input_start + input_len];
93                    let replacement_text = &self.config.mapping[input_text];
94                    let replacement_len = replacement_text.len();
95                    let diff_len = input_len as i64 - replacement_len as i64;
96                    let input_offset = input_start + input_len;
97
98                    if diff_len != 0 {
99                        let prev_diff = *diffs.last().unwrap_or(&0);
100
101                        if diff_len > 0 {
102                            // Replacement is shorter than matched surface.
103                            let offset = (input_offset as i64 - diff_len - prev_diff) as usize;
104                            let diff = prev_diff + diff_len;
105                            add_offset_diff(&mut offsets, &mut diffs, offset, diff);
106                        } else {
107                            // Replacement is longer than matched surface.
108                            let output_offset = (input_offset as i64 + -prev_diff) as usize;
109                            for extra_idx in 0..diff_len.unsigned_abs() as usize {
110                                let offset = output_offset + extra_idx;
111                                let diff = prev_diff - extra_idx as i64 - 1;
112                                add_offset_diff(&mut offsets, &mut diffs, offset, diff);
113                            }
114                        }
115                    }
116
117                    result.push_str(replacement_text);
118
119                    // move start offset
120                    input_start += input_len;
121                }
122                None => {
123                    match suffix.chars().next() {
124                        Some(c) => {
125                            result.push(c);
126
127                            // move start offset
128                            input_start += c.len_utf8();
129                        }
130                        None => break,
131                    }
132                }
133            }
134        }
135
136        Ok((result, offsets, diffs))
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use crate::character_filter::mapping::{MappingCharacterFilter, MappingCharacterFilterConfig};
143    use crate::character_filter::{correct_offset, CharacterFilter};
144
145    #[test]
146    fn test_mapping_character_filter_config_from_slice() {
147        let config_str = r#"
148        {
149            "mapping": {
150                "ア": "ア",
151                "イ": "イ",
152                "ウ": "ウ",
153                "エ": "エ",
154                "オ": "オ"
155            }
156        }
157        "#;
158        let config = MappingCharacterFilterConfig::from_slice(config_str.as_bytes()).unwrap();
159        assert_eq!("ア", config.mapping.get("ア").unwrap());
160    }
161
162    #[test]
163    fn test_mapping_character_filter_from_slice() {
164        let config_str = r#"
165        {
166            "mapping": {
167                "ア": "ア",
168                "イ": "イ",
169                "ウ": "ウ",
170                "エ": "エ",
171                "オ": "オ"
172            }
173        }
174        "#;
175        let result = MappingCharacterFilter::from_slice(config_str.as_bytes());
176        assert_eq!(true, result.is_ok());
177    }
178
179    #[test]
180    fn test_mapping_character_filter_apply() {
181        {
182            let config_str = r#"
183            {
184                "mapping": {
185                    "ア": "ア",
186                    "イ": "イ",
187                    "ウ": "ウ",
188                    "エ": "エ",
189                    "オ": "オ"
190                }
191            }
192            "#;
193            let filter = MappingCharacterFilter::from_slice(config_str.as_bytes()).unwrap();
194            let text = "アイウエオ";
195            let (filterd_text, offsets, diffs) = filter.apply(text).unwrap();
196            assert_eq!("アイウエオ", filterd_text);
197            assert_eq!(Vec::<usize>::new(), offsets);
198            assert_eq!(Vec::<i64>::new(), diffs);
199            let start = 3;
200            let end = 6;
201            assert_eq!("イ", &filterd_text[start..end]);
202            let correct_start = correct_offset(start, &offsets, &diffs, filterd_text.len());
203            let correct_end = correct_offset(end, &offsets, &diffs, filterd_text.len());
204            assert_eq!(3, correct_start);
205            assert_eq!(6, correct_end);
206            assert_eq!("イ", &text[correct_start..correct_end]);
207        }
208
209        {
210            let config_str = r#"
211            {
212                "mapping": {
213                    "リ": "リ",
214                    "ン": "ン",
215                    "デ": "デ",
216                    "ラ": "ラ"
217                }
218            }
219            "#;
220            let filter = MappingCharacterFilter::from_slice(config_str.as_bytes()).unwrap();
221            let text = "リンデラ";
222            let (filterd_text, offsets, diffs) = filter.apply(&text).unwrap();
223            assert_eq!("リンデラ", filterd_text);
224            assert_eq!(vec![9], offsets);
225            assert_eq!(vec![3], diffs);
226            let start = 6;
227            let end = 9;
228            assert_eq!("デ", &filterd_text[start..end]);
229            let correct_start = correct_offset(start, &offsets, &diffs, filterd_text.len());
230            let correct_end = correct_offset(end, &offsets, &diffs, filterd_text.len());
231            assert_eq!(6, correct_start);
232            assert_eq!(12, correct_end);
233            assert_eq!("デ", &text[correct_start..correct_end]);
234        }
235
236        {
237            let config_str = r#"
238            {
239                "mapping": {
240                    "リンデラ": "リンデラ"
241                }
242            }
243            "#;
244            let filter = MappingCharacterFilter::from_slice(config_str.as_bytes()).unwrap();
245            let text = "リンデラ";
246            let (filterd_text, offsets, diffs) = filter.apply(text).unwrap();
247            assert_eq!("リンデラ", filterd_text);
248            assert_eq!(vec![12], offsets);
249            assert_eq!(vec![3], diffs);
250            let start = 0;
251            let end = 12;
252            assert_eq!("リンデラ", &filterd_text[start..end]);
253            let correct_start = correct_offset(start, &offsets, &diffs, filterd_text.len());
254            let correct_end = correct_offset(end, &offsets, &diffs, filterd_text.len());
255            assert_eq!(0, correct_start);
256            assert_eq!(15, correct_end);
257            assert_eq!("リンデラ", &text[correct_start..correct_end]);
258        }
259
260        {
261            let config_str = r#"
262            {
263                "mapping": {
264                    "リンデラ": "Lindera"
265                }
266            }
267            "#;
268            let filter = MappingCharacterFilter::from_slice(config_str.as_bytes()).unwrap();
269            let text = "Rust製形態素解析器リンデラで日本語を形態素解析する。";
270            let (filterd_text, offsets, diffs) = filter.apply(text).unwrap();
271            assert_eq!(
272                "Rust製形態素解析器Linderaで日本語を形態素解析する。",
273                filterd_text
274            );
275            assert_eq!(vec![32], offsets);
276            assert_eq!(vec![5], diffs);
277            let start = 25;
278            let end = 32;
279            assert_eq!("Lindera", &filterd_text[start..end]);
280            let correct_start = correct_offset(start, &offsets, &diffs, filterd_text.len());
281            let correct_end = correct_offset(end, &offsets, &diffs, filterd_text.len());
282            assert_eq!(25, correct_start);
283            assert_eq!(37, correct_end);
284            assert_eq!("リンデラ", &text[correct_start..correct_end]);
285            let start = 35;
286            let end = 44;
287            assert_eq!("日本語", &filterd_text[start..end]);
288            let correct_start = correct_offset(start, &offsets, &diffs, filterd_text.len());
289            let correct_end = correct_offset(end, &offsets, &diffs, filterd_text.len());
290            assert_eq!(40, correct_start);
291            assert_eq!(49, correct_end);
292            assert_eq!("日本語", &text[correct_start..correct_end]);
293        }
294
295        {
296            let config_str = r#"
297            {
298                "mapping": {
299                    "1": "1",
300                    "0": "0",
301                    "㍑": "リットル"
302                }
303            }
304            "#;
305            let filter = MappingCharacterFilter::from_slice(config_str.as_bytes()).unwrap();
306            let text = "10㍑";
307            let (filterd_text, offsets, diffs) = filter.apply(text).unwrap();
308            assert_eq!("10リットル", filterd_text);
309            assert_eq!(vec![1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13], offsets);
310            assert_eq!(vec![2, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5], diffs);
311            let start = 0;
312            let end = 2;
313            assert_eq!("10", &filterd_text[start..end]);
314            let correct_start = correct_offset(start, &offsets, &diffs, filterd_text.len());
315            let correct_end = correct_offset(end, &offsets, &diffs, filterd_text.len());
316            assert_eq!(0, correct_start);
317            assert_eq!(6, correct_end);
318            assert_eq!("10", &text[correct_start..correct_end]);
319            let start = 2;
320            let end = 14;
321            assert_eq!("リットル", &filterd_text[start..end]);
322            let correct_start = correct_offset(start, &offsets, &diffs, filterd_text.len());
323            let correct_end = correct_offset(end, &offsets, &diffs, filterd_text.len());
324            assert_eq!(6, correct_start);
325            assert_eq!(9, correct_end);
326            assert_eq!("㍑", &text[correct_start..correct_end]);
327        }
328    }
329}