lindera_dictionary/dictionary_builder/
user_dictionary.rs

1use std::collections::BTreeMap;
2use std::fs;
3use std::fs::File;
4use std::io;
5use std::io::Write;
6use std::path::Path;
7
8use byteorder::{LittleEndian, WriteBytesExt};
9use csv::StringRecord;
10use derive_builder::Builder;
11use log::debug;
12use yada::builder::DoubleArrayBuilder;
13
14use crate::LinderaResult;
15use crate::dictionary::UserDictionary;
16use crate::dictionary::prefix_dictionary::PrefixDictionary;
17use crate::error::LinderaErrorKind;
18use crate::viterbi::{WordEntry, WordId};
19
20type StringRecordProcessor = Option<Box<dyn Fn(&StringRecord) -> LinderaResult<Vec<String>>>>;
21
22#[derive(Builder)]
23#[builder(pattern = "owned")]
24#[builder(name = UserDictionaryBuilderOptions)]
25#[builder(build_fn(name = "builder"))]
26pub struct UserDictionaryBuilder {
27    #[builder(default = "3")]
28    user_dictionary_fields_num: usize,
29    #[builder(default = "12")]
30    dictionary_fields_num: usize,
31    #[builder(default = "-10000")]
32    default_word_cost: i16,
33    #[builder(default = "0")]
34    default_left_context_id: u16,
35    #[builder(default = "0")]
36    default_right_context_id: u16,
37    #[builder(default = "true")]
38    flexible_csv: bool,
39    #[builder(setter(strip_option), default = "None")]
40    user_dictionary_handler: StringRecordProcessor,
41}
42
43impl UserDictionaryBuilder {
44    pub fn build(&self, input_file: &Path) -> LinderaResult<UserDictionary> {
45        debug!("reading {input_file:?}");
46
47        let mut rdr = csv::ReaderBuilder::new()
48            .has_headers(false)
49            .flexible(self.flexible_csv)
50            .from_path(input_file)
51            .map_err(|err| {
52                LinderaErrorKind::Io
53                    .with_error(anyhow::anyhow!(err))
54                    .add_context(format!(
55                        "Failed to open user dictionary CSV file: {input_file:?}"
56                    ))
57            })?;
58
59        let mut rows: Vec<StringRecord> = vec![];
60        for (line_num, result) in rdr.records().enumerate() {
61            let record = result.map_err(|err| {
62                LinderaErrorKind::Content
63                    .with_error(anyhow::anyhow!(err))
64                    .add_context(format!(
65                        "Failed to parse CSV record at line {} in file: {:?}",
66                        line_num + 1,
67                        input_file
68                    ))
69            })?;
70            rows.push(record);
71        }
72        rows.sort_by_key(|row| row[0].to_string());
73
74        let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
75
76        for (row_id, row) in rows.iter().enumerate() {
77            let surface = row[0].to_string();
78            let word_cost = if row.len() == self.user_dictionary_fields_num {
79                self.default_word_cost
80            } else {
81                row[3].parse::<i16>().map_err(|_err| {
82                    LinderaErrorKind::Parse
83                        .with_error(anyhow::anyhow!("failed to parse word cost"))
84                        .add_context(format!(
85                            "Invalid word cost '{}' at row {} (surface: '{}')",
86                            &row[3],
87                            row_id + 1,
88                            &row[0]
89                        ))
90                })?
91            };
92            let (left_id, right_id) = if row.len() == self.user_dictionary_fields_num {
93                (self.default_left_context_id, self.default_right_context_id)
94            } else {
95                (
96                    row[1].parse::<u16>().map_err(|_err| {
97                        LinderaErrorKind::Parse
98                            .with_error(anyhow::anyhow!("failed to parse left context id"))
99                            .add_context(format!(
100                                "Invalid left context ID '{}' at row {} (surface: '{}')",
101                                &row[1],
102                                row_id + 1,
103                                &row[0]
104                            ))
105                    })?,
106                    row[2].parse::<u16>().map_err(|_err| {
107                        LinderaErrorKind::Parse
108                            .with_error(anyhow::anyhow!("failed to parse right context id"))
109                            .add_context(format!(
110                                "Invalid right context ID '{}' at row {} (surface: '{}')",
111                                &row[2],
112                                row_id + 1,
113                                &row[0]
114                            ))
115                    })?,
116                )
117            };
118
119            word_entry_map.entry(surface).or_default().push(WordEntry {
120                word_id: WordId {
121                    id: row_id as u32,
122                    is_system: false,
123                },
124                word_cost,
125                left_id,
126                right_id,
127            });
128        }
129
130        let mut words_data = Vec::<u8>::new();
131        let mut words_idx_data = Vec::<u8>::new();
132        for row in rows.iter() {
133            let word_detail = if row.len() == self.user_dictionary_fields_num {
134                if let Some(handler) = &self.user_dictionary_handler {
135                    handler(row)?
136                } else {
137                    row.iter()
138                        .skip(1)
139                        .map(|s| s.to_string())
140                        .collect::<Vec<String>>()
141                }
142            } else if row.len() >= self.dictionary_fields_num {
143                let mut tmp_word_detail = Vec::new();
144                for item in row.iter().skip(4) {
145                    tmp_word_detail.push(item.to_string());
146                }
147                tmp_word_detail
148            } else {
149                return Err(LinderaErrorKind::Content
150                    .with_error(anyhow::anyhow!(
151                        "user dictionary should be a CSV with {} or {}+ fields",
152                        self.user_dictionary_fields_num,
153                        self.dictionary_fields_num
154                    ))
155                    .add_context(format!(
156                        "Row {} has {} fields (surface: '{}')",
157                        rows.iter().position(|r| std::ptr::eq(r, row)).unwrap_or(0) + 1,
158                        row.len(),
159                        row.get(0).unwrap_or("<empty>")
160                    )));
161            };
162
163            let offset = words_data.len();
164            words_idx_data
165                .write_u32::<LittleEndian>(offset as u32)
166                .map_err(|err| {
167                    LinderaErrorKind::Io
168                        .with_error(anyhow::anyhow!(err))
169                        .add_context("Failed to write word offset to user dictionary words index")
170                })?;
171
172            // Store word details as null-separated string (like main dictionary)
173            let joined_details = word_detail.join("\0");
174            let joined_details_len = u32::try_from(joined_details.len()).map_err(|err| {
175                LinderaErrorKind::Serialize
176                    .with_error(anyhow::anyhow!(err))
177                    .add_context(format!(
178                        "Word details length too large: {} bytes for word '{}'",
179                        joined_details.len(),
180                        row.get(0).unwrap_or("<unknown>")
181                    ))
182            })?;
183
184            words_data
185                .write_u32::<LittleEndian>(joined_details_len)
186                .map_err(|err| {
187                    LinderaErrorKind::Serialize
188                        .with_error(anyhow::anyhow!(err))
189                        .add_context(
190                            "Failed to write word details length to user dictionary words data",
191                        )
192                })?;
193            words_data
194                .write_all(joined_details.as_bytes())
195                .map_err(|err| {
196                    LinderaErrorKind::Serialize
197                        .with_error(anyhow::anyhow!(err))
198                        .add_context("Failed to write word details to user dictionary words data")
199                })?;
200        }
201
202        let mut id = 0u32;
203
204        // building double array trie
205        let mut keyset: Vec<(&[u8], u32)> = vec![];
206        for (key, word_entries) in &word_entry_map {
207            let len = word_entries.len() as u32;
208            let val = (id << 5) | len;
209            keyset.push((key.as_bytes(), val));
210            id += len;
211        }
212        let da_bytes = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
213            LinderaErrorKind::Build
214                .with_error(anyhow::anyhow!("DoubleArray build error."))
215                .add_context(format!(
216                    "Failed to build DoubleArray with {} keys for user dictionary",
217                    keyset.len()
218                ))
219        })?;
220
221        // building values
222        let mut vals_data = Vec::<u8>::new();
223        for word_entries in word_entry_map.values() {
224            for word_entry in word_entries {
225                word_entry.serialize(&mut vals_data).map_err(|err| {
226                    LinderaErrorKind::Serialize
227                        .with_error(anyhow::anyhow!(err))
228                        .add_context(format!(
229                            "Failed to serialize user dictionary word entry (id: {})",
230                            word_entry.word_id.id
231                        ))
232                })?;
233            }
234        }
235
236        let dict = PrefixDictionary::load(da_bytes, vals_data, words_idx_data, words_data, false);
237
238        Ok(UserDictionary { dict })
239    }
240}
241
242pub fn build_user_dictionary(user_dict: UserDictionary, output_file: &Path) -> LinderaResult<()> {
243    let parent_dir = match output_file.parent() {
244        Some(parent_dir) => parent_dir,
245        None => {
246            return Err(LinderaErrorKind::Io
247                .with_error(anyhow::anyhow!(
248                    "failed to get parent directory of output file"
249                ))
250                .add_context(format!("Invalid output file path: {output_file:?}")));
251        }
252    };
253    fs::create_dir_all(parent_dir).map_err(|err| {
254        LinderaErrorKind::Io
255            .with_error(anyhow::anyhow!(err))
256            .add_context(format!("Failed to create parent directory: {parent_dir:?}"))
257    })?;
258
259    let mut wtr = io::BufWriter::new(File::create(output_file).map_err(|err| {
260        LinderaErrorKind::Io
261            .with_error(anyhow::anyhow!(err))
262            .add_context(format!(
263                "Failed to create user dictionary output file: {output_file:?}"
264            ))
265    })?);
266    bincode::serde::encode_into_std_write(&user_dict, &mut wtr, bincode::config::legacy())
267        .map_err(|err| {
268            LinderaErrorKind::Serialize
269                .with_error(anyhow::anyhow!(err))
270                .add_context(format!(
271                    "Failed to serialize user dictionary to file: {output_file:?}"
272                ))
273        })?;
274    wtr.flush().map_err(|err| {
275        LinderaErrorKind::Io
276            .with_error(anyhow::anyhow!(err))
277            .add_context(format!(
278                "Failed to flush user dictionary output file: {output_file:?}"
279            ))
280    })?;
281
282    Ok(())
283}