Skip to main content

lindera_dictionary/dictionary/
connection_cost_matrix.rs

1use crate::util::Data;
2
3use byteorder::{ByteOrder, LittleEndian};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5
6#[derive(Clone, Archive, RkyvSerialize, RkyvDeserialize)]
7pub struct ConnectionCostMatrix {
8    /// The connection cost matrix data.
9    /// Previously, this was `Data` (byte array) and costs were read using `LittleEndian::read_i16` at runtime.
10    /// Changed to `Vec<i16>` to enable direct array indexing and avoid deserialization overhead during tokenization.
11    pub costs_data: Vec<i16>,
12    pub backward_size: u32,
13    pub forward_size: u32,
14}
15
16impl ConnectionCostMatrix {
17    pub fn load(conn_data: impl Into<Data>) -> ConnectionCostMatrix {
18        let conn_data = conn_data.into();
19        let first_v = LittleEndian::read_i16(&conn_data[0..2]);
20
21        if first_v == -1 {
22            // New format (transposed)
23            let forward_size = LittleEndian::read_i16(&conn_data[2..4]) as u32;
24            let backward_size = LittleEndian::read_i16(&conn_data[4..6]) as u32;
25            let size = conn_data.len() / 2 - 3;
26            let mut costs_data = vec![0i16; size];
27            LittleEndian::read_i16_into(&conn_data[6..], &mut costs_data);
28
29            ConnectionCostMatrix {
30                costs_data,
31                backward_size,
32                forward_size,
33            }
34        } else {
35            // Old format
36            let forward_size = first_v as u32;
37            let backward_size = LittleEndian::read_i16(&conn_data[2..4]) as u32;
38            let size = conn_data.len() / 2 - 2;
39            let mut old_costs_data = vec![0i16; size];
40            LittleEndian::read_i16_into(&conn_data[4..], &mut old_costs_data);
41
42            // Transpose to new layout in memory
43            let mut costs_data = vec![0i16; size];
44            for f in 0..forward_size {
45                for b in 0..backward_size {
46                    let old_id = (b + f * backward_size) as usize;
47                    let new_id = (f + b * forward_size) as usize;
48                    costs_data[new_id] = old_costs_data[old_id];
49                }
50            }
51
52            ConnectionCostMatrix {
53                costs_data,
54                backward_size,
55                forward_size,
56            }
57        }
58    }
59
60    #[inline]
61    pub fn cost(&self, forward_id: u32, backward_id: u32) -> i32 {
62        let cost_id = (forward_id + backward_id * self.forward_size) as usize;
63        self.costs_data[cost_id] as i32
64    }
65}
66
67impl ArchivedConnectionCostMatrix {
68    #[inline]
69    pub fn cost(&self, forward_id: u32, backward_id: u32) -> i32 {
70        let cost_id = (forward_id + backward_id * self.forward_size) as usize;
71        self.costs_data[cost_id].to_native() as i32
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use byteorder::{LittleEndian, WriteBytesExt};
79
80    #[test]
81    fn test_load_transposed() {
82        let mut data = Vec::new();
83        data.write_i16::<LittleEndian>(-1).unwrap(); // version
84        data.write_i16::<LittleEndian>(2).unwrap(); // forward_size
85        data.write_i16::<LittleEndian>(3).unwrap(); // backward_size
86        // [forward_id + backward_id * forward_size]
87        // [0][0], [1][0], [0][1], [1][1], [0][2], [1][2]
88        data.write_i16::<LittleEndian>(10).unwrap();
89        data.write_i16::<LittleEndian>(11).unwrap();
90        data.write_i16::<LittleEndian>(12).unwrap();
91        data.write_i16::<LittleEndian>(13).unwrap();
92        data.write_i16::<LittleEndian>(14).unwrap();
93        data.write_i16::<LittleEndian>(15).unwrap();
94
95        let matrix = ConnectionCostMatrix::load(data);
96        assert_eq!(matrix.forward_size, 2);
97        assert_eq!(matrix.backward_size, 3);
98        assert_eq!(matrix.cost(0, 0), 10);
99        assert_eq!(matrix.cost(1, 0), 11);
100        assert_eq!(matrix.cost(0, 1), 12);
101        assert_eq!(matrix.cost(1, 1), 13);
102        assert_eq!(matrix.cost(0, 2), 14);
103        assert_eq!(matrix.cost(1, 2), 15);
104    }
105
106    #[test]
107    fn test_load_old_format() {
108        let mut data = Vec::new();
109        data.write_i16::<LittleEndian>(2).unwrap(); // forward_size
110        data.write_i16::<LittleEndian>(3).unwrap(); // backward_size
111        // Old layout: [backward_id + forward_id * backward_size]
112        // [0][0], [1][0], [2][0], [0][1], [1][1], [2][1]
113        data.write_i16::<LittleEndian>(10).unwrap();
114        data.write_i16::<LittleEndian>(12).unwrap();
115        data.write_i16::<LittleEndian>(14).unwrap();
116        data.write_i16::<LittleEndian>(11).unwrap();
117        data.write_i16::<LittleEndian>(13).unwrap();
118        data.write_i16::<LittleEndian>(15).unwrap();
119
120        let matrix = ConnectionCostMatrix::load(data);
121        assert_eq!(matrix.forward_size, 2);
122        assert_eq!(matrix.backward_size, 3);
123        assert_eq!(matrix.cost(0, 0), 10);
124        assert_eq!(matrix.cost(1, 0), 11);
125        assert_eq!(matrix.cost(0, 1), 12);
126        assert_eq!(matrix.cost(1, 1), 13);
127        assert_eq!(matrix.cost(0, 2), 14);
128        assert_eq!(matrix.cost(1, 2), 15);
129    }
130}