lindera_dictionary/dictionary/
connection_cost_matrix.rs1use crate::util::Data;
2
3use byteorder::{ByteOrder, LittleEndian};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5
6#[derive(Clone, Archive, RkyvSerialize, RkyvDeserialize)]
7pub struct ConnectionCostMatrix {
8 pub costs_data: Vec<i16>,
12 pub backward_size: u32,
13 pub forward_size: u32,
14}
15
16impl ConnectionCostMatrix {
17 pub fn load(conn_data: impl Into<Data>) -> ConnectionCostMatrix {
18 let conn_data = conn_data.into();
19 let first_v = LittleEndian::read_i16(&conn_data[0..2]);
20
21 if first_v == -1 {
22 let forward_size = LittleEndian::read_i16(&conn_data[2..4]) as u32;
24 let backward_size = LittleEndian::read_i16(&conn_data[4..6]) as u32;
25 let size = conn_data.len() / 2 - 3;
26 let mut costs_data = vec![0i16; size];
27 LittleEndian::read_i16_into(&conn_data[6..], &mut costs_data);
28
29 ConnectionCostMatrix {
30 costs_data,
31 backward_size,
32 forward_size,
33 }
34 } else {
35 let forward_size = first_v as u32;
37 let backward_size = LittleEndian::read_i16(&conn_data[2..4]) as u32;
38 let size = conn_data.len() / 2 - 2;
39 let mut old_costs_data = vec![0i16; size];
40 LittleEndian::read_i16_into(&conn_data[4..], &mut old_costs_data);
41
42 let mut costs_data = vec![0i16; size];
44 for f in 0..forward_size {
45 for b in 0..backward_size {
46 let old_id = (b + f * backward_size) as usize;
47 let new_id = (f + b * forward_size) as usize;
48 costs_data[new_id] = old_costs_data[old_id];
49 }
50 }
51
52 ConnectionCostMatrix {
53 costs_data,
54 backward_size,
55 forward_size,
56 }
57 }
58 }
59
60 #[inline]
61 pub fn cost(&self, forward_id: u32, backward_id: u32) -> i32 {
62 let cost_id = (forward_id + backward_id * self.forward_size) as usize;
63 self.costs_data[cost_id] as i32
64 }
65}
66
67impl ArchivedConnectionCostMatrix {
68 #[inline]
69 pub fn cost(&self, forward_id: u32, backward_id: u32) -> i32 {
70 let cost_id = (forward_id + backward_id * self.forward_size) as usize;
71 self.costs_data[cost_id].to_native() as i32
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use byteorder::{LittleEndian, WriteBytesExt};
79
80 #[test]
81 fn test_load_transposed() {
82 let mut data = Vec::new();
83 data.write_i16::<LittleEndian>(-1).unwrap(); data.write_i16::<LittleEndian>(2).unwrap(); data.write_i16::<LittleEndian>(3).unwrap(); data.write_i16::<LittleEndian>(10).unwrap();
89 data.write_i16::<LittleEndian>(11).unwrap();
90 data.write_i16::<LittleEndian>(12).unwrap();
91 data.write_i16::<LittleEndian>(13).unwrap();
92 data.write_i16::<LittleEndian>(14).unwrap();
93 data.write_i16::<LittleEndian>(15).unwrap();
94
95 let matrix = ConnectionCostMatrix::load(data);
96 assert_eq!(matrix.forward_size, 2);
97 assert_eq!(matrix.backward_size, 3);
98 assert_eq!(matrix.cost(0, 0), 10);
99 assert_eq!(matrix.cost(1, 0), 11);
100 assert_eq!(matrix.cost(0, 1), 12);
101 assert_eq!(matrix.cost(1, 1), 13);
102 assert_eq!(matrix.cost(0, 2), 14);
103 assert_eq!(matrix.cost(1, 2), 15);
104 }
105
106 #[test]
107 fn test_load_old_format() {
108 let mut data = Vec::new();
109 data.write_i16::<LittleEndian>(2).unwrap(); data.write_i16::<LittleEndian>(3).unwrap(); data.write_i16::<LittleEndian>(10).unwrap();
114 data.write_i16::<LittleEndian>(12).unwrap();
115 data.write_i16::<LittleEndian>(14).unwrap();
116 data.write_i16::<LittleEndian>(11).unwrap();
117 data.write_i16::<LittleEndian>(13).unwrap();
118 data.write_i16::<LittleEndian>(15).unwrap();
119
120 let matrix = ConnectionCostMatrix::load(data);
121 assert_eq!(matrix.forward_size, 2);
122 assert_eq!(matrix.backward_size, 3);
123 assert_eq!(matrix.cost(0, 0), 10);
124 assert_eq!(matrix.cost(1, 0), 11);
125 assert_eq!(matrix.cost(0, 1), 12);
126 assert_eq!(matrix.cost(1, 1), 13);
127 assert_eq!(matrix.cost(0, 2), 14);
128 assert_eq!(matrix.cost(1, 2), 15);
129 }
130}