rustyml 0.11.0

A high-performance machine learning & deep learning library in pure Rust, offering ML algorithms and neural network support
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
use crate::error::ModelError;
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::convolution_layer::PaddingType;
use crate::neural_network::layer::convolution_layer::input_validation_function::{
    validate_filters, validate_input_shape_1d, validate_kernel_size_1d, validate_strides_1d,
};
use crate::neural_network::layer::helper_function::update_adam_conv;
use crate::neural_network::layer::layer_weight::{Conv1DLayerWeight, LayerWeight};
use crate::neural_network::neural_network_trait::{ActivationLayer, Layer};
use crate::neural_network::optimizer::OptimizerCacheConv1D;
use crate::neural_network::optimizer::ada_grad::AdaGradStatesConv1D;
use crate::neural_network::optimizer::adam::AdamStatesConv1D;
use crate::neural_network::optimizer::rms_prop::RMSpropCacheConv1D;
use crate::neural_network::optimizer::sgd::SGD;
use ndarray::{Array2, Array3, Axis, s};
use ndarray_rand::{RandomExt, rand_distr::Uniform};
use rayon::iter::{
    IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
    IntoParallelRefMutIterator, ParallelBridge, ParallelIterator,
};

/// Threshold for determining whether to use parallel or sequential computation in Conv1D.
/// When `batch_size * filters * output_length < CONV_1D_PARALLEL_THRESHOLD`,
/// sequential execution is used to avoid parallelization overhead.
const CONV_1D_PARALLEL_THRESHOLD: usize = 1000;

/// A 1D convolutional layer for neural networks.
///
/// Applies a convolution operation to sequential data such as time series, audio signals,
/// or text. Input shape is \[batch_size, channels, length\] and output shape is
/// \[batch_size, filters, output_length\], where output_length depends on input length,
/// kernel size, stride, and padding.
///
/// # Fields
///
/// - `filters` - Number of convolution filters (output channels).
/// - `kernel_size` - Size of the convolution kernel.
/// - `stride` - Stride value for the convolution operation.
/// - `padding` - Type of padding to apply (`Valid` or `Same`).
/// - `weights` - 3D array of filter weights with shape \[filters, channels, kernel_size\].
/// - `bias` - 2D array of bias values with shape \[1, filters\].
/// - `activation` - Activation layer from activation_layer module.
/// - `input_cache` - Cached input from the forward pass, used during backpropagation.
/// - `input_shape` - Shape of the input tensor.
/// - `weight_gradients` - Gradients for the weights, computed during backpropagation.
/// - `bias_gradients` - Gradients for the biases, computed during backpropagation.
/// - `optimizer_cache` - Cache for optimizer-specific state (e.g., momentum values for Adam).
///
/// # Examples
/// ```rust
/// use rustyml::neural_network::sequential::Sequential;
/// use rustyml::neural_network::layer::*;
/// use rustyml::neural_network::optimizer::*;
/// use rustyml::neural_network::loss_function::*;
/// use ndarray::Array3;
///
/// // Create a simple 3D input tensor: [batch_size, channels, length]
/// // Batch size=2, 1 input channel, 10 time steps
/// let x = Array3::ones((2, 1, 10)).into_dyn();
///
/// // Create target tensor - assuming we'll have 3 filters with output length 8
/// let y = Array3::ones((2, 3, 8)).into_dyn();
///
/// // Build model: add a Conv1D layer with 3 filters and kernel size 3
/// let mut model = Sequential::new();
/// model
///     .add(Conv1D::new(
///         3,                      // Number of filters
///         3,                      // Kernel size
///         vec![2, 1, 10],         // Input shape
///         1,                      // Stride
///         PaddingType::Valid,     // No padding
///         ReLU::new(), // ReLU activation layer
///     ).unwrap())
///     .compile(RMSprop::new(0.001, 0.9, 1e-8).unwrap(), MeanSquaredError::new());
///
/// // Print model structure
/// model.summary();
///
/// // Train the model (run a few epochs)
/// model.fit(&x, &y, 3).unwrap();
///
/// // Use predict for forward propagation prediction
/// let prediction = model.predict(&x).unwrap();
/// println!("Convolution layer prediction results: {:?}", prediction);
///
/// // Check if output shape is correct - should be [2, 3, 8]
/// assert_eq!(prediction.shape(), &[2, 3, 8]);
/// ```
pub struct Conv1D<T: ActivationLayer> {
    filters: usize,
    kernel_size: usize,
    stride: usize,
    padding: PaddingType,
    weights: Array3<f32>,
    bias: Array2<f32>,
    activation: T,
    input_cache: Option<Tensor>,
    input_shape: Vec<usize>,
    weight_gradients: Option<Array3<f32>>,
    bias_gradients: Option<Array2<f32>>,
    optimizer_cache: OptimizerCacheConv1D,
}

impl<T: ActivationLayer> Conv1D<T> {
    /// Creates a new Conv1D layer with the specified parameters.
    ///
    /// # Parameters
    ///
    /// - `filters` - Number of output filters (channels)
    /// - `kernel_size` - Size of the convolution kernel
    /// - `input_shape` - Shape of input tensor \[batch_size, channels, length\]
    /// - `stride` - Stride for the convolution operation
    /// - `padding` - Padding type (Valid or Same)
    /// - `activation` - Activation layer from activation_layer module (ReLU, Sigmoid, Tanh, Softmax)
    ///
    /// # Returns
    ///
    /// - `Result<Self, ModelError>` - A new `Conv1D` layer instance or an error
    ///
    /// # Errors
    ///
    /// - `ModelError::InputValidationError` - If `filters`, `kernel_size`, or `stride` is 0
    /// - `ModelError::InputValidationError` - If `input_shape` is not 3D or has 0 channels
    /// - `ModelError::InputValidationError` - If input length is less than kernel size
    pub fn new(
        filters: usize,
        kernel_size: usize,
        input_shape: Vec<usize>,
        stride: usize,
        padding: PaddingType,
        activation: T,
    ) -> Result<Self, ModelError> {
        // Validate input parameters
        validate_filters(filters)?;
        validate_kernel_size_1d(kernel_size)?;
        validate_strides_1d(stride)?;
        validate_input_shape_1d(&input_shape, kernel_size)?;

        let input_channels = input_shape[1];

        // Initialize weights using Xavier initialization for convolutional layers
        // Formula: sqrt(6 / (input_channels * kernel_size + filters * kernel_size))
        let fan_in = input_channels * kernel_size;
        let fan_out = filters * kernel_size;
        let weight_bound = (6.0 / (fan_in + fan_out) as f32).sqrt();

        let weights = Array3::random(
            (filters, input_channels, kernel_size),
            Uniform::new(-weight_bound, weight_bound).unwrap(),
        );

        // Initialize bias to zero
        let bias = Array2::zeros((1, filters));

        Ok(Self {
            filters,
            kernel_size,
            stride,
            padding,
            weights,
            bias,
            activation,
            input_cache: None,
            input_shape,
            weight_gradients: None,
            bias_gradients: None,
            optimizer_cache: OptimizerCacheConv1D {
                adam_states: None,
                rmsprop_cache: None,
                ada_grad_cache: None,
            },
        })
    }

    /// Calculates the output shape after convolution.
    ///
    /// # Parameters
    ///
    /// - `input_length` - Length of the input sequence
    ///
    /// # Returns
    ///
    /// - `usize` - Output length after convolution
    fn calculate_output_length(&self, input_length: usize) -> usize {
        match self.padding {
            PaddingType::Valid => (input_length - self.kernel_size) / self.stride + 1,
            PaddingType::Same => (input_length + self.stride - 1) / self.stride,
        }
    }

    /// Applies padding to the input tensor.
    ///
    /// # Parameters
    ///
    /// - `input` - Input tensor to pad
    ///
    /// # Returns
    ///
    /// - `Tensor` - Padded tensor
    fn apply_padding(&self, input: &Tensor) -> Tensor {
        match self.padding {
            PaddingType::Valid => input.clone(),
            PaddingType::Same => {
                let input_shape = input.shape();
                let batch_size = input_shape[0];
                let channels = input_shape[1];
                let input_length = input_shape[2];

                let (pad_total, pad_left) = self.calculate_padding_params(input_length);

                let mut padded = Array3::zeros((batch_size, channels, input_length + pad_total));
                let input_3d = input.view().into_dimensionality::<ndarray::Ix3>().unwrap();
                padded
                    .slice_mut(s![.., .., pad_left..input_length + pad_left])
                    .assign(&input_3d);

                padded.into_dyn()
            }
        }
    }

    /// Calculate padding parameters for Same padding mode.
    fn calculate_padding_params(&self, input_length: usize) -> (usize, usize) {
        let output_length = (input_length + self.stride - 1) / self.stride;
        let pad_total =
            ((output_length - 1) * self.stride + self.kernel_size).saturating_sub(input_length);
        let pad_left = pad_total / 2;
        (pad_total, pad_left)
    }

    /// Sets the weights and bias for this layer.
    ///
    /// # Parameters
    ///
    /// - `weights` - 3D array of filter weights with shape \[filters, channels, kernel_size\]
    /// - `bias` - 2D array of bias values with shape \[1, filters\]
    pub fn set_weights(&mut self, weights: Array3<f32>, bias: Array2<f32>) {
        self.weights = weights;
        self.bias = bias;
    }

    /// Convert padded input position to original input position.
    fn get_original_input_pos(&self, padded_pos: usize, input_length: usize) -> Option<usize> {
        match self.padding {
            PaddingType::Valid => {
                if padded_pos < input_length {
                    Some(padded_pos)
                } else {
                    None
                }
            }
            PaddingType::Same => {
                let (_, pad_left) = self.calculate_padding_params(input_length);
                if padded_pos >= pad_left && padded_pos < pad_left + input_length {
                    Some(padded_pos - pad_left)
                } else {
                    None
                }
            }
        }
    }

    /// Computes a single convolution output value.
    fn compute_conv_output(
        &self,
        input_3d: &Array3<f32>,
        batch: usize,
        filter: usize,
        out_pos: usize,
        input_length: usize,
    ) -> f32 {
        let start_pos = out_pos * self.stride;
        let mut sum = 0.0;

        // Convolution operation
        for in_channel in 0..self.input_shape[1] {
            for kernel_pos in 0..self.kernel_size {
                let input_pos = start_pos + kernel_pos;
                if input_pos < input_length {
                    sum += input_3d[[batch, in_channel, input_pos]]
                        * self.weights[[filter, in_channel, kernel_pos]];
                }
            }
        }

        // Add bias
        sum + self.bias[[0, filter]]
    }

    /// Computes gradients for a single batch during backpropagation.
    fn compute_batch_gradients(
        &self,
        batch: usize,
        grad_output_3d: &Array3<f32>,
        input_3d: &Array3<f32>,
        input_channels: usize,
        input_length: usize,
        output_length: usize,
    ) -> (Array3<f32>, Array2<f32>, Array2<f32>) {
        let mut local_weight_gradients = Array3::zeros(self.weights.dim());
        let mut local_bias_gradients = Array2::zeros(self.bias.dim());
        let mut local_input_gradients = Array2::zeros((input_channels, input_length));

        for filter in 0..self.filters {
            for out_pos in 0..output_length {
                let grad_val = grad_output_3d[[batch, filter, out_pos]];
                let start_pos = out_pos * self.stride;

                // Bias gradients
                local_bias_gradients[[0, filter]] += grad_val;

                // Weight and input gradients
                for in_channel in 0..input_channels {
                    for kernel_pos in 0..self.kernel_size {
                        let input_pos = start_pos + kernel_pos;
                        if input_pos < input_3d.shape()[2] {
                            // Weight gradients
                            local_weight_gradients[[filter, in_channel, kernel_pos]] +=
                                grad_val * input_3d[[batch, in_channel, input_pos]];

                            // Input gradients (considering padding)
                            if let Some(original_input_pos) =
                                self.get_original_input_pos(input_pos, input_length)
                            {
                                local_input_gradients[[in_channel, original_input_pos]] +=
                                    grad_val * self.weights[[filter, in_channel, kernel_pos]];
                            }
                        }
                    }
                }
            }
        }

        (
            local_weight_gradients,
            local_bias_gradients,
            local_input_gradients,
        )
    }

    /// Performs 1D convolution operation with adaptive parallel/sequential processing.
    ///
    /// # Parameters
    ///
    /// - `input` - Input tensor with shape \[batch_size, channels, length\]
    ///
    /// # Returns
    ///
    /// - `Tensor` - Output tensor after convolution
    fn conv1d(&self, input: &Tensor) -> Tensor {
        let padded_input = self.apply_padding(input);
        let input_shape = padded_input.shape();
        let batch_size = input_shape[0];
        let input_length = input_shape[2];

        let output_length = self.calculate_output_length(input_length);
        let mut output = Array3::zeros((batch_size, self.filters, output_length));

        let input_3d = padded_input.into_dimensionality::<ndarray::Ix3>().unwrap();

        // Determine whether to use parallel or sequential execution
        let total_ops = batch_size * self.filters * output_length;

        if total_ops >= CONV_1D_PARALLEL_THRESHOLD {
            // Parallel processing for large workloads
            output
                .axis_iter_mut(Axis(0))
                .into_par_iter()
                .enumerate()
                .for_each(|(batch, mut batch_output)| {
                    batch_output
                        .axis_iter_mut(Axis(0))
                        .into_par_iter()
                        .enumerate()
                        .for_each(|(filter, mut filter_output)| {
                            filter_output.indexed_iter_mut().par_bridge().for_each(
                                |(out_pos, output_val)| {
                                    *output_val = self.compute_conv_output(
                                        &input_3d,
                                        batch,
                                        filter,
                                        out_pos,
                                        input_length,
                                    );
                                },
                            );
                        });
                });
        } else {
            // Sequential processing for small workloads
            for batch in 0..batch_size {
                for filter in 0..self.filters {
                    for out_pos in 0..output_length {
                        output[[batch, filter, out_pos]] = self.compute_conv_output(
                            &input_3d,
                            batch,
                            filter,
                            out_pos,
                            input_length,
                        );
                    }
                }
            }
        }

        output.into_dyn()
    }
}

impl<T: ActivationLayer> Layer for Conv1D<T> {
    fn forward(&mut self, input: &Tensor) -> Result<Tensor, ModelError> {
        // Validate input is 3D
        if input.ndim() != 3 {
            return Err(ModelError::InputValidationError(
                "input tensor is not 3D".to_string(),
            ));
        }

        // Cache input for backpropagation
        self.input_cache = Some(input.clone());

        // Perform convolution
        let output = self.conv1d(input);

        // Apply activation
        self.activation.forward(&output.into_dyn())
    }

    fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor, ModelError> {
        // Apply activation backward pass
        let grad_upstream = self.activation.backward(grad_output)?;

        // Retrieve cached input from forward pass
        let input = self.input_cache.as_ref().ok_or_else(|| {
            ModelError::ProcessingError("No cached input for backward pass".to_string())
        })?;

        let input_shape = input.shape();
        let batch_size = input_shape[0];
        let input_channels = input_shape[1];
        let input_length = input_shape[2];

        let grad_upstream_3d = grad_upstream
            .into_dimensionality::<ndarray::Ix3>()
            .map_err(|e| {
                ModelError::ProcessingError(format!("Failed to convert gradient output: {}", e))
            })?;

        // Initialize gradients
        let mut weight_gradients = Array3::zeros(self.weights.dim());
        let mut bias_gradients = Array2::zeros(self.bias.dim());
        let mut input_gradients = Array3::zeros((batch_size, input_channels, input_length));

        let padded_input = self.apply_padding(input);
        let input_3d = padded_input
            .into_dimensionality::<ndarray::Ix3>()
            .map_err(|e| ModelError::ProcessingError(format!("Failed to convert input: {}", e)))?;

        let output_length = grad_upstream_3d.shape()[2];

        // Determine whether to use parallel or sequential execution
        let total_ops = batch_size * self.filters * output_length;

        if total_ops >= CONV_1D_PARALLEL_THRESHOLD {
            // Parallel computation for large workloads
            let batch_results: Vec<_> = (0..batch_size)
                .into_par_iter()
                .map(|batch| {
                    self.compute_batch_gradients(
                        batch,
                        &grad_upstream_3d,
                        &input_3d,
                        input_channels,
                        input_length,
                        output_length,
                    )
                })
                .collect();

            // Aggregate results from all batches
            for (batch_idx, (local_weight_grads, local_bias_grads, local_input_grads)) in
                batch_results.into_iter().enumerate()
            {
                weight_gradients += &local_weight_grads;
                bias_gradients += &local_bias_grads;
                input_gradients
                    .slice_mut(s![batch_idx, .., ..])
                    .assign(&local_input_grads);
            }
        } else {
            // Sequential computation for small workloads
            for batch in 0..batch_size {
                let (local_weight_grads, local_bias_grads, local_input_grads) = self
                    .compute_batch_gradients(
                        batch,
                        &grad_upstream_3d,
                        &input_3d,
                        input_channels,
                        input_length,
                        output_length,
                    );

                weight_gradients += &local_weight_grads;
                bias_gradients += &local_bias_grads;
                input_gradients
                    .slice_mut(s![batch, .., ..])
                    .assign(&local_input_grads);
            }
        }

        // Store gradients
        self.weight_gradients = Some(weight_gradients);
        self.bias_gradients = Some(bias_gradients);

        Ok(input_gradients.into_dyn())
    }

    fn layer_type(&self) -> &str {
        "Conv1D"
    }

    fn output_shape(&self) -> String {
        let input_length = self.input_shape[2];
        let output_length = self.calculate_output_length(input_length);
        format!(
            "({}, {}, {})",
            self.input_shape[0], self.filters, output_length
        )
    }

    fn param_count(&self) -> TrainingParameters {
        TrainingParameters::Trainable(self.weights.len() + self.bias.len())
    }

    update_sgd_conv!();

    fn update_parameters_adam(&mut self, lr: f32, beta1: f32, beta2: f32, epsilon: f32, t: u64) {
        if let (Some(weight_gradients), Some(bias_gradients)) =
            (&self.weight_gradients, &self.bias_gradients)
        {
            // Initialize Adam states (if not already initialized)
            if self.optimizer_cache.adam_states.is_none() {
                self.optimizer_cache.adam_states = Some(AdamStatesConv1D {
                    m: Array3::zeros(self.weights.dim()),
                    v: Array3::zeros(self.weights.dim()),
                    m_bias: Array2::zeros(self.bias.dim()),
                    v_bias: Array2::zeros(self.bias.dim()),
                });
            }

            if let Some(adam_states) = &mut self.optimizer_cache.adam_states {
                // Compute bias correction factors
                let bias_correction1 = 1.0 - beta1.powi(t as i32);
                let bias_correction2 = 1.0 - beta2.powi(t as i32);

                // Update weight parameters
                update_adam_conv(
                    self.weights.as_slice_mut().unwrap(),
                    weight_gradients.as_slice().unwrap(),
                    adam_states.m.as_slice_mut().unwrap(),
                    adam_states.v.as_slice_mut().unwrap(),
                    lr,
                    beta1,
                    beta2,
                    epsilon,
                    bias_correction1,
                    bias_correction2,
                );

                // Update bias parameters
                update_adam_conv(
                    self.bias.as_slice_mut().unwrap(),
                    bias_gradients.as_slice().unwrap(),
                    adam_states.m_bias.as_slice_mut().unwrap(),
                    adam_states.v_bias.as_slice_mut().unwrap(),
                    lr,
                    beta1,
                    beta2,
                    epsilon,
                    bias_correction1,
                    bias_correction2,
                );
            }
        }
    }

    fn update_parameters_rmsprop(&mut self, lr: f32, rho: f32, epsilon: f32) {
        if let (Some(weight_gradients), Some(bias_gradients)) =
            (&self.weight_gradients, &self.bias_gradients)
        {
            // Initialize RMSprop cache (if not already initialized)
            if self.optimizer_cache.rmsprop_cache.is_none() {
                self.optimizer_cache.rmsprop_cache = Some(RMSpropCacheConv1D {
                    cache: Some(Array3::zeros(self.weights.dim())),
                    bias: Some(Array2::zeros(self.bias.dim())),
                });
            }

            if let Some(rmsprop_cache) = &mut self.optimizer_cache.rmsprop_cache {
                // Define a generic parameter update closure
                let update_parameters = |params: &mut [f32], cache: &mut [f32], grads: &[f32]| {
                    // Update cache (moving average of squared gradients) in parallel
                    cache
                        .par_iter_mut()
                        .zip(grads.par_iter())
                        .for_each(|(c, &grad)| {
                            *c = rho * *c + (1.0 - rho) * grad * grad;
                        });

                    // Update parameters in parallel
                    params
                        .par_iter_mut()
                        .zip(grads.par_iter())
                        .zip(cache.par_iter())
                        .for_each(|((param, &grad), &cache_val)| {
                            *param -= lr * grad / (cache_val.sqrt() + epsilon);
                        });
                };

                // Update weight parameters
                if let Some(weight_cache) = &mut rmsprop_cache.cache {
                    update_parameters(
                        self.weights.as_slice_mut().unwrap(),
                        weight_cache.as_slice_mut().unwrap(),
                        weight_gradients.as_slice().unwrap(),
                    );
                }

                // Update bias parameters
                if let Some(bias_cache) = &mut rmsprop_cache.bias {
                    update_parameters(
                        self.bias.as_slice_mut().unwrap(),
                        bias_cache.as_slice_mut().unwrap(),
                        bias_gradients.as_slice().unwrap(),
                    );
                }
            }
        }
    }

    fn update_parameters_ada_grad(&mut self, lr: f32, epsilon: f32) {
        if let (Some(weight_gradients), Some(bias_gradients)) =
            (&self.weight_gradients, &self.bias_gradients)
        {
            // Initialize AdaGrad cache (if not already initialized)
            if self.optimizer_cache.ada_grad_cache.is_none() {
                self.optimizer_cache.ada_grad_cache = Some(AdaGradStatesConv1D {
                    accumulator: Array3::zeros(self.weights.dim()),
                    accumulator_bias: Array2::zeros(self.bias.dim()),
                });
            }

            update_adagrad_conv!(self, weight_gradients, bias_gradients, lr, epsilon);
        }
    }

    fn get_weights(&self) -> LayerWeight<'_> {
        LayerWeight::Conv1D(Conv1DLayerWeight {
            weight: &self.weights,
            bias: &self.bias,
        })
    }
}