1use crate::error::{Error, Result};
7use crate::io::Cursor;
8
9pub const UNLIMITED: u64 = u64::MAX;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum DataspaceType {
15 Null,
17 Scalar,
19 Simple,
21}
22
23#[derive(Debug, Clone)]
25pub struct DataspaceMessage {
26 pub rank: u8,
28 pub dims: Vec<u64>,
30 pub max_dims: Option<Vec<u64>>,
32 pub dataspace_type: DataspaceType,
34}
35
36impl DataspaceMessage {
37 pub fn num_elements(&self) -> Result<u64> {
39 if self.dims.is_empty() {
40 return Ok(match self.dataspace_type {
41 DataspaceType::Scalar => 1,
42 _ => 0,
43 });
44 }
45 self.dims.iter().try_fold(1u64, |acc, &dim| {
46 acc.checked_mul(dim).ok_or_else(|| {
47 Error::InvalidData("dataspace element count overflows u64".to_string())
48 })
49 })
50 }
51}
52
53pub fn parse(
58 cursor: &mut Cursor<'_>,
59 _offset_size: u8,
60 length_size: u8,
61 msg_size: usize,
62) -> Result<DataspaceMessage> {
63 let start = cursor.position();
64 let version = cursor.read_u8()?;
65
66 match version {
67 1 => parse_v1(cursor, length_size),
68 2 => parse_v2(cursor, length_size),
69 v => Err(Error::UnsupportedDataspaceVersion(v)),
70 }
71 .and_then(|msg| {
72 let consumed = (cursor.position() - start) as usize;
74 if consumed < msg_size {
75 cursor.skip(msg_size - consumed)?;
76 }
77 Ok(msg)
78 })
79}
80
81fn parse_v1(cursor: &mut Cursor<'_>, length_size: u8) -> Result<DataspaceMessage> {
83 let rank = cursor.read_u8()?;
84 let flags = cursor.read_u8()?;
85 let _reserved = cursor.read_u8()?; let _reserved2 = cursor.read_u32_le()?; let has_max_dims = (flags & 0x01) != 0;
89 let has_permutation = (flags & 0x02) != 0;
92
93 let dataspace_type = if rank == 0 {
94 DataspaceType::Scalar
95 } else {
96 DataspaceType::Simple
97 };
98
99 let mut dims = Vec::with_capacity(rank as usize);
100 for _ in 0..rank {
101 dims.push(cursor.read_length(length_size)?);
102 }
103
104 let max_dims = if has_max_dims {
105 let mut md = Vec::with_capacity(rank as usize);
106 for _ in 0..rank {
107 md.push(cursor.read_length(length_size)?);
108 }
109 Some(md)
110 } else {
111 None
112 };
113
114 if has_permutation {
115 for _ in 0..rank {
117 cursor.read_length(length_size)?;
118 }
119 }
120
121 Ok(DataspaceMessage {
122 rank,
123 dims,
124 max_dims,
125 dataspace_type,
126 })
127}
128
129fn parse_v2(cursor: &mut Cursor<'_>, length_size: u8) -> Result<DataspaceMessage> {
131 let rank = cursor.read_u8()?;
132 let flags = cursor.read_u8()?;
133 let ds_type_byte = cursor.read_u8()?;
134
135 let has_max_dims = (flags & 0x01) != 0;
136
137 let dataspace_type = match ds_type_byte {
138 0 => DataspaceType::Scalar,
139 1 => DataspaceType::Simple,
140 2 => DataspaceType::Null,
141 _ => {
142 return Err(Error::InvalidData(format!(
143 "unknown dataspace type: {}",
144 ds_type_byte
145 )))
146 }
147 };
148
149 let mut dims = Vec::with_capacity(rank as usize);
150 for _ in 0..rank {
151 dims.push(cursor.read_length(length_size)?);
152 }
153
154 let max_dims = if has_max_dims {
155 let mut md = Vec::with_capacity(rank as usize);
156 for _ in 0..rank {
157 md.push(cursor.read_length(length_size)?);
158 }
159 Some(md)
160 } else {
161 None
162 };
163
164 Ok(DataspaceMessage {
165 rank,
166 dims,
167 max_dims,
168 dataspace_type,
169 })
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn parse_v1_scalar() {
178 let data = [
180 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ];
186 let mut cursor = Cursor::new(&data);
187 let msg = parse(&mut cursor, 8, 8, data.len()).unwrap();
188 assert_eq!(msg.rank, 0);
189 assert_eq!(msg.dataspace_type, DataspaceType::Scalar);
190 assert!(msg.dims.is_empty());
191 assert!(msg.max_dims.is_none());
192 assert_eq!(msg.num_elements().unwrap(), 1);
193 }
194
195 #[test]
196 fn parse_v1_simple_2d() {
197 let mut data = vec![
199 0x01, 0x02, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, ];
205 data.extend_from_slice(&10u64.to_le_bytes());
207 data.extend_from_slice(&20u64.to_le_bytes());
209 data.extend_from_slice(&100u64.to_le_bytes());
211 data.extend_from_slice(&u64::MAX.to_le_bytes());
213
214 let mut cursor = Cursor::new(&data);
215 let msg = parse(&mut cursor, 8, 8, data.len()).unwrap();
216 assert_eq!(msg.rank, 2);
217 assert_eq!(msg.dims, vec![10, 20]);
218 assert_eq!(msg.max_dims.as_ref().unwrap(), &vec![100, UNLIMITED]);
219 assert_eq!(msg.dataspace_type, DataspaceType::Simple);
220 assert_eq!(msg.num_elements().unwrap(), 200);
221 }
222
223 #[test]
224 fn parse_v2_simple_1d() {
225 let mut data = vec![
227 0x02, 0x01, 0x00, 0x01, ];
232 data.extend_from_slice(&42u32.to_le_bytes());
234
235 let mut cursor = Cursor::new(&data);
236 let msg = parse(&mut cursor, 4, 4, data.len()).unwrap();
237 assert_eq!(msg.rank, 1);
238 assert_eq!(msg.dims, vec![42]);
239 assert!(msg.max_dims.is_none());
240 assert_eq!(msg.dataspace_type, DataspaceType::Simple);
241 }
242
243 #[test]
244 fn parse_v2_null() {
245 let data = [
246 0x02, 0x00, 0x00, 0x02, ];
251 let mut cursor = Cursor::new(&data);
252 let msg = parse(&mut cursor, 8, 8, data.len()).unwrap();
253 assert_eq!(msg.dataspace_type, DataspaceType::Null);
254 assert_eq!(msg.num_elements().unwrap(), 0);
255 }
256
257 #[test]
258 fn parse_v2_with_max_dims() {
259 let mut data = vec![
260 0x02, 0x03, 0x01, 0x01, ];
265 for &d in &[5u64, 10, 15] {
267 data.extend_from_slice(&d.to_le_bytes());
268 }
269 for &d in &[50u64, 100, u64::MAX] {
271 data.extend_from_slice(&d.to_le_bytes());
272 }
273
274 let mut cursor = Cursor::new(&data);
275 let msg = parse(&mut cursor, 8, 8, data.len()).unwrap();
276 assert_eq!(msg.rank, 3);
277 assert_eq!(msg.dims, vec![5, 10, 15]);
278 let md = msg.max_dims.clone().unwrap();
279 assert_eq!(md, vec![50, 100, UNLIMITED]);
280 assert_eq!(msg.num_elements().unwrap(), 750);
281 }
282
283 #[test]
284 fn unsupported_version() {
285 let data = [0x03, 0x00, 0x00, 0x00];
286 let mut cursor = Cursor::new(&data);
287 assert!(parse(&mut cursor, 8, 8, data.len()).is_err());
288 }
289
290 #[test]
291 fn num_elements_rejects_overflow() {
292 let msg = DataspaceMessage {
293 rank: 2,
294 dims: vec![u64::MAX, 2],
295 max_dims: None,
296 dataspace_type: DataspaceType::Simple,
297 };
298
299 let err = msg.num_elements().unwrap_err();
300 assert!(err.to_string().contains("element count"));
301 }
302}