lindera_dictionary/builder/
user_dictionary.rs

1use std::collections::BTreeMap;
2use std::fs;
3use std::fs::File;
4use std::io;
5use std::io::Write;
6use std::path::Path;
7
8use byteorder::{LittleEndian, WriteBytesExt};
9use csv::StringRecord;
10use derive_builder::Builder;
11use log::debug;
12use yada::builder::DoubleArrayBuilder;
13
14use crate::LinderaResult;
15use crate::dictionary::UserDictionary;
16use crate::dictionary::prefix_dictionary::PrefixDictionary;
17use crate::error::LinderaErrorKind;
18use crate::viterbi::WordEntry;
19
20type StringRecordProcessor = Option<Box<dyn Fn(&StringRecord) -> LinderaResult<Vec<String>>>>;
21
22#[derive(Builder)]
23#[builder(pattern = "owned")]
24#[builder(name = UserDictionaryBuilderOptions)]
25#[builder(build_fn(name = "builder"))]
26pub struct UserDictionaryBuilder {
27    #[builder(default = "3")]
28    user_dictionary_fields_num: usize,
29    #[builder(default = "12")]
30    dictionary_fields_num: usize,
31    #[builder(default = "-10000")]
32    default_word_cost: i16,
33    #[builder(default = "0")]
34    default_left_context_id: u16,
35    #[builder(default = "0")]
36    default_right_context_id: u16,
37    #[builder(default = "true")]
38    flexible_csv: bool,
39    #[builder(setter(strip_option), default = "None")]
40    user_dictionary_handler: StringRecordProcessor,
41}
42
43impl UserDictionaryBuilder {
44    pub fn build(&self, input_file: &Path) -> LinderaResult<UserDictionary> {
45        debug!("reading {input_file:?}");
46
47        let mut rdr = csv::ReaderBuilder::new()
48            .has_headers(false)
49            .flexible(self.flexible_csv)
50            .from_path(input_file)
51            .map_err(|err| {
52                LinderaErrorKind::Io
53                    .with_error(anyhow::anyhow!(err))
54                    .add_context(format!(
55                        "Failed to open user dictionary CSV file: {input_file:?}"
56                    ))
57            })?;
58
59        let mut rows: Vec<StringRecord> = vec![];
60        for (line_num, result) in rdr.records().enumerate() {
61            let record = result.map_err(|err| {
62                LinderaErrorKind::Content
63                    .with_error(anyhow::anyhow!(err))
64                    .add_context(format!(
65                        "Failed to parse CSV record at line {} in file: {:?}",
66                        line_num + 1,
67                        input_file
68                    ))
69            })?;
70            rows.push(record);
71        }
72        rows.sort_by_key(|row| row[0].to_string());
73
74        let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
75
76        for (row_id, row) in rows.iter().enumerate() {
77            let surface = row[0].to_string();
78            let word_cost = if row.len() == self.user_dictionary_fields_num {
79                self.default_word_cost
80            } else {
81                row[3].parse::<i16>().map_err(|_err| {
82                    LinderaErrorKind::Parse
83                        .with_error(anyhow::anyhow!("failed to parse word cost"))
84                        .add_context(format!(
85                            "Invalid word cost '{}' at row {} (surface: '{}')",
86                            &row[3],
87                            row_id + 1,
88                            &row[0]
89                        ))
90                })?
91            };
92            let (left_id, right_id) = if row.len() == self.user_dictionary_fields_num {
93                (self.default_left_context_id, self.default_right_context_id)
94            } else {
95                (
96                    row[1].parse::<u16>().map_err(|_err| {
97                        LinderaErrorKind::Parse
98                            .with_error(anyhow::anyhow!("failed to parse left context id"))
99                            .add_context(format!(
100                                "Invalid left context ID '{}' at row {} (surface: '{}')",
101                                &row[1],
102                                row_id + 1,
103                                &row[0]
104                            ))
105                    })?,
106                    row[2].parse::<u16>().map_err(|_err| {
107                        LinderaErrorKind::Parse
108                            .with_error(anyhow::anyhow!("failed to parse right context id"))
109                            .add_context(format!(
110                                "Invalid right context ID '{}' at row {} (surface: '{}')",
111                                &row[2],
112                                row_id + 1,
113                                &row[0]
114                            ))
115                    })?,
116                )
117            };
118
119            word_entry_map.entry(surface).or_default().push(WordEntry {
120                word_id: crate::viterbi::WordId::new(crate::viterbi::LexType::User, row_id as u32),
121                word_cost,
122                left_id,
123                right_id,
124            });
125        }
126
127        let mut words_data = Vec::<u8>::new();
128        let mut words_idx_data = Vec::<u8>::new();
129        for row in rows.iter() {
130            let word_detail = if row.len() == self.user_dictionary_fields_num {
131                if let Some(handler) = &self.user_dictionary_handler {
132                    handler(row)?
133                } else {
134                    row.iter()
135                        .skip(1)
136                        .map(|s| s.to_string())
137                        .collect::<Vec<String>>()
138                }
139            } else if row.len() >= self.dictionary_fields_num {
140                let mut tmp_word_detail = Vec::new();
141                for item in row.iter().skip(4) {
142                    tmp_word_detail.push(item.to_string());
143                }
144                tmp_word_detail
145            } else {
146                return Err(LinderaErrorKind::Content
147                    .with_error(anyhow::anyhow!(
148                        "user dictionary should be a CSV with {} or {}+ fields",
149                        self.user_dictionary_fields_num,
150                        self.dictionary_fields_num
151                    ))
152                    .add_context(format!(
153                        "Row {} has {} fields (surface: '{}')",
154                        rows.iter().position(|r| std::ptr::eq(r, row)).unwrap_or(0) + 1,
155                        row.len(),
156                        row.get(0).unwrap_or("<empty>")
157                    )));
158            };
159
160            let offset = words_data.len();
161            words_idx_data
162                .write_u32::<LittleEndian>(offset as u32)
163                .map_err(|err| {
164                    LinderaErrorKind::Io
165                        .with_error(anyhow::anyhow!(err))
166                        .add_context("Failed to write word offset to user dictionary words index")
167                })?;
168
169            // Store word details as null-separated string (like main dictionary)
170            let joined_details = word_detail.join("\0");
171            let joined_details_len = u32::try_from(joined_details.len()).map_err(|err| {
172                LinderaErrorKind::Serialize
173                    .with_error(anyhow::anyhow!(err))
174                    .add_context(format!(
175                        "Word details length too large: {} bytes for word '{}'",
176                        joined_details.len(),
177                        row.get(0).unwrap_or("<unknown>")
178                    ))
179            })?;
180
181            words_data
182                .write_u32::<LittleEndian>(joined_details_len)
183                .map_err(|err| {
184                    LinderaErrorKind::Serialize
185                        .with_error(anyhow::anyhow!(err))
186                        .add_context(
187                            "Failed to write word details length to user dictionary words data",
188                        )
189                })?;
190            words_data
191                .write_all(joined_details.as_bytes())
192                .map_err(|err| {
193                    LinderaErrorKind::Serialize
194                        .with_error(anyhow::anyhow!(err))
195                        .add_context("Failed to write word details to user dictionary words data")
196                })?;
197        }
198
199        let mut id = 0u32;
200
201        // building double array trie
202        let mut keyset: Vec<(&[u8], u32)> = vec![];
203        for (key, word_entries) in &word_entry_map {
204            let len = word_entries.len() as u32;
205            let val = (id << 5) | len;
206            keyset.push((key.as_bytes(), val));
207            id += len;
208        }
209        let da_bytes = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
210            LinderaErrorKind::Build
211                .with_error(anyhow::anyhow!("DoubleArray build error."))
212                .add_context(format!(
213                    "Failed to build DoubleArray with {} keys for user dictionary",
214                    keyset.len()
215                ))
216        })?;
217
218        // building values
219        let mut vals_data = Vec::<u8>::new();
220        for word_entries in word_entry_map.values() {
221            for word_entry in word_entries {
222                word_entry.serialize(&mut vals_data).map_err(|err| {
223                    LinderaErrorKind::Serialize
224                        .with_error(anyhow::anyhow!(err))
225                        .add_context(format!(
226                            "Failed to serialize user dictionary word entry (id: {})",
227                            word_entry.word_id.id
228                        ))
229                })?;
230            }
231        }
232
233        let dict = PrefixDictionary::load(da_bytes, vals_data, words_idx_data, words_data, false);
234
235        Ok(UserDictionary { dict })
236    }
237}
238
239pub fn build_user_dictionary(user_dict: UserDictionary, output_file: &Path) -> LinderaResult<()> {
240    let parent_dir = match output_file.parent() {
241        Some(parent_dir) => parent_dir,
242        None => {
243            return Err(LinderaErrorKind::Io
244                .with_error(anyhow::anyhow!(
245                    "failed to get parent directory of output file"
246                ))
247                .add_context(format!("Invalid output file path: {output_file:?}")));
248        }
249    };
250    fs::create_dir_all(parent_dir).map_err(|err| {
251        LinderaErrorKind::Io
252            .with_error(anyhow::anyhow!(err))
253            .add_context(format!("Failed to create parent directory: {parent_dir:?}"))
254    })?;
255
256    let mut wtr = io::BufWriter::new(File::create(output_file).map_err(|err| {
257        LinderaErrorKind::Io
258            .with_error(anyhow::anyhow!(err))
259            .add_context(format!(
260                "Failed to create user dictionary output file: {output_file:?}"
261            ))
262    })?);
263    let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&user_dict).map_err(|err| {
264        LinderaErrorKind::Serialize
265            .with_error(anyhow::anyhow!(err))
266            .add_context(format!(
267                "Failed to serialize user dictionary to file: {output_file:?}"
268            ))
269    })?;
270    wtr.write_all(&bytes).map_err(|err| {
271        LinderaErrorKind::Io
272            .with_error(anyhow::anyhow!(err))
273            .add_context(format!(
274                "Failed to write user dictionary to file: {output_file:?}"
275            ))
276    })?;
277    wtr.flush().map_err(|err| {
278        LinderaErrorKind::Io
279            .with_error(anyhow::anyhow!(err))
280            .add_context(format!(
281                "Failed to flush user dictionary output file: {output_file:?}"
282            ))
283    })?;
284
285    Ok(())
286}