rnn/network/
network_stats.rs1use super::NeuralNetwork;
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq)]
4pub struct NetworkStats {
5 pub layer_count: usize,
6 pub input_size: usize,
7 pub output_size: usize,
8 pub total_weights: usize,
9 pub total_biases: usize,
10}
11
12pub fn network_stats(network: &NeuralNetwork<'_>) -> Option<NetworkStats> {
13 let input_size = *network.layers.first()?;
14 let output_size = *network.layers.last()?;
15
16 Some(NetworkStats {
17 layer_count: network.layer_count(),
18 input_size,
19 output_size,
20 total_weights: network.weights.len(),
21 total_biases: network.biases.len(),
22 })
23}
24
25pub fn validate_network_parts(layers: &[usize], weights: &[f32], biases: &[f32]) -> bool {
26 let expected_w = NeuralNetwork::expected_weights_count(layers);
27 let expected_b = NeuralNetwork::expected_biases_count(layers);
28 match (expected_w, expected_b) {
29 (Some(w), Some(b)) => w == weights.len() && b == biases.len(),
30 _ => false,
31 }
32}