1use crate::layers::LayerSpec;
2
3use super::RnnApiError;
4
5#[derive(Clone, Copy, Debug, PartialEq, Eq)]
6pub struct RnnDenseCapacity {
7 pub layer_specs: usize,
8 pub weights: usize,
9 pub biases: usize,
10 pub infer_scratch: usize,
11}
12
13pub fn rnn_dense_required_buffers(topology: &[usize]) -> Result<RnnDenseCapacity, RnnApiError> {
14 if topology.len() < 2 {
15 return Err(RnnApiError::InvalidTopology);
16 }
17
18 let mut weights = 0usize;
19 let mut biases = 0usize;
20 let mut max_width = 0usize;
21
22 for i in 0..topology.len() {
23 let w = topology[i];
24 if w == 0 {
25 return Err(RnnApiError::InvalidTopology);
26 }
27 if w > max_width {
28 max_width = w;
29 }
30 if i + 1 < topology.len() {
31 let next = topology[i + 1];
32 weights = weights
33 .checked_add(w.checked_mul(next).ok_or(RnnApiError::CapacityTooSmall)?)
34 .ok_or(RnnApiError::CapacityTooSmall)?;
35 biases = biases.checked_add(next).ok_or(RnnApiError::CapacityTooSmall)?;
36 }
37 }
38
39 let infer_scratch = max_width
40 .checked_mul(2)
41 .ok_or(RnnApiError::CapacityTooSmall)?;
42
43 Ok(RnnDenseCapacity {
44 layer_specs: topology.len() - 1,
45 weights,
46 biases,
47 infer_scratch,
48 })
49}
50
51pub fn rnn_dense_required_infer_scratch_from_specs(layers: &[LayerSpec]) -> Result<usize, RnnApiError> {
52 if layers.is_empty() {
53 return Err(RnnApiError::InvalidTopology);
54 }
55
56 let mut max_width = 0usize;
57 for layer in layers {
58 let in_w = layer.input_size();
59 let out_w = layer.output_size();
60 if in_w == 0 || out_w == 0 {
61 return Err(RnnApiError::InvalidTopology);
62 }
63 if in_w > max_width {
64 max_width = in_w;
65 }
66 if out_w > max_width {
67 max_width = out_w;
68 }
69 }
70
71 max_width
72 .checked_mul(2)
73 .ok_or(RnnApiError::CapacityTooSmall)
74}