openigtlink_rust/protocol/types/
ndarray.rs1use crate::protocol::message::Message;
6use crate::error::{IgtlError, Result};
7use bytes::{Buf, BufMut};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11#[repr(u8)]
12pub enum ScalarType {
13 Int8 = 2,
14 Uint8 = 3,
15 Int16 = 4,
16 Uint16 = 5,
17 Int32 = 6,
18 Uint32 = 7,
19 Float32 = 10,
20 Float64 = 11,
21}
22
23impl ScalarType {
24 fn from_u8(value: u8) -> Result<Self> {
25 match value {
26 2 => Ok(ScalarType::Int8),
27 3 => Ok(ScalarType::Uint8),
28 4 => Ok(ScalarType::Int16),
29 5 => Ok(ScalarType::Uint16),
30 6 => Ok(ScalarType::Int32),
31 7 => Ok(ScalarType::Uint32),
32 10 => Ok(ScalarType::Float32),
33 11 => Ok(ScalarType::Float64),
34 _ => Err(IgtlError::InvalidHeader(format!(
35 "Invalid scalar type: {}",
36 value
37 ))),
38 }
39 }
40
41 pub fn size(&self) -> usize {
43 match self {
44 ScalarType::Int8 | ScalarType::Uint8 => 1,
45 ScalarType::Int16 | ScalarType::Uint16 => 2,
46 ScalarType::Int32 | ScalarType::Uint32 | ScalarType::Float32 => 4,
47 ScalarType::Float64 => 8,
48 }
49 }
50}
51
52#[derive(Debug, Clone, PartialEq)]
59pub struct NdArrayMessage {
60 pub scalar_type: ScalarType,
62 pub size: Vec<u16>,
64 pub data: Vec<u8>,
66}
67
68impl NdArrayMessage {
69 pub fn new(scalar_type: ScalarType, size: Vec<u16>, data: Vec<u8>) -> Result<Self> {
71 if size.is_empty() || size.len() > 255 {
72 return Err(IgtlError::InvalidHeader(format!(
73 "Invalid dimension count: {}",
74 size.len()
75 )));
76 }
77
78 let expected_size: usize = size.iter().map(|&s| s as usize).product::<usize>() * scalar_type.size();
80
81 if data.len() != expected_size {
82 return Err(IgtlError::InvalidSize {
83 expected: expected_size,
84 actual: data.len(),
85 });
86 }
87
88 Ok(NdArrayMessage {
89 scalar_type,
90 size,
91 data,
92 })
93 }
94
95 pub fn new_1d(scalar_type: ScalarType, data: Vec<u8>) -> Result<Self> {
97 let element_count = data.len() / scalar_type.size();
98 Self::new(scalar_type, vec![element_count as u16], data)
99 }
100
101 pub fn new_2d(scalar_type: ScalarType, rows: u16, cols: u16, data: Vec<u8>) -> Result<Self> {
103 Self::new(scalar_type, vec![rows, cols], data)
104 }
105
106 pub fn new_3d(scalar_type: ScalarType, dim1: u16, dim2: u16, dim3: u16, data: Vec<u8>) -> Result<Self> {
108 Self::new(scalar_type, vec![dim1, dim2, dim3], data)
109 }
110
111 pub fn ndim(&self) -> usize {
113 self.size.len()
114 }
115
116 pub fn element_count(&self) -> usize {
118 self.size.iter().map(|&s| s as usize).product()
119 }
120
121 pub fn data_size(&self) -> usize {
123 self.data.len()
124 }
125}
126
127impl Message for NdArrayMessage {
128 fn message_type() -> &'static str {
129 "NDARRAY"
130 }
131
132 fn encode_content(&self) -> Result<Vec<u8>> {
133 let dim = self.size.len();
134 if dim == 0 || dim > 255 {
135 return Err(IgtlError::InvalidHeader(format!(
136 "Invalid dimension count: {}",
137 dim
138 )));
139 }
140
141 let mut buf = Vec::with_capacity(2 + dim * 2 + self.data.len());
142
143 buf.put_u8(self.scalar_type as u8);
145
146 buf.put_u8(dim as u8);
148
149 for &s in &self.size {
151 buf.put_u16(s);
152 }
153
154 buf.extend_from_slice(&self.data);
156
157 Ok(buf)
158 }
159
160 fn decode_content(mut data: &[u8]) -> Result<Self> {
161 if data.len() < 2 {
162 return Err(IgtlError::InvalidSize {
163 expected: 2,
164 actual: data.len(),
165 });
166 }
167
168 let scalar_type = ScalarType::from_u8(data.get_u8())?;
170
171 let dim = data.get_u8() as usize;
173
174 if dim == 0 {
175 return Err(IgtlError::InvalidHeader("Dimension cannot be zero".to_string()));
176 }
177
178 if data.len() < dim * 2 {
180 return Err(IgtlError::InvalidSize {
181 expected: dim * 2,
182 actual: data.len(),
183 });
184 }
185
186 let mut size = Vec::with_capacity(dim);
188 for _ in 0..dim {
189 size.push(data.get_u16());
190 }
191
192 let expected_data_size: usize = size.iter().map(|&s| s as usize).product::<usize>() * scalar_type.size();
194
195 if data.len() < expected_data_size {
196 return Err(IgtlError::InvalidSize {
197 expected: expected_data_size,
198 actual: data.len(),
199 });
200 }
201
202 let array_data = data[..expected_data_size].to_vec();
204
205 Ok(NdArrayMessage {
206 scalar_type,
207 size,
208 data: array_data,
209 })
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn test_message_type() {
219 assert_eq!(NdArrayMessage::message_type(), "NDARRAY");
220 }
221
222 #[test]
223 fn test_scalar_type_size() {
224 assert_eq!(ScalarType::Int8.size(), 1);
225 assert_eq!(ScalarType::Uint8.size(), 1);
226 assert_eq!(ScalarType::Int16.size(), 2);
227 assert_eq!(ScalarType::Uint16.size(), 2);
228 assert_eq!(ScalarType::Int32.size(), 4);
229 assert_eq!(ScalarType::Uint32.size(), 4);
230 assert_eq!(ScalarType::Float32.size(), 4);
231 assert_eq!(ScalarType::Float64.size(), 8);
232 }
233
234 #[test]
235 fn test_new_1d() {
236 let data = vec![1u8, 2, 3, 4];
237 let msg = NdArrayMessage::new_1d(ScalarType::Uint8, data).unwrap();
238
239 assert_eq!(msg.ndim(), 1);
240 assert_eq!(msg.size[0], 4);
241 assert_eq!(msg.element_count(), 4);
242 }
243
244 #[test]
245 fn test_new_2d() {
246 let data = vec![0u8; 12]; let msg = NdArrayMessage::new_2d(ScalarType::Uint8, 3, 4, data).unwrap();
248
249 assert_eq!(msg.ndim(), 2);
250 assert_eq!(msg.size, vec![3, 4]);
251 assert_eq!(msg.element_count(), 12);
252 }
253
254 #[test]
255 fn test_new_3d() {
256 let data = vec![0u8; 24]; let msg = NdArrayMessage::new_3d(ScalarType::Uint8, 2, 3, 4, data).unwrap();
258
259 assert_eq!(msg.ndim(), 3);
260 assert_eq!(msg.size, vec![2, 3, 4]);
261 assert_eq!(msg.element_count(), 24);
262 }
263
264 #[test]
265 fn test_invalid_data_size() {
266 let data = vec![0u8; 10]; let result = NdArrayMessage::new_2d(ScalarType::Uint8, 3, 4, data);
268 assert!(result.is_err());
269 }
270
271 #[test]
272 fn test_encode_1d_uint8() {
273 let data = vec![1u8, 2, 3];
274 let msg = NdArrayMessage::new_1d(ScalarType::Uint8, data).unwrap();
275 let encoded = msg.encode_content().unwrap();
276
277 assert_eq!(encoded[0], 3); assert_eq!(encoded[1], 1); assert_eq!(u16::from_be_bytes([encoded[2], encoded[3]]), 3); assert_eq!(&encoded[4..], &[1, 2, 3]);
281 }
282
283 #[test]
284 fn test_roundtrip_1d() {
285 let original_data = vec![10u8, 20, 30, 40];
286 let original = NdArrayMessage::new_1d(ScalarType::Uint8, original_data.clone()).unwrap();
287
288 let encoded = original.encode_content().unwrap();
289 let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
290
291 assert_eq!(decoded.scalar_type, ScalarType::Uint8);
292 assert_eq!(decoded.size, vec![4]);
293 assert_eq!(decoded.data, original_data);
294 }
295
296 #[test]
297 fn test_roundtrip_2d() {
298 let data = vec![1u8, 2, 3, 4, 5, 6]; let original = NdArrayMessage::new_2d(ScalarType::Uint8, 2, 3, data.clone()).unwrap();
300
301 let encoded = original.encode_content().unwrap();
302 let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
303
304 assert_eq!(decoded.size, vec![2, 3]);
305 assert_eq!(decoded.data, data);
306 }
307
308 #[test]
309 fn test_roundtrip_float32() {
310 let mut data = Vec::new();
312 for val in [1.0f32, 2.0, 3.0, 4.0] {
313 data.extend_from_slice(&val.to_be_bytes());
314 }
315
316 let original = NdArrayMessage::new_2d(ScalarType::Float32, 2, 2, data.clone()).unwrap();
317 let encoded = original.encode_content().unwrap();
318 let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
319
320 assert_eq!(decoded.scalar_type, ScalarType::Float32);
321 assert_eq!(decoded.size, vec![2, 2]);
322 assert_eq!(decoded.data, data);
323 }
324
325 #[test]
326 fn test_decode_invalid_header() {
327 let data = vec![0u8]; let result = NdArrayMessage::decode_content(&data);
329 assert!(result.is_err());
330 }
331
332 #[test]
333 fn test_decode_truncated_data() {
334 let mut data = vec![3u8, 1]; data.extend_from_slice(&5u16.to_be_bytes()); data.extend_from_slice(&[1, 2, 3]); let result = NdArrayMessage::decode_content(&data);
339 assert!(result.is_err());
340 }
341}