#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PaddingType {
Valid,
Same,
}
macro_rules! update_sgd_conv {
() => {
fn update_parameters_sgd(&mut self, lr: f32) {
if let (Some(weight_grads), Some(bias_grads)) =
(&self.weight_gradients, &self.bias_gradients)
{
SGD::update_sgd_parameters(
self.weights.as_slice_mut().unwrap(),
weight_grads.as_slice().unwrap(),
self.bias.as_slice_mut().unwrap(),
bias_grads.as_slice().unwrap(),
lr,
)
}
}
};
}
macro_rules! update_adagrad_conv {
($self:expr, $weight_gradients:expr, $bias_gradients:expr, $lr:expr, $epsilon:expr) => {
if let Some(ada_grad_cache) = &mut $self.optimizer_cache.ada_grad_cache {
let update_parameters = |params: &mut [f32], accumulator: &mut [f32], grads: &[f32]| {
accumulator
.par_iter_mut()
.zip(grads.par_iter())
.for_each(|(acc, &grad)| {
*acc += grad * grad;
});
params
.par_iter_mut()
.zip(grads.par_iter())
.zip(accumulator.par_iter())
.for_each(|((param, &grad), &acc_val)| {
*param -= $lr * grad / (acc_val.sqrt() + $epsilon);
});
};
update_parameters(
$self.weights.as_slice_mut().unwrap(),
ada_grad_cache.accumulator.as_slice_mut().unwrap(),
$weight_gradients.as_slice().unwrap(),
);
update_parameters(
$self.bias.as_slice_mut().unwrap(),
ada_grad_cache.accumulator_bias.as_slice_mut().unwrap(),
$bias_gradients.as_slice().unwrap(),
);
}
};
}
pub mod conv_1d;
pub mod conv_2d;
pub mod conv_3d;
pub mod depthwise_conv_2d;
mod input_validation_function;
pub mod separable_conv_2d;
pub use conv_1d::Conv1D;
pub use conv_2d::Conv2D;
pub use conv_3d::Conv3D;
pub use depthwise_conv_2d::DepthwiseConv2D;
pub use separable_conv_2d::SeparableConv2D;