lindera_dictionary_builder/
user_dict.rs

1use std::collections::BTreeMap;
2use std::fs;
3use std::fs::File;
4use std::io;
5use std::io::Write;
6use std::path::Path;
7
8use byteorder::{LittleEndian, WriteBytesExt};
9use csv::StringRecord;
10use derive_builder::Builder;
11use lindera_core::dictionary::UserDictionary;
12use lindera_core::error::LinderaErrorKind;
13use lindera_core::prefix_dict::PrefixDict;
14use lindera_core::word_entry::{WordEntry, WordId};
15use lindera_core::LinderaResult;
16use log::debug;
17use yada::builder::DoubleArrayBuilder;
18use yada::DoubleArray;
19
20#[derive(Builder)]
21#[builder(pattern = "owned")]
22#[builder(name = "UserDictBuilderOptions")]
23#[builder(build_fn(name = "builder"))]
24pub struct UserDictBuilder {
25    #[builder(default = "3")]
26    simple_userdic_fields_num: usize,
27    #[builder(default = "4")]
28    detailed_userdic_fields_num: usize,
29    #[builder(default = "-10000")]
30    simple_word_cost: i16,
31    #[builder(default = "0")]
32    simple_context_id: u16,
33    #[builder(default = "true")]
34    flexible_csv: bool,
35    #[builder(setter(strip_option), default = "None")]
36    simple_userdic_details_handler:
37        Option<Box<dyn Fn(&StringRecord) -> LinderaResult<Vec<String>>>>,
38}
39
40impl UserDictBuilder {
41    pub fn build(&self, input_file: &Path) -> LinderaResult<UserDictionary> {
42        debug!("reading {:?}", input_file);
43
44        let mut rdr = csv::ReaderBuilder::new()
45            .has_headers(false)
46            .flexible(self.flexible_csv)
47            .from_path(input_file)
48            .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
49
50        let mut rows: Vec<StringRecord> = vec![];
51        for result in rdr.records() {
52            let record =
53                result.map_err(|err| LinderaErrorKind::Content.with_error(anyhow::anyhow!(err)))?;
54            rows.push(record);
55        }
56        rows.sort_by_key(|row| row[0].to_string());
57
58        let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
59
60        for (row_id, row) in rows.iter().enumerate() {
61            let surface = row[0].to_string();
62            let word_cost = if row.len() == self.simple_userdic_fields_num {
63                self.simple_word_cost
64            } else {
65                row[3].parse::<i16>().map_err(|_err| {
66                    LinderaErrorKind::Parse.with_error(anyhow::anyhow!("failed to parse word cost"))
67                })?
68            };
69            let (left_id, right_id) = if row.len() == self.simple_userdic_fields_num {
70                (self.simple_context_id, self.simple_context_id)
71            } else {
72                (
73                    row[1].parse::<u16>().map_err(|_err| {
74                        LinderaErrorKind::Parse
75                            .with_error(anyhow::anyhow!("failed to parse left context id"))
76                    })?,
77                    row[2].parse::<u16>().map_err(|_err| {
78                        LinderaErrorKind::Parse
79                            .with_error(anyhow::anyhow!("failed to parse left context id"))
80                    })?,
81                )
82            };
83
84            word_entry_map.entry(surface).or_default().push(WordEntry {
85                word_id: WordId(row_id as u32, false),
86                word_cost,
87                left_id,
88                right_id,
89            });
90        }
91
92        let mut words_data = Vec::<u8>::new();
93        let mut words_idx_data = Vec::<u8>::new();
94        for row in rows.iter() {
95            let word_detail = if row.len() == self.simple_userdic_fields_num {
96                if let Some(handler) = &self.simple_userdic_details_handler {
97                    handler(row)?
98                } else {
99                    row.iter()
100                        .skip(1)
101                        .map(|s| s.to_string())
102                        .collect::<Vec<String>>()
103                }
104            } else if row.len() >= self.detailed_userdic_fields_num {
105                let mut tmp_word_detail = Vec::new();
106                for item in row.iter().skip(4) {
107                    tmp_word_detail.push(item.to_string());
108                }
109                tmp_word_detail
110            } else {
111                return Err(LinderaErrorKind::Content.with_error(anyhow::anyhow!(
112                    "user dictionary should be a CSV with {} or {}+ fields",
113                    self.simple_userdic_fields_num,
114                    self.detailed_userdic_fields_num
115                )));
116            };
117
118            let offset = words_data.len();
119            words_idx_data
120                .write_u32::<LittleEndian>(offset as u32)
121                .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
122            bincode::serialize_into(&mut words_data, &word_detail)
123                .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
124        }
125
126        let mut id = 0u32;
127
128        // building double array trie
129        let mut keyset: Vec<(&[u8], u32)> = vec![];
130        for (key, word_entries) in &word_entry_map {
131            let len = word_entries.len() as u32;
132            let val = (id << 5) | len;
133            keyset.push((key.as_bytes(), val));
134            id += len;
135        }
136        let da_bytes = DoubleArrayBuilder::build(&keyset).ok_or_else(|| {
137            LinderaErrorKind::Io.with_error(anyhow::anyhow!("DoubleArray build error."))
138        })?;
139
140        // building values
141        let mut vals_data = Vec::<u8>::new();
142        for word_entries in word_entry_map.values() {
143            for word_entry in word_entries {
144                word_entry
145                    .serialize(&mut vals_data)
146                    .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
147            }
148        }
149
150        let dict = PrefixDict {
151            da: DoubleArray::new(da_bytes),
152            vals_data,
153            is_system: false,
154        };
155
156        Ok(UserDictionary {
157            dict,
158            words_idx_data,
159            words_data,
160        })
161    }
162}
163
164pub fn build_user_dictionary(user_dict: UserDictionary, output_file: &Path) -> LinderaResult<()> {
165    let parent_dir = match output_file.parent() {
166        Some(parent_dir) => parent_dir,
167        None => {
168            return Err(LinderaErrorKind::Io.with_error(anyhow::anyhow!(
169                "failed to get parent directory of output file"
170            )))
171        }
172    };
173    fs::create_dir_all(parent_dir)
174        .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
175
176    let mut wtr = io::BufWriter::new(
177        File::create(output_file)
178            .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?,
179    );
180    bincode::serialize_into(&mut wtr, &user_dict)
181        .map_err(|err| LinderaErrorKind::Serialize.with_error(anyhow::anyhow!(err)))?;
182    wtr.flush()
183        .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
184
185    Ok(())
186}