#[derive(Debug)]
pub enum NetworkStdError {
InvalidShape,
InvalidPlan,
ShapeMismatch,
ScratchTooSmall,
}
impl From<native_neural_network::engine::ForwardError> for NetworkStdError {
fn from(e: native_neural_network::engine::ForwardError) -> Self {
match e {
native_neural_network::engine::ForwardError::InvalidPlan => {
NetworkStdError::InvalidPlan
}
native_neural_network::engine::ForwardError::ShapeMismatch => {
NetworkStdError::ShapeMismatch
}
native_neural_network::engine::ForwardError::ScratchTooSmall => {
NetworkStdError::ScratchTooSmall
}
}
}
}
pub struct NeuralNetworkStd {
pub layers: Vec<usize>,
pub weights: Vec<f32>,
pub biases: Vec<f32>,
}
impl NeuralNetworkStd {
pub fn from_parts(
layers: Vec<usize>,
weights: Vec<f32>,
biases: Vec<f32>,
) -> Result<Self, NetworkStdError> {
if native_neural_network::network::NeuralNetwork::from_parts(&layers, &weights, &biases)
.is_none()
{
return Err(NetworkStdError::InvalidShape);
}
Ok(NeuralNetworkStd {
layers,
weights,
biases,
})
}
pub fn expected_weights_count(layers: &[usize]) -> Option<usize> {
native_neural_network::network::NeuralNetwork::expected_weights_count(layers)
}
pub fn expected_biases_count(layers: &[usize]) -> Option<usize> {
native_neural_network::network::NeuralNetwork::expected_biases_count(layers)
}
pub fn layer_count(&self) -> usize {
self.layers.len().saturating_sub(1)
}
pub fn build_layer_specs(
&self,
hidden_activation: native_neural_network::activations::ActivationKind,
output_activation: native_neural_network::activations::ActivationKind,
) -> Result<Vec<crate::std::layers_std::LayerSpec>, NetworkStdError> {
let layer_count = self.layer_count();
let mut out = vec![
native_neural_network::layers::LayerSpec::Dense(
native_neural_network::layers::LayerDesc {
input_size: 0,
output_size: 0,
weight_offset: 0,
bias_offset: 0,
activation: native_neural_network::activations::ActivationKind::Identity
}
);
layer_count
];
let used = native_neural_network::layers::build_from_layers(
&self.layers,
hidden_activation,
output_activation,
self.weights.len(),
self.biases.len(),
&mut out,
)
.map_err(|_e| NetworkStdError::InvalidShape)?;
out.truncate(used);
Ok(out
.into_iter()
.map(crate::std::layers_std::LayerSpec::from)
.collect())
}
pub fn forward(
&self,
input: &[f32],
output: &mut [f32],
scratch: &mut [f32],
layer_scratch: &mut [crate::std::layers_std::LayerSpec],
hidden_activation: native_neural_network::activations::ActivationKind,
output_activation: native_neural_network::activations::ActivationKind,
) -> Result<(), NetworkStdError> {
let mut native_buf: Vec<_> = vec![
native_neural_network::layers::LayerSpec::Dense(
native_neural_network::layers::LayerDesc {
input_size: 0,
output_size: 0,
weight_offset: 0,
bias_offset: 0,
activation: native_neural_network::activations::ActivationKind::Identity
}
);
layer_scratch.len()
];
let fill =
crate::std::layers_std::fill_native_slice_from_std(layer_scratch, &mut native_buf);
let nn_view = native_neural_network::network::NeuralNetwork::from_parts(
&self.layers,
&self.weights,
&self.biases,
)
.ok_or(NetworkStdError::InvalidShape)?;
let res = nn_view
.forward(
input,
output,
scratch,
&mut native_buf[..fill],
hidden_activation,
output_activation,
)
.map_err(|e| e.into());
crate::std::layers_std::fill_std_slice_from_native(&native_buf[..fill], layer_scratch);
res
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct NetworkStatsStd {
pub layer_count: usize,
pub input_size: usize,
pub output_size: usize,
pub total_weights: usize,
pub total_biases: usize,
}
pub fn network_stats(layers: &[usize], weights: &[f32], biases: &[f32]) -> Option<NetworkStatsStd> {
let nn = native_neural_network::network::NeuralNetwork::from_parts(layers, weights, biases)?;
native_neural_network::network::network_stats(&nn).map(|s| NetworkStatsStd {
layer_count: s.layer_count,
input_size: s.input_size,
output_size: s.output_size,
total_weights: s.total_weights,
total_biases: s.total_biases,
})
}
pub type NetworkStats = NetworkStatsStd;
pub fn validate_network_parts(layers: &[usize], weights: &[f32], biases: &[f32]) -> bool {
native_neural_network::network::validate_network_parts(layers, weights, biases)
}