use crate::activations::ActivationKind;
use crate::benchmark::decode_benchmark_blob;
use crate::crypto::ASYM_PUBLIC_KEY_LEN;
use crate::layers::{build_from_layers, LayerError, LayerSpec};
use crate::model_config::ingest;
use crate::model_format::{
encode_blob_payloads_with_header_metadata, encode_model_f32, encode_model_f64,
neuron_positions_blob_into, neuron_positions_blob_size, parse_payload,
total_neurons_from_layer_meta, BlobDesc, ModelFormatError, RnnProtocolError,
};
use crate::model_format::{
BLOB_BIASES, BLOB_LAYER_META, BLOB_NEURON_POSITIONS, BLOB_RUNTIME_INPUT, BLOB_WEIGHTS,
};
use crate::rnn_format::lookup::find_blob_index;
use crate::rnn_format::parser::parse_rnn_from_bytes;
use crate::runtime::FixedSliceVec;
use crate::scratch::Scratch;
use crate::tensor::TensorView;
#[cfg(feature = "publisher-trust-service")]
use crate::trust_service::PublisherTrustService;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RnnFlowError {
InvalidTopology,
CapacityTooSmall,
BadBytes,
Model,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ReadableRnnFormat<'a> {
pub scan: crate::rnn_format::scanner::ScanReport,
pub model_name: Option<&'a str>,
pub model_precision: Option<&'a str>,
pub layer_meta: Option<&'a [u8]>,
pub weights: Option<&'a [u8]>,
pub biases: Option<&'a [u8]>,
pub benchmark: Option<&'a [u8]>,
}
#[derive(Clone, Copy)]
pub struct BlobWrite<'a> {
pub name: &'a str,
pub payload: &'a [u8],
}
#[derive(Clone, Copy)]
pub enum RuntimeInput<'a> {
F32(&'a [f32]),
F64(&'a [f64]),
}
pub fn pack_rnn_f32(
topology: &[usize],
hidden_activation: ActivationKind,
output_activation: ActivationKind,
weights: &[f32],
biases: &[f32],
layer_specs_scratch: &mut [LayerSpec],
out_bytes: &mut [u8],
) -> Result<usize, RnnFlowError> {
let layer_count = build_from_layers(
topology,
hidden_activation,
output_activation,
weights.len(),
biases.len(),
layer_specs_scratch,
)
.map_err(map_layer_error)?;
encode_model_f32(
&layer_specs_scratch[..layer_count],
weights,
biases,
out_bytes,
)
.map_err(map_model_error)
}
pub fn pack_rnn_f64(
topology: &[usize],
hidden_activation: ActivationKind,
output_activation: ActivationKind,
weights: &[f64],
biases: &[f64],
layer_specs_scratch: &mut [LayerSpec],
out_bytes: &mut [u8],
) -> Result<usize, RnnFlowError> {
let layer_count = build_from_layers(
topology,
hidden_activation,
output_activation,
weights.len(),
biases.len(),
layer_specs_scratch,
)
.map_err(map_layer_error)?;
encode_model_f64(
&layer_specs_scratch[..layer_count],
weights,
biases,
out_bytes,
)
.map_err(map_model_error)
}
pub fn validate_benchmark_flags(topology: &[usize], precision: &str) -> Result<u64, RnnFlowError> {
let request = crate::model_config::rnn::ValidateRequest::BuildBenchmark {
topology,
precision,
};
match crate::model_config::rnn::validate(request) {
Ok(crate::model_config::rnn::ValidateResult::BuildBenchmarkFlags(flags)) => Ok(flags),
Ok(_) => Err(RnnFlowError::InvalidTopology),
Err(_) => Err(RnnFlowError::InvalidTopology),
}
}
pub fn build_default_benchmark_blob(
model_name: &str,
precision: &str,
topology: &[usize],
output_bytes: usize,
benchmark_flags: u64,
out: &mut [u8],
) -> Result<usize, RnnFlowError> {
let _ = model_name;
let (weights, biases) = crate::initializers::expected_parameter_counts(topology)
.ok_or(RnnFlowError::InvalidTopology)?;
if out.len() < 32 {
return Err(RnnFlowError::CapacityTooSmall);
}
let precision_tag = if precision.eq_ignore_ascii_case("f64") {
2u8
} else {
1u8
};
let layer_count = topology.len().saturating_sub(1) as u32;
let input_dim = topology.first().copied().unwrap_or(0) as u32;
let output_dim = topology.last().copied().unwrap_or(0) as u32;
let total_params = weights.saturating_add(biases) as u32;
let output_bytes_u32 = u32::try_from(output_bytes).unwrap_or(u32::MAX);
out[0..4].copy_from_slice(b"BMK\x01");
out[4] = precision_tag;
out[5] = 0;
out[6..8].copy_from_slice(&0u16.to_le_bytes());
out[8..12].copy_from_slice(&layer_count.to_le_bytes());
out[12..16].copy_from_slice(&input_dim.to_le_bytes());
out[16..20].copy_from_slice(&output_dim.to_le_bytes());
out[20..24].copy_from_slice(&total_params.to_le_bytes());
out[24..28].copy_from_slice(&output_bytes_u32.to_le_bytes());
out[28..32].copy_from_slice(&(benchmark_flags as u32).to_le_bytes());
Ok(32)
}
pub fn validate_distribution(
bytes: &[u8],
current_device_id: &[u8],
trusted_publisher_pubkeys: &[[u8; ASYM_PUBLIC_KEY_LEN]],
) -> Result<(), RnnFlowError> {
validate_distribution_secure(bytes, current_device_id, trusted_publisher_pubkeys)
.map_err(|_| RnnFlowError::Model)
}
pub fn rnn_dtype(bytes: &[u8]) -> Result<u8, RnnFlowError> {
let request = crate::model_config::rnn::ValidateRequest::RnnDtype { bytes };
match crate::model_config::rnn::validate(request) {
Ok(crate::model_config::rnn::ValidateResult::RnnDtype(dtype)) => Ok(dtype),
Ok(_) => Err(RnnFlowError::BadBytes),
Err(_) => Err(RnnFlowError::BadBytes),
}
}
pub fn train_update_benchmark(
bytes: &mut [u8],
train_samples: u64,
last_loss: f32,
current_device_id: &[u8],
trusted_publisher_pubkeys: &[[u8; ASYM_PUBLIC_KEY_LEN]],
) -> Result<&'static str, RnnFlowError> {
crate::rnn_format::validate(bytes, Some(current_device_id))
.map_err(|_| RnnFlowError::BadBytes)?;
validate_distribution(bytes, current_device_id, trusted_publisher_pubkeys)?;
let dtype = rnn_dtype(bytes)?;
let precision_str: &'static str = match dtype {
0 => "f32",
1 => "f64",
_ => return Err(RnnFlowError::BadBytes),
};
let bench_range =
crate::model_config::rnn::parse_header_tlv_range(bytes, ingest::TLV_HEADER_BENCHMARK)
.map_err(|_| RnnFlowError::Model)?;
let header_summary = crate::rnn_format::scanner::extract_header_summary(bytes)
.map_err(|_| RnnFlowError::BadBytes)?;
if header_summary.dtype != dtype {
return Err(RnnFlowError::BadBytes);
}
let bench_len = bench_range.1.saturating_sub(bench_range.0);
if bench_len < 116 {
if bench_len < 32 {
return Err(RnnFlowError::BadBytes);
}
let bench_bytes = &mut bytes[bench_range.0..bench_range.1];
if &bench_bytes[0..4] != b"BMK\x01" {
return Err(RnnFlowError::BadBytes);
}
let bench_precision = bench_bytes[4];
let expected_precision = if dtype == 1 { 2u8 } else { 1u8 };
if bench_precision != expected_precision {
return Err(RnnFlowError::BadBytes);
}
let prev_layer_count = u32::from_le_bytes([
bench_bytes[8],
bench_bytes[9],
bench_bytes[10],
bench_bytes[11],
]);
let prev_input_dim = u32::from_le_bytes([
bench_bytes[12],
bench_bytes[13],
bench_bytes[14],
bench_bytes[15],
]);
let prev_output_dim = u32::from_le_bytes([
bench_bytes[16],
bench_bytes[17],
bench_bytes[18],
bench_bytes[19],
]);
let prev_total_params = u32::from_le_bytes([
bench_bytes[20],
bench_bytes[21],
bench_bytes[22],
bench_bytes[23],
]);
let prev_flags = u32::from_le_bytes([
bench_bytes[28],
bench_bytes[29],
bench_bytes[30],
bench_bytes[31],
]);
if header_summary.layer_count != prev_layer_count {
return Err(RnnFlowError::BadBytes);
}
let summary_total_params = header_summary
.weights_len
.checked_add(header_summary.biases_len)
.ok_or(RnnFlowError::BadBytes)?;
if summary_total_params != prev_total_params {
return Err(RnnFlowError::BadBytes);
}
let mut derived_storage = [0usize; 256];
let mut derived_topology = FixedSliceVec::new(&mut derived_storage);
derived_topology
.push(prev_input_dim as usize)
.map_err(|_| RnnFlowError::CapacityTooSmall)?;
for _ in 1..(prev_layer_count as usize) {
derived_topology
.push(prev_output_dim as usize)
.map_err(|_| RnnFlowError::CapacityTooSmall)?;
}
derived_topology
.push(prev_output_dim as usize)
.map_err(|_| RnnFlowError::CapacityTooSmall)?;
let trained_flags = validate_benchmark_flags(derived_topology.as_slice(), precision_str)?;
let merged_flags = prev_flags | (trained_flags as u32);
bench_bytes[28..32].copy_from_slice(&merged_flags.to_le_bytes());
return Ok(precision_str);
}
let (
prev_iterations,
prev_train_samples,
prev_elapsed,
prev_output_bytes,
prev_total_params,
prev_layer_count,
prev_input_dim,
prev_output_dim,
prev_benchmark_flags,
prev_weights_bytes,
prev_biases_bytes,
prev_min_loss,
prev_max_loss,
prev_loss_stddev,
prev_iterations_per_sec,
prev_samples_per_sec,
) = {
let bench_bytes = &bytes[bench_range.0..bench_range.1];
let view = decode_benchmark_blob(bench_bytes).map_err(|_| RnnFlowError::BadBytes)?;
if header_summary.layer_count != view.layer_count || header_summary.dtype != dtype {
return Err(RnnFlowError::BadBytes);
}
let summary_total_params = header_summary
.weights_len
.checked_add(header_summary.biases_len)
.ok_or(RnnFlowError::BadBytes)? as u64;
if summary_total_params != view.total_params {
return Err(RnnFlowError::BadBytes);
}
(
view.iterations,
view.train_samples,
view.elapsed_ms,
view.output_bytes,
view.total_params,
view.layer_count,
view.input_dim,
view.output_dim,
view.benchmark_flags,
view.weights_bytes,
view.biases_bytes,
view.min_loss,
view.max_loss,
view.loss_stddev,
view.iterations_per_sec,
view.samples_per_sec,
)
};
let new_iterations = prev_iterations.saturating_add(1);
let new_train_samples = prev_train_samples.saturating_add(train_samples);
let new_last_loss = last_loss;
let new_avg_loss = if new_train_samples == 0 {
new_last_loss
} else {
let prev_total = (0.0_f32 as f64) * (prev_train_samples as f64);
let add = (new_last_loss as f64) * (train_samples as f64);
((prev_total + add) / (new_train_samples as f64)) as f32
};
let mut derived_storage = [0usize; 256];
let mut derived_topology = FixedSliceVec::new(&mut derived_storage);
derived_topology
.push(prev_input_dim as usize)
.map_err(|_| RnnFlowError::CapacityTooSmall)?;
for _ in 1..(prev_layer_count as usize) {
derived_topology
.push(prev_output_dim as usize)
.map_err(|_| RnnFlowError::CapacityTooSmall)?;
}
derived_topology
.push(prev_output_dim as usize)
.map_err(|_| RnnFlowError::CapacityTooSmall)?;
let trained_flags = validate_benchmark_flags(derived_topology.as_slice(), precision_str)?;
{
let bench_bytes = &mut bytes[bench_range.0..bench_range.1];
if bench_bytes.len() < 112 {
return Err(RnnFlowError::BadBytes);
}
bench_bytes[8..16].copy_from_slice(&prev_elapsed.to_le_bytes());
bench_bytes[16..24].copy_from_slice(&new_iterations.to_le_bytes());
bench_bytes[24..32].copy_from_slice(&new_train_samples.to_le_bytes());
bench_bytes[32..36].copy_from_slice(&new_avg_loss.to_le_bytes());
bench_bytes[36..40].copy_from_slice(&new_last_loss.to_le_bytes());
bench_bytes[40..48].copy_from_slice(&(prev_output_bytes as u64).to_le_bytes());
bench_bytes[48..56].copy_from_slice(&prev_total_params.to_le_bytes());
bench_bytes[56..60].copy_from_slice(&prev_layer_count.to_le_bytes());
bench_bytes[60..64].copy_from_slice(&prev_input_dim.to_le_bytes());
bench_bytes[64..68].copy_from_slice(&prev_output_dim.to_le_bytes());
bench_bytes[68..76].copy_from_slice(&(prev_benchmark_flags | trained_flags).to_le_bytes());
bench_bytes[76..84].copy_from_slice(&prev_weights_bytes.to_le_bytes());
bench_bytes[84..92].copy_from_slice(&prev_biases_bytes.to_le_bytes());
bench_bytes[92..96].copy_from_slice(&prev_min_loss.to_le_bytes());
bench_bytes[96..100].copy_from_slice(&prev_max_loss.to_le_bytes());
bench_bytes[100..104].copy_from_slice(&prev_loss_stddev.to_le_bytes());
bench_bytes[104..108].copy_from_slice(&prev_iterations_per_sec.to_le_bytes());
bench_bytes[108..112].copy_from_slice(&prev_samples_per_sec.to_le_bytes());
}
Ok(precision_str)
}
pub fn read_rnn_format<'a>(
bytes: &'a [u8],
device_id: Option<&[u8]>,
) -> Result<ReadableRnnFormat<'a>, RnnFlowError> {
let dtype = rnn_dtype(bytes)?;
if !matches!(dtype, 0 | 1) {
return Err(RnnFlowError::BadBytes);
}
if crate::model_format::parse_header(bytes).is_some()
&& !crate::model_format::payload_bounds::has_full_payload(bytes)
{
return Err(RnnFlowError::BadBytes);
}
let scan = crate::rnn_format::validate(bytes, device_id).map_err(|_| RnnFlowError::Model)?;
let mut readable = ReadableRnnFormat {
scan,
model_name: None,
model_precision: None,
layer_meta: None,
weights: None,
biases: None,
benchmark: None,
};
if !scan.encrypted {
let mut scratch_buf = [0u8; 16384];
let mut scratch = Scratch::new(&mut scratch_buf);
let handle =
parse_rnn_from_bytes(bytes, &mut scratch).map_err(|_| RnnFlowError::BadBytes)?;
readable.layer_meta = blob_payload_by_name(bytes, &handle, BLOB_LAYER_META);
readable.weights = blob_payload_by_name(bytes, &handle, BLOB_WEIGHTS);
readable.biases = blob_payload_by_name(bytes, &handle, BLOB_BIASES);
readable.benchmark =
crate::model_format::header_tlv_payload(bytes, ingest::TLV_HEADER_BENCHMARK).ok();
readable.model_name =
crate::model_format::header_tlv_payload(bytes, ingest::TLV_HEADER_MODEL_NAME)
.ok()
.and_then(|v| core::str::from_utf8(v).ok());
readable.model_precision =
blob_payload_by_name(bytes, &handle, ingest::MODEL_PRECISION_BLOB_DATA)
.and_then(|v| core::str::from_utf8(v).ok());
}
Ok(readable)
}
pub fn fill_blob_payloads(bytes: &mut [u8], writes: &[BlobWrite<'_>]) -> Result<(), RnnFlowError> {
for write in writes {
let (start, end, len) = {
let mut scratch_buf = [0u8; 16384];
let mut scratch = Scratch::new(&mut scratch_buf);
let handle =
parse_rnn_from_bytes(bytes, &mut scratch).map_err(|_| RnnFlowError::BadBytes)?;
let idx = find_blob_index(&handle, write.name).ok_or(RnnFlowError::Model)?;
let desc = handle.blobs.get(idx).ok_or(RnnFlowError::Model)?;
let start = usize::try_from(desc.offset).map_err(|_| RnnFlowError::BadBytes)?;
let len = usize::try_from(desc.length).map_err(|_| RnnFlowError::BadBytes)?;
let end = start.checked_add(len).ok_or(RnnFlowError::BadBytes)?;
(start, end, len)
};
if end > bytes.len() || len != write.payload.len() {
return Err(RnnFlowError::BadBytes);
}
bytes[start..end].copy_from_slice(write.payload);
}
Ok(())
}
pub(crate) fn assemble_rnn_container(
records: &[BlobDesc<'_>],
model_name_header: Option<&str>,
benchmark_header: &[u8],
out_bytes: &mut [u8],
) -> Result<usize, RnnFlowError> {
let used = encode_blob_payloads_with_header_metadata(
records,
model_name_header,
Some(benchmark_header),
out_bytes,
)
.map_err(map_protocol_error)?;
Ok(used)
}
pub fn build_container_from_rmd1(
topology: &[usize],
precision: &str,
model_name: &str,
rmd1_bytes: &[u8],
metadata_scratch: &mut [u8],
out_bytes: &mut [u8],
) -> Result<usize, RnnFlowError> {
build_container_from_rmd1_with_runtime_input(
topology,
precision,
model_name,
rmd1_bytes,
None,
metadata_scratch,
out_bytes,
)
}
pub fn build_container_from_rmd1_with_runtime_input(
topology: &[usize],
precision: &str,
model_name: &str,
rmd1_bytes: &[u8],
runtime_input: Option<RuntimeInput<'_>>,
metadata_scratch: &mut [u8],
out_bytes: &mut [u8],
) -> Result<usize, RnnFlowError> {
let benchmark_flags = validate_benchmark_flags(topology, precision)?;
let benchmark_used = build_default_benchmark_blob(
"",
precision,
topology,
rmd1_bytes.len(),
benchmark_flags,
metadata_scratch,
)?;
let dense = parse_payload(rmd1_bytes).map_err(map_protocol_error)?;
let total_neurons =
total_neurons_from_layer_meta(dense.layer_meta).map_err(map_protocol_error)?;
let neuron_positions_len =
neuron_positions_blob_size(dense.dtype, dense.layer_meta).map_err(map_protocol_error)?;
let (benchmark_area, remaining_meta) = metadata_scratch.split_at_mut(benchmark_used);
if remaining_meta.len() < neuron_positions_len {
return Err(RnnFlowError::CapacityTooSmall);
}
let used_neuron_positions =
neuron_positions_blob_into(dense.dtype, dense.layer_meta, remaining_meta)
.map_err(map_protocol_error)?;
let (neuron_positions_area, runtime_area) = remaining_meta.split_at_mut(used_neuron_positions);
let input_dim = u32::from_le_bytes([
dense.layer_meta[0],
dense.layer_meta[1],
dense.layer_meta[2],
dense.layer_meta[3],
]) as usize;
let runtime_input_len = match runtime_input {
Some(RuntimeInput::F32(values)) => {
encode_runtime_input_from_f32(dense.dtype, input_dim, values, runtime_area)?
}
Some(RuntimeInput::F64(values)) => {
encode_runtime_input_from_f64(dense.dtype, input_dim, values, runtime_area)?
}
None => encode_runtime_input_blob(dense.dtype, input_dim, runtime_area)?,
};
let benchmark_blob = &benchmark_area[..benchmark_used];
let neuron_positions_blob = &neuron_positions_area[..used_neuron_positions];
let runtime_input_blob = &runtime_area[..runtime_input_len];
let records = [
BlobDesc {
name: BLOB_NEURON_POSITIONS,
dtype: dense.dtype,
dims: [total_neurons as u32, 5],
ndim: 2,
payload: neuron_positions_blob,
},
BlobDesc {
name: BLOB_LAYER_META,
dtype: dense.dtype,
dims: [dense.layer_count as u32, 5],
ndim: 2,
payload: dense.layer_meta,
},
BlobDesc {
name: BLOB_WEIGHTS,
dtype: dense.dtype,
dims: [dense.weights_len as u32, 0],
ndim: 1,
payload: dense.weights,
},
BlobDesc {
name: BLOB_BIASES,
dtype: dense.dtype,
dims: [dense.biases_len as u32, 0],
ndim: 1,
payload: dense.biases,
},
BlobDesc {
name: BLOB_RUNTIME_INPUT,
dtype: dense.dtype,
dims: [input_dim as u32, 0],
ndim: 1,
payload: runtime_input_blob,
},
];
crate::model_config::rnn::validate_payload_records(&records)
.map_err(|_| RnnFlowError::Model)?;
let model_name_header = if model_name.is_empty() {
None
} else {
Some(model_name)
};
assemble_rnn_container(&records, model_name_header, benchmark_blob, out_bytes)
}
fn encode_runtime_input_blob(
dtype: u8,
input_dim: usize,
out: &mut [u8],
) -> Result<usize, RnnFlowError> {
if input_dim == 0 {
return Err(RnnFlowError::CapacityTooSmall);
}
let f32_bytes = input_dim
.checked_mul(4)
.ok_or(RnnFlowError::CapacityTooSmall)?;
let align = core::mem::align_of::<f32>();
let addr = out.as_ptr() as usize;
let pad = (align - (addr % align)) % align;
if pad
.checked_add(f32_bytes)
.ok_or(RnnFlowError::CapacityTooSmall)?
> out.len()
{
return Err(RnnFlowError::CapacityTooSmall);
}
let tensor_data: &mut [f32] = unsafe {
core::slice::from_raw_parts_mut(out.as_mut_ptr().add(pad) as *mut f32, input_dim)
};
let mut tensor = TensorView {
data: tensor_data,
shape: [1, 1, 1, 1, input_dim],
};
if !tensor.is_valid_layout() {
return Err(RnnFlowError::BadBytes);
}
for i in 0..input_dim {
let value = (i as f32) + 1.0;
if let Some(cell) = tensor.get_mut(0, 0, 0, 0, i) {
*cell = value;
} else {
return Err(RnnFlowError::BadBytes);
}
}
if dtype == 0 {
if pad != 0 {
out.copy_within(pad..pad + f32_bytes, 0);
}
Ok(f32_bytes)
} else {
let needed = input_dim
.checked_mul(8)
.ok_or(RnnFlowError::CapacityTooSmall)?;
if needed > out.len() {
return Err(RnnFlowError::CapacityTooSmall);
}
for i in 0..input_dim {
let value = tensor.get(0, 0, 0, 0, i).ok_or(RnnFlowError::BadBytes)? as f64;
let b = value.to_le_bytes();
let start = i * 8;
out[start..start + 8].copy_from_slice(&b);
}
Ok(needed)
}
}
fn encode_runtime_input_from_f32(
dtype: u8,
input_dim: usize,
values: &[f32],
out: &mut [u8],
) -> Result<usize, RnnFlowError> {
if dtype != 0 || values.len() != input_dim {
return Err(RnnFlowError::BadBytes);
}
let needed = input_dim
.checked_mul(4)
.ok_or(RnnFlowError::CapacityTooSmall)?;
let align = core::mem::align_of::<f32>();
let addr = out.as_ptr() as usize;
let pad = (align - (addr % align)) % align;
if pad
.checked_add(needed)
.ok_or(RnnFlowError::CapacityTooSmall)?
> out.len()
{
return Err(RnnFlowError::CapacityTooSmall);
}
let tensor_data: &mut [f32] = unsafe {
core::slice::from_raw_parts_mut(out.as_mut_ptr().add(pad) as *mut f32, input_dim)
};
let mut tensor = TensorView {
data: tensor_data,
shape: [1, 1, 1, 1, input_dim],
};
if !tensor.is_valid_layout() {
return Err(RnnFlowError::BadBytes);
}
for (i, value) in values.iter().take(input_dim).enumerate() {
if let Some(cell) = tensor.get_mut(0, 0, 0, 0, i) {
*cell = *value;
} else {
return Err(RnnFlowError::BadBytes);
}
}
if pad != 0 {
out.copy_within(pad..pad + needed, 0);
}
Ok(needed)
}
fn encode_runtime_input_from_f64(
dtype: u8,
input_dim: usize,
values: &[f64],
out: &mut [u8],
) -> Result<usize, RnnFlowError> {
if dtype != 1 || values.len() != input_dim {
return Err(RnnFlowError::BadBytes);
}
let needed = input_dim
.checked_mul(8)
.ok_or(RnnFlowError::CapacityTooSmall)?;
if needed > out.len() {
return Err(RnnFlowError::CapacityTooSmall);
}
for i in 0..input_dim {
out[i * 8..i * 8 + 8].copy_from_slice(&values[i].to_le_bytes());
}
Ok(needed)
}
fn blob_payload_by_name<'a>(
bytes: &'a [u8],
handle: &crate::rnn_format::parser::RnnHandle<'_, '_>,
name: &str,
) -> Option<&'a [u8]> {
let idx = find_blob_index(handle, name)?;
let desc = handle.blobs.get(idx)?;
let start = usize::try_from(desc.offset).ok()?;
let len = usize::try_from(desc.length).ok()?;
let end = start.checked_add(len)?;
if end > bytes.len() {
return None;
}
Some(&bytes[start..end])
}
fn map_layer_error(err: LayerError) -> RnnFlowError {
match err {
LayerError::EmptyPlan | LayerError::InvalidShape | LayerError::IncompatibleChain => {
RnnFlowError::InvalidTopology
}
LayerError::BufferTooSmall => RnnFlowError::CapacityTooSmall,
_ => RnnFlowError::Model,
}
}
fn map_model_error(err: ModelFormatError) -> RnnFlowError {
match err {
ModelFormatError::CapacityTooSmall => RnnFlowError::CapacityTooSmall,
_ => RnnFlowError::BadBytes,
}
}
fn map_protocol_error(err: RnnProtocolError) -> RnnFlowError {
match err {
RnnProtocolError::CapacityTooSmall => RnnFlowError::CapacityTooSmall,
_ => RnnFlowError::BadBytes,
}
}
pub(crate) fn validate_distribution_secure(
bytes: &[u8],
current_device_id: &[u8],
trusted_publisher_pubkeys: &[[u8; ASYM_PUBLIC_KEY_LEN]],
) -> Result<(), crate::rnn_format::Error> {
crate::model_config::rnn::validate_distribution_secure(
bytes,
current_device_id,
trusted_publisher_pubkeys,
)
}
#[cfg(feature = "publisher-trust-service")]
pub(crate) fn validate_distribution_with_service_secure<S: PublisherTrustService>(
bytes: &[u8],
current_device_id: &[u8],
trusted_publisher_pubkeys: &[[u8; ASYM_PUBLIC_KEY_LEN]],
machine_fingerprint: &[u8],
now_unix: u64,
service: &S,
) -> Result<(), crate::rnn_format::Error> {
crate::model_config::rnn::validate_distribution_with_service_secure(
bytes,
current_device_id,
trusted_publisher_pubkeys,
machine_fingerprint,
now_unix,
service,
)
}