use crate::activations::ActivationKind;
use crate::engine::{forward_dense_plan, ForwardError};
use crate::initializers::expected_parameter_counts;
use crate::layers::{build_dense_specs_from_layers, LayerError, LayerPlan, LayerSpec};
use crate::model_format::{
decode_dense_model_v1,
encode_dense_model_v1,
encoded_size_v1,
DecodedCounts,
ModelFormatError,
};
const MAGIC: &[u8; 4] = b"RMD1";
const VERSION: u16 = 1;
const HEADER_SIZE: usize = 4 + 2 + 2 + 4 + 4 + 4;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RnnApiError {
InvalidTopology,
CapacityTooSmall,
BadBytes,
Layer,
Model,
Forward,
}
pub fn rnn_required_dense_from_topology(topology: &[usize]) -> Result<DecodedCounts, RnnApiError> {
if topology.len() < 2 {
return Err(RnnApiError::InvalidTopology);
}
let (weights, biases) = expected_parameter_counts(topology).ok_or(RnnApiError::InvalidTopology)?;
Ok(DecodedCounts {
layers: topology.len() - 1,
weights,
biases,
})
}
pub fn rnn_required_dense_from_bytes_v1(bytes: &[u8]) -> Result<DecodedCounts, RnnApiError> {
if bytes.len() < HEADER_SIZE {
return Err(RnnApiError::BadBytes);
}
if &bytes[0..4] != MAGIC {
return Err(RnnApiError::BadBytes);
}
let version = u16::from_le_bytes([bytes[4], bytes[5]]);
if version != VERSION {
return Err(RnnApiError::BadBytes);
}
let layer_count = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize;
let weights_len = u32::from_le_bytes([bytes[12], bytes[13], bytes[14], bytes[15]]) as usize;
let biases_len = u32::from_le_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]) as usize;
let expected_size = encoded_size_v1(layer_count, weights_len, biases_len).ok_or(RnnApiError::BadBytes)?;
if bytes.len() < expected_size {
return Err(RnnApiError::BadBytes);
}
Ok(DecodedCounts {
layers: layer_count,
weights: weights_len,
biases: biases_len,
})
}
pub fn rnn_pack_dense_v1(
topology: &[usize],
hidden_activation: ActivationKind,
output_activation: ActivationKind,
weights: &[f32],
biases: &[f32],
layer_specs_scratch: &mut [LayerSpec],
out_bytes: &mut [u8],
) -> Result<usize, RnnApiError> {
let layer_count = build_dense_specs_from_layers(
topology,
hidden_activation,
output_activation,
weights.len(),
biases.len(),
layer_specs_scratch,
)
.map_err(map_layer_error)?;
encode_dense_model_v1(&layer_specs_scratch[..layer_count], weights, biases, out_bytes).map_err(map_model_error)
}
pub fn rnn_run_dense_v1(
bytes: &[u8],
input: &[f32],
output: &mut [f32],
layer_specs_scratch: &mut [LayerSpec],
weights_scratch: &mut [f32],
biases_scratch: &mut [f32],
infer_scratch: &mut [f32],
) -> Result<(), RnnApiError> {
let counts = rnn_required_dense_from_bytes_v1(bytes)?;
if layer_specs_scratch.len() < counts.layers || weights_scratch.len() < counts.weights || biases_scratch.len() < counts.biases {
return Err(RnnApiError::CapacityTooSmall);
}
let decoded = decode_dense_model_v1(
bytes,
&mut layer_specs_scratch[..counts.layers],
&mut weights_scratch[..counts.weights],
&mut biases_scratch[..counts.biases],
)
.map_err(map_model_error)?;
let plan = LayerPlan {
layers: &layer_specs_scratch[..decoded.layers],
weights: &weights_scratch[..decoded.weights],
biases: &biases_scratch[..decoded.biases],
};
forward_dense_plan(&plan, input, output, infer_scratch).map_err(map_forward_error)
}
fn map_layer_error(err: LayerError) -> RnnApiError {
match err {
LayerError::EmptyPlan | LayerError::InvalidShape | LayerError::IncompatibleChain => RnnApiError::InvalidTopology,
LayerError::BufferTooSmall => RnnApiError::CapacityTooSmall,
_ => RnnApiError::Layer,
}
}
pub(crate) fn map_model_error(err: ModelFormatError) -> RnnApiError {
match err {
ModelFormatError::CapacityTooSmall => RnnApiError::CapacityTooSmall,
_ => RnnApiError::BadBytes,
}
}
fn map_forward_error(err: ForwardError) -> RnnApiError {
match err {
ForwardError::ScratchTooSmall => RnnApiError::CapacityTooSmall,
_ => RnnApiError::Forward,
}
}