lindera_dictionary/dictionary_builder/
user_dictionary.rs1use std::collections::BTreeMap;
2use std::fs;
3use std::fs::File;
4use std::io;
5use std::io::Write;
6use std::path::Path;
7
8use byteorder::{LittleEndian, WriteBytesExt};
9use csv::StringRecord;
10use derive_builder::Builder;
11use log::debug;
12use yada::builder::DoubleArrayBuilder;
13
14use crate::dictionary::prefix_dictionary::PrefixDictionary;
15use crate::dictionary::UserDictionary;
16use crate::error::LinderaErrorKind;
17use crate::viterbi::{WordEntry, WordId};
18use crate::LinderaResult;
19
20type StringRecordProcessor = Option<Box<dyn Fn(&StringRecord) -> LinderaResult<Vec<String>>>>;
21
22#[derive(Builder)]
23#[builder(pattern = "owned")]
24#[builder(name = UserDictionaryBuilderOptions)]
25#[builder(build_fn(name = "builder"))]
26pub struct UserDictionaryBuilder {
27 #[builder(default = "3")]
28 simple_userdic_fields_num: usize,
29 #[builder(default = "4")]
30 detailed_userdic_fields_num: usize,
31 #[builder(default = "-10000")]
32 simple_word_cost: i16,
33 #[builder(default = "0")]
34 simple_context_id: u16,
35 #[builder(default = "true")]
36 flexible_csv: bool,
37 #[builder(setter(strip_option), default = "None")]
38 simple_userdic_details_handler: StringRecordProcessor,
39}
40
41impl UserDictionaryBuilder {
42 pub fn build(&self, input_file: &Path) -> LinderaResult<UserDictionary> {
43 debug!("reading {:?}", input_file);
44
45 let mut rdr = csv::ReaderBuilder::new()
46 .has_headers(false)
47 .flexible(self.flexible_csv)
48 .from_path(input_file)
49 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
50
51 let mut rows: Vec<StringRecord> = vec![];
52 for result in rdr.records() {
53 let record =
54 result.map_err(|err| LinderaErrorKind::Content.with_error(anyhow::anyhow!(err)))?;
55 rows.push(record);
56 }
57 rows.sort_by_key(|row| row[0].to_string());
58
59 let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
60
61 for (row_id, row) in rows.iter().enumerate() {
62 let surface = row[0].to_string();
63 let word_cost = if row.len() == self.simple_userdic_fields_num {
64 self.simple_word_cost
65 } else {
66 row[3].parse::<i16>().map_err(|_err| {
67 LinderaErrorKind::Parse.with_error(anyhow::anyhow!("failed to parse word cost"))
68 })?
69 };
70 let (left_id, right_id) = if row.len() == self.simple_userdic_fields_num {
71 (self.simple_context_id, self.simple_context_id)
72 } else {
73 (
74 row[1].parse::<u16>().map_err(|_err| {
75 LinderaErrorKind::Parse
76 .with_error(anyhow::anyhow!("failed to parse left context id"))
77 })?,
78 row[2].parse::<u16>().map_err(|_err| {
79 LinderaErrorKind::Parse
80 .with_error(anyhow::anyhow!("failed to parse left context id"))
81 })?,
82 )
83 };
84
85 word_entry_map.entry(surface).or_default().push(WordEntry {
86 word_id: WordId {
87 id: row_id as u32,
88 is_system: false,
89 },
90 word_cost,
91 left_id,
92 right_id,
93 });
94 }
95
96 let mut words_data = Vec::<u8>::new();
97 let mut words_idx_data = Vec::<u8>::new();
98 for row in rows.iter() {
99 let word_detail = if row.len() == self.simple_userdic_fields_num {
100 if let Some(handler) = &self.simple_userdic_details_handler {
101 handler(row)?
102 } else {
103 row.iter()
104 .skip(1)
105 .map(|s| s.to_string())
106 .collect::<Vec<String>>()
107 }
108 } else if row.len() >= self.detailed_userdic_fields_num {
109 let mut tmp_word_detail = Vec::new();
110 for item in row.iter().skip(4) {
111 tmp_word_detail.push(item.to_string());
112 }
113 tmp_word_detail
114 } else {
115 return Err(LinderaErrorKind::Content.with_error(anyhow::anyhow!(
116 "user dictionary should be a CSV with {} or {}+ fields",
117 self.simple_userdic_fields_num,
118 self.detailed_userdic_fields_num
119 )));
120 };
121
122 let offset = words_data.len();
123 words_idx_data
124 .write_u32::<LittleEndian>(offset as u32)
125 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
126
127 let joined_details = word_detail.join("\0");
129 let joined_details_len = u32::try_from(joined_details.len())
130 .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
131
132 words_data
133 .write_u32::<LittleEndian>(joined_details_len)
134 .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
135 words_data
136 .write_all(joined_details.as_bytes())
137 .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
138 }
139
140 let mut id = 0u32;
141
142 let mut keyset: Vec<(&[u8], u32)> = vec![];
144 for (key, word_entries) in &word_entry_map {
145 let len = word_entries.len() as u32;
146 let val = (id << 5) | len;
147 keyset.push((key.as_bytes(), val));
148 id += len;
149 }
150 let da_bytes = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
151 LinderaErrorKind::Io.with_error(anyhow::anyhow!("DoubleArray build error."))
152 })?;
153
154 let mut vals_data = Vec::<u8>::new();
156 for word_entries in word_entry_map.values() {
157 for word_entry in word_entries {
158 word_entry
159 .serialize(&mut vals_data)
160 .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
161 }
162 }
163
164 let dict = PrefixDictionary::load(da_bytes, vals_data, words_idx_data, words_data, false);
165
166 Ok(UserDictionary { dict })
167 }
168}
169
170pub fn build_user_dictionary(user_dict: UserDictionary, output_file: &Path) -> LinderaResult<()> {
171 let parent_dir = match output_file.parent() {
172 Some(parent_dir) => parent_dir,
173 None => {
174 return Err(LinderaErrorKind::Io.with_error(anyhow::anyhow!(
175 "failed to get parent directory of output file"
176 )))
177 }
178 };
179 fs::create_dir_all(parent_dir)
180 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
181
182 let mut wtr = io::BufWriter::new(
183 File::create(output_file)
184 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?,
185 );
186 bincode::serde::encode_into_std_write(&user_dict, &mut wtr, bincode::config::legacy())
187 .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
188 wtr.flush()
189 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
190
191 Ok(())
192}