1use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11pub type LayerId = String;
15
16pub type ParamId = String;
18
19#[derive(Debug, Clone)]
21pub struct ParameterMetadata {
22 pub layername: LayerId,
24 pub param_name: ParamId,
26 pub shape: Vec<usize>,
28 pub requires_grad: bool,
30 pub paramtype: ParameterType,
32 pub sharing_group: Option<String>,
34 pub tags: Vec<String>,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum ParameterType {
41 Weight,
43 Bias,
45 Normalization,
47 Embedding,
49 Attention,
51 Custom,
53}
54
55#[derive(Debug, Clone)]
57pub struct LayerArchitecture {
58 pub layer_type: String,
60 pub input_dims: Vec<usize>,
62 pub output_dims: Vec<usize>,
64 pub config: HashMap<String, LayerConfig>,
66 pub trainable: bool,
68}
69
70#[derive(Debug, Clone)]
72pub enum LayerConfig {
73 Int(i64),
75 Float(f64),
77 String(String),
79 Bool(bool),
81 List(Vec<LayerConfig>),
83}
84
85pub trait ParameterOptimizer<A: Float, D: Dimension> {
87 fn register_parameter(
89 &mut self,
90 paramid: ParamId,
91 parameter: &Array<A, D>,
92 metadata: ParameterMetadata,
93 ) -> Result<()>;
94
95 fn step(
97 &mut self,
98 gradients: HashMap<ParamId, Array<A, D>>,
99 parameters: &mut HashMap<ParamId, Array<A, D>>,
100 ) -> Result<()>;
101
102 fn get_learning_rate(&self, paramid: &ParamId) -> Option<A>;
104
105 fn set_learning_rate(&mut self, paramid: &ParamId, lr: A) -> Result<()>;
107
108 fn get_parameter_state(&self, paramid: &ParamId) -> Option<&HashMap<String, Array<A, D>>>;
110
111 fn reset_state(&mut self);
113
114 fn registered_parameters(&self) -> Vec<ParamId>;
116}
117
118#[derive(Debug)]
120pub struct ParameterManager<A: Float, D: Dimension> {
121 parameters: HashMap<ParamId, ParameterMetadata>,
123 optimizer_states: HashMap<ParamId, HashMap<String, Array<A, D>>>,
125 layer_architectures: HashMap<LayerId, LayerArchitecture>,
127 sharing_groups: HashMap<String, Vec<ParamId>>,
129 layer_rules: HashMap<LayerId, LayerOptimizationRule<A>>,
131 global_config: OptimizationConfig<A>,
133 lazy_mode: bool,
135 pending_registrations: Vec<(ParamId, ParameterMetadata)>,
137}
138
139#[derive(Debug, Clone)]
141pub struct LayerOptimizationRule<A: Float> {
142 pub lr_multiplier: A,
144 pub weight_decay_multiplier: A,
146 pub frozen: bool,
148 pub custom_settings: HashMap<String, LayerConfig>,
150}
151
152#[derive(Debug, Clone)]
154pub struct OptimizationConfig<A: Float> {
155 pub base_learning_rate: A,
157 pub weight_decay: A,
159 pub gradient_clip: Option<A>,
161 pub mixed_precision: bool,
163 pub architecture_optimizations: HashMap<String, bool>,
165}
166
167impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
168 ParameterManager<A, D>
169{
170 pub fn new(config: OptimizationConfig<A>) -> Self {
172 Self {
173 parameters: HashMap::new(),
174 optimizer_states: HashMap::new(),
175 layer_architectures: HashMap::new(),
176 sharing_groups: HashMap::new(),
177 layer_rules: HashMap::new(),
178 global_config: config,
179 lazy_mode: false,
180 pending_registrations: Vec::new(),
181 }
182 }
183
184 pub fn enable_lazy_mode(&mut self) {
186 self.lazy_mode = true;
187 }
188
189 pub fn disable_lazy_mode(&mut self) -> Result<()> {
191 self.lazy_mode = false;
192
193 let pending = std::mem::take(&mut self.pending_registrations);
195 for (paramid, metadata) in pending {
196 self.register_parameter_impl(paramid, metadata)?;
197 }
198
199 Ok(())
200 }
201
202 pub fn register_layer(&mut self, layerid: LayerId, architecture: LayerArchitecture) {
204 self.layer_architectures.insert(layerid, architecture);
205 }
206
207 pub fn set_layer_rule(&mut self, layerid: LayerId, rule: LayerOptimizationRule<A>) {
209 self.layer_rules.insert(layerid, rule);
210 }
211
212 pub fn register_parameter(
214 &mut self,
215 paramid: ParamId,
216 metadata: ParameterMetadata,
217 ) -> Result<()> {
218 if self.lazy_mode {
219 self.pending_registrations.push((paramid, metadata));
220 Ok(())
221 } else {
222 self.register_parameter_impl(paramid, metadata)
223 }
224 }
225
226 fn register_parameter_impl(
228 &mut self,
229 paramid: ParamId,
230 metadata: ParameterMetadata,
231 ) -> Result<()> {
232 if let Some(sharing_group) = &metadata.sharing_group {
234 self.sharing_groups
235 .entry(sharing_group.clone())
236 .or_default()
237 .push(paramid.clone());
238 }
239
240 self.optimizer_states
242 .insert(paramid.clone(), HashMap::new());
243
244 self.parameters.insert(paramid, metadata);
246
247 Ok(())
248 }
249
250 pub fn get_effective_learning_rate(&self, paramid: &ParamId) -> A {
252 let base_lr = self.global_config.base_learning_rate;
253
254 if let Some(metadata) = self.parameters.get(paramid) {
255 if let Some(rule) = self.layer_rules.get(&metadata.layername) {
256 return base_lr * rule.lr_multiplier;
257 }
258 }
259
260 base_lr
261 }
262
263 pub fn get_effective_weight_decay(&self, paramid: &ParamId) -> A {
265 let base_decay = self.global_config.weight_decay;
266
267 if let Some(metadata) = self.parameters.get(paramid) {
268 if let Some(rule) = self.layer_rules.get(&metadata.layername) {
269 return base_decay * rule.weight_decay_multiplier;
270 }
271 }
272
273 base_decay
274 }
275
276 pub fn is_parameter_frozen(&self, paramid: &ParamId) -> bool {
278 if let Some(metadata) = self.parameters.get(paramid) {
279 if let Some(rule) = self.layer_rules.get(&metadata.layername) {
280 return rule.frozen;
281 }
282 }
283 false
284 }
285
286 pub fn get_sharing_group(&self, groupname: &str) -> Option<&[ParamId]> {
288 self.sharing_groups.get(groupname).map(|v| v.as_slice())
289 }
290
291 pub fn get_all_parameters(&self) -> &HashMap<ParamId, ParameterMetadata> {
293 &self.parameters
294 }
295
296 pub fn get_layer_architecture(&self, layerid: &LayerId) -> Option<&LayerArchitecture> {
298 self.layer_architectures.get(layerid)
299 }
300
301 pub fn get_parameter_metadata(&self, paramid: &ParamId) -> Option<&ParameterMetadata> {
303 self.parameters.get(paramid)
304 }
305
306 pub fn update_config(&mut self, config: OptimizationConfig<A>) {
308 self.global_config = config;
309 }
310
311 pub fn get_optimizer_state(&self, paramid: &ParamId) -> Option<&HashMap<String, Array<A, D>>> {
313 self.optimizer_states.get(paramid)
314 }
315
316 pub fn get_optimizer_state_mut(
318 &mut self,
319 paramid: &ParamId,
320 ) -> Option<&mut HashMap<String, Array<A, D>>> {
321 self.optimizer_states.get_mut(paramid)
322 }
323
324 pub fn init_optimizer_state(
326 &mut self,
327 paramid: &ParamId,
328 state_name: &str,
329 state: Array<A, D>,
330 ) -> Result<()> {
331 if let Some(states) = self.optimizer_states.get_mut(paramid) {
332 states.insert(state_name.to_string(), state);
333 Ok(())
334 } else {
335 Err(OptimError::InvalidConfig(format!(
336 "Parameter {} not registered",
337 paramid
338 )))
339 }
340 }
341
342 pub fn reset_optimizer_states(&mut self) {
344 for states in self.optimizer_states.values_mut() {
345 states.clear();
346 }
347 }
348
349 pub fn get_parameters_by_layer(&self, layerid: &LayerId) -> Vec<&ParamId> {
351 self.parameters
352 .iter()
353 .filter(|(_, metadata)| &metadata.layername == layerid)
354 .map(|(paramid, _)| paramid)
355 .collect()
356 }
357
358 pub fn get_parameters_by_type(&self, paramtype: ParameterType) -> Vec<&ParamId> {
360 self.parameters
361 .iter()
362 .filter(|(_, metadata)| metadata.paramtype == paramtype)
363 .map(|(paramid, _)| paramid)
364 .collect()
365 }
366
367 pub fn get_trainable_parameters(&self) -> Vec<&ParamId> {
369 self.parameters
370 .iter()
371 .filter(|(paramid, metadata)| {
372 metadata.requires_grad && !self.is_parameter_frozen(paramid)
373 })
374 .map(|(paramid, _)| paramid)
375 .collect()
376 }
377}
378
379impl<A: Float + Send + Sync> Default for OptimizationConfig<A> {
380 fn default() -> Self {
381 Self {
382 base_learning_rate: A::from(0.001).expect("unwrap failed"),
383 weight_decay: A::zero(),
384 gradient_clip: None,
385 mixed_precision: false,
386 architecture_optimizations: HashMap::new(),
387 }
388 }
389}
390
391impl<A: Float + Send + Sync> Default for LayerOptimizationRule<A> {
392 fn default() -> Self {
393 Self {
394 lr_multiplier: A::one(),
395 weight_decay_multiplier: A::one(),
396 frozen: false,
397 custom_settings: HashMap::new(),
398 }
399 }
400}
401
402pub mod forward_backward {
404 use super::*;
405
406 pub trait ForwardHook<A: Float, D: Dimension> {
408 fn pre_forward(&mut self, layerid: &LayerId, inputs: &[Array<A, D>]) -> Result<()>;
410
411 fn post_forward(&mut self, layerid: &LayerId, outputs: &[Array<A, D>]) -> Result<()>;
413 }
414
415 pub trait BackwardHook<A: Float, D: Dimension> {
417 fn pre_backward(&mut self, layerid: &LayerId, gradoutputs: &[Array<A, D>]) -> Result<()>;
419
420 fn post_backward(&mut self, layerid: &LayerId, gradinputs: &[Array<A, D>]) -> Result<()>;
422 }
423
424 pub struct NeuralIntegration<A: Float, D: Dimension> {
426 param_manager: ParameterManager<A, D>,
428 forward_hooks: HashMap<LayerId, Box<dyn ForwardHook<A, D>>>,
430 backward_hooks: HashMap<LayerId, Box<dyn BackwardHook<A, D>>>,
432 gradient_accumulation: bool,
434 accumulated_gradients: HashMap<ParamId, Array<A, D>>,
436 accumulation_count: usize,
438 }
439
440 impl<
441 A: Float
442 + ScalarOperand
443 + Debug
444 + 'static
445 + scirs2_core::numeric::FromPrimitive
446 + std::iter::Sum
447 + Send
448 + Sync,
449 D: Dimension + 'static,
450 > NeuralIntegration<A, D>
451 {
452 pub fn new(config: OptimizationConfig<A>) -> Self {
454 Self {
455 param_manager: ParameterManager::new(config),
456 forward_hooks: HashMap::new(),
457 backward_hooks: HashMap::new(),
458 gradient_accumulation: false,
459 accumulated_gradients: HashMap::new(),
460 accumulation_count: 0,
461 }
462 }
463
464 pub fn register_forward_hook<H>(&mut self, layerid: LayerId, hook: H)
466 where
467 H: ForwardHook<A, D> + 'static,
468 {
469 self.forward_hooks.insert(layerid, Box::new(hook));
470 }
471
472 pub fn register_backward_hook<H>(&mut self, layerid: LayerId, hook: H)
474 where
475 H: BackwardHook<A, D> + 'static,
476 {
477 self.backward_hooks.insert(layerid, Box::new(hook));
478 }
479
480 pub fn enable_gradient_accumulation(&mut self) {
482 self.gradient_accumulation = true;
483 }
484
485 pub fn disable_gradient_accumulation(&mut self) -> HashMap<ParamId, Array<A, D>> {
487 self.gradient_accumulation = false;
488 let result = std::mem::take(&mut self.accumulated_gradients);
489 self.accumulation_count = 0;
490 result
491 }
492
493 pub fn forward_pass(
495 &mut self,
496 layerid: &LayerId,
497 inputs: &[Array<A, D>],
498 ) -> Result<Vec<Array<A, D>>> {
499 if let Some(hook) = self.forward_hooks.get_mut(layerid) {
501 hook.pre_forward(layerid, inputs)?;
502 }
503
504 let layer_arch = self
506 .param_manager
507 .get_layer_architecture(layerid)
508 .ok_or_else(|| {
509 OptimError::InvalidConfig(format!("Layer {} not registered", layerid))
510 })?
511 .clone();
512
513 let outputs = match layer_arch.layer_type.as_str() {
515 "linear" | "dense" | "fc" => {
516 self.compute_linear_forward(layerid, inputs)?
518 }
519 "conv" | "conv2d" => {
520 self.compute_conv_forward(layerid, inputs)?
522 }
523 "activation" => {
524 self.compute_activation_forward(layerid, inputs, &layer_arch)?
526 }
527 "normalization" | "batchnorm" | "layernorm" => {
528 self.compute_normalization_forward(layerid, inputs)?
530 }
531 "dropout" => {
532 self.compute_dropout_forward(layerid, inputs, &layer_arch)?
534 }
535 "pooling" | "maxpool" | "avgpool" => {
536 self.compute_pooling_forward(layerid, inputs, &layer_arch)?
538 }
539 _ => {
540 inputs.to_vec()
542 }
543 };
544
545 if let Some(hook) = self.forward_hooks.get_mut(layerid) {
547 hook.post_forward(layerid, &outputs)?;
548 }
549
550 Ok(outputs)
551 }
552
553 fn compute_linear_forward(
555 &self,
556 layerid: &LayerId,
557 inputs: &[Array<A, D>],
558 ) -> Result<Vec<Array<A, D>>> {
559 if inputs.is_empty() {
562 return Err(OptimError::InvalidConfig(
563 "Linear layer requires input".to_string(),
564 ));
565 }
566
567 let layer_params = self.param_manager.get_parameters_by_layer(layerid);
569
570 let lr =
572 self.param_manager
573 .get_effective_learning_rate(layer_params.first().ok_or_else(|| {
574 OptimError::InvalidConfig("No parameters for linear layer".to_string())
575 })?);
576
577 let outputs: Vec<Array<A, D>> =
578 inputs.iter().map(|input| input.mapv(|x| x * lr)).collect();
579
580 Ok(outputs)
581 }
582
583 fn compute_conv_forward(
585 &self,
586 _layer_id: &LayerId,
587 inputs: &[Array<A, D>],
588 ) -> Result<Vec<Array<A, D>>> {
589 Ok(inputs.to_vec())
592 }
593
594 fn compute_activation_forward(
596 &self,
597 _layer_id: &LayerId,
598 inputs: &[Array<A, D>],
599 layer_arch: &LayerArchitecture,
600 ) -> Result<Vec<Array<A, D>>> {
601 let activation_type = layer_arch
602 .config
603 .get("activation")
604 .and_then(|v| match v {
605 LayerConfig::String(s) => Some(s.as_str()),
606 _ => None,
607 })
608 .unwrap_or("relu");
609
610 let outputs: Vec<Array<A, D>> = inputs
611 .iter()
612 .map(|input| {
613 match activation_type {
614 "relu" => input.mapv(|x| if x > A::zero() { x } else { A::zero() }),
615 "sigmoid" => input.mapv(|x| A::one() / (A::one() + (-x).exp())),
616 "tanh" => input.mapv(|x| x.tanh()),
617 "leaky_relu" => {
618 let alpha = A::from(0.01).expect("unwrap failed");
619 input.mapv(|x| if x > A::zero() { x } else { alpha * x })
620 }
621 _ => input.clone(), }
623 })
624 .collect();
625
626 Ok(outputs)
627 }
628
629 fn compute_normalization_forward(
631 &self,
632 _layer_id: &LayerId,
633 inputs: &[Array<A, D>],
634 ) -> Result<Vec<Array<A, D>>> {
635 let outputs: Vec<Array<A, D>> = inputs
637 .iter()
638 .map(|input| {
639 let mean = input.iter().copied().sum::<A>()
640 / A::from(input.len()).unwrap_or(A::zero());
641 let variance = input
642 .mapv(|x| (x - mean).powi(2))
643 .mean()
644 .unwrap_or(A::one());
645 let std_dev = variance.sqrt();
646 let epsilon = A::from(1e-5).expect("unwrap failed");
647
648 input.mapv(|x| (x - mean) / (std_dev + epsilon))
649 })
650 .collect();
651
652 Ok(outputs)
653 }
654
655 fn compute_dropout_forward(
657 &self,
658 _layer_id: &LayerId,
659 inputs: &[Array<A, D>],
660 layer_arch: &LayerArchitecture,
661 ) -> Result<Vec<Array<A, D>>> {
662 let dropout_rate = layer_arch
663 .config
664 .get("dropout_rate")
665 .and_then(|v| match v {
666 LayerConfig::Float(f) => Some(A::from(*f).expect("unwrap failed")),
667 _ => None,
668 })
669 .unwrap_or(A::from(0.5).expect("unwrap failed"));
670
671 let scale = A::one() - dropout_rate;
674 let outputs: Vec<Array<A, D>> = inputs
675 .iter()
676 .map(|input| input.mapv(|x| x * scale))
677 .collect();
678
679 Ok(outputs)
680 }
681
682 fn compute_pooling_forward(
684 &self,
685 _layer_id: &LayerId,
686 inputs: &[Array<A, D>],
687 _layer_arch: &LayerArchitecture,
688 ) -> Result<Vec<Array<A, D>>> {
689 Ok(inputs.to_vec())
692 }
693
694 pub fn backward_pass(
696 &mut self,
697 layerid: &LayerId,
698 grad_outputs: &[Array<A, D>],
699 ) -> Result<Vec<Array<A, D>>> {
700 if let Some(hook) = self.backward_hooks.get_mut(layerid) {
702 hook.pre_backward(layerid, grad_outputs)?;
703 }
704
705 let layer_arch = self
707 .param_manager
708 .get_layer_architecture(layerid)
709 .ok_or_else(|| {
710 OptimError::InvalidConfig(format!("Layer {} not registered", layerid))
711 })?
712 .clone();
713
714 let grad_inputs = match layer_arch.layer_type.as_str() {
716 "linear" | "dense" | "fc" => {
717 self.compute_linear_backward(layerid, grad_outputs)?
719 }
720 "conv" | "conv2d" => {
721 self.compute_conv_backward(layerid, grad_outputs)?
723 }
724 "activation" => {
725 self.compute_activation_backward(layerid, grad_outputs, &layer_arch)?
727 }
728 "normalization" | "batchnorm" | "layernorm" => {
729 self.compute_normalization_backward(layerid, grad_outputs)?
731 }
732 "dropout" => {
733 self.compute_dropout_backward(layerid, grad_outputs, &layer_arch)?
735 }
736 "pooling" | "maxpool" | "avgpool" => {
737 self.compute_pooling_backward(layerid, grad_outputs, &layer_arch)?
739 }
740 _ => {
741 grad_outputs.to_vec()
743 }
744 };
745
746 let clipped_grads =
748 if let Some(clipvalue) = self.param_manager.global_config.gradient_clip {
749 self.apply_gradient_clipping(grad_inputs, clipvalue)?
750 } else {
751 grad_inputs
752 };
753
754 if let Some(hook) = self.backward_hooks.get_mut(layerid) {
756 hook.post_backward(layerid, &clipped_grads)?;
757 }
758
759 Ok(clipped_grads)
760 }
761
762 fn compute_linear_backward(
764 &mut self,
765 layerid: &LayerId,
766 grad_outputs: &[Array<A, D>],
767 ) -> Result<Vec<Array<A, D>>> {
768 if grad_outputs.is_empty() {
769 return Err(OptimError::InvalidConfig(
770 "Linear layer backward requires gradients".to_string(),
771 ));
772 }
773
774 let layer_params = self.param_manager.get_parameters_by_layer(layerid);
776
777 if self.gradient_accumulation {
779 let mut param_grads = HashMap::new();
780 for (i, paramid) in layer_params.iter().enumerate() {
781 if i < grad_outputs.len() {
782 param_grads.insert((*paramid).clone(), grad_outputs[i].clone());
783 }
784 }
785 self.accumulate_gradients(param_grads)?;
786 }
787
788 let lr_decay = A::from(0.9).expect("unwrap failed");
790 let grad_inputs: Vec<Array<A, D>> = grad_outputs
791 .iter()
792 .map(|grad| grad.mapv(|x| x * lr_decay))
793 .collect();
794
795 Ok(grad_inputs)
796 }
797
798 fn compute_conv_backward(
800 &self,
801 _layer_id: &LayerId,
802 grad_outputs: &[Array<A, D>],
803 ) -> Result<Vec<Array<A, D>>> {
804 Ok(grad_outputs.to_vec())
807 }
808
809 fn compute_activation_backward(
811 &self,
812 _layer_id: &LayerId,
813 grad_outputs: &[Array<A, D>],
814 layer_arch: &LayerArchitecture,
815 ) -> Result<Vec<Array<A, D>>> {
816 let activation_type = layer_arch
817 .config
818 .get("activation")
819 .and_then(|v| match v {
820 LayerConfig::String(s) => Some(s.as_str()),
821 _ => None,
822 })
823 .unwrap_or("relu");
824
825 let grad_inputs: Vec<Array<A, D>> = grad_outputs
827 .iter()
828 .map(|grad| {
829 match activation_type {
830 "relu" => {
831 grad.mapv(|g| if g > A::zero() { g } else { A::zero() })
834 }
835 "sigmoid" => {
836 let factor = A::from(0.25).expect("unwrap failed"); grad.mapv(|g| g * factor)
840 }
841 "tanh" => {
842 let factor = A::from(0.5).expect("unwrap failed");
845 grad.mapv(|g| g * factor)
846 }
847 "leaky_relu" => {
848 let alpha = A::from(0.01).expect("unwrap failed");
849 grad.mapv(|g| if g > A::zero() { g } else { alpha * g })
850 }
851 _ => grad.clone(), }
853 })
854 .collect();
855
856 Ok(grad_inputs)
857 }
858
859 fn compute_normalization_backward(
861 &self,
862 _layer_id: &LayerId,
863 grad_outputs: &[Array<A, D>],
864 ) -> Result<Vec<Array<A, D>>> {
865 let scale_factor = A::from(0.9).expect("unwrap failed");
868 let grad_inputs: Vec<Array<A, D>> = grad_outputs
869 .iter()
870 .map(|grad| grad.mapv(|g| g * scale_factor))
871 .collect();
872
873 Ok(grad_inputs)
874 }
875
876 fn compute_dropout_backward(
878 &self,
879 _layer_id: &LayerId,
880 grad_outputs: &[Array<A, D>],
881 layer_arch: &LayerArchitecture,
882 ) -> Result<Vec<Array<A, D>>> {
883 let dropout_rate = layer_arch
884 .config
885 .get("dropout_rate")
886 .and_then(|v| match v {
887 LayerConfig::Float(f) => Some(A::from(*f).expect("unwrap failed")),
888 _ => None,
889 })
890 .unwrap_or(A::from(0.5).expect("unwrap failed"));
891
892 let scale = A::one() - dropout_rate;
894 let grad_inputs: Vec<Array<A, D>> = grad_outputs
895 .iter()
896 .map(|grad| grad.mapv(|g| g * scale))
897 .collect();
898
899 Ok(grad_inputs)
900 }
901
902 fn compute_pooling_backward(
904 &self,
905 _layer_id: &LayerId,
906 grad_outputs: &[Array<A, D>],
907 _layer_arch: &LayerArchitecture,
908 ) -> Result<Vec<Array<A, D>>> {
909 Ok(grad_outputs.to_vec())
912 }
913
914 fn apply_gradient_clipping(
916 &self,
917 gradients: Vec<Array<A, D>>,
918 clipvalue: A,
919 ) -> Result<Vec<Array<A, D>>> {
920 let clipped: Vec<Array<A, D>> = gradients
921 .into_iter()
922 .map(|grad| {
923 let norm = grad.mapv(|x| x * x).sum().sqrt();
925
926 if norm > clipvalue {
927 let scale = clipvalue / norm;
929 grad.mapv(|x| x * scale)
930 } else {
931 grad
932 }
933 })
934 .collect();
935
936 Ok(clipped)
937 }
938
939 pub fn accumulate_gradients(
941 &mut self,
942 gradients: HashMap<ParamId, Array<A, D>>,
943 ) -> Result<()> {
944 if !self.gradient_accumulation {
945 return Err(OptimError::InvalidConfig(
946 "Gradient accumulation not enabled".to_string(),
947 ));
948 }
949
950 self.accumulation_count += 1;
951
952 for (paramid, grad) in gradients {
953 if let Some(acc_grad) = self.accumulated_gradients.get_mut(¶mid) {
954 *acc_grad = acc_grad.clone() + grad;
956 } else {
957 self.accumulated_gradients.insert(paramid, grad);
959 }
960 }
961
962 Ok(())
963 }
964
965 pub fn parameter_manager(&self) -> &ParameterManager<A, D> {
967 &self.param_manager
968 }
969
970 pub fn parameter_manager_mut(&mut self) -> &mut ParameterManager<A, D> {
972 &mut self.param_manager
973 }
974
975 pub fn accumulation_count(&self) -> usize {
977 self.accumulation_count
978 }
979 }
980}
981
982pub mod architecture_aware {
984 use super::*;
985
986 #[derive(Debug, Clone)]
988 pub enum ArchitectureStrategy {
989 Transformer {
991 component_specific_lr: bool,
993 layer_wise_decay: bool,
995 attention_warmup: usize,
997 },
998 ConvolutionalNet {
1000 layer_type_lr: bool,
1002 depth_scaling: bool,
1004 bn_special_handling: bool,
1006 },
1007 RecurrentNet {
1009 rnn_gradient_clip: Option<f64>,
1011 weight_type_lr: bool,
1013 },
1014 Custom {
1016 rules: HashMap<String, LayerConfig>,
1018 },
1019 }
1020
1021 #[derive(Debug)]
1023 pub struct ArchitectureAwareOptimizer<A: Float, D: Dimension> {
1024 param_manager: ParameterManager<A, D>,
1026 strategy: ArchitectureStrategy,
1028 step_count: usize,
1030 }
1031
1032 impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
1033 ArchitectureAwareOptimizer<A, D>
1034 {
1035 pub fn new(config: OptimizationConfig<A>, strategy: ArchitectureStrategy) -> Self {
1037 Self {
1038 param_manager: ParameterManager::new(config),
1039 strategy,
1040 step_count: 0,
1041 }
1042 }
1043
1044 pub fn apply_architecture_optimizations(&mut self) -> Result<()> {
1046 let strategy = self.strategy.clone();
1048 match strategy {
1049 ArchitectureStrategy::Transformer {
1050 component_specific_lr,
1051 layer_wise_decay,
1052 attention_warmup,
1053 } => {
1054 self.apply_transformer_optimizations(
1055 component_specific_lr,
1056 layer_wise_decay,
1057 attention_warmup,
1058 )?;
1059 }
1060 ArchitectureStrategy::ConvolutionalNet {
1061 layer_type_lr,
1062 depth_scaling,
1063 bn_special_handling,
1064 } => {
1065 self.apply_cnn_optimizations(
1066 layer_type_lr,
1067 depth_scaling,
1068 bn_special_handling,
1069 )?;
1070 }
1071 ArchitectureStrategy::RecurrentNet {
1072 rnn_gradient_clip,
1073 weight_type_lr,
1074 } => {
1075 self.apply_rnn_optimizations(rnn_gradient_clip, weight_type_lr)?;
1076 }
1077 ArchitectureStrategy::Custom { rules } => {
1078 self.apply_custom_optimizations(&rules)?;
1079 }
1080 }
1081 Ok(())
1082 }
1083
1084 fn apply_transformer_optimizations(
1086 &mut self,
1087 component_specific_lr: bool,
1088 layer_wise_decay: bool,
1089 attention_warmup: usize,
1090 ) -> Result<()> {
1091 if component_specific_lr {
1092 self.set_component_learning_rates()?;
1094 }
1095
1096 if layer_wise_decay {
1097 self.apply_layer_wise_decay()?;
1099 }
1100
1101 if attention_warmup > 0 && self.step_count < attention_warmup {
1102 self.apply_attention_warmup(attention_warmup)?;
1104 }
1105
1106 Ok(())
1107 }
1108
1109 fn apply_cnn_optimizations(
1111 &mut self,
1112 layer_type_lr: bool,
1113 depth_scaling: bool,
1114 bn_special_handling: bool,
1115 ) -> Result<()> {
1116 if layer_type_lr {
1117 self.set_layer_type_learning_rates()?;
1119 }
1120
1121 if depth_scaling {
1122 self.apply_depth_scaling()?;
1124 }
1125
1126 if bn_special_handling {
1127 self.apply_bn_optimizations()?;
1129 }
1130
1131 Ok(())
1132 }
1133
1134 fn apply_rnn_optimizations(
1136 &mut self,
1137 rnn_gradient_clip: Option<f64>,
1138 weight_type_lr: bool,
1139 ) -> Result<()> {
1140 if let Some(clipvalue) = rnn_gradient_clip {
1141 self.apply_rnn_gradient_clipping(A::from(clipvalue).expect("unwrap failed"))?;
1143 }
1144
1145 if weight_type_lr {
1146 self.set_weight_type_learning_rates()?;
1148 }
1149
1150 Ok(())
1151 }
1152
1153 fn apply_custom_optimizations(
1155 &mut self,
1156 rules: &HashMap<String, LayerConfig>,
1157 ) -> Result<()> {
1158 let rule_entries: Vec<(String, LayerConfig)> = rules
1160 .iter()
1161 .map(|(name, config)| (name.clone(), config.clone()))
1162 .collect();
1163
1164 for (rule_name, config) in rule_entries {
1165 self.apply_custom_rule(&rule_name, &config)?;
1166 }
1167 Ok(())
1168 }
1169
1170 fn set_component_learning_rates(&mut self) -> Result<()> {
1172 let layer_rules: Vec<(LayerId, LayerOptimizationRule<A>)> = self
1174 .param_manager
1175 .get_all_parameters()
1176 .values()
1177 .map(|metadata| {
1178 let mut rule = LayerOptimizationRule::default();
1179
1180 if metadata.tags.contains(&"attention".to_string()) {
1182 rule.lr_multiplier = A::from(1.2).expect("unwrap failed");
1183 } else if metadata.tags.contains(&"ffn".to_string()) {
1185 rule.lr_multiplier = A::from(1.0).expect("unwrap failed");
1186 } else if metadata.tags.contains(&"normalization".to_string()) {
1188 rule.lr_multiplier = A::from(0.8).expect("unwrap failed");
1189 }
1191
1192 (metadata.layername.clone(), rule)
1193 })
1194 .collect();
1195
1196 for (layername, rule) in layer_rules {
1198 self.param_manager.set_layer_rule(layername, rule);
1199 }
1200 Ok(())
1201 }
1202
1203 fn apply_layer_wise_decay(&mut self) -> Result<()> {
1205 for (layerid, _) in self.param_manager.layer_architectures.clone() {
1207 if let Some(layer_num) = self.extract_layer_number(&layerid) {
1208 let decay_factor =
1209 A::from(0.95_f64.powi(layer_num as i32)).expect("unwrap failed");
1210 let mut rule = self
1211 .param_manager
1212 .layer_rules
1213 .get(&layerid)
1214 .cloned()
1215 .unwrap_or_default();
1216 rule.lr_multiplier = rule.lr_multiplier * decay_factor;
1217 self.param_manager.set_layer_rule(layerid, rule);
1218 }
1219 }
1220 Ok(())
1221 }
1222
1223 fn apply_attention_warmup(&mut self, warmupsteps: usize) -> Result<()> {
1225 let warmup_factor =
1226 A::from(self.step_count as f64 / warmupsteps as f64).expect("unwrap failed");
1227
1228 let attention_layers: Vec<LayerId> = self
1230 .param_manager
1231 .get_all_parameters()
1232 .values()
1233 .filter_map(|metadata| {
1234 if metadata.tags.contains(&"attention".to_string()) {
1235 Some(metadata.layername.clone())
1236 } else {
1237 None
1238 }
1239 })
1240 .collect();
1241
1242 for layername in attention_layers {
1244 let mut rule = self
1245 .param_manager
1246 .layer_rules
1247 .get(&layername)
1248 .cloned()
1249 .unwrap_or_default();
1250 rule.lr_multiplier = rule.lr_multiplier * warmup_factor;
1251 self.param_manager.set_layer_rule(layername, rule);
1252 }
1253 Ok(())
1254 }
1255
1256 fn set_layer_type_learning_rates(&mut self) -> Result<()> {
1258 for (layerid, architecture) in self.param_manager.layer_architectures.clone() {
1259 let mut rule = LayerOptimizationRule::default();
1260
1261 match architecture.layer_type.as_str() {
1262 "conv" | "conv2d" | "conv3d" => {
1263 rule.lr_multiplier = A::from(1.0).expect("unwrap failed");
1264 }
1266 "linear" | "dense" | "fc" => {
1267 rule.lr_multiplier = A::from(0.8).expect("unwrap failed");
1268 }
1270 _ => {
1271 rule.lr_multiplier = A::from(1.0).expect("unwrap failed");
1272 }
1274 }
1275
1276 self.param_manager.set_layer_rule(layerid, rule);
1277 }
1278 Ok(())
1279 }
1280
1281 fn apply_depth_scaling(&mut self) -> Result<()> {
1283 let total_layers = self.param_manager.layer_architectures.len();
1285
1286 for (i, (layerid, _)) in self
1287 .param_manager
1288 .layer_architectures
1289 .clone()
1290 .iter()
1291 .enumerate()
1292 {
1293 let depth_factor =
1294 A::from(1.0 - 0.1 * (i as f64 / total_layers as f64)).expect("unwrap failed");
1295 let mut rule = self
1296 .param_manager
1297 .layer_rules
1298 .get(layerid)
1299 .cloned()
1300 .unwrap_or_default();
1301 rule.lr_multiplier = rule.lr_multiplier * depth_factor;
1302 self.param_manager.set_layer_rule(layerid.clone(), rule);
1303 }
1304 Ok(())
1305 }
1306
1307 fn apply_bn_optimizations(&mut self) -> Result<()> {
1309 let norm_layers: Vec<LayerId> = self
1311 .param_manager
1312 .get_all_parameters()
1313 .values()
1314 .filter_map(|metadata| {
1315 if metadata.paramtype == ParameterType::Normalization {
1316 Some(metadata.layername.clone())
1317 } else {
1318 None
1319 }
1320 })
1321 .collect();
1322
1323 for layername in norm_layers {
1325 let mut rule = self
1326 .param_manager
1327 .layer_rules
1328 .get(&layername)
1329 .cloned()
1330 .unwrap_or_default();
1331 rule.lr_multiplier = A::from(2.0).expect("unwrap failed");
1333 rule.weight_decay_multiplier = A::zero();
1334 self.param_manager.set_layer_rule(layername, rule);
1335 }
1336 Ok(())
1337 }
1338
1339 fn apply_rnn_gradient_clipping(&mut self, clipvalue: A) -> Result<()> {
1341 self.param_manager.global_config.gradient_clip = Some(clipvalue);
1344 Ok(())
1345 }
1346
1347 fn set_weight_type_learning_rates(&mut self) -> Result<()> {
1349 let layer_rules: Vec<(LayerId, LayerOptimizationRule<A>)> = self
1351 .param_manager
1352 .get_all_parameters()
1353 .values()
1354 .map(|metadata| {
1355 let mut rule = LayerOptimizationRule::default();
1356
1357 if metadata.tags.contains(&"recurrent".to_string()) {
1358 rule.lr_multiplier = A::from(0.5).expect("unwrap failed");
1359 } else if metadata.tags.contains(&"linear".to_string()) {
1361 rule.lr_multiplier = A::from(1.0).expect("unwrap failed");
1362 }
1364
1365 (metadata.layername.clone(), rule)
1366 })
1367 .collect();
1368
1369 for (layername, rule) in layer_rules {
1371 self.param_manager.set_layer_rule(layername, rule);
1372 }
1373 Ok(())
1374 }
1375
1376 fn apply_custom_rule(&mut self, _rule_name: &str, config: &LayerConfig) -> Result<()> {
1378 Ok(())
1381 }
1382
1383 fn extract_layer_number(&self, layername: &str) -> Option<usize> {
1385 layername.split('_').next_back()?.parse().ok()
1386 }
1387
1388 pub fn step(&mut self) -> Result<()> {
1390 self.step_count += 1;
1391 self.apply_architecture_optimizations()
1392 }
1393
1394 pub fn parameter_manager(&self) -> &ParameterManager<A, D> {
1396 &self.param_manager
1397 }
1398
1399 pub fn parameter_manager_mut(&mut self) -> &mut ParameterManager<A, D> {
1401 &mut self.param_manager
1402 }
1403 }
1404}
1405
1406#[cfg(test)]
1407mod tests {
1408 use super::*;
1409 use approx::assert_relative_eq;
1410
1411 #[test]
1412 fn test_parameter_manager_basic() {
1413 let config = OptimizationConfig::default();
1414 let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1415
1416 let metadata = ParameterMetadata {
1417 layername: "layer1".to_string(),
1418 param_name: "weight".to_string(),
1419 shape: vec![10, 5],
1420 requires_grad: true,
1421 paramtype: ParameterType::Weight,
1422 sharing_group: None,
1423 tags: vec!["dense".to_string()],
1424 };
1425
1426 manager
1427 .register_parameter("param1".to_string(), metadata)
1428 .expect("unwrap failed");
1429
1430 assert!(manager
1431 .get_parameter_metadata(&"param1".to_string())
1432 .is_some());
1433 assert_eq!(manager.get_all_parameters().len(), 1);
1434 assert!(!manager.is_parameter_frozen(&"param1".to_string()));
1435 }
1436
1437 #[test]
1438 fn test_lazy_registration() {
1439 let config = OptimizationConfig::default();
1440 let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1441
1442 manager.enable_lazy_mode();
1443
1444 let metadata = ParameterMetadata {
1445 layername: "layer1".to_string(),
1446 param_name: "weight".to_string(),
1447 shape: vec![10, 5],
1448 requires_grad: true,
1449 paramtype: ParameterType::Weight,
1450 sharing_group: None,
1451 tags: vec![],
1452 };
1453
1454 manager
1455 .register_parameter("param1".to_string(), metadata)
1456 .expect("unwrap failed");
1457
1458 assert_eq!(manager.get_all_parameters().len(), 0);
1460
1461 manager.disable_lazy_mode().expect("unwrap failed");
1463
1464 assert_eq!(manager.get_all_parameters().len(), 1);
1466 }
1467
1468 #[test]
1469 fn test_layer_specific_rules() {
1470 let config = OptimizationConfig {
1471 base_learning_rate: 0.01,
1472 weight_decay: 0.001,
1473 gradient_clip: None,
1474 mixed_precision: false,
1475 architecture_optimizations: HashMap::new(),
1476 };
1477 let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1478
1479 let rule = LayerOptimizationRule {
1480 lr_multiplier: 2.0,
1481 weight_decay_multiplier: 0.5,
1482 frozen: false,
1483 custom_settings: HashMap::new(),
1484 };
1485
1486 manager.set_layer_rule("layer1".to_string(), rule);
1487
1488 let metadata = ParameterMetadata {
1489 layername: "layer1".to_string(),
1490 param_name: "weight".to_string(),
1491 shape: vec![10, 5],
1492 requires_grad: true,
1493 paramtype: ParameterType::Weight,
1494 sharing_group: None,
1495 tags: vec![],
1496 };
1497
1498 manager
1499 .register_parameter("param1".to_string(), metadata)
1500 .expect("unwrap failed");
1501
1502 let effective_lr = manager.get_effective_learning_rate(&"param1".to_string());
1504 assert_relative_eq!(effective_lr, 0.02, epsilon = 1e-6); let effective_decay = manager.get_effective_weight_decay(&"param1".to_string());
1508 assert_relative_eq!(effective_decay, 0.0005, epsilon = 1e-6); }
1510
1511 #[test]
1512 fn test_parameter_sharing() {
1513 let config = OptimizationConfig::default();
1514 let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1515
1516 let metadata1 = ParameterMetadata {
1517 layername: "layer1".to_string(),
1518 param_name: "weight".to_string(),
1519 shape: vec![10, 5],
1520 requires_grad: true,
1521 paramtype: ParameterType::Weight,
1522 sharing_group: Some("shared_weights".to_string()),
1523 tags: vec![],
1524 };
1525
1526 let metadata2 = ParameterMetadata {
1527 layername: "layer2".to_string(),
1528 param_name: "weight".to_string(),
1529 shape: vec![10, 5],
1530 requires_grad: true,
1531 paramtype: ParameterType::Weight,
1532 sharing_group: Some("shared_weights".to_string()),
1533 tags: vec![],
1534 };
1535
1536 manager
1537 .register_parameter("param1".to_string(), metadata1)
1538 .expect("unwrap failed");
1539 manager
1540 .register_parameter("param2".to_string(), metadata2)
1541 .expect("unwrap failed");
1542
1543 let sharing_group = manager
1544 .get_sharing_group("shared_weights")
1545 .expect("unwrap failed");
1546 assert_eq!(sharing_group.len(), 2);
1547 assert!(sharing_group.contains(&"param1".to_string()));
1548 assert!(sharing_group.contains(&"param2".to_string()));
1549 }
1550
1551 #[test]
1552 fn test_parameter_filtering() {
1553 let config = OptimizationConfig::default();
1554 let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1555
1556 let weight_metadata = ParameterMetadata {
1557 layername: "layer1".to_string(),
1558 param_name: "weight".to_string(),
1559 shape: vec![10, 5],
1560 requires_grad: true,
1561 paramtype: ParameterType::Weight,
1562 sharing_group: None,
1563 tags: vec![],
1564 };
1565
1566 let bias_metadata = ParameterMetadata {
1567 layername: "layer1".to_string(),
1568 param_name: "bias".to_string(),
1569 shape: vec![5],
1570 requires_grad: true,
1571 paramtype: ParameterType::Bias,
1572 sharing_group: None,
1573 tags: vec![],
1574 };
1575
1576 manager
1577 .register_parameter("weight".to_string(), weight_metadata)
1578 .expect("unwrap failed");
1579 manager
1580 .register_parameter("bias".to_string(), bias_metadata)
1581 .expect("unwrap failed");
1582
1583 let weights = manager.get_parameters_by_type(ParameterType::Weight);
1585 assert_eq!(weights.len(), 1);
1586 assert_eq!(weights[0], &"weight".to_string());
1587
1588 let biases = manager.get_parameters_by_type(ParameterType::Bias);
1589 assert_eq!(biases.len(), 1);
1590 assert_eq!(biases[0], &"bias".to_string());
1591
1592 let layer_params = manager.get_parameters_by_layer(&"layer1".to_string());
1594 assert_eq!(layer_params.len(), 2);
1595
1596 let trainable = manager.get_trainable_parameters();
1598 assert_eq!(trainable.len(), 2);
1599 }
1600
1601 #[test]
1602 fn test_architecture_aware_transformer() {
1603 use crate::neural_integration::architecture_aware::*;
1604
1605 let config = OptimizationConfig::default();
1606 let strategy = ArchitectureStrategy::Transformer {
1607 component_specific_lr: true,
1608 layer_wise_decay: true,
1609 attention_warmup: 1000,
1610 };
1611
1612 let mut optimizer =
1613 ArchitectureAwareOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(config, strategy);
1614
1615 let layer_arch = LayerArchitecture {
1617 layer_type: "transformer_block".to_string(),
1618 input_dims: vec![512],
1619 output_dims: vec![512],
1620 config: HashMap::new(),
1621 trainable: true,
1622 };
1623
1624 optimizer
1625 .parameter_manager_mut()
1626 .register_layer("layer_0".to_string(), layer_arch);
1627
1628 let attention_metadata = ParameterMetadata {
1630 layername: "layer_0".to_string(),
1631 param_name: "attention_weight".to_string(),
1632 shape: vec![512, 512],
1633 requires_grad: true,
1634 paramtype: ParameterType::Attention,
1635 sharing_group: None,
1636 tags: vec!["attention".to_string()],
1637 };
1638
1639 optimizer
1640 .parameter_manager_mut()
1641 .register_parameter("attn_param".to_string(), attention_metadata)
1642 .expect("unwrap failed");
1643
1644 optimizer
1646 .apply_architecture_optimizations()
1647 .expect("unwrap failed");
1648
1649 assert!(optimizer
1651 .parameter_manager()
1652 .get_parameter_metadata(&"attn_param".to_string())
1653 .is_some());
1654 }
1655}