Skip to main content

rnn/rnn_api/
capacity.rs

1use crate::layers::LayerSpec;
2
3use super::RnnApiError;
4
5#[derive(Clone, Copy, Debug, PartialEq, Eq)]
6pub struct RnnDenseCapacity {
7    pub layer_specs: usize,
8    pub weights: usize,
9    pub biases: usize,
10    pub infer_scratch: usize,
11}
12
13pub fn rnn_dense_required_buffers(topology: &[usize]) -> Result<RnnDenseCapacity, RnnApiError> {
14    if topology.len() < 2 {
15        return Err(RnnApiError::InvalidTopology);
16    }
17
18    let mut weights = 0usize;
19    let mut biases = 0usize;
20    let mut max_width = 0usize;
21
22    for i in 0..topology.len() {
23        let w = topology[i];
24        if w == 0 {
25            return Err(RnnApiError::InvalidTopology);
26        }
27        if w > max_width {
28            max_width = w;
29        }
30        if i + 1 < topology.len() {
31            let next = topology[i + 1];
32            weights = weights
33                .checked_add(w.checked_mul(next).ok_or(RnnApiError::CapacityTooSmall)?)
34                .ok_or(RnnApiError::CapacityTooSmall)?;
35            biases = biases.checked_add(next).ok_or(RnnApiError::CapacityTooSmall)?;
36        }
37    }
38
39    let infer_scratch = max_width
40        .checked_mul(2)
41        .ok_or(RnnApiError::CapacityTooSmall)?;
42
43    Ok(RnnDenseCapacity {
44        layer_specs: topology.len() - 1,
45        weights,
46        biases,
47        infer_scratch,
48    })
49}
50
51pub fn rnn_dense_required_infer_scratch_from_specs(layers: &[LayerSpec]) -> Result<usize, RnnApiError> {
52    if layers.is_empty() {
53        return Err(RnnApiError::InvalidTopology);
54    }
55
56    let mut max_width = 0usize;
57    for layer in layers {
58        let in_w = layer.input_size();
59        let out_w = layer.output_size();
60        if in_w == 0 || out_w == 0 {
61            return Err(RnnApiError::InvalidTopology);
62        }
63        if in_w > max_width {
64            max_width = in_w;
65        }
66        if out_w > max_width {
67            max_width = out_w;
68        }
69    }
70
71    max_width
72        .checked_mul(2)
73        .ok_or(RnnApiError::CapacityTooSmall)
74}