use crate::error::{Error, Result};
use crate::io::Cursor;
pub const UNLIMITED: u64 = u64::MAX;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DataspaceType {
Null,
Scalar,
Simple,
}
#[derive(Debug, Clone)]
pub struct DataspaceMessage {
pub rank: u8,
pub dims: Vec<u64>,
pub max_dims: Option<Vec<u64>>,
pub dataspace_type: DataspaceType,
}
impl DataspaceMessage {
pub fn num_elements(&self) -> u64 {
if self.dims.is_empty() {
return match self.dataspace_type {
DataspaceType::Scalar => 1,
_ => 0,
};
}
self.dims.iter().product()
}
}
pub fn parse(
cursor: &mut Cursor<'_>,
_offset_size: u8,
length_size: u8,
msg_size: usize,
) -> Result<DataspaceMessage> {
let start = cursor.position();
let version = cursor.read_u8()?;
match version {
1 => parse_v1(cursor, length_size),
2 => parse_v2(cursor, length_size),
v => Err(Error::UnsupportedDataspaceVersion(v)),
}
.and_then(|msg| {
let consumed = (cursor.position() - start) as usize;
if consumed < msg_size {
cursor.skip(msg_size - consumed)?;
}
Ok(msg)
})
}
fn parse_v1(cursor: &mut Cursor<'_>, length_size: u8) -> Result<DataspaceMessage> {
let rank = cursor.read_u8()?;
let flags = cursor.read_u8()?;
let _reserved = cursor.read_u8()?; let _reserved2 = cursor.read_u32_le()?;
let has_max_dims = (flags & 0x01) != 0;
let has_permutation = (flags & 0x02) != 0;
let dataspace_type = if rank == 0 {
DataspaceType::Scalar
} else {
DataspaceType::Simple
};
let mut dims = Vec::with_capacity(rank as usize);
for _ in 0..rank {
dims.push(cursor.read_length(length_size)?);
}
let max_dims = if has_max_dims {
let mut md = Vec::with_capacity(rank as usize);
for _ in 0..rank {
md.push(cursor.read_length(length_size)?);
}
Some(md)
} else {
None
};
if has_permutation {
for _ in 0..rank {
cursor.read_length(length_size)?;
}
}
Ok(DataspaceMessage {
rank,
dims,
max_dims,
dataspace_type,
})
}
fn parse_v2(cursor: &mut Cursor<'_>, length_size: u8) -> Result<DataspaceMessage> {
let rank = cursor.read_u8()?;
let flags = cursor.read_u8()?;
let ds_type_byte = cursor.read_u8()?;
let has_max_dims = (flags & 0x01) != 0;
let dataspace_type = match ds_type_byte {
0 => DataspaceType::Scalar,
1 => DataspaceType::Simple,
2 => DataspaceType::Null,
_ => {
return Err(Error::InvalidData(format!(
"unknown dataspace type: {}",
ds_type_byte
)))
}
};
let mut dims = Vec::with_capacity(rank as usize);
for _ in 0..rank {
dims.push(cursor.read_length(length_size)?);
}
let max_dims = if has_max_dims {
let mut md = Vec::with_capacity(rank as usize);
for _ in 0..rank {
md.push(cursor.read_length(length_size)?);
}
Some(md)
} else {
None
};
Ok(DataspaceMessage {
rank,
dims,
max_dims,
dataspace_type,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_v1_scalar() {
let data = [
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ];
let mut cursor = Cursor::new(&data);
let msg = parse(&mut cursor, 8, 8, data.len()).unwrap();
assert_eq!(msg.rank, 0);
assert_eq!(msg.dataspace_type, DataspaceType::Scalar);
assert!(msg.dims.is_empty());
assert!(msg.max_dims.is_none());
assert_eq!(msg.num_elements(), 1);
}
#[test]
fn test_parse_v1_simple_2d() {
let mut data = vec![
0x01, 0x02, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, ];
data.extend_from_slice(&10u64.to_le_bytes());
data.extend_from_slice(&20u64.to_le_bytes());
data.extend_from_slice(&100u64.to_le_bytes());
data.extend_from_slice(&u64::MAX.to_le_bytes());
let mut cursor = Cursor::new(&data);
let msg = parse(&mut cursor, 8, 8, data.len()).unwrap();
assert_eq!(msg.rank, 2);
assert_eq!(msg.dims, vec![10, 20]);
assert_eq!(msg.max_dims.as_ref().unwrap(), &vec![100, UNLIMITED]);
assert_eq!(msg.dataspace_type, DataspaceType::Simple);
assert_eq!(msg.num_elements(), 200);
}
#[test]
fn test_parse_v2_simple_1d() {
let mut data = vec![
0x02, 0x01, 0x00, 0x01, ];
data.extend_from_slice(&42u32.to_le_bytes());
let mut cursor = Cursor::new(&data);
let msg = parse(&mut cursor, 4, 4, data.len()).unwrap();
assert_eq!(msg.rank, 1);
assert_eq!(msg.dims, vec![42]);
assert!(msg.max_dims.is_none());
assert_eq!(msg.dataspace_type, DataspaceType::Simple);
}
#[test]
fn test_parse_v2_null() {
let data = [
0x02, 0x00, 0x00, 0x02, ];
let mut cursor = Cursor::new(&data);
let msg = parse(&mut cursor, 8, 8, data.len()).unwrap();
assert_eq!(msg.dataspace_type, DataspaceType::Null);
assert_eq!(msg.num_elements(), 0);
}
#[test]
fn test_parse_v2_with_max_dims() {
let mut data = vec![
0x02, 0x03, 0x01, 0x01, ];
for &d in &[5u64, 10, 15] {
data.extend_from_slice(&d.to_le_bytes());
}
for &d in &[50u64, 100, u64::MAX] {
data.extend_from_slice(&d.to_le_bytes());
}
let mut cursor = Cursor::new(&data);
let msg = parse(&mut cursor, 8, 8, data.len()).unwrap();
assert_eq!(msg.rank, 3);
assert_eq!(msg.dims, vec![5, 10, 15]);
let md = msg.max_dims.clone().unwrap();
assert_eq!(md, vec![50, 100, UNLIMITED]);
assert_eq!(msg.num_elements(), 750);
}
#[test]
fn test_unsupported_version() {
let data = [0x03, 0x00, 0x00, 0x00];
let mut cursor = Cursor::new(&data);
assert!(parse(&mut cursor, 8, 8, data.len()).is_err());
}
}