native_neural_network_std 0.2.1

Ergonomic std wrapper for the `native_neural_network` crate (no_std) — std-friendly re-exports and utilities.
Documentation
#[derive(Debug)]
pub enum TrainerStdError {
    InvalidShape,
    InvalidConfig,
    CountMismatch,
    BufferTooSmall,
    ForwardNaN,
    LossError,
}

impl From<native_neural_network::trainer::TrainError> for TrainerStdError {
    fn from(e: native_neural_network::trainer::TrainError) -> Self {
        match e {
            native_neural_network::trainer::TrainError::InvalidShape => {
                TrainerStdError::InvalidShape
            }
            native_neural_network::trainer::TrainError::InvalidConfig => {
                TrainerStdError::InvalidConfig
            }
            native_neural_network::trainer::TrainError::CountMismatch => {
                TrainerStdError::CountMismatch
            }
            native_neural_network::trainer::TrainError::BufferTooSmall => {
                TrainerStdError::BufferTooSmall
            }
            native_neural_network::trainer::TrainError::ForwardNaN => TrainerStdError::ForwardNaN,
            native_neural_network::trainer::TrainError::LossError => TrainerStdError::LossError,
        }
    }
}

pub use native_neural_network::trainer::SgdConfig;

pub fn required_train_buffer_len(layers: &[usize]) -> Option<usize> {
    native_neural_network::trainer::required_train_buffer_len(layers)
}

#[derive(Debug)]
pub struct SgdParams<'a> {
    pub layers: &'a [usize],
    pub weights: &'a mut [f32],
    pub biases: &'a mut [f32],
    pub input: &'a [f32],
    pub target: &'a [f32],
    pub layer_specs_scratch: &'a mut [crate::std::layers_std::LayerSpec],
    pub activations_scratch: &'a mut [f32],
    pub deltas_scratch: &'a mut [f32],
    pub config: SgdConfig,
}

pub fn sgd_step(params: SgdParams) -> Result<f32, TrainerStdError> {
    let mut native_buf: Vec<_> = vec![
        native_neural_network::layers::LayerSpec::Dense(
            native_neural_network::layers::DenseLayerDesc {
                input_size: 0,
                output_size: 0,
                weight_offset: 0,
                bias_offset: 0,
                activation: native_neural_network::activations::ActivationKind::Identity
            }
        );
        params.layer_specs_scratch.len()
    ];
    let fill = crate::std::layers_std::fill_native_slice_from_std(
        params.layer_specs_scratch,
        &mut native_buf,
    );
    let native_slice = &mut native_buf[..fill];
    let mut native_scratch = native_neural_network::trainer::SgdScratch {
        layer_specs_scratch: native_slice,
        activations_scratch: params.activations_scratch,
        deltas_scratch: params.deltas_scratch,
    };
    let res = native_neural_network::trainer::sgd_step(
        params.layers,
        params.weights,
        params.biases,
        params.input,
        params.target,
        &mut native_scratch,
        params.config,
    )
    .map_err(Into::into);
    crate::std::layers_std::fill_std_slice_from_native(native_slice, params.layer_specs_scratch);
    res
}

impl core::fmt::Display for TrainerStdError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        write!(f, "TrainerStdError::{:?}", self)
    }
}

impl std::error::Error for TrainerStdError {}