rustyml 0.11.0

A high-performance machine learning & deep learning library in pure Rust, offering ML algorithms and neural network support
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::convolution_layer::PaddingType;
use crate::neural_network::layer::convolution_layer::input_validation_function::{
    validate_filters, validate_input_shape_2d, validate_kernel_size_2d, validate_strides_2d,
};
use crate::neural_network::layer::helper_function::{
    compute_row_gradient_sum, merge_results, update_adam_conv, update_rmsprop,
};
use crate::neural_network::layer::layer_weight::{Conv2DLayerWeight, LayerWeight};
use crate::neural_network::neural_network_trait::{ActivationLayer, Layer};
use crate::neural_network::optimizer::OptimizerCacheConv2D;
use crate::neural_network::optimizer::ada_grad::AdaGradStatesConv2D;
use crate::neural_network::optimizer::adam::AdamStatesConv2D;
use crate::neural_network::optimizer::rms_prop::RMSpropCacheConv2D;
use crate::neural_network::optimizer::sgd::SGD;
use ndarray::{Array2, Array3, Array4, Axis, s};
use ndarray_rand::{RandomExt, rand_distr::Uniform};
use rayon::iter::{
    IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
    IntoParallelRefMutIterator, ParallelIterator,
};

/// Threshold for deciding when to use parallel computation in Conv2D operations.
/// If batch_size * filters * output_area < threshold, use sequential processing.
/// Otherwise, use parallel processing with Rayon.
const CONV_2D_PARALLEL_THRESHOLD: usize = 10000;

/// A 2D convolutional layer for neural networks.
///
/// Applies a convolution operation to grid-like data such as images. Input shape is
/// \[batch_size, channels, height, width\] and output shape is
/// \[batch_size, filters, output_height, output_width\], where output dimensions depend on
/// input size, kernel size, strides, and padding.
///
/// # Fields
///
/// - `filters` - Number of convolution filters (output channels).
/// - `kernel_size` - Size of the convolution kernel as (height, width).
/// - `strides` - Stride values for the convolution operation as (vertical, horizontal).
/// - `padding` - Type of padding to apply (`Valid` or `Same`).
/// - `weights` - 4D array of filter weights with shape \[filters, channels, kernel_height, kernel_width\].
/// - `bias` - 2D array of bias values with shape \[1, filters\].
/// - `activation` - Activation layer from activation_layer module.
/// - `input_cache` - Cached input from the forward pass, used during backpropagation.
/// - `input_shape` - Shape of the input tensor.
/// - `weight_gradients` - Gradients for the weights, computed during backpropagation.
/// - `bias_gradients` - Gradients for the biases, computed during backpropagation.
/// - `optimizer_cache` - Cache for optimizer-specific state (e.g., momentum values for Adam).
///
/// # Examples
/// ```rust
/// use rustyml::neural_network::sequential::Sequential;
/// use rustyml::neural_network::layer::*;
/// use rustyml::neural_network::optimizer::*;
/// use rustyml::neural_network::loss_function::*;
/// use ndarray::Array4;
///
/// // Create a simple 4D input tensor: [batch_size, channels, height, width]
/// // Batch size=2, 1 input channel, 5x5 pixels
/// let x = Array4::ones((2, 1, 5, 5)).into_dyn();
///
/// // Create target tensor - assuming we'll have 3 filters with output size 3x3
/// let y = Array4::ones((2, 3, 3, 3)).into_dyn();
///
/// // Build model: add a Conv2D layer with 3 filters and 3x3 kernel
/// let mut model = Sequential::new();
/// model
///     .add(Conv2D::new(
///         3,                      // Number of filters
///         (3, 3),                 // Kernel size
///         vec![2, 1, 5, 5],       // Input shape
///         (1, 1),                 // Stride
///         PaddingType::Valid,     // No padding
///         ReLU::new(), // ReLU activation layer
///     ).unwrap())
///     .compile(RMSprop::new(0.001, 0.9, 1e-8).unwrap(), MeanSquaredError::new());
///
/// // Print model structure
/// model.summary();
///
/// // Train the model (run a few epochs)
/// model.fit(&x, &y, 3).unwrap();
///
/// // Use predict for forward propagation prediction
/// let prediction = model.predict(&x).unwrap();
/// println!("Convolution layer prediction results: {:?}", prediction);
///
/// // Check if output shape is correct - should be [2, 3, 3, 3]
/// assert_eq!(prediction.shape(), &[2, 3, 3, 3]);
/// ```
pub struct Conv2D<T: ActivationLayer> {
    filters: usize,
    kernel_size: (usize, usize),
    strides: (usize, usize),
    padding: PaddingType,
    weights: Array4<f32>,
    bias: Array2<f32>,
    activation: T,
    input_cache: Option<Tensor>,
    input_shape: Vec<usize>,
    weight_gradients: Option<Array4<f32>>,
    bias_gradients: Option<Array2<f32>>,
    optimizer_cache: OptimizerCacheConv2D,
}

impl<T: ActivationLayer> Conv2D<T> {
    /// Creates a new 2D convolutional layer with the specified parameters.
    ///
    /// Weights are initialized using Xavier (Glorot) uniform initialization.
    /// Biases are initialized to zeros.
    ///
    /// # Parameters
    ///
    /// - `filters` - Number of convolution filters (output channels).
    /// - `kernel_size` - Size of the convolution kernel as (height, width).
    /// - `input_shape` - Shape of the input tensor as \[batch_size, channels, height, width\].
    /// - `strides` - Stride values for the convolution operation as (vertical, horizontal).
    /// - `padding` - Type of padding to apply (`Valid` or `Same`).
    /// - `activation` - Activation layer from activation_layer module (ReLU, Sigmoid, Tanh, Softmax)
    ///
    /// # Returns
    ///
    /// - `Result<Self, ModelError>` - A new `Conv2D` layer instance with randomly initialized weights or an error
    ///
    /// # Errors
    ///
    /// - `ModelError::InputValidationError` - If `filters` is 0
    /// - `ModelError::InputValidationError` - If any kernel dimension or stride is 0
    /// - `ModelError::InputValidationError` - If `input_shape` is not 4D or has 0 channels
    /// - `ModelError::InputValidationError` - If input dimensions are smaller than kernel size
    pub fn new(
        filters: usize,
        kernel_size: (usize, usize),
        input_shape: Vec<usize>,
        strides: (usize, usize),
        padding: PaddingType,
        activation: T,
    ) -> Result<Self, ModelError> {
        validate_filters(filters)?;
        validate_kernel_size_2d(kernel_size)?;
        validate_strides_2d(strides)?;
        validate_input_shape_2d(&input_shape, kernel_size)?;

        // Shape is [batch_size, channels, height, width]
        let channels = input_shape[1];

        // Initialize weights using Xavier initialization for convolutional layers
        // Formula: sqrt(6 / (input_channels * kernel_area + filters * kernel_area))
        let fan_in = channels * kernel_size.0 * kernel_size.1;
        let fan_out = filters * kernel_size.0 * kernel_size.1;
        let weight_bound = (6.0 / (fan_in + fan_out) as f32).sqrt();

        let weights = Array4::random(
            (filters, channels, kernel_size.0, kernel_size.1),
            Uniform::new(-weight_bound, weight_bound).unwrap(),
        );

        // Initialize biases to zero
        let bias = Array2::zeros((1, filters));

        Ok(Conv2D {
            filters,
            kernel_size,
            strides,
            padding,
            weights,
            bias,
            activation,
            input_cache: None,
            input_shape,
            weight_gradients: None,
            bias_gradients: None,
            optimizer_cache: OptimizerCacheConv2D {
                adam_states: None,
                rmsprop_cache: None,
                ada_grad_cache: None,
            },
        })
    }

    /// Calculates the output shape of the convolutional layer based on input dimensions.
    fn calculate_output_shape(&self, input_shape: &[usize]) -> Vec<usize> {
        let batch_size = input_shape[0];
        let input_height = input_shape[2];
        let input_width = input_shape[3];

        let (output_height, output_width) = match self.padding {
            PaddingType::Valid => {
                let out_height = (input_height - self.kernel_size.0) / self.strides.0 + 1;
                let out_width = (input_width - self.kernel_size.1) / self.strides.1 + 1;
                (out_height, out_width)
            }
            PaddingType::Same => {
                let out_height = (input_height + self.strides.0 - 1) / self.strides.0;
                let out_width = (input_width + self.strides.1 - 1) / self.strides.1;
                (out_height, out_width)
            }
        };

        vec![batch_size, self.filters, output_height, output_width]
    }

    /// Sets the weights and bias for this layer.
    ///
    /// # Parameters
    ///
    /// - `weights` - 4D array of filter weights with shape \[filters, channels, kernel_height, kernel_width\]
    /// - `bias` - 2D array of bias values with shape \[1, filters\]
    pub fn set_weights(&mut self, weights: Array4<f32>, bias: Array2<f32>) {
        self.weights = weights;
        self.bias = bias;
    }

    /// Applies padding to the input tensor for PaddingType::Same.
    fn apply_padding(&self, input: &Tensor) -> Tensor {
        match self.padding {
            PaddingType::Valid => input.clone(),
            PaddingType::Same => {
                let input_shape = input.shape();
                let batch_size = input_shape[0];
                let channels = input_shape[1];
                let input_height = input_shape[2];
                let input_width = input_shape[3];

                // Calculate padding amounts
                let out_height = (input_height + self.strides.0 - 1) / self.strides.0;
                let out_width = (input_width + self.strides.1 - 1) / self.strides.1;

                let pad_height = ((out_height - 1) * self.strides.0 + self.kernel_size.0)
                    .saturating_sub(input_height);
                let pad_width = ((out_width - 1) * self.strides.1 + self.kernel_size.1)
                    .saturating_sub(input_width);

                let pad_top = pad_height / 2;
                let pad_left = pad_width / 2;

                let padded_height = input_height + pad_height;
                let padded_width = input_width + pad_width;

                let mut padded = Array4::zeros((batch_size, channels, padded_height, padded_width));
                let input_4d = input.view().into_dimensionality::<ndarray::Ix4>().unwrap();

                padded
                    .slice_mut(s![
                        ..,
                        ..,
                        pad_top..pad_top + input_height,
                        pad_left..pad_left + input_width
                    ])
                    .assign(&input_4d);

                padded.into_dyn()
            }
        }
    }

    /// Computes convolution for a single batch.
    fn compute_batch_convolution(
        &self,
        b: usize,
        padded_input: &Tensor,
        in_channels: usize,
        output_shape: &[usize],
    ) -> (usize, Array3<f32>) {
        // Create output portion for this batch
        let mut batch_output = Array3::zeros((self.filters, output_shape[2], output_shape[3]));

        // Computation for each batch
        for f in 0..self.filters {
            for i in 0..output_shape[2] {
                let i_base = i * self.strides.0;

                for j in 0..output_shape[3] {
                    let j_base = j * self.strides.1;
                    let mut sum = 0.0;

                    // Convolution kernel calculation
                    for c in 0..in_channels {
                        for ki in 0..self.kernel_size.0 {
                            let i_pos = i_base + ki;

                            for kj in 0..self.kernel_size.1 {
                                let j_pos = j_base + kj;
                                sum += padded_input[[b, c, i_pos, j_pos]]
                                    * self.weights[[f, c, ki, kj]];
                            }
                        }
                    }

                    // Update batch output
                    sum += self.bias[[0, f]];
                    batch_output[[f, i, j]] = sum;
                }
            }
        }

        (b, batch_output)
    }

    /// Performs the convolution operation on the input tensor.
    ///
    /// This method implements the core convolution algorithm with optimizations:
    /// - Adaptive parallel/sequential processing based on workload size
    /// - Boundary condition pre-checking
    /// - Memory access pattern optimization
    ///
    /// # Parameters
    ///
    /// - `input` - Input tensor with shape \[batch_size, channels, height, width\]
    ///
    /// # Returns
    ///
    /// - `Tensor` - Output tensor with shape \[batch_size, filters, output_height, output_width\]
    fn convolve(&self, input: &Tensor) -> Tensor {
        // Apply padding if needed
        let padded_input = self.apply_padding(input);
        let input_shape = padded_input.shape();
        let batch_size = input_shape[0];
        let in_channels = input_shape[1];
        let output_shape = self.calculate_output_shape(input.shape());

        // Calculate workload size to decide between parallel and sequential execution
        let workload_size = batch_size * self.filters * output_shape[2] * output_shape[3];

        // Choose execution strategy based on workload
        let results: Vec<_> = if workload_size >= CONV_2D_PARALLEL_THRESHOLD {
            // Use parallel processing for large workloads
            (0..batch_size)
                .into_par_iter()
                .map(|b| {
                    self.compute_batch_convolution(b, &padded_input, in_channels, &output_shape)
                })
                .collect()
        } else {
            // Use sequential processing for small workloads
            (0..batch_size)
                .map(|b| {
                    self.compute_batch_convolution(b, &padded_input, in_channels, &output_shape)
                })
                .collect()
        };

        // Merge results from each batch into final output
        merge_results(output_shape, results, self.filters)
    }
}

impl<T: ActivationLayer> Layer for Conv2D<T> {
    fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
        // Validate input is 4D
        if input.ndim() != 4 {
            return Err(ModelError::InputValidationError(
                "input tensor is not 4D".to_string(),
            ));
        }

        // Save input for backpropagation
        self.input_cache = Some(input.clone());

        // Perform convolution operation
        let output = self.convolve(input);

        // Apply activation
        self.activation.forward(&output.into_dyn())
    }

    fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
        // Apply activation backward pass
        let grad_upstream = self.activation.backward(grad_output)?;

        if let Some(input) = &self.input_cache {
            let original_input_shape = input.shape();

            // Apply padding to input (same as in forward pass)
            let padded_input = self.apply_padding(input);
            let input_shape = padded_input.shape();

            let batch_size = input_shape[0];
            let channels = input_shape[1];
            let grad_shape = grad_upstream.shape();

            let gradient = grad_upstream.clone();

            // Initialize gradients for weights and biases
            let mut weight_grads = Array4::zeros(self.weights.dim());
            let mut bias_grads = Array2::zeros((1, self.filters));

            // Calculate bias gradients in parallel
            bias_grads
                .axis_iter_mut(Axis(1))
                .into_par_iter()
                .enumerate()
                .for_each(|(f, mut bias)| {
                    let mut sum = 0.0;
                    for b in 0..batch_size {
                        for i in 0..grad_shape[2] {
                            for j in 0..grad_shape[3] {
                                sum += gradient[[b, f, i, j]];
                            }
                        }
                    }
                    *bias.first_mut().unwrap() = sum;
                });

            // Optimize weight gradient calculation using parallel computation
            weight_grads
                .axis_iter_mut(Axis(0))
                .into_par_iter()
                .enumerate()
                .for_each(|(f, mut filter_grad)| {
                    // Process each filter in parallel
                    for c in 0..channels {
                        for h in 0..self.kernel_size.0 {
                            for w in 0..self.kernel_size.1 {
                                let mut sum = 0.0;
                                // Pre-check boundary conditions to reduce conditional checks
                                for b in 0..batch_size {
                                    for i in 0..grad_shape[2] {
                                        let i_pos = i * self.strides.0 + h;
                                        if i_pos >= input_shape[2] {
                                            continue;
                                        }

                                        sum += compute_row_gradient_sum(
                                            &gradient,
                                            &padded_input,
                                            b,
                                            f,
                                            c,
                                            i,
                                            i_pos,
                                            w,
                                            grad_shape,
                                            input_shape,
                                            self.strides.1,
                                        );
                                    }
                                }
                                filter_grad[[c, h, w]] = sum;
                            }
                        }
                    }
                });

            // Save gradients for optimization
            self.weight_gradients = Some(weight_grads);
            self.bias_gradients = Some(bias_grads);

            // Use batch-wise parallel processing and collect results
            let local_results: Vec<_> = (0..batch_size)
                .into_par_iter()
                .map(|b| {
                    // Create local gradients for padded input shape
                    let mut local_gradients =
                        Array3::zeros([channels, input_shape[2], input_shape[3]]);

                    for c in 0..channels {
                        for i in 0..input_shape[2] {
                            for j in 0..input_shape[3] {
                                let mut sum = 0.0;

                                for f in 0..self.filters {
                                    for h in 0..self.kernel_size.0 {
                                        for w in 0..self.kernel_size.1 {
                                            // Check if indices are valid
                                            if i >= h && j >= w {
                                                let grad_i = (i - h) / self.strides.0;
                                                let grad_j = (j - w) / self.strides.1;

                                                // Check if calculated gradient position is valid
                                                if grad_i < grad_shape[2]
                                                    && grad_j < grad_shape[3]
                                                    && (i - h) % self.strides.0 == 0
                                                    && (j - w) % self.strides.1 == 0
                                                {
                                                    sum += gradient[[b, f, grad_i, grad_j]]
                                                        * self.weights[[f, c, h, w]];
                                                }
                                            }
                                        }
                                    }
                                }

                                local_gradients[[c, i, j]] = sum;
                            }
                        }
                    }

                    (b, local_gradients)
                })
                .collect();

            // Merge padded gradients
            let padded_grad = merge_results(
                vec![batch_size, channels, input_shape[2], input_shape[3]],
                local_results,
                channels,
            );

            // Remove padding from gradients if PaddingType::Same was used
            let final_grad = match self.padding {
                PaddingType::Valid => padded_grad,
                PaddingType::Same => {
                    let pad_height = input_shape[2].saturating_sub(original_input_shape[2]);
                    let pad_width = input_shape[3].saturating_sub(original_input_shape[3]);
                    let pad_top = pad_height / 2;
                    let pad_left = pad_width / 2;

                    let padded_4d = padded_grad.into_dimensionality::<ndarray::Ix4>().unwrap();
                    padded_4d
                        .slice(s![
                            ..,
                            ..,
                            pad_top..pad_top + original_input_shape[2],
                            pad_left..pad_left + original_input_shape[3]
                        ])
                        .to_owned()
                        .into_dyn()
                }
            };

            Ok(final_grad)
        } else {
            Err(ModelError::ProcessingError(
                "Forward pass has not been run".to_string(),
            ))
        }
    }

    fn layer_type(&self) -> &str {
        "Conv2D"
    }

    fn output_shape(&self) -> String {
        let output_shape = self.calculate_output_shape(&self.input_shape);
        format!(
            "({}, {}, {}, {})",
            output_shape[0], output_shape[1], output_shape[2], output_shape[3]
        )
    }

    fn param_count(&self) -> TrainingParameters {
        TrainingParameters::Trainable(self.weights.len() + self.bias.len())
    }

    update_sgd_conv!();

    fn update_parameters_adam(&mut self, lr: f32, beta1: f32, beta2: f32, epsilon: f32, t: u64) {
        if let (Some(weight_grads), Some(bias_grads)) =
            (&self.weight_gradients, &self.bias_gradients)
        {
            // Initialize Adam states (if not already initialized)
            if self.optimizer_cache.adam_states.is_none() {
                self.optimizer_cache.adam_states = Some(AdamStatesConv2D {
                    m: Array4::zeros(self.weights.dim()),
                    v: Array4::zeros(self.weights.dim()),
                    m_bias: Array2::zeros(self.bias.dim()),
                    v_bias: Array2::zeros(self.bias.dim()),
                });
            }

            if let Some(adam_states) = &mut self.optimizer_cache.adam_states {
                // Compute bias correction factors
                let bias_correction1 = 1.0 - beta1.powi(t as i32);
                let bias_correction2 = 1.0 - beta2.powi(t as i32);

                // Update weight parameters
                if let (Some(weight_slice), Some(weight_grad_slice), Some(m_slice), Some(v_slice)) = (
                    self.weights.as_slice_mut(),
                    weight_grads.as_slice(),
                    adam_states.m.as_slice_mut(),
                    adam_states.v.as_slice_mut(),
                ) {
                    update_adam_conv(
                        weight_slice,
                        weight_grad_slice,
                        m_slice,
                        v_slice,
                        lr,
                        beta1,
                        beta2,
                        epsilon,
                        bias_correction1,
                        bias_correction2,
                    );
                }

                // Update bias parameters
                if let (
                    Some(bias_slice),
                    Some(bias_grad_slice),
                    Some(m_bias_slice),
                    Some(v_bias_slice),
                ) = (
                    self.bias.as_slice_mut(),
                    bias_grads.as_slice(),
                    adam_states.m_bias.as_slice_mut(),
                    adam_states.v_bias.as_slice_mut(),
                ) {
                    update_adam_conv(
                        bias_slice,
                        bias_grad_slice,
                        m_bias_slice,
                        v_bias_slice,
                        lr,
                        beta1,
                        beta2,
                        epsilon,
                        bias_correction1,
                        bias_correction2,
                    );
                }
            }
        }
    }

    fn update_parameters_rmsprop(&mut self, lr: f32, rho: f32, epsilon: f32) {
        if let (Some(weight_grads), Some(bias_grads)) =
            (&self.weight_gradients, &self.bias_gradients)
        {
            // Initialize RMSprop cache (if not already initialized)
            if self.optimizer_cache.rmsprop_cache.is_none() {
                self.optimizer_cache.rmsprop_cache = Some(RMSpropCacheConv2D {
                    cache: Array4::zeros(self.weights.dim()),
                    bias: Array2::zeros(self.bias.dim()),
                });
            }

            if let Some(rmsprop_cache) = &mut self.optimizer_cache.rmsprop_cache {
                // Update weights
                if let (Some(weight_slice), Some(weight_grad_slice), Some(cache_slice)) = (
                    self.weights.as_slice_mut(),
                    weight_grads.as_slice(),
                    rmsprop_cache.cache.as_slice_mut(),
                ) {
                    update_rmsprop(
                        weight_slice,
                        weight_grad_slice,
                        cache_slice,
                        rho,
                        epsilon,
                        lr,
                    );
                }

                // Update biases
                if let (Some(bias_slice), Some(bias_grad_slice), Some(bias_cache_slice)) = (
                    self.bias.as_slice_mut(),
                    bias_grads.as_slice(),
                    rmsprop_cache.bias.as_slice_mut(),
                ) {
                    update_rmsprop(
                        bias_slice,
                        bias_grad_slice,
                        bias_cache_slice,
                        rho,
                        epsilon,
                        lr,
                    );
                }
            }
        }
    }

    fn update_parameters_ada_grad(&mut self, lr: f32, epsilon: f32) {
        if let (Some(weight_gradients), Some(bias_gradients)) =
            (&self.weight_gradients, &self.bias_gradients)
        {
            // Initialize AdaGrad cache (if not already initialized)
            if self.optimizer_cache.ada_grad_cache.is_none() {
                self.optimizer_cache.ada_grad_cache = Some(AdaGradStatesConv2D {
                    accumulator: Array4::zeros(self.weights.dim()),
                    accumulator_bias: Array2::zeros(self.bias.dim()),
                });
            }

            update_adagrad_conv!(self, weight_gradients, bias_gradients, lr, epsilon);
        }
    }

    fn get_weights(&self) -> LayerWeight<'_> {
        LayerWeight::Conv2D(Conv2DLayerWeight {
            weight: &self.weights,
            bias: &self.bias,
        })
    }
}