lindera_dictionary/builder/
prefix_dictionary.rs1use std::borrow::Cow;
2use std::collections::BTreeMap;
3use std::fs::File;
4use std::io::Write;
5use std::io::{self, Read};
6use std::path::{Path, PathBuf};
7use std::str::FromStr;
8
9use anyhow::anyhow;
10use byteorder::{LittleEndian, WriteBytesExt};
11use csv::StringRecord;
12use derive_builder::Builder;
13use encoding_rs::{Encoding, UTF_8};
14use encoding_rs_io::DecodeReaderBytesBuilder;
15use glob::glob;
16use log::debug;
17use yada::builder::DoubleArrayBuilder;
18
19use crate::LinderaResult;
20use crate::decompress::Algorithm;
21use crate::dictionary::schema::Schema;
22use crate::error::LinderaErrorKind;
23use crate::util::compress_write;
24use crate::viterbi::WordEntry;
25
26#[derive(Builder)]
27#[builder(name = PrefixDictionaryBuilderOptions)]
28#[builder(build_fn(name = "builder"))]
29pub struct PrefixDictionaryBuilder {
30 #[builder(default = "true")]
31 flexible_csv: bool,
32 #[builder(default = "\"UTF-8\".into()", setter(into))]
34 encoding: Cow<'static, str>,
35 #[builder(default = "Algorithm::Deflate")]
36 compress_algorithm: Algorithm,
37 #[builder(default = "false")]
38 normalize_details: bool,
39 #[builder(default = "false")]
40 skip_invalid_cost_or_id: bool,
41 #[builder(default = "Schema::default()")]
42 schema: Schema,
43}
44
45impl PrefixDictionaryBuilder {
46 pub fn new(schema: Schema) -> Self {
48 Self {
49 flexible_csv: true,
50 encoding: "UTF-8".into(),
51 compress_algorithm: Algorithm::Deflate,
52 normalize_details: false,
53 skip_invalid_cost_or_id: false,
54 schema,
55 }
56 }
57
58 pub fn build(&self, input_dir: &Path, output_dir: &Path) -> LinderaResult<()> {
60 let rows = self.load_csv_data(input_dir)?;
62
63 let word_entry_map = self.build_word_entry_map(&rows)?;
65
66 self.write_dictionary_files(output_dir, &rows, &word_entry_map)?;
68
69 Ok(())
70 }
71
72 fn load_csv_data(&self, input_dir: &Path) -> LinderaResult<Vec<StringRecord>> {
74 let filenames = self.collect_csv_files(input_dir)?;
75 let encoding = self.get_encoding()?;
76 let mut rows = self.read_csv_files(&filenames, encoding)?;
77
78 if self.normalize_details {
81 rows.sort_by_key(|row| normalize(&row[0]));
83 } else {
84 rows.sort_by(|a, b| a[0].cmp(&b[0]))
86 }
87
88 Ok(rows)
89 }
90
91 fn collect_csv_files(&self, input_dir: &Path) -> LinderaResult<Vec<PathBuf>> {
93 let pattern = if let Some(path) = input_dir.to_str() {
94 format!("{path}/*.csv")
95 } else {
96 return Err(LinderaErrorKind::Io
97 .with_error(anyhow::anyhow!("Failed to convert path to &str."))
98 .add_context(format!(
99 "Input directory path contains invalid characters: {input_dir:?}"
100 )));
101 };
102
103 let mut filenames: Vec<PathBuf> = Vec::new();
104 for entry in glob(&pattern).map_err(|err| {
105 LinderaErrorKind::Io
106 .with_error(anyhow::anyhow!(err))
107 .add_context(format!("Failed to glob CSV files with pattern: {pattern}"))
108 })? {
109 match entry {
110 Ok(path) => {
111 if let Some(filename) = path.file_name() {
112 filenames.push(Path::new(input_dir).join(filename));
113 } else {
114 return Err(LinderaErrorKind::Io
115 .with_error(anyhow::anyhow!("failed to get filename"))
116 .add_context(format!("Invalid filename in path: {path:?}")));
117 };
118 }
119 Err(err) => {
120 return Err(LinderaErrorKind::Content
121 .with_error(anyhow!(err))
122 .add_context(format!(
123 "Failed to process glob entry with pattern: {pattern}"
124 )));
125 }
126 }
127 }
128
129 Ok(filenames)
130 }
131
132 fn get_encoding(&self) -> LinderaResult<&'static Encoding> {
134 let encoding = Encoding::for_label_no_replacement(self.encoding.as_bytes());
135 encoding.ok_or_else(|| {
136 LinderaErrorKind::Decode
137 .with_error(anyhow!("Invalid encoding: {}", self.encoding))
138 .add_context("Failed to get encoding for CSV file reading")
139 })
140 }
141
142 fn read_csv_files(
144 &self,
145 filenames: &[PathBuf],
146 encoding: &'static Encoding,
147 ) -> LinderaResult<Vec<StringRecord>> {
148 let mut rows: Vec<StringRecord> = vec![];
149
150 for filename in filenames {
151 debug!("reading {filename:?}");
152
153 let file = File::open(filename).map_err(|err| {
154 LinderaErrorKind::Io
155 .with_error(anyhow::anyhow!(err))
156 .add_context(format!("Failed to open CSV file: {filename:?}"))
157 })?;
158 let reader: Box<dyn Read> = if encoding == UTF_8 {
159 Box::new(file)
160 } else {
161 Box::new(
162 DecodeReaderBytesBuilder::new()
163 .encoding(Some(encoding))
164 .build(file),
165 )
166 };
167 let mut rdr = csv::ReaderBuilder::new()
168 .has_headers(false)
169 .flexible(self.flexible_csv)
170 .from_reader(reader);
171
172 for result in rdr.records() {
173 let record = result.map_err(|err| {
174 LinderaErrorKind::Content
175 .with_error(anyhow!(err))
176 .add_context(format!("Failed to parse CSV record in file: {filename:?}"))
177 })?;
178 rows.push(record);
179 }
180 }
181
182 Ok(rows)
183 }
184
185 fn build_word_entry_map(
187 &self,
188 rows: &[StringRecord],
189 ) -> LinderaResult<BTreeMap<String, Vec<WordEntry>>> {
190 let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
191
192 for (row_id, row) in rows.iter().enumerate() {
193 let word_cost = self.parse_word_cost(row)?;
194 let left_id = self.parse_left_id(row)?;
195 let right_id = self.parse_right_id(row)?;
196
197 if word_cost.is_none() || left_id.is_none() || right_id.is_none() {
199 continue;
200 }
201
202 let key = if self.normalize_details {
203 if let Some(surface) = self.get_field_value(row, "surface")? {
204 normalize(&surface)
205 } else {
206 continue;
207 }
208 } else if let Some(surface) = self.get_field_value(row, "surface")? {
209 surface
210 } else {
211 continue;
212 };
213
214 word_entry_map.entry(key).or_default().push(WordEntry {
215 word_id: crate::viterbi::WordId::new(
216 crate::viterbi::LexType::System,
217 row_id as u32,
218 ),
219 word_cost: word_cost.unwrap(),
220 left_id: left_id.unwrap(),
221 right_id: right_id.unwrap(),
222 });
223 }
224
225 Ok(word_entry_map)
226 }
227
228 fn get_field_value(
230 &self,
231 row: &StringRecord,
232 field_name: &str,
233 ) -> LinderaResult<Option<String>> {
234 if let Some(index) = self.schema.get_field_index(field_name) {
235 if index >= row.len() {
236 return Ok(None);
237 }
238
239 let value = row[index].trim();
240 Ok(if value.is_empty() {
241 None
242 } else {
243 Some(value.to_string())
244 })
245 } else {
246 Ok(None)
247 }
248 }
249
250 fn parse_word_cost(&self, row: &StringRecord) -> LinderaResult<Option<i16>> {
252 let cost_str = self.get_field_value(row, "cost")?;
253 match cost_str {
254 Some(s) => match i16::from_str(&s) {
255 Ok(cost) => Ok(Some(cost)),
256 Err(_) => {
257 if self.skip_invalid_cost_or_id {
258 Ok(None)
259 } else {
260 Err(LinderaErrorKind::Content
261 .with_error(anyhow!("Invalid cost value: {s}")))
262 }
263 }
264 },
265 None => Ok(None),
266 }
267 }
268
269 fn parse_left_id(&self, row: &StringRecord) -> LinderaResult<Option<u16>> {
271 let left_id_str = self.get_field_value(row, "left_context_id")?;
272 match left_id_str {
273 Some(s) => match u16::from_str(&s) {
274 Ok(id) => Ok(Some(id)),
275 Err(_) => {
276 if self.skip_invalid_cost_or_id {
277 Ok(None)
278 } else {
279 Err(LinderaErrorKind::Content
280 .with_error(anyhow!("Invalid left context ID: {s}")))
281 }
282 }
283 },
284 None => Ok(None),
285 }
286 }
287
288 fn parse_right_id(&self, row: &StringRecord) -> LinderaResult<Option<u16>> {
290 let right_id_str = self.get_field_value(row, "right_context_id")?;
291 match right_id_str {
292 Some(s) => match u16::from_str(&s) {
293 Ok(id) => Ok(Some(id)),
294 Err(_) => {
295 if self.skip_invalid_cost_or_id {
296 Ok(None)
297 } else {
298 Err(LinderaErrorKind::Content
299 .with_error(anyhow!("Invalid right context ID: {s}")))
300 }
301 }
302 },
303 None => Ok(None),
304 }
305 }
306
307 fn write_dictionary_files(
309 &self,
310 output_dir: &Path,
311 rows: &[StringRecord],
312 word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
313 ) -> LinderaResult<()> {
314 self.write_words_files(output_dir, rows)?;
316
317 self.write_double_array_file(output_dir, word_entry_map)?;
319
320 self.write_values_file(output_dir, word_entry_map)?;
322
323 Ok(())
324 }
325
326 fn write_words_files(&self, output_dir: &Path, rows: &[StringRecord]) -> LinderaResult<()> {
328 let mut dict_words_buffer = Vec::new();
329 let mut dict_wordsidx_buffer = Vec::new();
330
331 for row in rows.iter() {
332 let offset = dict_words_buffer.len();
333 dict_wordsidx_buffer
334 .write_u32::<LittleEndian>(offset as u32)
335 .map_err(|err| {
336 LinderaErrorKind::Io
337 .with_error(anyhow::anyhow!(err))
338 .add_context("Failed to write word index offset to dict.wordsidx buffer")
339 })?;
340
341 let joined_details = if self.normalize_details {
343 row.iter()
344 .skip(4)
345 .map(normalize)
346 .collect::<Vec<String>>()
347 .join("\0")
348 } else {
349 row.iter().skip(4).collect::<Vec<&str>>().join("\0")
350 };
351 let joined_details_len = u32::try_from(joined_details.len()).map_err(|err| {
352 LinderaErrorKind::Serialize
353 .with_error(anyhow::anyhow!(err))
354 .add_context(format!(
355 "Word details length too large: {} bytes",
356 joined_details.len()
357 ))
358 })?;
359
360 dict_words_buffer
362 .write_u32::<LittleEndian>(joined_details_len)
363 .map_err(|err| {
364 LinderaErrorKind::Serialize
365 .with_error(anyhow::anyhow!(err))
366 .add_context("Failed to write word details length to dict.words buffer")
367 })?;
368 dict_words_buffer
369 .write_all(joined_details.as_bytes())
370 .map_err(|err| {
371 LinderaErrorKind::Serialize
372 .with_error(anyhow::anyhow!(err))
373 .add_context("Failed to write word details to dict.words buffer")
374 })?;
375 }
376
377 let dict_words_path = output_dir.join(Path::new("dict.words"));
379 let mut dict_words_writer =
380 io::BufWriter::new(File::create(&dict_words_path).map_err(|err| {
381 LinderaErrorKind::Io
382 .with_error(anyhow::anyhow!(err))
383 .add_context(format!(
384 "Failed to create dict.words file: {dict_words_path:?}"
385 ))
386 })?);
387
388 compress_write(
389 &dict_words_buffer,
390 self.compress_algorithm,
391 &mut dict_words_writer,
392 )?;
393
394 dict_words_writer.flush().map_err(|err| {
395 LinderaErrorKind::Io
396 .with_error(anyhow::anyhow!(err))
397 .add_context(format!(
398 "Failed to flush dict.words file: {dict_words_path:?}"
399 ))
400 })?;
401
402 let dict_wordsidx_path = output_dir.join(Path::new("dict.wordsidx"));
404 let mut dict_wordsidx_writer =
405 io::BufWriter::new(File::create(&dict_wordsidx_path).map_err(|err| {
406 LinderaErrorKind::Io
407 .with_error(anyhow::anyhow!(err))
408 .add_context(format!(
409 "Failed to create dict.wordsidx file: {dict_wordsidx_path:?}"
410 ))
411 })?);
412
413 compress_write(
414 &dict_wordsidx_buffer,
415 self.compress_algorithm,
416 &mut dict_wordsidx_writer,
417 )?;
418
419 dict_wordsidx_writer.flush().map_err(|err| {
420 LinderaErrorKind::Io
421 .with_error(anyhow::anyhow!(err))
422 .add_context(format!(
423 "Failed to flush dict.wordsidx file: {dict_wordsidx_path:?}"
424 ))
425 })?;
426
427 Ok(())
428 }
429
430 fn write_double_array_file(
432 &self,
433 output_dir: &Path,
434 word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
435 ) -> LinderaResult<()> {
436 let mut id = 0u32;
437 let mut keyset: Vec<(&[u8], u32)> = vec![];
438
439 for (key, word_entries) in word_entry_map {
440 let len = word_entries.len() as u32;
441 let val = (id << 5) | len; keyset.push((key.as_bytes(), val));
443 id += len;
444 }
445
446 let dict_da_buffer = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
447 LinderaErrorKind::Build
448 .with_error(anyhow::anyhow!("DoubleArray build error."))
449 .add_context(format!(
450 "Failed to build DoubleArray with {} keys for prefix dictionary",
451 keyset.len()
452 ))
453 })?;
454
455 let dict_da_path = output_dir.join(Path::new("dict.da"));
456 let mut dict_da_writer =
457 io::BufWriter::new(File::create(&dict_da_path).map_err(|err| {
458 LinderaErrorKind::Io
459 .with_error(anyhow::anyhow!(err))
460 .add_context(format!("Failed to create dict.da file: {dict_da_path:?}"))
461 })?);
462
463 compress_write(
464 &dict_da_buffer,
465 self.compress_algorithm,
466 &mut dict_da_writer,
467 )?;
468
469 Ok(())
470 }
471
472 fn write_values_file(
474 &self,
475 output_dir: &Path,
476 word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
477 ) -> LinderaResult<()> {
478 let mut dict_vals_buffer = Vec::new();
479 for word_entries in word_entry_map.values() {
480 for word_entry in word_entries {
481 word_entry.serialize(&mut dict_vals_buffer).map_err(|err| {
482 LinderaErrorKind::Serialize
483 .with_error(anyhow::anyhow!(err))
484 .add_context(format!(
485 "Failed to serialize word entry (id: {})",
486 word_entry.word_id.id
487 ))
488 })?;
489 }
490 }
491
492 let dict_vals_path = output_dir.join(Path::new("dict.vals"));
493 let mut dict_vals_writer =
494 io::BufWriter::new(File::create(&dict_vals_path).map_err(|err| {
495 LinderaErrorKind::Io
496 .with_error(anyhow::anyhow!(err))
497 .add_context(format!(
498 "Failed to create dict.vals file: {dict_vals_path:?}"
499 ))
500 })?);
501
502 compress_write(
503 &dict_vals_buffer,
504 self.compress_algorithm,
505 &mut dict_vals_writer,
506 )?;
507
508 dict_vals_writer.flush().map_err(|err| {
509 LinderaErrorKind::Io
510 .with_error(anyhow::anyhow!(err))
511 .add_context(format!(
512 "Failed to flush dict.vals file: {dict_vals_path:?}"
513 ))
514 })?;
515
516 Ok(())
517 }
518}
519
520fn normalize(text: &str) -> String {
521 text.to_string().replace('―', "—").replace('~', "〜")
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use crate::dictionary::schema::Schema;
528 use csv::StringRecord;
529
530 #[test]
531 fn test_new_with_schema() {
532 let schema = Schema::default();
533 let builder = PrefixDictionaryBuilder::new(schema.clone());
534
535 assert!(builder.flexible_csv);
538 assert_eq!(builder.encoding, "UTF-8");
539 assert!(!builder.normalize_details);
540 assert!(!builder.skip_invalid_cost_or_id);
541 }
542
543 #[test]
544 fn test_get_common_field_value_empty() {
545 let schema = Schema::default();
546 let builder = PrefixDictionaryBuilder::new(schema);
547
548 let record = StringRecord::from(vec![
549 "", "123", "456", "789", ]);
554
555 let surface = builder.get_field_value(&record, "surface").unwrap();
556 assert_eq!(surface, None);
557 }
558
559 #[test]
560 fn test_get_common_field_value_out_of_bounds() {
561 let schema = Schema::default();
562 let builder = PrefixDictionaryBuilder::new(schema);
563
564 let record = StringRecord::from(vec![
565 "surface_form", ]);
567
568 let left_id = builder.get_field_value(&record, "left_context_id").unwrap();
569 assert_eq!(left_id, None);
570 }
571
572 #[test]
573 fn test_parse_word_cost() {
574 let schema = Schema::default();
575 let builder = PrefixDictionaryBuilder::new(schema);
576
577 let record = StringRecord::from(vec![
578 "surface_form", "123", "456", "789", ]);
583
584 let cost = builder.parse_word_cost(&record).unwrap();
585 assert_eq!(cost, Some(789));
586 }
587
588 #[test]
589 fn test_parse_word_cost_invalid() {
590 let schema = Schema::default();
591 let builder = PrefixDictionaryBuilder::new(schema);
592
593 let record = StringRecord::from(vec![
594 "surface_form", "123", "456", "invalid", ]);
599
600 let result = builder.parse_word_cost(&record);
601 assert!(result.is_err());
602 }
603
604 #[test]
605 fn test_parse_word_cost_skip_invalid() {
606 let schema = Schema::default();
607 let mut builder = PrefixDictionaryBuilder::new(schema);
608 builder.skip_invalid_cost_or_id = true;
609
610 let record = StringRecord::from(vec![
611 "surface_form", "123", "456", "invalid", ]);
616
617 let cost = builder.parse_word_cost(&record).unwrap();
618 assert_eq!(cost, None);
619 }
620
621 #[test]
622 fn test_parse_left_id() {
623 let schema = Schema::default();
624 let builder = PrefixDictionaryBuilder::new(schema);
625
626 let record = StringRecord::from(vec![
627 "surface_form", "123", "456", "789", ]);
632
633 let left_id = builder.parse_left_id(&record).unwrap();
634 assert_eq!(left_id, Some(123));
635 }
636
637 #[test]
638 fn test_parse_right_id() {
639 let schema = Schema::default();
640 let builder = PrefixDictionaryBuilder::new(schema);
641
642 let record = StringRecord::from(vec![
643 "surface_form", "123", "456", "789", ]);
648
649 let right_id = builder.parse_right_id(&record).unwrap();
650 assert_eq!(right_id, Some(456));
651 }
652
653 #[test]
654 fn test_normalize_function() {
655 assert_eq!(normalize("test―text"), "test—text");
656 assert_eq!(normalize("test~text"), "test〜text");
657 assert_eq!(normalize("test―text~more"), "test—text〜more");
658 assert_eq!(normalize("normal text"), "normal text");
659 }
660
661 #[test]
662 fn test_get_encoding() {
663 let schema = Schema::default();
664 let builder = PrefixDictionaryBuilder::new(schema);
665
666 let encoding = builder.get_encoding().unwrap();
667 assert_eq!(encoding.name(), "UTF-8");
668 }
669
670 #[test]
671 fn test_get_encoding_invalid() {
672 let schema = Schema::default();
673 let mut builder = PrefixDictionaryBuilder::new(schema);
674 builder.encoding = "INVALID-ENCODING".into();
675
676 let result = builder.get_encoding();
677 assert!(result.is_err());
678 }
679
680 #[test]
681 fn test_get_common_field_value() {
682 let schema = Schema::default();
683 let builder = PrefixDictionaryBuilder::new(schema);
684
685 let record = StringRecord::from(vec![
686 "word", "123", "456", "789", "名詞", ]);
692
693 assert_eq!(
695 builder.get_field_value(&record, "surface").unwrap(),
696 Some("word".to_string())
697 );
698 assert_eq!(
699 builder.get_field_value(&record, "left_context_id").unwrap(),
700 Some("123".to_string())
701 );
702 assert_eq!(
703 builder
704 .get_field_value(&record, "right_context_id")
705 .unwrap(),
706 Some("456".to_string())
707 );
708 assert_eq!(
709 builder.get_field_value(&record, "cost").unwrap(),
710 Some("789".to_string())
711 );
712
713 let short_record = StringRecord::from(vec!["word", "123"]);
715 assert_eq!(
716 builder.get_field_value(&short_record, "cost").unwrap(),
717 None
718 );
719 }
720}