use crate::spec::types::DataType;
use super::error::ValueError;
use super::tensor::TensorValue;
use super::value_enum::Value;
impl Value {
#[inline]
pub fn from_element_bytes(ty: DataType, bytes: &[u8]) -> Result<Self, ValueError> {
if ty.min_bytes() > 0 && bytes.len() != ty.min_bytes() {
return Err(ValueError::TruncatedInput {
expected: ty.min_bytes(),
actual: bytes.len(),
});
}
Ok(match ty {
DataType::U32 => Self::U32(u32::from_le_bytes(read_array4(bytes)?)),
DataType::I32 => Self::I32(i32::from_le_bytes(read_array4(bytes)?)),
DataType::U64 => Self::U64(u64::from_le_bytes(read_array8(bytes)?)),
DataType::Bool => Self::Bool(u32::from_le_bytes(read_array4(bytes)?) != 0),
DataType::Bytes | DataType::Vec2U32 | DataType::Vec4U32 => Self::Bytes(bytes.to_vec()),
DataType::F16 | DataType::BF16 => {
let bits = u16::from_le_bytes(read_array2(bytes)?);
let f32_bits = if matches!(ty, DataType::BF16) {
u32::from(bits) << 16
} else {
u32::from(bits)
};
Self::Float(f32::from_bits(f32_bits) as f64)
}
DataType::F32 => Self::Float(f32::from_le_bytes(read_array4(bytes)?) as f64),
DataType::F64 => Self::Float(f64::from_le_bytes(read_array8(bytes)?)),
DataType::Tensor => Self::Tensor(parse_tensor(bytes)?),
DataType::Array { element_size } => Self::Array(parse_array(bytes, element_size)?),
})
}
#[inline]
pub fn from_element_type_tag(tag: u8, bytes: &[u8]) -> Result<Self, ValueError> {
let ty = match tag {
0 => DataType::U32,
1 => DataType::I32,
2 => DataType::U64,
3 => DataType::Vec2U32,
4 => DataType::Vec4U32,
5 => DataType::Bool,
6 => DataType::Bytes,
8 => DataType::F16,
9 => DataType::BF16,
10 => DataType::F32,
11 => DataType::F64,
12 => DataType::Tensor,
13 => DataType::Array { element_size: 0 },
_ => return Err(ValueError::UnknownElementTypeTag),
};
Self::from_element_bytes(ty, bytes)
}
}
fn read_array2(bytes: &[u8]) -> Result<[u8; 2], ValueError> {
if bytes.len() != 2 {
return Err(ValueError::TruncatedInput {
expected: 2,
actual: bytes.len(),
});
}
let mut arr = [0u8; 2];
arr.copy_from_slice(bytes);
Ok(arr)
}
fn read_array4(bytes: &[u8]) -> Result<[u8; 4], ValueError> {
if bytes.len() != 4 {
return Err(ValueError::TruncatedInput {
expected: 4,
actual: bytes.len(),
});
}
let mut arr = [0u8; 4];
arr.copy_from_slice(bytes);
Ok(arr)
}
fn read_array8(bytes: &[u8]) -> Result<[u8; 8], ValueError> {
if bytes.len() != 8 {
return Err(ValueError::TruncatedInput {
expected: 8,
actual: bytes.len(),
});
}
let mut arr = [0u8; 8];
arr.copy_from_slice(bytes);
Ok(arr)
}
fn parse_tensor(bytes: &[u8]) -> Result<TensorValue, ValueError> {
if bytes.len() < 2 {
return Err(ValueError::TruncatedInput {
expected: 2,
actual: bytes.len(),
});
}
let rank = bytes[0] as usize;
let header_len = 1usize
.checked_add(rank.checked_mul(8).ok_or(ValueError::TruncatedInput {
expected: usize::MAX,
actual: bytes.len(),
})?)
.and_then(|n| n.checked_add(1))
.ok_or(ValueError::TruncatedInput {
expected: usize::MAX,
actual: bytes.len(),
})?;
if bytes.len() < header_len {
return Err(ValueError::TruncatedInput {
expected: header_len,
actual: bytes.len(),
});
}
let mut shape = Vec::with_capacity(rank);
for i in 0..rank {
let offset = 1 + i * 8;
let dim = u64::from_le_bytes([
bytes[offset],
bytes[offset + 1],
bytes[offset + 2],
bytes[offset + 3],
bytes[offset + 4],
bytes[offset + 5],
bytes[offset + 6],
bytes[offset + 7],
]);
shape.push(dim);
}
let element_tag = bytes[1 + rank * 8];
let element_type = match element_tag {
0 => DataType::U32,
1 => DataType::I32,
2 => DataType::U64,
5 => DataType::Bool,
8 => DataType::F16,
9 => DataType::BF16,
10 => DataType::F32,
11 => DataType::F64,
_ => return Err(ValueError::UnknownElementTypeTag),
};
let payload = &bytes[header_len..];
let element_size = element_type.min_bytes();
if element_size == 0 {
return Err(ValueError::UnsupportedElementType);
}
let element_count = shape.iter().try_fold(1usize, |acc, dim| {
let dim = usize::try_from(*dim).map_err(|_| ValueError::TruncatedInput {
expected: usize::MAX,
actual: payload.len(),
})?;
acc.checked_mul(dim).ok_or(ValueError::TruncatedInput {
expected: usize::MAX,
actual: payload.len(),
})
})?;
let expected_payload =
element_count
.checked_mul(element_size)
.ok_or(ValueError::TruncatedInput {
expected: usize::MAX,
actual: payload.len(),
})?;
if payload.len() != expected_payload {
return Err(ValueError::TruncatedInput {
expected: header_len + expected_payload,
actual: bytes.len(),
});
}
Ok(TensorValue {
shape,
element_type,
data: payload.to_vec(),
})
}
fn parse_array(bytes: &[u8], element_size: usize) -> Result<Vec<Value>, ValueError> {
if bytes.len() < 16 {
return Err(ValueError::TruncatedInput {
expected: 16,
actual: bytes.len(),
});
}
let wire_element_size = u64::from_le_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]) as usize;
let effective_size = if element_size == 0 {
wire_element_size
} else {
if wire_element_size != element_size {
return Err(ValueError::MismatchedElementSize {
declared: element_size,
wire: wire_element_size,
});
}
element_size
};
let count = u64::from_le_bytes([
bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
]) as usize;
let payload = &bytes[16..];
let expected_payload = count.checked_mul(effective_size).unwrap_or(0);
if payload.len() != expected_payload {
return Err(ValueError::TruncatedInput {
expected: 16 + expected_payload,
actual: bytes.len(),
});
}
let mut elements = Vec::with_capacity(count);
for i in 0..count {
let start = i * effective_size;
let end = start + effective_size;
elements.push(Value::Bytes(payload[start..end].to_vec()));
}
Ok(elements)
}