1use crate::error::{OptimError, Result};
8use crate::optimizers::Optimizer;
9use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
10use scirs2_core::numeric::Float;
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::path::Path;
14
15#[derive(Debug, Clone)]
17pub enum ParameterConstraint<A: Float> {
18 ValueClip {
20 min: A,
22 max: A,
24 },
25 L2NormConstraint {
27 maxnorm: A,
29 },
30 L1NormConstraint {
32 maxnorm: A,
34 },
35 NonNegative,
37 UnitSphere,
39 Simplex,
41 Orthogonal {
43 tolerance: A,
45 },
46 PositiveDefinite {
48 mineigenvalue: A,
50 },
51 SpectralNorm {
53 maxnorm: A,
55 },
56 NuclearNorm {
58 maxnorm: A,
60 },
61 Custom {
63 name: String,
65 },
66}
67
68impl<A: Float + Send + Sync> ParameterConstraint<A> {
69 pub fn apply<D: Dimension>(&self, params: &mut Array<A, D>) -> Result<()>
71 where
72 A: ScalarOperand,
73 {
74 match self {
75 ParameterConstraint::ValueClip { min, max } => {
76 params.mapv_inplace(|x| {
77 if x < *min {
78 *min
79 } else if x > *max {
80 *max
81 } else {
82 x
83 }
84 });
85 }
86 ParameterConstraint::L2NormConstraint { maxnorm } => {
87 let norm = params.mapv(|x| x * x).sum().sqrt();
88 if norm > *maxnorm {
89 let scale = *maxnorm / norm;
90 params.mapv_inplace(|x| x * scale);
91 }
92 }
93 ParameterConstraint::L1NormConstraint { maxnorm } => {
94 let norm = params.mapv(|x| x.abs()).sum();
95 if norm > *maxnorm {
96 let scale = *maxnorm / norm;
97 params.mapv_inplace(|x| x * scale);
98 }
99 }
100 ParameterConstraint::NonNegative => {
101 params.mapv_inplace(|x| if x < A::zero() { A::zero() } else { x });
102 }
103 ParameterConstraint::UnitSphere => {
104 let norm = params.mapv(|x| x * x).sum().sqrt();
105 if norm > A::zero() {
106 let scale = A::one() / norm;
107 params.mapv_inplace(|x| x * scale);
108 }
109 }
110 ParameterConstraint::Simplex => {
111 params.mapv_inplace(|x| if x < A::zero() { A::zero() } else { x });
113
114 let sum = params.sum();
116 if sum > A::zero() {
117 let scale = A::one() / sum;
118 params.mapv_inplace(|x| x * scale);
119 } else {
120 let uniform_val = A::one() / A::from(params.len()).unwrap_or(A::one());
122 params.fill(uniform_val);
123 }
124 }
125 ParameterConstraint::Orthogonal { tolerance: _ } => {
126 if params.ndim() == 2 {
130 return Err(OptimError::InvalidConfig(
133 "Orthogonal constraint requires specialized linear algebra operations"
134 .to_string(),
135 ));
136 } else {
137 return Err(OptimError::InvalidConfig(
138 "Orthogonal constraint only applies to 2D arrays (matrices)".to_string(),
139 ));
140 }
141 }
142 ParameterConstraint::PositiveDefinite { mineigenvalue: _ } => {
143 return Err(OptimError::InvalidConfig(
145 "Positive definite constraint requires specialized eigenvalue operations"
146 .to_string(),
147 ));
148 }
149 ParameterConstraint::SpectralNorm { maxnorm } => {
150 let frobenius_norm = params.mapv(|x| x * x).sum().sqrt();
153 if frobenius_norm > *maxnorm {
154 let scale = *maxnorm / frobenius_norm;
155 params.mapv_inplace(|x| x * scale);
156 }
157 }
158 ParameterConstraint::NuclearNorm { maxnorm } => {
159 let l1_norm = params.mapv(|x| x.abs()).sum();
162 if l1_norm > *maxnorm {
163 let scale = *maxnorm / l1_norm;
164 params.mapv_inplace(|x| x * scale);
165 }
166 }
167 ParameterConstraint::Custom { name } => {
168 return Err(OptimError::InvalidConfig(format!(
169 "Custom constraint '{name}' not implemented"
170 )));
171 }
172 }
173 Ok(())
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct ParameterGroupConfig<A: Float> {
180 pub learning_rate: Option<A>,
182 pub weight_decay: Option<A>,
184 pub momentum: Option<A>,
186 pub constraints: Vec<ParameterConstraint<A>>,
188 pub custom_params: HashMap<String, A>,
190}
191
192impl<A: Float + Send + Sync> Default for ParameterGroupConfig<A> {
193 fn default() -> Self {
194 Self {
195 learning_rate: None,
196 weight_decay: None,
197 momentum: None,
198 constraints: Vec::new(),
199 custom_params: HashMap::new(),
200 }
201 }
202}
203
204impl<A: Float + Send + Sync> ParameterGroupConfig<A> {
205 pub fn new() -> Self {
207 Self::default()
208 }
209
210 pub fn with_learning_rate(mut self, lr: A) -> Self {
212 self.learning_rate = Some(lr);
213 self
214 }
215
216 pub fn with_weight_decay(mut self, wd: A) -> Self {
218 self.weight_decay = Some(wd);
219 self
220 }
221
222 pub fn with_momentum(mut self, momentum: A) -> Self {
224 self.momentum = Some(momentum);
225 self
226 }
227
228 pub fn with_custom_param(mut self, key: String, value: A) -> Self {
230 self.custom_params.insert(key, value);
231 self
232 }
233
234 pub fn with_constraint(mut self, constraint: ParameterConstraint<A>) -> Self {
236 self.constraints.push(constraint);
237 self
238 }
239
240 pub fn with_value_clip(mut self, min: A, max: A) -> Self {
242 self.constraints
243 .push(ParameterConstraint::ValueClip { min, max });
244 self
245 }
246
247 pub fn with_l2_norm_constraint(mut self, maxnorm: A) -> Self {
249 self.constraints
250 .push(ParameterConstraint::L2NormConstraint { maxnorm });
251 self
252 }
253
254 pub fn with_l1_norm_constraint(mut self, maxnorm: A) -> Self {
256 self.constraints
257 .push(ParameterConstraint::L1NormConstraint { maxnorm });
258 self
259 }
260
261 pub fn with_non_negative(mut self) -> Self {
263 self.constraints.push(ParameterConstraint::NonNegative);
264 self
265 }
266
267 pub fn with_unit_sphere(mut self) -> Self {
269 self.constraints.push(ParameterConstraint::UnitSphere);
270 self
271 }
272
273 pub fn with_simplex(mut self) -> Self {
275 self.constraints.push(ParameterConstraint::Simplex);
276 self
277 }
278
279 pub fn with_orthogonal(mut self, tolerance: A) -> Self {
281 self.constraints
282 .push(ParameterConstraint::Orthogonal { tolerance });
283 self
284 }
285
286 pub fn with_positive_definite(mut self, mineigenvalue: A) -> Self {
288 self.constraints
289 .push(ParameterConstraint::PositiveDefinite { mineigenvalue });
290 self
291 }
292
293 pub fn with_spectral_norm(mut self, maxnorm: A) -> Self {
295 self.constraints
296 .push(ParameterConstraint::SpectralNorm { maxnorm });
297 self
298 }
299
300 pub fn with_nuclear_norm(mut self, maxnorm: A) -> Self {
302 self.constraints
303 .push(ParameterConstraint::NuclearNorm { maxnorm });
304 self
305 }
306
307 pub fn with_custom_constraint(mut self, name: String) -> Self {
309 self.constraints.push(ParameterConstraint::Custom { name });
310 self
311 }
312}
313
314#[derive(Debug)]
316pub struct ParameterGroup<A: Float, D: Dimension> {
317 pub id: usize,
319 pub params: Vec<Array<A, D>>,
321 pub config: ParameterGroupConfig<A>,
323 pub state: HashMap<String, Vec<Array<A, D>>>,
325}
326
327impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> ParameterGroup<A, D> {
328 pub fn new(id: usize, params: Vec<Array<A, D>>, config: ParameterGroupConfig<A>) -> Self {
330 Self {
331 id,
332 params,
333 config,
334 state: HashMap::new(),
335 }
336 }
337
338 pub fn num_params(&self) -> usize {
340 self.params.len()
341 }
342
343 pub fn learning_rate(&self, default: A) -> A {
345 self.config.learning_rate.unwrap_or(default)
346 }
347
348 pub fn weight_decay(&self, default: A) -> A {
350 self.config.weight_decay.unwrap_or(default)
351 }
352
353 pub fn momentum(&self, default: A) -> A {
355 self.config.momentum.unwrap_or(default)
356 }
357
358 pub fn get_custom_param(&self, key: &str, default: A) -> A {
360 self.config
361 .custom_params
362 .get(key)
363 .copied()
364 .unwrap_or(default)
365 }
366
367 pub fn apply_constraints(&mut self) -> Result<()>
369 where
370 A: ScalarOperand + Send + Sync,
371 {
372 for constraint in &self.config.constraints {
373 for param in &mut self.params {
374 constraint.apply(param)?;
375 }
376 }
377 Ok(())
378 }
379
380 pub fn apply_constraints_to_param(&self, param: &mut Array<A, D>) -> Result<()>
382 where
383 A: ScalarOperand + Send + Sync,
384 {
385 for constraint in &self.config.constraints {
386 constraint.apply(param)?;
387 }
388 Ok(())
389 }
390
391 pub fn constraints(&self) -> &[ParameterConstraint<A>] {
393 &self.config.constraints
394 }
395}
396
397pub trait GroupedOptimizer<A: Float + ScalarOperand + Debug, D: Dimension>:
399 Optimizer<A, D>
400{
401 fn add_group(
403 &mut self,
404 params: Vec<Array<A, D>>,
405 config: ParameterGroupConfig<A>,
406 ) -> Result<usize>;
407
408 fn get_group(&self, groupid: usize) -> Result<&ParameterGroup<A, D>>;
410
411 fn get_group_mut(&mut self, groupid: usize) -> Result<&mut ParameterGroup<A, D>>;
413
414 fn groups(&self) -> &[ParameterGroup<A, D>];
416
417 fn groups_mut(&mut self) -> &mut [ParameterGroup<A, D>];
419
420 fn step_group(
422 &mut self,
423 group_id: usize,
424 gradients: &[Array<A, D>],
425 ) -> Result<Vec<Array<A, D>>>;
426
427 fn set_group_learning_rate(&mut self, groupid: usize, lr: A) -> Result<()>;
429
430 fn set_group_weight_decay(&mut self, groupid: usize, wd: A) -> Result<()>;
432}
433
434#[derive(Debug)]
436pub struct GroupManager<A: Float, D: Dimension> {
437 groups: Vec<ParameterGroup<A, D>>,
438 next_id: usize,
439}
440
441impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> Default for GroupManager<A, D> {
442 fn default() -> Self {
443 Self {
444 groups: Vec::new(),
445 next_id: 0,
446 }
447 }
448}
449
450impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GroupManager<A, D> {
451 pub fn new() -> Self {
453 Self::default()
454 }
455
456 pub fn add_group(
458 &mut self,
459 params: Vec<Array<A, D>>,
460 config: ParameterGroupConfig<A>,
461 ) -> usize {
462 let id = self.next_id;
463 self.next_id += 1;
464 self.groups.push(ParameterGroup::new(id, params, config));
465 id
466 }
467
468 pub fn get_group(&self, id: usize) -> Result<&ParameterGroup<A, D>> {
470 self.groups
471 .iter()
472 .find(|g| g.id == id)
473 .ok_or_else(|| OptimError::InvalidConfig(format!("Group {id} not found")))
474 }
475
476 pub fn get_group_mut(&mut self, id: usize) -> Result<&mut ParameterGroup<A, D>> {
478 self.groups
479 .iter_mut()
480 .find(|g| g.id == id)
481 .ok_or_else(|| OptimError::InvalidConfig(format!("Group {id} not found")))
482 }
483
484 pub fn groups(&self) -> &[ParameterGroup<A, D>] {
486 &self.groups
487 }
488
489 pub fn groups_mut(&mut self) -> &mut [ParameterGroup<A, D>] {
491 &mut self.groups
492 }
493
494 pub fn total_params(&self) -> usize {
496 self.groups.iter().map(|g| g.num_params()).sum()
497 }
498}
499
500pub mod checkpointing {
502 use super::*;
503
504 #[derive(Debug, Clone)]
506 pub struct OptimizerCheckpoint<A: Float, D: Dimension> {
507 pub step: usize,
509 pub groups: Vec<ParameterGroupCheckpoint<A, D>>,
511 pub global_state: HashMap<String, String>,
513 pub metadata: CheckpointMetadata,
515 }
516
517 #[derive(Debug, Clone)]
519 pub struct ParameterGroupCheckpoint<A: Float, D: Dimension> {
520 pub id: usize,
522 pub params: Vec<Array<A, D>>,
524 pub config: ParameterGroupConfig<A>,
526 pub state: HashMap<String, Vec<Array<A, D>>>,
528 }
529
530 #[derive(Debug, Clone)]
532 pub struct CheckpointMetadata {
533 pub timestamp: String,
535 pub optimizerversion: String,
537 pub custom: HashMap<String, String>,
539 }
540
541 impl CheckpointMetadata {
542 pub fn new(optimizerversion: String) -> Self {
544 use std::time::{SystemTime, UNIX_EPOCH};
545
546 let timestamp = SystemTime::now()
547 .duration_since(UNIX_EPOCH)
548 .unwrap_or_default()
549 .as_secs()
550 .to_string();
551
552 Self {
553 timestamp,
554 optimizerversion,
555 custom: HashMap::new(),
556 }
557 }
558
559 pub fn with_custom(mut self, key: String, value: String) -> Self {
561 self.custom.insert(key, value);
562 self
563 }
564 }
565
566 pub trait Checkpointable<
568 A: Float + ToString + std::fmt::Display + std::str::FromStr,
569 D: Dimension,
570 >
571 {
572 fn create_checkpoint(&self) -> Result<OptimizerCheckpoint<A, D>>;
574
575 fn restore_checkpoint(&mut self, checkpoint: &OptimizerCheckpoint<A, D>) -> Result<()>;
577
578 fn save_checkpoint<P: AsRef<Path>>(&self, path: P) -> Result<()> {
580 use std::fs::File;
581 use std::io::{BufWriter, Write};
582
583 let checkpoint = self.create_checkpoint()?;
584 let path = path.as_ref();
585
586 let file = File::create(path).map_err(|e| {
588 OptimError::InvalidConfig(format!("Failed to create checkpoint file: {e}"))
589 })?;
590 let mut writer = BufWriter::new(file);
591
592 writeln!(writer, "# ScirS2 Optimizer Checkpoint v1.0").map_err(|e| {
594 OptimError::InvalidConfig(format!("Failed to write checkpoint header: {e}"))
595 })?;
596 writeln!(writer, "# Timestamp: {}", checkpoint.metadata.timestamp).map_err(|e| {
597 OptimError::InvalidConfig(format!("Failed to write timestamp: {e}"))
598 })?;
599 writeln!(
600 writer,
601 "# Optimizer Version: {}",
602 checkpoint.metadata.optimizerversion
603 )
604 .map_err(|e| OptimError::InvalidConfig(format!("Failed to write version: {e}")))?;
605 writeln!(writer, "# Step: {}", checkpoint.step)
606 .map_err(|e| OptimError::InvalidConfig(format!("Failed to write step: {e}")))?;
607 writeln!(writer)
608 .map_err(|e| OptimError::InvalidConfig(format!("Failed to write newline: {e}")))?;
609
610 writeln!(writer, "[METADATA]").map_err(|e| {
612 OptimError::InvalidConfig(format!("Failed to write metadata section: {e}"))
613 })?;
614 for (key, value) in &checkpoint.metadata.custom {
615 writeln!(writer, "{}={}", key, value).map_err(|e| {
616 OptimError::InvalidConfig(format!("Failed to write metadata entry: {e}"))
617 })?;
618 }
619 writeln!(writer)
620 .map_err(|e| OptimError::InvalidConfig(format!("Failed to write newline: {e}")))?;
621
622 writeln!(writer, "[GLOBAL_STATE]").map_err(|e| {
624 OptimError::InvalidConfig(format!("Failed to write global state section: {e}"))
625 })?;
626 for (key, value) in &checkpoint.global_state {
627 writeln!(writer, "{}={}", key, value).map_err(|e| {
628 OptimError::InvalidConfig(format!("Failed to write global state entry: {e}"))
629 })?;
630 }
631 writeln!(writer)
632 .map_err(|e| OptimError::InvalidConfig(format!("Failed to write newline: {e}")))?;
633
634 writeln!(writer, "[GROUPS]").map_err(|e| {
636 OptimError::InvalidConfig(format!("Failed to write groups section: {e}"))
637 })?;
638 writeln!(writer, "count={}", checkpoint.groups.len()).map_err(|e| {
639 OptimError::InvalidConfig(format!("Failed to write group count: {e}"))
640 })?;
641 writeln!(writer)
642 .map_err(|e| OptimError::InvalidConfig(format!("Failed to write newline: {e}")))?;
643
644 for group in &checkpoint.groups {
645 writeln!(writer, "[GROUP_{}]", group.id).map_err(|e| {
647 OptimError::InvalidConfig(format!("Failed to write group header: {e}"))
648 })?;
649
650 writeln!(
652 writer,
653 "learning_rate={}",
654 group
655 .config
656 .learning_rate
657 .map(|lr| lr.to_string())
658 .unwrap_or_else(|| "None".to_string())
659 )
660 .map_err(|e| {
661 OptimError::InvalidConfig(format!("Failed to write learning rate: {e}"))
662 })?;
663 writeln!(
664 writer,
665 "weight_decay={}",
666 group
667 .config
668 .weight_decay
669 .map(|wd| wd.to_string())
670 .unwrap_or_else(|| "None".to_string())
671 )
672 .map_err(|e| {
673 OptimError::InvalidConfig(format!("Failed to write weight decay: {e}"))
674 })?;
675 writeln!(
676 writer,
677 "momentum={}",
678 group
679 .config
680 .momentum
681 .map(|m| m.to_string())
682 .unwrap_or_else(|| "None".to_string())
683 )
684 .map_err(|e| OptimError::InvalidConfig(format!("Failed to write momentum: {e}")))?;
685
686 writeln!(
688 writer,
689 "custom_params_count={}",
690 group.config.custom_params.len()
691 )
692 .map_err(|e| {
693 OptimError::InvalidConfig(format!("Failed to write custom params count: {e}"))
694 })?;
695 for (key, value) in &group.config.custom_params {
696 writeln!(writer, "custom_{}={}", key, value).map_err(|e| {
697 OptimError::InvalidConfig(format!("Failed to write custom param: {e}"))
698 })?;
699 }
700
701 writeln!(writer, "param_count={}", group.params.len()).map_err(|e| {
703 OptimError::InvalidConfig(format!("Failed to write param count: {e}"))
704 })?;
705 for (i, param) in group.params.iter().enumerate() {
706 writeln!(writer, "param_{}shape={:?}", i, param.shape()).map_err(|e| {
707 OptimError::InvalidConfig(format!("Failed to write param shape: {e}"))
708 })?;
709 write!(writer, "param_{}_data=", i).map_err(|e| {
710 OptimError::InvalidConfig(format!("Failed to write param data label: {e}"))
711 })?;
712
713 for (j, &val) in param.iter().enumerate() {
715 if j > 0 {
716 write!(writer, " ").map_err(|e| {
717 OptimError::InvalidConfig(format!("Failed to write space: {e}"))
718 })?;
719 }
720 write!(writer, "{}", val).map_err(|e| {
721 OptimError::InvalidConfig(format!("Failed to write value: {e}"))
722 })?;
723 }
724 writeln!(writer).map_err(|e| {
725 OptimError::InvalidConfig(format!("Failed to write newline: {e}"))
726 })?;
727 }
728
729 writeln!(writer, "state_count={}", group.state.len()).map_err(|e| {
731 OptimError::InvalidConfig(format!("Failed to write state count: {e}"))
732 })?;
733 for (state_name, state_arrays) in &group.state {
734 writeln!(writer, "state_name={}", state_name).map_err(|e| {
735 OptimError::InvalidConfig(format!("Failed to write state name: {e}"))
736 })?;
737 writeln!(writer, "state_array_count={}", state_arrays.len()).map_err(|e| {
738 OptimError::InvalidConfig(format!("Failed to write state array count: {e}"))
739 })?;
740 for (i, array) in state_arrays.iter().enumerate() {
741 writeln!(writer, "state_{}shape={:?}", i, array.shape()).map_err(|e| {
742 OptimError::InvalidConfig(format!("Failed to write state shape: {e}"))
743 })?;
744 write!(writer, "state_{}_data=", i).map_err(|e| {
745 OptimError::InvalidConfig(format!(
746 "Failed to write state data label: {}",
747 e
748 ))
749 })?;
750
751 for (j, &val) in array.iter().enumerate() {
753 if j > 0 {
754 write!(writer, " ").map_err(|e| {
755 OptimError::InvalidConfig(format!(
756 "Failed to write space: {}",
757 e
758 ))
759 })?;
760 }
761 write!(writer, "{}", val).map_err(|e| {
762 OptimError::InvalidConfig(format!("Failed to write value: {e}"))
763 })?;
764 }
765 writeln!(writer).map_err(|e| {
766 OptimError::InvalidConfig(format!("Failed to write newline: {e}"))
767 })?;
768 }
769 }
770
771 writeln!(writer).map_err(|e| {
772 OptimError::InvalidConfig(format!("Failed to write newline: {e}"))
773 })?;
774 }
775
776 writer.flush().map_err(|e| {
777 OptimError::InvalidConfig(format!("Failed to flush checkpoint file: {e}"))
778 })?;
779
780 Ok(())
781 }
782
783 fn load_checkpoint<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
785 use std::fs::File;
786 use std::io::{BufRead, BufReader};
787
788 let path = path.as_ref();
789 let file = File::open(path).map_err(|e| {
790 OptimError::InvalidConfig(format!("Failed to open checkpoint file: {e}"))
791 })?;
792 let reader = BufReader::new(file);
793 let mut lines = reader.lines();
794
795 let mut step = 0;
797 let mut optimizerversion = String::new();
798 let mut timestamp = String::new();
799
800 while let Some(Ok(line)) = lines.next() {
801 if line.starts_with("# Step: ") {
802 step = line.trim_start_matches("# Step: ").parse().map_err(|_| {
803 OptimError::InvalidConfig("Invalid step format".to_string())
804 })?;
805 } else if line.starts_with("# Optimizer Version: ") {
806 optimizerversion = line.trim_start_matches("# Optimizer Version: ").to_string();
807 } else if line.starts_with("# Timestamp: ") {
808 timestamp = line.trim_start_matches("# Timestamp: ").to_string();
809 } else if line.starts_with("[METADATA]") {
810 break;
811 }
812 }
813
814 let mut custom_metadata = HashMap::new();
816 while let Some(Ok(line)) = lines.next() {
817 if line.is_empty() || line.starts_with("[") {
818 if line.starts_with("[GLOBAL_STATE]") {
819 break;
820 }
821 continue;
822 }
823 if let Some((key, value)) = line.split_once('=') {
824 custom_metadata.insert(key.to_string(), value.to_string());
825 }
826 }
827
828 let mut global_state = HashMap::new();
830 while let Some(Ok(line)) = lines.next() {
831 if line.is_empty() || line.starts_with("[") {
832 if line.starts_with("[GROUPS]") {
833 break;
834 }
835 continue;
836 }
837 if let Some((key, value)) = line.split_once('=') {
838 global_state.insert(key.to_string(), value.to_string());
839 }
840 }
841
842 let mut group_count = 0;
844 while let Some(Ok(line)) = lines.next() {
845 if line.starts_with("count=") {
846 group_count = line.trim_start_matches("count=").parse().map_err(|_| {
847 OptimError::InvalidConfig("Invalid group count".to_string())
848 })?;
849 break;
850 }
851 }
852
853 let mut groups = Vec::new();
855 for _ in 0..group_count {
856 let mut group_id = 0;
858 while let Some(Ok(line)) = lines.next() {
859 if line.starts_with("[GROUP_") {
860 let id_str = line.trim_start_matches("[GROUP_").trim_end_matches(']');
861 group_id = id_str.parse().map_err(|_| {
862 OptimError::InvalidConfig("Invalid group ID".to_string())
863 })?;
864 break;
865 }
866 }
867
868 let mut learning_rate = None;
870 let mut weight_decay = None;
871 let mut momentum = None;
872 let mut custom_params = HashMap::new();
873 let mut _custom_params_count = 0;
874
875 while let Some(Ok(line)) = lines.next() {
876 if line.starts_with("learning_rate=") {
877 let val_str = line.trim_start_matches("learning_rate=");
878 if val_str != "None" {
879 learning_rate = Some(A::from_str(val_str).map_err(|_| {
880 OptimError::InvalidConfig("Invalid learning rate".to_string())
881 })?);
882 }
883 } else if line.starts_with("weight_decay=") {
884 let val_str = line.trim_start_matches("weight_decay=");
885 if val_str != "None" {
886 weight_decay = Some(A::from_str(val_str).map_err(|_| {
887 OptimError::InvalidConfig("Invalid weight decay".to_string())
888 })?);
889 }
890 } else if line.starts_with("momentum=") {
891 let val_str = line.trim_start_matches("momentum=");
892 if val_str != "None" {
893 momentum = Some(A::from_str(val_str).map_err(|_| {
894 OptimError::InvalidConfig("Invalid momentum".to_string())
895 })?);
896 }
897 } else if line.starts_with("custom_params_count=") {
898 _custom_params_count = line
899 .trim_start_matches("custom_params_count=")
900 .parse()
901 .map_err(|_| {
902 OptimError::InvalidConfig("Invalid custom params count".to_string())
903 })?;
904 } else if line.starts_with("custom_") {
905 if let Some((key_with_prefix, value)) = line.split_once('=') {
906 let key = key_with_prefix.trim_start_matches("custom_");
907 custom_params.insert(
908 key.to_string(),
909 A::from_str(value).map_err(|_| {
910 OptimError::InvalidConfig(
911 "Invalid custom param value".to_string(),
912 )
913 })?,
914 );
915 }
916 } else if line.starts_with("param_count=") {
917 break;
918 }
919 }
920
921 let config = ParameterGroupConfig {
923 learning_rate,
924 weight_decay,
925 momentum,
926 constraints: Vec::new(), custom_params,
928 };
929
930 let param_count: usize = lines
932 .next()
933 .ok_or_else(|| OptimError::InvalidConfig("Missing param count".to_string()))?
934 .map_err(|e| OptimError::InvalidConfig(format!("Failed to read line: {e}")))?
935 .trim_start_matches("param_count=")
936 .parse()
937 .map_err(|_| OptimError::InvalidConfig("Invalid param count".to_string()))?;
938
939 let mut params = Vec::new();
940 for i in 0..param_count {
941 let shape_line = lines
943 .next()
944 .ok_or_else(|| {
945 OptimError::InvalidConfig("Missing param shape".to_string())
946 })?
947 .map_err(|e| {
948 OptimError::InvalidConfig(format!("Failed to read line: {e}"))
949 })?;
950
951 let shape_str = shape_line
952 .trim_start_matches(&format!("param_{}shape=", i))
953 .trim_start_matches('[')
954 .trim_end_matches(']');
955
956 let shape: Vec<usize> = shape_str
957 .split(", ")
958 .map(|s| {
959 s.parse()
960 .map_err(|_| OptimError::InvalidConfig("Invalid shape".to_string()))
961 })
962 .collect::<Result<Vec<_>>>()?;
963
964 let data_line = lines
966 .next()
967 .ok_or_else(|| OptimError::InvalidConfig("Missing param data".to_string()))?
968 .map_err(|e| {
969 OptimError::InvalidConfig(format!("Failed to read line: {e}"))
970 })?;
971
972 let data_str = data_line.trim_start_matches(&format!("param_{}_data=", i));
973 let data: Vec<A> = data_str
974 .split(' ')
975 .filter(|s| !s.is_empty())
976 .map(|s| {
977 A::from_str(s).map_err(|_| {
978 OptimError::InvalidConfig("Invalid data value".to_string())
979 })
980 })
981 .collect::<Result<Vec<_>>>()?;
982
983 let array: Array<A, scirs2_core::ndarray::IxDyn> =
985 Array::from_shape_vec(shape, data).map_err(|e| {
986 OptimError::InvalidConfig(format!("Failed to create array: {e}"))
987 })?;
988 params.push(array);
989 }
990
991 let state_count: usize = lines
993 .next()
994 .ok_or_else(|| OptimError::InvalidConfig("Missing state count".to_string()))?
995 .map_err(|e| OptimError::InvalidConfig(format!("Failed to read line: {e}")))?
996 .trim_start_matches("state_count=")
997 .parse()
998 .map_err(|_| OptimError::InvalidConfig("Invalid state count".to_string()))?;
999
1000 let mut state = HashMap::new();
1001 for _ in 0..state_count {
1002 let state_name = lines
1003 .next()
1004 .ok_or_else(|| OptimError::InvalidConfig("Missing state name".to_string()))?
1005 .map_err(|e| {
1006 OptimError::InvalidConfig(format!("Failed to read line: {e}"))
1007 })?
1008 .trim_start_matches("state_name=")
1009 .to_string();
1010
1011 let array_count: usize = lines
1012 .next()
1013 .ok_or_else(|| {
1014 OptimError::InvalidConfig("Missing state array count".to_string())
1015 })?
1016 .map_err(|e| {
1017 OptimError::InvalidConfig(format!("Failed to read line: {e}"))
1018 })?
1019 .trim_start_matches("state_array_count=")
1020 .parse()
1021 .map_err(|_| {
1022 OptimError::InvalidConfig("Invalid state array count".to_string())
1023 })?;
1024
1025 let mut state_arrays = Vec::new();
1026 for i in 0..array_count {
1027 let shape_line = lines
1029 .next()
1030 .ok_or_else(|| {
1031 OptimError::InvalidConfig("Missing state shape".to_string())
1032 })?
1033 .map_err(|e| {
1034 OptimError::InvalidConfig(format!("Failed to read line: {e}"))
1035 })?;
1036
1037 let shape_str = shape_line
1038 .trim_start_matches(&format!("state_{}shape=", i))
1039 .trim_start_matches('[')
1040 .trim_end_matches(']');
1041
1042 let shape: Vec<usize> = shape_str
1043 .split(", ")
1044 .map(|s| {
1045 s.parse().map_err(|_| {
1046 OptimError::InvalidConfig("Invalid state shape".to_string())
1047 })
1048 })
1049 .collect::<Result<Vec<_>>>()?;
1050
1051 let data_line = lines
1053 .next()
1054 .ok_or_else(|| {
1055 OptimError::InvalidConfig("Missing state data".to_string())
1056 })?
1057 .map_err(|e| {
1058 OptimError::InvalidConfig(format!("Failed to read line: {e}"))
1059 })?;
1060
1061 let data_str = data_line.trim_start_matches(&format!("state_{}_data=", i));
1062 let data: Vec<A> = data_str
1063 .split(' ')
1064 .filter(|s| !s.is_empty())
1065 .map(|s| {
1066 A::from_str(s).map_err(|_| {
1067 OptimError::InvalidConfig("Invalid state value".to_string())
1068 })
1069 })
1070 .collect::<Result<Vec<_>>>()?;
1071
1072 let array = Array::from_shape_vec(shape, data).map_err(|e| {
1074 OptimError::InvalidConfig(format!("Failed to create state array: {e}"))
1075 })?;
1076 state_arrays.push(array);
1077 }
1078
1079 state.insert(state_name, state_arrays);
1080 }
1081
1082 groups.push(ParameterGroupCheckpoint {
1084 id: group_id,
1085 params,
1086 config,
1087 state,
1088 });
1089 }
1090
1091 let mut metadata = CheckpointMetadata::new(optimizerversion);
1093 metadata.timestamp = timestamp;
1094 metadata.custom = custom_metadata;
1095
1096 let _dyn_checkpoint = OptimizerCheckpoint::<A, scirs2_core::ndarray::IxDyn> {
1098 step,
1099 groups,
1100 global_state,
1101 metadata,
1102 };
1103
1104 Err(OptimError::InvalidConfig(
1118 "Checkpoint loading from file with dimension type conversion is not supported in v1.0.0. \
1119 Use CheckpointManager for in-memory checkpoints, or save/load with consistent dimension types. \
1120 See documentation for checkpoint best practices.".to_string(),
1121 ))
1122 }
1123 }
1124
1125 #[derive(Debug)]
1127 pub struct CheckpointManager<A: Float, D: Dimension> {
1128 checkpoints: HashMap<String, OptimizerCheckpoint<A, D>>,
1129 _maxcheckpoints: usize,
1130 checkpoint_keys: Vec<String>, }
1132
1133 impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> CheckpointManager<A, D> {
1134 pub fn new() -> Self {
1136 Self {
1137 checkpoints: HashMap::new(),
1138 _maxcheckpoints: 10,
1139 checkpoint_keys: Vec::new(),
1140 }
1141 }
1142
1143 pub fn with_max_checkpoints(_maxcheckpoints: usize) -> Self {
1145 Self {
1146 checkpoints: HashMap::new(),
1147 _maxcheckpoints,
1148 checkpoint_keys: Vec::new(),
1149 }
1150 }
1151
1152 pub fn store_checkpoint(&mut self, key: String, checkpoint: OptimizerCheckpoint<A, D>) {
1154 if self.checkpoints.contains_key(&key) {
1156 self.checkpoints.insert(key.clone(), checkpoint);
1157 return;
1158 }
1159
1160 if self.checkpoints.len() >= self._maxcheckpoints {
1162 if let Some(oldest_key) = self.checkpoint_keys.first().cloned() {
1163 self.checkpoints.remove(&oldest_key);
1164 self.checkpoint_keys.retain(|k| k != &oldest_key);
1165 }
1166 }
1167
1168 self.checkpoints.insert(key.clone(), checkpoint);
1170 self.checkpoint_keys.push(key);
1171 }
1172
1173 pub fn get_checkpoint(&self, key: &str) -> Option<&OptimizerCheckpoint<A, D>> {
1175 self.checkpoints.get(key)
1176 }
1177
1178 pub fn remove_checkpoint(&mut self, key: &str) -> Option<OptimizerCheckpoint<A, D>> {
1180 self.checkpoint_keys.retain(|k| k != key);
1181 self.checkpoints.remove(key)
1182 }
1183
1184 pub fn list_checkpoints(&self) -> &[String] {
1186 &self.checkpoint_keys
1187 }
1188
1189 pub fn clear(&mut self) {
1191 self.checkpoints.clear();
1192 self.checkpoint_keys.clear();
1193 }
1194
1195 pub fn len(&self) -> usize {
1197 self.checkpoints.len()
1198 }
1199
1200 pub fn is_empty(&self) -> bool {
1202 self.checkpoints.is_empty()
1203 }
1204 }
1205
1206 impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> Default
1207 for CheckpointManager<A, D>
1208 {
1209 fn default() -> Self {
1210 Self::new()
1211 }
1212 }
1213
1214 pub mod utils {
1216 use super::*;
1217
1218 pub fn create_checkpoint_from_groups<A: Float + ScalarOperand + Debug, D: Dimension>(
1220 step: usize,
1221 groups: &[ParameterGroup<A, D>],
1222 global_state: HashMap<String, String>,
1223 optimizerversion: String,
1224 ) -> OptimizerCheckpoint<A, D> {
1225 let group_checkpoints = groups
1226 .iter()
1227 .map(|group| ParameterGroupCheckpoint {
1228 id: group.id,
1229 params: group.params.clone(),
1230 config: group.config.clone(),
1231 state: group.state.clone(),
1232 })
1233 .collect();
1234
1235 OptimizerCheckpoint {
1236 step,
1237 groups: group_checkpoints,
1238 global_state,
1239 metadata: CheckpointMetadata::new(optimizerversion),
1240 }
1241 }
1242
1243 pub fn validate_checkpoint<A: Float, D: Dimension>(
1245 checkpoint: &OptimizerCheckpoint<A, D>,
1246 expected_groups: usize,
1247 ) -> Result<()> {
1248 if checkpoint.groups.len() != expected_groups {
1249 return Err(OptimError::InvalidConfig(format!(
1250 "Checkpoint has {} groups, expected {expected_groups}",
1251 checkpoint.groups.len()
1252 )));
1253 }
1254
1255 let mut ids = std::collections::HashSet::new();
1257 for group in &checkpoint.groups {
1258 if !ids.insert(group.id) {
1259 return Err(OptimError::InvalidConfig(format!(
1260 "Duplicate group ID {} in checkpoint",
1261 group.id
1262 )));
1263 }
1264 }
1265
1266 Ok(())
1267 }
1268
1269 pub fn checkpoint_summary<A: Float, D: Dimension>(
1271 checkpoint: &OptimizerCheckpoint<A, D>,
1272 ) -> String {
1273 let total_params: usize = checkpoint
1274 .groups
1275 .iter()
1276 .map(|g| g.params.iter().map(|p| p.len()).sum::<usize>())
1277 .sum();
1278
1279 format!(
1280 "Checkpoint at step {}: {} groups, {} total parameters, created at {}",
1281 checkpoint.step,
1282 checkpoint.groups.len(),
1283 total_params,
1284 checkpoint.metadata.timestamp
1285 )
1286 }
1287 }
1288}
1289
1290#[cfg(test)]
1291mod tests {
1292 use super::*;
1293 use scirs2_core::ndarray::Array1;
1294
1295 #[test]
1296 fn test_parameter_group_config() {
1297 let config = ParameterGroupConfig::new()
1298 .with_learning_rate(0.01)
1299 .with_weight_decay(0.0001)
1300 .with_momentum(0.9)
1301 .with_custom_param("beta1".to_string(), 0.9)
1302 .with_custom_param("beta2".to_string(), 0.999);
1303
1304 assert_eq!(config.learning_rate, Some(0.01));
1305 assert_eq!(config.weight_decay, Some(0.0001));
1306 assert_eq!(config.momentum, Some(0.9));
1307 assert_eq!(config.custom_params.get("beta1"), Some(&0.9));
1308 assert_eq!(config.custom_params.get("beta2"), Some(&0.999));
1309 }
1310
1311 #[test]
1312 fn test_parameter_group() {
1313 let params = vec![Array1::zeros(5), Array1::ones(3)];
1314 let config = ParameterGroupConfig::new().with_learning_rate(0.01);
1315
1316 let group = ParameterGroup::new(0, params, config);
1317
1318 assert_eq!(group.id, 0);
1319 assert_eq!(group.num_params(), 2);
1320 assert_eq!(group.learning_rate(0.001), 0.01);
1321 assert_eq!(group.weight_decay(0.0), 0.0);
1322 }
1323
1324 #[test]
1325 fn test_group_manager() {
1326 let mut manager: GroupManager<f64, scirs2_core::ndarray::Ix1> = GroupManager::new();
1327
1328 let params1 = vec![Array1::zeros(5)];
1330 let config1 = ParameterGroupConfig::new().with_learning_rate(0.01);
1331 let id1 = manager.add_group(params1, config1);
1332
1333 let params2 = vec![Array1::ones(3), Array1::zeros(4)];
1335 let config2 = ParameterGroupConfig::new().with_learning_rate(0.001);
1336 let id2 = manager.add_group(params2, config2);
1337
1338 assert_eq!(id1, 0);
1339 assert_eq!(id2, 1);
1340 assert_eq!(manager.groups().len(), 2);
1341 assert_eq!(manager.total_params(), 3);
1342
1343 let group1 = manager.get_group(id1).unwrap();
1345 assert_eq!(group1.learning_rate(0.0), 0.01);
1346
1347 let group2 = manager.get_group(id2).unwrap();
1348 assert_eq!(group2.learning_rate(0.0), 0.001);
1349 }
1350
1351 #[test]
1352 fn test_parameter_constraints() {
1353 use approx::assert_relative_eq;
1354
1355 let mut params = Array1::from_vec(vec![-2.0, 0.5, 3.0]);
1357 let clip_constraint = ParameterConstraint::ValueClip { min: 0.0, max: 1.0 };
1358 clip_constraint.apply(&mut params).unwrap();
1359 assert_eq!(params.as_slice().unwrap(), &[0.0, 0.5, 1.0]);
1360
1361 let mut params = Array1::from_vec(vec![3.0, 4.0]); let l2_constraint = ParameterConstraint::L2NormConstraint { maxnorm: 2.0 };
1364 l2_constraint.apply(&mut params).unwrap();
1365 let new_norm = params.mapv(|x| x * x).sum().sqrt();
1366 assert_relative_eq!(new_norm, 2.0, epsilon = 1e-6);
1367
1368 let mut params = Array1::from_vec(vec![-1.0, 2.0, -3.0]);
1370 let non_neg_constraint = ParameterConstraint::NonNegative;
1371 non_neg_constraint.apply(&mut params).unwrap();
1372 assert_eq!(params.as_slice().unwrap(), &[0.0, 2.0, 0.0]);
1373
1374 let mut params = Array1::from_vec(vec![3.0, 4.0]); let unit_sphere_constraint = ParameterConstraint::UnitSphere;
1377 unit_sphere_constraint.apply(&mut params).unwrap();
1378 let new_norm = params.mapv(|x| x * x).sum().sqrt();
1379 assert_relative_eq!(new_norm, 1.0, epsilon = 1e-6);
1380 }
1381
1382 #[test]
1383 fn test_parameter_group_with_constraints() {
1384 let params = vec![Array1::from_vec(vec![-2.0, 3.0])];
1385 let config = ParameterGroupConfig::new()
1386 .with_learning_rate(0.01)
1387 .with_value_clip(0.0, 1.0);
1388
1389 let mut group = ParameterGroup::new(0, params, config);
1390
1391 group.apply_constraints().unwrap();
1393
1394 assert_eq!(group.params[0].as_slice().unwrap(), &[0.0, 1.0]);
1396 }
1397
1398 #[test]
1399 fn test_parameter_config_builder() {
1400 let config = ParameterGroupConfig::new()
1401 .with_learning_rate(0.01)
1402 .with_l2_norm_constraint(1.0)
1403 .with_non_negative()
1404 .with_custom_param("beta".to_string(), 0.9);
1405
1406 assert_eq!(config.learning_rate, Some(0.01));
1407 assert_eq!(config.constraints.len(), 2);
1408 assert_eq!(config.custom_params.get("beta"), Some(&0.9));
1409 }
1410
1411 #[test]
1412 fn test_simplex_constraint() {
1413 use approx::assert_relative_eq;
1414
1415 let mut params = Array1::from_vec(vec![2.0, 3.0, 5.0]);
1417 let simplex_constraint = ParameterConstraint::Simplex;
1418 simplex_constraint.apply(&mut params).unwrap();
1419
1420 let sum: f64 = params.sum();
1422 assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
1423 assert!(params.iter().all(|&x| x >= 0.0));
1424
1425 assert_relative_eq!(params[0], 0.2, epsilon = 1e-6); assert_relative_eq!(params[1], 0.3, epsilon = 1e-6); assert_relative_eq!(params[2], 0.5, epsilon = 1e-6); }
1430
1431 #[test]
1432 fn test_simplex_constraint_with_negatives() {
1433 use approx::assert_relative_eq;
1434
1435 let mut params = Array1::from_vec(vec![-1.0, 2.0, 3.0]);
1437 let simplex_constraint = ParameterConstraint::Simplex;
1438 simplex_constraint.apply(&mut params).unwrap();
1439
1440 let sum: f64 = params.sum();
1442 assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
1443 assert!(params.iter().all(|&x| x >= 0.0));
1444
1445 assert_relative_eq!(params[0], 0.0, epsilon = 1e-6);
1447 assert_relative_eq!(params[1], 0.4, epsilon = 1e-6); assert_relative_eq!(params[2], 0.6, epsilon = 1e-6); }
1450
1451 #[test]
1452 fn test_simplex_constraint_all_zeros() {
1453 use approx::assert_relative_eq;
1454
1455 let mut params = Array1::from_vec(vec![0.0, 0.0, 0.0]);
1457 let simplex_constraint = ParameterConstraint::Simplex;
1458 simplex_constraint.apply(&mut params).unwrap();
1459
1460 let sum: f64 = params.sum();
1462 assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
1463 for &val in params.iter() {
1464 assert_relative_eq!(val, 1.0 / 3.0, epsilon = 1e-6);
1465 }
1466 }
1467
1468 #[test]
1469 fn test_spectral_norm_constraint() {
1470 use approx::assert_relative_eq;
1471
1472 let mut params = Array1::from_vec(vec![3.0, 4.0]); let spectral_constraint = ParameterConstraint::SpectralNorm { maxnorm: 2.0 };
1475 spectral_constraint.apply(&mut params).unwrap();
1476
1477 let new_norm = params.mapv(|x| x * x).sum().sqrt();
1478 assert_relative_eq!(new_norm, 2.0, epsilon = 1e-6);
1479 }
1480
1481 #[test]
1482 fn test_nuclear_norm_constraint() {
1483 use approx::assert_relative_eq;
1484
1485 let mut params = Array1::from_vec(vec![3.0, -4.0, 2.0]); let nuclear_constraint = ParameterConstraint::NuclearNorm { maxnorm: 3.0 };
1488 nuclear_constraint.apply(&mut params).unwrap();
1489
1490 let new_l1_norm = params.mapv(|x| x.abs()).sum();
1491 assert_relative_eq!(new_l1_norm, 3.0, epsilon = 1e-6);
1492 }
1493
1494 #[test]
1495 fn test_orthogonal_constraint_error() {
1496 let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1498 let orthogonal_constraint = ParameterConstraint::Orthogonal { tolerance: 1e-6 };
1499 let result = orthogonal_constraint.apply(&mut params);
1500
1501 assert!(result.is_err());
1502 assert!(result.unwrap_err().to_string().contains("2D arrays"));
1503 }
1504
1505 #[test]
1506 fn test_positive_definite_constraint_error() {
1507 let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1509 let pd_constraint = ParameterConstraint::PositiveDefinite {
1510 mineigenvalue: 0.01,
1511 };
1512 let result = pd_constraint.apply(&mut params);
1513
1514 assert!(result.is_err());
1515 assert!(result.unwrap_err().to_string().contains("eigenvalue"));
1516 }
1517
1518 #[test]
1519 fn test_enhanced_config_builder() {
1520 let config = ParameterGroupConfig::new()
1521 .with_learning_rate(0.01)
1522 .with_simplex()
1523 .with_spectral_norm(2.0)
1524 .with_nuclear_norm(1.5)
1525 .with_custom_constraint("my_constraint".to_string());
1526
1527 assert_eq!(config.learning_rate, Some(0.01));
1528 assert_eq!(config.constraints.len(), 4);
1529
1530 match &config.constraints[0] {
1532 ParameterConstraint::Simplex => (),
1533 _ => panic!("Expected Simplex constraint"),
1534 }
1535
1536 match &config.constraints[1] {
1537 ParameterConstraint::SpectralNorm { maxnorm } => {
1538 assert_eq!(*maxnorm, 2.0);
1539 }
1540 _ => panic!("Expected SpectralNorm constraint"),
1541 }
1542 }
1543
1544 #[test]
1545 fn test_constraint_combination() {
1546 use approx::assert_relative_eq;
1547
1548 let params = vec![Array1::from_vec(vec![-1.0, 2.0, 3.0])];
1550 let config = ParameterGroupConfig::new()
1551 .with_learning_rate(0.01)
1552 .with_non_negative()
1553 .with_simplex();
1554
1555 let mut group = ParameterGroup::new(0, params, config);
1556
1557 group.apply_constraints().unwrap();
1559
1560 let result = &group.params[0];
1562 let sum: f64 = result.sum();
1563 assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
1564 assert!(result.iter().all(|&x| x >= 0.0));
1565
1566 assert_relative_eq!(result[0], 0.0, epsilon = 1e-6);
1568 assert_relative_eq!(result[1], 0.4, epsilon = 1e-6);
1569 assert_relative_eq!(result[2], 0.6, epsilon = 1e-6);
1570 }
1571}