lindera_dictionary/builder/
user_dictionary.rs1use std::collections::BTreeMap;
2use std::fs;
3use std::fs::File;
4use std::io;
5use std::io::Write;
6use std::path::Path;
7
8use byteorder::{LittleEndian, WriteBytesExt};
9use csv::StringRecord;
10use derive_builder::Builder;
11use log::debug;
12use yada::builder::DoubleArrayBuilder;
13
14use crate::LinderaResult;
15use crate::dictionary::UserDictionary;
16use crate::dictionary::prefix_dictionary::PrefixDictionary;
17use crate::error::LinderaErrorKind;
18use crate::viterbi::WordEntry;
19
20type StringRecordProcessor = Option<Box<dyn Fn(&StringRecord) -> LinderaResult<Vec<String>>>>;
21
22#[derive(Builder)]
23#[builder(pattern = "owned")]
24#[builder(name = UserDictionaryBuilderOptions)]
25#[builder(build_fn(name = "builder"))]
26pub struct UserDictionaryBuilder {
27 #[builder(default = "3")]
28 user_dictionary_fields_num: usize,
29 #[builder(default = "12")]
30 dictionary_fields_num: usize,
31 #[builder(default = "-10000")]
32 default_word_cost: i16,
33 #[builder(default = "0")]
34 default_left_context_id: u16,
35 #[builder(default = "0")]
36 default_right_context_id: u16,
37 #[builder(default = "true")]
38 flexible_csv: bool,
39 #[builder(setter(strip_option), default = "None")]
40 user_dictionary_handler: StringRecordProcessor,
41}
42
43impl UserDictionaryBuilder {
44 pub fn build(&self, input_file: &Path) -> LinderaResult<UserDictionary> {
45 debug!("reading {input_file:?}");
46
47 let mut rdr = csv::ReaderBuilder::new()
48 .has_headers(false)
49 .flexible(self.flexible_csv)
50 .from_path(input_file)
51 .map_err(|err| {
52 LinderaErrorKind::Io
53 .with_error(anyhow::anyhow!(err))
54 .add_context(format!(
55 "Failed to open user dictionary CSV file: {input_file:?}"
56 ))
57 })?;
58
59 let mut rows: Vec<StringRecord> = vec![];
60 for (line_num, result) in rdr.records().enumerate() {
61 let record = result.map_err(|err| {
62 LinderaErrorKind::Content
63 .with_error(anyhow::anyhow!(err))
64 .add_context(format!(
65 "Failed to parse CSV record at line {} in file: {:?}",
66 line_num + 1,
67 input_file
68 ))
69 })?;
70 rows.push(record);
71 }
72 rows.sort_by_key(|row| row[0].to_string());
73
74 let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
75
76 for (row_id, row) in rows.iter().enumerate() {
77 let surface = row[0].to_string();
78 let word_cost = if row.len() == self.user_dictionary_fields_num {
79 self.default_word_cost
80 } else {
81 row[3].parse::<i16>().map_err(|_err| {
82 LinderaErrorKind::Parse
83 .with_error(anyhow::anyhow!("failed to parse word cost"))
84 .add_context(format!(
85 "Invalid word cost '{}' at row {} (surface: '{}')",
86 &row[3],
87 row_id + 1,
88 &row[0]
89 ))
90 })?
91 };
92 let (left_id, right_id) = if row.len() == self.user_dictionary_fields_num {
93 (self.default_left_context_id, self.default_right_context_id)
94 } else {
95 (
96 row[1].parse::<u16>().map_err(|_err| {
97 LinderaErrorKind::Parse
98 .with_error(anyhow::anyhow!("failed to parse left context id"))
99 .add_context(format!(
100 "Invalid left context ID '{}' at row {} (surface: '{}')",
101 &row[1],
102 row_id + 1,
103 &row[0]
104 ))
105 })?,
106 row[2].parse::<u16>().map_err(|_err| {
107 LinderaErrorKind::Parse
108 .with_error(anyhow::anyhow!("failed to parse right context id"))
109 .add_context(format!(
110 "Invalid right context ID '{}' at row {} (surface: '{}')",
111 &row[2],
112 row_id + 1,
113 &row[0]
114 ))
115 })?,
116 )
117 };
118
119 word_entry_map.entry(surface).or_default().push(WordEntry {
120 word_id: crate::viterbi::WordId::new(crate::viterbi::LexType::User, row_id as u32),
121 word_cost,
122 left_id,
123 right_id,
124 });
125 }
126
127 let mut words_data = Vec::<u8>::new();
128 let mut words_idx_data = Vec::<u8>::new();
129 for row in rows.iter() {
130 let word_detail = if row.len() == self.user_dictionary_fields_num {
131 if let Some(handler) = &self.user_dictionary_handler {
132 handler(row)?
133 } else {
134 row.iter()
135 .skip(1)
136 .map(|s| s.to_string())
137 .collect::<Vec<String>>()
138 }
139 } else if row.len() >= self.dictionary_fields_num {
140 let mut tmp_word_detail = Vec::new();
141 for item in row.iter().skip(4) {
142 tmp_word_detail.push(item.to_string());
143 }
144 tmp_word_detail
145 } else {
146 return Err(LinderaErrorKind::Content
147 .with_error(anyhow::anyhow!(
148 "user dictionary should be a CSV with {} or {}+ fields",
149 self.user_dictionary_fields_num,
150 self.dictionary_fields_num
151 ))
152 .add_context(format!(
153 "Row {} has {} fields (surface: '{}')",
154 rows.iter().position(|r| std::ptr::eq(r, row)).unwrap_or(0) + 1,
155 row.len(),
156 row.get(0).unwrap_or("<empty>")
157 )));
158 };
159
160 let offset = words_data.len();
161 words_idx_data
162 .write_u32::<LittleEndian>(offset as u32)
163 .map_err(|err| {
164 LinderaErrorKind::Io
165 .with_error(anyhow::anyhow!(err))
166 .add_context("Failed to write word offset to user dictionary words index")
167 })?;
168
169 let joined_details = word_detail.join("\0");
171 let joined_details_len = u32::try_from(joined_details.len()).map_err(|err| {
172 LinderaErrorKind::Serialize
173 .with_error(anyhow::anyhow!(err))
174 .add_context(format!(
175 "Word details length too large: {} bytes for word '{}'",
176 joined_details.len(),
177 row.get(0).unwrap_or("<unknown>")
178 ))
179 })?;
180
181 words_data
182 .write_u32::<LittleEndian>(joined_details_len)
183 .map_err(|err| {
184 LinderaErrorKind::Serialize
185 .with_error(anyhow::anyhow!(err))
186 .add_context(
187 "Failed to write word details length to user dictionary words data",
188 )
189 })?;
190 words_data
191 .write_all(joined_details.as_bytes())
192 .map_err(|err| {
193 LinderaErrorKind::Serialize
194 .with_error(anyhow::anyhow!(err))
195 .add_context("Failed to write word details to user dictionary words data")
196 })?;
197 }
198
199 let mut id = 0u32;
200
201 let mut keyset: Vec<(&[u8], u32)> = vec![];
203 for (key, word_entries) in &word_entry_map {
204 let len = word_entries.len() as u32;
205 let val = (id << 5) | len;
206 keyset.push((key.as_bytes(), val));
207 id += len;
208 }
209 let da_bytes = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
210 LinderaErrorKind::Build
211 .with_error(anyhow::anyhow!("DoubleArray build error."))
212 .add_context(format!(
213 "Failed to build DoubleArray with {} keys for user dictionary",
214 keyset.len()
215 ))
216 })?;
217
218 let mut vals_data = Vec::<u8>::new();
220 for word_entries in word_entry_map.values() {
221 for word_entry in word_entries {
222 word_entry.serialize(&mut vals_data).map_err(|err| {
223 LinderaErrorKind::Serialize
224 .with_error(anyhow::anyhow!(err))
225 .add_context(format!(
226 "Failed to serialize user dictionary word entry (id: {})",
227 word_entry.word_id.id
228 ))
229 })?;
230 }
231 }
232
233 let dict = PrefixDictionary::load(da_bytes, vals_data, words_idx_data, words_data, false);
234
235 Ok(UserDictionary { dict })
236 }
237}
238
239pub fn build_user_dictionary(user_dict: UserDictionary, output_file: &Path) -> LinderaResult<()> {
240 let parent_dir = match output_file.parent() {
241 Some(parent_dir) => parent_dir,
242 None => {
243 return Err(LinderaErrorKind::Io
244 .with_error(anyhow::anyhow!(
245 "failed to get parent directory of output file"
246 ))
247 .add_context(format!("Invalid output file path: {output_file:?}")));
248 }
249 };
250 fs::create_dir_all(parent_dir).map_err(|err| {
251 LinderaErrorKind::Io
252 .with_error(anyhow::anyhow!(err))
253 .add_context(format!("Failed to create parent directory: {parent_dir:?}"))
254 })?;
255
256 let mut wtr = io::BufWriter::new(File::create(output_file).map_err(|err| {
257 LinderaErrorKind::Io
258 .with_error(anyhow::anyhow!(err))
259 .add_context(format!(
260 "Failed to create user dictionary output file: {output_file:?}"
261 ))
262 })?);
263 let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&user_dict).map_err(|err| {
264 LinderaErrorKind::Serialize
265 .with_error(anyhow::anyhow!(err))
266 .add_context(format!(
267 "Failed to serialize user dictionary to file: {output_file:?}"
268 ))
269 })?;
270 wtr.write_all(&bytes).map_err(|err| {
271 LinderaErrorKind::Io
272 .with_error(anyhow::anyhow!(err))
273 .add_context(format!(
274 "Failed to write user dictionary to file: {output_file:?}"
275 ))
276 })?;
277 wtr.flush().map_err(|err| {
278 LinderaErrorKind::Io
279 .with_error(anyhow::anyhow!(err))
280 .add_context(format!(
281 "Failed to flush user dictionary output file: {output_file:?}"
282 ))
283 })?;
284
285 Ok(())
286}