lindera_dictionary/builder/
prefix_dictionary.rs1use std::borrow::Cow;
2use std::collections::BTreeMap;
3use std::fs::File;
4use std::io::Write;
5use std::io::{self, Read};
6use std::path::{Path, PathBuf};
7use std::str::FromStr;
8
9use anyhow::anyhow;
10use byteorder::{LittleEndian, WriteBytesExt};
11use csv::StringRecord;
12use daachorse::DoubleArrayAhoCorasickBuilder;
13use derive_builder::Builder;
14use encoding_rs::{Encoding, UTF_8};
15use encoding_rs_io::DecodeReaderBytesBuilder;
16use glob::glob;
17use log::debug;
18
19use crate::LinderaResult;
20use crate::dictionary::schema::Schema;
21use crate::error::LinderaErrorKind;
22use crate::util::write_data;
23use crate::viterbi::WordEntry;
24
25#[derive(Builder)]
26#[builder(name = PrefixDictionaryBuilderOptions)]
27#[builder(build_fn(name = "builder"))]
28pub struct PrefixDictionaryBuilder {
29 #[builder(default = "true")]
30 flexible_csv: bool,
31 #[builder(default = "\"UTF-8\".into()", setter(into))]
33 encoding: Cow<'static, str>,
34 #[builder(default = "false")]
35 normalize_details: bool,
36 #[builder(default = "false")]
37 skip_invalid_cost_or_id: bool,
38 #[builder(default = "Schema::default()")]
39 schema: Schema,
40}
41
42impl PrefixDictionaryBuilder {
43 pub fn new(schema: Schema) -> Self {
45 Self {
46 flexible_csv: true,
47 encoding: "UTF-8".into(),
48 normalize_details: false,
49 skip_invalid_cost_or_id: false,
50 schema,
51 }
52 }
53
54 pub fn build(&self, input_dir: &Path, output_dir: &Path) -> LinderaResult<()> {
56 let rows = self.load_csv_data(input_dir)?;
58
59 let word_entry_map = self.build_word_entry_map(&rows)?;
61
62 self.write_dictionary_files(output_dir, &rows, &word_entry_map)?;
64
65 Ok(())
66 }
67
68 fn load_csv_data(&self, input_dir: &Path) -> LinderaResult<Vec<StringRecord>> {
70 let filenames = self.collect_csv_files(input_dir)?;
71 let encoding = self.get_encoding()?;
72 let mut rows = self.read_csv_files(&filenames, encoding)?;
73
74 if self.normalize_details {
77 rows.sort_by_key(|row| normalize(&row[0]));
79 } else {
80 rows.sort_by(|a, b| a[0].cmp(&b[0]))
82 }
83
84 Ok(rows)
85 }
86
87 fn collect_csv_files(&self, input_dir: &Path) -> LinderaResult<Vec<PathBuf>> {
89 let pattern = if let Some(path) = input_dir.to_str() {
90 format!("{path}/*.csv")
91 } else {
92 return Err(LinderaErrorKind::Io
93 .with_error(anyhow::anyhow!("Failed to convert path to &str."))
94 .add_context(format!(
95 "Input directory path contains invalid characters: {input_dir:?}"
96 )));
97 };
98
99 let mut filenames: Vec<PathBuf> = Vec::new();
100 for entry in glob(&pattern).map_err(|err| {
101 LinderaErrorKind::Io
102 .with_error(anyhow::anyhow!(err))
103 .add_context(format!("Failed to glob CSV files with pattern: {pattern}"))
104 })? {
105 match entry {
106 Ok(path) => {
107 if let Some(filename) = path.file_name() {
108 filenames.push(Path::new(input_dir).join(filename));
109 } else {
110 return Err(LinderaErrorKind::Io
111 .with_error(anyhow::anyhow!("failed to get filename"))
112 .add_context(format!("Invalid filename in path: {path:?}")));
113 };
114 }
115 Err(err) => {
116 return Err(LinderaErrorKind::Content
117 .with_error(anyhow!(err))
118 .add_context(format!(
119 "Failed to process glob entry with pattern: {pattern}"
120 )));
121 }
122 }
123 }
124
125 Ok(filenames)
126 }
127
128 fn get_encoding(&self) -> LinderaResult<&'static Encoding> {
130 let encoding = Encoding::for_label_no_replacement(self.encoding.as_bytes());
131 encoding.ok_or_else(|| {
132 LinderaErrorKind::Decode
133 .with_error(anyhow!("Invalid encoding: {}", self.encoding))
134 .add_context("Failed to get encoding for CSV file reading")
135 })
136 }
137
138 fn read_csv_files(
140 &self,
141 filenames: &[PathBuf],
142 encoding: &'static Encoding,
143 ) -> LinderaResult<Vec<StringRecord>> {
144 let mut rows: Vec<StringRecord> = vec![];
145
146 for filename in filenames {
147 debug!("reading {filename:?}");
148
149 let file = File::open(filename).map_err(|err| {
150 LinderaErrorKind::Io
151 .with_error(anyhow::anyhow!(err))
152 .add_context(format!("Failed to open CSV file: {filename:?}"))
153 })?;
154 let reader: Box<dyn Read> = if encoding == UTF_8 {
155 Box::new(file)
156 } else {
157 Box::new(
158 DecodeReaderBytesBuilder::new()
159 .encoding(Some(encoding))
160 .build(file),
161 )
162 };
163 let mut rdr = csv::ReaderBuilder::new()
164 .has_headers(false)
165 .flexible(self.flexible_csv)
166 .from_reader(reader);
167
168 for result in rdr.records() {
169 let record = result.map_err(|err| {
170 LinderaErrorKind::Content
171 .with_error(anyhow!(err))
172 .add_context(format!("Failed to parse CSV record in file: {filename:?}"))
173 })?;
174 rows.push(record);
175 }
176 }
177
178 Ok(rows)
179 }
180
181 fn build_word_entry_map(
183 &self,
184 rows: &[StringRecord],
185 ) -> LinderaResult<BTreeMap<String, Vec<WordEntry>>> {
186 let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
187
188 for (row_id, row) in rows.iter().enumerate() {
189 let word_cost = self.parse_word_cost(row)?;
190 let left_id = self.parse_left_id(row)?;
191 let right_id = self.parse_right_id(row)?;
192
193 if word_cost.is_none() || left_id.is_none() || right_id.is_none() {
195 continue;
196 }
197
198 let key = if self.normalize_details {
199 if let Some(surface) = self.get_field_value(row, "surface")? {
200 normalize(&surface)
201 } else {
202 continue;
203 }
204 } else if let Some(surface) = self.get_field_value(row, "surface")? {
205 surface
206 } else {
207 continue;
208 };
209
210 word_entry_map.entry(key).or_default().push(WordEntry {
211 word_id: crate::viterbi::WordId::new(
212 crate::viterbi::LexType::System,
213 row_id as u32,
214 ),
215 word_cost: word_cost.unwrap(),
216 left_id: left_id.unwrap(),
217 right_id: right_id.unwrap(),
218 });
219 }
220
221 Ok(word_entry_map)
222 }
223
224 fn get_field_value(
226 &self,
227 row: &StringRecord,
228 field_name: &str,
229 ) -> LinderaResult<Option<String>> {
230 if let Some(index) = self.schema.get_field_index(field_name) {
231 if index >= row.len() {
232 return Ok(None);
233 }
234
235 let value = row[index].trim();
236 Ok(if value.is_empty() {
237 None
238 } else {
239 Some(value.to_string())
240 })
241 } else {
242 Ok(None)
243 }
244 }
245
246 fn parse_word_cost(&self, row: &StringRecord) -> LinderaResult<Option<i16>> {
248 let cost_str = self.get_field_value(row, "cost")?;
249 match cost_str {
250 Some(s) => match i16::from_str(&s) {
251 Ok(cost) => Ok(Some(cost)),
252 Err(_) => {
253 if self.skip_invalid_cost_or_id {
254 Ok(None)
255 } else {
256 Err(LinderaErrorKind::Content
257 .with_error(anyhow!("Invalid cost value: {s}")))
258 }
259 }
260 },
261 None => Ok(None),
262 }
263 }
264
265 fn parse_left_id(&self, row: &StringRecord) -> LinderaResult<Option<u16>> {
267 let left_id_str = self.get_field_value(row, "left_context_id")?;
268 match left_id_str {
269 Some(s) => match u16::from_str(&s) {
270 Ok(id) => Ok(Some(id)),
271 Err(_) => {
272 if self.skip_invalid_cost_or_id {
273 Ok(None)
274 } else {
275 Err(LinderaErrorKind::Content
276 .with_error(anyhow!("Invalid left context ID: {s}")))
277 }
278 }
279 },
280 None => Ok(None),
281 }
282 }
283
284 fn parse_right_id(&self, row: &StringRecord) -> LinderaResult<Option<u16>> {
286 let right_id_str = self.get_field_value(row, "right_context_id")?;
287 match right_id_str {
288 Some(s) => match u16::from_str(&s) {
289 Ok(id) => Ok(Some(id)),
290 Err(_) => {
291 if self.skip_invalid_cost_or_id {
292 Ok(None)
293 } else {
294 Err(LinderaErrorKind::Content
295 .with_error(anyhow!("Invalid right context ID: {s}")))
296 }
297 }
298 },
299 None => Ok(None),
300 }
301 }
302
303 fn write_dictionary_files(
305 &self,
306 output_dir: &Path,
307 rows: &[StringRecord],
308 word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
309 ) -> LinderaResult<()> {
310 self.write_words_files(output_dir, rows)?;
312
313 self.write_double_array_file(output_dir, word_entry_map)?;
315
316 self.write_values_file(output_dir, word_entry_map)?;
318
319 Ok(())
320 }
321
322 fn write_words_files(&self, output_dir: &Path, rows: &[StringRecord]) -> LinderaResult<()> {
324 let mut dict_words_buffer = Vec::new();
325 let mut dict_wordsidx_buffer = Vec::new();
326
327 for row in rows.iter() {
328 let offset = dict_words_buffer.len();
329 dict_wordsidx_buffer
330 .write_u32::<LittleEndian>(offset as u32)
331 .map_err(|err| {
332 LinderaErrorKind::Io
333 .with_error(anyhow::anyhow!(err))
334 .add_context("Failed to write word index offset to dict.wordsidx buffer")
335 })?;
336
337 let joined_details = if self.normalize_details {
339 row.iter()
340 .skip(4)
341 .map(normalize)
342 .collect::<Vec<String>>()
343 .join("\0")
344 } else {
345 row.iter().skip(4).collect::<Vec<&str>>().join("\0")
346 };
347 let joined_details_len = u32::try_from(joined_details.len()).map_err(|err| {
348 LinderaErrorKind::Serialize
349 .with_error(anyhow::anyhow!(err))
350 .add_context(format!(
351 "Word details length too large: {} bytes",
352 joined_details.len()
353 ))
354 })?;
355
356 dict_words_buffer
358 .write_u32::<LittleEndian>(joined_details_len)
359 .map_err(|err| {
360 LinderaErrorKind::Serialize
361 .with_error(anyhow::anyhow!(err))
362 .add_context("Failed to write word details length to dict.words buffer")
363 })?;
364 dict_words_buffer
365 .write_all(joined_details.as_bytes())
366 .map_err(|err| {
367 LinderaErrorKind::Serialize
368 .with_error(anyhow::anyhow!(err))
369 .add_context("Failed to write word details to dict.words buffer")
370 })?;
371 }
372
373 let dict_words_path = output_dir.join(Path::new("dict.words"));
375 let mut dict_words_writer =
376 io::BufWriter::new(File::create(&dict_words_path).map_err(|err| {
377 LinderaErrorKind::Io
378 .with_error(anyhow::anyhow!(err))
379 .add_context(format!(
380 "Failed to create dict.words file: {dict_words_path:?}"
381 ))
382 })?);
383
384 write_data(&dict_words_buffer, &mut dict_words_writer)?;
385
386 dict_words_writer.flush().map_err(|err| {
387 LinderaErrorKind::Io
388 .with_error(anyhow::anyhow!(err))
389 .add_context(format!(
390 "Failed to flush dict.words file: {dict_words_path:?}"
391 ))
392 })?;
393
394 let dict_wordsidx_path = output_dir.join(Path::new("dict.wordsidx"));
396 let mut dict_wordsidx_writer =
397 io::BufWriter::new(File::create(&dict_wordsidx_path).map_err(|err| {
398 LinderaErrorKind::Io
399 .with_error(anyhow::anyhow!(err))
400 .add_context(format!(
401 "Failed to create dict.wordsidx file: {dict_wordsidx_path:?}"
402 ))
403 })?);
404
405 write_data(&dict_wordsidx_buffer, &mut dict_wordsidx_writer)?;
406
407 dict_wordsidx_writer.flush().map_err(|err| {
408 LinderaErrorKind::Io
409 .with_error(anyhow::anyhow!(err))
410 .add_context(format!(
411 "Failed to flush dict.wordsidx file: {dict_wordsidx_path:?}"
412 ))
413 })?;
414
415 Ok(())
416 }
417
418 fn write_double_array_file(
420 &self,
421 output_dir: &Path,
422 word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
423 ) -> LinderaResult<()> {
424 let mut id = 0u32;
425 let mut keyset: Vec<(&[u8], u32)> = vec![];
426
427 for (key, word_entries) in word_entry_map {
428 let len = word_entries.len() as u32;
429 let val = (id << 8) | len; keyset.push((key.as_bytes(), val));
431 id += len;
432 }
433
434 let keyset_len = keyset.len();
435
436 let dict_da = DoubleArrayAhoCorasickBuilder::new()
437 .build_with_values(keyset)
438 .map_err(|err| {
439 LinderaErrorKind::Build
440 .with_error(anyhow::anyhow!(err))
441 .add_context(format!(
442 "Failed to build DoubleArray with {} keys for prefix dictionary",
443 keyset_len
444 ))
445 })?;
446
447 let dict_da_buffer = dict_da.serialize();
448
449 let dict_da_path = output_dir.join(Path::new("dict.da"));
450 let mut dict_da_writer =
451 io::BufWriter::new(File::create(&dict_da_path).map_err(|err| {
452 LinderaErrorKind::Io
453 .with_error(anyhow::anyhow!(err))
454 .add_context(format!("Failed to create dict.da file: {dict_da_path:?}"))
455 })?);
456
457 write_data(&dict_da_buffer, &mut dict_da_writer)?;
458
459 Ok(())
460 }
461
462 fn write_values_file(
464 &self,
465 output_dir: &Path,
466 word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
467 ) -> LinderaResult<()> {
468 let mut dict_vals_buffer = Vec::new();
469 for word_entries in word_entry_map.values() {
470 for word_entry in word_entries {
471 word_entry.serialize(&mut dict_vals_buffer).map_err(|err| {
472 LinderaErrorKind::Serialize
473 .with_error(anyhow::anyhow!(err))
474 .add_context(format!(
475 "Failed to serialize word entry (id: {})",
476 word_entry.word_id.id
477 ))
478 })?;
479 }
480 }
481
482 let dict_vals_path = output_dir.join(Path::new("dict.vals"));
483 let mut dict_vals_writer =
484 io::BufWriter::new(File::create(&dict_vals_path).map_err(|err| {
485 LinderaErrorKind::Io
486 .with_error(anyhow::anyhow!(err))
487 .add_context(format!(
488 "Failed to create dict.vals file: {dict_vals_path:?}"
489 ))
490 })?);
491
492 write_data(&dict_vals_buffer, &mut dict_vals_writer)?;
493
494 dict_vals_writer.flush().map_err(|err| {
495 LinderaErrorKind::Io
496 .with_error(anyhow::anyhow!(err))
497 .add_context(format!(
498 "Failed to flush dict.vals file: {dict_vals_path:?}"
499 ))
500 })?;
501
502 Ok(())
503 }
504}
505
506fn normalize(text: &str) -> String {
507 text.to_string().replace('―', "—").replace('~', "〜")
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513 use crate::dictionary::schema::Schema;
514 use csv::StringRecord;
515
516 #[test]
517 fn test_new_with_schema() {
518 let schema = Schema::default();
519 let builder = PrefixDictionaryBuilder::new(schema.clone());
520
521 assert!(builder.flexible_csv);
524 assert_eq!(builder.encoding, "UTF-8");
525 assert!(!builder.normalize_details);
526 assert!(!builder.skip_invalid_cost_or_id);
527 }
528
529 #[test]
530 fn test_get_common_field_value_empty() {
531 let schema = Schema::default();
532 let builder = PrefixDictionaryBuilder::new(schema);
533
534 let record = StringRecord::from(vec![
535 "", "123", "456", "789", ]);
540
541 let surface = builder.get_field_value(&record, "surface").unwrap();
542 assert_eq!(surface, None);
543 }
544
545 #[test]
546 fn test_get_common_field_value_out_of_bounds() {
547 let schema = Schema::default();
548 let builder = PrefixDictionaryBuilder::new(schema);
549
550 let record = StringRecord::from(vec![
551 "surface_form", ]);
553
554 let left_id = builder.get_field_value(&record, "left_context_id").unwrap();
555 assert_eq!(left_id, None);
556 }
557
558 #[test]
559 fn test_parse_word_cost() {
560 let schema = Schema::default();
561 let builder = PrefixDictionaryBuilder::new(schema);
562
563 let record = StringRecord::from(vec![
564 "surface_form", "123", "456", "789", ]);
569
570 let cost = builder.parse_word_cost(&record).unwrap();
571 assert_eq!(cost, Some(789));
572 }
573
574 #[test]
575 fn test_parse_word_cost_invalid() {
576 let schema = Schema::default();
577 let builder = PrefixDictionaryBuilder::new(schema);
578
579 let record = StringRecord::from(vec![
580 "surface_form", "123", "456", "invalid", ]);
585
586 let result = builder.parse_word_cost(&record);
587 assert!(result.is_err());
588 }
589
590 #[test]
591 fn test_parse_word_cost_skip_invalid() {
592 let schema = Schema::default();
593 let mut builder = PrefixDictionaryBuilder::new(schema);
594 builder.skip_invalid_cost_or_id = true;
595
596 let record = StringRecord::from(vec![
597 "surface_form", "123", "456", "invalid", ]);
602
603 let cost = builder.parse_word_cost(&record).unwrap();
604 assert_eq!(cost, None);
605 }
606
607 #[test]
608 fn test_parse_left_id() {
609 let schema = Schema::default();
610 let builder = PrefixDictionaryBuilder::new(schema);
611
612 let record = StringRecord::from(vec![
613 "surface_form", "123", "456", "789", ]);
618
619 let left_id = builder.parse_left_id(&record).unwrap();
620 assert_eq!(left_id, Some(123));
621 }
622
623 #[test]
624 fn test_parse_right_id() {
625 let schema = Schema::default();
626 let builder = PrefixDictionaryBuilder::new(schema);
627
628 let record = StringRecord::from(vec![
629 "surface_form", "123", "456", "789", ]);
634
635 let right_id = builder.parse_right_id(&record).unwrap();
636 assert_eq!(right_id, Some(456));
637 }
638
639 #[test]
640 fn test_normalize_function() {
641 assert_eq!(normalize("test―text"), "test—text");
642 assert_eq!(normalize("test~text"), "test〜text");
643 assert_eq!(normalize("test―text~more"), "test—text〜more");
644 assert_eq!(normalize("normal text"), "normal text");
645 }
646
647 #[test]
648 fn test_get_encoding() {
649 let schema = Schema::default();
650 let builder = PrefixDictionaryBuilder::new(schema);
651
652 let encoding = builder.get_encoding().unwrap();
653 assert_eq!(encoding.name(), "UTF-8");
654 }
655
656 #[test]
657 fn test_get_encoding_invalid() {
658 let schema = Schema::default();
659 let mut builder = PrefixDictionaryBuilder::new(schema);
660 builder.encoding = "INVALID-ENCODING".into();
661
662 let result = builder.get_encoding();
663 assert!(result.is_err());
664 }
665
666 #[test]
667 fn test_get_common_field_value() {
668 let schema = Schema::default();
669 let builder = PrefixDictionaryBuilder::new(schema);
670
671 let record = StringRecord::from(vec![
672 "word", "123", "456", "789", "名詞", ]);
678
679 assert_eq!(
681 builder.get_field_value(&record, "surface").unwrap(),
682 Some("word".to_string())
683 );
684 assert_eq!(
685 builder.get_field_value(&record, "left_context_id").unwrap(),
686 Some("123".to_string())
687 );
688 assert_eq!(
689 builder
690 .get_field_value(&record, "right_context_id")
691 .unwrap(),
692 Some("456".to_string())
693 );
694 assert_eq!(
695 builder.get_field_value(&record, "cost").unwrap(),
696 Some("789".to_string())
697 );
698
699 let short_record = StringRecord::from(vec!["word", "123"]);
701 assert_eq!(
702 builder.get_field_value(&short_record, "cost").unwrap(),
703 None
704 );
705 }
706}