use crate::protocol::message::Message;
use crate::error::{IgtlError, Result};
use bytes::{Buf, BufMut};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum ScalarType {
Int8 = 2,
Uint8 = 3,
Int16 = 4,
Uint16 = 5,
Int32 = 6,
Uint32 = 7,
Float32 = 10,
Float64 = 11,
}
impl ScalarType {
fn from_u8(value: u8) -> Result<Self> {
match value {
2 => Ok(ScalarType::Int8),
3 => Ok(ScalarType::Uint8),
4 => Ok(ScalarType::Int16),
5 => Ok(ScalarType::Uint16),
6 => Ok(ScalarType::Int32),
7 => Ok(ScalarType::Uint32),
10 => Ok(ScalarType::Float32),
11 => Ok(ScalarType::Float64),
_ => Err(IgtlError::InvalidHeader(format!(
"Invalid scalar type: {}",
value
))),
}
}
pub fn size(&self) -> usize {
match self {
ScalarType::Int8 | ScalarType::Uint8 => 1,
ScalarType::Int16 | ScalarType::Uint16 => 2,
ScalarType::Int32 | ScalarType::Uint32 | ScalarType::Float32 => 4,
ScalarType::Float64 => 8,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct NdArrayMessage {
pub scalar_type: ScalarType,
pub size: Vec<u16>,
pub data: Vec<u8>,
}
impl NdArrayMessage {
pub fn new(scalar_type: ScalarType, size: Vec<u16>, data: Vec<u8>) -> Result<Self> {
if size.is_empty() || size.len() > 255 {
return Err(IgtlError::InvalidHeader(format!(
"Invalid dimension count: {}",
size.len()
)));
}
let expected_size: usize = size.iter().map(|&s| s as usize).product::<usize>() * scalar_type.size();
if data.len() != expected_size {
return Err(IgtlError::InvalidSize {
expected: expected_size,
actual: data.len(),
});
}
Ok(NdArrayMessage {
scalar_type,
size,
data,
})
}
pub fn new_1d(scalar_type: ScalarType, data: Vec<u8>) -> Result<Self> {
let element_count = data.len() / scalar_type.size();
Self::new(scalar_type, vec![element_count as u16], data)
}
pub fn new_2d(scalar_type: ScalarType, rows: u16, cols: u16, data: Vec<u8>) -> Result<Self> {
Self::new(scalar_type, vec![rows, cols], data)
}
pub fn new_3d(scalar_type: ScalarType, dim1: u16, dim2: u16, dim3: u16, data: Vec<u8>) -> Result<Self> {
Self::new(scalar_type, vec![dim1, dim2, dim3], data)
}
pub fn ndim(&self) -> usize {
self.size.len()
}
pub fn element_count(&self) -> usize {
self.size.iter().map(|&s| s as usize).product()
}
pub fn data_size(&self) -> usize {
self.data.len()
}
}
impl Message for NdArrayMessage {
fn message_type() -> &'static str {
"NDARRAY"
}
fn encode_content(&self) -> Result<Vec<u8>> {
let dim = self.size.len();
if dim == 0 || dim > 255 {
return Err(IgtlError::InvalidHeader(format!(
"Invalid dimension count: {}",
dim
)));
}
let mut buf = Vec::with_capacity(2 + dim * 2 + self.data.len());
buf.put_u8(self.scalar_type as u8);
buf.put_u8(dim as u8);
for &s in &self.size {
buf.put_u16(s);
}
buf.extend_from_slice(&self.data);
Ok(buf)
}
fn decode_content(mut data: &[u8]) -> Result<Self> {
if data.len() < 2 {
return Err(IgtlError::InvalidSize {
expected: 2,
actual: data.len(),
});
}
let scalar_type = ScalarType::from_u8(data.get_u8())?;
let dim = data.get_u8() as usize;
if dim == 0 {
return Err(IgtlError::InvalidHeader("Dimension cannot be zero".to_string()));
}
if data.len() < dim * 2 {
return Err(IgtlError::InvalidSize {
expected: dim * 2,
actual: data.len(),
});
}
let mut size = Vec::with_capacity(dim);
for _ in 0..dim {
size.push(data.get_u16());
}
let expected_data_size: usize = size.iter().map(|&s| s as usize).product::<usize>() * scalar_type.size();
if data.len() < expected_data_size {
return Err(IgtlError::InvalidSize {
expected: expected_data_size,
actual: data.len(),
});
}
let array_data = data[..expected_data_size].to_vec();
Ok(NdArrayMessage {
scalar_type,
size,
data: array_data,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_type() {
assert_eq!(NdArrayMessage::message_type(), "NDARRAY");
}
#[test]
fn test_scalar_type_size() {
assert_eq!(ScalarType::Int8.size(), 1);
assert_eq!(ScalarType::Uint8.size(), 1);
assert_eq!(ScalarType::Int16.size(), 2);
assert_eq!(ScalarType::Uint16.size(), 2);
assert_eq!(ScalarType::Int32.size(), 4);
assert_eq!(ScalarType::Uint32.size(), 4);
assert_eq!(ScalarType::Float32.size(), 4);
assert_eq!(ScalarType::Float64.size(), 8);
}
#[test]
fn test_new_1d() {
let data = vec![1u8, 2, 3, 4];
let msg = NdArrayMessage::new_1d(ScalarType::Uint8, data).unwrap();
assert_eq!(msg.ndim(), 1);
assert_eq!(msg.size[0], 4);
assert_eq!(msg.element_count(), 4);
}
#[test]
fn test_new_2d() {
let data = vec![0u8; 12]; let msg = NdArrayMessage::new_2d(ScalarType::Uint8, 3, 4, data).unwrap();
assert_eq!(msg.ndim(), 2);
assert_eq!(msg.size, vec![3, 4]);
assert_eq!(msg.element_count(), 12);
}
#[test]
fn test_new_3d() {
let data = vec![0u8; 24]; let msg = NdArrayMessage::new_3d(ScalarType::Uint8, 2, 3, 4, data).unwrap();
assert_eq!(msg.ndim(), 3);
assert_eq!(msg.size, vec![2, 3, 4]);
assert_eq!(msg.element_count(), 24);
}
#[test]
fn test_invalid_data_size() {
let data = vec![0u8; 10]; let result = NdArrayMessage::new_2d(ScalarType::Uint8, 3, 4, data);
assert!(result.is_err());
}
#[test]
fn test_encode_1d_uint8() {
let data = vec![1u8, 2, 3];
let msg = NdArrayMessage::new_1d(ScalarType::Uint8, data).unwrap();
let encoded = msg.encode_content().unwrap();
assert_eq!(encoded[0], 3); assert_eq!(encoded[1], 1); assert_eq!(u16::from_be_bytes([encoded[2], encoded[3]]), 3); assert_eq!(&encoded[4..], &[1, 2, 3]);
}
#[test]
fn test_roundtrip_1d() {
let original_data = vec![10u8, 20, 30, 40];
let original = NdArrayMessage::new_1d(ScalarType::Uint8, original_data.clone()).unwrap();
let encoded = original.encode_content().unwrap();
let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
assert_eq!(decoded.scalar_type, ScalarType::Uint8);
assert_eq!(decoded.size, vec![4]);
assert_eq!(decoded.data, original_data);
}
#[test]
fn test_roundtrip_2d() {
let data = vec![1u8, 2, 3, 4, 5, 6]; let original = NdArrayMessage::new_2d(ScalarType::Uint8, 2, 3, data.clone()).unwrap();
let encoded = original.encode_content().unwrap();
let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
assert_eq!(decoded.size, vec![2, 3]);
assert_eq!(decoded.data, data);
}
#[test]
fn test_roundtrip_float32() {
let mut data = Vec::new();
for val in [1.0f32, 2.0, 3.0, 4.0] {
data.extend_from_slice(&val.to_be_bytes());
}
let original = NdArrayMessage::new_2d(ScalarType::Float32, 2, 2, data.clone()).unwrap();
let encoded = original.encode_content().unwrap();
let decoded = NdArrayMessage::decode_content(&encoded).unwrap();
assert_eq!(decoded.scalar_type, ScalarType::Float32);
assert_eq!(decoded.size, vec![2, 2]);
assert_eq!(decoded.data, data);
}
#[test]
fn test_decode_invalid_header() {
let data = vec![0u8]; let result = NdArrayMessage::decode_content(&data);
assert!(result.is_err());
}
#[test]
fn test_decode_truncated_data() {
let mut data = vec![3u8, 1]; data.extend_from_slice(&5u16.to_be_bytes()); data.extend_from_slice(&[1, 2, 3]);
let result = NdArrayMessage::decode_content(&data);
assert!(result.is_err());
}
}