#[cfg(feature = "publisher-trust-service")]
use crate::crypto::sha256_bytes;
use crate::crypto::ASYM_PUBLIC_KEY_LEN;
use crate::crypto::{constant_time_eq, ed25519_verify, is_encrypted_rnn, ASYM_SIGNATURE_LEN};
use crate::model_config::ingest::{
encoded_payload_size, AUTH_DISTRIBUTION_POLICY_BLOB_DATA, AUTH_ED25519_PUBKEY_BLOB_DATA,
AUTH_ED25519_SIG_BLOB_DATA, AUTH_HMAC_SHA256_BLOB_DATA, BLOB_BIASES, BLOB_LAYER_META,
BLOB_WEIGHTS, DISTRIBUTION_POLICY_PUBLISHER_SHARED, DISTRIBUTION_POLICY_USER_DEVICE_LOCKED,
LAYER_META_SIZE, RNN0_MAGIC, RNN0_VERSION, TLV_HEADER_BENCHMARK, TLV_HEADER_NETWORK_SUMMARY,
};
use crate::model_format::BLOB_NEURON_POSITIONS;
use crate::model_format::BLOB_RUNTIME_INPUT;
use crate::rnn_format::lookup::find_blob_index;
use crate::rnn_format::parser::parse_rnn_from_bytes;
use crate::scratch::Scratch;
#[cfg(feature = "publisher-trust-service")]
use crate::trust_service::PublisherTrustService;
#[cfg(feature = "publisher-trust-service")]
use crate::trust_service::{validate_activation_decision_placeholder, ActivationChallenge};
pub(crate) enum ValidateRequest<'a> {
BuildBenchmark {
topology: &'a [usize],
precision: &'a str,
},
#[cfg(feature = "publisher-trust-service")]
Distribution {
bytes: &'a [u8],
current_device_id: &'a [u8],
trusted_publisher_pubkeys: &'a [[u8; ASYM_PUBLIC_KEY_LEN]],
},
RnnDtype {
bytes: &'a [u8],
},
}
pub(crate) enum ValidateResult {
BuildBenchmarkFlags(u64),
#[cfg(feature = "publisher-trust-service")]
DistributionValidated,
RnnDtype(u8),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum ValidateError {
Config(crate::model_config::ConfigError),
Rnn(crate::rnn_format::Error),
}
impl From<crate::model_config::ConfigError> for ValidateError {
fn from(value: crate::model_config::ConfigError) -> Self {
ValidateError::Config(value)
}
}
impl From<crate::rnn_format::Error> for ValidateError {
fn from(value: crate::rnn_format::Error) -> Self {
ValidateError::Rnn(value)
}
}
pub(crate) fn validate(request: ValidateRequest<'_>) -> Result<ValidateResult, ValidateError> {
match request {
ValidateRequest::BuildBenchmark {
topology,
precision,
} => {
let flags = build_benchmark_flags(topology, precision)?;
Ok(ValidateResult::BuildBenchmarkFlags(flags))
}
#[cfg(feature = "publisher-trust-service")]
ValidateRequest::Distribution {
bytes,
current_device_id,
trusted_publisher_pubkeys,
} => {
validate_distribution_internal(bytes, current_device_id, trusted_publisher_pubkeys)?;
Ok(ValidateResult::DistributionValidated)
}
ValidateRequest::RnnDtype { bytes } => {
let dtype = extract_rnn_dtype(bytes)?;
Ok(ValidateResult::RnnDtype(dtype))
}
}
}
pub(crate) fn parse_header_tlv(
bytes: &[u8],
tlv_type: u8,
) -> Result<&[u8], crate::rnn_format::Error> {
crate::model_format::header_tlv_payload(bytes, tlv_type).map_err(|e| match e {
crate::model_format::RnnProtocolError::Truncated => crate::rnn_format::Error::Truncated,
crate::model_format::RnnProtocolError::BadMagic => crate::rnn_format::Error::BadMagic,
crate::model_format::RnnProtocolError::BadVersion => crate::rnn_format::Error::BadVersion,
crate::model_format::RnnProtocolError::BadHeader => crate::rnn_format::Error::BadHeader,
crate::model_format::RnnProtocolError::CapacityTooSmall => {
crate::rnn_format::Error::CapacityTooSmall
}
crate::model_format::RnnProtocolError::InvalidPayload => {
crate::rnn_format::Error::InvalidPayload
}
})
}
pub(crate) fn parse_header_tlv_range(
bytes: &[u8],
tlv_type: u8,
) -> Result<(usize, usize), crate::rnn_format::Error> {
let payload = parse_header_tlv(bytes, tlv_type)?;
let start = payload.as_ptr() as usize - bytes.as_ptr() as usize;
let end = start
.checked_add(payload.len())
.ok_or(crate::rnn_format::Error::BadHeader)?;
if end > bytes.len() {
return Err(crate::rnn_format::Error::BadHeader);
}
Ok((start, end))
}
pub(crate) fn validate_payload_records(
records: &[crate::model_format::BlobDesc<'_>],
) -> Result<(), crate::rnn_format::Error> {
let mut layer_meta = None;
let mut weights = None;
let mut biases = None;
let mut neuron_positions = None;
let mut runtime_input = None;
for rec in records {
match rec.name {
BLOB_LAYER_META => layer_meta = Some(rec),
BLOB_WEIGHTS => weights = Some(rec),
BLOB_BIASES => biases = Some(rec),
BLOB_NEURON_POSITIONS => neuron_positions = Some(rec),
BLOB_RUNTIME_INPUT => runtime_input = Some(rec),
_ => {}
}
}
let layer_meta = layer_meta.ok_or(crate::rnn_format::Error::InvalidPayload)?;
let weights = weights.ok_or(crate::rnn_format::Error::InvalidPayload)?;
let biases = biases.ok_or(crate::rnn_format::Error::InvalidPayload)?;
let neuron_positions = neuron_positions.ok_or(crate::rnn_format::Error::InvalidPayload)?;
if layer_meta.payload.is_empty()
|| weights.payload.is_empty()
|| biases.payload.is_empty()
|| neuron_positions.payload.is_empty()
{
return Err(crate::rnn_format::Error::InvalidPayload);
}
if let Some(runtime_input) = runtime_input {
if runtime_input.payload.is_empty() {
return Err(crate::rnn_format::Error::InvalidPayload);
}
}
if layer_meta.dtype != weights.dtype
|| weights.dtype != biases.dtype
|| biases.dtype != neuron_positions.dtype
{
return Err(crate::rnn_format::Error::InvalidPayload);
}
if let Some(runtime_input) = runtime_input {
if neuron_positions.dtype != runtime_input.dtype {
return Err(crate::rnn_format::Error::InvalidPayload);
}
}
if layer_meta.ndim != 2 || layer_meta.dims[1] != 5 {
return Err(crate::rnn_format::Error::InvalidPayload);
}
if neuron_positions.ndim != 2 || neuron_positions.dims[1] != 5 {
return Err(crate::rnn_format::Error::InvalidPayload);
}
if weights.ndim != 1 || biases.ndim != 1 {
return Err(crate::rnn_format::Error::InvalidPayload);
}
if let Some(runtime_input) = runtime_input {
if runtime_input.ndim != 1 {
return Err(crate::rnn_format::Error::InvalidPayload);
}
}
let elem_size = match layer_meta.dtype {
0 => 4usize,
1 => 8usize,
_ => return Err(crate::rnn_format::Error::InvalidPayload),
};
let expected_layer_meta = (layer_meta.dims[0] as usize)
.checked_mul(20)
.ok_or(crate::rnn_format::Error::InvalidPayload)?;
let expected_neuron_positions = (neuron_positions.dims[0] as usize)
.checked_mul(5)
.and_then(|v| v.checked_mul(elem_size))
.ok_or(crate::rnn_format::Error::InvalidPayload)?;
let expected_weights = (weights.dims[0] as usize)
.checked_mul(elem_size)
.ok_or(crate::rnn_format::Error::InvalidPayload)?;
let expected_biases = (biases.dims[0] as usize)
.checked_mul(elem_size)
.ok_or(crate::rnn_format::Error::InvalidPayload)?;
if layer_meta.payload.len() != expected_layer_meta
|| neuron_positions.payload.len() != expected_neuron_positions
|| weights.payload.len() != expected_weights
|| biases.payload.len() != expected_biases
{
return Err(crate::rnn_format::Error::InvalidPayload);
}
if let Some(runtime_input) = runtime_input {
let expected_runtime_input = (runtime_input.dims[0] as usize)
.checked_mul(elem_size)
.ok_or(crate::rnn_format::Error::InvalidPayload)?;
if runtime_input.payload.len() != expected_runtime_input {
return Err(crate::rnn_format::Error::InvalidPayload);
}
}
Ok(())
}
fn extract_rnn_dtype(bytes: &[u8]) -> Result<u8, crate::rnn_format::Error> {
if bytes.len() < 12 {
return Err(crate::rnn_format::Error::Truncated);
}
if !constant_time_eq(&bytes[0..4], RNN0_MAGIC) {
return Err(crate::rnn_format::Error::BadMagic);
}
let version = u16::from_le_bytes([bytes[4], bytes[5]]);
if version != RNN0_VERSION {
return Err(crate::rnn_format::Error::BadVersion);
}
let header_size = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize;
if header_size > bytes.len() || header_size < 12 {
return Err(crate::rnn_format::Error::BadHeader);
}
let mut scratch_buf = [0u8; 16384];
let mut scratch = Scratch::new(&mut scratch_buf);
let handle = parse_rnn_from_bytes(bytes, &mut scratch)
.map_err(|_| crate::rnn_format::Error::InvalidContainer)?;
let layer_meta_idx = find_blob_index(&handle, crate::model_config::ingest::BLOB_LAYER_META)
.ok_or(crate::rnn_format::Error::InvalidPayload)?;
let layer_meta_desc = handle
.blobs
.get(layer_meta_idx)
.ok_or(crate::rnn_format::Error::InvalidPayload)?;
Ok(layer_meta_desc.dtype)
}
fn build_benchmark_flags(
topology: &[usize],
precision: &str,
) -> Result<u64, crate::model_config::ConfigError> {
let cfg = if topology.len() <= 2 {
super::presets::tiny_transformer(1024, topology.len().max(1))
} else if topology.len() <= 6 {
super::presets::small_transformer(1024, topology.len())
} else {
super::presets::base_transformer(1024, topology.len())
};
cfg.validate()?;
cfg.attention_head_dim()?;
cfg.approximate_parameter_count()?;
cfg.parameters_per_layer()?;
cfg.embedding_parameter_count()?;
cfg.total_token_elements(1, 1)?;
cfg.activation_elements(1, 1)?;
cfg.kv_cache_elements(1, 1)?;
cfg.validate_runtime_shape(1, 1)?;
let runtime_profile = if precision.eq_ignore_ascii_case("f16") {
crate::runtime::RuntimeProfile::fp16(1, 1)
} else {
crate::runtime::RuntimeProfile::fp32(1, 1)
};
let runtime_estimate = cfg.estimate_runtime_memory_via_runtime(runtime_profile)?;
let budget_fit = super::config_core::fit_from_estimate(runtime_estimate, usize::MAX / 4);
let runtime_flops = cfg.estimate_runtime_flops_via_runtime(runtime_profile)?;
let throughput = cfg.estimate_tokens_per_second_via_runtime(runtime_profile, 1)?;
let checked_budget = cfg.check_runtime_budget_via_runtime(runtime_profile, usize::MAX / 4)?;
let mut benchmark_flags = 0u64;
if budget_fit.fits && checked_budget.fits {
benchmark_flags |= 1;
}
if runtime_flops.total_flops > 0 {
benchmark_flags |= 2;
}
if throughput.estimated_tokens_per_second >= 0.0 {
benchmark_flags |= 4;
}
Ok(benchmark_flags)
}
#[cfg(feature = "publisher-trust-service")]
fn validate_distribution_internal(
bytes: &[u8],
current_device_id: &[u8],
trusted_publisher_pubkeys: &[[u8; ASYM_PUBLIC_KEY_LEN]],
) -> Result<(), crate::rnn_format::Error> {
validate_distribution_secure(bytes, current_device_id, trusted_publisher_pubkeys)
}
#[cfg(feature = "publisher-trust-service")]
pub(crate) fn validate_distribution_with_service<S: PublisherTrustService>(
bytes: &[u8],
current_device_id: &[u8],
trusted_publisher_pubkeys: &[[u8; ASYM_PUBLIC_KEY_LEN]],
machine_fingerprint: &[u8],
now_unix: u64,
service: &S,
) -> Result<(), crate::rnn_format::Error> {
validate_distribution_with_service_secure(
bytes,
current_device_id,
trusted_publisher_pubkeys,
machine_fingerprint,
now_unix,
service,
)
}
pub(crate) fn validate_distribution_secure(
bytes: &[u8],
current_device_id: &[u8],
trusted_publisher_pubkeys: &[[u8; ASYM_PUBLIC_KEY_LEN]],
) -> Result<(), crate::rnn_format::Error> {
validate_model_contract_for_engine(bytes)?;
if is_encrypted_rnn(bytes) {
return Ok(());
}
let mut scratch_buf = [0u8; 16384];
let mut scratch = Scratch::new(&mut scratch_buf);
let handle = parse_rnn_from_bytes(bytes, &mut scratch)
.map_err(|_| crate::rnn_format::Error::InvalidContainer)?;
let policy_payload = find_blob_bytes_by_name(&handle, AUTH_DISTRIBUTION_POLICY_BLOB_DATA)
.ok_or(crate::rnn_format::Error::InvalidContainer)?;
if policy_payload.len() != 1 {
return Err(crate::rnn_format::Error::InvalidContainer);
}
match policy_payload[0] {
DISTRIBUTION_POLICY_PUBLISHER_SHARED => {
if trusted_publisher_pubkeys.is_empty() {
return Err(crate::rnn_format::Error::InvalidContainer);
}
verify_auth_signature_for_device(&handle, &[])?;
let pubkey = find_blob_bytes_by_name(&handle, AUTH_ED25519_PUBKEY_BLOB_DATA)
.ok_or(crate::rnn_format::Error::InvalidContainer)?;
if pubkey.len() != ASYM_PUBLIC_KEY_LEN {
return Err(crate::rnn_format::Error::InvalidContainer);
}
let is_trusted = trusted_publisher_pubkeys
.iter()
.any(|trusted| constant_time_eq(pubkey, trusted));
if !is_trusted {
return Err(crate::rnn_format::Error::InvalidContainer);
}
}
DISTRIBUTION_POLICY_USER_DEVICE_LOCKED => {
if current_device_id.is_empty() {
return Err(crate::rnn_format::Error::InvalidContainer);
}
verify_auth_signature_for_device(&handle, current_device_id)?;
}
_ => return Err(crate::rnn_format::Error::InvalidContainer),
}
Ok(())
}
struct FixedMessageBuffer<const N: usize> {
data: [u8; N],
len: usize,
}
impl<const N: usize> FixedMessageBuffer<N> {
fn new() -> Self {
Self {
data: [0u8; N],
len: 0,
}
}
fn extend_from_slice(&mut self, bytes: &[u8]) -> Result<(), crate::rnn_format::Error> {
let next_len = self
.len
.checked_add(bytes.len())
.ok_or(crate::rnn_format::Error::CapacityTooSmall)?;
if next_len > N {
return Err(crate::rnn_format::Error::CapacityTooSmall);
}
self.data[self.len..next_len].copy_from_slice(bytes);
self.len = next_len;
Ok(())
}
fn push(&mut self, byte: u8) -> Result<(), crate::rnn_format::Error> {
if self.len >= N {
return Err(crate::rnn_format::Error::CapacityTooSmall);
}
self.data[self.len] = byte;
self.len += 1;
Ok(())
}
fn as_slice(&self) -> &[u8] {
&self.data[..self.len]
}
}
#[cfg(feature = "publisher-trust-service")]
pub(crate) fn validate_distribution_with_service_secure<S: PublisherTrustService>(
bytes: &[u8],
current_device_id: &[u8],
trusted_publisher_pubkeys: &[[u8; ASYM_PUBLIC_KEY_LEN]],
machine_fingerprint: &[u8],
now_unix: u64,
service: &S,
) -> Result<(), crate::rnn_format::Error> {
validate_distribution_secure(bytes, current_device_id, trusted_publisher_pubkeys)?;
if is_encrypted_rnn(bytes) {
return Ok(());
}
if machine_fingerprint.is_empty() {
return Err(crate::rnn_format::Error::InvalidContainer);
}
let mut scratch_buf = [0u8; 16384];
let mut scratch = Scratch::new(&mut scratch_buf);
let handle = parse_rnn_from_bytes(bytes, &mut scratch)
.map_err(|_| crate::rnn_format::Error::InvalidContainer)?;
let policy_payload = find_blob_bytes_by_name(&handle, AUTH_DISTRIBUTION_POLICY_BLOB_DATA)
.ok_or(crate::rnn_format::Error::InvalidContainer)?;
if policy_payload.len() != 1 {
return Err(crate::rnn_format::Error::InvalidContainer);
}
if policy_payload[0] != DISTRIBUTION_POLICY_USER_DEVICE_LOCKED {
return Ok(());
}
let model_pubkey_bytes = find_blob_bytes_by_name(&handle, AUTH_ED25519_PUBKEY_BLOB_DATA)
.ok_or(crate::rnn_format::Error::InvalidContainer)?;
let model_signing_pubkey: [u8; ASYM_PUBLIC_KEY_LEN] = model_pubkey_bytes
.try_into()
.map_err(|_| crate::rnn_format::Error::InvalidContainer)?;
let mut model_sha256 = [0u8; 32];
sha256_bytes(bytes, &mut model_sha256);
let challenge = ActivationChallenge {
model_sha256,
model_signing_pubkey,
distribution_policy: policy_payload[0],
machine_fingerprint,
device_id: current_device_id,
issued_at_unix: now_unix,
};
let decision = service
.request_activation(&challenge)
.map_err(|_| crate::rnn_format::Error::InvalidContainer)?;
validate_activation_decision_placeholder(&decision, now_unix)
.map_err(|_| crate::rnn_format::Error::InvalidContainer)
}
fn find_blob_bytes_by_name<'bytes, 'scratch>(
handle: &crate::rnn_format::parser::RnnHandle<'bytes, 'scratch>,
name: &str,
) -> Option<&'bytes [u8]> {
let idx = find_blob_index(handle, name)?;
let meta = handle.blobs.get(idx)?;
let start = usize::try_from(meta.offset).ok()?;
let len = usize::try_from(meta.length).ok()?;
let end = start.checked_add(len)?;
handle.bytes.get(start..end)
}
fn verify_auth_signature_for_device(
handle: &crate::rnn_format::parser::RnnHandle<'_, '_>,
device_id: &[u8],
) -> Result<(), crate::rnn_format::Error> {
let sig = find_blob_bytes_by_name(handle, AUTH_ED25519_SIG_BLOB_DATA)
.ok_or(crate::rnn_format::Error::InvalidContainer)?;
let pubkey = find_blob_bytes_by_name(handle, AUTH_ED25519_PUBKEY_BLOB_DATA)
.ok_or(crate::rnn_format::Error::InvalidContainer)?;
if sig.len() != ASYM_SIGNATURE_LEN || pubkey.len() != ASYM_PUBLIC_KEY_LEN {
return Err(crate::rnn_format::Error::InvalidContainer);
}
let mut message = FixedMessageBuffer::<262144>::new();
if device_id.len() > u16::MAX as usize {
return Err(crate::rnn_format::Error::BadHeader);
}
message.extend_from_slice(&(device_id.len() as u16).to_le_bytes())?;
message.extend_from_slice(device_id)?;
for idx in 0..handle.blobs.len() {
let name = handle
.blob_name(idx)
.ok_or(crate::rnn_format::Error::InvalidContainer)?;
if name == AUTH_HMAC_SHA256_BLOB_DATA
|| name == AUTH_ED25519_SIG_BLOB_DATA
|| name == AUTH_ED25519_PUBKEY_BLOB_DATA
{
continue;
}
let blob = handle
.blobs
.get(idx)
.ok_or(crate::rnn_format::Error::InvalidContainer)?;
let payload = find_blob_bytes_by_name(handle, name)
.ok_or(crate::rnn_format::Error::InvalidContainer)?;
let name_bytes = name.as_bytes();
if name_bytes.len() > u16::MAX as usize {
return Err(crate::rnn_format::Error::BadHeader);
}
message.extend_from_slice(&(name_bytes.len() as u16).to_le_bytes())?;
message.extend_from_slice(name_bytes)?;
message.push(blob.dtype)?;
message.push(blob.ndim)?;
let shape_start = blob.shape_offset;
let dims_len = (blob.ndim as usize)
.checked_mul(4)
.ok_or(crate::rnn_format::Error::BadHeader)?;
let shape_end = shape_start
.checked_add(dims_len)
.ok_or(crate::rnn_format::Error::BadHeader)?;
let shape_bytes = handle
.scratch
.get(shape_start..shape_end)
.ok_or(crate::rnn_format::Error::InvalidContainer)?;
for chunk in shape_bytes.chunks_exact(4) {
let dim = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
message.extend_from_slice(&dim.to_le_bytes())?;
}
message.extend_from_slice(&(payload.len() as u64).to_le_bytes())?;
message.extend_from_slice(payload)?;
}
let pubkey_ref: &[u8; ASYM_PUBLIC_KEY_LEN] = pubkey
.try_into()
.map_err(|_| crate::rnn_format::Error::InvalidContainer)?;
let sig_ref: &[u8; ASYM_SIGNATURE_LEN] = sig
.try_into()
.map_err(|_| crate::rnn_format::Error::InvalidContainer)?;
ed25519_verify(pubkey_ref, message.as_slice(), sig_ref)
.map_err(|_| crate::rnn_format::Error::InvalidContainer)
}
fn validate_model_contract_for_engine(bytes: &[u8]) -> Result<(), crate::rnn_format::Error> {
if is_encrypted_rnn(bytes) {
return Ok(());
}
if bytes.len() < 12 {
return Err(crate::rnn_format::Error::Truncated);
}
if !constant_time_eq(&bytes[0..4], RNN0_MAGIC) {
return Err(crate::rnn_format::Error::BadMagic);
}
let version = u16::from_le_bytes([bytes[4], bytes[5]]);
if version != RNN0_VERSION {
return Err(crate::rnn_format::Error::BadVersion);
}
crate::model_format::header_tlv_payload(bytes, TLV_HEADER_NETWORK_SUMMARY)
.map_err(|_| crate::rnn_format::Error::InvalidPayload)?;
crate::model_format::header_tlv_payload(bytes, TLV_HEADER_BENCHMARK)
.map_err(|_| crate::rnn_format::Error::InvalidPayload)?;
let mut scratch_buf = [0u8; 16384];
let mut scratch = Scratch::new(&mut scratch_buf);
let handle = parse_rnn_from_bytes(bytes, &mut scratch)
.map_err(|_| crate::rnn_format::Error::InvalidContainer)?;
let layer_meta_idx = find_blob_index(&handle, crate::model_config::ingest::BLOB_LAYER_META)
.ok_or(crate::rnn_format::Error::InvalidPayload)?;
let weights_idx =
find_blob_index(&handle, BLOB_WEIGHTS).ok_or(crate::rnn_format::Error::InvalidPayload)?;
let biases_idx =
find_blob_index(&handle, BLOB_BIASES).ok_or(crate::rnn_format::Error::InvalidPayload)?;
let layer_meta_desc = handle
.blobs
.get(layer_meta_idx)
.ok_or(crate::rnn_format::Error::InvalidPayload)?;
let weights_desc = handle
.blobs
.get(weights_idx)
.ok_or(crate::rnn_format::Error::InvalidPayload)?;
let biases_desc = handle
.blobs
.get(biases_idx)
.ok_or(crate::rnn_format::Error::InvalidPayload)?;
let layer_meta_len = usize::try_from(layer_meta_desc.length)
.map_err(|_| crate::rnn_format::Error::InvalidPayload)?;
let weights_len_bytes = usize::try_from(weights_desc.length)
.map_err(|_| crate::rnn_format::Error::InvalidPayload)?;
let biases_len_bytes = usize::try_from(biases_desc.length)
.map_err(|_| crate::rnn_format::Error::InvalidPayload)?;
if layer_meta_len == 0 || !layer_meta_len.is_multiple_of(LAYER_META_SIZE) {
return Err(crate::rnn_format::Error::InvalidPayload);
}
let layer_count = layer_meta_len / LAYER_META_SIZE;
if weights_desc.dtype != biases_desc.dtype {
return Err(crate::rnn_format::Error::InvalidPayload);
}
let elem_size = match weights_desc.dtype {
0 => 4usize,
1 => 8usize,
_ => return Err(crate::rnn_format::Error::InvalidPayload),
};
if !weights_len_bytes.is_multiple_of(elem_size) || !biases_len_bytes.is_multiple_of(elem_size) {
return Err(crate::rnn_format::Error::InvalidPayload);
}
let weights_len = weights_len_bytes / elem_size;
let biases_len = biases_len_bytes / elem_size;
encoded_payload_size(weights_desc.dtype, layer_count, weights_len, biases_len)
.ok_or(crate::rnn_format::Error::InvalidPayload)?;
Ok(())
}