Skip to main content

lindera_dictionary/builder/
prefix_dictionary.rs

1use std::borrow::Cow;
2use std::collections::BTreeMap;
3use std::fs::File;
4use std::io::Write;
5use std::io::{self, Read};
6use std::path::{Path, PathBuf};
7use std::str::FromStr;
8
9use anyhow::anyhow;
10use byteorder::{LittleEndian, WriteBytesExt};
11use csv::StringRecord;
12use daachorse::DoubleArrayAhoCorasickBuilder;
13use derive_builder::Builder;
14use encoding_rs::{Encoding, UTF_8};
15use encoding_rs_io::DecodeReaderBytesBuilder;
16use glob::glob;
17use log::debug;
18
19use crate::LinderaResult;
20use crate::dictionary::schema::Schema;
21use crate::error::LinderaErrorKind;
22use crate::util::write_data;
23use crate::viterbi::WordEntry;
24
25#[derive(Builder)]
26#[builder(name = PrefixDictionaryBuilderOptions)]
27#[builder(build_fn(name = "builder"))]
28pub struct PrefixDictionaryBuilder {
29    #[builder(default = "true")]
30    flexible_csv: bool,
31    /* If set to UTF-8, it can also read UTF-16 files with BOM. */
32    #[builder(default = "\"UTF-8\".into()", setter(into))]
33    encoding: Cow<'static, str>,
34    #[builder(default = "false")]
35    normalize_details: bool,
36    #[builder(default = "false")]
37    skip_invalid_cost_or_id: bool,
38    #[builder(default = "Schema::default()")]
39    schema: Schema,
40}
41
42impl PrefixDictionaryBuilder {
43    /// Create a new builder with the specified schema
44    pub fn new(schema: Schema) -> Self {
45        Self {
46            flexible_csv: true,
47            encoding: "UTF-8".into(),
48            normalize_details: false,
49            skip_invalid_cost_or_id: false,
50            schema,
51        }
52    }
53
54    /// Main method for building the dictionary
55    pub fn build(&self, input_dir: &Path, output_dir: &Path) -> LinderaResult<()> {
56        // 1. Load CSV data
57        let rows = self.load_csv_data(input_dir)?;
58
59        // 2. Build word entry map
60        let word_entry_map = self.build_word_entry_map(&rows)?;
61
62        // 3. Write dictionary files
63        self.write_dictionary_files(output_dir, &rows, &word_entry_map)?;
64
65        Ok(())
66    }
67
68    /// Load data from CSV files
69    fn load_csv_data(&self, input_dir: &Path) -> LinderaResult<Vec<StringRecord>> {
70        let filenames = self.collect_csv_files(input_dir)?;
71        let encoding = self.get_encoding()?;
72        let mut rows = self.read_csv_files(&filenames, encoding)?;
73
74        // Sort dictionary entries by the first column (word)
75        // Change sorting method based on normalization settings
76        if self.normalize_details {
77            // Sort after normalizing characters (―→—, ~→〜)
78            rows.sort_by_key(|row| normalize(&row[0]));
79        } else {
80            // Sort using original strings directly
81            rows.sort_by(|a, b| a[0].cmp(&b[0]))
82        }
83
84        Ok(rows)
85    }
86
87    /// Collect .csv file paths from input directory
88    fn collect_csv_files(&self, input_dir: &Path) -> LinderaResult<Vec<PathBuf>> {
89        let pattern = if let Some(path) = input_dir.to_str() {
90            format!("{path}/*.csv")
91        } else {
92            return Err(LinderaErrorKind::Io
93                .with_error(anyhow::anyhow!("Failed to convert path to &str."))
94                .add_context(format!(
95                    "Input directory path contains invalid characters: {input_dir:?}"
96                )));
97        };
98
99        let mut filenames: Vec<PathBuf> = Vec::new();
100        for entry in glob(&pattern).map_err(|err| {
101            LinderaErrorKind::Io
102                .with_error(anyhow::anyhow!(err))
103                .add_context(format!("Failed to glob CSV files with pattern: {pattern}"))
104        })? {
105            match entry {
106                Ok(path) => {
107                    if let Some(filename) = path.file_name() {
108                        filenames.push(Path::new(input_dir).join(filename));
109                    } else {
110                        return Err(LinderaErrorKind::Io
111                            .with_error(anyhow::anyhow!("failed to get filename"))
112                            .add_context(format!("Invalid filename in path: {path:?}")));
113                    };
114                }
115                Err(err) => {
116                    return Err(LinderaErrorKind::Content
117                        .with_error(anyhow!(err))
118                        .add_context(format!(
119                            "Failed to process glob entry with pattern: {pattern}"
120                        )));
121                }
122            }
123        }
124
125        Ok(filenames)
126    }
127
128    /// Get encoding configuration
129    fn get_encoding(&self) -> LinderaResult<&'static Encoding> {
130        let encoding = Encoding::for_label_no_replacement(self.encoding.as_bytes());
131        encoding.ok_or_else(|| {
132            LinderaErrorKind::Decode
133                .with_error(anyhow!("Invalid encoding: {}", self.encoding))
134                .add_context("Failed to get encoding for CSV file reading")
135        })
136    }
137
138    /// Read CSV files
139    fn read_csv_files(
140        &self,
141        filenames: &[PathBuf],
142        encoding: &'static Encoding,
143    ) -> LinderaResult<Vec<StringRecord>> {
144        let mut rows: Vec<StringRecord> = vec![];
145
146        for filename in filenames {
147            debug!("reading {filename:?}");
148
149            let file = File::open(filename).map_err(|err| {
150                LinderaErrorKind::Io
151                    .with_error(anyhow::anyhow!(err))
152                    .add_context(format!("Failed to open CSV file: {filename:?}"))
153            })?;
154            let reader: Box<dyn Read> = if encoding == UTF_8 {
155                Box::new(file)
156            } else {
157                Box::new(
158                    DecodeReaderBytesBuilder::new()
159                        .encoding(Some(encoding))
160                        .build(file),
161                )
162            };
163            let mut rdr = csv::ReaderBuilder::new()
164                .has_headers(false)
165                .flexible(self.flexible_csv)
166                .from_reader(reader);
167
168            for result in rdr.records() {
169                let record = result.map_err(|err| {
170                    LinderaErrorKind::Content
171                        .with_error(anyhow!(err))
172                        .add_context(format!("Failed to parse CSV record in file: {filename:?}"))
173                })?;
174                rows.push(record);
175            }
176        }
177
178        Ok(rows)
179    }
180
181    /// Build word entry map
182    fn build_word_entry_map(
183        &self,
184        rows: &[StringRecord],
185    ) -> LinderaResult<BTreeMap<String, Vec<WordEntry>>> {
186        let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
187
188        for (row_id, row) in rows.iter().enumerate() {
189            let word_cost = self.parse_word_cost(row)?;
190            let left_id = self.parse_left_id(row)?;
191            let right_id = self.parse_right_id(row)?;
192
193            // Skip if any value is invalid
194            if word_cost.is_none() || left_id.is_none() || right_id.is_none() {
195                continue;
196            }
197
198            let key = if self.normalize_details {
199                if let Some(surface) = self.get_field_value(row, "surface")? {
200                    normalize(&surface)
201                } else {
202                    continue;
203                }
204            } else if let Some(surface) = self.get_field_value(row, "surface")? {
205                surface
206            } else {
207                continue;
208            };
209
210            word_entry_map.entry(key).or_default().push(WordEntry {
211                word_id: crate::viterbi::WordId::new(
212                    crate::viterbi::LexType::System,
213                    row_id as u32,
214                ),
215                word_cost: word_cost.unwrap(),
216                left_id: left_id.unwrap(),
217                right_id: right_id.unwrap(),
218            });
219        }
220
221        Ok(word_entry_map)
222    }
223
224    /// Get field value by name
225    fn get_field_value(
226        &self,
227        row: &StringRecord,
228        field_name: &str,
229    ) -> LinderaResult<Option<String>> {
230        if let Some(index) = self.schema.get_field_index(field_name) {
231            if index >= row.len() {
232                return Ok(None);
233            }
234
235            let value = row[index].trim();
236            Ok(if value.is_empty() {
237                None
238            } else {
239                Some(value.to_string())
240            })
241        } else {
242            Ok(None)
243        }
244    }
245
246    /// Parse word cost using schema
247    fn parse_word_cost(&self, row: &StringRecord) -> LinderaResult<Option<i16>> {
248        let cost_str = self.get_field_value(row, "cost")?;
249        match cost_str {
250            Some(s) => match i16::from_str(&s) {
251                Ok(cost) => Ok(Some(cost)),
252                Err(_) => {
253                    if self.skip_invalid_cost_or_id {
254                        Ok(None)
255                    } else {
256                        Err(LinderaErrorKind::Content
257                            .with_error(anyhow!("Invalid cost value: {s}")))
258                    }
259                }
260            },
261            None => Ok(None),
262        }
263    }
264
265    /// Parse left ID using schema
266    fn parse_left_id(&self, row: &StringRecord) -> LinderaResult<Option<u16>> {
267        let left_id_str = self.get_field_value(row, "left_context_id")?;
268        match left_id_str {
269            Some(s) => match u16::from_str(&s) {
270                Ok(id) => Ok(Some(id)),
271                Err(_) => {
272                    if self.skip_invalid_cost_or_id {
273                        Ok(None)
274                    } else {
275                        Err(LinderaErrorKind::Content
276                            .with_error(anyhow!("Invalid left context ID: {s}")))
277                    }
278                }
279            },
280            None => Ok(None),
281        }
282    }
283
284    /// Parse right ID using schema
285    fn parse_right_id(&self, row: &StringRecord) -> LinderaResult<Option<u16>> {
286        let right_id_str = self.get_field_value(row, "right_context_id")?;
287        match right_id_str {
288            Some(s) => match u16::from_str(&s) {
289                Ok(id) => Ok(Some(id)),
290                Err(_) => {
291                    if self.skip_invalid_cost_or_id {
292                        Ok(None)
293                    } else {
294                        Err(LinderaErrorKind::Content
295                            .with_error(anyhow!("Invalid right context ID: {s}")))
296                    }
297                }
298            },
299            None => Ok(None),
300        }
301    }
302
303    /// Write dictionary files
304    fn write_dictionary_files(
305        &self,
306        output_dir: &Path,
307        rows: &[StringRecord],
308        word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
309    ) -> LinderaResult<()> {
310        // Write dict.words and dict.wordsidx
311        self.write_words_files(output_dir, rows)?;
312
313        // Write dict.da
314        self.write_double_array_file(output_dir, word_entry_map)?;
315
316        // Write dict.vals
317        self.write_values_file(output_dir, word_entry_map)?;
318
319        Ok(())
320    }
321
322    /// Write word detail files (dict.words, dict.wordsidx)
323    fn write_words_files(&self, output_dir: &Path, rows: &[StringRecord]) -> LinderaResult<()> {
324        let mut dict_words_buffer = Vec::new();
325        let mut dict_wordsidx_buffer = Vec::new();
326
327        for row in rows.iter() {
328            let offset = dict_words_buffer.len();
329            dict_wordsidx_buffer
330                .write_u32::<LittleEndian>(offset as u32)
331                .map_err(|err| {
332                    LinderaErrorKind::Io
333                        .with_error(anyhow::anyhow!(err))
334                        .add_context("Failed to write word index offset to dict.wordsidx buffer")
335                })?;
336
337            // Create word details from the row data (5th column and beyond)
338            let joined_details = if self.normalize_details {
339                row.iter()
340                    .skip(4)
341                    .map(normalize)
342                    .collect::<Vec<String>>()
343                    .join("\0")
344            } else {
345                row.iter().skip(4).collect::<Vec<&str>>().join("\0")
346            };
347            let joined_details_len = u32::try_from(joined_details.len()).map_err(|err| {
348                LinderaErrorKind::Serialize
349                    .with_error(anyhow::anyhow!(err))
350                    .add_context(format!(
351                        "Word details length too large: {} bytes",
352                        joined_details.len()
353                    ))
354            })?;
355
356            // Write to dict.words buffer
357            dict_words_buffer
358                .write_u32::<LittleEndian>(joined_details_len)
359                .map_err(|err| {
360                    LinderaErrorKind::Serialize
361                        .with_error(anyhow::anyhow!(err))
362                        .add_context("Failed to write word details length to dict.words buffer")
363                })?;
364            dict_words_buffer
365                .write_all(joined_details.as_bytes())
366                .map_err(|err| {
367                    LinderaErrorKind::Serialize
368                        .with_error(anyhow::anyhow!(err))
369                        .add_context("Failed to write word details to dict.words buffer")
370                })?;
371        }
372
373        // Write dict.words file
374        let dict_words_path = output_dir.join(Path::new("dict.words"));
375        let mut dict_words_writer =
376            io::BufWriter::new(File::create(&dict_words_path).map_err(|err| {
377                LinderaErrorKind::Io
378                    .with_error(anyhow::anyhow!(err))
379                    .add_context(format!(
380                        "Failed to create dict.words file: {dict_words_path:?}"
381                    ))
382            })?);
383
384        write_data(&dict_words_buffer, &mut dict_words_writer)?;
385
386        dict_words_writer.flush().map_err(|err| {
387            LinderaErrorKind::Io
388                .with_error(anyhow::anyhow!(err))
389                .add_context(format!(
390                    "Failed to flush dict.words file: {dict_words_path:?}"
391                ))
392        })?;
393
394        // Write dict.wordsidx file
395        let dict_wordsidx_path = output_dir.join(Path::new("dict.wordsidx"));
396        let mut dict_wordsidx_writer =
397            io::BufWriter::new(File::create(&dict_wordsidx_path).map_err(|err| {
398                LinderaErrorKind::Io
399                    .with_error(anyhow::anyhow!(err))
400                    .add_context(format!(
401                        "Failed to create dict.wordsidx file: {dict_wordsidx_path:?}"
402                    ))
403            })?);
404
405        write_data(&dict_wordsidx_buffer, &mut dict_wordsidx_writer)?;
406
407        dict_wordsidx_writer.flush().map_err(|err| {
408            LinderaErrorKind::Io
409                .with_error(anyhow::anyhow!(err))
410                .add_context(format!(
411                    "Failed to flush dict.wordsidx file: {dict_wordsidx_path:?}"
412                ))
413        })?;
414
415        Ok(())
416    }
417
418    /// Write double array file (dict.da)
419    fn write_double_array_file(
420        &self,
421        output_dir: &Path,
422        word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
423    ) -> LinderaResult<()> {
424        let mut id = 0u32;
425        let mut keyset: Vec<(&[u8], u32)> = vec![];
426
427        for (key, word_entries) in word_entry_map {
428            let len = word_entries.len() as u32;
429            let val = (id << 8) | len; // 24bit for word ID, 8bit for variant count (up to 255 per surface).
430            keyset.push((key.as_bytes(), val));
431            id += len;
432        }
433
434        let keyset_len = keyset.len();
435
436        let dict_da = DoubleArrayAhoCorasickBuilder::new()
437            .build_with_values(keyset)
438            .map_err(|err| {
439                LinderaErrorKind::Build
440                    .with_error(anyhow::anyhow!(err))
441                    .add_context(format!(
442                        "Failed to build DoubleArray with {} keys for prefix dictionary",
443                        keyset_len
444                    ))
445            })?;
446
447        let dict_da_buffer = dict_da.serialize();
448
449        let dict_da_path = output_dir.join(Path::new("dict.da"));
450        let mut dict_da_writer =
451            io::BufWriter::new(File::create(&dict_da_path).map_err(|err| {
452                LinderaErrorKind::Io
453                    .with_error(anyhow::anyhow!(err))
454                    .add_context(format!("Failed to create dict.da file: {dict_da_path:?}"))
455            })?);
456
457        write_data(&dict_da_buffer, &mut dict_da_writer)?;
458
459        Ok(())
460    }
461
462    /// Write values file (dict.vals)
463    fn write_values_file(
464        &self,
465        output_dir: &Path,
466        word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
467    ) -> LinderaResult<()> {
468        let mut dict_vals_buffer = Vec::new();
469        for word_entries in word_entry_map.values() {
470            for word_entry in word_entries {
471                word_entry.serialize(&mut dict_vals_buffer).map_err(|err| {
472                    LinderaErrorKind::Serialize
473                        .with_error(anyhow::anyhow!(err))
474                        .add_context(format!(
475                            "Failed to serialize word entry (id: {})",
476                            word_entry.word_id.id
477                        ))
478                })?;
479            }
480        }
481
482        let dict_vals_path = output_dir.join(Path::new("dict.vals"));
483        let mut dict_vals_writer =
484            io::BufWriter::new(File::create(&dict_vals_path).map_err(|err| {
485                LinderaErrorKind::Io
486                    .with_error(anyhow::anyhow!(err))
487                    .add_context(format!(
488                        "Failed to create dict.vals file: {dict_vals_path:?}"
489                    ))
490            })?);
491
492        write_data(&dict_vals_buffer, &mut dict_vals_writer)?;
493
494        dict_vals_writer.flush().map_err(|err| {
495            LinderaErrorKind::Io
496                .with_error(anyhow::anyhow!(err))
497                .add_context(format!(
498                    "Failed to flush dict.vals file: {dict_vals_path:?}"
499                ))
500        })?;
501
502        Ok(())
503    }
504}
505
506fn normalize(text: &str) -> String {
507    text.to_string().replace('―', "—").replace('~', "〜")
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513    use crate::dictionary::schema::Schema;
514    use csv::StringRecord;
515
516    #[test]
517    fn test_new_with_schema() {
518        let schema = Schema::default();
519        let builder = PrefixDictionaryBuilder::new(schema.clone());
520
521        // Schema no longer has name field
522        // Schema no longer has version field
523        assert!(builder.flexible_csv);
524        assert_eq!(builder.encoding, "UTF-8");
525        assert!(!builder.normalize_details);
526        assert!(!builder.skip_invalid_cost_or_id);
527    }
528
529    #[test]
530    fn test_get_common_field_value_empty() {
531        let schema = Schema::default();
532        let builder = PrefixDictionaryBuilder::new(schema);
533
534        let record = StringRecord::from(vec![
535            "",    // Empty surface
536            "123", // LeftContextId
537            "456", // RightContextId
538            "789", // Cost
539        ]);
540
541        let surface = builder.get_field_value(&record, "surface").unwrap();
542        assert_eq!(surface, None);
543    }
544
545    #[test]
546    fn test_get_common_field_value_out_of_bounds() {
547        let schema = Schema::default();
548        let builder = PrefixDictionaryBuilder::new(schema);
549
550        let record = StringRecord::from(vec![
551            "surface_form", // Surface only
552        ]);
553
554        let left_id = builder.get_field_value(&record, "left_context_id").unwrap();
555        assert_eq!(left_id, None);
556    }
557
558    #[test]
559    fn test_parse_word_cost() {
560        let schema = Schema::default();
561        let builder = PrefixDictionaryBuilder::new(schema);
562
563        let record = StringRecord::from(vec![
564            "surface_form", // Surface
565            "123",          // LeftContextId
566            "456",          // RightContextId
567            "789",          // Cost
568        ]);
569
570        let cost = builder.parse_word_cost(&record).unwrap();
571        assert_eq!(cost, Some(789));
572    }
573
574    #[test]
575    fn test_parse_word_cost_invalid() {
576        let schema = Schema::default();
577        let builder = PrefixDictionaryBuilder::new(schema);
578
579        let record = StringRecord::from(vec![
580            "surface_form", // Surface
581            "123",          // LeftContextId
582            "456",          // RightContextId
583            "invalid",      // Invalid cost
584        ]);
585
586        let result = builder.parse_word_cost(&record);
587        assert!(result.is_err());
588    }
589
590    #[test]
591    fn test_parse_word_cost_skip_invalid() {
592        let schema = Schema::default();
593        let mut builder = PrefixDictionaryBuilder::new(schema);
594        builder.skip_invalid_cost_or_id = true;
595
596        let record = StringRecord::from(vec![
597            "surface_form", // Surface
598            "123",          // LeftContextId
599            "456",          // RightContextId
600            "invalid",      // Invalid cost
601        ]);
602
603        let cost = builder.parse_word_cost(&record).unwrap();
604        assert_eq!(cost, None);
605    }
606
607    #[test]
608    fn test_parse_left_id() {
609        let schema = Schema::default();
610        let builder = PrefixDictionaryBuilder::new(schema);
611
612        let record = StringRecord::from(vec![
613            "surface_form", // Surface
614            "123",          // LeftContextId
615            "456",          // RightContextId
616            "789",          // Cost
617        ]);
618
619        let left_id = builder.parse_left_id(&record).unwrap();
620        assert_eq!(left_id, Some(123));
621    }
622
623    #[test]
624    fn test_parse_right_id() {
625        let schema = Schema::default();
626        let builder = PrefixDictionaryBuilder::new(schema);
627
628        let record = StringRecord::from(vec![
629            "surface_form", // Surface
630            "123",          // LeftContextId
631            "456",          // RightContextId
632            "789",          // Cost
633        ]);
634
635        let right_id = builder.parse_right_id(&record).unwrap();
636        assert_eq!(right_id, Some(456));
637    }
638
639    #[test]
640    fn test_normalize_function() {
641        assert_eq!(normalize("test―text"), "test—text");
642        assert_eq!(normalize("test~text"), "test〜text");
643        assert_eq!(normalize("test―text~more"), "test—text〜more");
644        assert_eq!(normalize("normal text"), "normal text");
645    }
646
647    #[test]
648    fn test_get_encoding() {
649        let schema = Schema::default();
650        let builder = PrefixDictionaryBuilder::new(schema);
651
652        let encoding = builder.get_encoding().unwrap();
653        assert_eq!(encoding.name(), "UTF-8");
654    }
655
656    #[test]
657    fn test_get_encoding_invalid() {
658        let schema = Schema::default();
659        let mut builder = PrefixDictionaryBuilder::new(schema);
660        builder.encoding = "INVALID-ENCODING".into();
661
662        let result = builder.get_encoding();
663        assert!(result.is_err());
664    }
665
666    #[test]
667    fn test_get_common_field_value() {
668        let schema = Schema::default();
669        let builder = PrefixDictionaryBuilder::new(schema);
670
671        let record = StringRecord::from(vec![
672            "word", // Surface
673            "123",  // LeftContextId
674            "456",  // RightContextId
675            "789",  // Cost
676            "名詞", // MajorPos
677        ]);
678
679        // Test common fields
680        assert_eq!(
681            builder.get_field_value(&record, "surface").unwrap(),
682            Some("word".to_string())
683        );
684        assert_eq!(
685            builder.get_field_value(&record, "left_context_id").unwrap(),
686            Some("123".to_string())
687        );
688        assert_eq!(
689            builder
690                .get_field_value(&record, "right_context_id")
691                .unwrap(),
692            Some("456".to_string())
693        );
694        assert_eq!(
695            builder.get_field_value(&record, "cost").unwrap(),
696            Some("789".to_string())
697        );
698
699        // Test case where field is out of bounds - should return None, not an error
700        let short_record = StringRecord::from(vec!["word", "123"]);
701        assert_eq!(
702            builder.get_field_value(&short_record, "cost").unwrap(),
703            None
704        );
705    }
706}