openigtlink_rust/protocol/types/
ndarray.rs

1//! NDARRAY message type implementation
2//!
3//! The NDARRAY message type is used to transfer N-dimensional numerical arrays.
4
5use crate::protocol::message::Message;
6use crate::error::{IgtlError, Result};
7use bytes::{Buf, BufMut};
8
9/// Scalar data type for NDARRAY
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11#[repr(u8)]
12pub enum ScalarType {
13    Int8 = 2,
14    Uint8 = 3,
15    Int16 = 4,
16    Uint16 = 5,
17    Int32 = 6,
18    Uint32 = 7,
19    Float32 = 10,
20    Float64 = 11,
21}
22
23impl ScalarType {
24    fn from_u8(value: u8) -> Result<Self> {
25        match value {
26            2 => Ok(ScalarType::Int8),
27            3 => Ok(ScalarType::Uint8),
28            4 => Ok(ScalarType::Int16),
29            5 => Ok(ScalarType::Uint16),
30            6 => Ok(ScalarType::Int32),
31            7 => Ok(ScalarType::Uint32),
32            10 => Ok(ScalarType::Float32),
33            11 => Ok(ScalarType::Float64),
34            _ => Err(IgtlError::InvalidHeader(format!(
35                "Invalid scalar type: {}",
36                value
37            ))),
38        }
39    }
40
41    /// Get the size in bytes of this scalar type
42    pub fn size(&self) -> usize {
43        match self {
44            ScalarType::Int8 | ScalarType::Uint8 => 1,
45            ScalarType::Int16 | ScalarType::Uint16 => 2,
46            ScalarType::Int32 | ScalarType::Uint32 | ScalarType::Float32 => 4,
47            ScalarType::Float64 => 8,
48        }
49    }
50}
51
52/// NDARRAY message containing an N-dimensional numerical array
53///
54/// # OpenIGTLink Specification
55/// - Message type: "NDARRAY"
56/// - Body format: SCALAR_TYPE (uint8) + DIM (uint8) + SIZE (uint16[DIM]) + DATA (bytes)
57/// - Data layout: Row-major order (C-style)
58#[derive(Debug, Clone, PartialEq)]
59pub struct NdArrayMessage {
60    /// Scalar data type
61    pub scalar_type: ScalarType,
62    /// Array dimensions
63    pub size: Vec<u16>,
64    /// Raw array data in network byte order
65    pub data: Vec<u8>,
66}
67
68impl NdArrayMessage {
69    /// Create a new NDARRAY message
70    pub fn new(scalar_type: ScalarType, size: Vec<u16>, data: Vec<u8>) -> Result<Self> {
71        if size.is_empty() || size.len() > 255 {
72            return Err(IgtlError::InvalidHeader(format!(
73                "Invalid dimension count: {}",
74                size.len()
75            )));
76        }
77
78        // Calculate expected data size
79        let expected_size: usize = size.iter().map(|&s| s as usize).product::<usize>() * scalar_type.size();
80
81        if data.len() != expected_size {
82            return Err(IgtlError::InvalidSize {
83                expected: expected_size,
84                actual: data.len(),
85            });
86        }
87
88        Ok(NdArrayMessage {
89            scalar_type,
90            size,
91            data,
92        })
93    }
94
95    /// Create a 1D array
96    pub fn new_1d(scalar_type: ScalarType, data: Vec<u8>) -> Result<Self> {
97        let element_count = data.len() / scalar_type.size();
98        Self::new(scalar_type, vec![element_count as u16], data)
99    }
100
101    /// Create a 2D array
102    pub fn new_2d(scalar_type: ScalarType, rows: u16, cols: u16, data: Vec<u8>) -> Result<Self> {
103        Self::new(scalar_type, vec![rows, cols], data)
104    }
105
106    /// Create a 3D array
107    pub fn new_3d(scalar_type: ScalarType, dim1: u16, dim2: u16, dim3: u16, data: Vec<u8>) -> Result<Self> {
108        Self::new(scalar_type, vec![dim1, dim2, dim3], data)
109    }
110
111    /// Get the number of dimensions
112    pub fn ndim(&self) -> usize {
113        self.size.len()
114    }
115
116    /// Get total number of elements
117    pub fn element_count(&self) -> usize {
118        self.size.iter().map(|&s| s as usize).product()
119    }
120
121    /// Get total data size in bytes
122    pub fn data_size(&self) -> usize {
123        self.data.len()
124    }
125}
126
127impl Message for NdArrayMessage {
128    fn message_type() -> &'static str {
129        "NDARRAY"
130    }
131
132    fn encode_content(&self) -> Result<Vec<u8>> {
133        let dim = self.size.len();
134        if dim == 0 || dim > 255 {
135            return Err(IgtlError::InvalidHeader(format!(
136                "Invalid dimension count: {}",
137                dim
138            )));
139        }
140
141        let mut buf = Vec::with_capacity(2 + dim * 2 + self.data.len());
142
143        // Encode SCALAR_TYPE (uint8)
144        buf.put_u8(self.scalar_type as u8);
145
146        // Encode DIM (uint8)
147        buf.put_u8(dim as u8);
148
149        // Encode SIZE (uint16[DIM])
150        for &s in &self.size {
151            buf.put_u16(s);
152        }
153
154        // Encode DATA
155        buf.extend_from_slice(&self.data);
156
157        Ok(buf)
158    }
159
160    fn decode_content(mut data: &[u8]) -> Result<Self> {
161        if data.len() < 2 {
162            return Err(IgtlError::InvalidSize {
163                expected: 2,
164                actual: data.len(),
165            });
166        }
167
168        // Decode SCALAR_TYPE
169        let scalar_type = ScalarType::from_u8(data.get_u8())?;
170
171        // Decode DIM
172        let dim = data.get_u8() as usize;
173
174        if dim == 0 {
175            return Err(IgtlError::InvalidHeader("Dimension cannot be zero".to_string()));
176        }
177
178        // Check we have enough data for SIZE array
179        if data.len() < dim * 2 {
180            return Err(IgtlError::InvalidSize {
181                expected: dim * 2,
182                actual: data.len(),
183            });
184        }
185
186        // Decode SIZE
187        let mut size = Vec::with_capacity(dim);
188        for _ in 0..dim {
189            size.push(data.get_u16());
190        }
191
192        // Calculate expected data size
193        let expected_data_size: usize = size.iter().map(|&s| s as usize).product::<usize>() * scalar_type.size();
194
195        if data.len() < expected_data_size {
196            return Err(IgtlError::InvalidSize {
197                expected: expected_data_size,
198                actual: data.len(),
199            });
200        }
201
202        // Decode DATA
203        let array_data = data[..expected_data_size].to_vec();
204
205        Ok(NdArrayMessage {
206            scalar_type,
207            size,
208            data: array_data,
209        })
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_message_type() {
219        assert_eq!(NdArrayMessage::message_type(), "NDARRAY");
220    }
221
222    #[test]
223    fn test_scalar_type_size() {
224        assert_eq!(ScalarType::Int8.size(), 1);
225        assert_eq!(ScalarType::Uint8.size(), 1);
226        assert_eq!(ScalarType::Int16.size(), 2);
227        assert_eq!(ScalarType::Uint16.size(), 2);
228        assert_eq!(ScalarType::Int32.size(), 4);
229        assert_eq!(ScalarType::Uint32.size(), 4);
230        assert_eq!(ScalarType::Float32.size(), 4);
231        assert_eq!(ScalarType::Float64.size(), 8);
232    }
233
234    #[test]
235    fn test_new_1d() {
236        let data = vec![1u8, 2, 3, 4];
237        let msg = NdArrayMessage::new_1d(ScalarType::Uint8, data).unwrap();
238
239        assert_eq!(msg.ndim(), 1);
240        assert_eq!(msg.size[0], 4);
241        assert_eq!(msg.element_count(), 4);
242    }
243
244    #[test]
245    fn test_new_2d() {
246        let data = vec![0u8; 12]; // 3x4 matrix of uint8
247        let msg = NdArrayMessage::new_2d(ScalarType::Uint8, 3, 4, data).unwrap();
248
249        assert_eq!(msg.ndim(), 2);
250        assert_eq!(msg.size, vec![3, 4]);
251        assert_eq!(msg.element_count(), 12);
252    }
253
254    #[test]
255    fn test_new_3d() {
256        let data = vec![0u8; 24]; // 2x3x4 array of uint8
257        let msg = NdArrayMessage::new_3d(ScalarType::Uint8, 2, 3, 4, data).unwrap();
258
259        assert_eq!(msg.ndim(), 3);
260        assert_eq!(msg.size, vec![2, 3, 4]);
261        assert_eq!(msg.element_count(), 24);
262    }
263
264    #[test]
265    fn test_invalid_data_size() {
266        let data = vec![0u8; 10]; // Wrong size for 3x4 array
267        let result = NdArrayMessage::new_2d(ScalarType::Uint8, 3, 4, data);
268        assert!(result.is_err());
269    }
270
271    #[test]
272    fn test_encode_1d_uint8() {
273        let data = vec![1u8, 2, 3];
274        let msg = NdArrayMessage::new_1d(ScalarType::Uint8, data).unwrap();
275        let encoded = msg.encode_content().unwrap();
276
277        assert_eq!(encoded[0], 3); // SCALAR_TYPE = Uint8
278        assert_eq!(encoded[1], 1); // DIM = 1
279        assert_eq!(u16::from_be_bytes([encoded[2], encoded[3]]), 3); // SIZE[0] = 3
280        assert_eq!(&encoded[4..], &[1, 2, 3]);
281    }
282
283    #[test]
284    fn test_roundtrip_1d() {
285        let original_data = vec![10u8, 20, 30, 40];
286        let original = NdArrayMessage::new_1d(ScalarType::Uint8, original_data.clone()).unwrap();
287
288        let encoded = original.encode_content().unwrap();
289        let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
290
291        assert_eq!(decoded.scalar_type, ScalarType::Uint8);
292        assert_eq!(decoded.size, vec![4]);
293        assert_eq!(decoded.data, original_data);
294    }
295
296    #[test]
297    fn test_roundtrip_2d() {
298        let data = vec![1u8, 2, 3, 4, 5, 6]; // 2x3 matrix
299        let original = NdArrayMessage::new_2d(ScalarType::Uint8, 2, 3, data.clone()).unwrap();
300
301        let encoded = original.encode_content().unwrap();
302        let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
303
304        assert_eq!(decoded.size, vec![2, 3]);
305        assert_eq!(decoded.data, data);
306    }
307
308    #[test]
309    fn test_roundtrip_float32() {
310        // Create 2x2 float32 array
311        let mut data = Vec::new();
312        for val in [1.0f32, 2.0, 3.0, 4.0] {
313            data.extend_from_slice(&val.to_be_bytes());
314        }
315
316        let original = NdArrayMessage::new_2d(ScalarType::Float32, 2, 2, data.clone()).unwrap();
317        let encoded = original.encode_content().unwrap();
318        let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
319
320        assert_eq!(decoded.scalar_type, ScalarType::Float32);
321        assert_eq!(decoded.size, vec![2, 2]);
322        assert_eq!(decoded.data, data);
323    }
324
325    #[test]
326    fn test_decode_invalid_header() {
327        let data = vec![0u8]; // Too short
328        let result = NdArrayMessage::decode_content(&data);
329        assert!(result.is_err());
330    }
331
332    #[test]
333    fn test_decode_truncated_data() {
334        let mut data = vec![3u8, 1]; // SCALAR_TYPE=Uint8, DIM=1
335        data.extend_from_slice(&5u16.to_be_bytes()); // SIZE[0]=5
336        data.extend_from_slice(&[1, 2, 3]); // Only 3 bytes instead of 5
337
338        let result = NdArrayMessage::decode_content(&data);
339        assert!(result.is_err());
340    }
341}