lindera_dictionary_builder/
dict.rs1use std::borrow::Cow;
2use std::collections::BTreeMap;
3use std::fs::File;
4use std::io::Write;
5use std::io::{self, Read};
6use std::path::{Path, PathBuf};
7use std::str::FromStr;
8
9use anyhow::anyhow;
10use byteorder::{LittleEndian, WriteBytesExt};
11use csv::StringRecord;
12use derive_builder::Builder;
13use encoding_rs::{Encoding, UTF_8};
14use encoding_rs_io::DecodeReaderBytesBuilder;
15use glob::glob;
16use log::{debug, warn};
17use yada::builder::DoubleArrayBuilder;
18
19use lindera_core::error::LinderaErrorKind;
20use lindera_core::word_entry::{WordEntry, WordId};
21use lindera_core::LinderaResult;
22use lindera_decompress::Algorithm;
23
24use crate::utils::compress_write;
25
26#[derive(Builder, Debug)]
27#[builder(name = "DictBuilderOptions")]
28#[builder(build_fn(name = "builder"))]
29pub struct DictBuilder {
30 #[builder(default = "true")]
31 flexible_csv: bool,
32 #[builder(default = "\"UTF-8\".into()", setter(into))]
34 encoding: Cow<'static, str>,
35 #[builder(default = "Algorithm::Deflate")]
36 compress_algorithm: Algorithm,
37 #[builder(default = "false")]
38 normalize_details: bool,
39 #[builder(default = "false")]
40 skip_invalid_cost_or_id: bool,
41}
42
43impl DictBuilder {
44 pub fn build(&self, input_dir: &Path, output_dir: &Path) -> LinderaResult<()> {
45 let pattern = if let Some(path) = input_dir.to_str() {
46 format!("{}/*.csv", path)
47 } else {
48 return Err(
49 LinderaErrorKind::Io.with_error(anyhow::anyhow!("Failed to convert path to &str."))
50 );
51 };
52
53 let mut filenames: Vec<PathBuf> = Vec::new();
54 for entry in
55 glob(&pattern).map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?
56 {
57 match entry {
58 Ok(path) => {
59 if let Some(filename) = path.file_name() {
60 filenames.push(Path::new(input_dir).join(filename));
61 } else {
62 return Err(LinderaErrorKind::Io
63 .with_error(anyhow::anyhow!("failed to get filename")));
64 };
65 }
66 Err(err) => return Err(LinderaErrorKind::Content.with_error(anyhow!(err))),
67 }
68 }
69
70 let encoding = Encoding::for_label_no_replacement(self.encoding.as_bytes());
71 let encoding = encoding.ok_or_else(|| {
72 LinderaErrorKind::Decode.with_error(anyhow!("Invalid encoding: {}", self.encoding))
73 })?;
74
75 let mut rows: Vec<StringRecord> = vec![];
76 for filename in filenames {
77 debug!("reading {:?}", filename);
78
79 let file = File::open(filename)
80 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
81 let reader: Box<dyn Read> = if encoding == UTF_8 {
82 Box::new(file)
83 } else {
84 Box::new(
85 DecodeReaderBytesBuilder::new()
86 .encoding(Some(encoding))
87 .build(file),
88 )
89 };
90 let mut rdr = csv::ReaderBuilder::new()
91 .has_headers(false)
92 .flexible(self.flexible_csv)
93 .from_reader(reader);
94
95 for result in rdr.records() {
96 let record =
97 result.map_err(|err| LinderaErrorKind::Content.with_error(anyhow!(err)))?;
98 rows.push(record);
99 }
100 }
101
102 if self.normalize_details {
103 rows.sort_by_key(|row| normalize(&row[0]));
104 } else {
105 rows.sort_by(|a, b| a[0].cmp(&b[0]))
106 }
107
108 let wtr_da_path = output_dir.join(Path::new("dict.da"));
109 let mut wtr_da = io::BufWriter::new(
110 File::create(wtr_da_path)
111 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?,
112 );
113
114 let wtr_vals_path = output_dir.join(Path::new("dict.vals"));
115 let mut wtr_vals = io::BufWriter::new(
116 File::create(wtr_vals_path)
117 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?,
118 );
119
120 let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
121
122 for (row_id, row) in rows.iter().enumerate() {
123 let word_cost = match i16::from_str(row[3].trim()) {
124 Ok(wc) => wc,
125 Err(_err) => {
126 if self.skip_invalid_cost_or_id {
127 warn!("failed to parse word_cost: {:?}", row);
128 continue;
129 } else {
130 return Err(LinderaErrorKind::Parse
131 .with_error(anyhow::anyhow!("failed to parse word_cost")));
132 }
133 }
134 };
135 let left_id = match u16::from_str(row[1].trim()) {
136 Ok(lid) => lid,
137 Err(_err) => {
138 if self.skip_invalid_cost_or_id {
139 warn!("failed to parse left_id: {:?}", row);
140 continue;
141 } else {
142 return Err(LinderaErrorKind::Parse
143 .with_error(anyhow::anyhow!("failed to parse left_id")));
144 }
145 }
146 };
147 let right_id = match u16::from_str(row[2].trim()) {
148 Ok(rid) => rid,
149 Err(_err) => {
150 if self.skip_invalid_cost_or_id {
151 warn!("failed to parse right_id: {:?}", row);
152 continue;
153 } else {
154 return Err(LinderaErrorKind::Parse
155 .with_error(anyhow::anyhow!("failed to parse right_id")));
156 }
157 }
158 };
159 let key = if self.normalize_details {
160 normalize(&row[0])
161 } else {
162 row[0].to_string()
163 };
164 word_entry_map.entry(key).or_default().push(WordEntry {
165 word_id: WordId(row_id as u32, true),
166 word_cost,
167 left_id,
168 right_id,
169 });
170 }
171
172 let wtr_words_path = output_dir.join(Path::new("dict.words"));
173 let mut wtr_words = io::BufWriter::new(
174 File::create(wtr_words_path)
175 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?,
176 );
177
178 let wtr_words_idx_path = output_dir.join(Path::new("dict.wordsidx"));
179 let mut wtr_words_idx = io::BufWriter::new(
180 File::create(wtr_words_idx_path)
181 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?,
182 );
183
184 let mut words_buffer = Vec::new();
185 let mut words_idx_buffer = Vec::new();
186 for row in rows.iter() {
187 let offset = words_buffer.len();
188 words_idx_buffer
189 .write_u32::<LittleEndian>(offset as u32)
190 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
191
192 let joined_details = if self.normalize_details {
193 row.iter()
194 .skip(4)
195 .map(|item| normalize(item))
196 .collect::<Vec<String>>()
197 .join("\0")
198 } else {
199 row.iter().skip(4).collect::<Vec<&str>>().join("\0")
200 };
201 let joined_details_len = u32::try_from(joined_details.as_bytes().len())
202 .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
203 words_buffer
204 .write_u32::<LittleEndian>(joined_details_len)
205 .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
206 words_buffer
207 .write_all(joined_details.as_bytes())
208 .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
209 }
210
211 compress_write(&words_buffer, self.compress_algorithm, &mut wtr_words)?;
212 compress_write(
213 &words_idx_buffer,
214 self.compress_algorithm,
215 &mut wtr_words_idx,
216 )?;
217
218 wtr_words
219 .flush()
220 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
221 wtr_words_idx
222 .flush()
223 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
224
225 let mut id = 0u32;
226
227 let mut keyset: Vec<(&[u8], u32)> = vec![];
228 for (key, word_entries) in &word_entry_map {
229 let len = word_entries.len() as u32;
230 let val = (id << 5) | len; keyset.push((key.as_bytes(), val));
232 id += len;
233 }
234
235 let da_bytes = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
236 LinderaErrorKind::Io.with_error(anyhow::anyhow!("DoubleArray build error."))
237 })?;
238
239 compress_write(&da_bytes, self.compress_algorithm, &mut wtr_da)?;
240
241 let mut vals_buffer = Vec::new();
242 for word_entries in word_entry_map.values() {
243 for word_entry in word_entries {
244 word_entry
245 .serialize(&mut vals_buffer)
246 .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
247 }
248 }
249
250 compress_write(&vals_buffer, self.compress_algorithm, &mut wtr_vals)?;
251
252 wtr_vals
253 .flush()
254 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
255
256 Ok(())
257 }
258}
259
260fn normalize(text: &str) -> String {
261 text.to_string().replace('―', "—").replace('~', "〜")
262}