Skip to main content

rnn/network/
network_stats.rs

1use super::NeuralNetwork;
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq)]
4pub struct NetworkStats {
5    pub layer_count: usize,
6    pub input_size: usize,
7    pub output_size: usize,
8    pub total_weights: usize,
9    pub total_biases: usize,
10}
11
12pub fn network_stats(network: &NeuralNetwork<'_>) -> Option<NetworkStats> {
13    let input_size = *network.layers.first()?;
14    let output_size = *network.layers.last()?;
15
16    Some(NetworkStats {
17        layer_count: network.layer_count(),
18        input_size,
19        output_size,
20        total_weights: network.weights.len(),
21        total_biases: network.biases.len(),
22    })
23}
24
25pub fn validate_network_parts(layers: &[usize], weights: &[f32], biases: &[f32]) -> bool {
26    let expected_w = NeuralNetwork::expected_weights_count(layers);
27    let expected_b = NeuralNetwork::expected_biases_count(layers);
28    match (expected_w, expected_b) {
29        (Some(w), Some(b)) => w == weights.len() && b == biases.len(),
30        _ => false,
31    }
32}