Skip to main content

ferrous_opencc/dictionary/
fst_dict.rs

1use std::{
2    fs::File,
3    io::{
4        BufReader,
5        Read,
6        Write,
7    },
8    path::Path,
9};
10
11use ferrous_opencc_compiler::{
12    ArchivedDelta,
13    ArchivedSerializableFstDict,
14};
15use fst::Map;
16
17use crate::{
18    dictionary::Dictionary,
19    error::Result,
20};
21
22#[derive(Debug)]
23pub struct FstDict {
24    map: Map<Vec<u8>>,
25    metadata_bytes: Vec<u8>,
26}
27
28fn apply_delta(key: &str, delta: &ArchivedDelta) -> String {
29    match delta {
30        ArchivedDelta::FullReplacement(s) => s.as_str().to_string(),
31        ArchivedDelta::CharDiffs(diffs) => {
32            let mut chars: Vec<char> = key.chars().collect();
33            for diff in diffs.iter() {
34                let (index, new_char): (u16, char) =
35                    rkyv::deserialize::<_, rkyv::rancor::Error>(diff).unwrap();
36
37                if let Some(c) = chars.get_mut(index as usize) {
38                    *c = new_char;
39                }
40            }
41            chars.into_iter().collect()
42        }
43    }
44}
45
46impl FstDict {
47    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
48        let path = path.as_ref();
49        let compiled_path = path.with_extension("ocb");
50
51        if compiled_path.is_file() {
52            if let Ok(text_meta) = path.metadata() {
53                if let Ok(compiled_meta) = compiled_path.metadata() {
54                    let text_modified = text_meta.modified()?;
55                    let compiled_modified = compiled_meta.modified()?;
56                    if compiled_modified > text_modified {
57                        return Self::from_ocb_file(&compiled_path);
58                    }
59                }
60            } else {
61                return Self::from_ocb_file(&compiled_path);
62            }
63        }
64
65        let dict = Self::from_text(path)?;
66        let _ = dict.serialize_to_file(&compiled_path);
67        Ok(dict)
68    }
69
70    fn from_ocb_file(path: &Path) -> Result<Self> {
71        let file = File::open(path)?;
72        let reader = BufReader::new(file);
73        Self::from_reader(reader)
74    }
75
76    pub fn serialize_to_file(&self, path: &Path) -> Result<()> {
77        let mut file = File::create(path)?;
78        let mut final_bytes = Vec::new();
79        final_bytes.write_all(&(self.metadata_bytes.len() as u64).to_le_bytes())?;
80        final_bytes.write_all(&self.metadata_bytes)?;
81        final_bytes.write_all(self.map.as_fst().as_bytes())?;
82
83        file.write_all(&final_bytes)?;
84        Ok(())
85    }
86
87    pub fn from_text(path: &Path) -> Result<Self> {
88        let ocb_bytes = ferrous_opencc_compiler::compile_dictionary(path)?;
89        Self::from_ocb_bytes(&ocb_bytes)
90    }
91
92    pub fn from_ocb_bytes(bytes: &[u8]) -> Result<Self> {
93        Self::from_reader(bytes)
94    }
95
96    fn from_reader<R: Read>(mut reader: R) -> Result<Self> {
97        let mut len_bytes = [0u8; 8];
98        reader.read_exact(&mut len_bytes)?;
99        let metadata_len = u64::from_le_bytes(len_bytes) as usize;
100
101        let mut metadata_bytes = vec![0; metadata_len];
102        reader.read_exact(&mut metadata_bytes)?;
103
104        rkyv::access::<ArchivedSerializableFstDict, rkyv::rancor::Error>(&metadata_bytes)?;
105
106        let mut fst_bytes: Vec<u8> = Vec::new();
107        reader.read_to_end(&mut fst_bytes)?;
108
109        let map = Map::new(fst_bytes)?;
110
111        Ok(Self {
112            map,
113            metadata_bytes,
114        })
115    }
116}
117
118impl Dictionary for FstDict {
119    fn match_prefix<'a>(&self, word: &'a str) -> Option<(&'a str, Vec<String>)> {
120        let fst = self.map.as_fst();
121        let mut node = fst.root();
122
123        let mut last_match: Option<(usize, u64)> = None;
124
125        let mut current_output: u64 = 0;
126
127        for (i, byte) in word.as_bytes().iter().enumerate() {
128            if let Some(trans_idx) = node.find_input(*byte) {
129                let t = node.transition(trans_idx);
130                current_output += t.out.value();
131                node = fst.node(t.addr);
132
133                if node.is_final() {
134                    let final_value = current_output + node.final_output().value();
135                    last_match = Some((i + 1, final_value));
136                }
137            } else {
138                break;
139            }
140        }
141
142        if let Some((len, value_index)) = last_match {
143            let metadata = unsafe {
144                rkyv::access_unchecked::<ArchivedSerializableFstDict>(&self.metadata_bytes)
145            };
146
147            if let Some(deltas) = metadata.values.get(value_index as usize) {
148                let key = &word[..len];
149                let result_values: Vec<String> =
150                    deltas.iter().map(|delta| apply_delta(key, delta)).collect();
151                return Some((key, result_values));
152            }
153        }
154
155        None
156    }
157
158    fn max_key_length(&self) -> usize {
159        let metadata =
160            unsafe { rkyv::access_unchecked::<ArchivedSerializableFstDict>(&self.metadata_bytes) };
161
162        rkyv::deserialize::<u32, rkyv::rancor::Error>(&metadata.max_key_length).unwrap() as usize
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use std::path::PathBuf;
169
170    use tempfile::tempdir;
171
172    use super::*;
173
174    fn create_test_dict_file(dir: &tempfile::TempDir, content: &str) -> PathBuf {
175        let dict_path = dir.path().join("test_dict.txt");
176        let mut file = File::create(&dict_path).unwrap();
177        writeln!(file, "{content}").unwrap();
178        dict_path
179    }
180
181    #[test]
182    fn test_from_text_and_match_prefix() {
183        let dir = tempdir().unwrap();
184        let dict_content = "一\t一\n一个\t一個\n一个半\t一個半\n世纪\t世紀";
185        let dict_path = create_test_dict_file(&dir, dict_content);
186
187        let dict = FstDict::from_text(&dict_path).unwrap();
188
189        let (key, values) = dict.match_prefix("一个").unwrap();
190        assert_eq!(key, "一个");
191        let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
192        assert_eq!(values_str, ["一個"]);
193        let (key, values) = dict.match_prefix("一个半小时").unwrap();
194        assert_eq!(key, "一个半");
195        let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
196        assert_eq!(values_str, ["一個半"]);
197        let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
198        assert_eq!(values_str, ["一個半"]);
199
200        let (key, values) = dict.match_prefix("世纪之交").unwrap();
201        assert_eq!(key, "世纪");
202        let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
203        assert_eq!(values_str, ["世紀"]);
204        let (key, values) = dict.match_prefix("一").unwrap();
205        assert_eq!(key, "一");
206        let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
207        assert_eq!(values_str, ["一"]);
208        let (key, values) = dict.match_prefix("一").unwrap();
209        assert_eq!(key, "一");
210        let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
211        assert_eq!(values_str, ["一"]);
212    }
213
214    #[test]
215    fn test_serialization_and_deserialization() {
216        let dir = tempdir().unwrap();
217        let dict_content = "你好\tHello\n世界\tWorld";
218        let txt_path = create_test_dict_file(&dir, dict_content);
219        let ocb_path = txt_path.with_extension("ocb");
220
221        let dict_from_text = FstDict::from_text(&txt_path).unwrap();
222        dict_from_text.serialize_to_file(&ocb_path).unwrap();
223
224        assert!(ocb_path.exists());
225
226        let dict_from_ocb = FstDict::new(&txt_path).unwrap();
227
228        let (key, values) = dict_from_ocb.match_prefix("你好世界").unwrap();
229        assert_eq!(key, "你好");
230        let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
231        assert_eq!(values_str, ["Hello"]);
232
233        let (key, values) = dict_from_ocb.match_prefix("你好世界").unwrap();
234        assert_eq!(key, "你好");
235        let values_str: Vec<&str> = values.iter().map(AsRef::as_ref).collect();
236        assert_eq!(values_str, ["Hello"]);
237    }
238}