use crate::crypto::{constant_time_eq, sha256_bytes};
pub(crate) const RNN0_MAGIC: &[u8; 4] = b"RNN\x00";
pub(crate) const RNN0_VERSION: u16 = 1;
pub(crate) const RMD1_MAGIC: &[u8; 4] = b"RMD1";
pub(crate) const RMD1_VERSION: u16 = 1;
pub(crate) const RMD1_HEADER_SIZE: usize = 20;
pub(crate) const LAYER_META_SIZE: usize = 20;
pub(crate) const TLV_BLOB_TABLE: u8 = 0x03;
pub(crate) const TLV_HEADER_MODEL_NAME: u8 = 0x04;
pub(crate) const TLV_HEADER_NETWORK_SUMMARY: u8 = 0x05;
pub(crate) const TLV_HEADER_BENCHMARK: u8 = 0x06;
pub(crate) const BLOB_LAYER_META: &str = "layer_meta";
pub(crate) const BLOB_WEIGHTS: &str = "weights";
pub(crate) const BLOB_BIASES: &str = "biases";
pub(crate) const BLOB_NEURON_POSITIONS: &str = "neuron_positions";
pub(crate) const BLOB_RUNTIME_INPUT: &str = "runtime.input";
pub(crate) const TOKENIZER_VOCAB_BLOB_DATA: &str = "tokenizer.vocab";
pub(crate) const TOKENIZER_MERGES_BLOB_DATA: &str = "tokenizer.merges";
pub(crate) const GRAPH_OPS_BLOB_DATA: &str = "graph.ops";
pub(crate) const OPTIMIZER_STATE_BLOB_DATA: &str = "optimizer.state";
pub(crate) const TENSORS_SNAPSHOT_BLOB_DATA: &str = "tensors.snapshot";
const SUMMARY_MAGIC: &[u8; 4] = b"S5D0";
const SUMMARY_VERSION: u8 = 1;
const SUMMARY_CODEC_RLE: u8 = 1;
const SUMMARY_PLAIN_LEN: usize = 30;
const SUMMARY_MAX_COMPRESSED_LEN: usize = 2 + (SUMMARY_PLAIN_LEN * 2);
const COMPACT_BENCHMARK_LEN: usize = 32;
const COMPACT_NETWORK_SUMMARY_LEN: usize = 16;
const SUMMARY_FLAG_HAS_LAYER_META: u32 = 1 << 0;
const SUMMARY_FLAG_HAS_WEIGHTS: u32 = 1 << 1;
const SUMMARY_FLAG_HAS_BIASES: u32 = 1 << 2;
const SUMMARY_FLAG_HAS_NEURON_POSITIONS: u32 = 1 << 3;
const SUMMARY_FLAG_HAS_TOKENIZER_VOCAB: u32 = 1 << 4;
const SUMMARY_FLAG_HAS_TOKENIZER_MERGES: u32 = 1 << 5;
const SUMMARY_FLAG_HAS_GRAPH_OPS: u32 = 1 << 6;
const SUMMARY_FLAG_HAS_OPTIMIZER_STATE: u32 = 1 << 7;
const SUMMARY_FLAG_HAS_TENSORS_SNAPSHOT: u32 = 1 << 8;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum RnnProtocolError {
Truncated,
BadMagic,
BadVersion,
BadHeader,
CapacityTooSmall,
InvalidPayload,
}
#[derive(Clone, Copy)]
pub(crate) struct BlobDesc<'a> {
pub name: &'a str,
pub dtype: u8,
pub dims: [u32; 2],
pub ndim: u8,
pub payload: &'a [u8],
}
#[derive(Clone, Copy)]
pub(crate) struct Payload<'a> {
pub dtype: u8,
pub layer_count: usize,
pub weights_len: usize,
pub biases_len: usize,
pub layer_meta: &'a [u8],
pub weights: &'a [u8],
pub biases: &'a [u8],
}
#[derive(Clone, Copy)]
struct SummaryBlob {
bytes: [u8; SUMMARY_MAX_COMPRESSED_LEN],
len: usize,
}
impl SummaryBlob {
fn as_slice(&self) -> &[u8] {
&self.bytes[..self.len]
}
fn len(&self) -> usize {
self.len
}
}
pub(crate) fn parse_payload<'a>(bytes: &'a [u8]) -> Result<Payload<'a>, RnnProtocolError> {
if bytes.len() < RMD1_HEADER_SIZE {
return Err(RnnProtocolError::Truncated);
}
if !constant_time_eq(&bytes[0..4], RMD1_MAGIC) {
return Err(RnnProtocolError::BadMagic);
}
let version = u16::from_le_bytes([bytes[4], bytes[5]]);
if version != RMD1_VERSION {
return Err(RnnProtocolError::BadVersion);
}
let dtype = bytes[6];
if dtype != 0 && dtype != 1 {
return Err(RnnProtocolError::BadHeader);
}
let layer_count = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize;
let weights_len = u32::from_le_bytes([bytes[12], bytes[13], bytes[14], bytes[15]]) as usize;
let biases_len = u32::from_le_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]) as usize;
let layer_meta_bytes = layer_count
.checked_mul(LAYER_META_SIZE)
.ok_or(RnnProtocolError::BadHeader)?;
let elem_size = if dtype == 0 { 4usize } else { 8usize };
let weights_bytes = weights_len
.checked_mul(elem_size)
.ok_or(RnnProtocolError::BadHeader)?;
let biases_bytes = biases_len
.checked_mul(elem_size)
.ok_or(RnnProtocolError::BadHeader)?;
let layers_start = RMD1_HEADER_SIZE;
let layers_end = layers_start
.checked_add(layer_meta_bytes)
.ok_or(RnnProtocolError::BadHeader)?;
let weights_end = layers_end
.checked_add(weights_bytes)
.ok_or(RnnProtocolError::BadHeader)?;
let biases_end = weights_end
.checked_add(biases_bytes)
.ok_or(RnnProtocolError::BadHeader)?;
if biases_end > bytes.len() {
return Err(RnnProtocolError::Truncated);
}
Ok(Payload {
dtype,
layer_count,
weights_len,
biases_len,
layer_meta: &bytes[layers_start..layers_end],
weights: &bytes[layers_end..weights_end],
biases: &bytes[weights_end..biases_end],
})
}
pub(crate) fn header_tlv_payload(bytes: &[u8], wanted_type: u8) -> Result<&[u8], RnnProtocolError> {
if bytes.len() < 12 {
return Err(RnnProtocolError::Truncated);
}
if !constant_time_eq(&bytes[0..4], RNN0_MAGIC) {
return Err(RnnProtocolError::BadMagic);
}
let version = u16::from_le_bytes([bytes[4], bytes[5]]);
if version != RNN0_VERSION {
return Err(RnnProtocolError::BadVersion);
}
let header_size = u32::from_le_bytes(
bytes[8..12]
.try_into()
.map_err(|_| RnnProtocolError::BadHeader)?,
) as usize;
if header_size > bytes.len() || header_size < 12 {
return Err(RnnProtocolError::BadHeader);
}
if header_size >= 0xA0 && bytes.len() >= 0xA0 && bytes[0x50] == 0xC1 {
match wanted_type {
TLV_HEADER_MODEL_NAME => {
let region = &bytes[0x10..0x20];
let end = region.iter().position(|&b| b == 0).unwrap_or(region.len());
return Ok(®ion[..end]);
}
TLV_HEADER_BENCHMARK => {
return Ok(&bytes[0x20..0x20 + COMPACT_BENCHMARK_LEN]);
}
TLV_HEADER_NETWORK_SUMMARY => {
return Ok(&bytes[0x40..0x40 + COMPACT_NETWORK_SUMMARY_LEN]);
}
TLV_BLOB_TABLE => return Ok(&bytes[0x50..0xA0]),
_ => return Err(RnnProtocolError::InvalidPayload),
}
}
let mut cursor = 12usize;
while cursor < header_size {
if bytes[cursor] == 0 {
cursor += 1;
continue;
}
if cursor + 5 > header_size {
return Err(RnnProtocolError::BadHeader);
}
let tlv_type = bytes[cursor];
cursor += 1;
let tlv_len = u32::from_le_bytes(
bytes[cursor..cursor + 4]
.try_into()
.map_err(|_| RnnProtocolError::BadHeader)?,
) as usize;
cursor += 4;
let end = cursor
.checked_add(tlv_len)
.ok_or(RnnProtocolError::BadHeader)?;
if end > header_size {
return Err(RnnProtocolError::BadHeader);
}
if tlv_type == wanted_type {
return Ok(&bytes[cursor..end]);
}
cursor = end;
}
Err(RnnProtocolError::InvalidPayload)
}
pub(crate) fn total_neurons_from_layer_meta(layer_meta: &[u8]) -> Result<usize, RnnProtocolError> {
if layer_meta.len() < LAYER_META_SIZE || !layer_meta.len().is_multiple_of(LAYER_META_SIZE) {
return Err(RnnProtocolError::InvalidPayload);
}
let layer_count = layer_meta.len() / LAYER_META_SIZE;
let input0 =
u32::from_le_bytes([layer_meta[0], layer_meta[1], layer_meta[2], layer_meta[3]]) as usize;
let mut total = input0;
for idx in 0..layer_count {
let base = idx * LAYER_META_SIZE;
let out = u32::from_le_bytes([
layer_meta[base + 4],
layer_meta[base + 5],
layer_meta[base + 6],
layer_meta[base + 7],
]) as usize;
total = total.checked_add(out).ok_or(RnnProtocolError::BadHeader)?;
}
Ok(total)
}
pub(crate) fn neuron_positions_blob_size(
dtype: u8,
layer_meta: &[u8],
) -> Result<usize, RnnProtocolError> {
if (dtype != 0 && dtype != 1)
|| layer_meta.len() < LAYER_META_SIZE
|| !layer_meta.len().is_multiple_of(LAYER_META_SIZE)
{
return Err(RnnProtocolError::InvalidPayload);
}
let total_neurons = total_neurons_from_layer_meta(layer_meta)?;
let elem_size = if dtype == 0 { 4usize } else { 8usize };
let out_len = total_neurons
.checked_mul(5)
.and_then(|v| v.checked_mul(elem_size))
.ok_or(RnnProtocolError::BadHeader)?;
Ok(out_len)
}
pub(crate) fn neuron_positions_blob_into(
dtype: u8,
layer_meta: &[u8],
out: &mut [u8],
) -> Result<usize, RnnProtocolError> {
let out_len = neuron_positions_blob_size(dtype, layer_meta)?;
if out.len() < out_len {
return Err(RnnProtocolError::CapacityTooSmall);
}
let layer_count = layer_meta.len() / LAYER_META_SIZE;
let input0 =
u32::from_le_bytes([layer_meta[0], layer_meta[1], layer_meta[2], layer_meta[3]]) as usize;
let mut cursor = 0usize;
for layer_idx in 0..=layer_count {
let layer_size = if layer_idx == 0 {
input0
} else {
let base = (layer_idx - 1) * LAYER_META_SIZE;
u32::from_le_bytes([
layer_meta[base + 4],
layer_meta[base + 5],
layer_meta[base + 6],
layer_meta[base + 7],
]) as usize
};
for neuron_idx in 0..layer_size {
let denom_layer = (layer_count + 2) as f32;
let denom_neuron = (layer_size + 1) as f32;
let p0 = (layer_idx as f32 + 1.0) / denom_layer;
let p1 = (neuron_idx as f32 + 1.0) / denom_neuron;
let seed = ((layer_idx as u64) << 32) ^ (neuron_idx as u64);
if dtype == 0 {
let pos = crate::sphere5d::sphere_pos_from_seed_f32(seed, 1.0);
for v in [p0, p1, pos[2], pos[3], pos[4]] {
let b = v.to_le_bytes();
out[cursor..cursor + 4].copy_from_slice(&b);
cursor += 4;
}
} else {
let pos = crate::sphere5d::sphere_pos_from_seed_f64(seed, 1.0);
for v in [p0 as f64, p1 as f64, pos[2], pos[3], pos[4]] {
let b = v.to_le_bytes();
out[cursor..cursor + 8].copy_from_slice(&b);
cursor += 8;
}
}
}
}
Ok(cursor)
}
#[allow(dead_code)]
pub(crate) fn encoded_size_for_blob_records(
records: &[BlobDesc<'_>],
) -> Result<usize, RnnProtocolError> {
encoded_size_for_blob_records_with_header_benchmark(records, None)
}
pub(crate) fn encoded_size_for_blob_records_with_header_benchmark(
records: &[BlobDesc<'_>],
benchmark_header: Option<&[u8]>,
) -> Result<usize, RnnProtocolError> {
encoded_size_for_blob_records_with_header_metadata(records, None, benchmark_header)
}
pub(crate) fn encoded_size_for_blob_records_with_header_metadata(
records: &[BlobDesc<'_>],
model_name_header: Option<&str>,
benchmark_header: Option<&[u8]>,
) -> Result<usize, RnnProtocolError> {
let desc_len = descriptor_len(records)?;
let payload_len = payload_total_len(records)?;
let summary_len = header_network_summary_blob(records).map_or(0usize, |b| b.len());
let header_size = 12usize
.checked_add(5)
.and_then(|v| v.checked_add(summary_len))
.and_then(|v| {
model_name_header
.map(|name| v.checked_add(5).and_then(|n| n.checked_add(name.len())))
.unwrap_or(Some(v))
})
.and_then(|v| {
benchmark_header
.map(|b| v.checked_add(5).and_then(|n| n.checked_add(b.len())))
.unwrap_or(Some(v))
})
.and_then(|v| v.checked_add(5))
.and_then(|v| v.checked_add(desc_len))
.ok_or(RnnProtocolError::BadHeader)?;
header_size
.checked_add(payload_len)
.ok_or(RnnProtocolError::BadHeader)
}
#[allow(dead_code)]
pub(crate) fn encode_blob_template(
records: &[BlobDesc<'_>],
out: &mut [u8],
) -> Result<usize, RnnProtocolError> {
encode_blob_template_with_header_benchmark(records, None, out)
}
pub(crate) fn encode_blob_template_with_header_benchmark(
records: &[BlobDesc<'_>],
benchmark_header: Option<&[u8]>,
out: &mut [u8],
) -> Result<usize, RnnProtocolError> {
encode_blob_template_with_header_metadata(records, None, benchmark_header, out)
}
pub(crate) fn encode_blob_template_with_header_metadata(
records: &[BlobDesc<'_>],
model_name_header: Option<&str>,
benchmark_header: Option<&[u8]>,
out: &mut [u8],
) -> Result<usize, RnnProtocolError> {
encode_blob_records_impl(records, model_name_header, benchmark_header, out, false)
}
pub(crate) fn encode_blob_payloads_with_header_metadata(
records: &[BlobDesc<'_>],
model_name_header: Option<&str>,
benchmark_header: Option<&[u8]>,
out: &mut [u8],
) -> Result<usize, RnnProtocolError> {
encode_blob_records_impl(records, model_name_header, benchmark_header, out, true)
}
fn encode_blob_records_impl(
records: &[BlobDesc<'_>],
model_name_header: Option<&str>,
benchmark_header: Option<&[u8]>,
out: &mut [u8],
copy_payloads: bool,
) -> Result<usize, RnnProtocolError> {
let compact_eligible = records.len() <= 5
&& records.iter().all(|rec| {
rec.ndim > 0
&& rec.ndim <= 2
&& rec.dims[0] <= u16::MAX as u32
&& rec.dims[1] <= u16::MAX as u32
&& rec.payload.len() <= u32::MAX as usize
});
if compact_eligible && payload_total_len(records)? <= u32::MAX as usize {
let payload_len = payload_total_len(records)?;
let header_size = 0xA0usize;
let raw_total_size = header_size
.checked_add(payload_len)
.ok_or(RnnProtocolError::BadHeader)?;
let total_size = raw_total_size
.checked_add(15)
.map(|v| v & !15usize)
.ok_or(RnnProtocolError::BadHeader)?;
if out.len() < total_size {
return Err(RnnProtocolError::CapacityTooSmall);
}
out[..header_size].fill(0);
out[0..4].copy_from_slice(RNN0_MAGIC);
out[4..6].copy_from_slice(&RNN0_VERSION.to_le_bytes());
out[6..8].copy_from_slice(&0u16.to_le_bytes());
out[8..12].copy_from_slice(&(header_size as u32).to_le_bytes());
let model_name = model_name_header.unwrap_or("");
if model_name.len() > 11 {
return Err(RnnProtocolError::BadHeader);
}
out[0x10..0x10 + model_name.len()].copy_from_slice(model_name.as_bytes());
let benchmark = benchmark_header.unwrap_or(b"BMK\x01");
if benchmark.len() > 32 {
return Err(RnnProtocolError::BadHeader);
}
out[0x20..0x20 + benchmark.len()].copy_from_slice(benchmark);
let visual = compact_network_summary_blob(records);
out[0x40..0x40 + visual.len()].copy_from_slice(&visual);
let table_start = 0x50usize;
out[table_start..0xA0].fill(0);
out[table_start] = 0xC1;
out[table_start + 1] = 1;
out[table_start + 2] = records.len() as u8;
let mut payload_cursor = header_size;
let entry_size = 15usize;
let mut entry_cursor = 3usize;
for rec in records {
let blob_id = match rec.name {
BLOB_NEURON_POSITIONS => 1u8,
BLOB_LAYER_META => 2u8,
BLOB_WEIGHTS => 3u8,
BLOB_BIASES => 4u8,
BLOB_RUNTIME_INPUT => 5u8,
_ => return Err(RnnProtocolError::BadHeader),
};
if table_start + entry_cursor + entry_size > 0xA0 {
return Err(RnnProtocolError::BadHeader);
}
if rec.ndim == 0 || rec.ndim > 2 {
return Err(RnnProtocolError::BadHeader);
}
if rec.dims[0] > u16::MAX as u32 || rec.dims[1] > u16::MAX as u32 {
return Err(RnnProtocolError::BadHeader);
}
if payload_cursor > u32::MAX as usize || rec.payload.len() > u32::MAX as usize {
return Err(RnnProtocolError::BadHeader);
}
out[table_start + entry_cursor] = blob_id;
out[table_start + entry_cursor + 1] = rec.dtype;
out[table_start + entry_cursor + 2] = rec.ndim;
out[table_start + entry_cursor + 3..table_start + entry_cursor + 5]
.copy_from_slice(&(rec.dims[0] as u16).to_le_bytes());
out[table_start + entry_cursor + 5..table_start + entry_cursor + 7]
.copy_from_slice(&(rec.dims[1] as u16).to_le_bytes());
out[table_start + entry_cursor + 7..table_start + entry_cursor + 11]
.copy_from_slice(&(payload_cursor as u32).to_le_bytes());
out[table_start + entry_cursor + 11..table_start + entry_cursor + 15]
.copy_from_slice(&(rec.payload.len() as u32).to_le_bytes());
entry_cursor += entry_size;
let next_payload = payload_cursor
.checked_add(rec.payload.len())
.ok_or(RnnProtocolError::BadHeader)?;
if copy_payloads {
out[payload_cursor..next_payload].copy_from_slice(rec.payload);
} else {
out[payload_cursor..next_payload].fill(0);
}
payload_cursor = next_payload;
}
let table_used = 3usize
.checked_add(
records
.len()
.checked_mul(entry_size)
.ok_or(RnnProtocolError::BadHeader)?,
)
.ok_or(RnnProtocolError::BadHeader)?;
if table_used > 0x50 {
return Err(RnnProtocolError::BadHeader);
}
let table_body_start = table_start + table_used;
let table_body_end = 0xA0usize;
for (idx, slot) in out[table_body_start..table_body_end].iter_mut().enumerate() {
let v = (idx as u8).wrapping_mul(0x3D).wrapping_add(0xA7);
*slot = v;
}
if table_body_start + 12 <= 0xA0 {
out[table_body_start..table_body_start + 4]
.copy_from_slice(&(payload_len as u32).to_le_bytes());
out[table_body_start + 4..table_body_start + 8]
.copy_from_slice(&(header_size as u32).to_le_bytes());
out[table_body_start + 8..table_body_start + 12]
.copy_from_slice(&(payload_cursor as u32).to_le_bytes());
}
for (i, b) in out[payload_cursor..total_size].iter_mut().enumerate() {
*b = (i as u8).wrapping_mul(0x5B).wrapping_add(0x31);
}
return Ok(total_size);
}
let desc_len = descriptor_len(records)?;
let payload_len = payload_total_len(records)?;
let header_summary = header_network_summary_blob(records);
let header_summary_len = header_summary.as_ref().map_or(0usize, |b| b.len());
let header_size = 12usize
.checked_add(5)
.and_then(|v| v.checked_add(header_summary_len))
.and_then(|v| {
model_name_header
.map(|name| v.checked_add(5).and_then(|n| n.checked_add(name.len())))
.unwrap_or(Some(v))
})
.and_then(|v| {
benchmark_header
.map(|b| v.checked_add(5).and_then(|n| n.checked_add(b.len())))
.unwrap_or(Some(v))
})
.and_then(|v| v.checked_add(5))
.and_then(|v| v.checked_add(desc_len))
.ok_or(RnnProtocolError::BadHeader)?;
let total_size = header_size
.checked_add(payload_len)
.ok_or(RnnProtocolError::BadHeader)?;
if out.len() < total_size {
return Err(RnnProtocolError::CapacityTooSmall);
}
out[0..4].copy_from_slice(RNN0_MAGIC);
out[4..6].copy_from_slice(&RNN0_VERSION.to_le_bytes());
out[6..8].copy_from_slice(&0u16.to_le_bytes());
out[8..12].copy_from_slice(&(header_size as u32).to_le_bytes());
let mut cursor = 12usize;
if let Some(model_name) = model_name_header {
let name_bytes = model_name.as_bytes();
out[cursor] = TLV_HEADER_MODEL_NAME;
cursor += 1;
out[cursor..cursor + 4].copy_from_slice(&(name_bytes.len() as u32).to_le_bytes());
cursor += 4;
out[cursor..cursor + name_bytes.len()].copy_from_slice(name_bytes);
cursor += name_bytes.len();
}
if let Some(benchmark) = benchmark_header {
out[cursor] = TLV_HEADER_BENCHMARK;
cursor += 1;
out[cursor..cursor + 4].copy_from_slice(&(benchmark.len() as u32).to_le_bytes());
cursor += 4;
out[cursor..cursor + benchmark.len()].copy_from_slice(benchmark);
cursor += benchmark.len();
}
if let Some(summary) = header_summary.as_ref() {
out[cursor] = TLV_HEADER_NETWORK_SUMMARY;
cursor += 1;
out[cursor..cursor + 4].copy_from_slice(&(summary.len() as u32).to_le_bytes());
cursor += 4;
out[cursor..cursor + summary.len()].copy_from_slice(summary.as_slice());
cursor += summary.len();
}
out[cursor] = TLV_BLOB_TABLE;
cursor += 1;
out[cursor..cursor + 4].copy_from_slice(&(desc_len as u32).to_le_bytes());
cursor += 4;
let mut payload_cursor = header_size;
for rec in records {
let name_bytes = rec.name.as_bytes();
if name_bytes.len() > u16::MAX as usize {
return Err(RnnProtocolError::BadHeader);
}
out[cursor..cursor + 2].copy_from_slice(&(name_bytes.len() as u16).to_le_bytes());
cursor += 2;
out[cursor..cursor + name_bytes.len()].copy_from_slice(name_bytes);
cursor += name_bytes.len();
out[cursor] = rec.dtype;
cursor += 1;
out[cursor] = rec.ndim;
cursor += 1;
for dim in rec.dims.iter().take(rec.ndim as usize) {
out[cursor..cursor + 4].copy_from_slice(&dim.to_le_bytes());
cursor += 4;
}
out[cursor..cursor + 8].copy_from_slice(&(payload_cursor as u64).to_le_bytes());
cursor += 8;
out[cursor..cursor + 8].copy_from_slice(&(rec.payload.len() as u64).to_le_bytes());
cursor += 8;
let mut digest = [0u8; 32];
sha256_bytes(rec.payload, &mut digest);
out[cursor..cursor + 32].copy_from_slice(&digest);
cursor += 32;
let next_payload = payload_cursor
.checked_add(rec.payload.len())
.ok_or(RnnProtocolError::BadHeader)?;
if copy_payloads {
out[payload_cursor..next_payload].copy_from_slice(rec.payload);
} else {
out[payload_cursor..next_payload].fill(0);
}
payload_cursor = next_payload;
}
Ok(total_size)
}
fn header_network_summary_blob(records: &[BlobDesc<'_>]) -> Option<SummaryBlob> {
let mut flags = 0u32;
let mut dtype = 2u8;
let mut layer_count = 0u32;
let mut total_neurons = 0u32;
let mut weights_len = 0u32;
let mut biases_len = 0u32;
for rec in records {
match rec.name {
BLOB_LAYER_META => {
flags |= SUMMARY_FLAG_HAS_LAYER_META;
dtype = rec.dtype;
layer_count = rec.dims[0];
}
BLOB_WEIGHTS => {
flags |= SUMMARY_FLAG_HAS_WEIGHTS;
dtype = rec.dtype;
weights_len = rec.dims[0];
}
BLOB_BIASES => {
flags |= SUMMARY_FLAG_HAS_BIASES;
dtype = rec.dtype;
biases_len = rec.dims[0];
}
BLOB_NEURON_POSITIONS => {
flags |= SUMMARY_FLAG_HAS_NEURON_POSITIONS;
dtype = rec.dtype;
total_neurons = rec.dims[0];
}
TOKENIZER_VOCAB_BLOB_DATA => flags |= SUMMARY_FLAG_HAS_TOKENIZER_VOCAB,
TOKENIZER_MERGES_BLOB_DATA => flags |= SUMMARY_FLAG_HAS_TOKENIZER_MERGES,
GRAPH_OPS_BLOB_DATA => flags |= SUMMARY_FLAG_HAS_GRAPH_OPS,
OPTIMIZER_STATE_BLOB_DATA => flags |= SUMMARY_FLAG_HAS_OPTIMIZER_STATE,
TENSORS_SNAPSHOT_BLOB_DATA => flags |= SUMMARY_FLAG_HAS_TENSORS_SNAPSHOT,
_ => {}
}
}
if (flags & SUMMARY_FLAG_HAS_NEURON_POSITIONS) == 0 {
return None;
}
let mut plain = [0u8; SUMMARY_PLAIN_LEN];
plain[0..4].copy_from_slice(SUMMARY_MAGIC);
plain[4] = SUMMARY_VERSION;
plain[5] = dtype;
plain[6..10].copy_from_slice(&layer_count.to_le_bytes());
plain[10..14].copy_from_slice(&total_neurons.to_le_bytes());
plain[14..18].copy_from_slice(&weights_len.to_le_bytes());
plain[18..22].copy_from_slice(&biases_len.to_le_bytes());
plain[22..26].copy_from_slice(&(records.len() as u32).to_le_bytes());
plain[26..30].copy_from_slice(&flags.to_le_bytes());
let mut out = [0u8; SUMMARY_MAX_COMPRESSED_LEN];
out[0] = SUMMARY_VERSION;
out[1] = SUMMARY_CODEC_RLE;
let used = rle_encode_bytes(&plain, &mut out[2..])?;
Some(SummaryBlob {
bytes: out,
len: 2 + used,
})
}
fn compact_network_summary_blob(records: &[BlobDesc<'_>]) -> [u8; COMPACT_NETWORK_SUMMARY_LEN] {
let mut dtype = 2u8;
let mut layer_count = 0u16;
let mut total_neurons = 0u32;
let mut weights_len = 0u16;
let mut biases_len = 0u16;
for rec in records {
match rec.name {
BLOB_LAYER_META => {
dtype = rec.dtype;
layer_count = u16::try_from(rec.dims[0]).unwrap_or(u16::MAX);
}
BLOB_WEIGHTS => {
dtype = rec.dtype;
weights_len = u16::try_from(rec.dims[0]).unwrap_or(u16::MAX);
}
BLOB_BIASES => {
dtype = rec.dtype;
biases_len = u16::try_from(rec.dims[0]).unwrap_or(u16::MAX);
}
BLOB_NEURON_POSITIONS => {
dtype = rec.dtype;
total_neurons = rec.dims[0];
}
_ => {}
}
}
let mut out = [0u8; COMPACT_NETWORK_SUMMARY_LEN];
out[0..4].copy_from_slice(b"VIZ\x00");
out[4] = dtype;
out[5..7].copy_from_slice(&layer_count.to_le_bytes());
out[7..11].copy_from_slice(&total_neurons.to_le_bytes());
out[11..13].copy_from_slice(&weights_len.to_le_bytes());
out[13..15].copy_from_slice(&biases_len.to_le_bytes());
out[15] = u8::try_from(records.len()).unwrap_or(u8::MAX);
out
}
fn descriptor_len(records: &[BlobDesc<'_>]) -> Result<usize, RnnProtocolError> {
let mut total = 0usize;
for rec in records {
if rec.name.len() > u16::MAX as usize {
return Err(RnnProtocolError::BadHeader);
}
if rec.ndim == 0 || rec.ndim > 2 {
return Err(RnnProtocolError::BadHeader);
}
total = total
.checked_add(2)
.and_then(|v| v.checked_add(rec.name.len()))
.and_then(|v| v.checked_add(1 + 1))
.and_then(|v| v.checked_add((rec.ndim as usize).checked_mul(4)?))
.and_then(|v| v.checked_add(8 + 8 + 32))
.ok_or(RnnProtocolError::BadHeader)?;
}
Ok(total)
}
fn payload_total_len(records: &[BlobDesc<'_>]) -> Result<usize, RnnProtocolError> {
let mut total = 0usize;
for rec in records {
total = total
.checked_add(rec.payload.len())
.ok_or(RnnProtocolError::BadHeader)?;
}
Ok(total)
}
fn rle_encode_bytes(input: &[u8], out: &mut [u8]) -> Option<usize> {
if input.is_empty() {
return Some(0);
}
let mut idx = 0usize;
let mut cursor = 0usize;
while idx < input.len() {
let value = input[idx];
let mut run = 1usize;
while idx + run < input.len() && input[idx + run] == value && run < 255 {
run += 1;
}
if cursor + 2 > out.len() {
return None;
}
out[cursor] = run as u8;
out[cursor + 1] = value;
cursor += 2;
idx += run;
}
Some(cursor)
}