lindera_core/dictionary_builder/
user_dict.rs

1use std::collections::BTreeMap;
2use std::fs;
3use std::fs::File;
4use std::io;
5use std::io::Write;
6use std::path::Path;
7
8use byteorder::{LittleEndian, WriteBytesExt};
9use csv::StringRecord;
10use derive_builder::Builder;
11use log::debug;
12use yada::builder::DoubleArrayBuilder;
13use yada::DoubleArray;
14
15use crate::dictionary::prefix_dict::PrefixDict;
16use crate::dictionary::word_entry::{WordEntry, WordId};
17use crate::dictionary::UserDictionary;
18use crate::error::LinderaErrorKind;
19use crate::LinderaResult;
20
21type StringRecordProcessor = Option<Box<dyn Fn(&StringRecord) -> LinderaResult<Vec<String>>>>;
22
23#[derive(Builder)]
24#[builder(pattern = "owned")]
25#[builder(name = "UserDictBuilderOptions")]
26#[builder(build_fn(name = "builder"))]
27pub struct UserDictBuilder {
28    #[builder(default = "3")]
29    simple_userdic_fields_num: usize,
30    #[builder(default = "4")]
31    detailed_userdic_fields_num: usize,
32    #[builder(default = "-10000")]
33    simple_word_cost: i16,
34    #[builder(default = "0")]
35    simple_context_id: u16,
36    #[builder(default = "true")]
37    flexible_csv: bool,
38    #[builder(setter(strip_option), default = "None")]
39    simple_userdic_details_handler: StringRecordProcessor,
40}
41
42impl UserDictBuilder {
43    pub fn build(&self, input_file: &Path) -> LinderaResult<UserDictionary> {
44        debug!("reading {:?}", input_file);
45
46        let mut rdr = csv::ReaderBuilder::new()
47            .has_headers(false)
48            .flexible(self.flexible_csv)
49            .from_path(input_file)
50            .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
51
52        let mut rows: Vec<StringRecord> = vec![];
53        for result in rdr.records() {
54            let record =
55                result.map_err(|err| LinderaErrorKind::Content.with_error(anyhow::anyhow!(err)))?;
56            rows.push(record);
57        }
58        rows.sort_by_key(|row| row[0].to_string());
59
60        let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
61
62        for (row_id, row) in rows.iter().enumerate() {
63            let surface = row[0].to_string();
64            let word_cost = if row.len() == self.simple_userdic_fields_num {
65                self.simple_word_cost
66            } else {
67                row[3].parse::<i16>().map_err(|_err| {
68                    LinderaErrorKind::Parse.with_error(anyhow::anyhow!("failed to parse word cost"))
69                })?
70            };
71            let (left_id, right_id) = if row.len() == self.simple_userdic_fields_num {
72                (self.simple_context_id, self.simple_context_id)
73            } else {
74                (
75                    row[1].parse::<u16>().map_err(|_err| {
76                        LinderaErrorKind::Parse
77                            .with_error(anyhow::anyhow!("failed to parse left context id"))
78                    })?,
79                    row[2].parse::<u16>().map_err(|_err| {
80                        LinderaErrorKind::Parse
81                            .with_error(anyhow::anyhow!("failed to parse left context id"))
82                    })?,
83                )
84            };
85
86            word_entry_map.entry(surface).or_default().push(WordEntry {
87                word_id: WordId(row_id as u32, false),
88                word_cost,
89                left_id,
90                right_id,
91            });
92        }
93
94        let mut words_data = Vec::<u8>::new();
95        let mut words_idx_data = Vec::<u8>::new();
96        for row in rows.iter() {
97            let word_detail = if row.len() == self.simple_userdic_fields_num {
98                if let Some(handler) = &self.simple_userdic_details_handler {
99                    handler(row)?
100                } else {
101                    row.iter()
102                        .skip(1)
103                        .map(|s| s.to_string())
104                        .collect::<Vec<String>>()
105                }
106            } else if row.len() >= self.detailed_userdic_fields_num {
107                let mut tmp_word_detail = Vec::new();
108                for item in row.iter().skip(4) {
109                    tmp_word_detail.push(item.to_string());
110                }
111                tmp_word_detail
112            } else {
113                return Err(LinderaErrorKind::Content.with_error(anyhow::anyhow!(
114                    "user dictionary should be a CSV with {} or {}+ fields",
115                    self.simple_userdic_fields_num,
116                    self.detailed_userdic_fields_num
117                )));
118            };
119
120            let offset = words_data.len();
121            words_idx_data
122                .write_u32::<LittleEndian>(offset as u32)
123                .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
124            bincode::serialize_into(&mut words_data, &word_detail)
125                .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
126        }
127
128        let mut id = 0u32;
129
130        // building double array trie
131        let mut keyset: Vec<(&[u8], u32)> = vec![];
132        for (key, word_entries) in &word_entry_map {
133            let len = word_entries.len() as u32;
134            let val = (id << 5) | len;
135            keyset.push((key.as_bytes(), val));
136            id += len;
137        }
138        let da_bytes = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
139            LinderaErrorKind::Io.with_error(anyhow::anyhow!("DoubleArray build error."))
140        })?;
141
142        // building values
143        let mut vals_data = Vec::<u8>::new();
144        for word_entries in word_entry_map.values() {
145            for word_entry in word_entries {
146                word_entry
147                    .serialize(&mut vals_data)
148                    .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
149            }
150        }
151
152        let dict = PrefixDict {
153            da: DoubleArray::new(da_bytes),
154            vals_data,
155            is_system: false,
156        };
157
158        Ok(UserDictionary {
159            dict,
160            words_idx_data,
161            words_data,
162        })
163    }
164}
165
166pub fn build_user_dictionary(user_dict: UserDictionary, output_file: &Path) -> LinderaResult<()> {
167    let parent_dir = match output_file.parent() {
168        Some(parent_dir) => parent_dir,
169        None => {
170            return Err(LinderaErrorKind::Io.with_error(anyhow::anyhow!(
171                "failed to get parent directory of output file"
172            )))
173        }
174    };
175    fs::create_dir_all(parent_dir)
176        .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
177
178    let mut wtr = io::BufWriter::new(
179        File::create(output_file)
180            .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?,
181    );
182    bincode::serialize_into(&mut wtr, &user_dict)
183        .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
184    wtr.flush()
185        .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
186
187    Ok(())
188}