vibrato/
mecab.rs

1//! Utilities to support MeCab models.
2
3use std::io::{BufRead, BufReader, BufWriter, Read, Write};
4
5use hashbrown::HashMap;
6use regex::Regex;
7
8use crate::errors::{Result, VibratoError};
9use crate::trainer::TrainerConfig;
10use crate::utils;
11
12/// Generates bi-gram feature information from MeCab model.
13///
14/// This function is useful to create a small dictionary from an existing MeCab model.
15///
16/// # Arguments
17///
18/// * `feature_def_rdr` - A reader of the feature definition file `feature.def`.
19/// * `left_id_def_rdr` - A reader of the left-id and feature mapping file `left-id.def`.
20/// * `right_id_def_rdr` - A reader of the right-id and feature mapping file `right-id.def`
21/// * `model_def_rdr` - A reader of the model file `model.def`.
22/// * `cost_factor` - A factor to be multiplied when casting costs to integers.
23/// * `bigram_left_wtr` - A writer of the left-id and feature mapping file `bi-gram.left`.
24/// * `bigram_right_wtr` - A writer of the right-id and feature mapping file `bi-gram.right`.
25/// * `bigram_cost_wtr` - A writer of the bi-gram cost file `bi-gram.cost`.
26///
27/// # Errors
28///
29/// [`VibratoError`] is returned when the convertion failed.
30#[allow(clippy::too_many_arguments)]
31pub fn generate_bigram_info(
32    feature_def_rdr: impl Read,
33    right_id_def_rdr: impl Read,
34    left_id_def_rdr: impl Read,
35    model_def_rdr: impl Read,
36    cost_factor: f64,
37    bigram_right_wtr: impl Write,
38    bigram_left_wtr: impl Write,
39    bigram_cost_wtr: impl Write,
40) -> Result<()> {
41    let mut left_features = HashMap::new();
42    let mut right_features = HashMap::new();
43
44    let mut feature_extractor = TrainerConfig::parse_feature_config(feature_def_rdr)?;
45
46    let id_feature_re = Regex::new(r"^([0-9]+) (.*)$").unwrap();
47    let model_re = Regex::new(r"^([0-9\-\.]+)\t(.*)$").unwrap();
48
49    // right-id.def contains the right hand ID of the left context, and left-id.def contains the
50    // left hand ID of the right context. The left-right naming in this code is based on position
51    // between words, so these names are the opposite of left-id.def and right-id.def files.
52
53    // left features
54    let right_id_def_rdr = BufReader::new(right_id_def_rdr);
55    for line in right_id_def_rdr.lines() {
56        let line = line?;
57        if let Some(cap) = id_feature_re.captures(&line) {
58            let id = cap.get(1).unwrap().as_str().parse::<usize>()?;
59            let feature_str = cap.get(2).unwrap().as_str();
60            let feature_spl = utils::parse_csv_row(feature_str);
61            if id == 0 && feature_spl.first().is_some_and(|s| s != "BOS/EOS") {
62                return Err(VibratoError::invalid_format(
63                    "right_id_def_rdr",
64                    "ID 0 must be BOS/EOS",
65                ));
66            }
67            let feature_ids = feature_extractor.extract_left_feature_ids(&feature_spl);
68            left_features.insert(id, feature_ids);
69        } else {
70            return Err(VibratoError::invalid_format(
71                "right_id_def_rdr",
72                "each line must be a pair of an ID and features",
73            ));
74        }
75    }
76    // right features
77    let left_id_def_rdr = BufReader::new(left_id_def_rdr);
78    for line in left_id_def_rdr.lines() {
79        let line = line?;
80        if let Some(cap) = id_feature_re.captures(&line) {
81            let id = cap.get(1).unwrap().as_str().parse::<usize>()?;
82            let feature_str = cap.get(2).unwrap().as_str();
83            let feature_spl = utils::parse_csv_row(feature_str);
84            if id == 0 && feature_spl.first().is_some_and(|s| s != "BOS/EOS") {
85                return Err(VibratoError::invalid_format(
86                    "left_id_def_rdr",
87                    "ID 0 must be BOS/EOS",
88                ));
89            }
90            let feature_ids = feature_extractor.extract_right_feature_ids(&feature_spl);
91            right_features.insert(id, feature_ids);
92        } else {
93            return Err(VibratoError::invalid_format(
94                "left_id_def_rdr",
95                "each line must be a pair of an ID and features",
96            ));
97        }
98    }
99    // weights
100    let model_def_rdr = BufReader::new(model_def_rdr);
101    let mut bigram_cost_wtr = BufWriter::new(bigram_cost_wtr);
102    for line in model_def_rdr.lines() {
103        let line = line?;
104        if let Some(cap) = model_re.captures(&line) {
105            let weight = cap.get(1).unwrap().as_str().parse::<f64>()?;
106            let cost = -(weight * cost_factor) as i32;
107            if cost == 0 {
108                continue;
109            }
110            let feature_str = cap.get(2).unwrap().as_str().replace("BOS/EOS", "");
111            let mut spl = feature_str.split('/');
112            let left_feat_str = spl.next();
113            let right_feat_str = spl.next();
114            if let (Some(left_feat_str), Some(right_feat_str)) = (left_feat_str, right_feat_str) {
115                let left_id = if left_feat_str.is_empty() {
116                    String::new()
117                } else if let Some(id) = feature_extractor.left_feature_ids().get(left_feat_str) {
118                    id.to_string()
119                } else {
120                    continue;
121                };
122                let right_id = if right_feat_str.is_empty() {
123                    String::new()
124                } else if let Some(id) = feature_extractor.right_feature_ids().get(right_feat_str) {
125                    id.to_string()
126                } else {
127                    continue;
128                };
129                writeln!(&mut bigram_cost_wtr, "{left_id}/{right_id}\t{cost}")?;
130            }
131        }
132    }
133
134    let mut bigram_right_wtr = BufWriter::new(bigram_right_wtr);
135    for id in 1..left_features.len() {
136        write!(&mut bigram_right_wtr, "{id}\t")?;
137        if let Some(features) = left_features.get(&id) {
138            for (i, feat_id) in features.iter().enumerate() {
139                if i != 0 {
140                    write!(&mut bigram_right_wtr, ",")?;
141                }
142                if let Some(feat_id) = feat_id {
143                    write!(&mut bigram_right_wtr, "{}", feat_id.get())?;
144                } else {
145                    write!(&mut bigram_right_wtr, "*")?;
146                }
147            }
148        } else {
149            return Err(VibratoError::invalid_format(
150                "right_id_def_rdr",
151                format!("feature ID {id} is undefined"),
152            ));
153        }
154        writeln!(&mut bigram_right_wtr)?;
155    }
156
157    let mut bigram_left_wtr = BufWriter::new(bigram_left_wtr);
158    for id in 1..right_features.len() {
159        write!(&mut bigram_left_wtr, "{id}\t")?;
160        if let Some(features) = right_features.get(&id) {
161            for (i, feat_id) in features.iter().enumerate() {
162                if i != 0 {
163                    write!(&mut bigram_left_wtr, ",")?;
164                }
165                if let Some(feat_id) = feat_id {
166                    write!(&mut bigram_left_wtr, "{}", feat_id.get())?;
167                } else {
168                    write!(&mut bigram_left_wtr, "*")?;
169                }
170            }
171            writeln!(&mut bigram_left_wtr)?;
172        } else {
173            return Err(VibratoError::invalid_format(
174                "left_id_def_rdr",
175                format!("feature ID {id} is undefined"),
176            ));
177        }
178    }
179
180    Ok(())
181}