lindera_dictionary/dictionary_builder/
user_dictionary.rs

1use std::collections::BTreeMap;
2use std::fs;
3use std::fs::File;
4use std::io;
5use std::io::Write;
6use std::path::Path;
7
8use byteorder::{LittleEndian, WriteBytesExt};
9use csv::StringRecord;
10use derive_builder::Builder;
11use log::debug;
12use yada::builder::DoubleArrayBuilder;
13
14use crate::dictionary::prefix_dictionary::PrefixDictionary;
15use crate::dictionary::UserDictionary;
16use crate::error::LinderaErrorKind;
17use crate::viterbi::{WordEntry, WordId};
18use crate::LinderaResult;
19
20type StringRecordProcessor = Option<Box<dyn Fn(&StringRecord) -> LinderaResult<Vec<String>>>>;
21
22#[derive(Builder)]
23#[builder(pattern = "owned")]
24#[builder(name = UserDictionaryBuilderOptions)]
25#[builder(build_fn(name = "builder"))]
26pub struct UserDictionaryBuilder {
27    #[builder(default = "3")]
28    simple_userdic_fields_num: usize,
29    #[builder(default = "4")]
30    detailed_userdic_fields_num: usize,
31    #[builder(default = "-10000")]
32    simple_word_cost: i16,
33    #[builder(default = "0")]
34    simple_context_id: u16,
35    #[builder(default = "true")]
36    flexible_csv: bool,
37    #[builder(setter(strip_option), default = "None")]
38    simple_userdic_details_handler: StringRecordProcessor,
39}
40
41impl UserDictionaryBuilder {
42    pub fn build(&self, input_file: &Path) -> LinderaResult<UserDictionary> {
43        debug!("reading {:?}", input_file);
44
45        let mut rdr = csv::ReaderBuilder::new()
46            .has_headers(false)
47            .flexible(self.flexible_csv)
48            .from_path(input_file)
49            .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
50
51        let mut rows: Vec<StringRecord> = vec![];
52        for result in rdr.records() {
53            let record =
54                result.map_err(|err| LinderaErrorKind::Content.with_error(anyhow::anyhow!(err)))?;
55            rows.push(record);
56        }
57        rows.sort_by_key(|row| row[0].to_string());
58
59        let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
60
61        for (row_id, row) in rows.iter().enumerate() {
62            let surface = row[0].to_string();
63            let word_cost = if row.len() == self.simple_userdic_fields_num {
64                self.simple_word_cost
65            } else {
66                row[3].parse::<i16>().map_err(|_err| {
67                    LinderaErrorKind::Parse.with_error(anyhow::anyhow!("failed to parse word cost"))
68                })?
69            };
70            let (left_id, right_id) = if row.len() == self.simple_userdic_fields_num {
71                (self.simple_context_id, self.simple_context_id)
72            } else {
73                (
74                    row[1].parse::<u16>().map_err(|_err| {
75                        LinderaErrorKind::Parse
76                            .with_error(anyhow::anyhow!("failed to parse left context id"))
77                    })?,
78                    row[2].parse::<u16>().map_err(|_err| {
79                        LinderaErrorKind::Parse
80                            .with_error(anyhow::anyhow!("failed to parse left context id"))
81                    })?,
82                )
83            };
84
85            word_entry_map.entry(surface).or_default().push(WordEntry {
86                word_id: WordId {
87                    id: row_id as u32,
88                    is_system: false,
89                },
90                word_cost,
91                left_id,
92                right_id,
93            });
94        }
95
96        let mut words_data = Vec::<u8>::new();
97        let mut words_idx_data = Vec::<u8>::new();
98        for row in rows.iter() {
99            let word_detail = if row.len() == self.simple_userdic_fields_num {
100                if let Some(handler) = &self.simple_userdic_details_handler {
101                    handler(row)?
102                } else {
103                    row.iter()
104                        .skip(1)
105                        .map(|s| s.to_string())
106                        .collect::<Vec<String>>()
107                }
108            } else if row.len() >= self.detailed_userdic_fields_num {
109                let mut tmp_word_detail = Vec::new();
110                for item in row.iter().skip(4) {
111                    tmp_word_detail.push(item.to_string());
112                }
113                tmp_word_detail
114            } else {
115                return Err(LinderaErrorKind::Content.with_error(anyhow::anyhow!(
116                    "user dictionary should be a CSV with {} or {}+ fields",
117                    self.simple_userdic_fields_num,
118                    self.detailed_userdic_fields_num
119                )));
120            };
121
122            let offset = words_data.len();
123            words_idx_data
124                .write_u32::<LittleEndian>(offset as u32)
125                .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
126
127            // Store word details as null-separated string (like main dictionary)
128            let joined_details = word_detail.join("\0");
129            let joined_details_len = u32::try_from(joined_details.len())
130                .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
131
132            words_data
133                .write_u32::<LittleEndian>(joined_details_len)
134                .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
135            words_data
136                .write_all(joined_details.as_bytes())
137                .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
138        }
139
140        let mut id = 0u32;
141
142        // building double array trie
143        let mut keyset: Vec<(&[u8], u32)> = vec![];
144        for (key, word_entries) in &word_entry_map {
145            let len = word_entries.len() as u32;
146            let val = (id << 5) | len;
147            keyset.push((key.as_bytes(), val));
148            id += len;
149        }
150        let da_bytes = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
151            LinderaErrorKind::Io.with_error(anyhow::anyhow!("DoubleArray build error."))
152        })?;
153
154        // building values
155        let mut vals_data = Vec::<u8>::new();
156        for word_entries in word_entry_map.values() {
157            for word_entry in word_entries {
158                word_entry
159                    .serialize(&mut vals_data)
160                    .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
161            }
162        }
163
164        let dict = PrefixDictionary::load(da_bytes, vals_data, words_idx_data, words_data, false);
165
166        Ok(UserDictionary { dict })
167    }
168}
169
170pub fn build_user_dictionary(user_dict: UserDictionary, output_file: &Path) -> LinderaResult<()> {
171    let parent_dir = match output_file.parent() {
172        Some(parent_dir) => parent_dir,
173        None => {
174            return Err(LinderaErrorKind::Io.with_error(anyhow::anyhow!(
175                "failed to get parent directory of output file"
176            )))
177        }
178    };
179    fs::create_dir_all(parent_dir)
180        .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
181
182    let mut wtr = io::BufWriter::new(
183        File::create(output_file)
184            .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?,
185    );
186    bincode::serde::encode_into_std_write(&user_dict, &mut wtr, bincode::config::legacy())
187        .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
188    wtr.flush()
189        .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
190
191    Ok(())
192}