use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::convolution_layer::PaddingType;
use crate::neural_network::layer::convolution_layer::input_validation_function::{
validate_filters, validate_kernel_size_2d, validate_strides_2d,
};
use crate::neural_network::layer::helper_function::{
calculate_output_shape_2d, pad_tensor_2d, update_adam_conv, update_rmsprop,
};
use crate::neural_network::layer::layer_weight::{DepthwiseConv2DLayerWeight, LayerWeight};
use crate::neural_network::neural_network_trait::{ActivationLayer, Layer};
use crate::neural_network::optimizer::{
OptimizerCacheConv2D, ada_grad::AdaGradStatesConv2D, adam::AdamStatesConv2D,
rms_prop::RMSpropCacheConv2D, sgd::SGD,
};
use ndarray::{Array1, Array2, Array4, ArrayView2, ArrayViewD, Axis, s};
use ndarray_rand::rand::random;
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelIterator,
};
const DEPTHWISE_CONV_2D_PARALLEL_THRESHOLD: usize = 1500;
pub struct DepthwiseConv2D<T: ActivationLayer> {
filters: usize,
kernel_size: (usize, usize),
strides: (usize, usize),
padding: PaddingType,
weights: Array4<f32>,
bias: Array1<f32>,
activation: T,
input: Option<Tensor>,
input_shape: Vec<usize>,
weight_gradients: Option<Array4<f32>>,
bias_gradients: Option<Array1<f32>>,
optimizer_cache: OptimizerCacheConv2D,
}
impl<T: ActivationLayer> DepthwiseConv2D<T> {
pub fn new(
filters: usize,
kernel_size: (usize, usize),
strides: (usize, usize),
padding: PaddingType,
activation: T,
) -> Result<Self, ModelError> {
validate_filters(filters)?;
validate_kernel_size_2d(kernel_size)?;
validate_strides_2d(strides)?;
let (kernel_height, kernel_width) = kernel_size;
let weights = Array4::zeros((filters, 1, kernel_height, kernel_width));
let bias = Array1::zeros(filters);
Ok(Self {
filters,
kernel_size,
strides,
padding,
weights,
bias,
activation,
input: None,
input_shape: Vec::new(),
weight_gradients: None,
bias_gradients: None,
optimizer_cache: OptimizerCacheConv2D {
adam_states: None,
rmsprop_cache: None,
ada_grad_cache: None,
},
})
}
pub fn initialize_weights(&mut self, input_channels: usize) {
assert_eq!(
self.filters, input_channels,
"For depthwise convolution, number of filters must equal input channels"
);
let (kernel_height, kernel_width) = self.kernel_size;
let fan_in = kernel_height * kernel_width;
let fan_out = kernel_height * kernel_width; let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
self.weights
.axis_iter_mut(Axis(0))
.into_par_iter()
.for_each(|mut filter| {
filter
.slice_mut(s![0, .., ..])
.par_mapv_inplace(|_| (random::<f32>() - 0.5) * 2.0 * limit);
});
self.bias
.par_mapv_inplace(|_| (random::<f32>() - 0.5) * 0.1);
}
fn calculate_padding(
&self,
input_height: usize,
input_width: usize,
output_height: usize,
output_width: usize,
) -> (usize, usize) {
match self.padding {
PaddingType::Valid => (0, 0),
PaddingType::Same => {
let pad_h = ((output_height - 1) * self.strides.0 + self.kernel_size.0)
.saturating_sub(input_height);
let pad_w = ((output_width - 1) * self.strides.1 + self.kernel_size.1)
.saturating_sub(input_width);
(pad_h, pad_w)
}
}
}
pub fn set_weights(&mut self, weights: Array4<f32>, bias: Array1<f32>) {
self.weights = weights;
self.bias = bias;
}
fn convolve_channel(
input_channel: &ArrayView2<f32>,
kernel: &ArrayView2<f32>,
bias: f32,
output_shape: (usize, usize),
strides: (usize, usize),
kernel_size: (usize, usize),
padding: &PaddingType,
pad_h: usize,
pad_w: usize,
) -> Array2<f32> {
let padded_input = if *padding == PaddingType::Same {
pad_tensor_2d(&input_channel.to_owned(), pad_h, pad_w)
} else {
input_channel.to_owned()
};
let (output_height, output_width) = output_shape;
let mut channel_output = Array2::zeros(output_shape);
for oh in 0..output_height {
for ow in 0..output_width {
let start_h = oh * strides.0;
let start_w = ow * strides.1;
let end_h = start_h + kernel_size.0;
let end_w = start_w + kernel_size.1;
if end_h <= padded_input.shape()[0] && end_w <= padded_input.shape()[1] {
let input_patch = padded_input.slice(s![start_h..end_h, start_w..end_w]);
let conv_result = (&input_patch * kernel).sum();
channel_output[[oh, ow]] = conv_result + bias;
}
}
}
channel_output
}
fn compute_batch_gradients(
&self,
input_array: &ArrayViewD<f32>,
grad_upstream: &Tensor,
batch_idx: usize,
channels: usize,
input_height: usize,
input_width: usize,
output_height: usize,
output_width: usize,
pad_h: usize,
pad_w: usize,
) -> (Array4<f32>, Array4<f32>) {
let mut batch_weight_grads = Array4::zeros(self.weights.raw_dim());
let mut batch_input_grads = Array4::zeros((1, channels, input_height, input_width));
for c in 0..channels {
let input_channel = input_array.slice(s![batch_idx, c, .., ..]);
let grad_channel = grad_upstream.slice(s![batch_idx, c, .., ..]);
let padded_input = if self.padding == PaddingType::Same {
pad_tensor_2d(&input_channel.to_owned(), pad_h, pad_w)
} else {
input_channel.to_owned()
};
for kh in 0..self.kernel_size.0 {
for kw in 0..self.kernel_size.1 {
let mut weight_grad = 0.0;
for oh in 0..output_height {
for ow in 0..output_width {
let ih = oh * self.strides.0 + kh;
let iw = ow * self.strides.1 + kw;
if ih < padded_input.shape()[0] && iw < padded_input.shape()[1] {
weight_grad += padded_input[[ih, iw]] * grad_channel[[oh, ow]];
}
}
}
batch_weight_grads[[c, 0, kh, kw]] = weight_grad;
}
}
for oh in 0..output_height {
for ow in 0..output_width {
let grad_val = grad_channel[[oh, ow]];
for kh in 0..self.kernel_size.0 {
for kw in 0..self.kernel_size.1 {
let ih = oh * self.strides.0 + kh;
let iw = ow * self.strides.1 + kw;
if ih < input_height && iw < input_width {
batch_input_grads[[0, c, ih, iw]] +=
self.weights[[c, 0, kh, kw]] * grad_val;
}
}
}
}
}
}
(batch_weight_grads, batch_input_grads)
}
}
impl<T: ActivationLayer> Layer for DepthwiseConv2D<T> {
fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
if input.ndim() != 4 {
return Err(ModelError::InputValidationError(
"input tensor is not 4D".to_string(),
));
}
self.input = Some(input.clone());
self.input_shape = input.shape().to_vec();
let input_array = input.view().into_dimensionality::<ndarray::Ix4>().unwrap();
let (batch_size, channels, height, width) = (
input_array.shape()[0],
input_array.shape()[1],
input_array.shape()[2],
input_array.shape()[3],
);
assert_eq!(
channels, self.filters,
"Input channels must equal number of filters for depthwise convolution"
);
let output_shape = calculate_output_shape_2d(
&self.input_shape,
self.kernel_size,
self.strides,
&self.padding,
);
let (_, _, output_height, output_width) = (
output_shape[0],
output_shape[1],
output_shape[2],
output_shape[3],
);
let (pad_h, pad_w) = self.calculate_padding(height, width, output_height, output_width);
let mut output = Array4::zeros((batch_size, channels, output_height, output_width));
let total_elements = batch_size * channels * output_height * output_width;
if total_elements >= DEPTHWISE_CONV_2D_PARALLEL_THRESHOLD {
output
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(b, mut batch_output)| {
batch_output
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(c, mut channel_output)| {
let input_channel = input_array.slice(s![b, c, .., ..]);
let kernel = self.weights.slice(s![c, 0, .., ..]);
let result = Self::convolve_channel(
&input_channel,
&kernel,
self.bias[c],
(output_height, output_width),
self.strides,
self.kernel_size,
&self.padding,
pad_h,
pad_w,
);
channel_output.assign(&result);
});
});
} else {
for b in 0..batch_size {
for c in 0..channels {
let input_channel = input_array.slice(s![b, c, .., ..]);
let kernel = self.weights.slice(s![c, 0, .., ..]);
let result = Self::convolve_channel(
&input_channel,
&kernel,
self.bias[c],
(output_height, output_width),
self.strides,
self.kernel_size,
&self.padding,
pad_h,
pad_w,
);
output.slice_mut(s![b, c, .., ..]).assign(&result);
}
}
}
let output = output.into_dyn();
self.activation.forward(&output.into_dyn())
}
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
let grad_upstream = self.activation.backward(grad_output)?;
let input = self.input.as_ref().ok_or_else(|| {
ModelError::ProcessingError("Forward pass has not been run".to_string())
})?;
let input_array = input.view();
let (batch_size, channels, input_height, input_width) = (
input_array.shape()[0],
input_array.shape()[1],
input_array.shape()[2],
input_array.shape()[3],
);
let (_, _, output_height, output_width) = (
grad_upstream.shape()[0],
grad_upstream.shape()[1],
grad_upstream.shape()[2],
grad_upstream.shape()[3],
);
let mut weight_grads = Array4::zeros(self.weights.raw_dim());
let mut bias_grads = Array1::zeros(self.bias.raw_dim());
let mut input_grads = Array4::zeros((batch_size, channels, input_height, input_width));
for c in 0..channels {
let mut channel_sum = 0.0;
for b in 0..batch_size {
channel_sum += grad_upstream.slice(s![b, c, .., ..]).sum();
}
bias_grads[c] = channel_sum;
}
let (pad_h, pad_w) =
self.calculate_padding(input_height, input_width, output_height, output_width);
let total_elements = batch_size * channels * output_height * output_width;
if total_elements >= DEPTHWISE_CONV_2D_PARALLEL_THRESHOLD {
let batch_results: Vec<(Array4<f32>, Array4<f32>)> = (0..batch_size)
.into_par_iter()
.map(|b| {
self.compute_batch_gradients(
&input_array,
&grad_upstream,
b,
channels,
input_height,
input_width,
output_height,
output_width,
pad_h,
pad_w,
)
})
.collect();
for (b, (batch_weight_grads, batch_input_grads)) in
batch_results.into_iter().enumerate()
{
weight_grads += &batch_weight_grads;
input_grads
.slice_mut(s![b, .., .., ..])
.assign(&batch_input_grads.slice(s![0, .., .., ..]));
}
} else {
for b in 0..batch_size {
let (batch_weight_grads, batch_input_grads) = self.compute_batch_gradients(
&input_array,
&grad_upstream,
b,
channels,
input_height,
input_width,
output_height,
output_width,
pad_h,
pad_w,
);
weight_grads += &batch_weight_grads;
input_grads
.slice_mut(s![b, .., .., ..])
.assign(&batch_input_grads.slice(s![0, .., .., ..]));
}
}
self.weight_gradients = Some(weight_grads);
self.bias_gradients = Some(bias_grads);
Ok(input_grads.into_dyn())
}
fn layer_type(&self) -> &str {
"DepthwiseConv2D"
}
fn output_shape(&self) -> String {
if !self.input_shape.is_empty() {
let output_shape = calculate_output_shape_2d(
&self.input_shape,
self.kernel_size,
self.strides,
&self.padding,
);
format!(
"({}, {}, {}, {})",
output_shape[0], output_shape[1], output_shape[2], output_shape[3]
)
} else {
String::from("Unknown")
}
}
fn param_count(&self) -> TrainingParameters {
TrainingParameters::Trainable(self.weights.len() + self.bias.len())
}
update_sgd_conv!();
fn update_parameters_adam(&mut self, lr: f32, beta1: f32, beta2: f32, epsilon: f32, t: u64) {
if let (Some(weight_grads), Some(bias_grads)) =
(&self.weight_gradients, &self.bias_gradients)
{
if self.optimizer_cache.adam_states.is_none() {
self.optimizer_cache.adam_states = Some(AdamStatesConv2D {
m: Array4::zeros(self.weights.raw_dim()),
v: Array4::zeros(self.weights.raw_dim()),
m_bias: Array2::zeros((1, self.bias.len())),
v_bias: Array2::zeros((1, self.bias.len())),
});
}
if let Some(ref mut adam_states) = self.optimizer_cache.adam_states {
let bias_correction1 = 1.0 - beta1.powi(t as i32);
let bias_correction2 = 1.0 - beta2.powi(t as i32);
if let (Some(weights_slice), Some(grads_slice), Some(m_slice), Some(v_slice)) = (
self.weights.as_slice_mut(),
weight_grads.as_slice(),
adam_states.m.as_slice_mut(),
adam_states.v.as_slice_mut(),
) {
update_adam_conv(
weights_slice,
grads_slice,
m_slice,
v_slice,
lr,
beta1,
beta2,
epsilon,
bias_correction1,
bias_correction2,
);
}
if let (
Some(bias_slice),
Some(bias_grads_slice),
Some(m_bias_slice),
Some(v_bias_slice),
) = (
self.bias.as_slice_mut(),
bias_grads.as_slice(),
adam_states.m_bias.as_slice_mut(),
adam_states.v_bias.as_slice_mut(),
) {
update_adam_conv(
bias_slice,
bias_grads_slice,
m_bias_slice,
v_bias_slice,
lr,
beta1,
beta2,
epsilon,
bias_correction1,
bias_correction2,
);
}
}
}
}
fn update_parameters_rmsprop(&mut self, lr: f32, rho: f32, epsilon: f32) {
if let (Some(weight_grads), Some(bias_grads)) =
(&self.weight_gradients, &self.bias_gradients)
{
if self.optimizer_cache.rmsprop_cache.is_none() {
self.optimizer_cache.rmsprop_cache = Some(RMSpropCacheConv2D {
cache: Array4::zeros(self.weights.raw_dim()),
bias: Array2::zeros((1, self.bias.len())),
});
}
if let Some(ref mut rmsprop_cache) = self.optimizer_cache.rmsprop_cache {
if let (Some(weights_slice), Some(grads_slice), Some(cache_slice)) = (
self.weights.as_slice_mut(),
weight_grads.as_slice(),
rmsprop_cache.cache.as_slice_mut(),
) {
update_rmsprop(weights_slice, grads_slice, cache_slice, rho, epsilon, lr);
}
if let (Some(bias_slice), Some(bias_grads_slice), Some(bias_cache_slice)) = (
self.bias.as_slice_mut(),
bias_grads.as_slice(),
rmsprop_cache.bias.as_slice_mut(),
) {
update_rmsprop(
bias_slice,
bias_grads_slice,
bias_cache_slice,
rho,
epsilon,
lr,
);
}
}
}
}
fn update_parameters_ada_grad(&mut self, lr: f32, epsilon: f32) {
if let (Some(weight_gradients), Some(bias_gradients)) =
(&self.weight_gradients, &self.bias_gradients)
{
if self.optimizer_cache.ada_grad_cache.is_none() {
self.optimizer_cache.ada_grad_cache = Some(AdaGradStatesConv2D {
accumulator: Array4::zeros(self.weights.dim()),
accumulator_bias: Array2::zeros((1, self.bias.len())),
});
}
update_adagrad_conv!(self, weight_gradients, bias_gradients, lr, epsilon);
}
}
fn get_weights(&self) -> LayerWeight<'_> {
LayerWeight::DepthwiseConv2DLayer(DepthwiseConv2DLayerWeight {
weight: &self.weights,
bias: &self.bias,
})
}
}