lindera_filter/token_filter/
mapping.rs1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use yada::builder::DoubleArrayBuilder;
5use yada::DoubleArray;
6
7use lindera_core::error::LinderaErrorKind;
8use lindera_core::LinderaResult;
9
10use crate::token::Token;
11use crate::token_filter::TokenFilter;
12
13pub const MAPPING_TOKEN_FILTER_NAME: &str = "mapping";
14
15#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
16pub struct MappingTokenFilterConfig {
17 pub mapping: HashMap<String, String>,
18}
19
20impl MappingTokenFilterConfig {
21 pub fn new(map: HashMap<String, String>) -> Self {
22 Self { mapping: map }
23 }
24
25 pub fn from_slice(data: &[u8]) -> LinderaResult<Self> {
26 serde_json::from_slice::<MappingTokenFilterConfig>(data)
27 .map_err(|err| LinderaErrorKind::Deserialize.with_error(err))
28 }
29
30 pub fn from_value(value: &serde_json::Value) -> LinderaResult<Self> {
31 serde_json::from_value::<MappingTokenFilterConfig>(value.clone())
32 .map_err(|err| LinderaErrorKind::Deserialize.with_error(err))
33 }
34}
35
36#[derive(Clone)]
39pub struct MappingTokenFilter {
40 config: MappingTokenFilterConfig,
41 trie: DoubleArray<Vec<u8>>,
42}
43
44impl MappingTokenFilter {
45 pub fn new(config: MappingTokenFilterConfig) -> LinderaResult<Self> {
46 let mut keyset: Vec<(&[u8], u32)> = Vec::new();
47 let mut keys = config.mapping.keys().collect::<Vec<_>>();
48 keys.sort();
49 for (value, key) in keys.into_iter().enumerate() {
50 keyset.push((key.as_bytes(), value as u32));
51 }
52
53 let data = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
54 LinderaErrorKind::Io.with_error(anyhow::anyhow!("DoubleArray build error."))
55 })?;
56
57 let trie = DoubleArray::new(data);
58
59 Ok(Self { config, trie })
60 }
61
62 pub fn from_slice(data: &[u8]) -> LinderaResult<Self> {
63 Self::new(MappingTokenFilterConfig::from_slice(data)?)
64 }
65}
66
67impl TokenFilter for MappingTokenFilter {
68 fn name(&self) -> &'static str {
69 MAPPING_TOKEN_FILTER_NAME
70 }
71
72 fn apply<'a>(&self, tokens: &mut Vec<Token>) -> LinderaResult<()> {
73 for token in tokens.iter_mut() {
74 let mut result = String::new();
75 let mut start = 0_usize;
76 let len = token.text.len();
77
78 while start < len {
79 let suffix = &token.text[start..];
80 match self
81 .trie
82 .common_prefix_search(suffix.as_bytes())
83 .last()
84 .map(|(_offset_len, prefix_len)| prefix_len)
85 {
86 Some(prefix_len) => {
87 let surface = &token.text[start..start + prefix_len];
88 let replacement = &self.config.mapping[surface];
89
90 result.push_str(replacement);
91
92 start += prefix_len;
94 }
95 None => {
96 match suffix.chars().next() {
97 Some(c) => {
98 result.push(c);
99
100 start += c.len_utf8();
102 }
103 None => break,
104 }
105 }
106 }
107 }
108
109 token.text = result;
110 }
111
112 Ok(())
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 #[cfg(all(feature = "ipadic", feature = "filter",))]
119 use lindera_core::word_entry::WordId;
120
121 use crate::token_filter::mapping::{MappingTokenFilter, MappingTokenFilterConfig};
122 #[cfg(all(feature = "ipadic", feature = "filter",))]
123 use crate::{token::Token, token_filter::TokenFilter};
124
125 #[test]
126 fn test_mapping_token_filter_config_from_slice() {
127 let config_str = r#"
128 {
129 "mapping": {
130 "ア": "ア",
131 "イ": "イ",
132 "ウ": "ウ",
133 "エ": "エ",
134 "オ": "オ"
135 }
136 }
137 "#;
138 let config = MappingTokenFilterConfig::from_slice(config_str.as_bytes()).unwrap();
139 assert_eq!("ア", config.mapping.get("ア").unwrap());
140 }
141
142 #[test]
143 fn test_mapping_token_filter_from_slice() {
144 let config_str = r#"
145 {
146 "mapping": {
147 "ア": "ア",
148 "イ": "イ",
149 "ウ": "ウ",
150 "エ": "エ",
151 "オ": "オ"
152 }
153 }
154 "#;
155 let result = MappingTokenFilter::from_slice(config_str.as_bytes());
156 assert_eq!(true, result.is_ok());
157 }
158
159 #[test]
160 #[cfg(all(feature = "ipadic", feature = "filter",))]
161 fn test_mapping_token_filter_apply_ipadic() {
162 let config_str = r#"
163 {
164 "mapping": {
165 "籠": "篭"
166 }
167 }
168 "#;
169 let filter = MappingTokenFilter::from_slice(config_str.as_bytes()).unwrap();
170
171 let mut tokens: Vec<Token> = vec![
172 Token {
173 text: "籠原".to_string(),
174 byte_start: 0,
175 byte_end: 6,
176 position: 0,
177 position_length: 1,
178 word_id: WordId(312630, true),
179 details: vec![
180 "名詞".to_string(),
181 "固有名詞".to_string(),
182 "一般".to_string(),
183 "*".to_string(),
184 "*".to_string(),
185 "*".to_string(),
186 "籠原".to_string(),
187 "カゴハラ".to_string(),
188 "カゴハラ".to_string(),
189 ],
190 },
191 Token {
192 text: "駅".to_string(),
193 byte_start: 6,
194 byte_end: 9,
195 position: 1,
196 position_length: 1,
197 word_id: WordId(383791, true),
198 details: vec![
199 "名詞".to_string(),
200 "接尾".to_string(),
201 "地域".to_string(),
202 "*".to_string(),
203 "*".to_string(),
204 "*".to_string(),
205 "駅".to_string(),
206 "エキ".to_string(),
207 "エキ".to_string(),
208 ],
209 },
210 ];
211
212 filter.apply(&mut tokens).unwrap();
213
214 assert_eq!(tokens.len(), 2);
215 assert_eq!(&tokens[0].text, "篭原");
216 assert_eq!(&tokens[1].text, "駅");
217 }
218}