optirs_core/parameter_groups/
mod.rs

1// Parameter groups for different learning rates and configurations
2//
3// This module provides support for parameter groups, allowing different
4// sets of parameters to have different hyperparameters (learning rate,
5// weight decay, etc.) within the same optimizer.
6
7use crate::error::{OptimError, Result};
8use crate::optimizers::Optimizer;
9use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
10use scirs2_core::numeric::Float;
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::path::Path;
14
15/// Parameter constraints that can be applied to parameter groups
16#[derive(Debug, Clone)]
17pub enum ParameterConstraint<A: Float> {
18    /// Clip values to a range [min, max]
19    ValueClip {
20        /// Minimum allowed value
21        min: A,
22        /// Maximum allowed value
23        max: A,
24    },
25    /// Constrain L2 norm to a maximum value
26    L2NormConstraint {
27        /// Maximum allowed L2 norm
28        maxnorm: A,
29    },
30    /// Constrain L1 norm to a maximum value
31    L1NormConstraint {
32        /// Maximum allowed L1 norm
33        maxnorm: A,
34    },
35    /// Ensure all values are non-negative
36    NonNegative,
37    /// Constrain to unit sphere (normalize to unit L2 norm)
38    UnitSphere,
39    /// Constrain parameters to be within a probability simplex (sum to 1, all non-negative)
40    Simplex,
41    /// Constrain matrix parameters to be orthogonal
42    Orthogonal {
43        /// Tolerance for orthogonality check
44        tolerance: A,
45    },
46    /// Constrain symmetric matrices to be positive definite
47    PositiveDefinite {
48        /// Minimum eigenvalue to ensure positive definiteness
49        mineigenvalue: A,
50    },
51    /// Spectral norm constraint (maximum singular value)
52    SpectralNorm {
53        /// Maximum allowed spectral norm
54        maxnorm: A,
55    },
56    /// Nuclear norm constraint (sum of singular values)
57    NuclearNorm {
58        /// Maximum allowed nuclear norm
59        maxnorm: A,
60    },
61    /// Custom constraint function
62    Custom {
63        /// Name of the custom constraint
64        name: String,
65    },
66}
67
68impl<A: Float + Send + Sync> ParameterConstraint<A> {
69    /// Apply the constraint to a parameter array
70    pub fn apply<D: Dimension>(&self, params: &mut Array<A, D>) -> Result<()>
71    where
72        A: ScalarOperand,
73    {
74        match self {
75            ParameterConstraint::ValueClip { min, max } => {
76                params.mapv_inplace(|x| {
77                    if x < *min {
78                        *min
79                    } else if x > *max {
80                        *max
81                    } else {
82                        x
83                    }
84                });
85            }
86            ParameterConstraint::L2NormConstraint { maxnorm } => {
87                let norm = params.mapv(|x| x * x).sum().sqrt();
88                if norm > *maxnorm {
89                    let scale = *maxnorm / norm;
90                    params.mapv_inplace(|x| x * scale);
91                }
92            }
93            ParameterConstraint::L1NormConstraint { maxnorm } => {
94                let norm = params.mapv(|x| x.abs()).sum();
95                if norm > *maxnorm {
96                    let scale = *maxnorm / norm;
97                    params.mapv_inplace(|x| x * scale);
98                }
99            }
100            ParameterConstraint::NonNegative => {
101                params.mapv_inplace(|x| if x < A::zero() { A::zero() } else { x });
102            }
103            ParameterConstraint::UnitSphere => {
104                let norm = params.mapv(|x| x * x).sum().sqrt();
105                if norm > A::zero() {
106                    let scale = A::one() / norm;
107                    params.mapv_inplace(|x| x * scale);
108                }
109            }
110            ParameterConstraint::Simplex => {
111                // First make all values non-negative
112                params.mapv_inplace(|x| if x < A::zero() { A::zero() } else { x });
113
114                // Then normalize to sum to 1
115                let sum = params.sum();
116                if sum > A::zero() {
117                    let scale = A::one() / sum;
118                    params.mapv_inplace(|x| x * scale);
119                } else {
120                    // If all values are zero, set to uniform distribution
121                    let uniform_val = A::one() / A::from(params.len()).unwrap_or(A::one());
122                    params.fill(uniform_val);
123                }
124            }
125            ParameterConstraint::Orthogonal { tolerance: _ } => {
126                // For now, implement a simple orthogonal projection for matrices
127                // This is a simplified implementation - full orthogonal constraints
128                // would require SVD decomposition
129                if params.ndim() == 2 {
130                    // Apply Gram-Schmidt process for small matrices
131                    // For large matrices, this would need SVD-based orthogonalization
132                    return Err(OptimError::InvalidConfig(
133                        "Orthogonal constraint requires specialized linear algebra operations"
134                            .to_string(),
135                    ));
136                } else {
137                    return Err(OptimError::InvalidConfig(
138                        "Orthogonal constraint only applies to 2D arrays (matrices)".to_string(),
139                    ));
140                }
141            }
142            ParameterConstraint::PositiveDefinite { mineigenvalue: _ } => {
143                // Positive definite constraint requires eigenvalue computation
144                return Err(OptimError::InvalidConfig(
145                    "Positive definite constraint requires specialized eigenvalue operations"
146                        .to_string(),
147                ));
148            }
149            ParameterConstraint::SpectralNorm { maxnorm } => {
150                // Spectral norm constraint requires SVD computation
151                // For now, approximate with Frobenius norm
152                let frobenius_norm = params.mapv(|x| x * x).sum().sqrt();
153                if frobenius_norm > *maxnorm {
154                    let scale = *maxnorm / frobenius_norm;
155                    params.mapv_inplace(|x| x * scale);
156                }
157            }
158            ParameterConstraint::NuclearNorm { maxnorm } => {
159                // Nuclear norm constraint requires SVD computation
160                // For now, approximate with L1 norm
161                let l1_norm = params.mapv(|x| x.abs()).sum();
162                if l1_norm > *maxnorm {
163                    let scale = *maxnorm / l1_norm;
164                    params.mapv_inplace(|x| x * scale);
165                }
166            }
167            ParameterConstraint::Custom { name } => {
168                return Err(OptimError::InvalidConfig(format!(
169                    "Custom constraint '{name}' not implemented"
170                )));
171            }
172        }
173        Ok(())
174    }
175}
176
177/// Configuration for a parameter group
178#[derive(Debug, Clone)]
179pub struct ParameterGroupConfig<A: Float> {
180    /// Learning rate for this group
181    pub learning_rate: Option<A>,
182    /// Weight decay for this group
183    pub weight_decay: Option<A>,
184    /// Momentum for this group (if applicable)
185    pub momentum: Option<A>,
186    /// Parameter constraints for this group
187    pub constraints: Vec<ParameterConstraint<A>>,
188    /// Custom parameters as key-value pairs
189    pub custom_params: HashMap<String, A>,
190}
191
192impl<A: Float + Send + Sync> Default for ParameterGroupConfig<A> {
193    fn default() -> Self {
194        Self {
195            learning_rate: None,
196            weight_decay: None,
197            momentum: None,
198            constraints: Vec::new(),
199            custom_params: HashMap::new(),
200        }
201    }
202}
203
204impl<A: Float + Send + Sync> ParameterGroupConfig<A> {
205    /// Create a new parameter group configuration
206    pub fn new() -> Self {
207        Self::default()
208    }
209
210    /// Set learning rate
211    pub fn with_learning_rate(mut self, lr: A) -> Self {
212        self.learning_rate = Some(lr);
213        self
214    }
215
216    /// Set weight decay
217    pub fn with_weight_decay(mut self, wd: A) -> Self {
218        self.weight_decay = Some(wd);
219        self
220    }
221
222    /// Set momentum
223    pub fn with_momentum(mut self, momentum: A) -> Self {
224        self.momentum = Some(momentum);
225        self
226    }
227
228    /// Add custom parameter
229    pub fn with_custom_param(mut self, key: String, value: A) -> Self {
230        self.custom_params.insert(key, value);
231        self
232    }
233
234    /// Add a parameter constraint
235    pub fn with_constraint(mut self, constraint: ParameterConstraint<A>) -> Self {
236        self.constraints.push(constraint);
237        self
238    }
239
240    /// Add value clipping constraint
241    pub fn with_value_clip(mut self, min: A, max: A) -> Self {
242        self.constraints
243            .push(ParameterConstraint::ValueClip { min, max });
244        self
245    }
246
247    /// Add L2 norm constraint
248    pub fn with_l2_norm_constraint(mut self, maxnorm: A) -> Self {
249        self.constraints
250            .push(ParameterConstraint::L2NormConstraint { maxnorm });
251        self
252    }
253
254    /// Add L1 norm constraint
255    pub fn with_l1_norm_constraint(mut self, maxnorm: A) -> Self {
256        self.constraints
257            .push(ParameterConstraint::L1NormConstraint { maxnorm });
258        self
259    }
260
261    /// Add non-negativity constraint
262    pub fn with_non_negative(mut self) -> Self {
263        self.constraints.push(ParameterConstraint::NonNegative);
264        self
265    }
266
267    /// Add unit sphere constraint
268    pub fn with_unit_sphere(mut self) -> Self {
269        self.constraints.push(ParameterConstraint::UnitSphere);
270        self
271    }
272
273    /// Add simplex constraint (sum to 1, all non-negative)
274    pub fn with_simplex(mut self) -> Self {
275        self.constraints.push(ParameterConstraint::Simplex);
276        self
277    }
278
279    /// Add orthogonal constraint for matrices
280    pub fn with_orthogonal(mut self, tolerance: A) -> Self {
281        self.constraints
282            .push(ParameterConstraint::Orthogonal { tolerance });
283        self
284    }
285
286    /// Add positive definite constraint for symmetric matrices
287    pub fn with_positive_definite(mut self, mineigenvalue: A) -> Self {
288        self.constraints
289            .push(ParameterConstraint::PositiveDefinite { mineigenvalue });
290        self
291    }
292
293    /// Add spectral norm constraint
294    pub fn with_spectral_norm(mut self, maxnorm: A) -> Self {
295        self.constraints
296            .push(ParameterConstraint::SpectralNorm { maxnorm });
297        self
298    }
299
300    /// Add nuclear norm constraint
301    pub fn with_nuclear_norm(mut self, maxnorm: A) -> Self {
302        self.constraints
303            .push(ParameterConstraint::NuclearNorm { maxnorm });
304        self
305    }
306
307    /// Add custom constraint
308    pub fn with_custom_constraint(mut self, name: String) -> Self {
309        self.constraints.push(ParameterConstraint::Custom { name });
310        self
311    }
312}
313
314/// A parameter group with its own configuration
315#[derive(Debug)]
316pub struct ParameterGroup<A: Float, D: Dimension> {
317    /// Unique identifier for this group
318    pub id: usize,
319    /// Parameters in this group
320    pub params: Vec<Array<A, D>>,
321    /// Configuration for this group
322    pub config: ParameterGroupConfig<A>,
323    /// Internal state for optimization (optimizer-specific)
324    pub state: HashMap<String, Vec<Array<A, D>>>,
325}
326
327impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> ParameterGroup<A, D> {
328    /// Create a new parameter group
329    pub fn new(id: usize, params: Vec<Array<A, D>>, config: ParameterGroupConfig<A>) -> Self {
330        Self {
331            id,
332            params,
333            config,
334            state: HashMap::new(),
335        }
336    }
337
338    /// Get the number of parameters in this group
339    pub fn num_params(&self) -> usize {
340        self.params.len()
341    }
342
343    /// Get learning rate for this group
344    pub fn learning_rate(&self, default: A) -> A {
345        self.config.learning_rate.unwrap_or(default)
346    }
347
348    /// Get weight decay for this group
349    pub fn weight_decay(&self, default: A) -> A {
350        self.config.weight_decay.unwrap_or(default)
351    }
352
353    /// Get momentum for this group
354    pub fn momentum(&self, default: A) -> A {
355        self.config.momentum.unwrap_or(default)
356    }
357
358    /// Get custom parameter
359    pub fn get_custom_param(&self, key: &str, default: A) -> A {
360        self.config
361            .custom_params
362            .get(key)
363            .copied()
364            .unwrap_or(default)
365    }
366
367    /// Apply constraints to all parameters in this group
368    pub fn apply_constraints(&mut self) -> Result<()>
369    where
370        A: ScalarOperand + Send + Sync,
371    {
372        for constraint in &self.config.constraints {
373            for param in &mut self.params {
374                constraint.apply(param)?;
375            }
376        }
377        Ok(())
378    }
379
380    /// Apply constraints to a specific parameter
381    pub fn apply_constraints_to_param(&self, param: &mut Array<A, D>) -> Result<()>
382    where
383        A: ScalarOperand + Send + Sync,
384    {
385        for constraint in &self.config.constraints {
386            constraint.apply(param)?;
387        }
388        Ok(())
389    }
390
391    /// Get the constraints for this group
392    pub fn constraints(&self) -> &[ParameterConstraint<A>] {
393        &self.config.constraints
394    }
395}
396
397/// Optimizer with parameter group support
398pub trait GroupedOptimizer<A: Float + ScalarOperand + Debug, D: Dimension>:
399    Optimizer<A, D>
400{
401    /// Add a parameter group
402    fn add_group(
403        &mut self,
404        params: Vec<Array<A, D>>,
405        config: ParameterGroupConfig<A>,
406    ) -> Result<usize>;
407
408    /// Get parameter group by ID
409    fn get_group(&self, groupid: usize) -> Result<&ParameterGroup<A, D>>;
410
411    /// Get mutable parameter group by ID
412    fn get_group_mut(&mut self, groupid: usize) -> Result<&mut ParameterGroup<A, D>>;
413
414    /// Get all parameter groups
415    fn groups(&self) -> &[ParameterGroup<A, D>];
416
417    /// Get all parameter groups mutably
418    fn groups_mut(&mut self) -> &mut [ParameterGroup<A, D>];
419
420    /// Step for a specific group
421    fn step_group(
422        &mut self,
423        group_id: usize,
424        gradients: &[Array<A, D>],
425    ) -> Result<Vec<Array<A, D>>>;
426
427    /// Set learning rate for a specific group
428    fn set_group_learning_rate(&mut self, groupid: usize, lr: A) -> Result<()>;
429
430    /// Set weight decay for a specific group
431    fn set_group_weight_decay(&mut self, groupid: usize, wd: A) -> Result<()>;
432}
433
434/// Helper struct for managing parameter groups
435#[derive(Debug)]
436pub struct GroupManager<A: Float, D: Dimension> {
437    groups: Vec<ParameterGroup<A, D>>,
438    next_id: usize,
439}
440
441impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> Default for GroupManager<A, D> {
442    fn default() -> Self {
443        Self {
444            groups: Vec::new(),
445            next_id: 0,
446        }
447    }
448}
449
450impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GroupManager<A, D> {
451    /// Create a new group manager
452    pub fn new() -> Self {
453        Self::default()
454    }
455
456    /// Add a new parameter group
457    pub fn add_group(
458        &mut self,
459        params: Vec<Array<A, D>>,
460        config: ParameterGroupConfig<A>,
461    ) -> usize {
462        let id = self.next_id;
463        self.next_id += 1;
464        self.groups.push(ParameterGroup::new(id, params, config));
465        id
466    }
467
468    /// Get group by ID
469    pub fn get_group(&self, id: usize) -> Result<&ParameterGroup<A, D>> {
470        self.groups
471            .iter()
472            .find(|g| g.id == id)
473            .ok_or_else(|| OptimError::InvalidConfig(format!("Group {id} not found")))
474    }
475
476    /// Get mutable group by ID
477    pub fn get_group_mut(&mut self, id: usize) -> Result<&mut ParameterGroup<A, D>> {
478        self.groups
479            .iter_mut()
480            .find(|g| g.id == id)
481            .ok_or_else(|| OptimError::InvalidConfig(format!("Group {id} not found")))
482    }
483
484    /// Get all groups
485    pub fn groups(&self) -> &[ParameterGroup<A, D>] {
486        &self.groups
487    }
488
489    /// Get all groups mutably
490    pub fn groups_mut(&mut self) -> &mut [ParameterGroup<A, D>] {
491        &mut self.groups
492    }
493
494    /// Get total number of parameters across all groups
495    pub fn total_params(&self) -> usize {
496        self.groups.iter().map(|g| g.num_params()).sum()
497    }
498}
499
500/// State checkpointing for parameter management
501pub mod checkpointing {
502    use super::*;
503
504    /// Checkpoint data for optimizer state
505    #[derive(Debug, Clone)]
506    pub struct OptimizerCheckpoint<A: Float, D: Dimension> {
507        /// Step number
508        pub step: usize,
509        /// Parameter groups
510        pub groups: Vec<ParameterGroupCheckpoint<A, D>>,
511        /// Global optimizer state
512        pub global_state: HashMap<String, String>,
513        /// Metadata
514        pub metadata: CheckpointMetadata,
515    }
516
517    /// Checkpoint data for a parameter group
518    #[derive(Debug, Clone)]
519    pub struct ParameterGroupCheckpoint<A: Float, D: Dimension> {
520        /// Group ID
521        pub id: usize,
522        /// Parameters
523        pub params: Vec<Array<A, D>>,
524        /// Group configuration
525        pub config: ParameterGroupConfig<A>,
526        /// Optimizer-specific state for this group
527        pub state: HashMap<String, Vec<Array<A, D>>>,
528    }
529
530    /// Metadata for checkpoints
531    #[derive(Debug, Clone)]
532    pub struct CheckpointMetadata {
533        /// Timestamp when checkpoint was created
534        pub timestamp: String,
535        /// Version of the optimizer
536        pub optimizerversion: String,
537        /// Custom metadata
538        pub custom: HashMap<String, String>,
539    }
540
541    impl CheckpointMetadata {
542        /// Create new metadata with current timestamp
543        pub fn new(optimizerversion: String) -> Self {
544            use std::time::{SystemTime, UNIX_EPOCH};
545
546            let timestamp = SystemTime::now()
547                .duration_since(UNIX_EPOCH)
548                .unwrap_or_default()
549                .as_secs()
550                .to_string();
551
552            Self {
553                timestamp,
554                optimizerversion,
555                custom: HashMap::new(),
556            }
557        }
558
559        /// Add custom metadata
560        pub fn with_custom(mut self, key: String, value: String) -> Self {
561            self.custom.insert(key, value);
562            self
563        }
564    }
565
566    /// Trait for optimizers that support checkpointing
567    pub trait Checkpointable<
568        A: Float + ToString + std::fmt::Display + std::str::FromStr,
569        D: Dimension,
570    >
571    {
572        /// Create a checkpoint of the current optimizer state
573        fn create_checkpoint(&self) -> Result<OptimizerCheckpoint<A, D>>;
574
575        /// Restore optimizer state from a checkpoint
576        fn restore_checkpoint(&mut self, checkpoint: &OptimizerCheckpoint<A, D>) -> Result<()>;
577
578        /// Save checkpoint to file (simple text format)
579        fn save_checkpoint<P: AsRef<Path>>(&self, path: P) -> Result<()> {
580            use std::fs::File;
581            use std::io::{BufWriter, Write};
582
583            let checkpoint = self.create_checkpoint()?;
584            let path = path.as_ref();
585
586            // Create the file
587            let file = File::create(path).map_err(|e| {
588                OptimError::InvalidConfig(format!("Failed to create checkpoint file: {e}"))
589            })?;
590            let mut writer = BufWriter::new(file);
591
592            // Write header
593            writeln!(writer, "# ScirS2 Optimizer Checkpoint v1.0").map_err(|e| {
594                OptimError::InvalidConfig(format!("Failed to write checkpoint header: {e}"))
595            })?;
596            writeln!(writer, "# Timestamp: {}", checkpoint.metadata.timestamp).map_err(|e| {
597                OptimError::InvalidConfig(format!("Failed to write timestamp: {e}"))
598            })?;
599            writeln!(
600                writer,
601                "# Optimizer Version: {}",
602                checkpoint.metadata.optimizerversion
603            )
604            .map_err(|e| OptimError::InvalidConfig(format!("Failed to write version: {e}")))?;
605            writeln!(writer, "# Step: {}", checkpoint.step)
606                .map_err(|e| OptimError::InvalidConfig(format!("Failed to write step: {e}")))?;
607            writeln!(writer)
608                .map_err(|e| OptimError::InvalidConfig(format!("Failed to write newline: {e}")))?;
609
610            // Write custom metadata
611            writeln!(writer, "[METADATA]").map_err(|e| {
612                OptimError::InvalidConfig(format!("Failed to write metadata section: {e}"))
613            })?;
614            for (key, value) in &checkpoint.metadata.custom {
615                writeln!(writer, "{}={}", key, value).map_err(|e| {
616                    OptimError::InvalidConfig(format!("Failed to write metadata entry: {e}"))
617                })?;
618            }
619            writeln!(writer)
620                .map_err(|e| OptimError::InvalidConfig(format!("Failed to write newline: {e}")))?;
621
622            // Write global state
623            writeln!(writer, "[GLOBAL_STATE]").map_err(|e| {
624                OptimError::InvalidConfig(format!("Failed to write global state section: {e}"))
625            })?;
626            for (key, value) in &checkpoint.global_state {
627                writeln!(writer, "{}={}", key, value).map_err(|e| {
628                    OptimError::InvalidConfig(format!("Failed to write global state entry: {e}"))
629                })?;
630            }
631            writeln!(writer)
632                .map_err(|e| OptimError::InvalidConfig(format!("Failed to write newline: {e}")))?;
633
634            // Write parameter groups
635            writeln!(writer, "[GROUPS]").map_err(|e| {
636                OptimError::InvalidConfig(format!("Failed to write groups section: {e}"))
637            })?;
638            writeln!(writer, "count={}", checkpoint.groups.len()).map_err(|e| {
639                OptimError::InvalidConfig(format!("Failed to write group count: {e}"))
640            })?;
641            writeln!(writer)
642                .map_err(|e| OptimError::InvalidConfig(format!("Failed to write newline: {e}")))?;
643
644            for group in &checkpoint.groups {
645                // Write group header
646                writeln!(writer, "[GROUP_{}]", group.id).map_err(|e| {
647                    OptimError::InvalidConfig(format!("Failed to write group header: {e}"))
648                })?;
649
650                // Write group config
651                writeln!(
652                    writer,
653                    "learning_rate={}",
654                    group
655                        .config
656                        .learning_rate
657                        .map(|lr| lr.to_string())
658                        .unwrap_or_else(|| "None".to_string())
659                )
660                .map_err(|e| {
661                    OptimError::InvalidConfig(format!("Failed to write learning rate: {e}"))
662                })?;
663                writeln!(
664                    writer,
665                    "weight_decay={}",
666                    group
667                        .config
668                        .weight_decay
669                        .map(|wd| wd.to_string())
670                        .unwrap_or_else(|| "None".to_string())
671                )
672                .map_err(|e| {
673                    OptimError::InvalidConfig(format!("Failed to write weight decay: {e}"))
674                })?;
675                writeln!(
676                    writer,
677                    "momentum={}",
678                    group
679                        .config
680                        .momentum
681                        .map(|m| m.to_string())
682                        .unwrap_or_else(|| "None".to_string())
683                )
684                .map_err(|e| OptimError::InvalidConfig(format!("Failed to write momentum: {e}")))?;
685
686                // Write custom params
687                writeln!(
688                    writer,
689                    "custom_params_count={}",
690                    group.config.custom_params.len()
691                )
692                .map_err(|e| {
693                    OptimError::InvalidConfig(format!("Failed to write custom params count: {e}"))
694                })?;
695                for (key, value) in &group.config.custom_params {
696                    writeln!(writer, "custom_{}={}", key, value).map_err(|e| {
697                        OptimError::InvalidConfig(format!("Failed to write custom param: {e}"))
698                    })?;
699                }
700
701                // Write parameters
702                writeln!(writer, "param_count={}", group.params.len()).map_err(|e| {
703                    OptimError::InvalidConfig(format!("Failed to write param count: {e}"))
704                })?;
705                for (i, param) in group.params.iter().enumerate() {
706                    writeln!(writer, "param_{}shape={:?}", i, param.shape()).map_err(|e| {
707                        OptimError::InvalidConfig(format!("Failed to write param shape: {e}"))
708                    })?;
709                    write!(writer, "param_{}_data=", i).map_err(|e| {
710                        OptimError::InvalidConfig(format!("Failed to write param data label: {e}"))
711                    })?;
712
713                    // Write array data as space-separated values
714                    for (j, &val) in param.iter().enumerate() {
715                        if j > 0 {
716                            write!(writer, " ").map_err(|e| {
717                                OptimError::InvalidConfig(format!("Failed to write space: {e}"))
718                            })?;
719                        }
720                        write!(writer, "{}", val).map_err(|e| {
721                            OptimError::InvalidConfig(format!("Failed to write value: {e}"))
722                        })?;
723                    }
724                    writeln!(writer).map_err(|e| {
725                        OptimError::InvalidConfig(format!("Failed to write newline: {e}"))
726                    })?;
727                }
728
729                // Write optimizer state
730                writeln!(writer, "state_count={}", group.state.len()).map_err(|e| {
731                    OptimError::InvalidConfig(format!("Failed to write state count: {e}"))
732                })?;
733                for (state_name, state_arrays) in &group.state {
734                    writeln!(writer, "state_name={}", state_name).map_err(|e| {
735                        OptimError::InvalidConfig(format!("Failed to write state name: {e}"))
736                    })?;
737                    writeln!(writer, "state_array_count={}", state_arrays.len()).map_err(|e| {
738                        OptimError::InvalidConfig(format!("Failed to write state array count: {e}"))
739                    })?;
740                    for (i, array) in state_arrays.iter().enumerate() {
741                        writeln!(writer, "state_{}shape={:?}", i, array.shape()).map_err(|e| {
742                            OptimError::InvalidConfig(format!("Failed to write state shape: {e}"))
743                        })?;
744                        write!(writer, "state_{}_data=", i).map_err(|e| {
745                            OptimError::InvalidConfig(format!(
746                                "Failed to write state data label: {}",
747                                e
748                            ))
749                        })?;
750
751                        // Write array data
752                        for (j, &val) in array.iter().enumerate() {
753                            if j > 0 {
754                                write!(writer, " ").map_err(|e| {
755                                    OptimError::InvalidConfig(format!(
756                                        "Failed to write space: {}",
757                                        e
758                                    ))
759                                })?;
760                            }
761                            write!(writer, "{}", val).map_err(|e| {
762                                OptimError::InvalidConfig(format!("Failed to write value: {e}"))
763                            })?;
764                        }
765                        writeln!(writer).map_err(|e| {
766                            OptimError::InvalidConfig(format!("Failed to write newline: {e}"))
767                        })?;
768                    }
769                }
770
771                writeln!(writer).map_err(|e| {
772                    OptimError::InvalidConfig(format!("Failed to write newline: {e}"))
773                })?;
774            }
775
776            writer.flush().map_err(|e| {
777                OptimError::InvalidConfig(format!("Failed to flush checkpoint file: {e}"))
778            })?;
779
780            Ok(())
781        }
782
783        /// Load checkpoint from file (simple text format)
784        fn load_checkpoint<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
785            use std::fs::File;
786            use std::io::{BufRead, BufReader};
787
788            let path = path.as_ref();
789            let file = File::open(path).map_err(|e| {
790                OptimError::InvalidConfig(format!("Failed to open checkpoint file: {e}"))
791            })?;
792            let reader = BufReader::new(file);
793            let mut lines = reader.lines();
794
795            // Read header
796            let mut step = 0;
797            let mut optimizerversion = String::new();
798            let mut timestamp = String::new();
799
800            while let Some(Ok(line)) = lines.next() {
801                if line.starts_with("# Step: ") {
802                    step = line.trim_start_matches("# Step: ").parse().map_err(|_| {
803                        OptimError::InvalidConfig("Invalid step format".to_string())
804                    })?;
805                } else if line.starts_with("# Optimizer Version: ") {
806                    optimizerversion = line.trim_start_matches("# Optimizer Version: ").to_string();
807                } else if line.starts_with("# Timestamp: ") {
808                    timestamp = line.trim_start_matches("# Timestamp: ").to_string();
809                } else if line.starts_with("[METADATA]") {
810                    break;
811                }
812            }
813
814            // Read metadata
815            let mut custom_metadata = HashMap::new();
816            while let Some(Ok(line)) = lines.next() {
817                if line.is_empty() || line.starts_with("[") {
818                    if line.starts_with("[GLOBAL_STATE]") {
819                        break;
820                    }
821                    continue;
822                }
823                if let Some((key, value)) = line.split_once('=') {
824                    custom_metadata.insert(key.to_string(), value.to_string());
825                }
826            }
827
828            // Read global state
829            let mut global_state = HashMap::new();
830            while let Some(Ok(line)) = lines.next() {
831                if line.is_empty() || line.starts_with("[") {
832                    if line.starts_with("[GROUPS]") {
833                        break;
834                    }
835                    continue;
836                }
837                if let Some((key, value)) = line.split_once('=') {
838                    global_state.insert(key.to_string(), value.to_string());
839                }
840            }
841
842            // Read groups count
843            let mut group_count = 0;
844            while let Some(Ok(line)) = lines.next() {
845                if line.starts_with("count=") {
846                    group_count = line.trim_start_matches("count=").parse().map_err(|_| {
847                        OptimError::InvalidConfig("Invalid group count".to_string())
848                    })?;
849                    break;
850                }
851            }
852
853            // Read parameter groups
854            let mut groups = Vec::new();
855            for _ in 0..group_count {
856                // Skip to group header
857                let mut group_id = 0;
858                while let Some(Ok(line)) = lines.next() {
859                    if line.starts_with("[GROUP_") {
860                        let id_str = line.trim_start_matches("[GROUP_").trim_end_matches(']');
861                        group_id = id_str.parse().map_err(|_| {
862                            OptimError::InvalidConfig("Invalid group ID".to_string())
863                        })?;
864                        break;
865                    }
866                }
867
868                // Read group config
869                let mut learning_rate = None;
870                let mut weight_decay = None;
871                let mut momentum = None;
872                let mut custom_params = HashMap::new();
873                let mut _custom_params_count = 0;
874
875                while let Some(Ok(line)) = lines.next() {
876                    if line.starts_with("learning_rate=") {
877                        let val_str = line.trim_start_matches("learning_rate=");
878                        if val_str != "None" {
879                            learning_rate = Some(A::from_str(val_str).map_err(|_| {
880                                OptimError::InvalidConfig("Invalid learning rate".to_string())
881                            })?);
882                        }
883                    } else if line.starts_with("weight_decay=") {
884                        let val_str = line.trim_start_matches("weight_decay=");
885                        if val_str != "None" {
886                            weight_decay = Some(A::from_str(val_str).map_err(|_| {
887                                OptimError::InvalidConfig("Invalid weight decay".to_string())
888                            })?);
889                        }
890                    } else if line.starts_with("momentum=") {
891                        let val_str = line.trim_start_matches("momentum=");
892                        if val_str != "None" {
893                            momentum = Some(A::from_str(val_str).map_err(|_| {
894                                OptimError::InvalidConfig("Invalid momentum".to_string())
895                            })?);
896                        }
897                    } else if line.starts_with("custom_params_count=") {
898                        _custom_params_count = line
899                            .trim_start_matches("custom_params_count=")
900                            .parse()
901                            .map_err(|_| {
902                                OptimError::InvalidConfig("Invalid custom params count".to_string())
903                            })?;
904                    } else if line.starts_with("custom_") {
905                        if let Some((key_with_prefix, value)) = line.split_once('=') {
906                            let key = key_with_prefix.trim_start_matches("custom_");
907                            custom_params.insert(
908                                key.to_string(),
909                                A::from_str(value).map_err(|_| {
910                                    OptimError::InvalidConfig(
911                                        "Invalid custom param value".to_string(),
912                                    )
913                                })?,
914                            );
915                        }
916                    } else if line.starts_with("param_count=") {
917                        break;
918                    }
919                }
920
921                // Create group config
922                let config = ParameterGroupConfig {
923                    learning_rate,
924                    weight_decay,
925                    momentum,
926                    constraints: Vec::new(), // Constraints are not persisted in this simple format
927                    custom_params,
928                };
929
930                // Read parameters
931                let param_count: usize = lines
932                    .next()
933                    .ok_or_else(|| OptimError::InvalidConfig("Missing param count".to_string()))?
934                    .map_err(|e| OptimError::InvalidConfig(format!("Failed to read line: {e}")))?
935                    .trim_start_matches("param_count=")
936                    .parse()
937                    .map_err(|_| OptimError::InvalidConfig("Invalid param count".to_string()))?;
938
939                let mut params = Vec::new();
940                for i in 0..param_count {
941                    // Read shape
942                    let shape_line = lines
943                        .next()
944                        .ok_or_else(|| {
945                            OptimError::InvalidConfig("Missing param shape".to_string())
946                        })?
947                        .map_err(|e| {
948                            OptimError::InvalidConfig(format!("Failed to read line: {e}"))
949                        })?;
950
951                    let shape_str = shape_line
952                        .trim_start_matches(&format!("param_{}shape=", i))
953                        .trim_start_matches('[')
954                        .trim_end_matches(']');
955
956                    let shape: Vec<usize> = shape_str
957                        .split(", ")
958                        .map(|s| {
959                            s.parse()
960                                .map_err(|_| OptimError::InvalidConfig("Invalid shape".to_string()))
961                        })
962                        .collect::<Result<Vec<_>>>()?;
963
964                    // Read data
965                    let data_line = lines
966                        .next()
967                        .ok_or_else(|| OptimError::InvalidConfig("Missing param data".to_string()))?
968                        .map_err(|e| {
969                            OptimError::InvalidConfig(format!("Failed to read line: {e}"))
970                        })?;
971
972                    let data_str = data_line.trim_start_matches(&format!("param_{}_data=", i));
973                    let data: Vec<A> = data_str
974                        .split(' ')
975                        .filter(|s| !s.is_empty())
976                        .map(|s| {
977                            A::from_str(s).map_err(|_| {
978                                OptimError::InvalidConfig("Invalid data value".to_string())
979                            })
980                        })
981                        .collect::<Result<Vec<_>>>()?;
982
983                    // Create array from shape and data with dynamic dimensions
984                    let array: Array<A, scirs2_core::ndarray::IxDyn> =
985                        Array::from_shape_vec(shape, data).map_err(|e| {
986                            OptimError::InvalidConfig(format!("Failed to create array: {e}"))
987                        })?;
988                    params.push(array);
989                }
990
991                // Read optimizer state
992                let state_count: usize = lines
993                    .next()
994                    .ok_or_else(|| OptimError::InvalidConfig("Missing state count".to_string()))?
995                    .map_err(|e| OptimError::InvalidConfig(format!("Failed to read line: {e}")))?
996                    .trim_start_matches("state_count=")
997                    .parse()
998                    .map_err(|_| OptimError::InvalidConfig("Invalid state count".to_string()))?;
999
1000                let mut state = HashMap::new();
1001                for _ in 0..state_count {
1002                    let state_name = lines
1003                        .next()
1004                        .ok_or_else(|| OptimError::InvalidConfig("Missing state name".to_string()))?
1005                        .map_err(|e| {
1006                            OptimError::InvalidConfig(format!("Failed to read line: {e}"))
1007                        })?
1008                        .trim_start_matches("state_name=")
1009                        .to_string();
1010
1011                    let array_count: usize = lines
1012                        .next()
1013                        .ok_or_else(|| {
1014                            OptimError::InvalidConfig("Missing state array count".to_string())
1015                        })?
1016                        .map_err(|e| {
1017                            OptimError::InvalidConfig(format!("Failed to read line: {e}"))
1018                        })?
1019                        .trim_start_matches("state_array_count=")
1020                        .parse()
1021                        .map_err(|_| {
1022                            OptimError::InvalidConfig("Invalid state array count".to_string())
1023                        })?;
1024
1025                    let mut state_arrays = Vec::new();
1026                    for i in 0..array_count {
1027                        // Read shape
1028                        let shape_line = lines
1029                            .next()
1030                            .ok_or_else(|| {
1031                                OptimError::InvalidConfig("Missing state shape".to_string())
1032                            })?
1033                            .map_err(|e| {
1034                                OptimError::InvalidConfig(format!("Failed to read line: {e}"))
1035                            })?;
1036
1037                        let shape_str = shape_line
1038                            .trim_start_matches(&format!("state_{}shape=", i))
1039                            .trim_start_matches('[')
1040                            .trim_end_matches(']');
1041
1042                        let shape: Vec<usize> = shape_str
1043                            .split(", ")
1044                            .map(|s| {
1045                                s.parse().map_err(|_| {
1046                                    OptimError::InvalidConfig("Invalid state shape".to_string())
1047                                })
1048                            })
1049                            .collect::<Result<Vec<_>>>()?;
1050
1051                        // Read data
1052                        let data_line = lines
1053                            .next()
1054                            .ok_or_else(|| {
1055                                OptimError::InvalidConfig("Missing state data".to_string())
1056                            })?
1057                            .map_err(|e| {
1058                                OptimError::InvalidConfig(format!("Failed to read line: {e}"))
1059                            })?;
1060
1061                        let data_str = data_line.trim_start_matches(&format!("state_{}_data=", i));
1062                        let data: Vec<A> = data_str
1063                            .split(' ')
1064                            .filter(|s| !s.is_empty())
1065                            .map(|s| {
1066                                A::from_str(s).map_err(|_| {
1067                                    OptimError::InvalidConfig("Invalid state value".to_string())
1068                                })
1069                            })
1070                            .collect::<Result<Vec<_>>>()?;
1071
1072                        // Create array with dynamic dimensions
1073                        let array = Array::from_shape_vec(shape, data).map_err(|e| {
1074                            OptimError::InvalidConfig(format!("Failed to create state array: {e}"))
1075                        })?;
1076                        state_arrays.push(array);
1077                    }
1078
1079                    state.insert(state_name, state_arrays);
1080                }
1081
1082                // Create group checkpoint
1083                groups.push(ParameterGroupCheckpoint {
1084                    id: group_id,
1085                    params,
1086                    config,
1087                    state,
1088                });
1089            }
1090
1091            // Create checkpoint metadata
1092            let mut metadata = CheckpointMetadata::new(optimizerversion);
1093            metadata.timestamp = timestamp;
1094            metadata.custom = custom_metadata;
1095
1096            // Create the checkpoint with dynamic dimensions
1097            let _dyn_checkpoint = OptimizerCheckpoint::<A, scirs2_core::ndarray::IxDyn> {
1098                step,
1099                groups,
1100                global_state,
1101                metadata,
1102            };
1103
1104            // Dimension conversion from IxDyn to D is a known limitation
1105            // Checkpoints are saved with dynamic dimensions (IxDyn) for flexibility,
1106            // but loading requires compile-time dimension type D.
1107            //
1108            // DESIGN NOTE: This is intentional for v1.0.0 to maintain type safety.
1109            // Users should use save_checkpoint() and create a new optimizer instance
1110            // rather than load_checkpoint() for cross-session restoration.
1111            //
1112            // For same-session checkpoint restoration, use CheckpointManager's
1113            // in-memory storage which preserves dimension types.
1114            //
1115            // Future enhancement (v1.1.0+): Add dimension-specific load methods
1116            // or provide a type-erased checkpoint interface.
1117            Err(OptimError::InvalidConfig(
1118                "Checkpoint loading from file with dimension type conversion is not supported in v1.0.0. \
1119                 Use CheckpointManager for in-memory checkpoints, or save/load with consistent dimension types. \
1120                 See documentation for checkpoint best practices.".to_string(),
1121            ))
1122        }
1123    }
1124
1125    /// In-memory checkpoint manager
1126    #[derive(Debug)]
1127    pub struct CheckpointManager<A: Float, D: Dimension> {
1128        checkpoints: HashMap<String, OptimizerCheckpoint<A, D>>,
1129        _maxcheckpoints: usize,
1130        checkpoint_keys: Vec<String>, // To maintain order for LRU eviction
1131    }
1132
1133    impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> CheckpointManager<A, D> {
1134        /// Create a new checkpoint manager
1135        pub fn new() -> Self {
1136            Self {
1137                checkpoints: HashMap::new(),
1138                _maxcheckpoints: 10,
1139                checkpoint_keys: Vec::new(),
1140            }
1141        }
1142
1143        /// Create a new checkpoint manager with maximum number of checkpoints
1144        pub fn with_max_checkpoints(_maxcheckpoints: usize) -> Self {
1145            Self {
1146                checkpoints: HashMap::new(),
1147                _maxcheckpoints,
1148                checkpoint_keys: Vec::new(),
1149            }
1150        }
1151
1152        /// Store a checkpoint with a given key
1153        pub fn store_checkpoint(&mut self, key: String, checkpoint: OptimizerCheckpoint<A, D>) {
1154            // If key already exists, update it
1155            if self.checkpoints.contains_key(&key) {
1156                self.checkpoints.insert(key.clone(), checkpoint);
1157                return;
1158            }
1159
1160            // If we're at capacity, remove oldest checkpoint
1161            if self.checkpoints.len() >= self._maxcheckpoints {
1162                if let Some(oldest_key) = self.checkpoint_keys.first().cloned() {
1163                    self.checkpoints.remove(&oldest_key);
1164                    self.checkpoint_keys.retain(|k| k != &oldest_key);
1165                }
1166            }
1167
1168            // Add new checkpoint
1169            self.checkpoints.insert(key.clone(), checkpoint);
1170            self.checkpoint_keys.push(key);
1171        }
1172
1173        /// Retrieve a checkpoint by key
1174        pub fn get_checkpoint(&self, key: &str) -> Option<&OptimizerCheckpoint<A, D>> {
1175            self.checkpoints.get(key)
1176        }
1177
1178        /// Remove a checkpoint by key
1179        pub fn remove_checkpoint(&mut self, key: &str) -> Option<OptimizerCheckpoint<A, D>> {
1180            self.checkpoint_keys.retain(|k| k != key);
1181            self.checkpoints.remove(key)
1182        }
1183
1184        /// List all checkpoint keys
1185        pub fn list_checkpoints(&self) -> &[String] {
1186            &self.checkpoint_keys
1187        }
1188
1189        /// Clear all checkpoints
1190        pub fn clear(&mut self) {
1191            self.checkpoints.clear();
1192            self.checkpoint_keys.clear();
1193        }
1194
1195        /// Get number of stored checkpoints
1196        pub fn len(&self) -> usize {
1197            self.checkpoints.len()
1198        }
1199
1200        /// Check if manager is empty
1201        pub fn is_empty(&self) -> bool {
1202            self.checkpoints.is_empty()
1203        }
1204    }
1205
1206    impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> Default
1207        for CheckpointManager<A, D>
1208    {
1209        fn default() -> Self {
1210            Self::new()
1211        }
1212    }
1213
1214    /// Utility functions for checkpointing
1215    pub mod utils {
1216        use super::*;
1217
1218        /// Create a checkpoint from parameter groups
1219        pub fn create_checkpoint_from_groups<A: Float + ScalarOperand + Debug, D: Dimension>(
1220            step: usize,
1221            groups: &[ParameterGroup<A, D>],
1222            global_state: HashMap<String, String>,
1223            optimizerversion: String,
1224        ) -> OptimizerCheckpoint<A, D> {
1225            let group_checkpoints = groups
1226                .iter()
1227                .map(|group| ParameterGroupCheckpoint {
1228                    id: group.id,
1229                    params: group.params.clone(),
1230                    config: group.config.clone(),
1231                    state: group.state.clone(),
1232                })
1233                .collect();
1234
1235            OptimizerCheckpoint {
1236                step,
1237                groups: group_checkpoints,
1238                global_state,
1239                metadata: CheckpointMetadata::new(optimizerversion),
1240            }
1241        }
1242
1243        /// Validate checkpoint compatibility
1244        pub fn validate_checkpoint<A: Float, D: Dimension>(
1245            checkpoint: &OptimizerCheckpoint<A, D>,
1246            expected_groups: usize,
1247        ) -> Result<()> {
1248            if checkpoint.groups.len() != expected_groups {
1249                return Err(OptimError::InvalidConfig(format!(
1250                    "Checkpoint has {} groups, expected {expected_groups}",
1251                    checkpoint.groups.len()
1252                )));
1253            }
1254
1255            // Validate that all group IDs are unique
1256            let mut ids = std::collections::HashSet::new();
1257            for group in &checkpoint.groups {
1258                if !ids.insert(group.id) {
1259                    return Err(OptimError::InvalidConfig(format!(
1260                        "Duplicate group ID {} in checkpoint",
1261                        group.id
1262                    )));
1263                }
1264            }
1265
1266            Ok(())
1267        }
1268
1269        /// Get checkpoint summary information
1270        pub fn checkpoint_summary<A: Float, D: Dimension>(
1271            checkpoint: &OptimizerCheckpoint<A, D>,
1272        ) -> String {
1273            let total_params: usize = checkpoint
1274                .groups
1275                .iter()
1276                .map(|g| g.params.iter().map(|p| p.len()).sum::<usize>())
1277                .sum();
1278
1279            format!(
1280                "Checkpoint at step {}: {} groups, {} total parameters, created at {}",
1281                checkpoint.step,
1282                checkpoint.groups.len(),
1283                total_params,
1284                checkpoint.metadata.timestamp
1285            )
1286        }
1287    }
1288}
1289
1290#[cfg(test)]
1291mod tests {
1292    use super::*;
1293    use scirs2_core::ndarray::Array1;
1294
1295    #[test]
1296    fn test_parameter_group_config() {
1297        let config = ParameterGroupConfig::new()
1298            .with_learning_rate(0.01)
1299            .with_weight_decay(0.0001)
1300            .with_momentum(0.9)
1301            .with_custom_param("beta1".to_string(), 0.9)
1302            .with_custom_param("beta2".to_string(), 0.999);
1303
1304        assert_eq!(config.learning_rate, Some(0.01));
1305        assert_eq!(config.weight_decay, Some(0.0001));
1306        assert_eq!(config.momentum, Some(0.9));
1307        assert_eq!(config.custom_params.get("beta1"), Some(&0.9));
1308        assert_eq!(config.custom_params.get("beta2"), Some(&0.999));
1309    }
1310
1311    #[test]
1312    fn test_parameter_group() {
1313        let params = vec![Array1::zeros(5), Array1::ones(3)];
1314        let config = ParameterGroupConfig::new().with_learning_rate(0.01);
1315
1316        let group = ParameterGroup::new(0, params, config);
1317
1318        assert_eq!(group.id, 0);
1319        assert_eq!(group.num_params(), 2);
1320        assert_eq!(group.learning_rate(0.001), 0.01);
1321        assert_eq!(group.weight_decay(0.0), 0.0);
1322    }
1323
1324    #[test]
1325    fn test_group_manager() {
1326        let mut manager: GroupManager<f64, scirs2_core::ndarray::Ix1> = GroupManager::new();
1327
1328        // Add first group
1329        let params1 = vec![Array1::zeros(5)];
1330        let config1 = ParameterGroupConfig::new().with_learning_rate(0.01);
1331        let id1 = manager.add_group(params1, config1);
1332
1333        // Add second group
1334        let params2 = vec![Array1::ones(3), Array1::zeros(4)];
1335        let config2 = ParameterGroupConfig::new().with_learning_rate(0.001);
1336        let id2 = manager.add_group(params2, config2);
1337
1338        assert_eq!(id1, 0);
1339        assert_eq!(id2, 1);
1340        assert_eq!(manager.groups().len(), 2);
1341        assert_eq!(manager.total_params(), 3);
1342
1343        // Test group access
1344        let group1 = manager.get_group(id1).unwrap();
1345        assert_eq!(group1.learning_rate(0.0), 0.01);
1346
1347        let group2 = manager.get_group(id2).unwrap();
1348        assert_eq!(group2.learning_rate(0.0), 0.001);
1349    }
1350
1351    #[test]
1352    fn test_parameter_constraints() {
1353        use approx::assert_relative_eq;
1354
1355        // Test value clipping
1356        let mut params = Array1::from_vec(vec![-2.0, 0.5, 3.0]);
1357        let clip_constraint = ParameterConstraint::ValueClip { min: 0.0, max: 1.0 };
1358        clip_constraint.apply(&mut params).unwrap();
1359        assert_eq!(params.as_slice().unwrap(), &[0.0, 0.5, 1.0]);
1360
1361        // Test L2 norm constraint
1362        let mut params = Array1::from_vec(vec![3.0, 4.0]); // norm = 5
1363        let l2_constraint = ParameterConstraint::L2NormConstraint { maxnorm: 2.0 };
1364        l2_constraint.apply(&mut params).unwrap();
1365        let new_norm = params.mapv(|x| x * x).sum().sqrt();
1366        assert_relative_eq!(new_norm, 2.0, epsilon = 1e-6);
1367
1368        // Test non-negativity constraint
1369        let mut params = Array1::from_vec(vec![-1.0, 2.0, -3.0]);
1370        let non_neg_constraint = ParameterConstraint::NonNegative;
1371        non_neg_constraint.apply(&mut params).unwrap();
1372        assert_eq!(params.as_slice().unwrap(), &[0.0, 2.0, 0.0]);
1373
1374        // Test unit sphere constraint
1375        let mut params = Array1::from_vec(vec![3.0, 4.0]); // norm = 5
1376        let unit_sphere_constraint = ParameterConstraint::UnitSphere;
1377        unit_sphere_constraint.apply(&mut params).unwrap();
1378        let new_norm = params.mapv(|x| x * x).sum().sqrt();
1379        assert_relative_eq!(new_norm, 1.0, epsilon = 1e-6);
1380    }
1381
1382    #[test]
1383    fn test_parameter_group_with_constraints() {
1384        let params = vec![Array1::from_vec(vec![-2.0, 3.0])];
1385        let config = ParameterGroupConfig::new()
1386            .with_learning_rate(0.01)
1387            .with_value_clip(0.0, 1.0);
1388
1389        let mut group = ParameterGroup::new(0, params, config);
1390
1391        // Apply constraints
1392        group.apply_constraints().unwrap();
1393
1394        // Check that constraints were applied
1395        assert_eq!(group.params[0].as_slice().unwrap(), &[0.0, 1.0]);
1396    }
1397
1398    #[test]
1399    fn test_parameter_config_builder() {
1400        let config = ParameterGroupConfig::new()
1401            .with_learning_rate(0.01)
1402            .with_l2_norm_constraint(1.0)
1403            .with_non_negative()
1404            .with_custom_param("beta".to_string(), 0.9);
1405
1406        assert_eq!(config.learning_rate, Some(0.01));
1407        assert_eq!(config.constraints.len(), 2);
1408        assert_eq!(config.custom_params.get("beta"), Some(&0.9));
1409    }
1410
1411    #[test]
1412    fn test_simplex_constraint() {
1413        use approx::assert_relative_eq;
1414
1415        // Test simplex constraint with positive values
1416        let mut params = Array1::from_vec(vec![2.0, 3.0, 5.0]);
1417        let simplex_constraint = ParameterConstraint::Simplex;
1418        simplex_constraint.apply(&mut params).unwrap();
1419
1420        // Check that values sum to 1 and are non-negative
1421        let sum: f64 = params.sum();
1422        assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
1423        assert!(params.iter().all(|&x| x >= 0.0));
1424
1425        // Values should be proportional to original
1426        assert_relative_eq!(params[0], 0.2, epsilon = 1e-6); // 2/10
1427        assert_relative_eq!(params[1], 0.3, epsilon = 1e-6); // 3/10
1428        assert_relative_eq!(params[2], 0.5, epsilon = 1e-6); // 5/10
1429    }
1430
1431    #[test]
1432    fn test_simplex_constraint_with_negatives() {
1433        use approx::assert_relative_eq;
1434
1435        // Test simplex constraint with negative values
1436        let mut params = Array1::from_vec(vec![-1.0, 2.0, 3.0]);
1437        let simplex_constraint = ParameterConstraint::Simplex;
1438        simplex_constraint.apply(&mut params).unwrap();
1439
1440        // Check that values sum to 1 and are non-negative
1441        let sum: f64 = params.sum();
1442        assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
1443        assert!(params.iter().all(|&x| x >= 0.0));
1444
1445        // Negative value should become 0, others normalized
1446        assert_relative_eq!(params[0], 0.0, epsilon = 1e-6);
1447        assert_relative_eq!(params[1], 0.4, epsilon = 1e-6); // 2/5
1448        assert_relative_eq!(params[2], 0.6, epsilon = 1e-6); // 3/5
1449    }
1450
1451    #[test]
1452    fn test_simplex_constraint_all_zeros() {
1453        use approx::assert_relative_eq;
1454
1455        // Test simplex constraint with all zeros
1456        let mut params = Array1::from_vec(vec![0.0, 0.0, 0.0]);
1457        let simplex_constraint = ParameterConstraint::Simplex;
1458        simplex_constraint.apply(&mut params).unwrap();
1459
1460        // Should result in uniform distribution
1461        let sum: f64 = params.sum();
1462        assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
1463        for &val in params.iter() {
1464            assert_relative_eq!(val, 1.0 / 3.0, epsilon = 1e-6);
1465        }
1466    }
1467
1468    #[test]
1469    fn test_spectral_norm_constraint() {
1470        use approx::assert_relative_eq;
1471
1472        // Test spectral norm constraint (approximated with Frobenius norm)
1473        let mut params = Array1::from_vec(vec![3.0, 4.0]); // Frobenius norm = 5
1474        let spectral_constraint = ParameterConstraint::SpectralNorm { maxnorm: 2.0 };
1475        spectral_constraint.apply(&mut params).unwrap();
1476
1477        let new_norm = params.mapv(|x| x * x).sum().sqrt();
1478        assert_relative_eq!(new_norm, 2.0, epsilon = 1e-6);
1479    }
1480
1481    #[test]
1482    fn test_nuclear_norm_constraint() {
1483        use approx::assert_relative_eq;
1484
1485        // Test nuclear norm constraint (approximated with L1 norm)
1486        let mut params = Array1::from_vec(vec![3.0, -4.0, 2.0]); // L1 norm = 9
1487        let nuclear_constraint = ParameterConstraint::NuclearNorm { maxnorm: 3.0 };
1488        nuclear_constraint.apply(&mut params).unwrap();
1489
1490        let new_l1_norm = params.mapv(|x| x.abs()).sum();
1491        assert_relative_eq!(new_l1_norm, 3.0, epsilon = 1e-6);
1492    }
1493
1494    #[test]
1495    fn test_orthogonal_constraint_error() {
1496        // Test that orthogonal constraint returns appropriate error
1497        let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1498        let orthogonal_constraint = ParameterConstraint::Orthogonal { tolerance: 1e-6 };
1499        let result = orthogonal_constraint.apply(&mut params);
1500
1501        assert!(result.is_err());
1502        assert!(result.unwrap_err().to_string().contains("2D arrays"));
1503    }
1504
1505    #[test]
1506    fn test_positive_definite_constraint_error() {
1507        // Test that positive definite constraint returns appropriate error
1508        let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1509        let pd_constraint = ParameterConstraint::PositiveDefinite {
1510            mineigenvalue: 0.01,
1511        };
1512        let result = pd_constraint.apply(&mut params);
1513
1514        assert!(result.is_err());
1515        assert!(result.unwrap_err().to_string().contains("eigenvalue"));
1516    }
1517
1518    #[test]
1519    fn test_enhanced_config_builder() {
1520        let config = ParameterGroupConfig::new()
1521            .with_learning_rate(0.01)
1522            .with_simplex()
1523            .with_spectral_norm(2.0)
1524            .with_nuclear_norm(1.5)
1525            .with_custom_constraint("my_constraint".to_string());
1526
1527        assert_eq!(config.learning_rate, Some(0.01));
1528        assert_eq!(config.constraints.len(), 4);
1529
1530        // Check that the right constraint types were added
1531        match &config.constraints[0] {
1532            ParameterConstraint::Simplex => (),
1533            _ => panic!("Expected Simplex constraint"),
1534        }
1535
1536        match &config.constraints[1] {
1537            ParameterConstraint::SpectralNorm { maxnorm } => {
1538                assert_eq!(*maxnorm, 2.0);
1539            }
1540            _ => panic!("Expected SpectralNorm constraint"),
1541        }
1542    }
1543
1544    #[test]
1545    fn test_constraint_combination() {
1546        use approx::assert_relative_eq;
1547
1548        // Test applying multiple constraints in sequence
1549        let params = vec![Array1::from_vec(vec![-1.0, 2.0, 3.0])];
1550        let config = ParameterGroupConfig::new()
1551            .with_learning_rate(0.01)
1552            .with_non_negative()
1553            .with_simplex();
1554
1555        let mut group = ParameterGroup::new(0, params, config);
1556
1557        // Apply constraints
1558        group.apply_constraints().unwrap();
1559
1560        // Check that both non-negative and simplex constraints were applied
1561        let result = &group.params[0];
1562        let sum: f64 = result.sum();
1563        assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
1564        assert!(result.iter().all(|&x| x >= 0.0));
1565
1566        // Should be [0, 0.4, 0.6] after non-negative then simplex
1567        assert_relative_eq!(result[0], 0.0, epsilon = 1e-6);
1568        assert_relative_eq!(result[1], 0.4, epsilon = 1e-6);
1569        assert_relative_eq!(result[2], 0.6, epsilon = 1e-6);
1570    }
1571}