use crate::spec::types::DataType;
use crate::spec::value::error::ValueError;
use crate::spec::value::tensor::TensorValue;
use crate::spec::value::value_enum::Value;
fn parse_tensor(bytes: &[u8]) -> Result<TensorValue, ValueError> {
if bytes.len() < 2 {
return Err(ValueError::TruncatedInput {
expected: 2,
actual: bytes.len(),
});
}
let rank = bytes[0] as usize;
let header_len = 1usize
.checked_add(rank.checked_mul(8).ok_or(ValueError::TruncatedInput {
expected: usize::MAX,
actual: bytes.len(),
})?)
.and_then(|n| n.checked_add(1))
.ok_or(ValueError::TruncatedInput {
expected: usize::MAX,
actual: bytes.len(),
})?;
if bytes.len() < header_len {
return Err(ValueError::TruncatedInput {
expected: header_len,
actual: bytes.len(),
});
}
let mut shape = Vec::with_capacity(rank);
for i in 0..rank {
let offset = 1 + i * 8;
let dim = u64::from_le_bytes([
bytes[offset],
bytes[offset + 1],
bytes[offset + 2],
bytes[offset + 3],
bytes[offset + 4],
bytes[offset + 5],
bytes[offset + 6],
bytes[offset + 7],
]);
shape.push(dim);
}
let element_tag = bytes[1 + rank * 8];
let element_type = match element_tag {
0 => DataType::U32,
1 => DataType::I32,
2 => DataType::U64,
5 => DataType::Bool,
8 => DataType::F16,
9 => DataType::BF16,
10 => DataType::F32,
11 => DataType::F64,
_ => return Err(ValueError::UnknownElementTypeTag),
};
let payload = &bytes[header_len..];
let element_size = element_type.min_bytes();
if element_size == 0 {
return Err(ValueError::UnsupportedElementType);
}
let element_count = shape.iter().try_fold(1usize, |acc, dim| {
let dim = usize::try_from(*dim).map_err(|_| ValueError::TruncatedInput {
expected: usize::MAX,
actual: payload.len(),
})?;
acc.checked_mul(dim).ok_or(ValueError::TruncatedInput {
expected: usize::MAX,
actual: payload.len(),
})
})?;
let expected_payload =
element_count
.checked_mul(element_size)
.ok_or(ValueError::TruncatedInput {
expected: usize::MAX,
actual: payload.len(),
})?;
if payload.len() != expected_payload {
return Err(ValueError::TruncatedInput {
expected: header_len + expected_payload,
actual: bytes.len(),
});
}
Ok(TensorValue {
shape,
element_type,
data: payload.to_vec(),
})
}