lindera_dictionary/builder/
prefix_dictionary.rs1use std::borrow::Cow;
2use std::collections::BTreeMap;
3use std::fs::File;
4use std::io::Write;
5use std::io::{self, Read};
6use std::path::{Path, PathBuf};
7use std::str::FromStr;
8
9use anyhow::anyhow;
10use byteorder::{LittleEndian, WriteBytesExt};
11use csv::StringRecord;
12use daachorse::DoubleArrayAhoCorasickBuilder;
13use derive_builder::Builder;
14use encoding_rs::{Encoding, UTF_8};
15use encoding_rs_io::DecodeReaderBytesBuilder;
16use glob::glob;
17use log::debug;
18
19use crate::LinderaResult;
20use crate::decompress::Algorithm;
21use crate::dictionary::schema::Schema;
22use crate::error::LinderaErrorKind;
23use crate::util::compress_write;
24use crate::viterbi::WordEntry;
25
26#[derive(Builder)]
27#[builder(name = PrefixDictionaryBuilderOptions)]
28#[builder(build_fn(name = "builder"))]
29pub struct PrefixDictionaryBuilder {
30 #[builder(default = "true")]
31 flexible_csv: bool,
32 #[builder(default = "\"UTF-8\".into()", setter(into))]
34 encoding: Cow<'static, str>,
35 #[builder(default = "Algorithm::Deflate")]
36 compress_algorithm: Algorithm,
37 #[builder(default = "false")]
38 normalize_details: bool,
39 #[builder(default = "false")]
40 skip_invalid_cost_or_id: bool,
41 #[builder(default = "Schema::default()")]
42 schema: Schema,
43}
44
45impl PrefixDictionaryBuilder {
46 pub fn new(schema: Schema) -> Self {
48 Self {
49 flexible_csv: true,
50 encoding: "UTF-8".into(),
51 compress_algorithm: Algorithm::Deflate,
52 normalize_details: false,
53 skip_invalid_cost_or_id: false,
54 schema,
55 }
56 }
57
58 pub fn build(&self, input_dir: &Path, output_dir: &Path) -> LinderaResult<()> {
60 let rows = self.load_csv_data(input_dir)?;
62
63 let word_entry_map = self.build_word_entry_map(&rows)?;
65
66 self.write_dictionary_files(output_dir, &rows, &word_entry_map)?;
68
69 Ok(())
70 }
71
72 fn load_csv_data(&self, input_dir: &Path) -> LinderaResult<Vec<StringRecord>> {
74 let filenames = self.collect_csv_files(input_dir)?;
75 let encoding = self.get_encoding()?;
76 let mut rows = self.read_csv_files(&filenames, encoding)?;
77
78 if self.normalize_details {
81 rows.sort_by_key(|row| normalize(&row[0]));
83 } else {
84 rows.sort_by(|a, b| a[0].cmp(&b[0]))
86 }
87
88 Ok(rows)
89 }
90
91 fn collect_csv_files(&self, input_dir: &Path) -> LinderaResult<Vec<PathBuf>> {
93 let pattern = if let Some(path) = input_dir.to_str() {
94 format!("{path}/*.csv")
95 } else {
96 return Err(LinderaErrorKind::Io
97 .with_error(anyhow::anyhow!("Failed to convert path to &str."))
98 .add_context(format!(
99 "Input directory path contains invalid characters: {input_dir:?}"
100 )));
101 };
102
103 let mut filenames: Vec<PathBuf> = Vec::new();
104 for entry in glob(&pattern).map_err(|err| {
105 LinderaErrorKind::Io
106 .with_error(anyhow::anyhow!(err))
107 .add_context(format!("Failed to glob CSV files with pattern: {pattern}"))
108 })? {
109 match entry {
110 Ok(path) => {
111 if let Some(filename) = path.file_name() {
112 filenames.push(Path::new(input_dir).join(filename));
113 } else {
114 return Err(LinderaErrorKind::Io
115 .with_error(anyhow::anyhow!("failed to get filename"))
116 .add_context(format!("Invalid filename in path: {path:?}")));
117 };
118 }
119 Err(err) => {
120 return Err(LinderaErrorKind::Content
121 .with_error(anyhow!(err))
122 .add_context(format!(
123 "Failed to process glob entry with pattern: {pattern}"
124 )));
125 }
126 }
127 }
128
129 Ok(filenames)
130 }
131
132 fn get_encoding(&self) -> LinderaResult<&'static Encoding> {
134 let encoding = Encoding::for_label_no_replacement(self.encoding.as_bytes());
135 encoding.ok_or_else(|| {
136 LinderaErrorKind::Decode
137 .with_error(anyhow!("Invalid encoding: {}", self.encoding))
138 .add_context("Failed to get encoding for CSV file reading")
139 })
140 }
141
142 fn read_csv_files(
144 &self,
145 filenames: &[PathBuf],
146 encoding: &'static Encoding,
147 ) -> LinderaResult<Vec<StringRecord>> {
148 let mut rows: Vec<StringRecord> = vec![];
149
150 for filename in filenames {
151 debug!("reading {filename:?}");
152
153 let file = File::open(filename).map_err(|err| {
154 LinderaErrorKind::Io
155 .with_error(anyhow::anyhow!(err))
156 .add_context(format!("Failed to open CSV file: {filename:?}"))
157 })?;
158 let reader: Box<dyn Read> = if encoding == UTF_8 {
159 Box::new(file)
160 } else {
161 Box::new(
162 DecodeReaderBytesBuilder::new()
163 .encoding(Some(encoding))
164 .build(file),
165 )
166 };
167 let mut rdr = csv::ReaderBuilder::new()
168 .has_headers(false)
169 .flexible(self.flexible_csv)
170 .from_reader(reader);
171
172 for result in rdr.records() {
173 let record = result.map_err(|err| {
174 LinderaErrorKind::Content
175 .with_error(anyhow!(err))
176 .add_context(format!("Failed to parse CSV record in file: {filename:?}"))
177 })?;
178 rows.push(record);
179 }
180 }
181
182 Ok(rows)
183 }
184
185 fn build_word_entry_map(
187 &self,
188 rows: &[StringRecord],
189 ) -> LinderaResult<BTreeMap<String, Vec<WordEntry>>> {
190 let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
191
192 for (row_id, row) in rows.iter().enumerate() {
193 let word_cost = self.parse_word_cost(row)?;
194 let left_id = self.parse_left_id(row)?;
195 let right_id = self.parse_right_id(row)?;
196
197 if word_cost.is_none() || left_id.is_none() || right_id.is_none() {
199 continue;
200 }
201
202 let key = if self.normalize_details {
203 if let Some(surface) = self.get_field_value(row, "surface")? {
204 normalize(&surface)
205 } else {
206 continue;
207 }
208 } else if let Some(surface) = self.get_field_value(row, "surface")? {
209 surface
210 } else {
211 continue;
212 };
213
214 word_entry_map.entry(key).or_default().push(WordEntry {
215 word_id: crate::viterbi::WordId::new(
216 crate::viterbi::LexType::System,
217 row_id as u32,
218 ),
219 word_cost: word_cost.unwrap(),
220 left_id: left_id.unwrap(),
221 right_id: right_id.unwrap(),
222 });
223 }
224
225 Ok(word_entry_map)
226 }
227
228 fn get_field_value(
230 &self,
231 row: &StringRecord,
232 field_name: &str,
233 ) -> LinderaResult<Option<String>> {
234 if let Some(index) = self.schema.get_field_index(field_name) {
235 if index >= row.len() {
236 return Ok(None);
237 }
238
239 let value = row[index].trim();
240 Ok(if value.is_empty() {
241 None
242 } else {
243 Some(value.to_string())
244 })
245 } else {
246 Ok(None)
247 }
248 }
249
250 fn parse_word_cost(&self, row: &StringRecord) -> LinderaResult<Option<i16>> {
252 let cost_str = self.get_field_value(row, "cost")?;
253 match cost_str {
254 Some(s) => match i16::from_str(&s) {
255 Ok(cost) => Ok(Some(cost)),
256 Err(_) => {
257 if self.skip_invalid_cost_or_id {
258 Ok(None)
259 } else {
260 Err(LinderaErrorKind::Content
261 .with_error(anyhow!("Invalid cost value: {s}")))
262 }
263 }
264 },
265 None => Ok(None),
266 }
267 }
268
269 fn parse_left_id(&self, row: &StringRecord) -> LinderaResult<Option<u16>> {
271 let left_id_str = self.get_field_value(row, "left_context_id")?;
272 match left_id_str {
273 Some(s) => match u16::from_str(&s) {
274 Ok(id) => Ok(Some(id)),
275 Err(_) => {
276 if self.skip_invalid_cost_or_id {
277 Ok(None)
278 } else {
279 Err(LinderaErrorKind::Content
280 .with_error(anyhow!("Invalid left context ID: {s}")))
281 }
282 }
283 },
284 None => Ok(None),
285 }
286 }
287
288 fn parse_right_id(&self, row: &StringRecord) -> LinderaResult<Option<u16>> {
290 let right_id_str = self.get_field_value(row, "right_context_id")?;
291 match right_id_str {
292 Some(s) => match u16::from_str(&s) {
293 Ok(id) => Ok(Some(id)),
294 Err(_) => {
295 if self.skip_invalid_cost_or_id {
296 Ok(None)
297 } else {
298 Err(LinderaErrorKind::Content
299 .with_error(anyhow!("Invalid right context ID: {s}")))
300 }
301 }
302 },
303 None => Ok(None),
304 }
305 }
306
307 fn write_dictionary_files(
309 &self,
310 output_dir: &Path,
311 rows: &[StringRecord],
312 word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
313 ) -> LinderaResult<()> {
314 self.write_words_files(output_dir, rows)?;
316
317 self.write_double_array_file(output_dir, word_entry_map)?;
319
320 self.write_values_file(output_dir, word_entry_map)?;
322
323 Ok(())
324 }
325
326 fn write_words_files(&self, output_dir: &Path, rows: &[StringRecord]) -> LinderaResult<()> {
328 let mut dict_words_buffer = Vec::new();
329 let mut dict_wordsidx_buffer = Vec::new();
330
331 for row in rows.iter() {
332 let offset = dict_words_buffer.len();
333 dict_wordsidx_buffer
334 .write_u32::<LittleEndian>(offset as u32)
335 .map_err(|err| {
336 LinderaErrorKind::Io
337 .with_error(anyhow::anyhow!(err))
338 .add_context("Failed to write word index offset to dict.wordsidx buffer")
339 })?;
340
341 let joined_details = if self.normalize_details {
343 row.iter()
344 .skip(4)
345 .map(normalize)
346 .collect::<Vec<String>>()
347 .join("\0")
348 } else {
349 row.iter().skip(4).collect::<Vec<&str>>().join("\0")
350 };
351 let joined_details_len = u32::try_from(joined_details.len()).map_err(|err| {
352 LinderaErrorKind::Serialize
353 .with_error(anyhow::anyhow!(err))
354 .add_context(format!(
355 "Word details length too large: {} bytes",
356 joined_details.len()
357 ))
358 })?;
359
360 dict_words_buffer
362 .write_u32::<LittleEndian>(joined_details_len)
363 .map_err(|err| {
364 LinderaErrorKind::Serialize
365 .with_error(anyhow::anyhow!(err))
366 .add_context("Failed to write word details length to dict.words buffer")
367 })?;
368 dict_words_buffer
369 .write_all(joined_details.as_bytes())
370 .map_err(|err| {
371 LinderaErrorKind::Serialize
372 .with_error(anyhow::anyhow!(err))
373 .add_context("Failed to write word details to dict.words buffer")
374 })?;
375 }
376
377 let dict_words_path = output_dir.join(Path::new("dict.words"));
379 let mut dict_words_writer =
380 io::BufWriter::new(File::create(&dict_words_path).map_err(|err| {
381 LinderaErrorKind::Io
382 .with_error(anyhow::anyhow!(err))
383 .add_context(format!(
384 "Failed to create dict.words file: {dict_words_path:?}"
385 ))
386 })?);
387
388 compress_write(
389 &dict_words_buffer,
390 self.compress_algorithm,
391 &mut dict_words_writer,
392 )?;
393
394 dict_words_writer.flush().map_err(|err| {
395 LinderaErrorKind::Io
396 .with_error(anyhow::anyhow!(err))
397 .add_context(format!(
398 "Failed to flush dict.words file: {dict_words_path:?}"
399 ))
400 })?;
401
402 let dict_wordsidx_path = output_dir.join(Path::new("dict.wordsidx"));
404 let mut dict_wordsidx_writer =
405 io::BufWriter::new(File::create(&dict_wordsidx_path).map_err(|err| {
406 LinderaErrorKind::Io
407 .with_error(anyhow::anyhow!(err))
408 .add_context(format!(
409 "Failed to create dict.wordsidx file: {dict_wordsidx_path:?}"
410 ))
411 })?);
412
413 compress_write(
414 &dict_wordsidx_buffer,
415 self.compress_algorithm,
416 &mut dict_wordsidx_writer,
417 )?;
418
419 dict_wordsidx_writer.flush().map_err(|err| {
420 LinderaErrorKind::Io
421 .with_error(anyhow::anyhow!(err))
422 .add_context(format!(
423 "Failed to flush dict.wordsidx file: {dict_wordsidx_path:?}"
424 ))
425 })?;
426
427 Ok(())
428 }
429
430 fn write_double_array_file(
432 &self,
433 output_dir: &Path,
434 word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
435 ) -> LinderaResult<()> {
436 let mut id = 0u32;
437 let mut keyset: Vec<(&[u8], u32)> = vec![];
438
439 for (key, word_entries) in word_entry_map {
440 let len = word_entries.len() as u32;
441 let val = (id << 5) | len; keyset.push((key.as_bytes(), val));
443 id += len;
444 }
445
446 let keyset_len = keyset.len();
447
448 let dict_da = DoubleArrayAhoCorasickBuilder::new()
449 .build_with_values(keyset)
450 .map_err(|err| {
451 LinderaErrorKind::Build
452 .with_error(anyhow::anyhow!(err))
453 .add_context(format!(
454 "Failed to build DoubleArray with {} keys for prefix dictionary",
455 keyset_len
456 ))
457 })?;
458
459 let dict_da_buffer = dict_da.serialize();
460
461 let dict_da_path = output_dir.join(Path::new("dict.da"));
462 let mut dict_da_writer =
463 io::BufWriter::new(File::create(&dict_da_path).map_err(|err| {
464 LinderaErrorKind::Io
465 .with_error(anyhow::anyhow!(err))
466 .add_context(format!("Failed to create dict.da file: {dict_da_path:?}"))
467 })?);
468
469 compress_write(
470 &dict_da_buffer,
471 self.compress_algorithm,
472 &mut dict_da_writer,
473 )?;
474
475 Ok(())
476 }
477
478 fn write_values_file(
480 &self,
481 output_dir: &Path,
482 word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
483 ) -> LinderaResult<()> {
484 let mut dict_vals_buffer = Vec::new();
485 for word_entries in word_entry_map.values() {
486 for word_entry in word_entries {
487 word_entry.serialize(&mut dict_vals_buffer).map_err(|err| {
488 LinderaErrorKind::Serialize
489 .with_error(anyhow::anyhow!(err))
490 .add_context(format!(
491 "Failed to serialize word entry (id: {})",
492 word_entry.word_id.id
493 ))
494 })?;
495 }
496 }
497
498 let dict_vals_path = output_dir.join(Path::new("dict.vals"));
499 let mut dict_vals_writer =
500 io::BufWriter::new(File::create(&dict_vals_path).map_err(|err| {
501 LinderaErrorKind::Io
502 .with_error(anyhow::anyhow!(err))
503 .add_context(format!(
504 "Failed to create dict.vals file: {dict_vals_path:?}"
505 ))
506 })?);
507
508 compress_write(
509 &dict_vals_buffer,
510 self.compress_algorithm,
511 &mut dict_vals_writer,
512 )?;
513
514 dict_vals_writer.flush().map_err(|err| {
515 LinderaErrorKind::Io
516 .with_error(anyhow::anyhow!(err))
517 .add_context(format!(
518 "Failed to flush dict.vals file: {dict_vals_path:?}"
519 ))
520 })?;
521
522 Ok(())
523 }
524}
525
526fn normalize(text: &str) -> String {
527 text.to_string().replace('―', "—").replace('~', "〜")
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use crate::dictionary::schema::Schema;
534 use csv::StringRecord;
535
536 #[test]
537 fn test_new_with_schema() {
538 let schema = Schema::default();
539 let builder = PrefixDictionaryBuilder::new(schema.clone());
540
541 assert!(builder.flexible_csv);
544 assert_eq!(builder.encoding, "UTF-8");
545 assert!(!builder.normalize_details);
546 assert!(!builder.skip_invalid_cost_or_id);
547 }
548
549 #[test]
550 fn test_get_common_field_value_empty() {
551 let schema = Schema::default();
552 let builder = PrefixDictionaryBuilder::new(schema);
553
554 let record = StringRecord::from(vec![
555 "", "123", "456", "789", ]);
560
561 let surface = builder.get_field_value(&record, "surface").unwrap();
562 assert_eq!(surface, None);
563 }
564
565 #[test]
566 fn test_get_common_field_value_out_of_bounds() {
567 let schema = Schema::default();
568 let builder = PrefixDictionaryBuilder::new(schema);
569
570 let record = StringRecord::from(vec![
571 "surface_form", ]);
573
574 let left_id = builder.get_field_value(&record, "left_context_id").unwrap();
575 assert_eq!(left_id, None);
576 }
577
578 #[test]
579 fn test_parse_word_cost() {
580 let schema = Schema::default();
581 let builder = PrefixDictionaryBuilder::new(schema);
582
583 let record = StringRecord::from(vec![
584 "surface_form", "123", "456", "789", ]);
589
590 let cost = builder.parse_word_cost(&record).unwrap();
591 assert_eq!(cost, Some(789));
592 }
593
594 #[test]
595 fn test_parse_word_cost_invalid() {
596 let schema = Schema::default();
597 let builder = PrefixDictionaryBuilder::new(schema);
598
599 let record = StringRecord::from(vec![
600 "surface_form", "123", "456", "invalid", ]);
605
606 let result = builder.parse_word_cost(&record);
607 assert!(result.is_err());
608 }
609
610 #[test]
611 fn test_parse_word_cost_skip_invalid() {
612 let schema = Schema::default();
613 let mut builder = PrefixDictionaryBuilder::new(schema);
614 builder.skip_invalid_cost_or_id = true;
615
616 let record = StringRecord::from(vec![
617 "surface_form", "123", "456", "invalid", ]);
622
623 let cost = builder.parse_word_cost(&record).unwrap();
624 assert_eq!(cost, None);
625 }
626
627 #[test]
628 fn test_parse_left_id() {
629 let schema = Schema::default();
630 let builder = PrefixDictionaryBuilder::new(schema);
631
632 let record = StringRecord::from(vec![
633 "surface_form", "123", "456", "789", ]);
638
639 let left_id = builder.parse_left_id(&record).unwrap();
640 assert_eq!(left_id, Some(123));
641 }
642
643 #[test]
644 fn test_parse_right_id() {
645 let schema = Schema::default();
646 let builder = PrefixDictionaryBuilder::new(schema);
647
648 let record = StringRecord::from(vec![
649 "surface_form", "123", "456", "789", ]);
654
655 let right_id = builder.parse_right_id(&record).unwrap();
656 assert_eq!(right_id, Some(456));
657 }
658
659 #[test]
660 fn test_normalize_function() {
661 assert_eq!(normalize("test―text"), "test—text");
662 assert_eq!(normalize("test~text"), "test〜text");
663 assert_eq!(normalize("test―text~more"), "test—text〜more");
664 assert_eq!(normalize("normal text"), "normal text");
665 }
666
667 #[test]
668 fn test_get_encoding() {
669 let schema = Schema::default();
670 let builder = PrefixDictionaryBuilder::new(schema);
671
672 let encoding = builder.get_encoding().unwrap();
673 assert_eq!(encoding.name(), "UTF-8");
674 }
675
676 #[test]
677 fn test_get_encoding_invalid() {
678 let schema = Schema::default();
679 let mut builder = PrefixDictionaryBuilder::new(schema);
680 builder.encoding = "INVALID-ENCODING".into();
681
682 let result = builder.get_encoding();
683 assert!(result.is_err());
684 }
685
686 #[test]
687 fn test_get_common_field_value() {
688 let schema = Schema::default();
689 let builder = PrefixDictionaryBuilder::new(schema);
690
691 let record = StringRecord::from(vec![
692 "word", "123", "456", "789", "名詞", ]);
698
699 assert_eq!(
701 builder.get_field_value(&record, "surface").unwrap(),
702 Some("word".to_string())
703 );
704 assert_eq!(
705 builder.get_field_value(&record, "left_context_id").unwrap(),
706 Some("123".to_string())
707 );
708 assert_eq!(
709 builder
710 .get_field_value(&record, "right_context_id")
711 .unwrap(),
712 Some("456".to_string())
713 );
714 assert_eq!(
715 builder.get_field_value(&record, "cost").unwrap(),
716 Some("789".to_string())
717 );
718
719 let short_record = StringRecord::from(vec!["word", "123"]);
721 assert_eq!(
722 builder.get_field_value(&short_record, "cost").unwrap(),
723 None
724 );
725 }
726}