lindera_dictionary_builder/
dict.rs

1use std::borrow::Cow;
2use std::collections::BTreeMap;
3use std::fs::File;
4use std::io::Write;
5use std::io::{self, Read};
6use std::path::{Path, PathBuf};
7use std::str::FromStr;
8
9use anyhow::anyhow;
10use byteorder::{LittleEndian, WriteBytesExt};
11use csv::StringRecord;
12use derive_builder::Builder;
13use encoding_rs::{Encoding, UTF_8};
14use encoding_rs_io::DecodeReaderBytesBuilder;
15use glob::glob;
16use log::{debug, warn};
17use yada::builder::DoubleArrayBuilder;
18
19use lindera_core::error::LinderaErrorKind;
20use lindera_core::word_entry::{WordEntry, WordId};
21use lindera_core::LinderaResult;
22use lindera_decompress::Algorithm;
23
24use crate::utils::compress_write;
25
26#[derive(Builder, Debug)]
27#[builder(name = "DictBuilderOptions")]
28#[builder(build_fn(name = "builder"))]
29pub struct DictBuilder {
30    #[builder(default = "true")]
31    flexible_csv: bool,
32    /* If set to UTF-8, it can also read UTF-16 files with BOM. */
33    #[builder(default = "\"UTF-8\".into()", setter(into))]
34    encoding: Cow<'static, str>,
35    #[builder(default = "Algorithm::Deflate")]
36    compress_algorithm: Algorithm,
37    #[builder(default = "false")]
38    normalize_details: bool,
39    #[builder(default = "false")]
40    skip_invalid_cost_or_id: bool,
41}
42
43impl DictBuilder {
44    pub fn build(&self, input_dir: &Path, output_dir: &Path) -> LinderaResult<()> {
45        let pattern = if let Some(path) = input_dir.to_str() {
46            format!("{}/*.csv", path)
47        } else {
48            return Err(
49                LinderaErrorKind::Io.with_error(anyhow::anyhow!("Failed to convert path to &str."))
50            );
51        };
52
53        let mut filenames: Vec<PathBuf> = Vec::new();
54        for entry in
55            glob(&pattern).map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?
56        {
57            match entry {
58                Ok(path) => {
59                    if let Some(filename) = path.file_name() {
60                        filenames.push(Path::new(input_dir).join(filename));
61                    } else {
62                        return Err(LinderaErrorKind::Io
63                            .with_error(anyhow::anyhow!("failed to get filename")));
64                    };
65                }
66                Err(err) => return Err(LinderaErrorKind::Content.with_error(anyhow!(err))),
67            }
68        }
69
70        let encoding = Encoding::for_label_no_replacement(self.encoding.as_bytes());
71        let encoding = encoding.ok_or_else(|| {
72            LinderaErrorKind::Decode.with_error(anyhow!("Invalid encoding: {}", self.encoding))
73        })?;
74
75        let mut rows: Vec<StringRecord> = vec![];
76        for filename in filenames {
77            debug!("reading {:?}", filename);
78
79            let file = File::open(filename)
80                .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
81            let reader: Box<dyn Read> = if encoding == UTF_8 {
82                Box::new(file)
83            } else {
84                Box::new(
85                    DecodeReaderBytesBuilder::new()
86                        .encoding(Some(encoding))
87                        .build(file),
88                )
89            };
90            let mut rdr = csv::ReaderBuilder::new()
91                .has_headers(false)
92                .flexible(self.flexible_csv)
93                .from_reader(reader);
94
95            for result in rdr.records() {
96                let record =
97                    result.map_err(|err| LinderaErrorKind::Content.with_error(anyhow!(err)))?;
98                rows.push(record);
99            }
100        }
101
102        if self.normalize_details {
103            rows.sort_by_key(|row| normalize(&row[0]));
104        } else {
105            rows.sort_by(|a, b| a[0].cmp(&b[0]))
106        }
107
108        let wtr_da_path = output_dir.join(Path::new("dict.da"));
109        let mut wtr_da = io::BufWriter::new(
110            File::create(wtr_da_path)
111                .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?,
112        );
113
114        let wtr_vals_path = output_dir.join(Path::new("dict.vals"));
115        let mut wtr_vals = io::BufWriter::new(
116            File::create(wtr_vals_path)
117                .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?,
118        );
119
120        let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
121
122        for (row_id, row) in rows.iter().enumerate() {
123            let word_cost = match i16::from_str(row[3].trim()) {
124                Ok(wc) => wc,
125                Err(_err) => {
126                    if self.skip_invalid_cost_or_id {
127                        warn!("failed to parse word_cost: {:?}", row);
128                        continue;
129                    } else {
130                        return Err(LinderaErrorKind::Parse
131                            .with_error(anyhow::anyhow!("failed to parse word_cost")));
132                    }
133                }
134            };
135            let left_id = match u16::from_str(row[1].trim()) {
136                Ok(lid) => lid,
137                Err(_err) => {
138                    if self.skip_invalid_cost_or_id {
139                        warn!("failed to parse left_id: {:?}", row);
140                        continue;
141                    } else {
142                        return Err(LinderaErrorKind::Parse
143                            .with_error(anyhow::anyhow!("failed to parse left_id")));
144                    }
145                }
146            };
147            let right_id = match u16::from_str(row[2].trim()) {
148                Ok(rid) => rid,
149                Err(_err) => {
150                    if self.skip_invalid_cost_or_id {
151                        warn!("failed to parse right_id: {:?}", row);
152                        continue;
153                    } else {
154                        return Err(LinderaErrorKind::Parse
155                            .with_error(anyhow::anyhow!("failed to parse right_id")));
156                    }
157                }
158            };
159            let key = if self.normalize_details {
160                normalize(&row[0])
161            } else {
162                row[0].to_string()
163            };
164            word_entry_map.entry(key).or_default().push(WordEntry {
165                word_id: WordId(row_id as u32, true),
166                word_cost,
167                left_id,
168                right_id,
169            });
170        }
171
172        let wtr_words_path = output_dir.join(Path::new("dict.words"));
173        let mut wtr_words = io::BufWriter::new(
174            File::create(wtr_words_path)
175                .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?,
176        );
177
178        let wtr_words_idx_path = output_dir.join(Path::new("dict.wordsidx"));
179        let mut wtr_words_idx = io::BufWriter::new(
180            File::create(wtr_words_idx_path)
181                .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?,
182        );
183
184        let mut words_buffer = Vec::new();
185        let mut words_idx_buffer = Vec::new();
186        for row in rows.iter() {
187            let offset = words_buffer.len();
188            words_idx_buffer
189                .write_u32::<LittleEndian>(offset as u32)
190                .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
191
192            let joined_details = if self.normalize_details {
193                row.iter()
194                    .skip(4)
195                    .map(|item| normalize(item))
196                    .collect::<Vec<String>>()
197                    .join("\0")
198            } else {
199                row.iter().skip(4).collect::<Vec<&str>>().join("\0")
200            };
201            let joined_details_len = u32::try_from(joined_details.as_bytes().len())
202                .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
203            words_buffer
204                .write_u32::<LittleEndian>(joined_details_len)
205                .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
206            words_buffer
207                .write_all(joined_details.as_bytes())
208                .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
209        }
210
211        compress_write(&words_buffer, self.compress_algorithm, &mut wtr_words)?;
212        compress_write(
213            &words_idx_buffer,
214            self.compress_algorithm,
215            &mut wtr_words_idx,
216        )?;
217
218        wtr_words
219            .flush()
220            .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
221        wtr_words_idx
222            .flush()
223            .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
224
225        let mut id = 0u32;
226
227        let mut keyset: Vec<(&[u8], u32)> = vec![];
228        for (key, word_entries) in &word_entry_map {
229            let len = word_entries.len() as u32;
230            let val = (id << 5) | len; // 27bit for word ID, 5bit for different parts of speech on the same surface.
231            keyset.push((key.as_bytes(), val));
232            id += len;
233        }
234
235        let da_bytes = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
236            LinderaErrorKind::Io.with_error(anyhow::anyhow!("DoubleArray build error."))
237        })?;
238
239        compress_write(&da_bytes, self.compress_algorithm, &mut wtr_da)?;
240
241        let mut vals_buffer = Vec::new();
242        for word_entries in word_entry_map.values() {
243            for word_entry in word_entries {
244                word_entry
245                    .serialize(&mut vals_buffer)
246                    .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
247            }
248        }
249
250        compress_write(&vals_buffer, self.compress_algorithm, &mut wtr_vals)?;
251
252        wtr_vals
253            .flush()
254            .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
255
256        Ok(())
257    }
258}
259
260fn normalize(text: &str) -> String {
261    text.to_string().replace('―', "—").replace('~', "〜")
262}