1use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11pub type LayerId = String;
15
16pub type ParamId = String;
18
19#[derive(Debug, Clone)]
21pub struct ParameterMetadata {
22 pub layername: LayerId,
24 pub param_name: ParamId,
26 pub shape: Vec<usize>,
28 pub requires_grad: bool,
30 pub paramtype: ParameterType,
32 pub sharing_group: Option<String>,
34 pub tags: Vec<String>,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum ParameterType {
41 Weight,
43 Bias,
45 Normalization,
47 Embedding,
49 Attention,
51 Custom,
53}
54
55#[derive(Debug, Clone)]
57pub struct LayerArchitecture {
58 pub layer_type: String,
60 pub input_dims: Vec<usize>,
62 pub output_dims: Vec<usize>,
64 pub config: HashMap<String, LayerConfig>,
66 pub trainable: bool,
68}
69
70#[derive(Debug, Clone)]
72pub enum LayerConfig {
73 Int(i64),
75 Float(f64),
77 String(String),
79 Bool(bool),
81 List(Vec<LayerConfig>),
83}
84
85pub trait ParameterOptimizer<A: Float, D: Dimension> {
87 fn register_parameter(
89 &mut self,
90 paramid: ParamId,
91 parameter: &Array<A, D>,
92 metadata: ParameterMetadata,
93 ) -> Result<()>;
94
95 fn step(
97 &mut self,
98 gradients: HashMap<ParamId, Array<A, D>>,
99 parameters: &mut HashMap<ParamId, Array<A, D>>,
100 ) -> Result<()>;
101
102 fn get_learning_rate(&self, paramid: &ParamId) -> Option<A>;
104
105 fn set_learning_rate(&mut self, paramid: &ParamId, lr: A) -> Result<()>;
107
108 fn get_parameter_state(&self, paramid: &ParamId) -> Option<&HashMap<String, Array<A, D>>>;
110
111 fn reset_state(&mut self);
113
114 fn registered_parameters(&self) -> Vec<ParamId>;
116}
117
118#[derive(Debug)]
120pub struct ParameterManager<A: Float, D: Dimension> {
121 parameters: HashMap<ParamId, ParameterMetadata>,
123 optimizer_states: HashMap<ParamId, HashMap<String, Array<A, D>>>,
125 layer_architectures: HashMap<LayerId, LayerArchitecture>,
127 sharing_groups: HashMap<String, Vec<ParamId>>,
129 layer_rules: HashMap<LayerId, LayerOptimizationRule<A>>,
131 global_config: OptimizationConfig<A>,
133 lazy_mode: bool,
135 pending_registrations: Vec<(ParamId, ParameterMetadata)>,
137}
138
139#[derive(Debug, Clone)]
141pub struct LayerOptimizationRule<A: Float> {
142 pub lr_multiplier: A,
144 pub weight_decay_multiplier: A,
146 pub frozen: bool,
148 pub custom_settings: HashMap<String, LayerConfig>,
150}
151
152#[derive(Debug, Clone)]
154pub struct OptimizationConfig<A: Float> {
155 pub base_learning_rate: A,
157 pub weight_decay: A,
159 pub gradient_clip: Option<A>,
161 pub mixed_precision: bool,
163 pub architecture_optimizations: HashMap<String, bool>,
165}
166
167impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
168 ParameterManager<A, D>
169{
170 pub fn new(config: OptimizationConfig<A>) -> Self {
172 Self {
173 parameters: HashMap::new(),
174 optimizer_states: HashMap::new(),
175 layer_architectures: HashMap::new(),
176 sharing_groups: HashMap::new(),
177 layer_rules: HashMap::new(),
178 global_config: config,
179 lazy_mode: false,
180 pending_registrations: Vec::new(),
181 }
182 }
183
184 pub fn enable_lazy_mode(&mut self) {
186 self.lazy_mode = true;
187 }
188
189 pub fn disable_lazy_mode(&mut self) -> Result<()> {
191 self.lazy_mode = false;
192
193 let pending = std::mem::take(&mut self.pending_registrations);
195 for (paramid, metadata) in pending {
196 self.register_parameter_impl(paramid, metadata)?;
197 }
198
199 Ok(())
200 }
201
202 pub fn register_layer(&mut self, layerid: LayerId, architecture: LayerArchitecture) {
204 self.layer_architectures.insert(layerid, architecture);
205 }
206
207 pub fn set_layer_rule(&mut self, layerid: LayerId, rule: LayerOptimizationRule<A>) {
209 self.layer_rules.insert(layerid, rule);
210 }
211
212 pub fn register_parameter(
214 &mut self,
215 paramid: ParamId,
216 metadata: ParameterMetadata,
217 ) -> Result<()> {
218 if self.lazy_mode {
219 self.pending_registrations.push((paramid, metadata));
220 Ok(())
221 } else {
222 self.register_parameter_impl(paramid, metadata)
223 }
224 }
225
226 fn register_parameter_impl(
228 &mut self,
229 paramid: ParamId,
230 metadata: ParameterMetadata,
231 ) -> Result<()> {
232 if let Some(sharing_group) = &metadata.sharing_group {
234 self.sharing_groups
235 .entry(sharing_group.clone())
236 .or_default()
237 .push(paramid.clone());
238 }
239
240 self.optimizer_states
242 .insert(paramid.clone(), HashMap::new());
243
244 self.parameters.insert(paramid, metadata);
246
247 Ok(())
248 }
249
250 pub fn get_effective_learning_rate(&self, paramid: &ParamId) -> A {
252 let base_lr = self.global_config.base_learning_rate;
253
254 if let Some(metadata) = self.parameters.get(paramid) {
255 if let Some(rule) = self.layer_rules.get(&metadata.layername) {
256 return base_lr * rule.lr_multiplier;
257 }
258 }
259
260 base_lr
261 }
262
263 pub fn get_effective_weight_decay(&self, paramid: &ParamId) -> A {
265 let base_decay = self.global_config.weight_decay;
266
267 if let Some(metadata) = self.parameters.get(paramid) {
268 if let Some(rule) = self.layer_rules.get(&metadata.layername) {
269 return base_decay * rule.weight_decay_multiplier;
270 }
271 }
272
273 base_decay
274 }
275
276 pub fn is_parameter_frozen(&self, paramid: &ParamId) -> bool {
278 if let Some(metadata) = self.parameters.get(paramid) {
279 if let Some(rule) = self.layer_rules.get(&metadata.layername) {
280 return rule.frozen;
281 }
282 }
283 false
284 }
285
286 pub fn get_sharing_group(&self, groupname: &str) -> Option<&[ParamId]> {
288 self.sharing_groups.get(groupname).map(|v| v.as_slice())
289 }
290
291 pub fn get_all_parameters(&self) -> &HashMap<ParamId, ParameterMetadata> {
293 &self.parameters
294 }
295
296 pub fn get_layer_architecture(&self, layerid: &LayerId) -> Option<&LayerArchitecture> {
298 self.layer_architectures.get(layerid)
299 }
300
301 pub fn get_parameter_metadata(&self, paramid: &ParamId) -> Option<&ParameterMetadata> {
303 self.parameters.get(paramid)
304 }
305
306 pub fn update_config(&mut self, config: OptimizationConfig<A>) {
308 self.global_config = config;
309 }
310
311 pub fn get_optimizer_state(&self, paramid: &ParamId) -> Option<&HashMap<String, Array<A, D>>> {
313 self.optimizer_states.get(paramid)
314 }
315
316 pub fn get_optimizer_state_mut(
318 &mut self,
319 paramid: &ParamId,
320 ) -> Option<&mut HashMap<String, Array<A, D>>> {
321 self.optimizer_states.get_mut(paramid)
322 }
323
324 pub fn init_optimizer_state(
326 &mut self,
327 paramid: &ParamId,
328 state_name: &str,
329 state: Array<A, D>,
330 ) -> Result<()> {
331 if let Some(states) = self.optimizer_states.get_mut(paramid) {
332 states.insert(state_name.to_string(), state);
333 Ok(())
334 } else {
335 Err(OptimError::InvalidConfig(format!(
336 "Parameter {} not registered",
337 paramid
338 )))
339 }
340 }
341
342 pub fn reset_optimizer_states(&mut self) {
344 for states in self.optimizer_states.values_mut() {
345 states.clear();
346 }
347 }
348
349 pub fn get_parameters_by_layer(&self, layerid: &LayerId) -> Vec<&ParamId> {
351 self.parameters
352 .iter()
353 .filter(|(_, metadata)| &metadata.layername == layerid)
354 .map(|(paramid, _)| paramid)
355 .collect()
356 }
357
358 pub fn get_parameters_by_type(&self, paramtype: ParameterType) -> Vec<&ParamId> {
360 self.parameters
361 .iter()
362 .filter(|(_, metadata)| metadata.paramtype == paramtype)
363 .map(|(paramid, _)| paramid)
364 .collect()
365 }
366
367 pub fn get_trainable_parameters(&self) -> Vec<&ParamId> {
369 self.parameters
370 .iter()
371 .filter(|(paramid, metadata)| {
372 metadata.requires_grad && !self.is_parameter_frozen(paramid)
373 })
374 .map(|(paramid, _)| paramid)
375 .collect()
376 }
377}
378
379impl<A: Float + Send + Sync> Default for OptimizationConfig<A> {
380 fn default() -> Self {
381 Self {
382 base_learning_rate: A::from(0.001).unwrap(),
383 weight_decay: A::zero(),
384 gradient_clip: None,
385 mixed_precision: false,
386 architecture_optimizations: HashMap::new(),
387 }
388 }
389}
390
391impl<A: Float + Send + Sync> Default for LayerOptimizationRule<A> {
392 fn default() -> Self {
393 Self {
394 lr_multiplier: A::one(),
395 weight_decay_multiplier: A::one(),
396 frozen: false,
397 custom_settings: HashMap::new(),
398 }
399 }
400}
401
402pub mod forward_backward {
404 use super::*;
405
406 pub trait ForwardHook<A: Float, D: Dimension> {
408 fn pre_forward(&mut self, layerid: &LayerId, inputs: &[Array<A, D>]) -> Result<()>;
410
411 fn post_forward(&mut self, layerid: &LayerId, outputs: &[Array<A, D>]) -> Result<()>;
413 }
414
415 pub trait BackwardHook<A: Float, D: Dimension> {
417 fn pre_backward(&mut self, layerid: &LayerId, gradoutputs: &[Array<A, D>]) -> Result<()>;
419
420 fn post_backward(&mut self, layerid: &LayerId, gradinputs: &[Array<A, D>]) -> Result<()>;
422 }
423
424 pub struct NeuralIntegration<A: Float, D: Dimension> {
426 param_manager: ParameterManager<A, D>,
428 forward_hooks: HashMap<LayerId, Box<dyn ForwardHook<A, D>>>,
430 backward_hooks: HashMap<LayerId, Box<dyn BackwardHook<A, D>>>,
432 gradient_accumulation: bool,
434 accumulated_gradients: HashMap<ParamId, Array<A, D>>,
436 accumulation_count: usize,
438 }
439
440 impl<
441 A: Float
442 + ScalarOperand
443 + Debug
444 + 'static
445 + scirs2_core::numeric::FromPrimitive
446 + std::iter::Sum
447 + Send
448 + Sync,
449 D: Dimension + 'static,
450 > NeuralIntegration<A, D>
451 {
452 pub fn new(config: OptimizationConfig<A>) -> Self {
454 Self {
455 param_manager: ParameterManager::new(config),
456 forward_hooks: HashMap::new(),
457 backward_hooks: HashMap::new(),
458 gradient_accumulation: false,
459 accumulated_gradients: HashMap::new(),
460 accumulation_count: 0,
461 }
462 }
463
464 pub fn register_forward_hook<H>(&mut self, layerid: LayerId, hook: H)
466 where
467 H: ForwardHook<A, D> + 'static,
468 {
469 self.forward_hooks.insert(layerid, Box::new(hook));
470 }
471
472 pub fn register_backward_hook<H>(&mut self, layerid: LayerId, hook: H)
474 where
475 H: BackwardHook<A, D> + 'static,
476 {
477 self.backward_hooks.insert(layerid, Box::new(hook));
478 }
479
480 pub fn enable_gradient_accumulation(&mut self) {
482 self.gradient_accumulation = true;
483 }
484
485 pub fn disable_gradient_accumulation(&mut self) -> HashMap<ParamId, Array<A, D>> {
487 self.gradient_accumulation = false;
488 let result = std::mem::take(&mut self.accumulated_gradients);
489 self.accumulation_count = 0;
490 result
491 }
492
493 pub fn forward_pass(
495 &mut self,
496 layerid: &LayerId,
497 inputs: &[Array<A, D>],
498 ) -> Result<Vec<Array<A, D>>> {
499 if let Some(hook) = self.forward_hooks.get_mut(layerid) {
501 hook.pre_forward(layerid, inputs)?;
502 }
503
504 let layer_arch = self
506 .param_manager
507 .get_layer_architecture(layerid)
508 .ok_or_else(|| {
509 OptimError::InvalidConfig(format!("Layer {} not registered", layerid))
510 })?
511 .clone();
512
513 let outputs = match layer_arch.layer_type.as_str() {
515 "linear" | "dense" | "fc" => {
516 self.compute_linear_forward(layerid, inputs)?
518 }
519 "conv" | "conv2d" => {
520 self.compute_conv_forward(layerid, inputs)?
522 }
523 "activation" => {
524 self.compute_activation_forward(layerid, inputs, &layer_arch)?
526 }
527 "normalization" | "batchnorm" | "layernorm" => {
528 self.compute_normalization_forward(layerid, inputs)?
530 }
531 "dropout" => {
532 self.compute_dropout_forward(layerid, inputs, &layer_arch)?
534 }
535 "pooling" | "maxpool" | "avgpool" => {
536 self.compute_pooling_forward(layerid, inputs, &layer_arch)?
538 }
539 _ => {
540 inputs.to_vec()
542 }
543 };
544
545 if let Some(hook) = self.forward_hooks.get_mut(layerid) {
547 hook.post_forward(layerid, &outputs)?;
548 }
549
550 Ok(outputs)
551 }
552
553 fn compute_linear_forward(
555 &self,
556 layerid: &LayerId,
557 inputs: &[Array<A, D>],
558 ) -> Result<Vec<Array<A, D>>> {
559 if inputs.is_empty() {
562 return Err(OptimError::InvalidConfig(
563 "Linear layer requires input".to_string(),
564 ));
565 }
566
567 let layer_params = self.param_manager.get_parameters_by_layer(layerid);
569
570 let lr =
572 self.param_manager
573 .get_effective_learning_rate(layer_params.first().ok_or_else(|| {
574 OptimError::InvalidConfig("No parameters for linear layer".to_string())
575 })?);
576
577 let outputs: Vec<Array<A, D>> =
578 inputs.iter().map(|input| input.mapv(|x| x * lr)).collect();
579
580 Ok(outputs)
581 }
582
583 fn compute_conv_forward(
585 &self,
586 _layer_id: &LayerId,
587 inputs: &[Array<A, D>],
588 ) -> Result<Vec<Array<A, D>>> {
589 Ok(inputs.to_vec())
592 }
593
594 fn compute_activation_forward(
596 &self,
597 _layer_id: &LayerId,
598 inputs: &[Array<A, D>],
599 layer_arch: &LayerArchitecture,
600 ) -> Result<Vec<Array<A, D>>> {
601 let activation_type = layer_arch
602 .config
603 .get("activation")
604 .and_then(|v| match v {
605 LayerConfig::String(s) => Some(s.as_str()),
606 _ => None,
607 })
608 .unwrap_or("relu");
609
610 let outputs: Vec<Array<A, D>> = inputs
611 .iter()
612 .map(|input| {
613 match activation_type {
614 "relu" => input.mapv(|x| if x > A::zero() { x } else { A::zero() }),
615 "sigmoid" => input.mapv(|x| A::one() / (A::one() + (-x).exp())),
616 "tanh" => input.mapv(|x| x.tanh()),
617 "leaky_relu" => {
618 let alpha = A::from(0.01).unwrap();
619 input.mapv(|x| if x > A::zero() { x } else { alpha * x })
620 }
621 _ => input.clone(), }
623 })
624 .collect();
625
626 Ok(outputs)
627 }
628
629 fn compute_normalization_forward(
631 &self,
632 _layer_id: &LayerId,
633 inputs: &[Array<A, D>],
634 ) -> Result<Vec<Array<A, D>>> {
635 let outputs: Vec<Array<A, D>> = inputs
637 .iter()
638 .map(|input| {
639 let mean = input.iter().copied().sum::<A>()
640 / A::from(input.len()).unwrap_or(A::zero());
641 let variance = input
642 .mapv(|x| (x - mean).powi(2))
643 .mean()
644 .unwrap_or(A::one());
645 let std_dev = variance.sqrt();
646 let epsilon = A::from(1e-5).unwrap();
647
648 input.mapv(|x| (x - mean) / (std_dev + epsilon))
649 })
650 .collect();
651
652 Ok(outputs)
653 }
654
655 fn compute_dropout_forward(
657 &self,
658 _layer_id: &LayerId,
659 inputs: &[Array<A, D>],
660 layer_arch: &LayerArchitecture,
661 ) -> Result<Vec<Array<A, D>>> {
662 let dropout_rate = layer_arch
663 .config
664 .get("dropout_rate")
665 .and_then(|v| match v {
666 LayerConfig::Float(f) => Some(A::from(*f).unwrap()),
667 _ => None,
668 })
669 .unwrap_or(A::from(0.5).unwrap());
670
671 let scale = A::one() - dropout_rate;
674 let outputs: Vec<Array<A, D>> = inputs
675 .iter()
676 .map(|input| input.mapv(|x| x * scale))
677 .collect();
678
679 Ok(outputs)
680 }
681
682 fn compute_pooling_forward(
684 &self,
685 _layer_id: &LayerId,
686 inputs: &[Array<A, D>],
687 _layer_arch: &LayerArchitecture,
688 ) -> Result<Vec<Array<A, D>>> {
689 Ok(inputs.to_vec())
692 }
693
694 pub fn backward_pass(
696 &mut self,
697 layerid: &LayerId,
698 grad_outputs: &[Array<A, D>],
699 ) -> Result<Vec<Array<A, D>>> {
700 if let Some(hook) = self.backward_hooks.get_mut(layerid) {
702 hook.pre_backward(layerid, grad_outputs)?;
703 }
704
705 let layer_arch = self
707 .param_manager
708 .get_layer_architecture(layerid)
709 .ok_or_else(|| {
710 OptimError::InvalidConfig(format!("Layer {} not registered", layerid))
711 })?
712 .clone();
713
714 let grad_inputs = match layer_arch.layer_type.as_str() {
716 "linear" | "dense" | "fc" => {
717 self.compute_linear_backward(layerid, grad_outputs)?
719 }
720 "conv" | "conv2d" => {
721 self.compute_conv_backward(layerid, grad_outputs)?
723 }
724 "activation" => {
725 self.compute_activation_backward(layerid, grad_outputs, &layer_arch)?
727 }
728 "normalization" | "batchnorm" | "layernorm" => {
729 self.compute_normalization_backward(layerid, grad_outputs)?
731 }
732 "dropout" => {
733 self.compute_dropout_backward(layerid, grad_outputs, &layer_arch)?
735 }
736 "pooling" | "maxpool" | "avgpool" => {
737 self.compute_pooling_backward(layerid, grad_outputs, &layer_arch)?
739 }
740 _ => {
741 grad_outputs.to_vec()
743 }
744 };
745
746 let clipped_grads =
748 if let Some(clipvalue) = self.param_manager.global_config.gradient_clip {
749 self.apply_gradient_clipping(grad_inputs, clipvalue)?
750 } else {
751 grad_inputs
752 };
753
754 if let Some(hook) = self.backward_hooks.get_mut(layerid) {
756 hook.post_backward(layerid, &clipped_grads)?;
757 }
758
759 Ok(clipped_grads)
760 }
761
762 fn compute_linear_backward(
764 &mut self,
765 layerid: &LayerId,
766 grad_outputs: &[Array<A, D>],
767 ) -> Result<Vec<Array<A, D>>> {
768 if grad_outputs.is_empty() {
769 return Err(OptimError::InvalidConfig(
770 "Linear layer backward requires gradients".to_string(),
771 ));
772 }
773
774 let layer_params = self.param_manager.get_parameters_by_layer(layerid);
776
777 if self.gradient_accumulation {
779 let mut param_grads = HashMap::new();
780 for (i, paramid) in layer_params.iter().enumerate() {
781 if i < grad_outputs.len() {
782 param_grads.insert((*paramid).clone(), grad_outputs[i].clone());
783 }
784 }
785 self.accumulate_gradients(param_grads)?;
786 }
787
788 let lr_decay = A::from(0.9).unwrap();
790 let grad_inputs: Vec<Array<A, D>> = grad_outputs
791 .iter()
792 .map(|grad| grad.mapv(|x| x * lr_decay))
793 .collect();
794
795 Ok(grad_inputs)
796 }
797
798 fn compute_conv_backward(
800 &self,
801 _layer_id: &LayerId,
802 grad_outputs: &[Array<A, D>],
803 ) -> Result<Vec<Array<A, D>>> {
804 Ok(grad_outputs.to_vec())
807 }
808
809 fn compute_activation_backward(
811 &self,
812 _layer_id: &LayerId,
813 grad_outputs: &[Array<A, D>],
814 layer_arch: &LayerArchitecture,
815 ) -> Result<Vec<Array<A, D>>> {
816 let activation_type = layer_arch
817 .config
818 .get("activation")
819 .and_then(|v| match v {
820 LayerConfig::String(s) => Some(s.as_str()),
821 _ => None,
822 })
823 .unwrap_or("relu");
824
825 let grad_inputs: Vec<Array<A, D>> = grad_outputs
827 .iter()
828 .map(|grad| {
829 match activation_type {
830 "relu" => {
831 grad.mapv(|g| if g > A::zero() { g } else { A::zero() })
834 }
835 "sigmoid" => {
836 let factor = A::from(0.25).unwrap(); grad.mapv(|g| g * factor)
840 }
841 "tanh" => {
842 let factor = A::from(0.5).unwrap();
845 grad.mapv(|g| g * factor)
846 }
847 "leaky_relu" => {
848 let alpha = A::from(0.01).unwrap();
849 grad.mapv(|g| if g > A::zero() { g } else { alpha * g })
850 }
851 _ => grad.clone(), }
853 })
854 .collect();
855
856 Ok(grad_inputs)
857 }
858
859 fn compute_normalization_backward(
861 &self,
862 _layer_id: &LayerId,
863 grad_outputs: &[Array<A, D>],
864 ) -> Result<Vec<Array<A, D>>> {
865 let scale_factor = A::from(0.9).unwrap();
868 let grad_inputs: Vec<Array<A, D>> = grad_outputs
869 .iter()
870 .map(|grad| grad.mapv(|g| g * scale_factor))
871 .collect();
872
873 Ok(grad_inputs)
874 }
875
876 fn compute_dropout_backward(
878 &self,
879 _layer_id: &LayerId,
880 grad_outputs: &[Array<A, D>],
881 layer_arch: &LayerArchitecture,
882 ) -> Result<Vec<Array<A, D>>> {
883 let dropout_rate = layer_arch
884 .config
885 .get("dropout_rate")
886 .and_then(|v| match v {
887 LayerConfig::Float(f) => Some(A::from(*f).unwrap()),
888 _ => None,
889 })
890 .unwrap_or(A::from(0.5).unwrap());
891
892 let scale = A::one() - dropout_rate;
894 let grad_inputs: Vec<Array<A, D>> = grad_outputs
895 .iter()
896 .map(|grad| grad.mapv(|g| g * scale))
897 .collect();
898
899 Ok(grad_inputs)
900 }
901
902 fn compute_pooling_backward(
904 &self,
905 _layer_id: &LayerId,
906 grad_outputs: &[Array<A, D>],
907 _layer_arch: &LayerArchitecture,
908 ) -> Result<Vec<Array<A, D>>> {
909 Ok(grad_outputs.to_vec())
912 }
913
914 fn apply_gradient_clipping(
916 &self,
917 gradients: Vec<Array<A, D>>,
918 clipvalue: A,
919 ) -> Result<Vec<Array<A, D>>> {
920 let clipped: Vec<Array<A, D>> = gradients
921 .into_iter()
922 .map(|grad| {
923 let norm = grad.mapv(|x| x * x).sum().sqrt();
925
926 if norm > clipvalue {
927 let scale = clipvalue / norm;
929 grad.mapv(|x| x * scale)
930 } else {
931 grad
932 }
933 })
934 .collect();
935
936 Ok(clipped)
937 }
938
939 pub fn accumulate_gradients(
941 &mut self,
942 gradients: HashMap<ParamId, Array<A, D>>,
943 ) -> Result<()> {
944 if !self.gradient_accumulation {
945 return Err(OptimError::InvalidConfig(
946 "Gradient accumulation not enabled".to_string(),
947 ));
948 }
949
950 self.accumulation_count += 1;
951
952 for (paramid, grad) in gradients {
953 if let Some(acc_grad) = self.accumulated_gradients.get_mut(¶mid) {
954 *acc_grad = acc_grad.clone() + grad;
956 } else {
957 self.accumulated_gradients.insert(paramid, grad);
959 }
960 }
961
962 Ok(())
963 }
964
965 pub fn parameter_manager(&self) -> &ParameterManager<A, D> {
967 &self.param_manager
968 }
969
970 pub fn parameter_manager_mut(&mut self) -> &mut ParameterManager<A, D> {
972 &mut self.param_manager
973 }
974
975 pub fn accumulation_count(&self) -> usize {
977 self.accumulation_count
978 }
979 }
980}
981
982pub mod architecture_aware {
984 use super::*;
985
986 #[derive(Debug, Clone)]
988 pub enum ArchitectureStrategy {
989 Transformer {
991 component_specific_lr: bool,
993 layer_wise_decay: bool,
995 attention_warmup: usize,
997 },
998 ConvolutionalNet {
1000 layer_type_lr: bool,
1002 depth_scaling: bool,
1004 bn_special_handling: bool,
1006 },
1007 RecurrentNet {
1009 rnn_gradient_clip: Option<f64>,
1011 weight_type_lr: bool,
1013 },
1014 Custom {
1016 rules: HashMap<String, LayerConfig>,
1018 },
1019 }
1020
1021 #[derive(Debug)]
1023 pub struct ArchitectureAwareOptimizer<A: Float, D: Dimension> {
1024 param_manager: ParameterManager<A, D>,
1026 strategy: ArchitectureStrategy,
1028 step_count: usize,
1030 }
1031
1032 impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync>
1033 ArchitectureAwareOptimizer<A, D>
1034 {
1035 pub fn new(config: OptimizationConfig<A>, strategy: ArchitectureStrategy) -> Self {
1037 Self {
1038 param_manager: ParameterManager::new(config),
1039 strategy,
1040 step_count: 0,
1041 }
1042 }
1043
1044 pub fn apply_architecture_optimizations(&mut self) -> Result<()> {
1046 let strategy = self.strategy.clone();
1048 match strategy {
1049 ArchitectureStrategy::Transformer {
1050 component_specific_lr,
1051 layer_wise_decay,
1052 attention_warmup,
1053 } => {
1054 self.apply_transformer_optimizations(
1055 component_specific_lr,
1056 layer_wise_decay,
1057 attention_warmup,
1058 )?;
1059 }
1060 ArchitectureStrategy::ConvolutionalNet {
1061 layer_type_lr,
1062 depth_scaling,
1063 bn_special_handling,
1064 } => {
1065 self.apply_cnn_optimizations(
1066 layer_type_lr,
1067 depth_scaling,
1068 bn_special_handling,
1069 )?;
1070 }
1071 ArchitectureStrategy::RecurrentNet {
1072 rnn_gradient_clip,
1073 weight_type_lr,
1074 } => {
1075 self.apply_rnn_optimizations(rnn_gradient_clip, weight_type_lr)?;
1076 }
1077 ArchitectureStrategy::Custom { rules } => {
1078 self.apply_custom_optimizations(&rules)?;
1079 }
1080 }
1081 Ok(())
1082 }
1083
1084 fn apply_transformer_optimizations(
1086 &mut self,
1087 component_specific_lr: bool,
1088 layer_wise_decay: bool,
1089 attention_warmup: usize,
1090 ) -> Result<()> {
1091 if component_specific_lr {
1092 self.set_component_learning_rates()?;
1094 }
1095
1096 if layer_wise_decay {
1097 self.apply_layer_wise_decay()?;
1099 }
1100
1101 if attention_warmup > 0 && self.step_count < attention_warmup {
1102 self.apply_attention_warmup(attention_warmup)?;
1104 }
1105
1106 Ok(())
1107 }
1108
1109 fn apply_cnn_optimizations(
1111 &mut self,
1112 layer_type_lr: bool,
1113 depth_scaling: bool,
1114 bn_special_handling: bool,
1115 ) -> Result<()> {
1116 if layer_type_lr {
1117 self.set_layer_type_learning_rates()?;
1119 }
1120
1121 if depth_scaling {
1122 self.apply_depth_scaling()?;
1124 }
1125
1126 if bn_special_handling {
1127 self.apply_bn_optimizations()?;
1129 }
1130
1131 Ok(())
1132 }
1133
1134 fn apply_rnn_optimizations(
1136 &mut self,
1137 rnn_gradient_clip: Option<f64>,
1138 weight_type_lr: bool,
1139 ) -> Result<()> {
1140 if let Some(clipvalue) = rnn_gradient_clip {
1141 self.apply_rnn_gradient_clipping(A::from(clipvalue).unwrap())?;
1143 }
1144
1145 if weight_type_lr {
1146 self.set_weight_type_learning_rates()?;
1148 }
1149
1150 Ok(())
1151 }
1152
1153 fn apply_custom_optimizations(
1155 &mut self,
1156 rules: &HashMap<String, LayerConfig>,
1157 ) -> Result<()> {
1158 let rule_entries: Vec<(String, LayerConfig)> = rules
1160 .iter()
1161 .map(|(name, config)| (name.clone(), config.clone()))
1162 .collect();
1163
1164 for (rule_name, config) in rule_entries {
1165 self.apply_custom_rule(&rule_name, &config)?;
1166 }
1167 Ok(())
1168 }
1169
1170 fn set_component_learning_rates(&mut self) -> Result<()> {
1172 let layer_rules: Vec<(LayerId, LayerOptimizationRule<A>)> = self
1174 .param_manager
1175 .get_all_parameters()
1176 .values()
1177 .map(|metadata| {
1178 let mut rule = LayerOptimizationRule::default();
1179
1180 if metadata.tags.contains(&"attention".to_string()) {
1182 rule.lr_multiplier = A::from(1.2).unwrap(); } else if metadata.tags.contains(&"ffn".to_string()) {
1184 rule.lr_multiplier = A::from(1.0).unwrap(); } else if metadata.tags.contains(&"normalization".to_string()) {
1186 rule.lr_multiplier = A::from(0.8).unwrap(); }
1188
1189 (metadata.layername.clone(), rule)
1190 })
1191 .collect();
1192
1193 for (layername, rule) in layer_rules {
1195 self.param_manager.set_layer_rule(layername, rule);
1196 }
1197 Ok(())
1198 }
1199
1200 fn apply_layer_wise_decay(&mut self) -> Result<()> {
1202 for (layerid, _) in self.param_manager.layer_architectures.clone() {
1204 if let Some(layer_num) = self.extract_layer_number(&layerid) {
1205 let decay_factor = A::from(0.95_f64.powi(layer_num as i32)).unwrap();
1206 let mut rule = self
1207 .param_manager
1208 .layer_rules
1209 .get(&layerid)
1210 .cloned()
1211 .unwrap_or_default();
1212 rule.lr_multiplier = rule.lr_multiplier * decay_factor;
1213 self.param_manager.set_layer_rule(layerid, rule);
1214 }
1215 }
1216 Ok(())
1217 }
1218
1219 fn apply_attention_warmup(&mut self, warmupsteps: usize) -> Result<()> {
1221 let warmup_factor = A::from(self.step_count as f64 / warmupsteps as f64).unwrap();
1222
1223 let attention_layers: Vec<LayerId> = self
1225 .param_manager
1226 .get_all_parameters()
1227 .iter()
1228 .filter_map(|(_param_id, metadata)| {
1229 if metadata.tags.contains(&"attention".to_string()) {
1230 Some(metadata.layername.clone())
1231 } else {
1232 None
1233 }
1234 })
1235 .collect();
1236
1237 for layername in attention_layers {
1239 let mut rule = self
1240 .param_manager
1241 .layer_rules
1242 .get(&layername)
1243 .cloned()
1244 .unwrap_or_default();
1245 rule.lr_multiplier = rule.lr_multiplier * warmup_factor;
1246 self.param_manager.set_layer_rule(layername, rule);
1247 }
1248 Ok(())
1249 }
1250
1251 fn set_layer_type_learning_rates(&mut self) -> Result<()> {
1253 for (layerid, architecture) in self.param_manager.layer_architectures.clone() {
1254 let mut rule = LayerOptimizationRule::default();
1255
1256 match architecture.layer_type.as_str() {
1257 "conv" | "conv2d" | "conv3d" => {
1258 rule.lr_multiplier = A::from(1.0).unwrap(); }
1260 "linear" | "dense" | "fc" => {
1261 rule.lr_multiplier = A::from(0.8).unwrap(); }
1263 _ => {
1264 rule.lr_multiplier = A::from(1.0).unwrap(); }
1266 }
1267
1268 self.param_manager.set_layer_rule(layerid, rule);
1269 }
1270 Ok(())
1271 }
1272
1273 fn apply_depth_scaling(&mut self) -> Result<()> {
1275 let total_layers = self.param_manager.layer_architectures.len();
1277
1278 for (i, (layerid, _)) in self
1279 .param_manager
1280 .layer_architectures
1281 .clone()
1282 .iter()
1283 .enumerate()
1284 {
1285 let depth_factor = A::from(1.0 - 0.1 * (i as f64 / total_layers as f64)).unwrap();
1286 let mut rule = self
1287 .param_manager
1288 .layer_rules
1289 .get(layerid)
1290 .cloned()
1291 .unwrap_or_default();
1292 rule.lr_multiplier = rule.lr_multiplier * depth_factor;
1293 self.param_manager.set_layer_rule(layerid.clone(), rule);
1294 }
1295 Ok(())
1296 }
1297
1298 fn apply_bn_optimizations(&mut self) -> Result<()> {
1300 let norm_layers: Vec<LayerId> = self
1302 .param_manager
1303 .get_all_parameters()
1304 .iter()
1305 .filter_map(|(_param_id, metadata)| {
1306 if metadata.paramtype == ParameterType::Normalization {
1307 Some(metadata.layername.clone())
1308 } else {
1309 None
1310 }
1311 })
1312 .collect();
1313
1314 for layername in norm_layers {
1316 let mut rule = self
1317 .param_manager
1318 .layer_rules
1319 .get(&layername)
1320 .cloned()
1321 .unwrap_or_default();
1322 rule.lr_multiplier = A::from(2.0).unwrap();
1324 rule.weight_decay_multiplier = A::zero();
1325 self.param_manager.set_layer_rule(layername, rule);
1326 }
1327 Ok(())
1328 }
1329
1330 fn apply_rnn_gradient_clipping(&mut self, clipvalue: A) -> Result<()> {
1332 self.param_manager.global_config.gradient_clip = Some(clipvalue);
1335 Ok(())
1336 }
1337
1338 fn set_weight_type_learning_rates(&mut self) -> Result<()> {
1340 let layer_rules: Vec<(LayerId, LayerOptimizationRule<A>)> = self
1342 .param_manager
1343 .get_all_parameters()
1344 .values()
1345 .map(|metadata| {
1346 let mut rule = LayerOptimizationRule::default();
1347
1348 if metadata.tags.contains(&"recurrent".to_string()) {
1349 rule.lr_multiplier = A::from(0.5).unwrap(); } else if metadata.tags.contains(&"linear".to_string()) {
1351 rule.lr_multiplier = A::from(1.0).unwrap(); }
1353
1354 (metadata.layername.clone(), rule)
1355 })
1356 .collect();
1357
1358 for (layername, rule) in layer_rules {
1360 self.param_manager.set_layer_rule(layername, rule);
1361 }
1362 Ok(())
1363 }
1364
1365 fn apply_custom_rule(&mut self, _rule_name: &str, config: &LayerConfig) -> Result<()> {
1367 Ok(())
1370 }
1371
1372 fn extract_layer_number(&self, layername: &str) -> Option<usize> {
1374 layername.split('_').next_back()?.parse().ok()
1375 }
1376
1377 pub fn step(&mut self) -> Result<()> {
1379 self.step_count += 1;
1380 self.apply_architecture_optimizations()
1381 }
1382
1383 pub fn parameter_manager(&self) -> &ParameterManager<A, D> {
1385 &self.param_manager
1386 }
1387
1388 pub fn parameter_manager_mut(&mut self) -> &mut ParameterManager<A, D> {
1390 &mut self.param_manager
1391 }
1392 }
1393}
1394
1395#[cfg(test)]
1396mod tests {
1397 use super::*;
1398 use approx::assert_relative_eq;
1399
1400 #[test]
1401 fn test_parameter_manager_basic() {
1402 let config = OptimizationConfig::default();
1403 let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1404
1405 let metadata = ParameterMetadata {
1406 layername: "layer1".to_string(),
1407 param_name: "weight".to_string(),
1408 shape: vec![10, 5],
1409 requires_grad: true,
1410 paramtype: ParameterType::Weight,
1411 sharing_group: None,
1412 tags: vec!["dense".to_string()],
1413 };
1414
1415 manager
1416 .register_parameter("param1".to_string(), metadata)
1417 .unwrap();
1418
1419 assert!(manager
1420 .get_parameter_metadata(&"param1".to_string())
1421 .is_some());
1422 assert_eq!(manager.get_all_parameters().len(), 1);
1423 assert!(!manager.is_parameter_frozen(&"param1".to_string()));
1424 }
1425
1426 #[test]
1427 fn test_lazy_registration() {
1428 let config = OptimizationConfig::default();
1429 let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1430
1431 manager.enable_lazy_mode();
1432
1433 let metadata = ParameterMetadata {
1434 layername: "layer1".to_string(),
1435 param_name: "weight".to_string(),
1436 shape: vec![10, 5],
1437 requires_grad: true,
1438 paramtype: ParameterType::Weight,
1439 sharing_group: None,
1440 tags: vec![],
1441 };
1442
1443 manager
1444 .register_parameter("param1".to_string(), metadata)
1445 .unwrap();
1446
1447 assert_eq!(manager.get_all_parameters().len(), 0);
1449
1450 manager.disable_lazy_mode().unwrap();
1452
1453 assert_eq!(manager.get_all_parameters().len(), 1);
1455 }
1456
1457 #[test]
1458 fn test_layer_specific_rules() {
1459 let config = OptimizationConfig {
1460 base_learning_rate: 0.01,
1461 weight_decay: 0.001,
1462 gradient_clip: None,
1463 mixed_precision: false,
1464 architecture_optimizations: HashMap::new(),
1465 };
1466 let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1467
1468 let rule = LayerOptimizationRule {
1469 lr_multiplier: 2.0,
1470 weight_decay_multiplier: 0.5,
1471 frozen: false,
1472 custom_settings: HashMap::new(),
1473 };
1474
1475 manager.set_layer_rule("layer1".to_string(), rule);
1476
1477 let metadata = ParameterMetadata {
1478 layername: "layer1".to_string(),
1479 param_name: "weight".to_string(),
1480 shape: vec![10, 5],
1481 requires_grad: true,
1482 paramtype: ParameterType::Weight,
1483 sharing_group: None,
1484 tags: vec![],
1485 };
1486
1487 manager
1488 .register_parameter("param1".to_string(), metadata)
1489 .unwrap();
1490
1491 let effective_lr = manager.get_effective_learning_rate(&"param1".to_string());
1493 assert_relative_eq!(effective_lr, 0.02, epsilon = 1e-6); let effective_decay = manager.get_effective_weight_decay(&"param1".to_string());
1497 assert_relative_eq!(effective_decay, 0.0005, epsilon = 1e-6); }
1499
1500 #[test]
1501 fn test_parameter_sharing() {
1502 let config = OptimizationConfig::default();
1503 let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1504
1505 let metadata1 = ParameterMetadata {
1506 layername: "layer1".to_string(),
1507 param_name: "weight".to_string(),
1508 shape: vec![10, 5],
1509 requires_grad: true,
1510 paramtype: ParameterType::Weight,
1511 sharing_group: Some("shared_weights".to_string()),
1512 tags: vec![],
1513 };
1514
1515 let metadata2 = ParameterMetadata {
1516 layername: "layer2".to_string(),
1517 param_name: "weight".to_string(),
1518 shape: vec![10, 5],
1519 requires_grad: true,
1520 paramtype: ParameterType::Weight,
1521 sharing_group: Some("shared_weights".to_string()),
1522 tags: vec![],
1523 };
1524
1525 manager
1526 .register_parameter("param1".to_string(), metadata1)
1527 .unwrap();
1528 manager
1529 .register_parameter("param2".to_string(), metadata2)
1530 .unwrap();
1531
1532 let sharing_group = manager.get_sharing_group("shared_weights").unwrap();
1533 assert_eq!(sharing_group.len(), 2);
1534 assert!(sharing_group.contains(&"param1".to_string()));
1535 assert!(sharing_group.contains(&"param2".to_string()));
1536 }
1537
1538 #[test]
1539 fn test_parameter_filtering() {
1540 let config = OptimizationConfig::default();
1541 let mut manager = ParameterManager::<f64, scirs2_core::ndarray::Ix1>::new(config);
1542
1543 let weight_metadata = ParameterMetadata {
1544 layername: "layer1".to_string(),
1545 param_name: "weight".to_string(),
1546 shape: vec![10, 5],
1547 requires_grad: true,
1548 paramtype: ParameterType::Weight,
1549 sharing_group: None,
1550 tags: vec![],
1551 };
1552
1553 let bias_metadata = ParameterMetadata {
1554 layername: "layer1".to_string(),
1555 param_name: "bias".to_string(),
1556 shape: vec![5],
1557 requires_grad: true,
1558 paramtype: ParameterType::Bias,
1559 sharing_group: None,
1560 tags: vec![],
1561 };
1562
1563 manager
1564 .register_parameter("weight".to_string(), weight_metadata)
1565 .unwrap();
1566 manager
1567 .register_parameter("bias".to_string(), bias_metadata)
1568 .unwrap();
1569
1570 let weights = manager.get_parameters_by_type(ParameterType::Weight);
1572 assert_eq!(weights.len(), 1);
1573 assert_eq!(weights[0], &"weight".to_string());
1574
1575 let biases = manager.get_parameters_by_type(ParameterType::Bias);
1576 assert_eq!(biases.len(), 1);
1577 assert_eq!(biases[0], &"bias".to_string());
1578
1579 let layer_params = manager.get_parameters_by_layer(&"layer1".to_string());
1581 assert_eq!(layer_params.len(), 2);
1582
1583 let trainable = manager.get_trainable_parameters();
1585 assert_eq!(trainable.len(), 2);
1586 }
1587
1588 #[test]
1589 fn test_architecture_aware_transformer() {
1590 use crate::neural_integration::architecture_aware::*;
1591
1592 let config = OptimizationConfig::default();
1593 let strategy = ArchitectureStrategy::Transformer {
1594 component_specific_lr: true,
1595 layer_wise_decay: true,
1596 attention_warmup: 1000,
1597 };
1598
1599 let mut optimizer =
1600 ArchitectureAwareOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(config, strategy);
1601
1602 let layer_arch = LayerArchitecture {
1604 layer_type: "transformer_block".to_string(),
1605 input_dims: vec![512],
1606 output_dims: vec![512],
1607 config: HashMap::new(),
1608 trainable: true,
1609 };
1610
1611 optimizer
1612 .parameter_manager_mut()
1613 .register_layer("layer_0".to_string(), layer_arch);
1614
1615 let attention_metadata = ParameterMetadata {
1617 layername: "layer_0".to_string(),
1618 param_name: "attention_weight".to_string(),
1619 shape: vec![512, 512],
1620 requires_grad: true,
1621 paramtype: ParameterType::Attention,
1622 sharing_group: None,
1623 tags: vec!["attention".to_string()],
1624 };
1625
1626 optimizer
1627 .parameter_manager_mut()
1628 .register_parameter("attn_param".to_string(), attention_metadata)
1629 .unwrap();
1630
1631 optimizer.apply_architecture_optimizations().unwrap();
1633
1634 assert!(optimizer
1636 .parameter_manager()
1637 .get_parameter_metadata(&"attn_param".to_string())
1638 .is_some());
1639 }
1640}