Skip to main content

lindera_dictionary/dictionary/
connection_cost_matrix.rs

1use crate::{LinderaResult, error::LinderaErrorKind, util::Data};
2
3use byteorder::{ByteOrder, LittleEndian};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5
6#[derive(Clone, Archive, RkyvSerialize, RkyvDeserialize)]
7pub struct ConnectionCostMatrix {
8    /// The connection cost matrix data.
9    /// Previously, this was `Data` (byte array) and costs were read using `LittleEndian::read_i16` at runtime.
10    /// Changed to `Vec<i16>` to enable direct array indexing and avoid deserialization overhead during tokenization.
11    pub costs_data: Vec<i16>,
12    pub backward_size: u32,
13    pub forward_size: u32,
14}
15
16impl ConnectionCostMatrix {
17    /// Load a `ConnectionCostMatrix` from raw binary data.
18    ///
19    /// Supports both the new transposed format (header marker `-1`) and the old format.
20    ///
21    /// # Arguments
22    ///
23    /// * `conn_data` - Raw binary data for the connection cost matrix.
24    ///
25    /// # Returns
26    ///
27    /// A `ConnectionCostMatrix`, or an error if the data is too short or malformed.
28    pub fn load(conn_data: impl Into<Data>) -> LinderaResult<ConnectionCostMatrix> {
29        let conn_data = conn_data.into();
30        if conn_data.len() < 4 {
31            return Err(LinderaErrorKind::Deserialize.with_error(anyhow::anyhow!(
32                "Connection cost matrix data too short: {} bytes",
33                conn_data.len()
34            )));
35        }
36
37        let first_v = LittleEndian::read_i16(&conn_data[0..2]);
38
39        if first_v == -1 {
40            // New format (transposed)
41            if conn_data.len() < 6 {
42                return Err(LinderaErrorKind::Deserialize.with_error(anyhow::anyhow!(
43                    "Connection cost matrix header too short for new format: {} bytes",
44                    conn_data.len()
45                )));
46            }
47            let forward_size = LittleEndian::read_i16(&conn_data[2..4]) as u32;
48            let backward_size = LittleEndian::read_i16(&conn_data[4..6]) as u32;
49            let size = conn_data.len() / 2 - 3;
50            let mut costs_data = vec![0i16; size];
51            LittleEndian::read_i16_into(&conn_data[6..], &mut costs_data);
52
53            Ok(ConnectionCostMatrix {
54                costs_data,
55                backward_size,
56                forward_size,
57            })
58        } else {
59            // Old format
60            let forward_size = first_v as u32;
61            let backward_size = LittleEndian::read_i16(&conn_data[2..4]) as u32;
62            let size = conn_data.len() / 2 - 2;
63            let mut old_costs_data = vec![0i16; size];
64            LittleEndian::read_i16_into(&conn_data[4..], &mut old_costs_data);
65
66            // Transpose to new layout in memory
67            let mut costs_data = vec![0i16; size];
68            for f in 0..forward_size {
69                for b in 0..backward_size {
70                    let old_id = (b + f * backward_size) as usize;
71                    let new_id = (f + b * forward_size) as usize;
72                    costs_data[new_id] = old_costs_data[old_id];
73                }
74            }
75
76            Ok(ConnectionCostMatrix {
77                costs_data,
78                backward_size,
79                forward_size,
80            })
81        }
82    }
83
84    #[inline]
85    pub fn cost(&self, forward_id: u32, backward_id: u32) -> i32 {
86        let cost_id = (forward_id + backward_id * self.forward_size) as usize;
87        self.costs_data[cost_id] as i32
88    }
89}
90
91impl ArchivedConnectionCostMatrix {
92    #[inline]
93    pub fn cost(&self, forward_id: u32, backward_id: u32) -> i32 {
94        let cost_id = (forward_id + backward_id * self.forward_size) as usize;
95        self.costs_data[cost_id].to_native() as i32
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use byteorder::{LittleEndian, WriteBytesExt};
103
104    #[test]
105    fn test_load_transposed() {
106        let mut data = Vec::new();
107        data.write_i16::<LittleEndian>(-1).unwrap(); // version
108        data.write_i16::<LittleEndian>(2).unwrap(); // forward_size
109        data.write_i16::<LittleEndian>(3).unwrap(); // backward_size
110        // [forward_id + backward_id * forward_size]
111        // [0][0], [1][0], [0][1], [1][1], [0][2], [1][2]
112        data.write_i16::<LittleEndian>(10).unwrap();
113        data.write_i16::<LittleEndian>(11).unwrap();
114        data.write_i16::<LittleEndian>(12).unwrap();
115        data.write_i16::<LittleEndian>(13).unwrap();
116        data.write_i16::<LittleEndian>(14).unwrap();
117        data.write_i16::<LittleEndian>(15).unwrap();
118
119        let matrix = ConnectionCostMatrix::load(data).unwrap();
120        assert_eq!(matrix.forward_size, 2);
121        assert_eq!(matrix.backward_size, 3);
122        assert_eq!(matrix.cost(0, 0), 10);
123        assert_eq!(matrix.cost(1, 0), 11);
124        assert_eq!(matrix.cost(0, 1), 12);
125        assert_eq!(matrix.cost(1, 1), 13);
126        assert_eq!(matrix.cost(0, 2), 14);
127        assert_eq!(matrix.cost(1, 2), 15);
128    }
129
130    #[test]
131    fn test_load_old_format() {
132        let mut data = Vec::new();
133        data.write_i16::<LittleEndian>(2).unwrap(); // forward_size
134        data.write_i16::<LittleEndian>(3).unwrap(); // backward_size
135        // Old layout: [backward_id + forward_id * backward_size]
136        // [0][0], [1][0], [2][0], [0][1], [1][1], [2][1]
137        data.write_i16::<LittleEndian>(10).unwrap();
138        data.write_i16::<LittleEndian>(12).unwrap();
139        data.write_i16::<LittleEndian>(14).unwrap();
140        data.write_i16::<LittleEndian>(11).unwrap();
141        data.write_i16::<LittleEndian>(13).unwrap();
142        data.write_i16::<LittleEndian>(15).unwrap();
143
144        let matrix = ConnectionCostMatrix::load(data).unwrap();
145        assert_eq!(matrix.forward_size, 2);
146        assert_eq!(matrix.backward_size, 3);
147        assert_eq!(matrix.cost(0, 0), 10);
148        assert_eq!(matrix.cost(1, 0), 11);
149        assert_eq!(matrix.cost(0, 1), 12);
150        assert_eq!(matrix.cost(1, 1), 13);
151        assert_eq!(matrix.cost(0, 2), 14);
152        assert_eq!(matrix.cost(1, 2), 15);
153    }
154
155    #[test]
156    fn test_load_data_too_short() {
157        let data: Vec<u8> = vec![0x01, 0x02];
158        let result = ConnectionCostMatrix::load(data);
159        assert!(result.is_err());
160    }
161}