use crate::layers::LayerSpec;
use super::RnnApiError;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct RnnDenseCapacity {
pub layer_specs: usize,
pub weights: usize,
pub biases: usize,
pub infer_scratch: usize,
}
pub fn rnn_dense_required_buffers(topology: &[usize]) -> Result<RnnDenseCapacity, RnnApiError> {
if topology.len() < 2 {
return Err(RnnApiError::InvalidTopology);
}
let mut weights = 0usize;
let mut biases = 0usize;
let mut max_width = 0usize;
for i in 0..topology.len() {
let w = topology[i];
if w == 0 {
return Err(RnnApiError::InvalidTopology);
}
if w > max_width {
max_width = w;
}
if i + 1 < topology.len() {
let next = topology[i + 1];
weights = weights
.checked_add(w.checked_mul(next).ok_or(RnnApiError::CapacityTooSmall)?)
.ok_or(RnnApiError::CapacityTooSmall)?;
biases = biases.checked_add(next).ok_or(RnnApiError::CapacityTooSmall)?;
}
}
let infer_scratch = max_width
.checked_mul(2)
.ok_or(RnnApiError::CapacityTooSmall)?;
Ok(RnnDenseCapacity {
layer_specs: topology.len() - 1,
weights,
biases,
infer_scratch,
})
}
pub fn rnn_dense_required_infer_scratch_from_specs(layers: &[LayerSpec]) -> Result<usize, RnnApiError> {
if layers.is_empty() {
return Err(RnnApiError::InvalidTopology);
}
let mut max_width = 0usize;
for layer in layers {
let in_w = layer.input_size();
let out_w = layer.output_size();
if in_w == 0 || out_w == 0 {
return Err(RnnApiError::InvalidTopology);
}
if in_w > max_width {
max_width = in_w;
}
if out_w > max_width {
max_width = out_w;
}
}
max_width
.checked_mul(2)
.ok_or(RnnApiError::CapacityTooSmall)
}