Skip to main content

lindera_dictionary/builder/
prefix_dictionary.rs

1use std::borrow::Cow;
2use std::collections::BTreeMap;
3use std::fs::File;
4use std::io::Write;
5use std::io::{self, Read};
6use std::path::{Path, PathBuf};
7use std::str::FromStr;
8
9use anyhow::anyhow;
10use byteorder::{LittleEndian, WriteBytesExt};
11use csv::StringRecord;
12use daachorse::DoubleArrayAhoCorasickBuilder;
13use derive_builder::Builder;
14use encoding_rs::{Encoding, UTF_8};
15use encoding_rs_io::DecodeReaderBytesBuilder;
16use glob::glob;
17use log::debug;
18
19use crate::LinderaResult;
20use crate::decompress::Algorithm;
21use crate::dictionary::schema::Schema;
22use crate::error::LinderaErrorKind;
23use crate::util::compress_write;
24use crate::viterbi::WordEntry;
25
26#[derive(Builder)]
27#[builder(name = PrefixDictionaryBuilderOptions)]
28#[builder(build_fn(name = "builder"))]
29pub struct PrefixDictionaryBuilder {
30    #[builder(default = "true")]
31    flexible_csv: bool,
32    /* If set to UTF-8, it can also read UTF-16 files with BOM. */
33    #[builder(default = "\"UTF-8\".into()", setter(into))]
34    encoding: Cow<'static, str>,
35    #[builder(default = "Algorithm::Deflate")]
36    compress_algorithm: Algorithm,
37    #[builder(default = "false")]
38    normalize_details: bool,
39    #[builder(default = "false")]
40    skip_invalid_cost_or_id: bool,
41    #[builder(default = "Schema::default()")]
42    schema: Schema,
43}
44
45impl PrefixDictionaryBuilder {
46    /// Create a new builder with the specified schema
47    pub fn new(schema: Schema) -> Self {
48        Self {
49            flexible_csv: true,
50            encoding: "UTF-8".into(),
51            compress_algorithm: Algorithm::Deflate,
52            normalize_details: false,
53            skip_invalid_cost_or_id: false,
54            schema,
55        }
56    }
57
58    /// Main method for building the dictionary
59    pub fn build(&self, input_dir: &Path, output_dir: &Path) -> LinderaResult<()> {
60        // 1. Load CSV data
61        let rows = self.load_csv_data(input_dir)?;
62
63        // 2. Build word entry map
64        let word_entry_map = self.build_word_entry_map(&rows)?;
65
66        // 3. Write dictionary files
67        self.write_dictionary_files(output_dir, &rows, &word_entry_map)?;
68
69        Ok(())
70    }
71
72    /// Load data from CSV files
73    fn load_csv_data(&self, input_dir: &Path) -> LinderaResult<Vec<StringRecord>> {
74        let filenames = self.collect_csv_files(input_dir)?;
75        let encoding = self.get_encoding()?;
76        let mut rows = self.read_csv_files(&filenames, encoding)?;
77
78        // Sort dictionary entries by the first column (word)
79        // Change sorting method based on normalization settings
80        if self.normalize_details {
81            // Sort after normalizing characters (―→—, ~→〜)
82            rows.sort_by_key(|row| normalize(&row[0]));
83        } else {
84            // Sort using original strings directly
85            rows.sort_by(|a, b| a[0].cmp(&b[0]))
86        }
87
88        Ok(rows)
89    }
90
91    /// Collect .csv file paths from input directory
92    fn collect_csv_files(&self, input_dir: &Path) -> LinderaResult<Vec<PathBuf>> {
93        let pattern = if let Some(path) = input_dir.to_str() {
94            format!("{path}/*.csv")
95        } else {
96            return Err(LinderaErrorKind::Io
97                .with_error(anyhow::anyhow!("Failed to convert path to &str."))
98                .add_context(format!(
99                    "Input directory path contains invalid characters: {input_dir:?}"
100                )));
101        };
102
103        let mut filenames: Vec<PathBuf> = Vec::new();
104        for entry in glob(&pattern).map_err(|err| {
105            LinderaErrorKind::Io
106                .with_error(anyhow::anyhow!(err))
107                .add_context(format!("Failed to glob CSV files with pattern: {pattern}"))
108        })? {
109            match entry {
110                Ok(path) => {
111                    if let Some(filename) = path.file_name() {
112                        filenames.push(Path::new(input_dir).join(filename));
113                    } else {
114                        return Err(LinderaErrorKind::Io
115                            .with_error(anyhow::anyhow!("failed to get filename"))
116                            .add_context(format!("Invalid filename in path: {path:?}")));
117                    };
118                }
119                Err(err) => {
120                    return Err(LinderaErrorKind::Content
121                        .with_error(anyhow!(err))
122                        .add_context(format!(
123                            "Failed to process glob entry with pattern: {pattern}"
124                        )));
125                }
126            }
127        }
128
129        Ok(filenames)
130    }
131
132    /// Get encoding configuration
133    fn get_encoding(&self) -> LinderaResult<&'static Encoding> {
134        let encoding = Encoding::for_label_no_replacement(self.encoding.as_bytes());
135        encoding.ok_or_else(|| {
136            LinderaErrorKind::Decode
137                .with_error(anyhow!("Invalid encoding: {}", self.encoding))
138                .add_context("Failed to get encoding for CSV file reading")
139        })
140    }
141
142    /// Read CSV files
143    fn read_csv_files(
144        &self,
145        filenames: &[PathBuf],
146        encoding: &'static Encoding,
147    ) -> LinderaResult<Vec<StringRecord>> {
148        let mut rows: Vec<StringRecord> = vec![];
149
150        for filename in filenames {
151            debug!("reading {filename:?}");
152
153            let file = File::open(filename).map_err(|err| {
154                LinderaErrorKind::Io
155                    .with_error(anyhow::anyhow!(err))
156                    .add_context(format!("Failed to open CSV file: {filename:?}"))
157            })?;
158            let reader: Box<dyn Read> = if encoding == UTF_8 {
159                Box::new(file)
160            } else {
161                Box::new(
162                    DecodeReaderBytesBuilder::new()
163                        .encoding(Some(encoding))
164                        .build(file),
165                )
166            };
167            let mut rdr = csv::ReaderBuilder::new()
168                .has_headers(false)
169                .flexible(self.flexible_csv)
170                .from_reader(reader);
171
172            for result in rdr.records() {
173                let record = result.map_err(|err| {
174                    LinderaErrorKind::Content
175                        .with_error(anyhow!(err))
176                        .add_context(format!("Failed to parse CSV record in file: {filename:?}"))
177                })?;
178                rows.push(record);
179            }
180        }
181
182        Ok(rows)
183    }
184
185    /// Build word entry map
186    fn build_word_entry_map(
187        &self,
188        rows: &[StringRecord],
189    ) -> LinderaResult<BTreeMap<String, Vec<WordEntry>>> {
190        let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
191
192        for (row_id, row) in rows.iter().enumerate() {
193            let word_cost = self.parse_word_cost(row)?;
194            let left_id = self.parse_left_id(row)?;
195            let right_id = self.parse_right_id(row)?;
196
197            // Skip if any value is invalid
198            if word_cost.is_none() || left_id.is_none() || right_id.is_none() {
199                continue;
200            }
201
202            let key = if self.normalize_details {
203                if let Some(surface) = self.get_field_value(row, "surface")? {
204                    normalize(&surface)
205                } else {
206                    continue;
207                }
208            } else if let Some(surface) = self.get_field_value(row, "surface")? {
209                surface
210            } else {
211                continue;
212            };
213
214            word_entry_map.entry(key).or_default().push(WordEntry {
215                word_id: crate::viterbi::WordId::new(
216                    crate::viterbi::LexType::System,
217                    row_id as u32,
218                ),
219                word_cost: word_cost.unwrap(),
220                left_id: left_id.unwrap(),
221                right_id: right_id.unwrap(),
222            });
223        }
224
225        Ok(word_entry_map)
226    }
227
228    /// Get field value by name
229    fn get_field_value(
230        &self,
231        row: &StringRecord,
232        field_name: &str,
233    ) -> LinderaResult<Option<String>> {
234        if let Some(index) = self.schema.get_field_index(field_name) {
235            if index >= row.len() {
236                return Ok(None);
237            }
238
239            let value = row[index].trim();
240            Ok(if value.is_empty() {
241                None
242            } else {
243                Some(value.to_string())
244            })
245        } else {
246            Ok(None)
247        }
248    }
249
250    /// Parse word cost using schema
251    fn parse_word_cost(&self, row: &StringRecord) -> LinderaResult<Option<i16>> {
252        let cost_str = self.get_field_value(row, "cost")?;
253        match cost_str {
254            Some(s) => match i16::from_str(&s) {
255                Ok(cost) => Ok(Some(cost)),
256                Err(_) => {
257                    if self.skip_invalid_cost_or_id {
258                        Ok(None)
259                    } else {
260                        Err(LinderaErrorKind::Content
261                            .with_error(anyhow!("Invalid cost value: {s}")))
262                    }
263                }
264            },
265            None => Ok(None),
266        }
267    }
268
269    /// Parse left ID using schema
270    fn parse_left_id(&self, row: &StringRecord) -> LinderaResult<Option<u16>> {
271        let left_id_str = self.get_field_value(row, "left_context_id")?;
272        match left_id_str {
273            Some(s) => match u16::from_str(&s) {
274                Ok(id) => Ok(Some(id)),
275                Err(_) => {
276                    if self.skip_invalid_cost_or_id {
277                        Ok(None)
278                    } else {
279                        Err(LinderaErrorKind::Content
280                            .with_error(anyhow!("Invalid left context ID: {s}")))
281                    }
282                }
283            },
284            None => Ok(None),
285        }
286    }
287
288    /// Parse right ID using schema
289    fn parse_right_id(&self, row: &StringRecord) -> LinderaResult<Option<u16>> {
290        let right_id_str = self.get_field_value(row, "right_context_id")?;
291        match right_id_str {
292            Some(s) => match u16::from_str(&s) {
293                Ok(id) => Ok(Some(id)),
294                Err(_) => {
295                    if self.skip_invalid_cost_or_id {
296                        Ok(None)
297                    } else {
298                        Err(LinderaErrorKind::Content
299                            .with_error(anyhow!("Invalid right context ID: {s}")))
300                    }
301                }
302            },
303            None => Ok(None),
304        }
305    }
306
307    /// Write dictionary files
308    fn write_dictionary_files(
309        &self,
310        output_dir: &Path,
311        rows: &[StringRecord],
312        word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
313    ) -> LinderaResult<()> {
314        // Write dict.words and dict.wordsidx
315        self.write_words_files(output_dir, rows)?;
316
317        // Write dict.da
318        self.write_double_array_file(output_dir, word_entry_map)?;
319
320        // Write dict.vals
321        self.write_values_file(output_dir, word_entry_map)?;
322
323        Ok(())
324    }
325
326    /// Write word detail files (dict.words, dict.wordsidx)
327    fn write_words_files(&self, output_dir: &Path, rows: &[StringRecord]) -> LinderaResult<()> {
328        let mut dict_words_buffer = Vec::new();
329        let mut dict_wordsidx_buffer = Vec::new();
330
331        for row in rows.iter() {
332            let offset = dict_words_buffer.len();
333            dict_wordsidx_buffer
334                .write_u32::<LittleEndian>(offset as u32)
335                .map_err(|err| {
336                    LinderaErrorKind::Io
337                        .with_error(anyhow::anyhow!(err))
338                        .add_context("Failed to write word index offset to dict.wordsidx buffer")
339                })?;
340
341            // Create word details from the row data (5th column and beyond)
342            let joined_details = if self.normalize_details {
343                row.iter()
344                    .skip(4)
345                    .map(normalize)
346                    .collect::<Vec<String>>()
347                    .join("\0")
348            } else {
349                row.iter().skip(4).collect::<Vec<&str>>().join("\0")
350            };
351            let joined_details_len = u32::try_from(joined_details.len()).map_err(|err| {
352                LinderaErrorKind::Serialize
353                    .with_error(anyhow::anyhow!(err))
354                    .add_context(format!(
355                        "Word details length too large: {} bytes",
356                        joined_details.len()
357                    ))
358            })?;
359
360            // Write to dict.words buffer
361            dict_words_buffer
362                .write_u32::<LittleEndian>(joined_details_len)
363                .map_err(|err| {
364                    LinderaErrorKind::Serialize
365                        .with_error(anyhow::anyhow!(err))
366                        .add_context("Failed to write word details length to dict.words buffer")
367                })?;
368            dict_words_buffer
369                .write_all(joined_details.as_bytes())
370                .map_err(|err| {
371                    LinderaErrorKind::Serialize
372                        .with_error(anyhow::anyhow!(err))
373                        .add_context("Failed to write word details to dict.words buffer")
374                })?;
375        }
376
377        // Write dict.words file
378        let dict_words_path = output_dir.join(Path::new("dict.words"));
379        let mut dict_words_writer =
380            io::BufWriter::new(File::create(&dict_words_path).map_err(|err| {
381                LinderaErrorKind::Io
382                    .with_error(anyhow::anyhow!(err))
383                    .add_context(format!(
384                        "Failed to create dict.words file: {dict_words_path:?}"
385                    ))
386            })?);
387
388        compress_write(
389            &dict_words_buffer,
390            self.compress_algorithm,
391            &mut dict_words_writer,
392        )?;
393
394        dict_words_writer.flush().map_err(|err| {
395            LinderaErrorKind::Io
396                .with_error(anyhow::anyhow!(err))
397                .add_context(format!(
398                    "Failed to flush dict.words file: {dict_words_path:?}"
399                ))
400        })?;
401
402        // Write dict.wordsidx file
403        let dict_wordsidx_path = output_dir.join(Path::new("dict.wordsidx"));
404        let mut dict_wordsidx_writer =
405            io::BufWriter::new(File::create(&dict_wordsidx_path).map_err(|err| {
406                LinderaErrorKind::Io
407                    .with_error(anyhow::anyhow!(err))
408                    .add_context(format!(
409                        "Failed to create dict.wordsidx file: {dict_wordsidx_path:?}"
410                    ))
411            })?);
412
413        compress_write(
414            &dict_wordsidx_buffer,
415            self.compress_algorithm,
416            &mut dict_wordsidx_writer,
417        )?;
418
419        dict_wordsidx_writer.flush().map_err(|err| {
420            LinderaErrorKind::Io
421                .with_error(anyhow::anyhow!(err))
422                .add_context(format!(
423                    "Failed to flush dict.wordsidx file: {dict_wordsidx_path:?}"
424                ))
425        })?;
426
427        Ok(())
428    }
429
430    /// Write double array file (dict.da)
431    fn write_double_array_file(
432        &self,
433        output_dir: &Path,
434        word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
435    ) -> LinderaResult<()> {
436        let mut id = 0u32;
437        let mut keyset: Vec<(&[u8], u32)> = vec![];
438
439        for (key, word_entries) in word_entry_map {
440            let len = word_entries.len() as u32;
441            let val = (id << 5) | len; // 27bit for word ID, 5bit for different parts of speech on the same surface.
442            keyset.push((key.as_bytes(), val));
443            id += len;
444        }
445
446        let keyset_len = keyset.len();
447
448        let dict_da = DoubleArrayAhoCorasickBuilder::new()
449            .build_with_values(keyset)
450            .map_err(|err| {
451                LinderaErrorKind::Build
452                    .with_error(anyhow::anyhow!(err))
453                    .add_context(format!(
454                        "Failed to build DoubleArray with {} keys for prefix dictionary",
455                        keyset_len
456                    ))
457            })?;
458
459        let dict_da_buffer = dict_da.serialize();
460
461        let dict_da_path = output_dir.join(Path::new("dict.da"));
462        let mut dict_da_writer =
463            io::BufWriter::new(File::create(&dict_da_path).map_err(|err| {
464                LinderaErrorKind::Io
465                    .with_error(anyhow::anyhow!(err))
466                    .add_context(format!("Failed to create dict.da file: {dict_da_path:?}"))
467            })?);
468
469        compress_write(
470            &dict_da_buffer,
471            self.compress_algorithm,
472            &mut dict_da_writer,
473        )?;
474
475        Ok(())
476    }
477
478    /// Write values file (dict.vals)
479    fn write_values_file(
480        &self,
481        output_dir: &Path,
482        word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
483    ) -> LinderaResult<()> {
484        let mut dict_vals_buffer = Vec::new();
485        for word_entries in word_entry_map.values() {
486            for word_entry in word_entries {
487                word_entry.serialize(&mut dict_vals_buffer).map_err(|err| {
488                    LinderaErrorKind::Serialize
489                        .with_error(anyhow::anyhow!(err))
490                        .add_context(format!(
491                            "Failed to serialize word entry (id: {})",
492                            word_entry.word_id.id
493                        ))
494                })?;
495            }
496        }
497
498        let dict_vals_path = output_dir.join(Path::new("dict.vals"));
499        let mut dict_vals_writer =
500            io::BufWriter::new(File::create(&dict_vals_path).map_err(|err| {
501                LinderaErrorKind::Io
502                    .with_error(anyhow::anyhow!(err))
503                    .add_context(format!(
504                        "Failed to create dict.vals file: {dict_vals_path:?}"
505                    ))
506            })?);
507
508        compress_write(
509            &dict_vals_buffer,
510            self.compress_algorithm,
511            &mut dict_vals_writer,
512        )?;
513
514        dict_vals_writer.flush().map_err(|err| {
515            LinderaErrorKind::Io
516                .with_error(anyhow::anyhow!(err))
517                .add_context(format!(
518                    "Failed to flush dict.vals file: {dict_vals_path:?}"
519                ))
520        })?;
521
522        Ok(())
523    }
524}
525
526fn normalize(text: &str) -> String {
527    text.to_string().replace('―', "—").replace('~', "〜")
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use crate::dictionary::schema::Schema;
534    use csv::StringRecord;
535
536    #[test]
537    fn test_new_with_schema() {
538        let schema = Schema::default();
539        let builder = PrefixDictionaryBuilder::new(schema.clone());
540
541        // Schema no longer has name field
542        // Schema no longer has version field
543        assert!(builder.flexible_csv);
544        assert_eq!(builder.encoding, "UTF-8");
545        assert!(!builder.normalize_details);
546        assert!(!builder.skip_invalid_cost_or_id);
547    }
548
549    #[test]
550    fn test_get_common_field_value_empty() {
551        let schema = Schema::default();
552        let builder = PrefixDictionaryBuilder::new(schema);
553
554        let record = StringRecord::from(vec![
555            "",    // Empty surface
556            "123", // LeftContextId
557            "456", // RightContextId
558            "789", // Cost
559        ]);
560
561        let surface = builder.get_field_value(&record, "surface").unwrap();
562        assert_eq!(surface, None);
563    }
564
565    #[test]
566    fn test_get_common_field_value_out_of_bounds() {
567        let schema = Schema::default();
568        let builder = PrefixDictionaryBuilder::new(schema);
569
570        let record = StringRecord::from(vec![
571            "surface_form", // Surface only
572        ]);
573
574        let left_id = builder.get_field_value(&record, "left_context_id").unwrap();
575        assert_eq!(left_id, None);
576    }
577
578    #[test]
579    fn test_parse_word_cost() {
580        let schema = Schema::default();
581        let builder = PrefixDictionaryBuilder::new(schema);
582
583        let record = StringRecord::from(vec![
584            "surface_form", // Surface
585            "123",          // LeftContextId
586            "456",          // RightContextId
587            "789",          // Cost
588        ]);
589
590        let cost = builder.parse_word_cost(&record).unwrap();
591        assert_eq!(cost, Some(789));
592    }
593
594    #[test]
595    fn test_parse_word_cost_invalid() {
596        let schema = Schema::default();
597        let builder = PrefixDictionaryBuilder::new(schema);
598
599        let record = StringRecord::from(vec![
600            "surface_form", // Surface
601            "123",          // LeftContextId
602            "456",          // RightContextId
603            "invalid",      // Invalid cost
604        ]);
605
606        let result = builder.parse_word_cost(&record);
607        assert!(result.is_err());
608    }
609
610    #[test]
611    fn test_parse_word_cost_skip_invalid() {
612        let schema = Schema::default();
613        let mut builder = PrefixDictionaryBuilder::new(schema);
614        builder.skip_invalid_cost_or_id = true;
615
616        let record = StringRecord::from(vec![
617            "surface_form", // Surface
618            "123",          // LeftContextId
619            "456",          // RightContextId
620            "invalid",      // Invalid cost
621        ]);
622
623        let cost = builder.parse_word_cost(&record).unwrap();
624        assert_eq!(cost, None);
625    }
626
627    #[test]
628    fn test_parse_left_id() {
629        let schema = Schema::default();
630        let builder = PrefixDictionaryBuilder::new(schema);
631
632        let record = StringRecord::from(vec![
633            "surface_form", // Surface
634            "123",          // LeftContextId
635            "456",          // RightContextId
636            "789",          // Cost
637        ]);
638
639        let left_id = builder.parse_left_id(&record).unwrap();
640        assert_eq!(left_id, Some(123));
641    }
642
643    #[test]
644    fn test_parse_right_id() {
645        let schema = Schema::default();
646        let builder = PrefixDictionaryBuilder::new(schema);
647
648        let record = StringRecord::from(vec![
649            "surface_form", // Surface
650            "123",          // LeftContextId
651            "456",          // RightContextId
652            "789",          // Cost
653        ]);
654
655        let right_id = builder.parse_right_id(&record).unwrap();
656        assert_eq!(right_id, Some(456));
657    }
658
659    #[test]
660    fn test_normalize_function() {
661        assert_eq!(normalize("test―text"), "test—text");
662        assert_eq!(normalize("test~text"), "test〜text");
663        assert_eq!(normalize("test―text~more"), "test—text〜more");
664        assert_eq!(normalize("normal text"), "normal text");
665    }
666
667    #[test]
668    fn test_get_encoding() {
669        let schema = Schema::default();
670        let builder = PrefixDictionaryBuilder::new(schema);
671
672        let encoding = builder.get_encoding().unwrap();
673        assert_eq!(encoding.name(), "UTF-8");
674    }
675
676    #[test]
677    fn test_get_encoding_invalid() {
678        let schema = Schema::default();
679        let mut builder = PrefixDictionaryBuilder::new(schema);
680        builder.encoding = "INVALID-ENCODING".into();
681
682        let result = builder.get_encoding();
683        assert!(result.is_err());
684    }
685
686    #[test]
687    fn test_get_common_field_value() {
688        let schema = Schema::default();
689        let builder = PrefixDictionaryBuilder::new(schema);
690
691        let record = StringRecord::from(vec![
692            "word", // Surface
693            "123",  // LeftContextId
694            "456",  // RightContextId
695            "789",  // Cost
696            "名詞", // MajorPos
697        ]);
698
699        // Test common fields
700        assert_eq!(
701            builder.get_field_value(&record, "surface").unwrap(),
702            Some("word".to_string())
703        );
704        assert_eq!(
705            builder.get_field_value(&record, "left_context_id").unwrap(),
706            Some("123".to_string())
707        );
708        assert_eq!(
709            builder
710                .get_field_value(&record, "right_context_id")
711                .unwrap(),
712            Some("456".to_string())
713        );
714        assert_eq!(
715            builder.get_field_value(&record, "cost").unwrap(),
716            Some("789".to_string())
717        );
718
719        // Test case where field is out of bounds - should return None, not an error
720        let short_record = StringRecord::from(vec!["word", "123"]);
721        assert_eq!(
722            builder.get_field_value(&short_record, "cost").unwrap(),
723            None
724        );
725    }
726}