use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::convolution_layer::PaddingType;
use crate::neural_network::layer::convolution_layer::input_validation_function::{
validate_filters, validate_input_shape_1d, validate_kernel_size_1d, validate_strides_1d,
};
use crate::neural_network::layer::helper_function::update_adam_conv;
use crate::neural_network::layer::layer_weight::{Conv1DLayerWeight, LayerWeight};
use crate::neural_network::neural_network_trait::{ActivationLayer, Layer};
use crate::neural_network::optimizer::OptimizerCacheConv1D;
use crate::neural_network::optimizer::ada_grad::AdaGradStatesConv1D;
use crate::neural_network::optimizer::adam::AdamStatesConv1D;
use crate::neural_network::optimizer::rms_prop::RMSpropCacheConv1D;
use crate::neural_network::optimizer::sgd::SGD;
use ndarray::{Array2, Array3, Axis, s};
use ndarray_rand::{RandomExt, rand_distr::Uniform};
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelBridge, ParallelIterator,
};
const CONV_1D_PARALLEL_THRESHOLD: usize = 1000;
pub struct Conv1D<T: ActivationLayer> {
filters: usize,
kernel_size: usize,
stride: usize,
padding: PaddingType,
weights: Array3<f32>,
bias: Array2<f32>,
activation: T,
input_cache: Option<Tensor>,
input_shape: Vec<usize>,
weight_gradients: Option<Array3<f32>>,
bias_gradients: Option<Array2<f32>>,
optimizer_cache: OptimizerCacheConv1D,
}
impl<T: ActivationLayer> Conv1D<T> {
pub fn new(
filters: usize,
kernel_size: usize,
input_shape: Vec<usize>,
stride: usize,
padding: PaddingType,
activation: T,
) -> Result<Self, ModelError> {
validate_filters(filters)?;
validate_kernel_size_1d(kernel_size)?;
validate_strides_1d(stride)?;
validate_input_shape_1d(&input_shape, kernel_size)?;
let input_channels = input_shape[1];
let fan_in = input_channels * kernel_size;
let fan_out = filters * kernel_size;
let weight_bound = (6.0 / (fan_in + fan_out) as f32).sqrt();
let weights = Array3::random(
(filters, input_channels, kernel_size),
Uniform::new(-weight_bound, weight_bound).unwrap(),
);
let bias = Array2::zeros((1, filters));
Ok(Self {
filters,
kernel_size,
stride,
padding,
weights,
bias,
activation,
input_cache: None,
input_shape,
weight_gradients: None,
bias_gradients: None,
optimizer_cache: OptimizerCacheConv1D {
adam_states: None,
rmsprop_cache: None,
ada_grad_cache: None,
},
})
}
fn calculate_output_length(&self, input_length: usize) -> usize {
match self.padding {
PaddingType::Valid => (input_length - self.kernel_size) / self.stride + 1,
PaddingType::Same => (input_length + self.stride - 1) / self.stride,
}
}
fn apply_padding(&self, input: &Tensor) -> Tensor {
match self.padding {
PaddingType::Valid => input.clone(),
PaddingType::Same => {
let input_shape = input.shape();
let batch_size = input_shape[0];
let channels = input_shape[1];
let input_length = input_shape[2];
let (pad_total, pad_left) = self.calculate_padding_params(input_length);
let mut padded = Array3::zeros((batch_size, channels, input_length + pad_total));
let input_3d = input.view().into_dimensionality::<ndarray::Ix3>().unwrap();
padded
.slice_mut(s![.., .., pad_left..input_length + pad_left])
.assign(&input_3d);
padded.into_dyn()
}
}
}
fn calculate_padding_params(&self, input_length: usize) -> (usize, usize) {
let output_length = (input_length + self.stride - 1) / self.stride;
let pad_total =
((output_length - 1) * self.stride + self.kernel_size).saturating_sub(input_length);
let pad_left = pad_total / 2;
(pad_total, pad_left)
}
pub fn set_weights(&mut self, weights: Array3<f32>, bias: Array2<f32>) {
self.weights = weights;
self.bias = bias;
}
fn get_original_input_pos(&self, padded_pos: usize, input_length: usize) -> Option<usize> {
match self.padding {
PaddingType::Valid => {
if padded_pos < input_length {
Some(padded_pos)
} else {
None
}
}
PaddingType::Same => {
let (_, pad_left) = self.calculate_padding_params(input_length);
if padded_pos >= pad_left && padded_pos < pad_left + input_length {
Some(padded_pos - pad_left)
} else {
None
}
}
}
}
fn compute_conv_output(
&self,
input_3d: &Array3<f32>,
batch: usize,
filter: usize,
out_pos: usize,
input_length: usize,
) -> f32 {
let start_pos = out_pos * self.stride;
let mut sum = 0.0;
for in_channel in 0..self.input_shape[1] {
for kernel_pos in 0..self.kernel_size {
let input_pos = start_pos + kernel_pos;
if input_pos < input_length {
sum += input_3d[[batch, in_channel, input_pos]]
* self.weights[[filter, in_channel, kernel_pos]];
}
}
}
sum + self.bias[[0, filter]]
}
fn compute_batch_gradients(
&self,
batch: usize,
grad_output_3d: &Array3<f32>,
input_3d: &Array3<f32>,
input_channels: usize,
input_length: usize,
output_length: usize,
) -> (Array3<f32>, Array2<f32>, Array2<f32>) {
let mut local_weight_gradients = Array3::zeros(self.weights.dim());
let mut local_bias_gradients = Array2::zeros(self.bias.dim());
let mut local_input_gradients = Array2::zeros((input_channels, input_length));
for filter in 0..self.filters {
for out_pos in 0..output_length {
let grad_val = grad_output_3d[[batch, filter, out_pos]];
let start_pos = out_pos * self.stride;
local_bias_gradients[[0, filter]] += grad_val;
for in_channel in 0..input_channels {
for kernel_pos in 0..self.kernel_size {
let input_pos = start_pos + kernel_pos;
if input_pos < input_3d.shape()[2] {
local_weight_gradients[[filter, in_channel, kernel_pos]] +=
grad_val * input_3d[[batch, in_channel, input_pos]];
if let Some(original_input_pos) =
self.get_original_input_pos(input_pos, input_length)
{
local_input_gradients[[in_channel, original_input_pos]] +=
grad_val * self.weights[[filter, in_channel, kernel_pos]];
}
}
}
}
}
}
(
local_weight_gradients,
local_bias_gradients,
local_input_gradients,
)
}
fn conv1d(&self, input: &Tensor) -> Tensor {
let padded_input = self.apply_padding(input);
let input_shape = padded_input.shape();
let batch_size = input_shape[0];
let input_length = input_shape[2];
let output_length = self.calculate_output_length(input_length);
let mut output = Array3::zeros((batch_size, self.filters, output_length));
let input_3d = padded_input.into_dimensionality::<ndarray::Ix3>().unwrap();
let total_ops = batch_size * self.filters * output_length;
if total_ops >= CONV_1D_PARALLEL_THRESHOLD {
output
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(batch, mut batch_output)| {
batch_output
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(filter, mut filter_output)| {
filter_output.indexed_iter_mut().par_bridge().for_each(
|(out_pos, output_val)| {
*output_val = self.compute_conv_output(
&input_3d,
batch,
filter,
out_pos,
input_length,
);
},
);
});
});
} else {
for batch in 0..batch_size {
for filter in 0..self.filters {
for out_pos in 0..output_length {
output[[batch, filter, out_pos]] = self.compute_conv_output(
&input_3d,
batch,
filter,
out_pos,
input_length,
);
}
}
}
}
output.into_dyn()
}
}
impl<T: ActivationLayer> Layer for Conv1D<T> {
fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
if input.ndim() != 3 {
return Err(ModelError::InputValidationError(
"input tensor is not 3D".to_string(),
));
}
self.input_cache = Some(input.clone());
let output = self.conv1d(input);
self.activation.forward(&output.into_dyn())
}
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
let grad_upstream = self.activation.backward(grad_output)?;
let input = self.input_cache.as_ref().ok_or_else(|| {
ModelError::ProcessingError("No cached input for backward pass".to_string())
})?;
let input_shape = input.shape();
let batch_size = input_shape[0];
let input_channels = input_shape[1];
let input_length = input_shape[2];
let grad_upstream_3d = grad_upstream
.into_dimensionality::<ndarray::Ix3>()
.map_err(|e| {
ModelError::ProcessingError(format!("Failed to convert gradient output: {}", e))
})?;
let mut weight_gradients = Array3::zeros(self.weights.dim());
let mut bias_gradients = Array2::zeros(self.bias.dim());
let mut input_gradients = Array3::zeros((batch_size, input_channels, input_length));
let padded_input = self.apply_padding(input);
let input_3d = padded_input
.into_dimensionality::<ndarray::Ix3>()
.map_err(|e| ModelError::ProcessingError(format!("Failed to convert input: {}", e)))?;
let output_length = grad_upstream_3d.shape()[2];
let total_ops = batch_size * self.filters * output_length;
if total_ops >= CONV_1D_PARALLEL_THRESHOLD {
let batch_results: Vec<_> = (0..batch_size)
.into_par_iter()
.map(|batch| {
self.compute_batch_gradients(
batch,
&grad_upstream_3d,
&input_3d,
input_channels,
input_length,
output_length,
)
})
.collect();
for (batch_idx, (local_weight_grads, local_bias_grads, local_input_grads)) in
batch_results.into_iter().enumerate()
{
weight_gradients += &local_weight_grads;
bias_gradients += &local_bias_grads;
input_gradients
.slice_mut(s![batch_idx, .., ..])
.assign(&local_input_grads);
}
} else {
for batch in 0..batch_size {
let (local_weight_grads, local_bias_grads, local_input_grads) = self
.compute_batch_gradients(
batch,
&grad_upstream_3d,
&input_3d,
input_channels,
input_length,
output_length,
);
weight_gradients += &local_weight_grads;
bias_gradients += &local_bias_grads;
input_gradients
.slice_mut(s![batch, .., ..])
.assign(&local_input_grads);
}
}
self.weight_gradients = Some(weight_gradients);
self.bias_gradients = Some(bias_gradients);
Ok(input_gradients.into_dyn())
}
fn layer_type(&self) -> &str {
"Conv1D"
}
fn output_shape(&self) -> String {
let input_length = self.input_shape[2];
let output_length = self.calculate_output_length(input_length);
format!(
"({}, {}, {})",
self.input_shape[0], self.filters, output_length
)
}
fn param_count(&self) -> TrainingParameters {
TrainingParameters::Trainable(self.weights.len() + self.bias.len())
}
update_sgd_conv!();
fn update_parameters_adam(&mut self, lr: f32, beta1: f32, beta2: f32, epsilon: f32, t: u64) {
if let (Some(weight_gradients), Some(bias_gradients)) =
(&self.weight_gradients, &self.bias_gradients)
{
if self.optimizer_cache.adam_states.is_none() {
self.optimizer_cache.adam_states = Some(AdamStatesConv1D {
m: Array3::zeros(self.weights.dim()),
v: Array3::zeros(self.weights.dim()),
m_bias: Array2::zeros(self.bias.dim()),
v_bias: Array2::zeros(self.bias.dim()),
});
}
if let Some(adam_states) = &mut self.optimizer_cache.adam_states {
let bias_correction1 = 1.0 - beta1.powi(t as i32);
let bias_correction2 = 1.0 - beta2.powi(t as i32);
update_adam_conv(
self.weights.as_slice_mut().unwrap(),
weight_gradients.as_slice().unwrap(),
adam_states.m.as_slice_mut().unwrap(),
adam_states.v.as_slice_mut().unwrap(),
lr,
beta1,
beta2,
epsilon,
bias_correction1,
bias_correction2,
);
update_adam_conv(
self.bias.as_slice_mut().unwrap(),
bias_gradients.as_slice().unwrap(),
adam_states.m_bias.as_slice_mut().unwrap(),
adam_states.v_bias.as_slice_mut().unwrap(),
lr,
beta1,
beta2,
epsilon,
bias_correction1,
bias_correction2,
);
}
}
}
fn update_parameters_rmsprop(&mut self, lr: f32, rho: f32, epsilon: f32) {
if let (Some(weight_gradients), Some(bias_gradients)) =
(&self.weight_gradients, &self.bias_gradients)
{
if self.optimizer_cache.rmsprop_cache.is_none() {
self.optimizer_cache.rmsprop_cache = Some(RMSpropCacheConv1D {
cache: Some(Array3::zeros(self.weights.dim())),
bias: Some(Array2::zeros(self.bias.dim())),
});
}
if let Some(rmsprop_cache) = &mut self.optimizer_cache.rmsprop_cache {
let update_parameters = |params: &mut [f32], cache: &mut [f32], grads: &[f32]| {
cache
.par_iter_mut()
.zip(grads.par_iter())
.for_each(|(c, &grad)| {
*c = rho * *c + (1.0 - rho) * grad * grad;
});
params
.par_iter_mut()
.zip(grads.par_iter())
.zip(cache.par_iter())
.for_each(|((param, &grad), &cache_val)| {
*param -= lr * grad / (cache_val.sqrt() + epsilon);
});
};
if let Some(weight_cache) = &mut rmsprop_cache.cache {
update_parameters(
self.weights.as_slice_mut().unwrap(),
weight_cache.as_slice_mut().unwrap(),
weight_gradients.as_slice().unwrap(),
);
}
if let Some(bias_cache) = &mut rmsprop_cache.bias {
update_parameters(
self.bias.as_slice_mut().unwrap(),
bias_cache.as_slice_mut().unwrap(),
bias_gradients.as_slice().unwrap(),
);
}
}
}
}
fn update_parameters_ada_grad(&mut self, lr: f32, epsilon: f32) {
if let (Some(weight_gradients), Some(bias_gradients)) =
(&self.weight_gradients, &self.bias_gradients)
{
if self.optimizer_cache.ada_grad_cache.is_none() {
self.optimizer_cache.ada_grad_cache = Some(AdaGradStatesConv1D {
accumulator: Array3::zeros(self.weights.dim()),
accumulator_bias: Array2::zeros(self.bias.dim()),
});
}
update_adagrad_conv!(self, weight_gradients, bias_gradients, lr, epsilon);
}
}
fn get_weights(&self) -> LayerWeight<'_> {
LayerWeight::Conv1D(Conv1DLayerWeight {
weight: &self.weights,
bias: &self.bias,
})
}
}