lindera_dictionary/dictionary/
connection_cost_matrix.rs1use crate::{LinderaResult, error::LinderaErrorKind, util::Data};
2
3use byteorder::{ByteOrder, LittleEndian};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5
6#[derive(Clone, Archive, RkyvSerialize, RkyvDeserialize)]
7pub struct ConnectionCostMatrix {
8 pub costs_data: Vec<i16>,
12 pub backward_size: u32,
13 pub forward_size: u32,
14}
15
16impl ConnectionCostMatrix {
17 pub fn load(conn_data: impl Into<Data>) -> LinderaResult<ConnectionCostMatrix> {
29 let conn_data = conn_data.into();
30 if conn_data.len() < 4 {
31 return Err(LinderaErrorKind::Deserialize.with_error(anyhow::anyhow!(
32 "Connection cost matrix data too short: {} bytes",
33 conn_data.len()
34 )));
35 }
36
37 let first_v = LittleEndian::read_i16(&conn_data[0..2]);
38
39 if first_v == -1 {
40 if conn_data.len() < 6 {
42 return Err(LinderaErrorKind::Deserialize.with_error(anyhow::anyhow!(
43 "Connection cost matrix header too short for new format: {} bytes",
44 conn_data.len()
45 )));
46 }
47 let forward_size = LittleEndian::read_i16(&conn_data[2..4]) as u32;
48 let backward_size = LittleEndian::read_i16(&conn_data[4..6]) as u32;
49 let size = conn_data.len() / 2 - 3;
50 let mut costs_data = vec![0i16; size];
51 LittleEndian::read_i16_into(&conn_data[6..], &mut costs_data);
52
53 Ok(ConnectionCostMatrix {
54 costs_data,
55 backward_size,
56 forward_size,
57 })
58 } else {
59 let forward_size = first_v as u32;
61 let backward_size = LittleEndian::read_i16(&conn_data[2..4]) as u32;
62 let size = conn_data.len() / 2 - 2;
63 let mut old_costs_data = vec![0i16; size];
64 LittleEndian::read_i16_into(&conn_data[4..], &mut old_costs_data);
65
66 let mut costs_data = vec![0i16; size];
68 for f in 0..forward_size {
69 for b in 0..backward_size {
70 let old_id = (b + f * backward_size) as usize;
71 let new_id = (f + b * forward_size) as usize;
72 costs_data[new_id] = old_costs_data[old_id];
73 }
74 }
75
76 Ok(ConnectionCostMatrix {
77 costs_data,
78 backward_size,
79 forward_size,
80 })
81 }
82 }
83
84 #[inline]
85 pub fn cost(&self, forward_id: u32, backward_id: u32) -> i32 {
86 let cost_id = (forward_id + backward_id * self.forward_size) as usize;
87 self.costs_data[cost_id] as i32
88 }
89}
90
91impl ArchivedConnectionCostMatrix {
92 #[inline]
93 pub fn cost(&self, forward_id: u32, backward_id: u32) -> i32 {
94 let cost_id = (forward_id + backward_id * self.forward_size) as usize;
95 self.costs_data[cost_id].to_native() as i32
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use byteorder::{LittleEndian, WriteBytesExt};
103
104 #[test]
105 fn test_load_transposed() {
106 let mut data = Vec::new();
107 data.write_i16::<LittleEndian>(-1).unwrap(); data.write_i16::<LittleEndian>(2).unwrap(); data.write_i16::<LittleEndian>(3).unwrap(); data.write_i16::<LittleEndian>(10).unwrap();
113 data.write_i16::<LittleEndian>(11).unwrap();
114 data.write_i16::<LittleEndian>(12).unwrap();
115 data.write_i16::<LittleEndian>(13).unwrap();
116 data.write_i16::<LittleEndian>(14).unwrap();
117 data.write_i16::<LittleEndian>(15).unwrap();
118
119 let matrix = ConnectionCostMatrix::load(data).unwrap();
120 assert_eq!(matrix.forward_size, 2);
121 assert_eq!(matrix.backward_size, 3);
122 assert_eq!(matrix.cost(0, 0), 10);
123 assert_eq!(matrix.cost(1, 0), 11);
124 assert_eq!(matrix.cost(0, 1), 12);
125 assert_eq!(matrix.cost(1, 1), 13);
126 assert_eq!(matrix.cost(0, 2), 14);
127 assert_eq!(matrix.cost(1, 2), 15);
128 }
129
130 #[test]
131 fn test_load_old_format() {
132 let mut data = Vec::new();
133 data.write_i16::<LittleEndian>(2).unwrap(); data.write_i16::<LittleEndian>(3).unwrap(); data.write_i16::<LittleEndian>(10).unwrap();
138 data.write_i16::<LittleEndian>(12).unwrap();
139 data.write_i16::<LittleEndian>(14).unwrap();
140 data.write_i16::<LittleEndian>(11).unwrap();
141 data.write_i16::<LittleEndian>(13).unwrap();
142 data.write_i16::<LittleEndian>(15).unwrap();
143
144 let matrix = ConnectionCostMatrix::load(data).unwrap();
145 assert_eq!(matrix.forward_size, 2);
146 assert_eq!(matrix.backward_size, 3);
147 assert_eq!(matrix.cost(0, 0), 10);
148 assert_eq!(matrix.cost(1, 0), 11);
149 assert_eq!(matrix.cost(0, 1), 12);
150 assert_eq!(matrix.cost(1, 1), 13);
151 assert_eq!(matrix.cost(0, 2), 14);
152 assert_eq!(matrix.cost(1, 2), 15);
153 }
154
155 #[test]
156 fn test_load_data_too_short() {
157 let data: Vec<u8> = vec![0x01, 0x02];
158 let result = ConnectionCostMatrix::load(data);
159 assert!(result.is_err());
160 }
161}