openigtlink_rust/protocol/types/
ndarray.rs1use crate::error::{IgtlError, Result};
6use crate::protocol::message::Message;
7use bytes::{Buf, BufMut};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11#[repr(u8)]
12pub enum ScalarType {
13 Int8 = 2,
14 Uint8 = 3,
15 Int16 = 4,
16 Uint16 = 5,
17 Int32 = 6,
18 Uint32 = 7,
19 Float32 = 10,
20 Float64 = 11,
21}
22
23impl ScalarType {
24 fn from_u8(value: u8) -> Result<Self> {
25 match value {
26 2 => Ok(ScalarType::Int8),
27 3 => Ok(ScalarType::Uint8),
28 4 => Ok(ScalarType::Int16),
29 5 => Ok(ScalarType::Uint16),
30 6 => Ok(ScalarType::Int32),
31 7 => Ok(ScalarType::Uint32),
32 10 => Ok(ScalarType::Float32),
33 11 => Ok(ScalarType::Float64),
34 _ => Err(IgtlError::InvalidHeader(format!(
35 "Invalid scalar type: {}",
36 value
37 ))),
38 }
39 }
40
41 pub fn size(&self) -> usize {
43 match self {
44 ScalarType::Int8 | ScalarType::Uint8 => 1,
45 ScalarType::Int16 | ScalarType::Uint16 => 2,
46 ScalarType::Int32 | ScalarType::Uint32 | ScalarType::Float32 => 4,
47 ScalarType::Float64 => 8,
48 }
49 }
50}
51
52#[derive(Debug, Clone, PartialEq)]
59pub struct NdArrayMessage {
60 pub scalar_type: ScalarType,
62 pub size: Vec<u16>,
64 pub data: Vec<u8>,
66}
67
68impl NdArrayMessage {
69 pub fn new(scalar_type: ScalarType, size: Vec<u16>, data: Vec<u8>) -> Result<Self> {
71 if size.is_empty() || size.len() > 255 {
72 return Err(IgtlError::InvalidHeader(format!(
73 "Invalid dimension count: {}",
74 size.len()
75 )));
76 }
77
78 let expected_size: usize =
80 size.iter().map(|&s| s as usize).product::<usize>() * scalar_type.size();
81
82 if data.len() != expected_size {
83 return Err(IgtlError::InvalidSize {
84 expected: expected_size,
85 actual: data.len(),
86 });
87 }
88
89 Ok(NdArrayMessage {
90 scalar_type,
91 size,
92 data,
93 })
94 }
95
96 pub fn new_1d(scalar_type: ScalarType, data: Vec<u8>) -> Result<Self> {
98 let element_count = data.len() / scalar_type.size();
99 Self::new(scalar_type, vec![element_count as u16], data)
100 }
101
102 pub fn new_2d(scalar_type: ScalarType, rows: u16, cols: u16, data: Vec<u8>) -> Result<Self> {
104 Self::new(scalar_type, vec![rows, cols], data)
105 }
106
107 pub fn new_3d(
109 scalar_type: ScalarType,
110 dim1: u16,
111 dim2: u16,
112 dim3: u16,
113 data: Vec<u8>,
114 ) -> Result<Self> {
115 Self::new(scalar_type, vec![dim1, dim2, dim3], data)
116 }
117
118 pub fn ndim(&self) -> usize {
120 self.size.len()
121 }
122
123 pub fn element_count(&self) -> usize {
125 self.size.iter().map(|&s| s as usize).product()
126 }
127
128 pub fn data_size(&self) -> usize {
130 self.data.len()
131 }
132}
133
134impl Message for NdArrayMessage {
135 fn message_type() -> &'static str {
136 "NDARRAY"
137 }
138
139 fn encode_content(&self) -> Result<Vec<u8>> {
140 let dim = self.size.len();
141 if dim == 0 || dim > 255 {
142 return Err(IgtlError::InvalidHeader(format!(
143 "Invalid dimension count: {}",
144 dim
145 )));
146 }
147
148 let mut buf = Vec::with_capacity(2 + dim * 2 + self.data.len());
149
150 buf.put_u8(self.scalar_type as u8);
152
153 buf.put_u8(dim as u8);
155
156 for &s in &self.size {
158 buf.put_u16(s);
159 }
160
161 buf.extend_from_slice(&self.data);
163
164 Ok(buf)
165 }
166
167 fn decode_content(mut data: &[u8]) -> Result<Self> {
168 if data.len() < 2 {
169 return Err(IgtlError::InvalidSize {
170 expected: 2,
171 actual: data.len(),
172 });
173 }
174
175 let scalar_type = ScalarType::from_u8(data.get_u8())?;
177
178 let dim = data.get_u8() as usize;
180
181 if dim == 0 {
182 return Err(IgtlError::InvalidHeader(
183 "Dimension cannot be zero".to_string(),
184 ));
185 }
186
187 if data.len() < dim * 2 {
189 return Err(IgtlError::InvalidSize {
190 expected: dim * 2,
191 actual: data.len(),
192 });
193 }
194
195 let mut size = Vec::with_capacity(dim);
197 for _ in 0..dim {
198 size.push(data.get_u16());
199 }
200
201 let expected_data_size: usize =
203 size.iter().map(|&s| s as usize).product::<usize>() * scalar_type.size();
204
205 if data.len() < expected_data_size {
206 return Err(IgtlError::InvalidSize {
207 expected: expected_data_size,
208 actual: data.len(),
209 });
210 }
211
212 let array_data = data[..expected_data_size].to_vec();
214
215 Ok(NdArrayMessage {
216 scalar_type,
217 size,
218 data: array_data,
219 })
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn test_message_type() {
229 assert_eq!(NdArrayMessage::message_type(), "NDARRAY");
230 }
231
232 #[test]
233 fn test_scalar_type_size() {
234 assert_eq!(ScalarType::Int8.size(), 1);
235 assert_eq!(ScalarType::Uint8.size(), 1);
236 assert_eq!(ScalarType::Int16.size(), 2);
237 assert_eq!(ScalarType::Uint16.size(), 2);
238 assert_eq!(ScalarType::Int32.size(), 4);
239 assert_eq!(ScalarType::Uint32.size(), 4);
240 assert_eq!(ScalarType::Float32.size(), 4);
241 assert_eq!(ScalarType::Float64.size(), 8);
242 }
243
244 #[test]
245 fn test_new_1d() {
246 let data = vec![1u8, 2, 3, 4];
247 let msg = NdArrayMessage::new_1d(ScalarType::Uint8, data).unwrap();
248
249 assert_eq!(msg.ndim(), 1);
250 assert_eq!(msg.size[0], 4);
251 assert_eq!(msg.element_count(), 4);
252 }
253
254 #[test]
255 fn test_new_2d() {
256 let data = vec![0u8; 12]; let msg = NdArrayMessage::new_2d(ScalarType::Uint8, 3, 4, data).unwrap();
258
259 assert_eq!(msg.ndim(), 2);
260 assert_eq!(msg.size, vec![3, 4]);
261 assert_eq!(msg.element_count(), 12);
262 }
263
264 #[test]
265 fn test_new_3d() {
266 let data = vec![0u8; 24]; let msg = NdArrayMessage::new_3d(ScalarType::Uint8, 2, 3, 4, data).unwrap();
268
269 assert_eq!(msg.ndim(), 3);
270 assert_eq!(msg.size, vec![2, 3, 4]);
271 assert_eq!(msg.element_count(), 24);
272 }
273
274 #[test]
275 fn test_invalid_data_size() {
276 let data = vec![0u8; 10]; let result = NdArrayMessage::new_2d(ScalarType::Uint8, 3, 4, data);
278 assert!(result.is_err());
279 }
280
281 #[test]
282 fn test_encode_1d_uint8() {
283 let data = vec![1u8, 2, 3];
284 let msg = NdArrayMessage::new_1d(ScalarType::Uint8, data).unwrap();
285 let encoded = msg.encode_content().unwrap();
286
287 assert_eq!(encoded[0], 3); assert_eq!(encoded[1], 1); assert_eq!(u16::from_be_bytes([encoded[2], encoded[3]]), 3); assert_eq!(&encoded[4..], &[1, 2, 3]);
291 }
292
293 #[test]
294 fn test_roundtrip_1d() {
295 let original_data = vec![10u8, 20, 30, 40];
296 let original = NdArrayMessage::new_1d(ScalarType::Uint8, original_data.clone()).unwrap();
297
298 let encoded = original.encode_content().unwrap();
299 let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
300
301 assert_eq!(decoded.scalar_type, ScalarType::Uint8);
302 assert_eq!(decoded.size, vec![4]);
303 assert_eq!(decoded.data, original_data);
304 }
305
306 #[test]
307 fn test_roundtrip_2d() {
308 let data = vec![1u8, 2, 3, 4, 5, 6]; let original = NdArrayMessage::new_2d(ScalarType::Uint8, 2, 3, data.clone()).unwrap();
310
311 let encoded = original.encode_content().unwrap();
312 let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
313
314 assert_eq!(decoded.size, vec![2, 3]);
315 assert_eq!(decoded.data, data);
316 }
317
318 #[test]
319 fn test_roundtrip_float32() {
320 let mut data = Vec::new();
322 for val in [1.0f32, 2.0, 3.0, 4.0] {
323 data.extend_from_slice(&val.to_be_bytes());
324 }
325
326 let original = NdArrayMessage::new_2d(ScalarType::Float32, 2, 2, data.clone()).unwrap();
327 let encoded = original.encode_content().unwrap();
328 let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
329
330 assert_eq!(decoded.scalar_type, ScalarType::Float32);
331 assert_eq!(decoded.size, vec![2, 2]);
332 assert_eq!(decoded.data, data);
333 }
334
335 #[test]
336 fn test_decode_invalid_header() {
337 let data = vec![0u8]; let result = NdArrayMessage::decode_content(&data);
339 assert!(result.is_err());
340 }
341
342 #[test]
343 fn test_decode_truncated_data() {
344 let mut data = vec![3u8, 1]; data.extend_from_slice(&5u16.to_be_bytes()); data.extend_from_slice(&[1, 2, 3]); let result = NdArrayMessage::decode_content(&data);
349 assert!(result.is_err());
350 }
351}