optirs_core/neural_integration/
mod.rs

1// Neural network integration for optimizers
2//
3// This module provides interfaces and utilities for integrating optimizers with neural networks,
4// including generic parameter optimization, lazy registration, and architecture-aware optimizations.
5
6use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11// use statrs::statistics::Statistics; // statrs not available
12
13/// Type alias for layer identifiers
14pub type LayerId = String;
15
16/// Type alias for parameter identifiers
17pub type ParamId = String;
18
19/// Parameter metadata for neural network parameters
20#[derive(Debug, Clone)]
21pub struct ParameterMetadata {
22    /// Layer name this parameter belongs to
23    pub layername: LayerId,
24    /// Parameter name within the layer
25    pub param_name: ParamId,
26    /// Parameter shape
27    pub shape: Vec<usize>,
28    /// Whether parameter requires gradients
29    pub requires_grad: bool,
30    /// Parameter type (weights, bias, etc.)
31    pub paramtype: ParameterType,
32    /// Sharing group for parameter sharing
33    pub sharing_group: Option<String>,
34    /// Custom tags for architecture-specific optimizations
35    pub tags: Vec<String>,
36}
37
38/// Types of neural network parameters
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum ParameterType {
41    /// Weight matrices
42    Weight,
43    /// Bias vectors
44    Bias,
45    /// Normalization parameters (scale/shift)
46    Normalization,
47    /// Embedding parameters
48    Embedding,
49    /// Attention parameters
50    Attention,
51    /// Custom parameter type
52    Custom,
53}
54
55/// Layer architecture information
56#[derive(Debug, Clone)]
57pub struct LayerArchitecture {
58    /// Layer type name
59    pub layer_type: String,
60    /// Input dimensions
61    pub input_dims: Vec<usize>,
62    /// Output dimensions
63    pub output_dims: Vec<usize>,
64    /// Layer-specific configuration
65    pub config: HashMap<String, LayerConfig>,
66    /// Whether layer is trainable
67    pub trainable: bool,
68}
69
70/// Layer configuration values
71#[derive(Debug, Clone)]
72pub enum LayerConfig {
73    /// Integer value
74    Int(i64),
75    /// Float value
76    Float(f64),
77    /// String value
78    String(String),
79    /// Boolean value
80    Bool(bool),
81    /// List of values
82    List(Vec<LayerConfig>),
83}
84
85/// Generic parameter optimization interface
86pub trait ParameterOptimizer<A: Float, D: Dimension> {
87    /// Register a parameter for optimization
88    fn register_parameter(
89        &mut self,
90        paramid: ParamId,
91        parameter: &Array<A, D>,
92        metadata: ParameterMetadata,
93    ) -> Result<()>;
94
95    /// Update registered parameters with gradients
96    fn step(
97        &mut self,
98        gradients: HashMap<ParamId, Array<A, D>>,
99        parameters: &mut HashMap<ParamId, Array<A, D>>,
100    ) -> Result<()>;
101
102    /// Get parameter-specific learning rate
103    fn get_learning_rate(&self, paramid: &ParamId) -> Option<A>;
104
105    /// Set parameter-specific learning rate
106    fn set_learning_rate(&mut self, paramid: &ParamId, lr: A) -> Result<()>;
107
108    /// Get optimizer state for a parameter
109    fn get_parameter_state(&self, paramid: &ParamId) -> Option<&HashMap<String, Array<A, D>>>;
110
111    /// Reset optimizer state
112    fn reset_state(&mut self);
113
114    /// Get all registered parameter IDs
115    fn registered_parameters(&self) -> Vec<ParamId>;
116}
117
118/// Neural network parameter manager with lazy registration
119#[derive(Debug)]
120pub struct ParameterManager<A: Float, D: Dimension> {
121    /// Registered parameters with metadata
122    parameters: HashMap<ParamId, ParameterMetadata>,
123    /// Parameter optimizer states
124    optimizer_states: HashMap<ParamId, HashMap<String, Array<A, D>>>,
125    /// Layer architectures
126    layer_architectures: HashMap<LayerId, LayerArchitecture>,
127    /// Parameter sharing groups
128    sharing_groups: HashMap<String, Vec<ParamId>>,
129    /// Layer-specific optimization rules
130    layer_rules: HashMap<LayerId, LayerOptimizationRule<A>>,
131    /// Global optimization configuration
132    global_config: OptimizationConfig<A>,
133    /// Lazy registration mode
134    lazy_mode: bool,
135    /// Pending registrations (for lazy mode)
136    pending_registrations: Vec<(ParamId, ParameterMetadata)>,
137}
138
139/// Layer-specific optimization rules
140#[derive(Debug, Clone)]
141pub struct LayerOptimizationRule<A: Float> {
142    /// Learning rate multiplier for this layer
143    pub lr_multiplier: A,
144    /// Weight decay multiplier
145    pub weight_decay_multiplier: A,
146    /// Whether to freeze this layer
147    pub frozen: bool,
148    /// Custom optimizer settings
149    pub custom_settings: HashMap<String, LayerConfig>,
150}
151
152/// Global optimization configuration
153#[derive(Debug, Clone)]
154pub struct OptimizationConfig<A: Float> {
155    /// Base learning rate
156    pub base_learning_rate: A,
157    /// Global weight decay
158    pub weight_decay: A,
159    /// Gradient clipping threshold
160    pub gradient_clip: Option<A>,
161    /// Whether to use mixed precision
162    pub mixed_precision: bool,
163    /// Architecture-specific optimizations
164    pub architecture_optimizations: HashMap<String, bool>,
165}
166
167impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
168    ParameterManager<A, D>
169{
170    /// Create a new parameter manager
171    pub fn new(config: OptimizationConfig<A>) -> Self {
172        Self {
173            parameters: HashMap::new(),
174            optimizer_states: HashMap::new(),
175            layer_architectures: HashMap::new(),
176            sharing_groups: HashMap::new(),
177            layer_rules: HashMap::new(),
178            global_config: config,
179            lazy_mode: false,
180            pending_registrations: Vec::new(),
181        }
182    }
183
184    /// Enable lazy registration mode
185    pub fn enable_lazy_mode(&mut self) {
186        self.lazy_mode = true;
187    }
188
189    /// Disable lazy registration mode and process pending registrations
190    pub fn disable_lazy_mode(&mut self) -> Result<()> {
191        self.lazy_mode = false;
192
193        // Process all pending registrations
194        let pending = std::mem::take(&mut self.pending_registrations);
195        for (paramid, metadata) in pending {
196            self.register_parameter_impl(paramid, metadata)?;
197        }
198
199        Ok(())
200    }
201
202    /// Register a layer architecture
203    pub fn register_layer(&mut self, layerid: LayerId, architecture: LayerArchitecture) {
204        self.layer_architectures.insert(layerid, architecture);
205    }
206
207    /// Set layer-specific optimization rule
208    pub fn set_layer_rule(&mut self, layerid: LayerId, rule: LayerOptimizationRule<A>) {
209        self.layer_rules.insert(layerid, rule);
210    }
211
212    /// Register a parameter
213    pub fn register_parameter(
214        &mut self,
215        paramid: ParamId,
216        metadata: ParameterMetadata,
217    ) -> Result<()> {
218        if self.lazy_mode {
219            self.pending_registrations.push((paramid, metadata));
220            Ok(())
221        } else {
222            self.register_parameter_impl(paramid, metadata)
223        }
224    }
225
226    /// Internal parameter registration implementation
227    fn register_parameter_impl(
228        &mut self,
229        paramid: ParamId,
230        metadata: ParameterMetadata,
231    ) -> Result<()> {
232        // Handle parameter sharing
233        if let Some(sharing_group) = &metadata.sharing_group {
234            self.sharing_groups
235                .entry(sharing_group.clone())
236                .or_default()
237                .push(paramid.clone());
238        }
239
240        // Initialize optimizer state for this parameter
241        self.optimizer_states
242            .insert(paramid.clone(), HashMap::new());
243
244        // Store parameter metadata
245        self.parameters.insert(paramid, metadata);
246
247        Ok(())
248    }
249
250    /// Get effective learning rate for a parameter
251    pub fn get_effective_learning_rate(&self, paramid: &ParamId) -> A {
252        let base_lr = self.global_config.base_learning_rate;
253
254        if let Some(metadata) = self.parameters.get(paramid) {
255            if let Some(rule) = self.layer_rules.get(&metadata.layername) {
256                return base_lr * rule.lr_multiplier;
257            }
258        }
259
260        base_lr
261    }
262
263    /// Get effective weight decay for a parameter
264    pub fn get_effective_weight_decay(&self, paramid: &ParamId) -> A {
265        let base_decay = self.global_config.weight_decay;
266
267        if let Some(metadata) = self.parameters.get(paramid) {
268            if let Some(rule) = self.layer_rules.get(&metadata.layername) {
269                return base_decay * rule.weight_decay_multiplier;
270            }
271        }
272
273        base_decay
274    }
275
276    /// Check if parameter is frozen
277    pub fn is_parameter_frozen(&self, paramid: &ParamId) -> bool {
278        if let Some(metadata) = self.parameters.get(paramid) {
279            if let Some(rule) = self.layer_rules.get(&metadata.layername) {
280                return rule.frozen;
281            }
282        }
283        false
284    }
285
286    /// Get parameters in a sharing group
287    pub fn get_sharing_group(&self, groupname: &str) -> Option<&[ParamId]> {
288        self.sharing_groups.get(groupname).map(|v| v.as_slice())
289    }
290
291    /// Get all registered parameters
292    pub fn get_all_parameters(&self) -> &HashMap<ParamId, ParameterMetadata> {
293        &self.parameters
294    }
295
296    /// Get layer architecture
297    pub fn get_layer_architecture(&self, layerid: &LayerId) -> Option<&LayerArchitecture> {
298        self.layer_architectures.get(layerid)
299    }
300
301    /// Get parameter metadata
302    pub fn get_parameter_metadata(&self, paramid: &ParamId) -> Option<&ParameterMetadata> {
303        self.parameters.get(paramid)
304    }
305
306    /// Update global configuration
307    pub fn update_config(&mut self, config: OptimizationConfig<A>) {
308        self.global_config = config;
309    }
310
311    /// Get optimizer state for parameter
312    pub fn get_optimizer_state(&self, paramid: &ParamId) -> Option<&HashMap<String, Array<A, D>>> {
313        self.optimizer_states.get(paramid)
314    }
315
316    /// Get mutable optimizer state for parameter
317    pub fn get_optimizer_state_mut(
318        &mut self,
319        paramid: &ParamId,
320    ) -> Option<&mut HashMap<String, Array<A, D>>> {
321        self.optimizer_states.get_mut(paramid)
322    }
323
324    /// Initialize optimizer state for parameter
325    pub fn init_optimizer_state(
326        &mut self,
327        paramid: &ParamId,
328        state_name: &str,
329        state: Array<A, D>,
330    ) -> Result<()> {
331        if let Some(states) = self.optimizer_states.get_mut(paramid) {
332            states.insert(state_name.to_string(), state);
333            Ok(())
334        } else {
335            Err(OptimError::InvalidConfig(format!(
336                "Parameter {} not registered",
337                paramid
338            )))
339        }
340    }
341
342    /// Reset all optimizer states
343    pub fn reset_optimizer_states(&mut self) {
344        for states in self.optimizer_states.values_mut() {
345            states.clear();
346        }
347    }
348
349    /// Get parameters by layer
350    pub fn get_parameters_by_layer(&self, layerid: &LayerId) -> Vec<&ParamId> {
351        self.parameters
352            .iter()
353            .filter(|(_, metadata)| &metadata.layername == layerid)
354            .map(|(paramid, _)| paramid)
355            .collect()
356    }
357
358    /// Get parameters by type
359    pub fn get_parameters_by_type(&self, paramtype: ParameterType) -> Vec<&ParamId> {
360        self.parameters
361            .iter()
362            .filter(|(_, metadata)| metadata.paramtype == paramtype)
363            .map(|(paramid, _)| paramid)
364            .collect()
365    }
366
367    /// Get trainable parameters
368    pub fn get_trainable_parameters(&self) -> Vec<&ParamId> {
369        self.parameters
370            .iter()
371            .filter(|(paramid, metadata)| {
372                metadata.requires_grad && !self.is_parameter_frozen(paramid)
373            })
374            .map(|(paramid, _)| paramid)
375            .collect()
376    }
377}
378
379impl<A: Float + Send + Sync> Default for OptimizationConfig<A> {
380    fn default() -> Self {
381        Self {
382            base_learning_rate: A::from(0.001).unwrap(),
383            weight_decay: A::zero(),
384            gradient_clip: None,
385            mixed_precision: false,
386            architecture_optimizations: HashMap::new(),
387        }
388    }
389}
390
391impl<A: Float + Send + Sync> Default for LayerOptimizationRule<A> {
392    fn default() -> Self {
393        Self {
394            lr_multiplier: A::one(),
395            weight_decay_multiplier: A::one(),
396            frozen: false,
397            custom_settings: HashMap::new(),
398        }
399    }
400}
401
402/// Forward/backward pass integration
403pub mod forward_backward {
404    use super::*;
405
406    /// Forward pass hook for parameter tracking
407    pub trait ForwardHook<A: Float, D: Dimension> {
408        /// Called before layer forward pass
409        fn pre_forward(&mut self, layerid: &LayerId, inputs: &[Array<A, D>]) -> Result<()>;
410
411        /// Called after layer forward pass
412        fn post_forward(&mut self, layerid: &LayerId, outputs: &[Array<A, D>]) -> Result<()>;
413    }
414
415    /// Backward pass hook for gradient processing
416    pub trait BackwardHook<A: Float, D: Dimension> {
417        /// Called before layer backward pass
418        fn pre_backward(&mut self, layerid: &LayerId, gradoutputs: &[Array<A, D>]) -> Result<()>;
419
420        /// Called after layer backward pass
421        fn post_backward(&mut self, layerid: &LayerId, gradinputs: &[Array<A, D>]) -> Result<()>;
422    }
423
424    /// Neural network integration manager
425    pub struct NeuralIntegration<A: Float, D: Dimension> {
426        /// Parameter manager
427        param_manager: ParameterManager<A, D>,
428        /// Forward hooks
429        forward_hooks: HashMap<LayerId, Box<dyn ForwardHook<A, D>>>,
430        /// Backward hooks
431        backward_hooks: HashMap<LayerId, Box<dyn BackwardHook<A, D>>>,
432        /// Gradient accumulation mode
433        gradient_accumulation: bool,
434        /// Accumulated gradients
435        accumulated_gradients: HashMap<ParamId, Array<A, D>>,
436        /// Accumulation count
437        accumulation_count: usize,
438    }
439
440    impl<
441            A: Float
442                + ScalarOperand
443                + Debug
444                + 'static
445                + scirs2_core::numeric::FromPrimitive
446                + std::iter::Sum
447                + Send
448                + Sync,
449            D: Dimension + 'static,
450        > NeuralIntegration<A, D>
451    {
452        /// Create a new neural integration manager
453        pub fn new(config: OptimizationConfig<A>) -> Self {
454            Self {
455                param_manager: ParameterManager::new(config),
456                forward_hooks: HashMap::new(),
457                backward_hooks: HashMap::new(),
458                gradient_accumulation: false,
459                accumulated_gradients: HashMap::new(),
460                accumulation_count: 0,
461            }
462        }
463
464        /// Register a forward hook for a layer
465        pub fn register_forward_hook<H>(&mut self, layerid: LayerId, hook: H)
466        where
467            H: ForwardHook<A, D> + 'static,
468        {
469            self.forward_hooks.insert(layerid, Box::new(hook));
470        }
471
472        /// Register a backward hook for a layer
473        pub fn register_backward_hook<H>(&mut self, layerid: LayerId, hook: H)
474        where
475            H: BackwardHook<A, D> + 'static,
476        {
477            self.backward_hooks.insert(layerid, Box::new(hook));
478        }
479
480        /// Enable gradient accumulation
481        pub fn enable_gradient_accumulation(&mut self) {
482            self.gradient_accumulation = true;
483        }
484
485        /// Disable gradient accumulation and return accumulated gradients
486        pub fn disable_gradient_accumulation(&mut self) -> HashMap<ParamId, Array<A, D>> {
487            self.gradient_accumulation = false;
488            let result = std::mem::take(&mut self.accumulated_gradients);
489            self.accumulation_count = 0;
490            result
491        }
492
493        /// Execute forward pass with hooks
494        pub fn forward_pass(
495            &mut self,
496            layerid: &LayerId,
497            inputs: &[Array<A, D>],
498        ) -> Result<Vec<Array<A, D>>> {
499            // Execute pre-forward hook
500            if let Some(hook) = self.forward_hooks.get_mut(layerid) {
501                hook.pre_forward(layerid, inputs)?;
502            }
503
504            // Get layer architecture and parameters
505            let layer_arch = self
506                .param_manager
507                .get_layer_architecture(layerid)
508                .ok_or_else(|| {
509                    OptimError::InvalidConfig(format!("Layer {} not registered", layerid))
510                })?
511                .clone();
512
513            // Compute outputs based on layer type
514            let outputs = match layer_arch.layer_type.as_str() {
515                "linear" | "dense" | "fc" => {
516                    // Linear layer: output = input @ weight^T + bias
517                    self.compute_linear_forward(layerid, inputs)?
518                }
519                "conv" | "conv2d" => {
520                    // Convolutional layer: simplified computation
521                    self.compute_conv_forward(layerid, inputs)?
522                }
523                "activation" => {
524                    // Activation layer: apply activation function
525                    self.compute_activation_forward(layerid, inputs, &layer_arch)?
526                }
527                "normalization" | "batchnorm" | "layernorm" => {
528                    // Normalization layer
529                    self.compute_normalization_forward(layerid, inputs)?
530                }
531                "dropout" => {
532                    // Dropout layer: apply dropout mask
533                    self.compute_dropout_forward(layerid, inputs, &layer_arch)?
534                }
535                "pooling" | "maxpool" | "avgpool" => {
536                    // Pooling layer
537                    self.compute_pooling_forward(layerid, inputs, &layer_arch)?
538                }
539                _ => {
540                    // Default: pass through for unknown layer types
541                    inputs.to_vec()
542                }
543            };
544
545            // Execute post-forward hook
546            if let Some(hook) = self.forward_hooks.get_mut(layerid) {
547                hook.post_forward(layerid, &outputs)?;
548            }
549
550            Ok(outputs)
551        }
552
553        /// Compute linear layer forward pass
554        fn compute_linear_forward(
555            &self,
556            layerid: &LayerId,
557            inputs: &[Array<A, D>],
558        ) -> Result<Vec<Array<A, D>>> {
559            // For demonstration, we implement a simple pass-through
560            // In a real implementation, this would multiply by weights and add bias
561            if inputs.is_empty() {
562                return Err(OptimError::InvalidConfig(
563                    "Linear layer requires input".to_string(),
564                ));
565            }
566
567            // Get parameters for this layer
568            let layer_params = self.param_manager.get_parameters_by_layer(layerid);
569
570            // Simple transformation: scale input by learning rate (as a placeholder)
571            let lr =
572                self.param_manager
573                    .get_effective_learning_rate(layer_params.first().ok_or_else(|| {
574                        OptimError::InvalidConfig("No parameters for linear layer".to_string())
575                    })?);
576
577            let outputs: Vec<Array<A, D>> =
578                inputs.iter().map(|input| input.mapv(|x| x * lr)).collect();
579
580            Ok(outputs)
581        }
582
583        /// Compute convolutional layer forward pass
584        fn compute_conv_forward(
585            &self,
586            _layer_id: &LayerId,
587            inputs: &[Array<A, D>],
588        ) -> Result<Vec<Array<A, D>>> {
589            // Simplified convolution: just pass through
590            // Real implementation would apply convolution kernels
591            Ok(inputs.to_vec())
592        }
593
594        /// Compute activation forward pass
595        fn compute_activation_forward(
596            &self,
597            _layer_id: &LayerId,
598            inputs: &[Array<A, D>],
599            layer_arch: &LayerArchitecture,
600        ) -> Result<Vec<Array<A, D>>> {
601            let activation_type = layer_arch
602                .config
603                .get("activation")
604                .and_then(|v| match v {
605                    LayerConfig::String(s) => Some(s.as_str()),
606                    _ => None,
607                })
608                .unwrap_or("relu");
609
610            let outputs: Vec<Array<A, D>> = inputs
611                .iter()
612                .map(|input| {
613                    match activation_type {
614                        "relu" => input.mapv(|x| if x > A::zero() { x } else { A::zero() }),
615                        "sigmoid" => input.mapv(|x| A::one() / (A::one() + (-x).exp())),
616                        "tanh" => input.mapv(|x| x.tanh()),
617                        "leaky_relu" => {
618                            let alpha = A::from(0.01).unwrap();
619                            input.mapv(|x| if x > A::zero() { x } else { alpha * x })
620                        }
621                        _ => input.clone(), // Unknown activation, pass through
622                    }
623                })
624                .collect();
625
626            Ok(outputs)
627        }
628
629        /// Compute normalization forward pass
630        fn compute_normalization_forward(
631            &self,
632            _layer_id: &LayerId,
633            inputs: &[Array<A, D>],
634        ) -> Result<Vec<Array<A, D>>> {
635            // Simplified normalization: normalize to zero mean and unit variance
636            let outputs: Vec<Array<A, D>> = inputs
637                .iter()
638                .map(|input| {
639                    let mean = input.iter().copied().sum::<A>()
640                        / A::from(input.len()).unwrap_or(A::zero());
641                    let variance = input
642                        .mapv(|x| (x - mean).powi(2))
643                        .mean()
644                        .unwrap_or(A::one());
645                    let std_dev = variance.sqrt();
646                    let epsilon = A::from(1e-5).unwrap();
647
648                    input.mapv(|x| (x - mean) / (std_dev + epsilon))
649                })
650                .collect();
651
652            Ok(outputs)
653        }
654
655        /// Compute dropout forward pass
656        fn compute_dropout_forward(
657            &self,
658            _layer_id: &LayerId,
659            inputs: &[Array<A, D>],
660            layer_arch: &LayerArchitecture,
661        ) -> Result<Vec<Array<A, D>>> {
662            let dropout_rate = layer_arch
663                .config
664                .get("dropout_rate")
665                .and_then(|v| match v {
666                    LayerConfig::Float(f) => Some(A::from(*f).unwrap()),
667                    _ => None,
668                })
669                .unwrap_or(A::from(0.5).unwrap());
670
671            // During training, we would apply dropout mask
672            // For now, scale by (1 - dropout_rate) to maintain expected value
673            let scale = A::one() - dropout_rate;
674            let outputs: Vec<Array<A, D>> = inputs
675                .iter()
676                .map(|input| input.mapv(|x| x * scale))
677                .collect();
678
679            Ok(outputs)
680        }
681
682        /// Compute pooling forward pass
683        fn compute_pooling_forward(
684            &self,
685            _layer_id: &LayerId,
686            inputs: &[Array<A, D>],
687            _layer_arch: &LayerArchitecture,
688        ) -> Result<Vec<Array<A, D>>> {
689            // Simplified pooling: just pass through
690            // Real implementation would downsample the input
691            Ok(inputs.to_vec())
692        }
693
694        /// Execute backward pass with hooks
695        pub fn backward_pass(
696            &mut self,
697            layerid: &LayerId,
698            grad_outputs: &[Array<A, D>],
699        ) -> Result<Vec<Array<A, D>>> {
700            // Execute pre-backward hook
701            if let Some(hook) = self.backward_hooks.get_mut(layerid) {
702                hook.pre_backward(layerid, grad_outputs)?;
703            }
704
705            // Get layer architecture
706            let layer_arch = self
707                .param_manager
708                .get_layer_architecture(layerid)
709                .ok_or_else(|| {
710                    OptimError::InvalidConfig(format!("Layer {} not registered", layerid))
711                })?
712                .clone();
713
714            // Compute gradients based on layer type
715            let grad_inputs = match layer_arch.layer_type.as_str() {
716                "linear" | "dense" | "fc" => {
717                    // Linear layer gradient computation
718                    self.compute_linear_backward(layerid, grad_outputs)?
719                }
720                "conv" | "conv2d" => {
721                    // Convolutional layer gradient computation
722                    self.compute_conv_backward(layerid, grad_outputs)?
723                }
724                "activation" => {
725                    // Activation gradient computation
726                    self.compute_activation_backward(layerid, grad_outputs, &layer_arch)?
727                }
728                "normalization" | "batchnorm" | "layernorm" => {
729                    // Normalization gradient computation
730                    self.compute_normalization_backward(layerid, grad_outputs)?
731                }
732                "dropout" => {
733                    // Dropout gradient computation
734                    self.compute_dropout_backward(layerid, grad_outputs, &layer_arch)?
735                }
736                "pooling" | "maxpool" | "avgpool" => {
737                    // Pooling gradient computation
738                    self.compute_pooling_backward(layerid, grad_outputs, &layer_arch)?
739                }
740                _ => {
741                    // Default: pass through gradients for unknown layer types
742                    grad_outputs.to_vec()
743                }
744            };
745
746            // Apply gradient clipping if configured
747            let clipped_grads =
748                if let Some(clipvalue) = self.param_manager.global_config.gradient_clip {
749                    self.apply_gradient_clipping(grad_inputs, clipvalue)?
750                } else {
751                    grad_inputs
752                };
753
754            // Execute post-backward hook
755            if let Some(hook) = self.backward_hooks.get_mut(layerid) {
756                hook.post_backward(layerid, &clipped_grads)?;
757            }
758
759            Ok(clipped_grads)
760        }
761
762        /// Compute linear layer backward pass
763        fn compute_linear_backward(
764            &mut self,
765            layerid: &LayerId,
766            grad_outputs: &[Array<A, D>],
767        ) -> Result<Vec<Array<A, D>>> {
768            if grad_outputs.is_empty() {
769                return Err(OptimError::InvalidConfig(
770                    "Linear layer backward requires gradients".to_string(),
771                ));
772            }
773
774            // Get parameters for this layer
775            let layer_params = self.param_manager.get_parameters_by_layer(layerid);
776
777            // Store gradients for weight update
778            if self.gradient_accumulation {
779                let mut param_grads = HashMap::new();
780                for (i, paramid) in layer_params.iter().enumerate() {
781                    if i < grad_outputs.len() {
782                        param_grads.insert((*paramid).clone(), grad_outputs[i].clone());
783                    }
784                }
785                self.accumulate_gradients(param_grads)?;
786            }
787
788            // Simple gradient transformation: scale by learning rate decay
789            let lr_decay = A::from(0.9).unwrap();
790            let grad_inputs: Vec<Array<A, D>> = grad_outputs
791                .iter()
792                .map(|grad| grad.mapv(|x| x * lr_decay))
793                .collect();
794
795            Ok(grad_inputs)
796        }
797
798        /// Compute convolutional layer backward pass
799        fn compute_conv_backward(
800            &self,
801            _layer_id: &LayerId,
802            grad_outputs: &[Array<A, D>],
803        ) -> Result<Vec<Array<A, D>>> {
804            // Simplified convolution backward: pass through gradients
805            // Real implementation would compute gradients w.r.t. kernels and input
806            Ok(grad_outputs.to_vec())
807        }
808
809        /// Compute activation backward pass
810        fn compute_activation_backward(
811            &self,
812            _layer_id: &LayerId,
813            grad_outputs: &[Array<A, D>],
814            layer_arch: &LayerArchitecture,
815        ) -> Result<Vec<Array<A, D>>> {
816            let activation_type = layer_arch
817                .config
818                .get("activation")
819                .and_then(|v| match v {
820                    LayerConfig::String(s) => Some(s.as_str()),
821                    _ => None,
822                })
823                .unwrap_or("relu");
824
825            // Note: This is simplified - real implementation would need the forward pass inputs
826            let grad_inputs: Vec<Array<A, D>> = grad_outputs
827                .iter()
828                .map(|grad| {
829                    match activation_type {
830                        "relu" => {
831                            // ReLU gradient: 1 if x > 0, 0 otherwise
832                            // Since we don't have the original input, we approximate
833                            grad.mapv(|g| if g > A::zero() { g } else { A::zero() })
834                        }
835                        "sigmoid" => {
836                            // Sigmoid gradient: sigmoid(x) * (1 - sigmoid(x))
837                            // Approximation without original input
838                            let factor = A::from(0.25).unwrap(); // Max gradient of sigmoid
839                            grad.mapv(|g| g * factor)
840                        }
841                        "tanh" => {
842                            // Tanh gradient: 1 - tanh(x)^2
843                            // Approximation without original input
844                            let factor = A::from(0.5).unwrap();
845                            grad.mapv(|g| g * factor)
846                        }
847                        "leaky_relu" => {
848                            let alpha = A::from(0.01).unwrap();
849                            grad.mapv(|g| if g > A::zero() { g } else { alpha * g })
850                        }
851                        _ => grad.clone(), // Unknown activation, pass through
852                    }
853                })
854                .collect();
855
856            Ok(grad_inputs)
857        }
858
859        /// Compute normalization backward pass
860        fn compute_normalization_backward(
861            &self,
862            _layer_id: &LayerId,
863            grad_outputs: &[Array<A, D>],
864        ) -> Result<Vec<Array<A, D>>> {
865            // Simplified normalization backward
866            // Real implementation would compute gradients considering mean and variance
867            let scale_factor = A::from(0.9).unwrap();
868            let grad_inputs: Vec<Array<A, D>> = grad_outputs
869                .iter()
870                .map(|grad| grad.mapv(|g| g * scale_factor))
871                .collect();
872
873            Ok(grad_inputs)
874        }
875
876        /// Compute dropout backward pass
877        fn compute_dropout_backward(
878            &self,
879            _layer_id: &LayerId,
880            grad_outputs: &[Array<A, D>],
881            layer_arch: &LayerArchitecture,
882        ) -> Result<Vec<Array<A, D>>> {
883            let dropout_rate = layer_arch
884                .config
885                .get("dropout_rate")
886                .and_then(|v| match v {
887                    LayerConfig::Float(f) => Some(A::from(*f).unwrap()),
888                    _ => None,
889                })
890                .unwrap_or(A::from(0.5).unwrap());
891
892            // Scale gradients by (1 - dropout_rate) to match forward pass
893            let scale = A::one() - dropout_rate;
894            let grad_inputs: Vec<Array<A, D>> = grad_outputs
895                .iter()
896                .map(|grad| grad.mapv(|g| g * scale))
897                .collect();
898
899            Ok(grad_inputs)
900        }
901
902        /// Compute pooling backward pass
903        fn compute_pooling_backward(
904            &self,
905            _layer_id: &LayerId,
906            grad_outputs: &[Array<A, D>],
907            _layer_arch: &LayerArchitecture,
908        ) -> Result<Vec<Array<A, D>>> {
909            // Simplified pooling backward: pass through gradients
910            // Real implementation would upsample gradients to match input size
911            Ok(grad_outputs.to_vec())
912        }
913
914        /// Apply gradient clipping
915        fn apply_gradient_clipping(
916            &self,
917            gradients: Vec<Array<A, D>>,
918            clipvalue: A,
919        ) -> Result<Vec<Array<A, D>>> {
920            let clipped: Vec<Array<A, D>> = gradients
921                .into_iter()
922                .map(|grad| {
923                    // Compute L2 norm of gradient
924                    let norm = grad.mapv(|x| x * x).sum().sqrt();
925
926                    if norm > clipvalue {
927                        // Scale gradient to have norm = clipvalue
928                        let scale = clipvalue / norm;
929                        grad.mapv(|x| x * scale)
930                    } else {
931                        grad
932                    }
933                })
934                .collect();
935
936            Ok(clipped)
937        }
938
939        /// Accumulate gradients for parameters
940        pub fn accumulate_gradients(
941            &mut self,
942            gradients: HashMap<ParamId, Array<A, D>>,
943        ) -> Result<()> {
944            if !self.gradient_accumulation {
945                return Err(OptimError::InvalidConfig(
946                    "Gradient accumulation not enabled".to_string(),
947                ));
948            }
949
950            self.accumulation_count += 1;
951
952            for (paramid, grad) in gradients {
953                if let Some(acc_grad) = self.accumulated_gradients.get_mut(&paramid) {
954                    // Add to existing accumulated gradient
955                    *acc_grad = acc_grad.clone() + grad;
956                } else {
957                    // First gradient for this parameter
958                    self.accumulated_gradients.insert(paramid, grad);
959                }
960            }
961
962            Ok(())
963        }
964
965        /// Get parameter manager
966        pub fn parameter_manager(&self) -> &ParameterManager<A, D> {
967            &self.param_manager
968        }
969
970        /// Get mutable parameter manager
971        pub fn parameter_manager_mut(&mut self) -> &mut ParameterManager<A, D> {
972            &mut self.param_manager
973        }
974
975        /// Get accumulation count
976        pub fn accumulation_count(&self) -> usize {
977            self.accumulation_count
978        }
979    }
980}
981
982/// Architecture-aware optimization utilities
983pub mod architecture_aware {
984    use super::*;
985
986    /// Architecture-specific optimization strategy
987    #[derive(Debug, Clone)]
988    pub enum ArchitectureStrategy {
989        /// Transformer-specific optimizations
990        Transformer {
991            /// Use different learning rates for different components
992            component_specific_lr: bool,
993            /// Apply layer-wise learning rate decay
994            layer_wise_decay: bool,
995            /// Warmup steps for attention parameters
996            attention_warmup: usize,
997        },
998        /// CNN-specific optimizations
999        ConvolutionalNet {
1000            /// Use different learning rates for conv vs fc layers
1001            layer_type_lr: bool,
1002            /// Apply depth-wise learning rate scaling
1003            depth_scaling: bool,
1004            /// Batch norm parameter handling
1005            bn_special_handling: bool,
1006        },
1007        /// RNN-specific optimizations
1008        RecurrentNet {
1009            /// Gradient clipping specifically for RNNs
1010            rnn_gradient_clip: Option<f64>,
1011            /// Different learning rates for recurrent vs linear weights
1012            weight_type_lr: bool,
1013        },
1014        /// Custom architecture strategy
1015        Custom {
1016            /// Custom optimization rules
1017            rules: HashMap<String, LayerConfig>,
1018        },
1019    }
1020
1021    /// Architecture-aware optimizer
1022    #[derive(Debug)]
1023    pub struct ArchitectureAwareOptimizer<A: Float, D: Dimension> {
1024        /// Parameter manager
1025        param_manager: ParameterManager<A, D>,
1026        /// Architecture strategy
1027        strategy: ArchitectureStrategy,
1028        /// Step count
1029        step_count: usize,
1030    }
1031
1032    impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
1033        ArchitectureAwareOptimizer<A, D>
1034    {
1035        /// Create a new architecture-aware optimizer
1036        pub fn new(config: OptimizationConfig<A>, strategy: ArchitectureStrategy) -> Self {
1037            Self {
1038                param_manager: ParameterManager::new(config),
1039                strategy,
1040                step_count: 0,
1041            }
1042        }
1043
1044        /// Apply architecture-specific optimizations
1045        pub fn apply_architecture_optimizations(&mut self) -> Result<()> {
1046            // Clone the strategy to avoid borrowing conflicts
1047            let strategy = self.strategy.clone();
1048            match strategy {
1049                ArchitectureStrategy::Transformer {
1050                    component_specific_lr,
1051                    layer_wise_decay,
1052                    attention_warmup,
1053                } => {
1054                    self.apply_transformer_optimizations(
1055                        component_specific_lr,
1056                        layer_wise_decay,
1057                        attention_warmup,
1058                    )?;
1059                }
1060                ArchitectureStrategy::ConvolutionalNet {
1061                    layer_type_lr,
1062                    depth_scaling,
1063                    bn_special_handling,
1064                } => {
1065                    self.apply_cnn_optimizations(
1066                        layer_type_lr,
1067                        depth_scaling,
1068                        bn_special_handling,
1069                    )?;
1070                }
1071                ArchitectureStrategy::RecurrentNet {
1072                    rnn_gradient_clip,
1073                    weight_type_lr,
1074                } => {
1075                    self.apply_rnn_optimizations(rnn_gradient_clip, weight_type_lr)?;
1076                }
1077                ArchitectureStrategy::Custom { rules } => {
1078                    self.apply_custom_optimizations(&rules)?;
1079                }
1080            }
1081            Ok(())
1082        }
1083
1084        /// Apply Transformer-specific optimizations
1085        fn apply_transformer_optimizations(
1086            &mut self,
1087            component_specific_lr: bool,
1088            layer_wise_decay: bool,
1089            attention_warmup: usize,
1090        ) -> Result<()> {
1091            if component_specific_lr {
1092                // Different learning rates for attention, ffn, and normalization
1093                self.set_component_learning_rates()?;
1094            }
1095
1096            if layer_wise_decay {
1097                // Apply layer-wise learning rate _decay
1098                self.apply_layer_wise_decay()?;
1099            }
1100
1101            if attention_warmup > 0 && self.step_count < attention_warmup {
1102                // Apply _warmup to attention parameters
1103                self.apply_attention_warmup(attention_warmup)?;
1104            }
1105
1106            Ok(())
1107        }
1108
1109        /// Apply CNN-specific optimizations
1110        fn apply_cnn_optimizations(
1111            &mut self,
1112            layer_type_lr: bool,
1113            depth_scaling: bool,
1114            bn_special_handling: bool,
1115        ) -> Result<()> {
1116            if layer_type_lr {
1117                // Different learning rates for conv vs fully connected layers
1118                self.set_layer_type_learning_rates()?;
1119            }
1120
1121            if depth_scaling {
1122                // Scale learning rates based on network depth
1123                self.apply_depth_scaling()?;
1124            }
1125
1126            if bn_special_handling {
1127                // Special _handling for batch normalization parameters
1128                self.apply_bn_optimizations()?;
1129            }
1130
1131            Ok(())
1132        }
1133
1134        /// Apply RNN-specific optimizations
1135        fn apply_rnn_optimizations(
1136            &mut self,
1137            rnn_gradient_clip: Option<f64>,
1138            weight_type_lr: bool,
1139        ) -> Result<()> {
1140            if let Some(clipvalue) = rnn_gradient_clip {
1141                // Apply RNN-specific gradient clipping
1142                self.apply_rnn_gradient_clipping(A::from(clipvalue).unwrap())?;
1143            }
1144
1145            if weight_type_lr {
1146                // Different learning rates for recurrent vs linear weights
1147                self.set_weight_type_learning_rates()?;
1148            }
1149
1150            Ok(())
1151        }
1152
1153        /// Apply custom optimizations
1154        fn apply_custom_optimizations(
1155            &mut self,
1156            rules: &HashMap<String, LayerConfig>,
1157        ) -> Result<()> {
1158            // Collect rules first to avoid borrowing conflicts
1159            let rule_entries: Vec<(String, LayerConfig)> = rules
1160                .iter()
1161                .map(|(name, config)| (name.clone(), config.clone()))
1162                .collect();
1163
1164            for (rule_name, config) in rule_entries {
1165                self.apply_custom_rule(&rule_name, &config)?;
1166            }
1167            Ok(())
1168        }
1169
1170        /// Set component-specific learning rates for Transformers
1171        fn set_component_learning_rates(&mut self) -> Result<()> {
1172            // Collect the data first to avoid borrowing conflicts
1173            let layer_rules: Vec<(LayerId, LayerOptimizationRule<A>)> = self
1174                .param_manager
1175                .get_all_parameters()
1176                .values()
1177                .map(|metadata| {
1178                    let mut rule = LayerOptimizationRule::default();
1179
1180                    // Determine learning rate multiplier based on parameter tags
1181                    if metadata.tags.contains(&"attention".to_string()) {
1182                        rule.lr_multiplier = A::from(1.2).unwrap(); // Higher LR for attention
1183                    } else if metadata.tags.contains(&"ffn".to_string()) {
1184                        rule.lr_multiplier = A::from(1.0).unwrap(); // Standard LR for FFN
1185                    } else if metadata.tags.contains(&"normalization".to_string()) {
1186                        rule.lr_multiplier = A::from(0.8).unwrap(); // Lower LR for normalization
1187                    }
1188
1189                    (metadata.layername.clone(), rule)
1190                })
1191                .collect();
1192
1193            // Now apply the rules
1194            for (layername, rule) in layer_rules {
1195                self.param_manager.set_layer_rule(layername, rule);
1196            }
1197            Ok(())
1198        }
1199
1200        /// Apply layer-wise learning rate decay
1201        fn apply_layer_wise_decay(&mut self) -> Result<()> {
1202            // Extract layer numbers from layer names and apply decay
1203            for (layerid, _) in self.param_manager.layer_architectures.clone() {
1204                if let Some(layer_num) = self.extract_layer_number(&layerid) {
1205                    let decay_factor = A::from(0.95_f64.powi(layer_num as i32)).unwrap();
1206                    let mut rule = self
1207                        .param_manager
1208                        .layer_rules
1209                        .get(&layerid)
1210                        .cloned()
1211                        .unwrap_or_default();
1212                    rule.lr_multiplier = rule.lr_multiplier * decay_factor;
1213                    self.param_manager.set_layer_rule(layerid, rule);
1214                }
1215            }
1216            Ok(())
1217        }
1218
1219        /// Apply attention parameter warmup
1220        fn apply_attention_warmup(&mut self, warmupsteps: usize) -> Result<()> {
1221            let warmup_factor = A::from(self.step_count as f64 / warmupsteps as f64).unwrap();
1222
1223            // Collect attention layers first
1224            let attention_layers: Vec<LayerId> = self
1225                .param_manager
1226                .get_all_parameters()
1227                .iter()
1228                .filter_map(|(_param_id, metadata)| {
1229                    if metadata.tags.contains(&"attention".to_string()) {
1230                        Some(metadata.layername.clone())
1231                    } else {
1232                        None
1233                    }
1234                })
1235                .collect();
1236
1237            // Apply warmup to attention layers
1238            for layername in attention_layers {
1239                let mut rule = self
1240                    .param_manager
1241                    .layer_rules
1242                    .get(&layername)
1243                    .cloned()
1244                    .unwrap_or_default();
1245                rule.lr_multiplier = rule.lr_multiplier * warmup_factor;
1246                self.param_manager.set_layer_rule(layername, rule);
1247            }
1248            Ok(())
1249        }
1250
1251        /// Set learning rates based on layer type (conv vs fc)
1252        fn set_layer_type_learning_rates(&mut self) -> Result<()> {
1253            for (layerid, architecture) in self.param_manager.layer_architectures.clone() {
1254                let mut rule = LayerOptimizationRule::default();
1255
1256                match architecture.layer_type.as_str() {
1257                    "conv" | "conv2d" | "conv3d" => {
1258                        rule.lr_multiplier = A::from(1.0).unwrap(); // Standard LR for conv
1259                    }
1260                    "linear" | "dense" | "fc" => {
1261                        rule.lr_multiplier = A::from(0.8).unwrap(); // Lower LR for FC
1262                    }
1263                    _ => {
1264                        rule.lr_multiplier = A::from(1.0).unwrap(); // Default
1265                    }
1266                }
1267
1268                self.param_manager.set_layer_rule(layerid, rule);
1269            }
1270            Ok(())
1271        }
1272
1273        /// Apply depth-based scaling
1274        fn apply_depth_scaling(&mut self) -> Result<()> {
1275            // Count total layers
1276            let total_layers = self.param_manager.layer_architectures.len();
1277
1278            for (i, (layerid, _)) in self
1279                .param_manager
1280                .layer_architectures
1281                .clone()
1282                .iter()
1283                .enumerate()
1284            {
1285                let depth_factor = A::from(1.0 - 0.1 * (i as f64 / total_layers as f64)).unwrap();
1286                let mut rule = self
1287                    .param_manager
1288                    .layer_rules
1289                    .get(layerid)
1290                    .cloned()
1291                    .unwrap_or_default();
1292                rule.lr_multiplier = rule.lr_multiplier * depth_factor;
1293                self.param_manager.set_layer_rule(layerid.clone(), rule);
1294            }
1295            Ok(())
1296        }
1297
1298        /// Apply batch normalization optimizations
1299        fn apply_bn_optimizations(&mut self) -> Result<()> {
1300            // Collect normalization layers first
1301            let norm_layers: Vec<LayerId> = self
1302                .param_manager
1303                .get_all_parameters()
1304                .iter()
1305                .filter_map(|(_param_id, metadata)| {
1306                    if metadata.paramtype == ParameterType::Normalization {
1307                        Some(metadata.layername.clone())
1308                    } else {
1309                        None
1310                    }
1311                })
1312                .collect();
1313
1314            // Apply optimization to normalization layers
1315            for layername in norm_layers {
1316                let mut rule = self
1317                    .param_manager
1318                    .layer_rules
1319                    .get(&layername)
1320                    .cloned()
1321                    .unwrap_or_default();
1322                // Higher learning rate and no weight decay for BN parameters
1323                rule.lr_multiplier = A::from(2.0).unwrap();
1324                rule.weight_decay_multiplier = A::zero();
1325                self.param_manager.set_layer_rule(layername, rule);
1326            }
1327            Ok(())
1328        }
1329
1330        /// Apply RNN-specific gradient clipping
1331        fn apply_rnn_gradient_clipping(&mut self, clipvalue: A) -> Result<()> {
1332            // This would be implemented in coordination with the gradient processing system
1333            // For now, we'll store the clip _value in the global config
1334            self.param_manager.global_config.gradient_clip = Some(clipvalue);
1335            Ok(())
1336        }
1337
1338        /// Set learning rates based on weight type (recurrent vs linear)
1339        fn set_weight_type_learning_rates(&mut self) -> Result<()> {
1340            // Collect weight type layers first
1341            let layer_rules: Vec<(LayerId, LayerOptimizationRule<A>)> = self
1342                .param_manager
1343                .get_all_parameters()
1344                .values()
1345                .map(|metadata| {
1346                    let mut rule = LayerOptimizationRule::default();
1347
1348                    if metadata.tags.contains(&"recurrent".to_string()) {
1349                        rule.lr_multiplier = A::from(0.5).unwrap(); // Lower LR for recurrent weights
1350                    } else if metadata.tags.contains(&"linear".to_string()) {
1351                        rule.lr_multiplier = A::from(1.0).unwrap(); // Standard LR for linear weights
1352                    }
1353
1354                    (metadata.layername.clone(), rule)
1355                })
1356                .collect();
1357
1358            // Apply the rules
1359            for (layername, rule) in layer_rules {
1360                self.param_manager.set_layer_rule(layername, rule);
1361            }
1362            Ok(())
1363        }
1364
1365        /// Apply custom optimization rule
1366        fn apply_custom_rule(&mut self, _rule_name: &str, config: &LayerConfig) -> Result<()> {
1367            // Custom rule implementation would depend on the specific rule
1368            // This is a placeholder for extensibility
1369            Ok(())
1370        }
1371
1372        /// Extract layer number from layer name (e.g., "layer_12" -> 12)
1373        fn extract_layer_number(&self, layername: &str) -> Option<usize> {
1374            layername.split('_').next_back()?.parse().ok()
1375        }
1376
1377        /// Step the optimizer
1378        pub fn step(&mut self) -> Result<()> {
1379            self.step_count += 1;
1380            self.apply_architecture_optimizations()
1381        }
1382
1383        /// Get parameter manager
1384        pub fn parameter_manager(&self) -> &ParameterManager<A, D> {
1385            &self.param_manager
1386        }
1387
1388        /// Get mutable parameter manager
1389        pub fn parameter_manager_mut(&mut self) -> &mut ParameterManager<A, D> {
1390            &mut self.param_manager
1391        }
1392    }
1393}
1394
1395#[cfg(test)]
1396mod tests {
1397    use super::*;
1398    use approx::assert_relative_eq;
1399
1400    #[test]
1401    fn test_parameter_manager_basic() {
1402        let config = OptimizationConfig::default();
1403        let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1404
1405        let metadata = ParameterMetadata {
1406            layername: "layer1".to_string(),
1407            param_name: "weight".to_string(),
1408            shape: vec![10, 5],
1409            requires_grad: true,
1410            paramtype: ParameterType::Weight,
1411            sharing_group: None,
1412            tags: vec!["dense".to_string()],
1413        };
1414
1415        manager
1416            .register_parameter("param1".to_string(), metadata)
1417            .unwrap();
1418
1419        assert!(manager
1420            .get_parameter_metadata(&"param1".to_string())
1421            .is_some());
1422        assert_eq!(manager.get_all_parameters().len(), 1);
1423        assert!(!manager.is_parameter_frozen(&"param1".to_string()));
1424    }
1425
1426    #[test]
1427    fn test_lazy_registration() {
1428        let config = OptimizationConfig::default();
1429        let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1430
1431        manager.enable_lazy_mode();
1432
1433        let metadata = ParameterMetadata {
1434            layername: "layer1".to_string(),
1435            param_name: "weight".to_string(),
1436            shape: vec![10, 5],
1437            requires_grad: true,
1438            paramtype: ParameterType::Weight,
1439            sharing_group: None,
1440            tags: vec![],
1441        };
1442
1443        manager
1444            .register_parameter("param1".to_string(), metadata)
1445            .unwrap();
1446
1447        // Parameter should not be registered yet
1448        assert_eq!(manager.get_all_parameters().len(), 0);
1449
1450        // Disable lazy mode to process pending registrations
1451        manager.disable_lazy_mode().unwrap();
1452
1453        // Now parameter should be registered
1454        assert_eq!(manager.get_all_parameters().len(), 1);
1455    }
1456
1457    #[test]
1458    fn test_layer_specific_rules() {
1459        let config = OptimizationConfig {
1460            base_learning_rate: 0.01,
1461            weight_decay: 0.001,
1462            gradient_clip: None,
1463            mixed_precision: false,
1464            architecture_optimizations: HashMap::new(),
1465        };
1466        let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1467
1468        let rule = LayerOptimizationRule {
1469            lr_multiplier: 2.0,
1470            weight_decay_multiplier: 0.5,
1471            frozen: false,
1472            custom_settings: HashMap::new(),
1473        };
1474
1475        manager.set_layer_rule("layer1".to_string(), rule);
1476
1477        let metadata = ParameterMetadata {
1478            layername: "layer1".to_string(),
1479            param_name: "weight".to_string(),
1480            shape: vec![10, 5],
1481            requires_grad: true,
1482            paramtype: ParameterType::Weight,
1483            sharing_group: None,
1484            tags: vec![],
1485        };
1486
1487        manager
1488            .register_parameter("param1".to_string(), metadata)
1489            .unwrap();
1490
1491        // Test effective learning rate
1492        let effective_lr = manager.get_effective_learning_rate(&"param1".to_string());
1493        assert_relative_eq!(effective_lr, 0.02, epsilon = 1e-6); // 0.01 * 2.0
1494
1495        // Test effective weight decay
1496        let effective_decay = manager.get_effective_weight_decay(&"param1".to_string());
1497        assert_relative_eq!(effective_decay, 0.0005, epsilon = 1e-6); // 0.001 * 0.5
1498    }
1499
1500    #[test]
1501    fn test_parameter_sharing() {
1502        let config = OptimizationConfig::default();
1503        let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1504
1505        let metadata1 = ParameterMetadata {
1506            layername: "layer1".to_string(),
1507            param_name: "weight".to_string(),
1508            shape: vec![10, 5],
1509            requires_grad: true,
1510            paramtype: ParameterType::Weight,
1511            sharing_group: Some("shared_weights".to_string()),
1512            tags: vec![],
1513        };
1514
1515        let metadata2 = ParameterMetadata {
1516            layername: "layer2".to_string(),
1517            param_name: "weight".to_string(),
1518            shape: vec![10, 5],
1519            requires_grad: true,
1520            paramtype: ParameterType::Weight,
1521            sharing_group: Some("shared_weights".to_string()),
1522            tags: vec![],
1523        };
1524
1525        manager
1526            .register_parameter("param1".to_string(), metadata1)
1527            .unwrap();
1528        manager
1529            .register_parameter("param2".to_string(), metadata2)
1530            .unwrap();
1531
1532        let sharing_group = manager.get_sharing_group("shared_weights").unwrap();
1533        assert_eq!(sharing_group.len(), 2);
1534        assert!(sharing_group.contains(&"param1".to_string()));
1535        assert!(sharing_group.contains(&"param2".to_string()));
1536    }
1537
1538    #[test]
1539    fn test_parameter_filtering() {
1540        let config = OptimizationConfig::default();
1541        let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1542
1543        let weight_metadata = ParameterMetadata {
1544            layername: "layer1".to_string(),
1545            param_name: "weight".to_string(),
1546            shape: vec![10, 5],
1547            requires_grad: true,
1548            paramtype: ParameterType::Weight,
1549            sharing_group: None,
1550            tags: vec![],
1551        };
1552
1553        let bias_metadata = ParameterMetadata {
1554            layername: "layer1".to_string(),
1555            param_name: "bias".to_string(),
1556            shape: vec![5],
1557            requires_grad: true,
1558            paramtype: ParameterType::Bias,
1559            sharing_group: None,
1560            tags: vec![],
1561        };
1562
1563        manager
1564            .register_parameter("weight".to_string(), weight_metadata)
1565            .unwrap();
1566        manager
1567            .register_parameter("bias".to_string(), bias_metadata)
1568            .unwrap();
1569
1570        // Test filtering by type
1571        let weights = manager.get_parameters_by_type(ParameterType::Weight);
1572        assert_eq!(weights.len(), 1);
1573        assert_eq!(weights[0], &"weight".to_string());
1574
1575        let biases = manager.get_parameters_by_type(ParameterType::Bias);
1576        assert_eq!(biases.len(), 1);
1577        assert_eq!(biases[0], &"bias".to_string());
1578
1579        // Test filtering by layer
1580        let layer_params = manager.get_parameters_by_layer(&"layer1".to_string());
1581        assert_eq!(layer_params.len(), 2);
1582
1583        // Test trainable parameters
1584        let trainable = manager.get_trainable_parameters();
1585        assert_eq!(trainable.len(), 2);
1586    }
1587
1588    #[test]
1589    fn test_architecture_aware_transformer() {
1590        use crate::neural_integration::architecture_aware::*;
1591
1592        let config = OptimizationConfig::default();
1593        let strategy = ArchitectureStrategy::Transformer {
1594            component_specific_lr: true,
1595            layer_wise_decay: true,
1596            attention_warmup: 1000,
1597        };
1598
1599        let mut optimizer =
1600            ArchitectureAwareOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(config, strategy);
1601
1602        // Register a layer architecture
1603        let layer_arch = LayerArchitecture {
1604            layer_type: "transformer_block".to_string(),
1605            input_dims: vec![512],
1606            output_dims: vec![512],
1607            config: HashMap::new(),
1608            trainable: true,
1609        };
1610
1611        optimizer
1612            .parameter_manager_mut()
1613            .register_layer("layer_0".to_string(), layer_arch);
1614
1615        // Register parameters with different tags
1616        let attention_metadata = ParameterMetadata {
1617            layername: "layer_0".to_string(),
1618            param_name: "attention_weight".to_string(),
1619            shape: vec![512, 512],
1620            requires_grad: true,
1621            paramtype: ParameterType::Attention,
1622            sharing_group: None,
1623            tags: vec!["attention".to_string()],
1624        };
1625
1626        optimizer
1627            .parameter_manager_mut()
1628            .register_parameter("attn_param".to_string(), attention_metadata)
1629            .unwrap();
1630
1631        // Apply optimizations
1632        optimizer.apply_architecture_optimizations().unwrap();
1633
1634        // Verify that attention parameters get special treatment
1635        assert!(optimizer
1636            .parameter_manager()
1637            .get_parameter_metadata(&"attn_param".to_string())
1638            .is_some());
1639    }
1640}