1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use yada::builder::DoubleArrayBuilder;
6use yada::DoubleArray;
7
8use lindera_core::error::LinderaErrorKind;
9use lindera_core::LinderaResult;
10
11use crate::character_filter::{add_offset_diff, CharacterFilter};
12
13pub const MAPPING_CHARACTER_FILTER_NAME: &str = "mapping";
14
15#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
16pub struct MappingCharacterFilterConfig {
17 pub mapping: HashMap<String, String>,
18}
19
20impl MappingCharacterFilterConfig {
21 pub fn new(map: HashMap<String, String>) -> Self {
22 Self { mapping: map }
23 }
24
25 pub fn from_slice(data: &[u8]) -> LinderaResult<Self> {
26 serde_json::from_slice::<MappingCharacterFilterConfig>(data)
27 .map_err(|err| LinderaErrorKind::Deserialize.with_error(err))
28 }
29
30 pub fn from_value(value: &Value) -> LinderaResult<Self> {
31 serde_json::from_value::<MappingCharacterFilterConfig>(value.clone())
32 .map_err(|err| LinderaErrorKind::Deserialize.with_error(err))
33 }
34}
35
36#[derive(Clone)]
42pub struct MappingCharacterFilter {
43 config: MappingCharacterFilterConfig,
44 trie: DoubleArray<Vec<u8>>,
45}
46
47impl MappingCharacterFilter {
48 pub fn new(config: MappingCharacterFilterConfig) -> LinderaResult<Self> {
49 let mut keyset: Vec<(&[u8], u32)> = Vec::new();
50 let mut keys = config.mapping.keys().collect::<Vec<_>>();
51 keys.sort();
52 for (value, key) in keys.into_iter().enumerate() {
53 keyset.push((key.as_bytes(), value as u32));
54 }
55
56 let data = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
57 LinderaErrorKind::Io.with_error(anyhow::anyhow!("DoubleArray build error."))
58 })?;
59
60 let trie = DoubleArray::new(data);
61
62 Ok(Self { config, trie })
63 }
64
65 pub fn from_slice(data: &[u8]) -> LinderaResult<Self> {
66 Self::new(MappingCharacterFilterConfig::from_slice(data)?)
67 }
68}
69
70impl CharacterFilter for MappingCharacterFilter {
71 fn name(&self) -> &'static str {
72 MAPPING_CHARACTER_FILTER_NAME
73 }
74
75 fn apply(&self, text: &str) -> LinderaResult<(String, Vec<usize>, Vec<i64>)> {
76 let mut offsets: Vec<usize> = Vec::new();
77 let mut diffs: Vec<i64> = Vec::new();
78
79 let mut result = String::new();
80 let mut input_start = 0_usize;
81 let len = text.len();
82
83 while input_start < len {
84 let suffix = &text[input_start..];
85 match self
86 .trie
87 .common_prefix_search(suffix.as_bytes())
88 .last()
89 .map(|(_offset_len, prefix_len)| prefix_len)
90 {
91 Some(input_len) => {
92 let input_text = &text[input_start..input_start + input_len];
93 let replacement_text = &self.config.mapping[input_text];
94 let replacement_len = replacement_text.len();
95 let diff_len = input_len as i64 - replacement_len as i64;
96 let input_offset = input_start + input_len;
97
98 if diff_len != 0 {
99 let prev_diff = *diffs.last().unwrap_or(&0);
100
101 if diff_len > 0 {
102 let offset = (input_offset as i64 - diff_len - prev_diff) as usize;
104 let diff = prev_diff + diff_len;
105 add_offset_diff(&mut offsets, &mut diffs, offset, diff);
106 } else {
107 let output_offset = (input_offset as i64 + -prev_diff) as usize;
109 for extra_idx in 0..diff_len.unsigned_abs() as usize {
110 let offset = output_offset + extra_idx;
111 let diff = prev_diff - extra_idx as i64 - 1;
112 add_offset_diff(&mut offsets, &mut diffs, offset, diff);
113 }
114 }
115 }
116
117 result.push_str(replacement_text);
118
119 input_start += input_len;
121 }
122 None => {
123 match suffix.chars().next() {
124 Some(c) => {
125 result.push(c);
126
127 input_start += c.len_utf8();
129 }
130 None => break,
131 }
132 }
133 }
134 }
135
136 Ok((result, offsets, diffs))
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use crate::character_filter::mapping::{MappingCharacterFilter, MappingCharacterFilterConfig};
143 use crate::character_filter::{correct_offset, CharacterFilter};
144
145 #[test]
146 fn test_mapping_character_filter_config_from_slice() {
147 let config_str = r#"
148 {
149 "mapping": {
150 "ア": "ア",
151 "イ": "イ",
152 "ウ": "ウ",
153 "エ": "エ",
154 "オ": "オ"
155 }
156 }
157 "#;
158 let config = MappingCharacterFilterConfig::from_slice(config_str.as_bytes()).unwrap();
159 assert_eq!("ア", config.mapping.get("ア").unwrap());
160 }
161
162 #[test]
163 fn test_mapping_character_filter_from_slice() {
164 let config_str = r#"
165 {
166 "mapping": {
167 "ア": "ア",
168 "イ": "イ",
169 "ウ": "ウ",
170 "エ": "エ",
171 "オ": "オ"
172 }
173 }
174 "#;
175 let result = MappingCharacterFilter::from_slice(config_str.as_bytes());
176 assert_eq!(true, result.is_ok());
177 }
178
179 #[test]
180 fn test_mapping_character_filter_apply() {
181 {
182 let config_str = r#"
183 {
184 "mapping": {
185 "ア": "ア",
186 "イ": "イ",
187 "ウ": "ウ",
188 "エ": "エ",
189 "オ": "オ"
190 }
191 }
192 "#;
193 let filter = MappingCharacterFilter::from_slice(config_str.as_bytes()).unwrap();
194 let text = "アイウエオ";
195 let (filterd_text, offsets, diffs) = filter.apply(text).unwrap();
196 assert_eq!("アイウエオ", filterd_text);
197 assert_eq!(Vec::<usize>::new(), offsets);
198 assert_eq!(Vec::<i64>::new(), diffs);
199 let start = 3;
200 let end = 6;
201 assert_eq!("イ", &filterd_text[start..end]);
202 let correct_start = correct_offset(start, &offsets, &diffs, filterd_text.len());
203 let correct_end = correct_offset(end, &offsets, &diffs, filterd_text.len());
204 assert_eq!(3, correct_start);
205 assert_eq!(6, correct_end);
206 assert_eq!("イ", &text[correct_start..correct_end]);
207 }
208
209 {
210 let config_str = r#"
211 {
212 "mapping": {
213 "リ": "リ",
214 "ン": "ン",
215 "デ": "デ",
216 "ラ": "ラ"
217 }
218 }
219 "#;
220 let filter = MappingCharacterFilter::from_slice(config_str.as_bytes()).unwrap();
221 let text = "リンデラ";
222 let (filterd_text, offsets, diffs) = filter.apply(&text).unwrap();
223 assert_eq!("リンデラ", filterd_text);
224 assert_eq!(vec![9], offsets);
225 assert_eq!(vec![3], diffs);
226 let start = 6;
227 let end = 9;
228 assert_eq!("デ", &filterd_text[start..end]);
229 let correct_start = correct_offset(start, &offsets, &diffs, filterd_text.len());
230 let correct_end = correct_offset(end, &offsets, &diffs, filterd_text.len());
231 assert_eq!(6, correct_start);
232 assert_eq!(12, correct_end);
233 assert_eq!("デ", &text[correct_start..correct_end]);
234 }
235
236 {
237 let config_str = r#"
238 {
239 "mapping": {
240 "リンデラ": "リンデラ"
241 }
242 }
243 "#;
244 let filter = MappingCharacterFilter::from_slice(config_str.as_bytes()).unwrap();
245 let text = "リンデラ";
246 let (filterd_text, offsets, diffs) = filter.apply(text).unwrap();
247 assert_eq!("リンデラ", filterd_text);
248 assert_eq!(vec![12], offsets);
249 assert_eq!(vec![3], diffs);
250 let start = 0;
251 let end = 12;
252 assert_eq!("リンデラ", &filterd_text[start..end]);
253 let correct_start = correct_offset(start, &offsets, &diffs, filterd_text.len());
254 let correct_end = correct_offset(end, &offsets, &diffs, filterd_text.len());
255 assert_eq!(0, correct_start);
256 assert_eq!(15, correct_end);
257 assert_eq!("リンデラ", &text[correct_start..correct_end]);
258 }
259
260 {
261 let config_str = r#"
262 {
263 "mapping": {
264 "リンデラ": "Lindera"
265 }
266 }
267 "#;
268 let filter = MappingCharacterFilter::from_slice(config_str.as_bytes()).unwrap();
269 let text = "Rust製形態素解析器リンデラで日本語を形態素解析する。";
270 let (filterd_text, offsets, diffs) = filter.apply(text).unwrap();
271 assert_eq!(
272 "Rust製形態素解析器Linderaで日本語を形態素解析する。",
273 filterd_text
274 );
275 assert_eq!(vec![32], offsets);
276 assert_eq!(vec![5], diffs);
277 let start = 25;
278 let end = 32;
279 assert_eq!("Lindera", &filterd_text[start..end]);
280 let correct_start = correct_offset(start, &offsets, &diffs, filterd_text.len());
281 let correct_end = correct_offset(end, &offsets, &diffs, filterd_text.len());
282 assert_eq!(25, correct_start);
283 assert_eq!(37, correct_end);
284 assert_eq!("リンデラ", &text[correct_start..correct_end]);
285 let start = 35;
286 let end = 44;
287 assert_eq!("日本語", &filterd_text[start..end]);
288 let correct_start = correct_offset(start, &offsets, &diffs, filterd_text.len());
289 let correct_end = correct_offset(end, &offsets, &diffs, filterd_text.len());
290 assert_eq!(40, correct_start);
291 assert_eq!(49, correct_end);
292 assert_eq!("日本語", &text[correct_start..correct_end]);
293 }
294
295 {
296 let config_str = r#"
297 {
298 "mapping": {
299 "1": "1",
300 "0": "0",
301 "㍑": "リットル"
302 }
303 }
304 "#;
305 let filter = MappingCharacterFilter::from_slice(config_str.as_bytes()).unwrap();
306 let text = "10㍑";
307 let (filterd_text, offsets, diffs) = filter.apply(text).unwrap();
308 assert_eq!("10リットル", filterd_text);
309 assert_eq!(vec![1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13], offsets);
310 assert_eq!(vec![2, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5], diffs);
311 let start = 0;
312 let end = 2;
313 assert_eq!("10", &filterd_text[start..end]);
314 let correct_start = correct_offset(start, &offsets, &diffs, filterd_text.len());
315 let correct_end = correct_offset(end, &offsets, &diffs, filterd_text.len());
316 assert_eq!(0, correct_start);
317 assert_eq!(6, correct_end);
318 assert_eq!("10", &text[correct_start..correct_end]);
319 let start = 2;
320 let end = 14;
321 assert_eq!("リットル", &filterd_text[start..end]);
322 let correct_start = correct_offset(start, &offsets, &diffs, filterd_text.len());
323 let correct_end = correct_offset(end, &offsets, &diffs, filterd_text.len());
324 assert_eq!(6, correct_start);
325 assert_eq!(9, correct_end);
326 assert_eq!("㍑", &text[correct_start..correct_end]);
327 }
328 }
329}