Skip to main content

rnn/rnn_api/
rnn_api.rs

1use crate::activations::ActivationKind;
2use crate::engine::{forward_dense_plan, ForwardError};
3use crate::initializers::expected_parameter_counts;
4use crate::layers::{build_dense_specs_from_layers, LayerError, LayerPlan, LayerSpec};
5use crate::model_format::{
6    decode_dense_model_v1,
7    encode_dense_model_v1,
8    encoded_size_v1,
9    DecodedCounts,
10    ModelFormatError,
11};
12
13const MAGIC: &[u8; 4] = b"RMD1";
14const VERSION: u16 = 1;
15const HEADER_SIZE: usize = 4 + 2 + 2 + 4 + 4 + 4;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum RnnApiError {
19    InvalidTopology,
20    CapacityTooSmall,
21    BadBytes,
22    Layer,
23    Model,
24    Forward,
25}
26
27pub fn rnn_required_dense_from_topology(topology: &[usize]) -> Result<DecodedCounts, RnnApiError> {
28    if topology.len() < 2 {
29        return Err(RnnApiError::InvalidTopology);
30    }
31
32    let (weights, biases) = expected_parameter_counts(topology).ok_or(RnnApiError::InvalidTopology)?;
33    Ok(DecodedCounts {
34        layers: topology.len() - 1,
35        weights,
36        biases,
37    })
38}
39
40pub fn rnn_required_dense_from_bytes_v1(bytes: &[u8]) -> Result<DecodedCounts, RnnApiError> {
41    if bytes.len() < HEADER_SIZE {
42        return Err(RnnApiError::BadBytes);
43    }
44    if &bytes[0..4] != MAGIC {
45        return Err(RnnApiError::BadBytes);
46    }
47
48    let version = u16::from_le_bytes([bytes[4], bytes[5]]);
49    if version != VERSION {
50        return Err(RnnApiError::BadBytes);
51    }
52
53    let layer_count = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize;
54    let weights_len = u32::from_le_bytes([bytes[12], bytes[13], bytes[14], bytes[15]]) as usize;
55    let biases_len = u32::from_le_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]) as usize;
56
57    let expected_size = encoded_size_v1(layer_count, weights_len, biases_len).ok_or(RnnApiError::BadBytes)?;
58    if bytes.len() < expected_size {
59        return Err(RnnApiError::BadBytes);
60    }
61
62    Ok(DecodedCounts {
63        layers: layer_count,
64        weights: weights_len,
65        biases: biases_len,
66    })
67}
68
69pub fn rnn_pack_dense_v1(
70    topology: &[usize],
71    hidden_activation: ActivationKind,
72    output_activation: ActivationKind,
73    weights: &[f32],
74    biases: &[f32],
75    layer_specs_scratch: &mut [LayerSpec],
76    out_bytes: &mut [u8],
77) -> Result<usize, RnnApiError> {
78    let layer_count = build_dense_specs_from_layers(
79        topology,
80        hidden_activation,
81        output_activation,
82        weights.len(),
83        biases.len(),
84        layer_specs_scratch,
85    )
86    .map_err(map_layer_error)?;
87
88    encode_dense_model_v1(&layer_specs_scratch[..layer_count], weights, biases, out_bytes).map_err(map_model_error)
89}
90
91pub fn rnn_run_dense_v1(
92    bytes: &[u8],
93    input: &[f32],
94    output: &mut [f32],
95    layer_specs_scratch: &mut [LayerSpec],
96    weights_scratch: &mut [f32],
97    biases_scratch: &mut [f32],
98    infer_scratch: &mut [f32],
99) -> Result<(), RnnApiError> {
100    let counts = rnn_required_dense_from_bytes_v1(bytes)?;
101    if layer_specs_scratch.len() < counts.layers || weights_scratch.len() < counts.weights || biases_scratch.len() < counts.biases {
102        return Err(RnnApiError::CapacityTooSmall);
103    }
104
105    let decoded = decode_dense_model_v1(
106        bytes,
107        &mut layer_specs_scratch[..counts.layers],
108        &mut weights_scratch[..counts.weights],
109        &mut biases_scratch[..counts.biases],
110    )
111    .map_err(map_model_error)?;
112
113    let plan = LayerPlan {
114        layers: &layer_specs_scratch[..decoded.layers],
115        weights: &weights_scratch[..decoded.weights],
116        biases: &biases_scratch[..decoded.biases],
117    };
118
119    forward_dense_plan(&plan, input, output, infer_scratch).map_err(map_forward_error)
120}
121
122fn map_layer_error(_err: LayerError) -> RnnApiError {
123    RnnApiError::Layer
124}
125
126pub(crate) fn map_model_error(_err: ModelFormatError) -> RnnApiError {
127    RnnApiError::Model
128}
129
130fn map_forward_error(_err: ForwardError) -> RnnApiError {
131    RnnApiError::Forward
132}