lindera_dictionary/builder/
prefix_dictionary.rs

1use std::borrow::Cow;
2use std::collections::BTreeMap;
3use std::fs::File;
4use std::io::Write;
5use std::io::{self, Read};
6use std::path::{Path, PathBuf};
7use std::str::FromStr;
8
9use anyhow::anyhow;
10use byteorder::{LittleEndian, WriteBytesExt};
11use csv::StringRecord;
12use derive_builder::Builder;
13use encoding_rs::{Encoding, UTF_8};
14use encoding_rs_io::DecodeReaderBytesBuilder;
15use glob::glob;
16use log::debug;
17use yada::builder::DoubleArrayBuilder;
18
19use crate::LinderaResult;
20use crate::decompress::Algorithm;
21use crate::dictionary::schema::Schema;
22use crate::error::LinderaErrorKind;
23use crate::util::compress_write;
24use crate::viterbi::WordEntry;
25
26#[derive(Builder)]
27#[builder(name = PrefixDictionaryBuilderOptions)]
28#[builder(build_fn(name = "builder"))]
29pub struct PrefixDictionaryBuilder {
30    #[builder(default = "true")]
31    flexible_csv: bool,
32    /* If set to UTF-8, it can also read UTF-16 files with BOM. */
33    #[builder(default = "\"UTF-8\".into()", setter(into))]
34    encoding: Cow<'static, str>,
35    #[builder(default = "Algorithm::Deflate")]
36    compress_algorithm: Algorithm,
37    #[builder(default = "false")]
38    normalize_details: bool,
39    #[builder(default = "false")]
40    skip_invalid_cost_or_id: bool,
41    #[builder(default = "Schema::default()")]
42    schema: Schema,
43}
44
45impl PrefixDictionaryBuilder {
46    /// Create a new builder with the specified schema
47    pub fn new(schema: Schema) -> Self {
48        Self {
49            flexible_csv: true,
50            encoding: "UTF-8".into(),
51            compress_algorithm: Algorithm::Deflate,
52            normalize_details: false,
53            skip_invalid_cost_or_id: false,
54            schema,
55        }
56    }
57
58    /// Main method for building the dictionary
59    pub fn build(&self, input_dir: &Path, output_dir: &Path) -> LinderaResult<()> {
60        // 1. Load CSV data
61        let rows = self.load_csv_data(input_dir)?;
62
63        // 2. Build word entry map
64        let word_entry_map = self.build_word_entry_map(&rows)?;
65
66        // 3. Write dictionary files
67        self.write_dictionary_files(output_dir, &rows, &word_entry_map)?;
68
69        Ok(())
70    }
71
72    /// Load data from CSV files
73    fn load_csv_data(&self, input_dir: &Path) -> LinderaResult<Vec<StringRecord>> {
74        let filenames = self.collect_csv_files(input_dir)?;
75        let encoding = self.get_encoding()?;
76        let mut rows = self.read_csv_files(&filenames, encoding)?;
77
78        // Sort dictionary entries by the first column (word)
79        // Change sorting method based on normalization settings
80        if self.normalize_details {
81            // Sort after normalizing characters (―→—, ~→〜)
82            rows.sort_by_key(|row| normalize(&row[0]));
83        } else {
84            // Sort using original strings directly
85            rows.sort_by(|a, b| a[0].cmp(&b[0]))
86        }
87
88        Ok(rows)
89    }
90
91    /// Collect .csv file paths from input directory
92    fn collect_csv_files(&self, input_dir: &Path) -> LinderaResult<Vec<PathBuf>> {
93        let pattern = if let Some(path) = input_dir.to_str() {
94            format!("{path}/*.csv")
95        } else {
96            return Err(LinderaErrorKind::Io
97                .with_error(anyhow::anyhow!("Failed to convert path to &str."))
98                .add_context(format!(
99                    "Input directory path contains invalid characters: {input_dir:?}"
100                )));
101        };
102
103        let mut filenames: Vec<PathBuf> = Vec::new();
104        for entry in glob(&pattern).map_err(|err| {
105            LinderaErrorKind::Io
106                .with_error(anyhow::anyhow!(err))
107                .add_context(format!("Failed to glob CSV files with pattern: {pattern}"))
108        })? {
109            match entry {
110                Ok(path) => {
111                    if let Some(filename) = path.file_name() {
112                        filenames.push(Path::new(input_dir).join(filename));
113                    } else {
114                        return Err(LinderaErrorKind::Io
115                            .with_error(anyhow::anyhow!("failed to get filename"))
116                            .add_context(format!("Invalid filename in path: {path:?}")));
117                    };
118                }
119                Err(err) => {
120                    return Err(LinderaErrorKind::Content
121                        .with_error(anyhow!(err))
122                        .add_context(format!(
123                            "Failed to process glob entry with pattern: {pattern}"
124                        )));
125                }
126            }
127        }
128
129        Ok(filenames)
130    }
131
132    /// Get encoding configuration
133    fn get_encoding(&self) -> LinderaResult<&'static Encoding> {
134        let encoding = Encoding::for_label_no_replacement(self.encoding.as_bytes());
135        encoding.ok_or_else(|| {
136            LinderaErrorKind::Decode
137                .with_error(anyhow!("Invalid encoding: {}", self.encoding))
138                .add_context("Failed to get encoding for CSV file reading")
139        })
140    }
141
142    /// Read CSV files
143    fn read_csv_files(
144        &self,
145        filenames: &[PathBuf],
146        encoding: &'static Encoding,
147    ) -> LinderaResult<Vec<StringRecord>> {
148        let mut rows: Vec<StringRecord> = vec![];
149
150        for filename in filenames {
151            debug!("reading {filename:?}");
152
153            let file = File::open(filename).map_err(|err| {
154                LinderaErrorKind::Io
155                    .with_error(anyhow::anyhow!(err))
156                    .add_context(format!("Failed to open CSV file: {filename:?}"))
157            })?;
158            let reader: Box<dyn Read> = if encoding == UTF_8 {
159                Box::new(file)
160            } else {
161                Box::new(
162                    DecodeReaderBytesBuilder::new()
163                        .encoding(Some(encoding))
164                        .build(file),
165                )
166            };
167            let mut rdr = csv::ReaderBuilder::new()
168                .has_headers(false)
169                .flexible(self.flexible_csv)
170                .from_reader(reader);
171
172            for result in rdr.records() {
173                let record = result.map_err(|err| {
174                    LinderaErrorKind::Content
175                        .with_error(anyhow!(err))
176                        .add_context(format!("Failed to parse CSV record in file: {filename:?}"))
177                })?;
178                rows.push(record);
179            }
180        }
181
182        Ok(rows)
183    }
184
185    /// Build word entry map
186    fn build_word_entry_map(
187        &self,
188        rows: &[StringRecord],
189    ) -> LinderaResult<BTreeMap<String, Vec<WordEntry>>> {
190        let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
191
192        for (row_id, row) in rows.iter().enumerate() {
193            let word_cost = self.parse_word_cost(row)?;
194            let left_id = self.parse_left_id(row)?;
195            let right_id = self.parse_right_id(row)?;
196
197            // Skip if any value is invalid
198            if word_cost.is_none() || left_id.is_none() || right_id.is_none() {
199                continue;
200            }
201
202            let key = if self.normalize_details {
203                if let Some(surface) = self.get_field_value(row, "surface")? {
204                    normalize(&surface)
205                } else {
206                    continue;
207                }
208            } else if let Some(surface) = self.get_field_value(row, "surface")? {
209                surface
210            } else {
211                continue;
212            };
213
214            word_entry_map.entry(key).or_default().push(WordEntry {
215                word_id: crate::viterbi::WordId::new(
216                    crate::viterbi::LexType::System,
217                    row_id as u32,
218                ),
219                word_cost: word_cost.unwrap(),
220                left_id: left_id.unwrap(),
221                right_id: right_id.unwrap(),
222            });
223        }
224
225        Ok(word_entry_map)
226    }
227
228    /// Get field value by name
229    fn get_field_value(
230        &self,
231        row: &StringRecord,
232        field_name: &str,
233    ) -> LinderaResult<Option<String>> {
234        if let Some(index) = self.schema.get_field_index(field_name) {
235            if index >= row.len() {
236                return Ok(None);
237            }
238
239            let value = row[index].trim();
240            Ok(if value.is_empty() {
241                None
242            } else {
243                Some(value.to_string())
244            })
245        } else {
246            Ok(None)
247        }
248    }
249
250    /// Parse word cost using schema
251    fn parse_word_cost(&self, row: &StringRecord) -> LinderaResult<Option<i16>> {
252        let cost_str = self.get_field_value(row, "cost")?;
253        match cost_str {
254            Some(s) => match i16::from_str(&s) {
255                Ok(cost) => Ok(Some(cost)),
256                Err(_) => {
257                    if self.skip_invalid_cost_or_id {
258                        Ok(None)
259                    } else {
260                        Err(LinderaErrorKind::Content
261                            .with_error(anyhow!("Invalid cost value: {s}")))
262                    }
263                }
264            },
265            None => Ok(None),
266        }
267    }
268
269    /// Parse left ID using schema
270    fn parse_left_id(&self, row: &StringRecord) -> LinderaResult<Option<u16>> {
271        let left_id_str = self.get_field_value(row, "left_context_id")?;
272        match left_id_str {
273            Some(s) => match u16::from_str(&s) {
274                Ok(id) => Ok(Some(id)),
275                Err(_) => {
276                    if self.skip_invalid_cost_or_id {
277                        Ok(None)
278                    } else {
279                        Err(LinderaErrorKind::Content
280                            .with_error(anyhow!("Invalid left context ID: {s}")))
281                    }
282                }
283            },
284            None => Ok(None),
285        }
286    }
287
288    /// Parse right ID using schema
289    fn parse_right_id(&self, row: &StringRecord) -> LinderaResult<Option<u16>> {
290        let right_id_str = self.get_field_value(row, "right_context_id")?;
291        match right_id_str {
292            Some(s) => match u16::from_str(&s) {
293                Ok(id) => Ok(Some(id)),
294                Err(_) => {
295                    if self.skip_invalid_cost_or_id {
296                        Ok(None)
297                    } else {
298                        Err(LinderaErrorKind::Content
299                            .with_error(anyhow!("Invalid right context ID: {s}")))
300                    }
301                }
302            },
303            None => Ok(None),
304        }
305    }
306
307    /// Write dictionary files
308    fn write_dictionary_files(
309        &self,
310        output_dir: &Path,
311        rows: &[StringRecord],
312        word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
313    ) -> LinderaResult<()> {
314        // Write dict.words and dict.wordsidx
315        self.write_words_files(output_dir, rows)?;
316
317        // Write dict.da
318        self.write_double_array_file(output_dir, word_entry_map)?;
319
320        // Write dict.vals
321        self.write_values_file(output_dir, word_entry_map)?;
322
323        Ok(())
324    }
325
326    /// Write word detail files (dict.words, dict.wordsidx)
327    fn write_words_files(&self, output_dir: &Path, rows: &[StringRecord]) -> LinderaResult<()> {
328        let mut dict_words_buffer = Vec::new();
329        let mut dict_wordsidx_buffer = Vec::new();
330
331        for row in rows.iter() {
332            let offset = dict_words_buffer.len();
333            dict_wordsidx_buffer
334                .write_u32::<LittleEndian>(offset as u32)
335                .map_err(|err| {
336                    LinderaErrorKind::Io
337                        .with_error(anyhow::anyhow!(err))
338                        .add_context("Failed to write word index offset to dict.wordsidx buffer")
339                })?;
340
341            // Create word details from the row data (5th column and beyond)
342            let joined_details = if self.normalize_details {
343                row.iter()
344                    .skip(4)
345                    .map(normalize)
346                    .collect::<Vec<String>>()
347                    .join("\0")
348            } else {
349                row.iter().skip(4).collect::<Vec<&str>>().join("\0")
350            };
351            let joined_details_len = u32::try_from(joined_details.len()).map_err(|err| {
352                LinderaErrorKind::Serialize
353                    .with_error(anyhow::anyhow!(err))
354                    .add_context(format!(
355                        "Word details length too large: {} bytes",
356                        joined_details.len()
357                    ))
358            })?;
359
360            // Write to dict.words buffer
361            dict_words_buffer
362                .write_u32::<LittleEndian>(joined_details_len)
363                .map_err(|err| {
364                    LinderaErrorKind::Serialize
365                        .with_error(anyhow::anyhow!(err))
366                        .add_context("Failed to write word details length to dict.words buffer")
367                })?;
368            dict_words_buffer
369                .write_all(joined_details.as_bytes())
370                .map_err(|err| {
371                    LinderaErrorKind::Serialize
372                        .with_error(anyhow::anyhow!(err))
373                        .add_context("Failed to write word details to dict.words buffer")
374                })?;
375        }
376
377        // Write dict.words file
378        let dict_words_path = output_dir.join(Path::new("dict.words"));
379        let mut dict_words_writer =
380            io::BufWriter::new(File::create(&dict_words_path).map_err(|err| {
381                LinderaErrorKind::Io
382                    .with_error(anyhow::anyhow!(err))
383                    .add_context(format!(
384                        "Failed to create dict.words file: {dict_words_path:?}"
385                    ))
386            })?);
387
388        compress_write(
389            &dict_words_buffer,
390            self.compress_algorithm,
391            &mut dict_words_writer,
392        )?;
393
394        dict_words_writer.flush().map_err(|err| {
395            LinderaErrorKind::Io
396                .with_error(anyhow::anyhow!(err))
397                .add_context(format!(
398                    "Failed to flush dict.words file: {dict_words_path:?}"
399                ))
400        })?;
401
402        // Write dict.wordsidx file
403        let dict_wordsidx_path = output_dir.join(Path::new("dict.wordsidx"));
404        let mut dict_wordsidx_writer =
405            io::BufWriter::new(File::create(&dict_wordsidx_path).map_err(|err| {
406                LinderaErrorKind::Io
407                    .with_error(anyhow::anyhow!(err))
408                    .add_context(format!(
409                        "Failed to create dict.wordsidx file: {dict_wordsidx_path:?}"
410                    ))
411            })?);
412
413        compress_write(
414            &dict_wordsidx_buffer,
415            self.compress_algorithm,
416            &mut dict_wordsidx_writer,
417        )?;
418
419        dict_wordsidx_writer.flush().map_err(|err| {
420            LinderaErrorKind::Io
421                .with_error(anyhow::anyhow!(err))
422                .add_context(format!(
423                    "Failed to flush dict.wordsidx file: {dict_wordsidx_path:?}"
424                ))
425        })?;
426
427        Ok(())
428    }
429
430    /// Write double array file (dict.da)
431    fn write_double_array_file(
432        &self,
433        output_dir: &Path,
434        word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
435    ) -> LinderaResult<()> {
436        let mut id = 0u32;
437        let mut keyset: Vec<(&[u8], u32)> = vec![];
438
439        for (key, word_entries) in word_entry_map {
440            let len = word_entries.len() as u32;
441            let val = (id << 5) | len; // 27bit for word ID, 5bit for different parts of speech on the same surface.
442            keyset.push((key.as_bytes(), val));
443            id += len;
444        }
445
446        let dict_da_buffer = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
447            LinderaErrorKind::Build
448                .with_error(anyhow::anyhow!("DoubleArray build error."))
449                .add_context(format!(
450                    "Failed to build DoubleArray with {} keys for prefix dictionary",
451                    keyset.len()
452                ))
453        })?;
454
455        let dict_da_path = output_dir.join(Path::new("dict.da"));
456        let mut dict_da_writer =
457            io::BufWriter::new(File::create(&dict_da_path).map_err(|err| {
458                LinderaErrorKind::Io
459                    .with_error(anyhow::anyhow!(err))
460                    .add_context(format!("Failed to create dict.da file: {dict_da_path:?}"))
461            })?);
462
463        compress_write(
464            &dict_da_buffer,
465            self.compress_algorithm,
466            &mut dict_da_writer,
467        )?;
468
469        Ok(())
470    }
471
472    /// Write values file (dict.vals)
473    fn write_values_file(
474        &self,
475        output_dir: &Path,
476        word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
477    ) -> LinderaResult<()> {
478        let mut dict_vals_buffer = Vec::new();
479        for word_entries in word_entry_map.values() {
480            for word_entry in word_entries {
481                word_entry.serialize(&mut dict_vals_buffer).map_err(|err| {
482                    LinderaErrorKind::Serialize
483                        .with_error(anyhow::anyhow!(err))
484                        .add_context(format!(
485                            "Failed to serialize word entry (id: {})",
486                            word_entry.word_id.id
487                        ))
488                })?;
489            }
490        }
491
492        let dict_vals_path = output_dir.join(Path::new("dict.vals"));
493        let mut dict_vals_writer =
494            io::BufWriter::new(File::create(&dict_vals_path).map_err(|err| {
495                LinderaErrorKind::Io
496                    .with_error(anyhow::anyhow!(err))
497                    .add_context(format!(
498                        "Failed to create dict.vals file: {dict_vals_path:?}"
499                    ))
500            })?);
501
502        compress_write(
503            &dict_vals_buffer,
504            self.compress_algorithm,
505            &mut dict_vals_writer,
506        )?;
507
508        dict_vals_writer.flush().map_err(|err| {
509            LinderaErrorKind::Io
510                .with_error(anyhow::anyhow!(err))
511                .add_context(format!(
512                    "Failed to flush dict.vals file: {dict_vals_path:?}"
513                ))
514        })?;
515
516        Ok(())
517    }
518}
519
520fn normalize(text: &str) -> String {
521    text.to_string().replace('―', "—").replace('~', "〜")
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use crate::dictionary::schema::Schema;
528    use csv::StringRecord;
529
530    #[test]
531    fn test_new_with_schema() {
532        let schema = Schema::default();
533        let builder = PrefixDictionaryBuilder::new(schema.clone());
534
535        // Schema no longer has name field
536        // Schema no longer has version field
537        assert!(builder.flexible_csv);
538        assert_eq!(builder.encoding, "UTF-8");
539        assert!(!builder.normalize_details);
540        assert!(!builder.skip_invalid_cost_or_id);
541    }
542
543    #[test]
544    fn test_get_common_field_value_empty() {
545        let schema = Schema::default();
546        let builder = PrefixDictionaryBuilder::new(schema);
547
548        let record = StringRecord::from(vec![
549            "",    // Empty surface
550            "123", // LeftContextId
551            "456", // RightContextId
552            "789", // Cost
553        ]);
554
555        let surface = builder.get_field_value(&record, "surface").unwrap();
556        assert_eq!(surface, None);
557    }
558
559    #[test]
560    fn test_get_common_field_value_out_of_bounds() {
561        let schema = Schema::default();
562        let builder = PrefixDictionaryBuilder::new(schema);
563
564        let record = StringRecord::from(vec![
565            "surface_form", // Surface only
566        ]);
567
568        let left_id = builder.get_field_value(&record, "left_context_id").unwrap();
569        assert_eq!(left_id, None);
570    }
571
572    #[test]
573    fn test_parse_word_cost() {
574        let schema = Schema::default();
575        let builder = PrefixDictionaryBuilder::new(schema);
576
577        let record = StringRecord::from(vec![
578            "surface_form", // Surface
579            "123",          // LeftContextId
580            "456",          // RightContextId
581            "789",          // Cost
582        ]);
583
584        let cost = builder.parse_word_cost(&record).unwrap();
585        assert_eq!(cost, Some(789));
586    }
587
588    #[test]
589    fn test_parse_word_cost_invalid() {
590        let schema = Schema::default();
591        let builder = PrefixDictionaryBuilder::new(schema);
592
593        let record = StringRecord::from(vec![
594            "surface_form", // Surface
595            "123",          // LeftContextId
596            "456",          // RightContextId
597            "invalid",      // Invalid cost
598        ]);
599
600        let result = builder.parse_word_cost(&record);
601        assert!(result.is_err());
602    }
603
604    #[test]
605    fn test_parse_word_cost_skip_invalid() {
606        let schema = Schema::default();
607        let mut builder = PrefixDictionaryBuilder::new(schema);
608        builder.skip_invalid_cost_or_id = true;
609
610        let record = StringRecord::from(vec![
611            "surface_form", // Surface
612            "123",          // LeftContextId
613            "456",          // RightContextId
614            "invalid",      // Invalid cost
615        ]);
616
617        let cost = builder.parse_word_cost(&record).unwrap();
618        assert_eq!(cost, None);
619    }
620
621    #[test]
622    fn test_parse_left_id() {
623        let schema = Schema::default();
624        let builder = PrefixDictionaryBuilder::new(schema);
625
626        let record = StringRecord::from(vec![
627            "surface_form", // Surface
628            "123",          // LeftContextId
629            "456",          // RightContextId
630            "789",          // Cost
631        ]);
632
633        let left_id = builder.parse_left_id(&record).unwrap();
634        assert_eq!(left_id, Some(123));
635    }
636
637    #[test]
638    fn test_parse_right_id() {
639        let schema = Schema::default();
640        let builder = PrefixDictionaryBuilder::new(schema);
641
642        let record = StringRecord::from(vec![
643            "surface_form", // Surface
644            "123",          // LeftContextId
645            "456",          // RightContextId
646            "789",          // Cost
647        ]);
648
649        let right_id = builder.parse_right_id(&record).unwrap();
650        assert_eq!(right_id, Some(456));
651    }
652
653    #[test]
654    fn test_normalize_function() {
655        assert_eq!(normalize("test―text"), "test—text");
656        assert_eq!(normalize("test~text"), "test〜text");
657        assert_eq!(normalize("test―text~more"), "test—text〜more");
658        assert_eq!(normalize("normal text"), "normal text");
659    }
660
661    #[test]
662    fn test_get_encoding() {
663        let schema = Schema::default();
664        let builder = PrefixDictionaryBuilder::new(schema);
665
666        let encoding = builder.get_encoding().unwrap();
667        assert_eq!(encoding.name(), "UTF-8");
668    }
669
670    #[test]
671    fn test_get_encoding_invalid() {
672        let schema = Schema::default();
673        let mut builder = PrefixDictionaryBuilder::new(schema);
674        builder.encoding = "INVALID-ENCODING".into();
675
676        let result = builder.get_encoding();
677        assert!(result.is_err());
678    }
679
680    #[test]
681    fn test_get_common_field_value() {
682        let schema = Schema::default();
683        let builder = PrefixDictionaryBuilder::new(schema);
684
685        let record = StringRecord::from(vec![
686            "word", // Surface
687            "123",  // LeftContextId
688            "456",  // RightContextId
689            "789",  // Cost
690            "名詞", // MajorPos
691        ]);
692
693        // Test common fields
694        assert_eq!(
695            builder.get_field_value(&record, "surface").unwrap(),
696            Some("word".to_string())
697        );
698        assert_eq!(
699            builder.get_field_value(&record, "left_context_id").unwrap(),
700            Some("123".to_string())
701        );
702        assert_eq!(
703            builder
704                .get_field_value(&record, "right_context_id")
705                .unwrap(),
706            Some("456".to_string())
707        );
708        assert_eq!(
709            builder.get_field_value(&record, "cost").unwrap(),
710            Some("789".to_string())
711        );
712
713        // Test case where field is out of bounds - should return None, not an error
714        let short_record = StringRecord::from(vec!["word", "123"]);
715        assert_eq!(
716            builder.get_field_value(&short_record, "cost").unwrap(),
717            None
718        );
719    }
720}