1use crate::activations::ActivationKind;
2use crate::engine::{forward_dense_plan, ForwardError};
3use crate::initializers::expected_parameter_counts;
4use crate::layers::{build_dense_specs_from_layers, LayerError, LayerPlan, LayerSpec};
5use crate::model_format::{
6 decode_dense_model_v1,
7 encode_dense_model_v1,
8 encoded_size_v1,
9 DecodedCounts,
10 ModelFormatError,
11};
12
13const MAGIC: &[u8; 4] = b"RMD1";
14const VERSION: u16 = 1;
15const HEADER_SIZE: usize = 4 + 2 + 2 + 4 + 4 + 4;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum RnnApiError {
19 InvalidTopology,
20 CapacityTooSmall,
21 BadBytes,
22 Layer,
23 Model,
24 Forward,
25}
26
27pub fn rnn_required_dense_from_topology(topology: &[usize]) -> Result<DecodedCounts, RnnApiError> {
28 if topology.len() < 2 {
29 return Err(RnnApiError::InvalidTopology);
30 }
31
32 let (weights, biases) = expected_parameter_counts(topology).ok_or(RnnApiError::InvalidTopology)?;
33 Ok(DecodedCounts {
34 layers: topology.len() - 1,
35 weights,
36 biases,
37 })
38}
39
40pub fn rnn_required_dense_from_bytes_v1(bytes: &[u8]) -> Result<DecodedCounts, RnnApiError> {
41 if bytes.len() < HEADER_SIZE {
42 return Err(RnnApiError::BadBytes);
43 }
44 if &bytes[0..4] != MAGIC {
45 return Err(RnnApiError::BadBytes);
46 }
47
48 let version = u16::from_le_bytes([bytes[4], bytes[5]]);
49 if version != VERSION {
50 return Err(RnnApiError::BadBytes);
51 }
52
53 let layer_count = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize;
54 let weights_len = u32::from_le_bytes([bytes[12], bytes[13], bytes[14], bytes[15]]) as usize;
55 let biases_len = u32::from_le_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]) as usize;
56
57 let expected_size = encoded_size_v1(layer_count, weights_len, biases_len).ok_or(RnnApiError::BadBytes)?;
58 if bytes.len() < expected_size {
59 return Err(RnnApiError::BadBytes);
60 }
61
62 Ok(DecodedCounts {
63 layers: layer_count,
64 weights: weights_len,
65 biases: biases_len,
66 })
67}
68
69pub fn rnn_pack_dense_v1(
70 topology: &[usize],
71 hidden_activation: ActivationKind,
72 output_activation: ActivationKind,
73 weights: &[f32],
74 biases: &[f32],
75 layer_specs_scratch: &mut [LayerSpec],
76 out_bytes: &mut [u8],
77) -> Result<usize, RnnApiError> {
78 let layer_count = build_dense_specs_from_layers(
79 topology,
80 hidden_activation,
81 output_activation,
82 weights.len(),
83 biases.len(),
84 layer_specs_scratch,
85 )
86 .map_err(map_layer_error)?;
87
88 encode_dense_model_v1(&layer_specs_scratch[..layer_count], weights, biases, out_bytes).map_err(map_model_error)
89}
90
91pub fn rnn_run_dense_v1(
92 bytes: &[u8],
93 input: &[f32],
94 output: &mut [f32],
95 layer_specs_scratch: &mut [LayerSpec],
96 weights_scratch: &mut [f32],
97 biases_scratch: &mut [f32],
98 infer_scratch: &mut [f32],
99) -> Result<(), RnnApiError> {
100 let counts = rnn_required_dense_from_bytes_v1(bytes)?;
101 if layer_specs_scratch.len() < counts.layers || weights_scratch.len() < counts.weights || biases_scratch.len() < counts.biases {
102 return Err(RnnApiError::CapacityTooSmall);
103 }
104
105 let decoded = decode_dense_model_v1(
106 bytes,
107 &mut layer_specs_scratch[..counts.layers],
108 &mut weights_scratch[..counts.weights],
109 &mut biases_scratch[..counts.biases],
110 )
111 .map_err(map_model_error)?;
112
113 let plan = LayerPlan {
114 layers: &layer_specs_scratch[..decoded.layers],
115 weights: &weights_scratch[..decoded.weights],
116 biases: &biases_scratch[..decoded.biases],
117 };
118
119 forward_dense_plan(&plan, input, output, infer_scratch).map_err(map_forward_error)
120}
121
122fn map_layer_error(_err: LayerError) -> RnnApiError {
123 RnnApiError::Layer
124}
125
126pub(crate) fn map_model_error(_err: ModelFormatError) -> RnnApiError {
127 RnnApiError::Model
128}
129
130fn map_forward_error(_err: ForwardError) -> RnnApiError {
131 RnnApiError::Forward
132}