Skip to main content

optirs_learned/transformer_based_optimizer/
state.rs

1// State management for transformer-based optimizer
2
3use super::config::TransformerBasedOptimizerConfig;
4use super::meta_learning::MetaState;
5use crate::error::Result;
6use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
7use scirs2_core::numeric::Float;
8use serde::{Deserialize, Serialize};
9use std::collections::{BTreeMap, HashMap, VecDeque};
10use std::fmt::Debug;
11use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
12
13/// Transformer optimizer state
14pub struct TransformerOptimizerState<T: Float + Debug + Send + Sync + 'static> {
15    /// Current model parameters
16    pub current_parameters: Array1<T>,
17
18    /// Parameter history
19    parameter_history: ParameterHistory<T>,
20
21    /// Optimization state
22    optimization_state: OptimizationState<T>,
23
24    /// Learning state
25    learning_state: LearningState<T>,
26
27    /// Memory state
28    memory_state: MemoryState<T>,
29
30    /// Checkpoint manager
31    checkpoint_manager: CheckpointManager<T>,
32
33    /// State configuration
34    config: StateConfig,
35
36    /// State statistics
37    statistics: StateStatistics<T>,
38
39    /// State version for tracking changes
40    version: usize,
41
42    /// Creation timestamp
43    created_at: std::time::Instant,
44
45    /// Last update timestamp
46    last_updated: std::time::Instant,
47}
48
49impl<T: Float + Debug + Send + Sync + 'static> TransformerOptimizerState<T> {
50    /// Create new optimizer state
51    pub fn new(config: &TransformerBasedOptimizerConfig<T>) -> Result<Self> {
52        let parameter_count = config.model_dimension * config.num_transformer_layers;
53        let current_parameters = Array1::zeros(parameter_count);
54
55        let parameter_history = ParameterHistory::new(1000, parameter_count)?;
56        let optimization_state = OptimizationState::new(config)?;
57        let learning_state = LearningState::new(config)?;
58        let memory_state = MemoryState::new()?;
59        let checkpoint_manager = CheckpointManager::new(config)?;
60        let state_config = StateConfig::from_optimizer_config(config);
61        let statistics = StateStatistics::new();
62
63        let now = std::time::Instant::now();
64
65        Ok(Self {
66            current_parameters,
67            parameter_history,
68            optimization_state,
69            learning_state,
70            memory_state,
71            checkpoint_manager,
72            config: state_config,
73            statistics,
74            version: 0,
75            created_at: now,
76            last_updated: now,
77        })
78    }
79
80    /// Update state with optimization step
81    pub fn update_with_step(&mut self, update: &Array1<T>, loss: Option<T>) -> Result<()> {
82        // Apply parameter update
83        self.current_parameters = &self.current_parameters + update;
84
85        // Record parameter history
86        self.parameter_history
87            .record_parameters(&self.current_parameters)?;
88
89        // Update optimization state
90        self.optimization_state.update_with_step(update, loss)?;
91
92        // Update learning state
93        if let Some(loss_val) = loss {
94            self.learning_state.update_with_loss(loss_val)?;
95        }
96
97        // Update statistics
98        self.statistics.record_update(update, loss);
99
100        // Increment version and update timestamp
101        self.version += 1;
102        self.last_updated = std::time::Instant::now();
103
104        Ok(())
105    }
106
107    /// Create state snapshot
108    pub fn create_snapshot(&self) -> Result<OptimizerStateSnapshot<T>> {
109        Ok(OptimizerStateSnapshot {
110            parameters: self.current_parameters.clone(),
111            optimization_state: self.optimization_state.clone(),
112            learning_state: self.learning_state.clone(),
113            memory_state: self.memory_state.clone(),
114            version: self.version,
115            timestamp: self.last_updated,
116            metadata: SnapshotMetadata {
117                parameter_count: self.current_parameters.len(),
118                total_updates: self.statistics.total_updates,
119                session_duration: self.last_updated.duration_since(self.created_at),
120            },
121        })
122    }
123
124    /// Restore from snapshot
125    pub fn restore_from_snapshot(&mut self, snapshot: &OptimizerStateSnapshot<T>) -> Result<()> {
126        self.current_parameters = snapshot.parameters.clone();
127        self.optimization_state = snapshot.optimization_state.clone();
128        self.learning_state = snapshot.learning_state.clone();
129        self.memory_state = snapshot.memory_state.clone();
130        self.version = snapshot.version;
131        self.last_updated = snapshot.timestamp;
132
133        Ok(())
134    }
135
136    /// Save checkpoint
137    pub fn save_checkpoint(&mut self, name: String) -> Result<String> {
138        let snapshot = self.create_snapshot()?;
139        let checkpoint_id = self.checkpoint_manager.save_checkpoint(name, snapshot)?;
140        Ok(checkpoint_id)
141    }
142
143    /// Load checkpoint
144    pub fn load_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
145        let snapshot = self.checkpoint_manager.load_checkpoint(checkpoint_id)?;
146        self.restore_from_snapshot(&snapshot)?;
147        Ok(())
148    }
149
150    /// Get parameter statistics
151    pub fn get_parameter_stats(&self) -> ParameterStatistics<T> {
152        self.parameter_history.get_statistics()
153    }
154
155    /// Get optimization progress
156    pub fn get_optimization_progress(&self) -> OptimizationProgress<T> {
157        self.optimization_state.get_progress()
158    }
159
160    /// Get learning statistics
161    pub fn get_learning_stats(&self) -> LearningStatistics<T> {
162        self.learning_state.get_statistics()
163    }
164
165    /// Reset state
166    pub fn reset(&mut self) -> Result<()> {
167        self.current_parameters.fill(T::zero());
168        self.parameter_history.clear();
169        self.optimization_state.reset()?;
170        self.learning_state.reset()?;
171        self.memory_state.reset()?;
172        self.statistics.reset();
173        self.version = 0;
174        self.last_updated = std::time::Instant::now();
175        Ok(())
176    }
177
178    /// Validate state consistency
179    pub fn validate_state(&self) -> Result<StateValidationReport> {
180        let mut issues = Vec::new();
181
182        // Check parameter validity
183        if self.current_parameters.iter().any(|&x| !x.is_finite()) {
184            issues.push("Invalid parameters detected (NaN or infinity)".to_string());
185        }
186
187        // Check version consistency
188        if self.version == 0 && self.statistics.total_updates > 0 {
189            issues.push("Version mismatch with update count".to_string());
190        }
191
192        // Check timestamp consistency
193        if self.last_updated < self.created_at {
194            issues.push("Invalid timestamp ordering".to_string());
195        }
196
197        // Validate optimization state
198        let opt_validation = self.optimization_state.validate()?;
199        issues.extend(opt_validation.issues);
200
201        // Validate learning state
202        let learning_validation = self.learning_state.validate()?;
203        issues.extend(learning_validation.issues);
204
205        Ok(StateValidationReport {
206            is_valid: issues.is_empty(),
207            issues,
208            validation_timestamp: std::time::Instant::now(),
209        })
210    }
211
212    /// Get state summary
213    pub fn get_state_summary(&self) -> StateSummary<T> {
214        StateSummary {
215            version: self.version,
216            parameter_count: self.current_parameters.len(),
217            parameter_norm: self.compute_parameter_norm(),
218            total_updates: self.statistics.total_updates,
219            session_duration: self.last_updated.duration_since(self.created_at),
220            last_update_magnitude: self.statistics.last_update_magnitude,
221            average_loss: self.learning_state.get_average_loss(),
222            convergence_rate: self.learning_state.get_convergence_rate(),
223            memory_usage: self.memory_state.get_total_usage(),
224            checkpoint_count: self.checkpoint_manager.get_checkpoint_count(),
225        }
226    }
227
228    /// Compute parameter norm
229    fn compute_parameter_norm(&self) -> T {
230        self.current_parameters
231            .iter()
232            .map(|&x| x * x)
233            .fold(T::zero(), |acc, x| acc + x)
234            .sqrt()
235    }
236
237    /// Get state metadata
238    pub fn get_metadata(&self) -> StateMetadata {
239        StateMetadata {
240            version: self.version,
241            created_at: SystemTime::now(), // Convert from Instant for serialization
242            last_updated: SystemTime::now(), // Convert from Instant for serialization
243            total_updates: self.statistics.total_updates,
244            configuration: self.config.clone(),
245        }
246    }
247
248    /// Export state to serializable format
249    pub fn export_state(&self) -> Result<SerializableState<T>> {
250        Ok(SerializableState {
251            parameters: self.current_parameters.to_vec(),
252            parameter_shape: self.current_parameters.shape().to_vec(),
253            optimization_state: self.optimization_state.to_serializable()?,
254            learning_state: self.learning_state.to_serializable()?,
255            metadata: self.get_metadata(),
256            statistics: self.statistics.clone(),
257        })
258    }
259
260    /// Import state from serializable format
261    pub fn import_state(&mut self, state: SerializableState<T>) -> Result<()> {
262        // Reconstruct parameters
263        if state.parameter_shape.len() != 1 {
264            return Err(crate::error::OptimError::Other(
265                "Invalid parameter shape for 1D array".to_string(),
266            ));
267        }
268
269        self.current_parameters = Array1::from_vec(state.parameters);
270
271        // Restore other state components
272        self.optimization_state
273            .from_serializable(state.optimization_state)?;
274        self.learning_state
275            .from_serializable(state.learning_state)?;
276        self.statistics = state.statistics;
277        self.version = state.metadata.version;
278        self.last_updated = Instant::now(); // Convert from SystemTime for internal use
279
280        Ok(())
281    }
282}
283
284/// Parameter history management
285pub struct ParameterHistory<T: Float + Debug + Send + Sync + 'static> {
286    /// Parameter snapshots
287    snapshots: VecDeque<ParameterSnapshot<T>>,
288
289    /// Maximum history size
290    max_size: usize,
291
292    /// Parameter dimension
293    parameter_dimension: usize,
294
295    /// Statistics
296    statistics: ParameterStatistics<T>,
297}
298
299impl<T: Float + Debug + Send + Sync + 'static> ParameterHistory<T> {
300    pub fn new(max_size: usize, parameter_dimension: usize) -> Result<Self> {
301        Ok(Self {
302            snapshots: VecDeque::new(),
303            max_size,
304            parameter_dimension,
305            statistics: ParameterStatistics::new(),
306        })
307    }
308
309    pub fn record_parameters(&mut self, parameters: &Array1<T>) -> Result<()> {
310        let snapshot = ParameterSnapshot {
311            parameters: parameters.clone(),
312            timestamp: std::time::Instant::now(),
313            norm: parameters
314                .iter()
315                .map(|&x| x * x)
316                .fold(T::zero(), |acc, x| acc + x)
317                .sqrt(),
318        };
319
320        self.snapshots.push_back(snapshot.clone());
321        if self.snapshots.len() > self.max_size {
322            self.snapshots.pop_front();
323        }
324
325        self.statistics.update_with_snapshot(&snapshot);
326        Ok(())
327    }
328
329    pub fn get_recent_parameters(&self, count: usize) -> Vec<Array1<T>> {
330        self.snapshots
331            .iter()
332            .rev()
333            .take(count)
334            .map(|snapshot| snapshot.parameters.clone())
335            .collect()
336    }
337
338    pub fn get_statistics(&self) -> ParameterStatistics<T> {
339        self.statistics.clone()
340    }
341
342    pub fn clear(&mut self) {
343        self.snapshots.clear();
344        self.statistics = ParameterStatistics::new();
345    }
346}
347
348/// Optimization state tracking
349#[derive(Debug, Clone)]
350pub struct OptimizationState<T: Float + Debug + Send + Sync + 'static> {
351    /// Current learning rate
352    pub learning_rate: T,
353
354    /// Momentum state
355    pub momentum: Option<Array1<T>>,
356
357    /// Adaptive learning rate state (e.g., Adam)
358    pub adaptive_state: Option<AdaptiveState<T>>,
359
360    /// Gradient accumulation
361    pub gradient_accumulator: GradientAccumulator<T>,
362
363    /// Step count
364    pub step_count: usize,
365
366    /// Last update magnitude
367    pub last_update_magnitude: T,
368
369    /// Convergence tracking
370    pub convergence_tracker: ConvergenceTracker<T>,
371}
372
373impl<T: Float + Debug + Send + Sync + 'static> OptimizationState<T> {
374    pub fn new(config: &TransformerBasedOptimizerConfig<T>) -> Result<Self> {
375        let parameter_count = config.model_dimension * config.num_transformer_layers;
376
377        Ok(Self {
378            learning_rate: config.learning_rate,
379            momentum: None,
380            adaptive_state: Some(AdaptiveState::new(parameter_count)?),
381            gradient_accumulator: GradientAccumulator::new(parameter_count)?,
382            step_count: 0,
383            last_update_magnitude: T::zero(),
384            convergence_tracker: ConvergenceTracker::new(),
385        })
386    }
387
388    pub fn update_with_step(&mut self, update: &Array1<T>, loss: Option<T>) -> Result<()> {
389        self.step_count += 1;
390        self.last_update_magnitude = update
391            .iter()
392            .map(|&x| x * x)
393            .fold(T::zero(), |acc, x| acc + x)
394            .sqrt();
395
396        if let Some(loss_val) = loss {
397            self.convergence_tracker.record_loss(loss_val);
398        }
399
400        // Update adaptive state
401        if let Some(ref mut adaptive) = self.adaptive_state {
402            adaptive.update_with_step(update)?;
403        }
404
405        Ok(())
406    }
407
408    pub fn get_progress(&self) -> OptimizationProgress<T> {
409        OptimizationProgress {
410            step_count: self.step_count,
411            current_learning_rate: self.learning_rate,
412            last_update_magnitude: self.last_update_magnitude,
413            convergence_rate: self.convergence_tracker.get_convergence_rate(),
414            stability_score: self.convergence_tracker.get_stability_score(),
415        }
416    }
417
418    pub fn reset(&mut self) -> Result<()> {
419        self.step_count = 0;
420        self.last_update_magnitude = T::zero();
421        self.convergence_tracker.reset();
422
423        if let Some(ref mut adaptive) = self.adaptive_state {
424            adaptive.reset()?;
425        }
426
427        self.gradient_accumulator.reset()?;
428        Ok(())
429    }
430
431    pub fn validate(&self) -> Result<ValidationResult> {
432        let mut issues = Vec::new();
433
434        if self.learning_rate <= T::zero() {
435            issues.push("Invalid learning rate".to_string());
436        }
437
438        if !self.last_update_magnitude.is_finite() {
439            issues.push("Invalid update magnitude".to_string());
440        }
441
442        Ok(ValidationResult { issues })
443    }
444
445    pub fn to_serializable(&self) -> Result<SerializableOptimizationState<T>> {
446        Ok(SerializableOptimizationState {
447            learning_rate: self.learning_rate,
448            step_count: self.step_count,
449            last_update_magnitude: self.last_update_magnitude,
450            momentum: self.momentum.as_ref().map(|m| m.to_vec()),
451            convergence_metrics: self.convergence_tracker.to_serializable(),
452        })
453    }
454
455    pub fn from_serializable(&mut self, state: SerializableOptimizationState<T>) -> Result<()> {
456        self.learning_rate = state.learning_rate;
457        self.step_count = state.step_count;
458        self.last_update_magnitude = state.last_update_magnitude;
459
460        if let Some(momentum_vec) = state.momentum {
461            self.momentum = Some(Array1::from_vec(momentum_vec));
462        }
463
464        self.convergence_tracker
465            .from_serializable(state.convergence_metrics)?;
466        Ok(())
467    }
468}
469
470/// Learning state tracking
471#[derive(Debug, Clone)]
472pub struct LearningState<T: Float + Debug + Send + Sync + 'static> {
473    /// Loss history
474    loss_history: VecDeque<T>,
475
476    /// Meta-learning state
477    meta_state: Option<MetaState<T>>,
478
479    /// Task adaptation history
480    adaptation_history: VecDeque<TaskAdaptationRecord<T>>,
481
482    /// Learning rate schedule
483    learning_schedule: LearningSchedule<T>,
484
485    /// Performance metrics
486    performance_metrics: LearningPerformanceMetrics<T>,
487}
488
489impl<T: Float + Debug + Send + Sync + 'static> LearningState<T> {
490    pub fn new(config: &TransformerBasedOptimizerConfig<T>) -> Result<Self> {
491        let meta_state = Some(MetaState::new(config.model_dimension)?);
492        let learning_schedule = LearningSchedule::new(config.learning_rate, config.warmup_steps);
493
494        Ok(Self {
495            loss_history: VecDeque::new(),
496            meta_state,
497            adaptation_history: VecDeque::new(),
498            learning_schedule,
499            performance_metrics: LearningPerformanceMetrics::new(),
500        })
501    }
502
503    pub fn update_with_loss(&mut self, loss: T) -> Result<()> {
504        self.loss_history.push_back(loss);
505        if self.loss_history.len() > 1000 {
506            self.loss_history.pop_front();
507        }
508
509        self.performance_metrics.record_loss(loss);
510
511        if let Some(ref mut meta) = self.meta_state {
512            meta.update_loss_history(loss);
513        }
514
515        Ok(())
516    }
517
518    pub fn get_statistics(&self) -> LearningStatistics<T> {
519        LearningStatistics {
520            total_episodes: self.loss_history.len(),
521            average_loss: self.get_average_loss(),
522            best_loss: self.get_best_loss(),
523            convergence_rate: self.get_convergence_rate(),
524            learning_stability: self.performance_metrics.get_stability_score(),
525        }
526    }
527
528    pub fn get_average_loss(&self) -> T {
529        if self.loss_history.is_empty() {
530            T::zero()
531        } else {
532            self.loss_history
533                .iter()
534                .fold(T::zero(), |acc, &loss| acc + loss)
535                / T::from(self.loss_history.len()).expect("unwrap failed")
536        }
537    }
538
539    pub fn get_best_loss(&self) -> T {
540        self.loss_history
541            .iter()
542            .fold(T::infinity(), |min, &loss| min.min(loss))
543    }
544
545    pub fn get_convergence_rate(&self) -> T {
546        if self.loss_history.len() < 2 {
547            return T::zero();
548        }
549
550        let recent_losses: Vec<_> = self.loss_history.iter().rev().take(10).cloned().collect();
551        if recent_losses.len() < 2 {
552            return T::zero();
553        }
554
555        let initial = recent_losses.last().expect("unwrap failed");
556        let final_loss = recent_losses.first().expect("unwrap failed");
557
558        if *initial > T::zero() {
559            (*initial - *final_loss) / *initial
560        } else {
561            T::zero()
562        }
563    }
564
565    pub fn reset(&mut self) -> Result<()> {
566        self.loss_history.clear();
567        self.adaptation_history.clear();
568        self.performance_metrics.reset();
569
570        if let Some(ref mut meta) = self.meta_state {
571            *meta = MetaState::new(meta.get_parameters().len())?;
572        }
573
574        Ok(())
575    }
576
577    pub fn validate(&self) -> Result<ValidationResult> {
578        let mut issues = Vec::new();
579
580        if self.loss_history.iter().any(|&loss| !loss.is_finite()) {
581            issues.push("Invalid loss values detected".to_string());
582        }
583
584        Ok(ValidationResult { issues })
585    }
586
587    pub fn to_serializable(&self) -> Result<SerializableLearningState<T>> {
588        Ok(SerializableLearningState {
589            loss_history: self.loss_history.iter().cloned().collect(),
590            average_loss: self.get_average_loss(),
591            best_loss: self.get_best_loss(),
592            convergence_rate: self.get_convergence_rate(),
593        })
594    }
595
596    pub fn from_serializable(&mut self, state: SerializableLearningState<T>) -> Result<()> {
597        self.loss_history = VecDeque::from(state.loss_history);
598        Ok(())
599    }
600}
601
602/// Memory state management
603#[derive(Debug, Clone)]
604pub struct MemoryState<T: Float + Debug + Send + Sync + 'static> {
605    /// Attention caches
606    attention_caches: HashMap<String, AttentionCache<T>>,
607
608    /// Memory usage tracking
609    memory_usage: MemoryUsageTracker,
610
611    /// Cache statistics
612    cache_statistics: CacheStatistics,
613}
614
615impl<T: Float + Debug + Send + Sync + 'static> MemoryState<T> {
616    pub fn new() -> Result<Self> {
617        Ok(Self {
618            attention_caches: HashMap::new(),
619            memory_usage: MemoryUsageTracker::new(),
620            cache_statistics: CacheStatistics::new(),
621        })
622    }
623
624    pub fn get_total_usage(&self) -> usize {
625        self.memory_usage.total_usage
626    }
627
628    pub fn reset(&mut self) -> Result<()> {
629        self.attention_caches.clear();
630        self.memory_usage.reset();
631        self.cache_statistics.reset();
632        Ok(())
633    }
634}
635
636/// Checkpoint management
637pub struct CheckpointManager<T: Float + Debug + Send + Sync + 'static> {
638    /// Stored checkpoints
639    checkpoints: HashMap<String, OptimizerStateSnapshot<T>>,
640
641    /// Checkpoint metadata
642    metadata: HashMap<String, CheckpointMetadata>,
643
644    /// Maximum checkpoints to keep
645    max_checkpoints: usize,
646
647    /// Auto-save configuration
648    auto_save_config: AutoSaveConfig,
649}
650
651impl<T: Float + Debug + Send + Sync + 'static> CheckpointManager<T> {
652    pub fn new(config: &TransformerBasedOptimizerConfig<T>) -> Result<Self> {
653        Ok(Self {
654            checkpoints: HashMap::new(),
655            metadata: HashMap::new(),
656            max_checkpoints: 10,
657            auto_save_config: AutoSaveConfig::default(),
658        })
659    }
660
661    pub fn save_checkpoint(
662        &mut self,
663        name: String,
664        snapshot: OptimizerStateSnapshot<T>,
665    ) -> Result<String> {
666        let checkpoint_id = format!("{}_{}", name, snapshot.version);
667
668        // Add metadata
669        let metadata = CheckpointMetadata {
670            id: checkpoint_id.clone(),
671            name: name.clone(),
672            created_at: std::time::Instant::now(),
673            size_estimate: snapshot.parameters.len() * std::mem::size_of::<T>(),
674            description: format!("Checkpoint at version {}", snapshot.version),
675        };
676
677        self.checkpoints.insert(checkpoint_id.clone(), snapshot);
678        self.metadata.insert(checkpoint_id.clone(), metadata);
679
680        // Cleanup old checkpoints if necessary
681        if self.checkpoints.len() > self.max_checkpoints {
682            self.cleanup_old_checkpoints()?;
683        }
684
685        Ok(checkpoint_id)
686    }
687
688    pub fn load_checkpoint(&self, checkpoint_id: &str) -> Result<OptimizerStateSnapshot<T>> {
689        self.checkpoints.get(checkpoint_id).cloned().ok_or_else(|| {
690            crate::error::OptimError::Other(format!("Checkpoint {} not found", checkpoint_id))
691        })
692    }
693
694    pub fn list_checkpoints(&self) -> Vec<CheckpointMetadata> {
695        self.metadata.values().cloned().collect()
696    }
697
698    pub fn delete_checkpoint(&mut self, checkpoint_id: &str) -> Result<bool> {
699        let removed_checkpoint = self.checkpoints.remove(checkpoint_id).is_some();
700        let removed_metadata = self.metadata.remove(checkpoint_id).is_some();
701        Ok(removed_checkpoint && removed_metadata)
702    }
703
704    pub fn get_checkpoint_count(&self) -> usize {
705        self.checkpoints.len()
706    }
707
708    fn cleanup_old_checkpoints(&mut self) -> Result<()> {
709        // Remove oldest checkpoints if over limit
710        while self.checkpoints.len() > self.max_checkpoints {
711            if let Some((oldest_id, _)) = self
712                .metadata
713                .iter()
714                .min_by_key(|(_, metadata)| metadata.created_at)
715                .map(|(id, metadata)| (id.clone(), metadata.clone()))
716            {
717                self.checkpoints.remove(&oldest_id);
718                self.metadata.remove(&oldest_id);
719            } else {
720                break;
721            }
722        }
723        Ok(())
724    }
725}
726
727/// Supporting data structures and types
728
729#[derive(Debug, Clone)]
730pub struct ParameterSnapshot<T: Float + Debug + Send + Sync + 'static> {
731    pub parameters: Array1<T>,
732    pub timestamp: std::time::Instant,
733    pub norm: T,
734}
735
736#[derive(Debug, Clone)]
737pub struct ParameterStatistics<T: Float + Debug + Send + Sync + 'static> {
738    pub total_snapshots: usize,
739    pub average_norm: T,
740    pub max_norm: T,
741    pub min_norm: T,
742    pub norm_trend: T,
743}
744
745impl<T: Float + Debug + Send + Sync + 'static> Default for ParameterStatistics<T> {
746    fn default() -> Self {
747        Self::new()
748    }
749}
750
751impl<T: Float + Debug + Send + Sync + 'static> ParameterStatistics<T> {
752    pub fn new() -> Self {
753        Self {
754            total_snapshots: 0,
755            average_norm: T::zero(),
756            max_norm: T::zero(),
757            min_norm: T::infinity(),
758            norm_trend: T::zero(),
759        }
760    }
761
762    pub fn update_with_snapshot(&mut self, snapshot: &ParameterSnapshot<T>) {
763        self.total_snapshots += 1;
764        self.average_norm = (self.average_norm
765            * scirs2_core::numeric::NumCast::from(self.total_snapshots - 1)
766                .unwrap_or_else(|| T::zero())
767            + snapshot.norm)
768            / scirs2_core::numeric::NumCast::from(self.total_snapshots)
769                .unwrap_or_else(|| T::zero());
770        self.max_norm = self.max_norm.max(snapshot.norm);
771        self.min_norm = self.min_norm.min(snapshot.norm);
772    }
773}
774
775#[derive(Debug, Clone)]
776pub struct AdaptiveState<T: Float + Debug + Send + Sync + 'static> {
777    /// First moment estimates
778    pub m: Array1<T>,
779    /// Second moment estimates
780    pub v: Array1<T>,
781    /// Step count for bias correction
782    pub step_count: usize,
783    /// Beta parameters
784    pub beta1: T,
785    pub beta2: T,
786    /// Epsilon for numerical stability
787    pub epsilon: T,
788}
789
790impl<T: Float + Debug + Send + Sync + 'static> AdaptiveState<T> {
791    pub fn new(parameter_count: usize) -> Result<Self> {
792        Ok(Self {
793            m: Array1::zeros(parameter_count),
794            v: Array1::zeros(parameter_count),
795            step_count: 0,
796            beta1: scirs2_core::numeric::NumCast::from(0.9).unwrap_or_else(|| T::zero()),
797            beta2: scirs2_core::numeric::NumCast::from(0.999).unwrap_or_else(|| T::zero()),
798            epsilon: scirs2_core::numeric::NumCast::from(1e-8).unwrap_or_else(|| T::zero()),
799        })
800    }
801
802    pub fn update_with_step(&mut self, _update: &Array1<T>) -> Result<()> {
803        self.step_count += 1;
804        // Adam-style update logic would go here
805        Ok(())
806    }
807
808    pub fn reset(&mut self) -> Result<()> {
809        self.m.fill(T::zero());
810        self.v.fill(T::zero());
811        self.step_count = 0;
812        Ok(())
813    }
814}
815
816#[derive(Debug, Clone)]
817pub struct GradientAccumulator<T: Float + Debug + Send + Sync + 'static> {
818    /// Accumulated gradients
819    pub accumulated_gradients: Array1<T>,
820    /// Accumulation count
821    pub accumulation_count: usize,
822}
823
824impl<T: Float + Debug + Send + Sync + 'static> GradientAccumulator<T> {
825    pub fn new(parameter_count: usize) -> Result<Self> {
826        Ok(Self {
827            accumulated_gradients: Array1::zeros(parameter_count),
828            accumulation_count: 0,
829        })
830    }
831
832    pub fn reset(&mut self) -> Result<()> {
833        self.accumulated_gradients.fill(T::zero());
834        self.accumulation_count = 0;
835        Ok(())
836    }
837}
838
839#[derive(Debug, Clone)]
840pub struct ConvergenceTracker<T: Float + Debug + Send + Sync + 'static> {
841    /// Recent loss values
842    recent_losses: VecDeque<T>,
843    /// Convergence threshold
844    convergence_threshold: T,
845    /// Stability window size
846    stability_window: usize,
847}
848
849impl<T: Float + Debug + Send + Sync + 'static> Default for ConvergenceTracker<T> {
850    fn default() -> Self {
851        Self::new()
852    }
853}
854
855impl<T: Float + Debug + Send + Sync + 'static> ConvergenceTracker<T> {
856    pub fn new() -> Self {
857        Self {
858            recent_losses: VecDeque::new(),
859            convergence_threshold: scirs2_core::numeric::NumCast::from(1e-6)
860                .unwrap_or_else(|| T::zero()),
861            stability_window: 10,
862        }
863    }
864
865    pub fn record_loss(&mut self, loss: T) {
866        self.recent_losses.push_back(loss);
867        if self.recent_losses.len() > self.stability_window {
868            self.recent_losses.pop_front();
869        }
870    }
871
872    pub fn get_convergence_rate(&self) -> T {
873        if self.recent_losses.len() < 2 {
874            return T::zero();
875        }
876
877        let first = self.recent_losses[0];
878        let last = *self.recent_losses.back().expect("unwrap failed");
879
880        if first > T::zero() {
881            (first - last) / first
882        } else {
883            T::zero()
884        }
885    }
886
887    pub fn get_stability_score(&self) -> T {
888        if self.recent_losses.len() < 2 {
889            return T::zero();
890        }
891
892        let mean = self.recent_losses.iter().fold(T::zero(), |acc, &x| acc + x)
893            / T::from(self.recent_losses.len()).expect("unwrap failed");
894        let variance = self
895            .recent_losses
896            .iter()
897            .map(|&x| (x - mean) * (x - mean))
898            .fold(T::zero(), |acc, x| acc + x)
899            / T::from(self.recent_losses.len()).expect("unwrap failed");
900
901        T::one() / (T::one() + variance.sqrt())
902    }
903
904    pub fn reset(&mut self) {
905        self.recent_losses.clear();
906    }
907
908    pub fn to_serializable(&self) -> SerializableConvergenceState<T> {
909        SerializableConvergenceState {
910            recent_losses: self.recent_losses.iter().cloned().collect(),
911            convergence_rate: self.get_convergence_rate(),
912            stability_score: self.get_stability_score(),
913        }
914    }
915
916    pub fn from_serializable(&mut self, state: SerializableConvergenceState<T>) -> Result<()> {
917        self.recent_losses = VecDeque::from(state.recent_losses);
918        Ok(())
919    }
920}
921
922/// State snapshots and serialization
923#[derive(Debug, Clone)]
924pub struct OptimizerStateSnapshot<T: Float + Debug + Send + Sync + 'static> {
925    pub parameters: Array1<T>,
926    pub optimization_state: OptimizationState<T>,
927    pub learning_state: LearningState<T>,
928    pub memory_state: MemoryState<T>,
929    pub version: usize,
930    pub timestamp: std::time::Instant,
931    pub metadata: SnapshotMetadata,
932}
933
934#[derive(Debug, Clone)]
935pub struct SnapshotMetadata {
936    pub parameter_count: usize,
937    pub total_updates: usize,
938    pub session_duration: Duration,
939}
940
941/// Configuration and metadata structures
942#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
943pub struct StateConfig {
944    pub max_history_size: usize,
945    pub checkpoint_frequency: usize,
946    pub auto_save_enabled: bool,
947    pub validation_enabled: bool,
948}
949
950impl StateConfig {
951    pub fn from_optimizer_config<T: Float + Debug + Send + Sync + 'static>(
952        config: &TransformerBasedOptimizerConfig<T>,
953    ) -> Self {
954        Self {
955            max_history_size: 1000,
956            checkpoint_frequency: 100,
957            auto_save_enabled: true,
958            validation_enabled: true,
959        }
960    }
961}
962
963#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
964pub struct StateMetadata {
965    pub version: usize,
966    pub created_at: std::time::SystemTime,
967    pub last_updated: std::time::SystemTime,
968    pub total_updates: usize,
969    pub configuration: StateConfig,
970}
971
972/// Statistics and tracking structures
973#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
974pub struct StateStatistics<T: Float + Debug + Send + Sync + 'static> {
975    pub total_updates: usize,
976    pub last_update_magnitude: T,
977    pub average_update_magnitude: T,
978    pub parameter_change_rate: T,
979    pub update_frequency: f64,
980}
981
982impl<T: Float + Debug + Send + Sync + 'static> Default for StateStatistics<T> {
983    fn default() -> Self {
984        Self::new()
985    }
986}
987
988impl<T: Float + Debug + Send + Sync + 'static> StateStatistics<T> {
989    pub fn new() -> Self {
990        Self {
991            total_updates: 0,
992            last_update_magnitude: T::zero(),
993            average_update_magnitude: T::zero(),
994            parameter_change_rate: T::zero(),
995            update_frequency: 0.0,
996        }
997    }
998
999    pub fn record_update(&mut self, update: &Array1<T>, _loss: Option<T>) {
1000        self.total_updates += 1;
1001        let magnitude = update
1002            .iter()
1003            .map(|&x| x * x)
1004            .fold(T::zero(), |acc, x| acc + x)
1005            .sqrt();
1006        self.last_update_magnitude = magnitude;
1007        self.average_update_magnitude = (self.average_update_magnitude
1008            * scirs2_core::numeric::NumCast::from(self.total_updates - 1)
1009                .unwrap_or_else(|| T::zero())
1010            + magnitude)
1011            / scirs2_core::numeric::NumCast::from(self.total_updates).unwrap_or_else(|| T::zero());
1012    }
1013
1014    pub fn reset(&mut self) {
1015        self.total_updates = 0;
1016        self.last_update_magnitude = T::zero();
1017        self.average_update_magnitude = T::zero();
1018        self.parameter_change_rate = T::zero();
1019        self.update_frequency = 0.0;
1020    }
1021}
1022
1023/// Validation and reporting structures
1024#[derive(Debug, Clone)]
1025pub struct StateValidationReport {
1026    pub is_valid: bool,
1027    pub issues: Vec<String>,
1028    pub validation_timestamp: Instant,
1029}
1030
1031#[derive(Debug, Clone)]
1032pub struct ValidationResult {
1033    pub issues: Vec<String>,
1034}
1035
1036#[derive(Debug, Clone)]
1037pub struct StateSummary<T: Float + Debug + Send + Sync + 'static> {
1038    pub version: usize,
1039    pub parameter_count: usize,
1040    pub parameter_norm: T,
1041    pub total_updates: usize,
1042    pub session_duration: Duration,
1043    pub last_update_magnitude: T,
1044    pub average_loss: T,
1045    pub convergence_rate: T,
1046    pub memory_usage: usize,
1047    pub checkpoint_count: usize,
1048}
1049
1050/// Serializable state structures
1051#[derive(Debug, Clone, Serialize, Deserialize)]
1052pub struct SerializableState<T: Float + Debug + Send + Sync + 'static> {
1053    pub parameters: Vec<T>,
1054    pub parameter_shape: Vec<usize>,
1055    pub optimization_state: SerializableOptimizationState<T>,
1056    pub learning_state: SerializableLearningState<T>,
1057    pub metadata: StateMetadata,
1058    pub statistics: StateStatistics<T>,
1059}
1060
1061#[derive(Debug, Clone, Serialize, Deserialize)]
1062pub struct SerializableOptimizationState<T: Float + Debug + Send + Sync + 'static> {
1063    pub learning_rate: T,
1064    pub step_count: usize,
1065    pub last_update_magnitude: T,
1066    pub momentum: Option<Vec<T>>,
1067    pub convergence_metrics: SerializableConvergenceState<T>,
1068}
1069
1070#[derive(Debug, Clone, Serialize, Deserialize)]
1071pub struct SerializableLearningState<T: Float + Debug + Send + Sync + 'static> {
1072    pub loss_history: Vec<T>,
1073    pub average_loss: T,
1074    pub best_loss: T,
1075    pub convergence_rate: T,
1076}
1077
1078#[derive(Debug, Clone, Serialize, Deserialize)]
1079pub struct SerializableConvergenceState<T: Float + Debug + Send + Sync + 'static> {
1080    pub recent_losses: Vec<T>,
1081    pub convergence_rate: T,
1082    pub stability_score: T,
1083}
1084
1085/// Additional supporting structures
1086#[derive(Debug, Clone)]
1087pub struct TaskAdaptationRecord<T: Float + Debug + Send + Sync + 'static> {
1088    pub task_id: String,
1089    pub adaptation_steps: usize,
1090    pub final_loss: T,
1091    pub adaptation_time: Duration,
1092}
1093
1094#[derive(Debug, Clone)]
1095pub struct LearningSchedule<T: Float + Debug + Send + Sync + 'static> {
1096    pub initial_rate: T,
1097    pub current_rate: T,
1098    pub warmup_steps: usize,
1099    pub decay_factor: T,
1100}
1101
1102impl<T: Float + Debug + Send + Sync + 'static> LearningSchedule<T> {
1103    pub fn new(initial_rate: T, warmup_steps: usize) -> Self {
1104        Self {
1105            initial_rate,
1106            current_rate: initial_rate,
1107            warmup_steps,
1108            decay_factor: scirs2_core::numeric::NumCast::from(0.95).unwrap_or_else(|| T::zero()),
1109        }
1110    }
1111}
1112
1113#[derive(Debug, Clone)]
1114pub struct LearningPerformanceMetrics<T: Float + Debug + Send + Sync + 'static> {
1115    pub loss_trend: T,
1116    pub convergence_stability: T,
1117    pub adaptation_efficiency: T,
1118}
1119
1120impl<T: Float + Debug + Send + Sync + 'static> Default for LearningPerformanceMetrics<T> {
1121    fn default() -> Self {
1122        Self::new()
1123    }
1124}
1125
1126impl<T: Float + Debug + Send + Sync + 'static> LearningPerformanceMetrics<T> {
1127    pub fn new() -> Self {
1128        Self {
1129            loss_trend: T::zero(),
1130            convergence_stability: T::zero(),
1131            adaptation_efficiency: T::zero(),
1132        }
1133    }
1134
1135    pub fn record_loss(&mut self, _loss: T) {
1136        // Update performance metrics logic
1137    }
1138
1139    pub fn get_stability_score(&self) -> T {
1140        self.convergence_stability
1141    }
1142
1143    pub fn reset(&mut self) {
1144        self.loss_trend = T::zero();
1145        self.convergence_stability = T::zero();
1146        self.adaptation_efficiency = T::zero();
1147    }
1148}
1149
1150#[derive(Debug, Clone)]
1151pub struct OptimizationProgress<T: Float + Debug + Send + Sync + 'static> {
1152    pub step_count: usize,
1153    pub current_learning_rate: T,
1154    pub last_update_magnitude: T,
1155    pub convergence_rate: T,
1156    pub stability_score: T,
1157}
1158
1159#[derive(Debug, Clone)]
1160pub struct LearningStatistics<T: Float + Debug + Send + Sync + 'static> {
1161    pub total_episodes: usize,
1162    pub average_loss: T,
1163    pub best_loss: T,
1164    pub convergence_rate: T,
1165    pub learning_stability: T,
1166}
1167
1168#[derive(Debug, Clone)]
1169pub struct AttentionCache<T: Float + Debug + Send + Sync + 'static> {
1170    pub cached_keys: Array2<T>,
1171    pub cached_values: Array2<T>,
1172    pub cache_size: usize,
1173}
1174
1175#[derive(Debug, Clone)]
1176pub struct MemoryUsageTracker {
1177    pub total_usage: usize,
1178    pub peak_usage: usize,
1179    pub allocation_count: usize,
1180}
1181
1182impl Default for MemoryUsageTracker {
1183    fn default() -> Self {
1184        Self::new()
1185    }
1186}
1187
1188impl MemoryUsageTracker {
1189    pub fn new() -> Self {
1190        Self {
1191            total_usage: 0,
1192            peak_usage: 0,
1193            allocation_count: 0,
1194        }
1195    }
1196
1197    pub fn reset(&mut self) {
1198        self.total_usage = 0;
1199        self.peak_usage = 0;
1200        self.allocation_count = 0;
1201    }
1202}
1203
1204#[derive(Debug, Clone)]
1205pub struct CacheStatistics {
1206    pub hit_count: usize,
1207    pub miss_count: usize,
1208    pub eviction_count: usize,
1209}
1210
1211impl Default for CacheStatistics {
1212    fn default() -> Self {
1213        Self::new()
1214    }
1215}
1216
1217impl CacheStatistics {
1218    pub fn new() -> Self {
1219        Self {
1220            hit_count: 0,
1221            miss_count: 0,
1222            eviction_count: 0,
1223        }
1224    }
1225
1226    pub fn reset(&mut self) {
1227        self.hit_count = 0;
1228        self.miss_count = 0;
1229        self.eviction_count = 0;
1230    }
1231}
1232
1233#[derive(Debug, Clone)]
1234pub struct CheckpointMetadata {
1235    pub id: String,
1236    pub name: String,
1237    pub created_at: Instant,
1238    pub size_estimate: usize,
1239    pub description: String,
1240}
1241
1242#[derive(Debug, Clone)]
1243pub struct AutoSaveConfig {
1244    pub enabled: bool,
1245    pub frequency: usize,
1246    pub max_auto_saves: usize,
1247}
1248
1249impl Default for AutoSaveConfig {
1250    fn default() -> Self {
1251        Self {
1252            enabled: true,
1253            frequency: 100,
1254            max_auto_saves: 5,
1255        }
1256    }
1257}
1258
1259#[cfg(test)]
1260mod tests {
1261    use super::*;
1262
1263    #[test]
1264    fn test_optimizer_state_creation() {
1265        let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
1266        let state = TransformerOptimizerState::new(&config);
1267        assert!(state.is_ok());
1268
1269        let s = state.expect("unwrap failed");
1270        assert_eq!(s.version, 0);
1271        assert!(!s.current_parameters.is_empty());
1272    }
1273
1274    #[test]
1275    fn test_state_update() {
1276        let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
1277        let mut state = TransformerOptimizerState::new(&config).expect("unwrap failed");
1278
1279        let update = Array1::<f32>::ones(state.current_parameters.len());
1280        let result = state.update_with_step(&update, Some(1.5));
1281        assert!(result.is_ok());
1282        assert_eq!(state.version, 1);
1283    }
1284
1285    #[test]
1286    fn test_snapshot_creation() {
1287        let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
1288        let state = TransformerOptimizerState::new(&config).expect("unwrap failed");
1289
1290        let snapshot = state.create_snapshot();
1291        assert!(snapshot.is_ok());
1292
1293        let snap = snapshot.expect("unwrap failed");
1294        assert_eq!(snap.version, 0);
1295        assert_eq!(snap.parameters.len(), state.current_parameters.len());
1296    }
1297
1298    #[test]
1299    fn test_checkpoint_management() {
1300        let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
1301        let mut state = TransformerOptimizerState::new(&config).expect("unwrap failed");
1302
1303        let checkpoint_id = state.save_checkpoint("test_checkpoint".to_string());
1304        assert!(checkpoint_id.is_ok());
1305
1306        let id = checkpoint_id.expect("unwrap failed");
1307        let load_result = state.load_checkpoint(&id);
1308        assert!(load_result.is_ok());
1309    }
1310
1311    #[test]
1312    fn test_parameter_history() {
1313        let history = ParameterHistory::<f32>::new(10, 5);
1314        assert!(history.is_ok());
1315
1316        let mut h = history.expect("unwrap failed");
1317        let params = Array1::<f32>::ones(5);
1318        assert!(h.record_parameters(&params).is_ok());
1319
1320        let recent = h.get_recent_parameters(1);
1321        assert_eq!(recent.len(), 1);
1322    }
1323
1324    #[test]
1325    fn test_convergence_tracker() {
1326        let mut tracker = ConvergenceTracker::<f32>::new();
1327
1328        tracker.record_loss(2.0);
1329        tracker.record_loss(1.5);
1330        tracker.record_loss(1.0);
1331
1332        let convergence = tracker.get_convergence_rate();
1333        assert!(convergence > 0.0);
1334
1335        let stability = tracker.get_stability_score();
1336        assert!(stability > 0.0 && stability <= 1.0);
1337    }
1338
1339    #[test]
1340    fn test_state_validation() {
1341        let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
1342        let state = TransformerOptimizerState::new(&config).expect("unwrap failed");
1343
1344        let validation = state.validate_state();
1345        assert!(validation.is_ok());
1346
1347        let report = validation.expect("unwrap failed");
1348        assert!(report.is_valid);
1349    }
1350}