#[derive(Debug)]
pub enum TrainerStdError {
InvalidShape,
InvalidConfig,
CountMismatch,
BufferTooSmall,
ForwardNaN,
LossError,
}
impl From<native_neural_network::trainer::TrainError> for TrainerStdError {
fn from(e: native_neural_network::trainer::TrainError) -> Self {
match e {
native_neural_network::trainer::TrainError::InvalidShape => {
TrainerStdError::InvalidShape
}
native_neural_network::trainer::TrainError::InvalidConfig => {
TrainerStdError::InvalidConfig
}
native_neural_network::trainer::TrainError::CountMismatch => {
TrainerStdError::CountMismatch
}
native_neural_network::trainer::TrainError::BufferTooSmall => {
TrainerStdError::BufferTooSmall
}
native_neural_network::trainer::TrainError::ForwardNaN => TrainerStdError::ForwardNaN,
native_neural_network::trainer::TrainError::LossError => TrainerStdError::LossError,
}
}
}
pub use native_neural_network::trainer::SgdConfig;
pub fn required_train_buffer_len(layers: &[usize]) -> Option<usize> {
native_neural_network::trainer::required_train_buffer_len(layers)
}
#[derive(Debug)]
pub struct SgdParams<'a> {
pub layers: &'a [usize],
pub weights: &'a mut [f32],
pub biases: &'a mut [f32],
pub input: &'a [f32],
pub target: &'a [f32],
pub layer_specs_scratch: &'a mut [crate::std::layers_std::LayerSpec],
pub activations_scratch: &'a mut [f32],
pub deltas_scratch: &'a mut [f32],
pub config: SgdConfig,
}
pub fn sgd_step(params: SgdParams) -> Result<f32, TrainerStdError> {
let mut native_buf: Vec<_> = vec![
native_neural_network::layers::LayerSpec::Dense(
native_neural_network::layers::DenseLayerDesc {
input_size: 0,
output_size: 0,
weight_offset: 0,
bias_offset: 0,
activation: native_neural_network::activations::ActivationKind::Identity
}
);
params.layer_specs_scratch.len()
];
let fill = crate::std::layers_std::fill_native_slice_from_std(
params.layer_specs_scratch,
&mut native_buf,
);
let native_slice = &mut native_buf[..fill];
let mut native_scratch = native_neural_network::trainer::SgdScratch {
layer_specs_scratch: native_slice,
activations_scratch: params.activations_scratch,
deltas_scratch: params.deltas_scratch,
};
let res = native_neural_network::trainer::sgd_step(
params.layers,
params.weights,
params.biases,
params.input,
params.target,
&mut native_scratch,
params.config,
)
.map_err(Into::into);
crate::std::layers_std::fill_std_slice_from_native(native_slice, params.layer_specs_scratch);
res
}
impl core::fmt::Display for TrainerStdError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "TrainerStdError::{:?}", self)
}
}
impl std::error::Error for TrainerStdError {}