openigtlink_rust/protocol/types/
ndarray.rs

1//! NDARRAY message type implementation
2//!
3//! The NDARRAY message type is used to transfer N-dimensional numerical arrays.
4
5use crate::error::{IgtlError, Result};
6use crate::protocol::message::Message;
7use bytes::{Buf, BufMut};
8
9/// Scalar data type for NDARRAY
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11#[repr(u8)]
12pub enum ScalarType {
13    Int8 = 2,
14    Uint8 = 3,
15    Int16 = 4,
16    Uint16 = 5,
17    Int32 = 6,
18    Uint32 = 7,
19    Float32 = 10,
20    Float64 = 11,
21}
22
23impl ScalarType {
24    fn from_u8(value: u8) -> Result<Self> {
25        match value {
26            2 => Ok(ScalarType::Int8),
27            3 => Ok(ScalarType::Uint8),
28            4 => Ok(ScalarType::Int16),
29            5 => Ok(ScalarType::Uint16),
30            6 => Ok(ScalarType::Int32),
31            7 => Ok(ScalarType::Uint32),
32            10 => Ok(ScalarType::Float32),
33            11 => Ok(ScalarType::Float64),
34            _ => Err(IgtlError::InvalidHeader(format!(
35                "Invalid scalar type: {}",
36                value
37            ))),
38        }
39    }
40
41    /// Get the size in bytes of this scalar type
42    pub fn size(&self) -> usize {
43        match self {
44            ScalarType::Int8 | ScalarType::Uint8 => 1,
45            ScalarType::Int16 | ScalarType::Uint16 => 2,
46            ScalarType::Int32 | ScalarType::Uint32 | ScalarType::Float32 => 4,
47            ScalarType::Float64 => 8,
48        }
49    }
50}
51
52/// NDARRAY message containing an N-dimensional numerical array
53///
54/// # OpenIGTLink Specification
55/// - Message type: "NDARRAY"
56/// - Body format: SCALAR_TYPE (uint8) + DIM (uint8) + SIZE (`uint16[DIM]`) + DATA (bytes)
57/// - Data layout: Row-major order (C-style)
58#[derive(Debug, Clone, PartialEq)]
59pub struct NdArrayMessage {
60    /// Scalar data type
61    pub scalar_type: ScalarType,
62    /// Array dimensions
63    pub size: Vec<u16>,
64    /// Raw array data in network byte order
65    pub data: Vec<u8>,
66}
67
68impl NdArrayMessage {
69    /// Create a new NDARRAY message
70    pub fn new(scalar_type: ScalarType, size: Vec<u16>, data: Vec<u8>) -> Result<Self> {
71        if size.is_empty() || size.len() > 255 {
72            return Err(IgtlError::InvalidHeader(format!(
73                "Invalid dimension count: {}",
74                size.len()
75            )));
76        }
77
78        // Calculate expected data size
79        let expected_size: usize =
80            size.iter().map(|&s| s as usize).product::<usize>() * scalar_type.size();
81
82        if data.len() != expected_size {
83            return Err(IgtlError::InvalidSize {
84                expected: expected_size,
85                actual: data.len(),
86            });
87        }
88
89        Ok(NdArrayMessage {
90            scalar_type,
91            size,
92            data,
93        })
94    }
95
96    /// Create a 1D array
97    pub fn new_1d(scalar_type: ScalarType, data: Vec<u8>) -> Result<Self> {
98        let element_count = data.len() / scalar_type.size();
99        Self::new(scalar_type, vec![element_count as u16], data)
100    }
101
102    /// Create a 2D array
103    pub fn new_2d(scalar_type: ScalarType, rows: u16, cols: u16, data: Vec<u8>) -> Result<Self> {
104        Self::new(scalar_type, vec![rows, cols], data)
105    }
106
107    /// Create a 3D array
108    pub fn new_3d(
109        scalar_type: ScalarType,
110        dim1: u16,
111        dim2: u16,
112        dim3: u16,
113        data: Vec<u8>,
114    ) -> Result<Self> {
115        Self::new(scalar_type, vec![dim1, dim2, dim3], data)
116    }
117
118    /// Get the number of dimensions
119    pub fn ndim(&self) -> usize {
120        self.size.len()
121    }
122
123    /// Get total number of elements
124    pub fn element_count(&self) -> usize {
125        self.size.iter().map(|&s| s as usize).product()
126    }
127
128    /// Get total data size in bytes
129    pub fn data_size(&self) -> usize {
130        self.data.len()
131    }
132}
133
134impl Message for NdArrayMessage {
135    fn message_type() -> &'static str {
136        "NDARRAY"
137    }
138
139    fn encode_content(&self) -> Result<Vec<u8>> {
140        let dim = self.size.len();
141        if dim == 0 || dim > 255 {
142            return Err(IgtlError::InvalidHeader(format!(
143                "Invalid dimension count: {}",
144                dim
145            )));
146        }
147
148        let mut buf = Vec::with_capacity(2 + dim * 2 + self.data.len());
149
150        // Encode SCALAR_TYPE (uint8)
151        buf.put_u8(self.scalar_type as u8);
152
153        // Encode DIM (uint8)
154        buf.put_u8(dim as u8);
155
156        // Encode SIZE (`uint16[DIM]`)
157        for &s in &self.size {
158            buf.put_u16(s);
159        }
160
161        // Encode DATA
162        buf.extend_from_slice(&self.data);
163
164        Ok(buf)
165    }
166
167    fn decode_content(mut data: &[u8]) -> Result<Self> {
168        if data.len() < 2 {
169            return Err(IgtlError::InvalidSize {
170                expected: 2,
171                actual: data.len(),
172            });
173        }
174
175        // Decode SCALAR_TYPE
176        let scalar_type = ScalarType::from_u8(data.get_u8())?;
177
178        // Decode DIM
179        let dim = data.get_u8() as usize;
180
181        if dim == 0 {
182            return Err(IgtlError::InvalidHeader(
183                "Dimension cannot be zero".to_string(),
184            ));
185        }
186
187        // Check we have enough data for SIZE array
188        if data.len() < dim * 2 {
189            return Err(IgtlError::InvalidSize {
190                expected: dim * 2,
191                actual: data.len(),
192            });
193        }
194
195        // Decode SIZE
196        let mut size = Vec::with_capacity(dim);
197        for _ in 0..dim {
198            size.push(data.get_u16());
199        }
200
201        // Calculate expected data size
202        let expected_data_size: usize =
203            size.iter().map(|&s| s as usize).product::<usize>() * scalar_type.size();
204
205        if data.len() < expected_data_size {
206            return Err(IgtlError::InvalidSize {
207                expected: expected_data_size,
208                actual: data.len(),
209            });
210        }
211
212        // Decode DATA
213        let array_data = data[..expected_data_size].to_vec();
214
215        Ok(NdArrayMessage {
216            scalar_type,
217            size,
218            data: array_data,
219        })
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn test_message_type() {
229        assert_eq!(NdArrayMessage::message_type(), "NDARRAY");
230    }
231
232    #[test]
233    fn test_scalar_type_size() {
234        assert_eq!(ScalarType::Int8.size(), 1);
235        assert_eq!(ScalarType::Uint8.size(), 1);
236        assert_eq!(ScalarType::Int16.size(), 2);
237        assert_eq!(ScalarType::Uint16.size(), 2);
238        assert_eq!(ScalarType::Int32.size(), 4);
239        assert_eq!(ScalarType::Uint32.size(), 4);
240        assert_eq!(ScalarType::Float32.size(), 4);
241        assert_eq!(ScalarType::Float64.size(), 8);
242    }
243
244    #[test]
245    fn test_new_1d() {
246        let data = vec![1u8, 2, 3, 4];
247        let msg = NdArrayMessage::new_1d(ScalarType::Uint8, data).unwrap();
248
249        assert_eq!(msg.ndim(), 1);
250        assert_eq!(msg.size[0], 4);
251        assert_eq!(msg.element_count(), 4);
252    }
253
254    #[test]
255    fn test_new_2d() {
256        let data = vec![0u8; 12]; // 3x4 matrix of uint8
257        let msg = NdArrayMessage::new_2d(ScalarType::Uint8, 3, 4, data).unwrap();
258
259        assert_eq!(msg.ndim(), 2);
260        assert_eq!(msg.size, vec![3, 4]);
261        assert_eq!(msg.element_count(), 12);
262    }
263
264    #[test]
265    fn test_new_3d() {
266        let data = vec![0u8; 24]; // 2x3x4 array of uint8
267        let msg = NdArrayMessage::new_3d(ScalarType::Uint8, 2, 3, 4, data).unwrap();
268
269        assert_eq!(msg.ndim(), 3);
270        assert_eq!(msg.size, vec![2, 3, 4]);
271        assert_eq!(msg.element_count(), 24);
272    }
273
274    #[test]
275    fn test_invalid_data_size() {
276        let data = vec![0u8; 10]; // Wrong size for 3x4 array
277        let result = NdArrayMessage::new_2d(ScalarType::Uint8, 3, 4, data);
278        assert!(result.is_err());
279    }
280
281    #[test]
282    fn test_encode_1d_uint8() {
283        let data = vec![1u8, 2, 3];
284        let msg = NdArrayMessage::new_1d(ScalarType::Uint8, data).unwrap();
285        let encoded = msg.encode_content().unwrap();
286
287        assert_eq!(encoded[0], 3); // SCALAR_TYPE = Uint8
288        assert_eq!(encoded[1], 1); // DIM = 1
289        assert_eq!(u16::from_be_bytes([encoded[2], encoded[3]]), 3); // SIZE[0] = 3
290        assert_eq!(&encoded[4..], &[1, 2, 3]);
291    }
292
293    #[test]
294    fn test_roundtrip_1d() {
295        let original_data = vec![10u8, 20, 30, 40];
296        let original = NdArrayMessage::new_1d(ScalarType::Uint8, original_data.clone()).unwrap();
297
298        let encoded = original.encode_content().unwrap();
299        let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
300
301        assert_eq!(decoded.scalar_type, ScalarType::Uint8);
302        assert_eq!(decoded.size, vec![4]);
303        assert_eq!(decoded.data, original_data);
304    }
305
306    #[test]
307    fn test_roundtrip_2d() {
308        let data = vec![1u8, 2, 3, 4, 5, 6]; // 2x3 matrix
309        let original = NdArrayMessage::new_2d(ScalarType::Uint8, 2, 3, data.clone()).unwrap();
310
311        let encoded = original.encode_content().unwrap();
312        let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
313
314        assert_eq!(decoded.size, vec![2, 3]);
315        assert_eq!(decoded.data, data);
316    }
317
318    #[test]
319    fn test_roundtrip_float32() {
320        // Create 2x2 float32 array
321        let mut data = Vec::new();
322        for val in [1.0f32, 2.0, 3.0, 4.0] {
323            data.extend_from_slice(&val.to_be_bytes());
324        }
325
326        let original = NdArrayMessage::new_2d(ScalarType::Float32, 2, 2, data.clone()).unwrap();
327        let encoded = original.encode_content().unwrap();
328        let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
329
330        assert_eq!(decoded.scalar_type, ScalarType::Float32);
331        assert_eq!(decoded.size, vec![2, 2]);
332        assert_eq!(decoded.data, data);
333    }
334
335    #[test]
336    fn test_decode_invalid_header() {
337        let data = vec![0u8]; // Too short
338        let result = NdArrayMessage::decode_content(&data);
339        assert!(result.is_err());
340    }
341
342    #[test]
343    fn test_decode_truncated_data() {
344        let mut data = vec![3u8, 1]; // SCALAR_TYPE=Uint8, DIM=1
345        data.extend_from_slice(&5u16.to_be_bytes()); // SIZE[0]=5
346        data.extend_from_slice(&[1, 2, 3]); // Only 3 bytes instead of 5
347
348        let result = NdArrayMessage::decode_content(&data);
349        assert!(result.is_err());
350    }
351}