ferrous_opencc/dictionary/
fst_dict.rs1use std::{
2 fs::File,
3 io::{
4 BufReader,
5 Read,
6 Write,
7 },
8 path::Path,
9};
10
11use ferrous_opencc_compiler::{
12 ArchivedDelta,
13 ArchivedSerializableFstDict,
14};
15use fst::Map;
16
17use crate::{
18 dictionary::Dictionary,
19 error::Result,
20};
21
22#[derive(Debug)]
23pub struct FstDict {
24 map: Map<Vec<u8>>,
25 metadata_bytes: Vec<u8>,
26}
27
28fn apply_delta(key: &str, delta: &ArchivedDelta) -> String {
29 match delta {
30 ArchivedDelta::FullReplacement(s) => s.as_str().to_string(),
31 ArchivedDelta::CharDiffs(diffs) => {
32 let mut chars: Vec<char> = key.chars().collect();
33 for diff in diffs.iter() {
34 let (index, new_char): (u16, char) =
35 rkyv::deserialize::<_, rkyv::rancor::Error>(diff).unwrap();
36
37 if let Some(c) = chars.get_mut(index as usize) {
38 *c = new_char;
39 }
40 }
41 chars.into_iter().collect()
42 }
43 }
44}
45
46impl FstDict {
47 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
48 let path = path.as_ref();
49 let compiled_path = path.with_extension("ocb");
50
51 if compiled_path.is_file() {
52 if let Ok(text_meta) = path.metadata() {
53 if let Ok(compiled_meta) = compiled_path.metadata() {
54 let text_modified = text_meta.modified()?;
55 let compiled_modified = compiled_meta.modified()?;
56 if compiled_modified > text_modified {
57 return Self::from_ocb_file(&compiled_path);
58 }
59 }
60 } else {
61 return Self::from_ocb_file(&compiled_path);
62 }
63 }
64
65 let dict = Self::from_text(path)?;
66 let _ = dict.serialize_to_file(&compiled_path);
67 Ok(dict)
68 }
69
70 fn from_ocb_file(path: &Path) -> Result<Self> {
71 let file = File::open(path)?;
72 let reader = BufReader::new(file);
73 Self::from_reader(reader)
74 }
75
76 pub fn serialize_to_file(&self, path: &Path) -> Result<()> {
77 let mut file = File::create(path)?;
78 let mut final_bytes = Vec::new();
79 final_bytes.write_all(&(self.metadata_bytes.len() as u64).to_le_bytes())?;
80 final_bytes.write_all(&self.metadata_bytes)?;
81 final_bytes.write_all(self.map.as_fst().as_bytes())?;
82
83 file.write_all(&final_bytes)?;
84 Ok(())
85 }
86
87 pub fn from_text(path: &Path) -> Result<Self> {
88 let ocb_bytes = ferrous_opencc_compiler::compile_dictionary(path)?;
89 Self::from_ocb_bytes(&ocb_bytes)
90 }
91
92 pub fn from_ocb_bytes(bytes: &[u8]) -> Result<Self> {
93 Self::from_reader(bytes)
94 }
95
96 fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
97 let mut len_bytes = [0u8; 8];
98 reader.read_exact(&mut len_bytes)?;
99 let metadata_len = u64::from_le_bytes(len_bytes) as usize;
100
101 let mut metadata_bytes = vec![0; metadata_len];
102 reader.read_exact(&mut metadata_bytes)?;
103
104 rkyv::access::<ArchivedSerializableFstDict, rkyv::rancor::Error>(&metadata_bytes)?;
105
106 let mut fst_bytes: Vec<u8> = Vec::new();
107 reader.read_to_end(&mut fst_bytes)?;
108
109 let map = Map::new(fst_bytes)?;
110
111 Ok(Self {
112 map,
113 metadata_bytes,
114 })
115 }
116}
117
118impl Dictionary for FstDict {
119 fn match_prefix<'a>(&self, word: &'a str) -> Option<(&'a str, Vec<String>)> {
120 let fst = self.map.as_fst();
121 let mut node = fst.root();
122
123 let mut last_match: Option<(usize, u64)> = None;
124
125 let mut current_output: u64 = 0;
126
127 for (i, byte) in word.as_bytes().iter().enumerate() {
128 if let Some(trans_idx) = node.find_input(*byte) {
129 let t = node.transition(trans_idx);
130 current_output += t.out.value();
131 node = fst.node(t.addr);
132
133 if node.is_final() {
134 let final_value = current_output + node.final_output().value();
135 last_match = Some((i + 1, final_value));
136 }
137 } else {
138 break;
139 }
140 }
141
142 if let Some((len, value_index)) = last_match {
143 let metadata = unsafe {
144 rkyv::access_unchecked::<ArchivedSerializableFstDict>(&self.metadata_bytes)
145 };
146
147 if let Some(deltas) = metadata.values.get(value_index as usize) {
148 let key = &word[..len];
149 let result_values: Vec<String> =
150 deltas.iter().map(|delta| apply_delta(key, delta)).collect();
151 return Some((key, result_values));
152 }
153 }
154
155 None
156 }
157
158 fn max_key_length(&self) -> usize {
159 let metadata =
160 unsafe { rkyv::access_unchecked::<ArchivedSerializableFstDict>(&self.metadata_bytes) };
161
162 rkyv::deserialize::<u32, rkyv::rancor::Error>(&metadata.max_key_length).unwrap() as usize
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use std::path::PathBuf;
169
170 use tempfile::tempdir;
171
172 use super::*;
173
174 fn create_test_dict_file(dir: &tempfile::TempDir, content: &str) -> PathBuf {
175 let dict_path = dir.path().join("test_dict.txt");
176 let mut file = File::create(&dict_path).unwrap();
177 writeln!(file, "{content}").unwrap();
178 dict_path
179 }
180
181 #[test]
182 fn test_from_text_and_match_prefix() {
183 let dir = tempdir().unwrap();
184 let dict_content = "一\t一\n一个\t一個\n一个半\t一個半\n世纪\t世紀";
185 let dict_path = create_test_dict_file(&dir, dict_content);
186
187 let dict = FstDict::from_text(&dict_path).unwrap();
188
189 let (key, values) = dict.match_prefix("一个").unwrap();
190 assert_eq!(key, "一个");
191 let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
192 assert_eq!(values_str, ["一個"]);
193 let (key, values) = dict.match_prefix("一个半小时").unwrap();
194 assert_eq!(key, "一个半");
195 let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
196 assert_eq!(values_str, ["一個半"]);
197 let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
198 assert_eq!(values_str, ["一個半"]);
199
200 let (key, values) = dict.match_prefix("世纪之交").unwrap();
201 assert_eq!(key, "世纪");
202 let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
203 assert_eq!(values_str, ["世紀"]);
204 let (key, values) = dict.match_prefix("一").unwrap();
205 assert_eq!(key, "一");
206 let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
207 assert_eq!(values_str, ["一"]);
208 let (key, values) = dict.match_prefix("一").unwrap();
209 assert_eq!(key, "一");
210 let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
211 assert_eq!(values_str, ["一"]);
212 }
213
214 #[test]
215 fn test_serialization_and_deserialization() {
216 let dir = tempdir().unwrap();
217 let dict_content = "你好\tHello\n世界\tWorld";
218 let txt_path = create_test_dict_file(&dir, dict_content);
219 let ocb_path = txt_path.with_extension("ocb");
220
221 let dict_from_text = FstDict::from_text(&txt_path).unwrap();
222 dict_from_text.serialize_to_file(&ocb_path).unwrap();
223
224 assert!(ocb_path.exists());
225
226 let dict_from_ocb = FstDict::new(&txt_path).unwrap();
227
228 let (key, values) = dict_from_ocb.match_prefix("你好世界").unwrap();
229 assert_eq!(key, "你好");
230 let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
231 assert_eq!(values_str, ["Hello"]);
232
233 let (key, values) = dict_from_ocb.match_prefix("你好世界").unwrap();
234 assert_eq!(key, "你好");
235 let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
236 assert_eq!(values_str, ["Hello"]);
237 }
238}