use super::bounds::{all_blob_bounds_valid, all_blob_digests_valid};
use super::lookup::find_blob_index;
use super::parser::{parse_rnn_from_bytes, RnnHandle};
use crate::crypto::{
constant_time_eq, extract_rnn_crypto_timestamp, extract_rnn_issue_counter,
extract_rnn_key_version, extract_rnn_public_has_benchmark, extract_rnn_public_has_model_name,
extract_rnn_public_has_model_precision, extract_rnn_public_header_summary, is_encrypted_rnn,
RNN_MIN_ACCEPTED_KEY_VERSION,
};
use crate::model_format::{
header_tlv_payload, BLOB_BIASES, BLOB_LAYER_META, BLOB_WEIGHTS, LAYER_META_SIZE, RNN0_MAGIC,
RNN0_VERSION, TLV_HEADER_BENCHMARK, TLV_HEADER_NETWORK_SUMMARY,
};
use crate::scratch::Scratch;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
Truncated,
BadMagic,
BadVersion,
BadHeader,
BadBounds,
CapacityTooSmall,
InvalidPayload,
InvalidContainer,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct HeaderSummary {
pub dtype: u8,
pub layer_count: u32,
pub total_neurons: u32,
pub weights_len: u32,
pub biases_len: u32,
pub blob_count: u32,
pub flags: u32,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ScanReport {
pub encrypted: bool,
pub header: Option<HeaderSummary>,
pub has_benchmark: Option<bool>,
pub has_model_name: Option<bool>,
pub has_model_precision: Option<bool>,
pub key_version: Option<u32>,
pub issue_counter: Option<u64>,
pub crypto_timestamp: Option<u64>,
}
const MODEL_NAME_BLOB_DATA: &str = "model.name";
const MODEL_PRECISION_BLOB_DATA: &str = "model.precision";
fn encoded_payload_size(
dtype: u8,
layer_count: usize,
weights_len: usize,
biases_len: usize,
) -> Option<usize> {
let elem_size = match dtype {
0 => 4usize,
1 => 8usize,
_ => return None,
};
let layers_bytes = layer_count.checked_mul(LAYER_META_SIZE)?;
let weights_bytes = weights_len.checked_mul(elem_size)?;
let biases_bytes = biases_len.checked_mul(elem_size)?;
20usize
.checked_add(layers_bytes)?
.checked_add(weights_bytes)?
.checked_add(biases_bytes)
}
pub(crate) fn extract_header_summary(bytes: &[u8]) -> Result<HeaderSummary, Error> {
let payload =
header_tlv_payload(bytes, TLV_HEADER_NETWORK_SUMMARY).map_err(|_| Error::InvalidPayload)?;
if payload.len() >= 16 && constant_time_eq(&payload[0..4], b"VIZ\x00") {
return decode_compact_viz_summary(payload);
}
decode_header_summary(payload)
}
pub fn validate(bytes: &[u8], device_id: Option<&[u8]>) -> Result<ScanReport, Error> {
validate_scan_backend(bytes, device_id)
}
pub(crate) fn validate_scan_backend(
bytes: &[u8],
device_id: Option<&[u8]>,
) -> Result<ScanReport, Error> {
validate_inner(bytes, device_id)
}
fn validate_inner(bytes: &[u8], device_id: Option<&[u8]>) -> Result<ScanReport, Error> {
if bytes.len() < 12 {
return Err(Error::Truncated);
}
if !constant_time_eq(&bytes[0..4], RNN0_MAGIC) {
return Err(Error::BadMagic);
}
let version = u16::from_le_bytes([bytes[4], bytes[5]]);
if version != RNN0_VERSION {
return Err(Error::BadVersion);
}
let mut scratch_buf = [0u8; 16384];
let mut scratch = Scratch::new(&mut scratch_buf);
let handle = match parse_rnn_from_bytes(bytes, &mut scratch) {
Ok(handle) => handle,
Err(_) => {
if is_encrypted_rnn(bytes) {
return validate_encrypted_container(bytes, device_id);
}
return Err(Error::InvalidContainer);
}
};
if !all_blob_bounds_valid(&handle) {
return Err(Error::BadBounds);
}
if !all_blob_digests_valid(&handle) {
return Err(Error::BadBounds);
}
let layer_meta = find_layer_meta_blob(&handle).ok_or(Error::InvalidPayload)?;
let weights = find_weights_blob(&handle).ok_or(Error::InvalidPayload)?;
let biases = find_biases_blob(&handle).ok_or(Error::InvalidPayload)?;
let layer_meta_idx = find_blob_index(&handle, BLOB_LAYER_META).ok_or(Error::InvalidPayload)?;
let weights_idx = find_blob_index(&handle, BLOB_WEIGHTS).ok_or(Error::InvalidPayload)?;
let biases_idx = find_blob_index(&handle, BLOB_BIASES).ok_or(Error::InvalidPayload)?;
let layer_meta_desc = handle
.blobs
.get(layer_meta_idx)
.ok_or(Error::InvalidPayload)?;
let weights_desc = handle.blobs.get(weights_idx).ok_or(Error::InvalidPayload)?;
let biases_desc = handle.blobs.get(biases_idx).ok_or(Error::InvalidPayload)?;
if layer_meta.is_empty() || weights.is_empty() || biases.is_empty() {
return Err(Error::InvalidPayload);
}
if !layer_meta.len().is_multiple_of(LAYER_META_SIZE) {
return Err(Error::InvalidPayload);
}
let benchmark = crate::model_format::header_tlv_payload(bytes, TLV_HEADER_BENCHMARK)
.map_err(|_| Error::InvalidPayload)?;
if benchmark.is_empty() {
return Err(Error::InvalidPayload);
}
if layer_meta_desc.dtype != weights_desc.dtype || weights_desc.dtype != biases_desc.dtype {
return Err(Error::InvalidPayload);
}
let layer_count = layer_meta.len() / LAYER_META_SIZE;
match weights_desc.dtype {
0 => {
if !weights.len().is_multiple_of(4) || !biases.len().is_multiple_of(4) {
return Err(Error::InvalidPayload);
}
let weights_len = weights.len() / 4;
let biases_len = biases.len() / 4;
encoded_payload_size(weights_desc.dtype, layer_count, weights_len, biases_len)
.ok_or(Error::InvalidPayload)?;
}
1 => {
if !weights.len().is_multiple_of(8) || !biases.len().is_multiple_of(8) {
return Err(Error::InvalidPayload);
}
let weights_len = weights.len() / 8;
let biases_len = biases.len() / 8;
encoded_payload_size(weights_desc.dtype, layer_count, weights_len, biases_len)
.ok_or(Error::InvalidPayload)?;
}
_ => return Err(Error::InvalidPayload),
}
let has_benchmark =
crate::model_format::header_tlv_payload(bytes, TLV_HEADER_BENCHMARK).is_ok();
let has_model_name = find_blob_index(&handle, MODEL_NAME_BLOB_DATA).is_some();
let has_model_precision = find_blob_index(&handle, MODEL_PRECISION_BLOB_DATA).is_some();
let header = extract_header_summary(bytes).ok();
Ok(ScanReport {
encrypted: false,
header,
has_benchmark: Some(has_benchmark),
has_model_name: Some(has_model_name),
has_model_precision: Some(has_model_precision),
key_version: None,
issue_counter: None,
crypto_timestamp: None,
})
}
fn validate_encrypted_container(
bytes: &[u8],
key_material: Option<&[u8]>,
) -> Result<ScanReport, Error> {
if let Some(material) = key_material {
if material.is_empty() {
return Err(Error::InvalidContainer);
}
}
if !is_encrypted_rnn(bytes) {
return Err(Error::BadMagic);
}
let key_version = extract_rnn_key_version(bytes).ok_or(Error::BadHeader)?;
let issue_counter = extract_rnn_issue_counter(bytes).ok_or(Error::BadHeader)?;
let crypto_timestamp = extract_rnn_crypto_timestamp(bytes).ok_or(Error::BadHeader)?;
if key_version < RNN_MIN_ACCEPTED_KEY_VERSION || issue_counter == 0 || crypto_timestamp == 0 {
return Err(Error::InvalidContainer);
}
let header = extract_rnn_public_header_summary(bytes).map(
|(dtype, layer_count, total_neurons, weights_len, biases_len, blob_count, flags)| {
HeaderSummary {
dtype,
layer_count,
total_neurons,
weights_len,
biases_len,
blob_count,
flags,
}
},
);
let has_benchmark = extract_rnn_public_has_benchmark(bytes);
Ok(ScanReport {
encrypted: true,
header,
has_benchmark,
has_model_name: extract_rnn_public_has_model_name(bytes),
has_model_precision: extract_rnn_public_has_model_precision(bytes),
key_version: Some(key_version),
issue_counter: Some(issue_counter),
crypto_timestamp: Some(crypto_timestamp),
})
}
pub(crate) fn blob_bytes<'bytes, 'scratch>(
handle: &RnnHandle<'bytes, 'scratch>,
index: usize,
) -> Option<&'bytes [u8]> {
let meta = handle.blobs.get(index)?;
let start = usize::try_from(meta.offset).ok()?;
let len = usize::try_from(meta.length).ok()?;
let end = start.checked_add(len)?;
handle.bytes.get(start..end)
}
fn find_blob_bytes<'bytes, 'scratch>(
handle: &RnnHandle<'bytes, 'scratch>,
name: &str,
) -> Option<&'bytes [u8]> {
let idx = find_blob_index(handle, name)?;
blob_bytes(handle, idx)
}
fn find_layer_meta_blob<'bytes, 'scratch>(
handle: &RnnHandle<'bytes, 'scratch>,
) -> Option<&'bytes [u8]> {
find_blob_bytes(handle, BLOB_LAYER_META)
}
fn find_weights_blob<'bytes, 'scratch>(
handle: &RnnHandle<'bytes, 'scratch>,
) -> Option<&'bytes [u8]> {
find_blob_bytes(handle, BLOB_WEIGHTS)
}
fn find_biases_blob<'bytes, 'scratch>(
handle: &RnnHandle<'bytes, 'scratch>,
) -> Option<&'bytes [u8]> {
find_blob_bytes(handle, BLOB_BIASES)
}
fn decode_header_summary(payload: &[u8]) -> Result<HeaderSummary, Error> {
if payload.len() < 2 {
return Err(Error::BadHeader);
}
let summary_version = payload[0];
let codec = payload[1];
if summary_version != 1 || codec != 1 {
return Err(Error::BadHeader);
}
let compressed = &payload[2..];
if !compressed.len().is_multiple_of(2) {
return Err(Error::BadHeader);
}
let mut plain = [0u8; 30];
let mut plain_len = 0usize;
let mut idx = 0usize;
while idx < compressed.len() {
let run = compressed[idx] as usize;
let value = compressed[idx + 1];
if run == 0 {
return Err(Error::BadHeader);
}
plain_len = plain_len.checked_add(run).ok_or(Error::BadHeader)?;
if plain_len <= 30 {
let start = plain_len - run;
let end = plain_len;
plain[start..end].fill(value);
} else if plain_len - run < 30 {
let start = plain_len - run;
let end = 30usize;
plain[start..end].fill(value);
}
idx += 2;
}
if plain_len < 30 {
return Err(Error::BadHeader);
}
if !constant_time_eq(&plain[0..4], b"S5D0") {
return Err(Error::BadHeader);
}
if plain[4] != 1 {
return Err(Error::BadHeader);
}
let dtype = plain[5];
let layer_count = u32::from_le_bytes([plain[6], plain[7], plain[8], plain[9]]);
let total_neurons = u32::from_le_bytes([plain[10], plain[11], plain[12], plain[13]]);
let weights_len = u32::from_le_bytes([plain[14], plain[15], plain[16], plain[17]]);
let biases_len = u32::from_le_bytes([plain[18], plain[19], plain[20], plain[21]]);
let blob_count = u32::from_le_bytes([plain[22], plain[23], plain[24], plain[25]]);
let flags = u32::from_le_bytes([plain[26], plain[27], plain[28], plain[29]]);
Ok(HeaderSummary {
dtype,
layer_count,
total_neurons,
weights_len,
biases_len,
blob_count,
flags,
})
}
fn decode_compact_viz_summary(payload: &[u8]) -> Result<HeaderSummary, Error> {
if payload.len() < 16 {
return Err(Error::BadHeader);
}
if !constant_time_eq(&payload[0..4], b"VIZ\x00") {
return Err(Error::BadHeader);
}
if payload[3] != 0 {
return Err(Error::BadHeader);
}
let dtype = payload[4];
let layer_count = u16::from_le_bytes([payload[5], payload[6]]) as u32;
let total_neurons = u32::from_le_bytes([payload[7], payload[8], payload[9], payload[10]]);
let weights_len = u16::from_le_bytes([payload[11], payload[12]]) as u32;
let biases_len = u16::from_le_bytes([payload[13], payload[14]]) as u32;
let blob_count = payload[15] as u32;
Ok(HeaderSummary {
dtype,
layer_count,
total_neurons,
weights_len,
biases_len,
blob_count,
flags: 0,
})
}