Skip to main content

lindera_dictionary/builder/
connection_cost_matrix.rs

1use std::borrow::Cow;
2use std::fs::File;
3use std::io::{self, Write};
4use std::path::Path;
5use std::str::FromStr;
6
7use byteorder::{LittleEndian, WriteBytesExt};
8use derive_builder::Builder;
9use log::debug;
10
11use crate::LinderaResult;
12use crate::error::LinderaErrorKind;
13use crate::util::{read_file_with_encoding, write_data};
14
15#[derive(Builder, Debug)]
16#[builder(name = ConnectionCostMatrixBuilderOptions)]
17#[builder(build_fn(name = "builder"))]
18pub struct ConnectionCostMatrixBuilder {
19    #[builder(default = "\"UTF-8\".into()", setter(into))]
20    encoding: Cow<'static, str>,
21}
22
23impl ConnectionCostMatrixBuilder {
24    pub fn build(&self, input_dir: &Path, output_dir: &Path) -> LinderaResult<()> {
25        let matrix_data_path = input_dir.join("matrix.def");
26        debug!("reading {matrix_data_path:?}");
27        let matrix_data = read_file_with_encoding(&matrix_data_path, &self.encoding)?;
28
29        let mut lines = Vec::new();
30        for line in matrix_data.lines() {
31            let fields: Vec<i32> = line
32                .split_whitespace()
33                .map(i32::from_str)
34                .collect::<Result<_, _>>()
35                .map_err(|err| LinderaErrorKind::Parse.with_error(anyhow::anyhow!(err)))?;
36            lines.push(fields);
37        }
38        let mut lines_it = lines.into_iter();
39        let header = lines_it.next().ok_or_else(|| {
40            LinderaErrorKind::Content.with_error(anyhow::anyhow!("unknown error"))
41        })?;
42        let forward_size = header[0] as u32;
43        let backward_size = header[1] as u32;
44        let len = 3 + (forward_size * backward_size) as usize;
45        let mut costs = vec![i16::MAX; len];
46        costs[0] = -1; // Version flag for transposed layout
47        costs[1] = forward_size as i16;
48        costs[2] = backward_size as i16;
49        for fields in lines_it {
50            let forward_id = fields[0] as u32;
51            let backward_id = fields[1] as u32;
52            let cost = fields[2] as u16;
53            costs[3 + (forward_id + backward_id * forward_size) as usize] = cost as i16;
54        }
55
56        let wtr_matrix_mtx_path = output_dir.join(Path::new("matrix.mtx"));
57        let mut wtr_matrix_mtx = io::BufWriter::new(
58            File::create(wtr_matrix_mtx_path)
59                .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?,
60        );
61        let mut matrix_mtx_buffer = Vec::new();
62        for cost in costs {
63            matrix_mtx_buffer
64                .write_i16::<LittleEndian>(cost)
65                .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
66        }
67
68        write_data(&matrix_mtx_buffer, &mut wtr_matrix_mtx)?;
69
70        wtr_matrix_mtx
71            .flush()
72            .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
73
74        Ok(())
75    }
76}