1use super::config::TransformerBasedOptimizerConfig;
4use super::meta_learning::MetaState;
5use crate::error::Result;
6use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
7use scirs2_core::numeric::Float;
8use serde::{Deserialize, Serialize};
9use std::collections::{BTreeMap, HashMap, VecDeque};
10use std::fmt::Debug;
11use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
12
13pub struct TransformerOptimizerState<T: Float + Debug + Send + Sync + 'static> {
15 pub current_parameters: Array1<T>,
17
18 parameter_history: ParameterHistory<T>,
20
21 optimization_state: OptimizationState<T>,
23
24 learning_state: LearningState<T>,
26
27 memory_state: MemoryState<T>,
29
30 checkpoint_manager: CheckpointManager<T>,
32
33 config: StateConfig,
35
36 statistics: StateStatistics<T>,
38
39 version: usize,
41
42 created_at: std::time::Instant,
44
45 last_updated: std::time::Instant,
47}
48
49impl<T: Float + Debug + Send + Sync + 'static> TransformerOptimizerState<T> {
50 pub fn new(config: &TransformerBasedOptimizerConfig<T>) -> Result<Self> {
52 let parameter_count = config.model_dimension * config.num_transformer_layers;
53 let current_parameters = Array1::zeros(parameter_count);
54
55 let parameter_history = ParameterHistory::new(1000, parameter_count)?;
56 let optimization_state = OptimizationState::new(config)?;
57 let learning_state = LearningState::new(config)?;
58 let memory_state = MemoryState::new()?;
59 let checkpoint_manager = CheckpointManager::new(config)?;
60 let state_config = StateConfig::from_optimizer_config(config);
61 let statistics = StateStatistics::new();
62
63 let now = std::time::Instant::now();
64
65 Ok(Self {
66 current_parameters,
67 parameter_history,
68 optimization_state,
69 learning_state,
70 memory_state,
71 checkpoint_manager,
72 config: state_config,
73 statistics,
74 version: 0,
75 created_at: now,
76 last_updated: now,
77 })
78 }
79
80 pub fn update_with_step(&mut self, update: &Array1<T>, loss: Option<T>) -> Result<()> {
82 self.current_parameters = &self.current_parameters + update;
84
85 self.parameter_history
87 .record_parameters(&self.current_parameters)?;
88
89 self.optimization_state.update_with_step(update, loss)?;
91
92 if let Some(loss_val) = loss {
94 self.learning_state.update_with_loss(loss_val)?;
95 }
96
97 self.statistics.record_update(update, loss);
99
100 self.version += 1;
102 self.last_updated = std::time::Instant::now();
103
104 Ok(())
105 }
106
107 pub fn create_snapshot(&self) -> Result<OptimizerStateSnapshot<T>> {
109 Ok(OptimizerStateSnapshot {
110 parameters: self.current_parameters.clone(),
111 optimization_state: self.optimization_state.clone(),
112 learning_state: self.learning_state.clone(),
113 memory_state: self.memory_state.clone(),
114 version: self.version,
115 timestamp: self.last_updated,
116 metadata: SnapshotMetadata {
117 parameter_count: self.current_parameters.len(),
118 total_updates: self.statistics.total_updates,
119 session_duration: self.last_updated.duration_since(self.created_at),
120 },
121 })
122 }
123
124 pub fn restore_from_snapshot(&mut self, snapshot: &OptimizerStateSnapshot<T>) -> Result<()> {
126 self.current_parameters = snapshot.parameters.clone();
127 self.optimization_state = snapshot.optimization_state.clone();
128 self.learning_state = snapshot.learning_state.clone();
129 self.memory_state = snapshot.memory_state.clone();
130 self.version = snapshot.version;
131 self.last_updated = snapshot.timestamp;
132
133 Ok(())
134 }
135
136 pub fn save_checkpoint(&mut self, name: String) -> Result<String> {
138 let snapshot = self.create_snapshot()?;
139 let checkpoint_id = self.checkpoint_manager.save_checkpoint(name, snapshot)?;
140 Ok(checkpoint_id)
141 }
142
143 pub fn load_checkpoint(&mut self, checkpoint_id: &str) -> Result<()> {
145 let snapshot = self.checkpoint_manager.load_checkpoint(checkpoint_id)?;
146 self.restore_from_snapshot(&snapshot)?;
147 Ok(())
148 }
149
150 pub fn get_parameter_stats(&self) -> ParameterStatistics<T> {
152 self.parameter_history.get_statistics()
153 }
154
155 pub fn get_optimization_progress(&self) -> OptimizationProgress<T> {
157 self.optimization_state.get_progress()
158 }
159
160 pub fn get_learning_stats(&self) -> LearningStatistics<T> {
162 self.learning_state.get_statistics()
163 }
164
165 pub fn reset(&mut self) -> Result<()> {
167 self.current_parameters.fill(T::zero());
168 self.parameter_history.clear();
169 self.optimization_state.reset()?;
170 self.learning_state.reset()?;
171 self.memory_state.reset()?;
172 self.statistics.reset();
173 self.version = 0;
174 self.last_updated = std::time::Instant::now();
175 Ok(())
176 }
177
178 pub fn validate_state(&self) -> Result<StateValidationReport> {
180 let mut issues = Vec::new();
181
182 if self.current_parameters.iter().any(|&x| !x.is_finite()) {
184 issues.push("Invalid parameters detected (NaN or infinity)".to_string());
185 }
186
187 if self.version == 0 && self.statistics.total_updates > 0 {
189 issues.push("Version mismatch with update count".to_string());
190 }
191
192 if self.last_updated < self.created_at {
194 issues.push("Invalid timestamp ordering".to_string());
195 }
196
197 let opt_validation = self.optimization_state.validate()?;
199 issues.extend(opt_validation.issues);
200
201 let learning_validation = self.learning_state.validate()?;
203 issues.extend(learning_validation.issues);
204
205 Ok(StateValidationReport {
206 is_valid: issues.is_empty(),
207 issues,
208 validation_timestamp: std::time::Instant::now(),
209 })
210 }
211
212 pub fn get_state_summary(&self) -> StateSummary<T> {
214 StateSummary {
215 version: self.version,
216 parameter_count: self.current_parameters.len(),
217 parameter_norm: self.compute_parameter_norm(),
218 total_updates: self.statistics.total_updates,
219 session_duration: self.last_updated.duration_since(self.created_at),
220 last_update_magnitude: self.statistics.last_update_magnitude,
221 average_loss: self.learning_state.get_average_loss(),
222 convergence_rate: self.learning_state.get_convergence_rate(),
223 memory_usage: self.memory_state.get_total_usage(),
224 checkpoint_count: self.checkpoint_manager.get_checkpoint_count(),
225 }
226 }
227
228 fn compute_parameter_norm(&self) -> T {
230 self.current_parameters
231 .iter()
232 .map(|&x| x * x)
233 .fold(T::zero(), |acc, x| acc + x)
234 .sqrt()
235 }
236
237 pub fn get_metadata(&self) -> StateMetadata {
239 StateMetadata {
240 version: self.version,
241 created_at: SystemTime::now(), last_updated: SystemTime::now(), total_updates: self.statistics.total_updates,
244 configuration: self.config.clone(),
245 }
246 }
247
248 pub fn export_state(&self) -> Result<SerializableState<T>> {
250 Ok(SerializableState {
251 parameters: self.current_parameters.to_vec(),
252 parameter_shape: self.current_parameters.shape().to_vec(),
253 optimization_state: self.optimization_state.to_serializable()?,
254 learning_state: self.learning_state.to_serializable()?,
255 metadata: self.get_metadata(),
256 statistics: self.statistics.clone(),
257 })
258 }
259
260 pub fn import_state(&mut self, state: SerializableState<T>) -> Result<()> {
262 if state.parameter_shape.len() != 1 {
264 return Err(crate::error::OptimError::Other(
265 "Invalid parameter shape for 1D array".to_string(),
266 ));
267 }
268
269 self.current_parameters = Array1::from_vec(state.parameters);
270
271 self.optimization_state
273 .from_serializable(state.optimization_state)?;
274 self.learning_state
275 .from_serializable(state.learning_state)?;
276 self.statistics = state.statistics;
277 self.version = state.metadata.version;
278 self.last_updated = Instant::now(); Ok(())
281 }
282}
283
284pub struct ParameterHistory<T: Float + Debug + Send + Sync + 'static> {
286 snapshots: VecDeque<ParameterSnapshot<T>>,
288
289 max_size: usize,
291
292 parameter_dimension: usize,
294
295 statistics: ParameterStatistics<T>,
297}
298
299impl<T: Float + Debug + Send + Sync + 'static> ParameterHistory<T> {
300 pub fn new(max_size: usize, parameter_dimension: usize) -> Result<Self> {
301 Ok(Self {
302 snapshots: VecDeque::new(),
303 max_size,
304 parameter_dimension,
305 statistics: ParameterStatistics::new(),
306 })
307 }
308
309 pub fn record_parameters(&mut self, parameters: &Array1<T>) -> Result<()> {
310 let snapshot = ParameterSnapshot {
311 parameters: parameters.clone(),
312 timestamp: std::time::Instant::now(),
313 norm: parameters
314 .iter()
315 .map(|&x| x * x)
316 .fold(T::zero(), |acc, x| acc + x)
317 .sqrt(),
318 };
319
320 self.snapshots.push_back(snapshot.clone());
321 if self.snapshots.len() > self.max_size {
322 self.snapshots.pop_front();
323 }
324
325 self.statistics.update_with_snapshot(&snapshot);
326 Ok(())
327 }
328
329 pub fn get_recent_parameters(&self, count: usize) -> Vec<Array1<T>> {
330 self.snapshots
331 .iter()
332 .rev()
333 .take(count)
334 .map(|snapshot| snapshot.parameters.clone())
335 .collect()
336 }
337
338 pub fn get_statistics(&self) -> ParameterStatistics<T> {
339 self.statistics.clone()
340 }
341
342 pub fn clear(&mut self) {
343 self.snapshots.clear();
344 self.statistics = ParameterStatistics::new();
345 }
346}
347
348#[derive(Debug, Clone)]
350pub struct OptimizationState<T: Float + Debug + Send + Sync + 'static> {
351 pub learning_rate: T,
353
354 pub momentum: Option<Array1<T>>,
356
357 pub adaptive_state: Option<AdaptiveState<T>>,
359
360 pub gradient_accumulator: GradientAccumulator<T>,
362
363 pub step_count: usize,
365
366 pub last_update_magnitude: T,
368
369 pub convergence_tracker: ConvergenceTracker<T>,
371}
372
373impl<T: Float + Debug + Send + Sync + 'static> OptimizationState<T> {
374 pub fn new(config: &TransformerBasedOptimizerConfig<T>) -> Result<Self> {
375 let parameter_count = config.model_dimension * config.num_transformer_layers;
376
377 Ok(Self {
378 learning_rate: config.learning_rate,
379 momentum: None,
380 adaptive_state: Some(AdaptiveState::new(parameter_count)?),
381 gradient_accumulator: GradientAccumulator::new(parameter_count)?,
382 step_count: 0,
383 last_update_magnitude: T::zero(),
384 convergence_tracker: ConvergenceTracker::new(),
385 })
386 }
387
388 pub fn update_with_step(&mut self, update: &Array1<T>, loss: Option<T>) -> Result<()> {
389 self.step_count += 1;
390 self.last_update_magnitude = update
391 .iter()
392 .map(|&x| x * x)
393 .fold(T::zero(), |acc, x| acc + x)
394 .sqrt();
395
396 if let Some(loss_val) = loss {
397 self.convergence_tracker.record_loss(loss_val);
398 }
399
400 if let Some(ref mut adaptive) = self.adaptive_state {
402 adaptive.update_with_step(update)?;
403 }
404
405 Ok(())
406 }
407
408 pub fn get_progress(&self) -> OptimizationProgress<T> {
409 OptimizationProgress {
410 step_count: self.step_count,
411 current_learning_rate: self.learning_rate,
412 last_update_magnitude: self.last_update_magnitude,
413 convergence_rate: self.convergence_tracker.get_convergence_rate(),
414 stability_score: self.convergence_tracker.get_stability_score(),
415 }
416 }
417
418 pub fn reset(&mut self) -> Result<()> {
419 self.step_count = 0;
420 self.last_update_magnitude = T::zero();
421 self.convergence_tracker.reset();
422
423 if let Some(ref mut adaptive) = self.adaptive_state {
424 adaptive.reset()?;
425 }
426
427 self.gradient_accumulator.reset()?;
428 Ok(())
429 }
430
431 pub fn validate(&self) -> Result<ValidationResult> {
432 let mut issues = Vec::new();
433
434 if self.learning_rate <= T::zero() {
435 issues.push("Invalid learning rate".to_string());
436 }
437
438 if !self.last_update_magnitude.is_finite() {
439 issues.push("Invalid update magnitude".to_string());
440 }
441
442 Ok(ValidationResult { issues })
443 }
444
445 pub fn to_serializable(&self) -> Result<SerializableOptimizationState<T>> {
446 Ok(SerializableOptimizationState {
447 learning_rate: self.learning_rate,
448 step_count: self.step_count,
449 last_update_magnitude: self.last_update_magnitude,
450 momentum: self.momentum.as_ref().map(|m| m.to_vec()),
451 convergence_metrics: self.convergence_tracker.to_serializable(),
452 })
453 }
454
455 pub fn from_serializable(&mut self, state: SerializableOptimizationState<T>) -> Result<()> {
456 self.learning_rate = state.learning_rate;
457 self.step_count = state.step_count;
458 self.last_update_magnitude = state.last_update_magnitude;
459
460 if let Some(momentum_vec) = state.momentum {
461 self.momentum = Some(Array1::from_vec(momentum_vec));
462 }
463
464 self.convergence_tracker
465 .from_serializable(state.convergence_metrics)?;
466 Ok(())
467 }
468}
469
470#[derive(Debug, Clone)]
472pub struct LearningState<T: Float + Debug + Send + Sync + 'static> {
473 loss_history: VecDeque<T>,
475
476 meta_state: Option<MetaState<T>>,
478
479 adaptation_history: VecDeque<TaskAdaptationRecord<T>>,
481
482 learning_schedule: LearningSchedule<T>,
484
485 performance_metrics: LearningPerformanceMetrics<T>,
487}
488
489impl<T: Float + Debug + Send + Sync + 'static> LearningState<T> {
490 pub fn new(config: &TransformerBasedOptimizerConfig<T>) -> Result<Self> {
491 let meta_state = Some(MetaState::new(config.model_dimension)?);
492 let learning_schedule = LearningSchedule::new(config.learning_rate, config.warmup_steps);
493
494 Ok(Self {
495 loss_history: VecDeque::new(),
496 meta_state,
497 adaptation_history: VecDeque::new(),
498 learning_schedule,
499 performance_metrics: LearningPerformanceMetrics::new(),
500 })
501 }
502
503 pub fn update_with_loss(&mut self, loss: T) -> Result<()> {
504 self.loss_history.push_back(loss);
505 if self.loss_history.len() > 1000 {
506 self.loss_history.pop_front();
507 }
508
509 self.performance_metrics.record_loss(loss);
510
511 if let Some(ref mut meta) = self.meta_state {
512 meta.update_loss_history(loss);
513 }
514
515 Ok(())
516 }
517
518 pub fn get_statistics(&self) -> LearningStatistics<T> {
519 LearningStatistics {
520 total_episodes: self.loss_history.len(),
521 average_loss: self.get_average_loss(),
522 best_loss: self.get_best_loss(),
523 convergence_rate: self.get_convergence_rate(),
524 learning_stability: self.performance_metrics.get_stability_score(),
525 }
526 }
527
528 pub fn get_average_loss(&self) -> T {
529 if self.loss_history.is_empty() {
530 T::zero()
531 } else {
532 self.loss_history
533 .iter()
534 .fold(T::zero(), |acc, &loss| acc + loss)
535 / T::from(self.loss_history.len()).unwrap()
536 }
537 }
538
539 pub fn get_best_loss(&self) -> T {
540 self.loss_history
541 .iter()
542 .fold(T::infinity(), |min, &loss| min.min(loss))
543 }
544
545 pub fn get_convergence_rate(&self) -> T {
546 if self.loss_history.len() < 2 {
547 return T::zero();
548 }
549
550 let recent_losses: Vec<_> = self.loss_history.iter().rev().take(10).cloned().collect();
551 if recent_losses.len() < 2 {
552 return T::zero();
553 }
554
555 let initial = recent_losses.last().unwrap();
556 let final_loss = recent_losses.first().unwrap();
557
558 if *initial > T::zero() {
559 (*initial - *final_loss) / *initial
560 } else {
561 T::zero()
562 }
563 }
564
565 pub fn reset(&mut self) -> Result<()> {
566 self.loss_history.clear();
567 self.adaptation_history.clear();
568 self.performance_metrics.reset();
569
570 if let Some(ref mut meta) = self.meta_state {
571 *meta = MetaState::new(meta.get_parameters().len())?;
572 }
573
574 Ok(())
575 }
576
577 pub fn validate(&self) -> Result<ValidationResult> {
578 let mut issues = Vec::new();
579
580 if self.loss_history.iter().any(|&loss| !loss.is_finite()) {
581 issues.push("Invalid loss values detected".to_string());
582 }
583
584 Ok(ValidationResult { issues })
585 }
586
587 pub fn to_serializable(&self) -> Result<SerializableLearningState<T>> {
588 Ok(SerializableLearningState {
589 loss_history: self.loss_history.iter().cloned().collect(),
590 average_loss: self.get_average_loss(),
591 best_loss: self.get_best_loss(),
592 convergence_rate: self.get_convergence_rate(),
593 })
594 }
595
596 pub fn from_serializable(&mut self, state: SerializableLearningState<T>) -> Result<()> {
597 self.loss_history = VecDeque::from(state.loss_history);
598 Ok(())
599 }
600}
601
602#[derive(Debug, Clone)]
604pub struct MemoryState<T: Float + Debug + Send + Sync + 'static> {
605 attention_caches: HashMap<String, AttentionCache<T>>,
607
608 memory_usage: MemoryUsageTracker,
610
611 cache_statistics: CacheStatistics,
613}
614
615impl<T: Float + Debug + Send + Sync + 'static> MemoryState<T> {
616 pub fn new() -> Result<Self> {
617 Ok(Self {
618 attention_caches: HashMap::new(),
619 memory_usage: MemoryUsageTracker::new(),
620 cache_statistics: CacheStatistics::new(),
621 })
622 }
623
624 pub fn get_total_usage(&self) -> usize {
625 self.memory_usage.total_usage
626 }
627
628 pub fn reset(&mut self) -> Result<()> {
629 self.attention_caches.clear();
630 self.memory_usage.reset();
631 self.cache_statistics.reset();
632 Ok(())
633 }
634}
635
636pub struct CheckpointManager<T: Float + Debug + Send + Sync + 'static> {
638 checkpoints: HashMap<String, OptimizerStateSnapshot<T>>,
640
641 metadata: HashMap<String, CheckpointMetadata>,
643
644 max_checkpoints: usize,
646
647 auto_save_config: AutoSaveConfig,
649}
650
651impl<T: Float + Debug + Send + Sync + 'static> CheckpointManager<T> {
652 pub fn new(config: &TransformerBasedOptimizerConfig<T>) -> Result<Self> {
653 Ok(Self {
654 checkpoints: HashMap::new(),
655 metadata: HashMap::new(),
656 max_checkpoints: 10,
657 auto_save_config: AutoSaveConfig::default(),
658 })
659 }
660
661 pub fn save_checkpoint(
662 &mut self,
663 name: String,
664 snapshot: OptimizerStateSnapshot<T>,
665 ) -> Result<String> {
666 let checkpoint_id = format!("{}_{}", name, snapshot.version);
667
668 let metadata = CheckpointMetadata {
670 id: checkpoint_id.clone(),
671 name: name.clone(),
672 created_at: std::time::Instant::now(),
673 size_estimate: snapshot.parameters.len() * std::mem::size_of::<T>(),
674 description: format!("Checkpoint at version {}", snapshot.version),
675 };
676
677 self.checkpoints.insert(checkpoint_id.clone(), snapshot);
678 self.metadata.insert(checkpoint_id.clone(), metadata);
679
680 if self.checkpoints.len() > self.max_checkpoints {
682 self.cleanup_old_checkpoints()?;
683 }
684
685 Ok(checkpoint_id)
686 }
687
688 pub fn load_checkpoint(&self, checkpoint_id: &str) -> Result<OptimizerStateSnapshot<T>> {
689 self.checkpoints.get(checkpoint_id).cloned().ok_or_else(|| {
690 crate::error::OptimError::Other(format!("Checkpoint {} not found", checkpoint_id))
691 })
692 }
693
694 pub fn list_checkpoints(&self) -> Vec<CheckpointMetadata> {
695 self.metadata.values().cloned().collect()
696 }
697
698 pub fn delete_checkpoint(&mut self, checkpoint_id: &str) -> Result<bool> {
699 let removed_checkpoint = self.checkpoints.remove(checkpoint_id).is_some();
700 let removed_metadata = self.metadata.remove(checkpoint_id).is_some();
701 Ok(removed_checkpoint && removed_metadata)
702 }
703
704 pub fn get_checkpoint_count(&self) -> usize {
705 self.checkpoints.len()
706 }
707
708 fn cleanup_old_checkpoints(&mut self) -> Result<()> {
709 while self.checkpoints.len() > self.max_checkpoints {
711 if let Some((oldest_id, _)) = self
712 .metadata
713 .iter()
714 .min_by_key(|(_, metadata)| metadata.created_at)
715 .map(|(id, metadata)| (id.clone(), metadata.clone()))
716 {
717 self.checkpoints.remove(&oldest_id);
718 self.metadata.remove(&oldest_id);
719 } else {
720 break;
721 }
722 }
723 Ok(())
724 }
725}
726
727#[derive(Debug, Clone)]
730pub struct ParameterSnapshot<T: Float + Debug + Send + Sync + 'static> {
731 pub parameters: Array1<T>,
732 pub timestamp: std::time::Instant,
733 pub norm: T,
734}
735
736#[derive(Debug, Clone)]
737pub struct ParameterStatistics<T: Float + Debug + Send + Sync + 'static> {
738 pub total_snapshots: usize,
739 pub average_norm: T,
740 pub max_norm: T,
741 pub min_norm: T,
742 pub norm_trend: T,
743}
744
745impl<T: Float + Debug + Send + Sync + 'static> Default for ParameterStatistics<T> {
746 fn default() -> Self {
747 Self::new()
748 }
749}
750
751impl<T: Float + Debug + Send + Sync + 'static> ParameterStatistics<T> {
752 pub fn new() -> Self {
753 Self {
754 total_snapshots: 0,
755 average_norm: T::zero(),
756 max_norm: T::zero(),
757 min_norm: T::infinity(),
758 norm_trend: T::zero(),
759 }
760 }
761
762 pub fn update_with_snapshot(&mut self, snapshot: &ParameterSnapshot<T>) {
763 self.total_snapshots += 1;
764 self.average_norm = (self.average_norm
765 * scirs2_core::numeric::NumCast::from(self.total_snapshots - 1)
766 .unwrap_or_else(|| T::zero())
767 + snapshot.norm)
768 / scirs2_core::numeric::NumCast::from(self.total_snapshots)
769 .unwrap_or_else(|| T::zero());
770 self.max_norm = self.max_norm.max(snapshot.norm);
771 self.min_norm = self.min_norm.min(snapshot.norm);
772 }
773}
774
775#[derive(Debug, Clone)]
776pub struct AdaptiveState<T: Float + Debug + Send + Sync + 'static> {
777 pub m: Array1<T>,
779 pub v: Array1<T>,
781 pub step_count: usize,
783 pub beta1: T,
785 pub beta2: T,
786 pub epsilon: T,
788}
789
790impl<T: Float + Debug + Send + Sync + 'static> AdaptiveState<T> {
791 pub fn new(parameter_count: usize) -> Result<Self> {
792 Ok(Self {
793 m: Array1::zeros(parameter_count),
794 v: Array1::zeros(parameter_count),
795 step_count: 0,
796 beta1: scirs2_core::numeric::NumCast::from(0.9).unwrap_or_else(|| T::zero()),
797 beta2: scirs2_core::numeric::NumCast::from(0.999).unwrap_or_else(|| T::zero()),
798 epsilon: scirs2_core::numeric::NumCast::from(1e-8).unwrap_or_else(|| T::zero()),
799 })
800 }
801
802 pub fn update_with_step(&mut self, _update: &Array1<T>) -> Result<()> {
803 self.step_count += 1;
804 Ok(())
806 }
807
808 pub fn reset(&mut self) -> Result<()> {
809 self.m.fill(T::zero());
810 self.v.fill(T::zero());
811 self.step_count = 0;
812 Ok(())
813 }
814}
815
816#[derive(Debug, Clone)]
817pub struct GradientAccumulator<T: Float + Debug + Send + Sync + 'static> {
818 pub accumulated_gradients: Array1<T>,
820 pub accumulation_count: usize,
822}
823
824impl<T: Float + Debug + Send + Sync + 'static> GradientAccumulator<T> {
825 pub fn new(parameter_count: usize) -> Result<Self> {
826 Ok(Self {
827 accumulated_gradients: Array1::zeros(parameter_count),
828 accumulation_count: 0,
829 })
830 }
831
832 pub fn reset(&mut self) -> Result<()> {
833 self.accumulated_gradients.fill(T::zero());
834 self.accumulation_count = 0;
835 Ok(())
836 }
837}
838
839#[derive(Debug, Clone)]
840pub struct ConvergenceTracker<T: Float + Debug + Send + Sync + 'static> {
841 recent_losses: VecDeque<T>,
843 convergence_threshold: T,
845 stability_window: usize,
847}
848
849impl<T: Float + Debug + Send + Sync + 'static> Default for ConvergenceTracker<T> {
850 fn default() -> Self {
851 Self::new()
852 }
853}
854
855impl<T: Float + Debug + Send + Sync + 'static> ConvergenceTracker<T> {
856 pub fn new() -> Self {
857 Self {
858 recent_losses: VecDeque::new(),
859 convergence_threshold: scirs2_core::numeric::NumCast::from(1e-6)
860 .unwrap_or_else(|| T::zero()),
861 stability_window: 10,
862 }
863 }
864
865 pub fn record_loss(&mut self, loss: T) {
866 self.recent_losses.push_back(loss);
867 if self.recent_losses.len() > self.stability_window {
868 self.recent_losses.pop_front();
869 }
870 }
871
872 pub fn get_convergence_rate(&self) -> T {
873 if self.recent_losses.len() < 2 {
874 return T::zero();
875 }
876
877 let first = self.recent_losses[0];
878 let last = *self.recent_losses.back().unwrap();
879
880 if first > T::zero() {
881 (first - last) / first
882 } else {
883 T::zero()
884 }
885 }
886
887 pub fn get_stability_score(&self) -> T {
888 if self.recent_losses.len() < 2 {
889 return T::zero();
890 }
891
892 let mean = self.recent_losses.iter().fold(T::zero(), |acc, &x| acc + x)
893 / T::from(self.recent_losses.len()).unwrap();
894 let variance = self
895 .recent_losses
896 .iter()
897 .map(|&x| (x - mean) * (x - mean))
898 .fold(T::zero(), |acc, x| acc + x)
899 / T::from(self.recent_losses.len()).unwrap();
900
901 T::one() / (T::one() + variance.sqrt())
902 }
903
904 pub fn reset(&mut self) {
905 self.recent_losses.clear();
906 }
907
908 pub fn to_serializable(&self) -> SerializableConvergenceState<T> {
909 SerializableConvergenceState {
910 recent_losses: self.recent_losses.iter().cloned().collect(),
911 convergence_rate: self.get_convergence_rate(),
912 stability_score: self.get_stability_score(),
913 }
914 }
915
916 pub fn from_serializable(&mut self, state: SerializableConvergenceState<T>) -> Result<()> {
917 self.recent_losses = VecDeque::from(state.recent_losses);
918 Ok(())
919 }
920}
921
922#[derive(Debug, Clone)]
924pub struct OptimizerStateSnapshot<T: Float + Debug + Send + Sync + 'static> {
925 pub parameters: Array1<T>,
926 pub optimization_state: OptimizationState<T>,
927 pub learning_state: LearningState<T>,
928 pub memory_state: MemoryState<T>,
929 pub version: usize,
930 pub timestamp: std::time::Instant,
931 pub metadata: SnapshotMetadata,
932}
933
934#[derive(Debug, Clone)]
935pub struct SnapshotMetadata {
936 pub parameter_count: usize,
937 pub total_updates: usize,
938 pub session_duration: Duration,
939}
940
941#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
943pub struct StateConfig {
944 pub max_history_size: usize,
945 pub checkpoint_frequency: usize,
946 pub auto_save_enabled: bool,
947 pub validation_enabled: bool,
948}
949
950impl StateConfig {
951 pub fn from_optimizer_config<T: Float + Debug + Send + Sync + 'static>(
952 config: &TransformerBasedOptimizerConfig<T>,
953 ) -> Self {
954 Self {
955 max_history_size: 1000,
956 checkpoint_frequency: 100,
957 auto_save_enabled: true,
958 validation_enabled: true,
959 }
960 }
961}
962
963#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
964pub struct StateMetadata {
965 pub version: usize,
966 pub created_at: std::time::SystemTime,
967 pub last_updated: std::time::SystemTime,
968 pub total_updates: usize,
969 pub configuration: StateConfig,
970}
971
972#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
974pub struct StateStatistics<T: Float + Debug + Send + Sync + 'static> {
975 pub total_updates: usize,
976 pub last_update_magnitude: T,
977 pub average_update_magnitude: T,
978 pub parameter_change_rate: T,
979 pub update_frequency: f64,
980}
981
982impl<T: Float + Debug + Send + Sync + 'static> Default for StateStatistics<T> {
983 fn default() -> Self {
984 Self::new()
985 }
986}
987
988impl<T: Float + Debug + Send + Sync + 'static> StateStatistics<T> {
989 pub fn new() -> Self {
990 Self {
991 total_updates: 0,
992 last_update_magnitude: T::zero(),
993 average_update_magnitude: T::zero(),
994 parameter_change_rate: T::zero(),
995 update_frequency: 0.0,
996 }
997 }
998
999 pub fn record_update(&mut self, update: &Array1<T>, _loss: Option<T>) {
1000 self.total_updates += 1;
1001 let magnitude = update
1002 .iter()
1003 .map(|&x| x * x)
1004 .fold(T::zero(), |acc, x| acc + x)
1005 .sqrt();
1006 self.last_update_magnitude = magnitude;
1007 self.average_update_magnitude = (self.average_update_magnitude
1008 * scirs2_core::numeric::NumCast::from(self.total_updates - 1)
1009 .unwrap_or_else(|| T::zero())
1010 + magnitude)
1011 / scirs2_core::numeric::NumCast::from(self.total_updates).unwrap_or_else(|| T::zero());
1012 }
1013
1014 pub fn reset(&mut self) {
1015 self.total_updates = 0;
1016 self.last_update_magnitude = T::zero();
1017 self.average_update_magnitude = T::zero();
1018 self.parameter_change_rate = T::zero();
1019 self.update_frequency = 0.0;
1020 }
1021}
1022
1023#[derive(Debug, Clone)]
1025pub struct StateValidationReport {
1026 pub is_valid: bool,
1027 pub issues: Vec<String>,
1028 pub validation_timestamp: Instant,
1029}
1030
1031#[derive(Debug, Clone)]
1032pub struct ValidationResult {
1033 pub issues: Vec<String>,
1034}
1035
1036#[derive(Debug, Clone)]
1037pub struct StateSummary<T: Float + Debug + Send + Sync + 'static> {
1038 pub version: usize,
1039 pub parameter_count: usize,
1040 pub parameter_norm: T,
1041 pub total_updates: usize,
1042 pub session_duration: Duration,
1043 pub last_update_magnitude: T,
1044 pub average_loss: T,
1045 pub convergence_rate: T,
1046 pub memory_usage: usize,
1047 pub checkpoint_count: usize,
1048}
1049
1050#[derive(Debug, Clone, Serialize, Deserialize)]
1052pub struct SerializableState<T: Float + Debug + Send + Sync + 'static> {
1053 pub parameters: Vec<T>,
1054 pub parameter_shape: Vec<usize>,
1055 pub optimization_state: SerializableOptimizationState<T>,
1056 pub learning_state: SerializableLearningState<T>,
1057 pub metadata: StateMetadata,
1058 pub statistics: StateStatistics<T>,
1059}
1060
1061#[derive(Debug, Clone, Serialize, Deserialize)]
1062pub struct SerializableOptimizationState<T: Float + Debug + Send + Sync + 'static> {
1063 pub learning_rate: T,
1064 pub step_count: usize,
1065 pub last_update_magnitude: T,
1066 pub momentum: Option<Vec<T>>,
1067 pub convergence_metrics: SerializableConvergenceState<T>,
1068}
1069
1070#[derive(Debug, Clone, Serialize, Deserialize)]
1071pub struct SerializableLearningState<T: Float + Debug + Send + Sync + 'static> {
1072 pub loss_history: Vec<T>,
1073 pub average_loss: T,
1074 pub best_loss: T,
1075 pub convergence_rate: T,
1076}
1077
1078#[derive(Debug, Clone, Serialize, Deserialize)]
1079pub struct SerializableConvergenceState<T: Float + Debug + Send + Sync + 'static> {
1080 pub recent_losses: Vec<T>,
1081 pub convergence_rate: T,
1082 pub stability_score: T,
1083}
1084
1085#[derive(Debug, Clone)]
1087pub struct TaskAdaptationRecord<T: Float + Debug + Send + Sync + 'static> {
1088 pub task_id: String,
1089 pub adaptation_steps: usize,
1090 pub final_loss: T,
1091 pub adaptation_time: Duration,
1092}
1093
1094#[derive(Debug, Clone)]
1095pub struct LearningSchedule<T: Float + Debug + Send + Sync + 'static> {
1096 pub initial_rate: T,
1097 pub current_rate: T,
1098 pub warmup_steps: usize,
1099 pub decay_factor: T,
1100}
1101
1102impl<T: Float + Debug + Send + Sync + 'static> LearningSchedule<T> {
1103 pub fn new(initial_rate: T, warmup_steps: usize) -> Self {
1104 Self {
1105 initial_rate,
1106 current_rate: initial_rate,
1107 warmup_steps,
1108 decay_factor: scirs2_core::numeric::NumCast::from(0.95).unwrap_or_else(|| T::zero()),
1109 }
1110 }
1111}
1112
1113#[derive(Debug, Clone)]
1114pub struct LearningPerformanceMetrics<T: Float + Debug + Send + Sync + 'static> {
1115 pub loss_trend: T,
1116 pub convergence_stability: T,
1117 pub adaptation_efficiency: T,
1118}
1119
1120impl<T: Float + Debug + Send + Sync + 'static> Default for LearningPerformanceMetrics<T> {
1121 fn default() -> Self {
1122 Self::new()
1123 }
1124}
1125
1126impl<T: Float + Debug + Send + Sync + 'static> LearningPerformanceMetrics<T> {
1127 pub fn new() -> Self {
1128 Self {
1129 loss_trend: T::zero(),
1130 convergence_stability: T::zero(),
1131 adaptation_efficiency: T::zero(),
1132 }
1133 }
1134
1135 pub fn record_loss(&mut self, _loss: T) {
1136 }
1138
1139 pub fn get_stability_score(&self) -> T {
1140 self.convergence_stability
1141 }
1142
1143 pub fn reset(&mut self) {
1144 self.loss_trend = T::zero();
1145 self.convergence_stability = T::zero();
1146 self.adaptation_efficiency = T::zero();
1147 }
1148}
1149
1150#[derive(Debug, Clone)]
1151pub struct OptimizationProgress<T: Float + Debug + Send + Sync + 'static> {
1152 pub step_count: usize,
1153 pub current_learning_rate: T,
1154 pub last_update_magnitude: T,
1155 pub convergence_rate: T,
1156 pub stability_score: T,
1157}
1158
1159#[derive(Debug, Clone)]
1160pub struct LearningStatistics<T: Float + Debug + Send + Sync + 'static> {
1161 pub total_episodes: usize,
1162 pub average_loss: T,
1163 pub best_loss: T,
1164 pub convergence_rate: T,
1165 pub learning_stability: T,
1166}
1167
1168#[derive(Debug, Clone)]
1169pub struct AttentionCache<T: Float + Debug + Send + Sync + 'static> {
1170 pub cached_keys: Array2<T>,
1171 pub cached_values: Array2<T>,
1172 pub cache_size: usize,
1173}
1174
1175#[derive(Debug, Clone)]
1176pub struct MemoryUsageTracker {
1177 pub total_usage: usize,
1178 pub peak_usage: usize,
1179 pub allocation_count: usize,
1180}
1181
1182impl Default for MemoryUsageTracker {
1183 fn default() -> Self {
1184 Self::new()
1185 }
1186}
1187
1188impl MemoryUsageTracker {
1189 pub fn new() -> Self {
1190 Self {
1191 total_usage: 0,
1192 peak_usage: 0,
1193 allocation_count: 0,
1194 }
1195 }
1196
1197 pub fn reset(&mut self) {
1198 self.total_usage = 0;
1199 self.peak_usage = 0;
1200 self.allocation_count = 0;
1201 }
1202}
1203
1204#[derive(Debug, Clone)]
1205pub struct CacheStatistics {
1206 pub hit_count: usize,
1207 pub miss_count: usize,
1208 pub eviction_count: usize,
1209}
1210
1211impl Default for CacheStatistics {
1212 fn default() -> Self {
1213 Self::new()
1214 }
1215}
1216
1217impl CacheStatistics {
1218 pub fn new() -> Self {
1219 Self {
1220 hit_count: 0,
1221 miss_count: 0,
1222 eviction_count: 0,
1223 }
1224 }
1225
1226 pub fn reset(&mut self) {
1227 self.hit_count = 0;
1228 self.miss_count = 0;
1229 self.eviction_count = 0;
1230 }
1231}
1232
1233#[derive(Debug, Clone)]
1234pub struct CheckpointMetadata {
1235 pub id: String,
1236 pub name: String,
1237 pub created_at: Instant,
1238 pub size_estimate: usize,
1239 pub description: String,
1240}
1241
1242#[derive(Debug, Clone)]
1243pub struct AutoSaveConfig {
1244 pub enabled: bool,
1245 pub frequency: usize,
1246 pub max_auto_saves: usize,
1247}
1248
1249impl Default for AutoSaveConfig {
1250 fn default() -> Self {
1251 Self {
1252 enabled: true,
1253 frequency: 100,
1254 max_auto_saves: 5,
1255 }
1256 }
1257}
1258
1259#[cfg(test)]
1260mod tests {
1261 use super::*;
1262
1263 #[test]
1264 fn test_optimizer_state_creation() {
1265 let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
1266 let state = TransformerOptimizerState::new(&config);
1267 assert!(state.is_ok());
1268
1269 let s = state.unwrap();
1270 assert_eq!(s.version, 0);
1271 assert!(!s.current_parameters.is_empty());
1272 }
1273
1274 #[test]
1275 fn test_state_update() {
1276 let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
1277 let mut state = TransformerOptimizerState::new(&config).unwrap();
1278
1279 let update = Array1::<f32>::ones(state.current_parameters.len());
1280 let result = state.update_with_step(&update, Some(1.5));
1281 assert!(result.is_ok());
1282 assert_eq!(state.version, 1);
1283 }
1284
1285 #[test]
1286 fn test_snapshot_creation() {
1287 let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
1288 let state = TransformerOptimizerState::new(&config).unwrap();
1289
1290 let snapshot = state.create_snapshot();
1291 assert!(snapshot.is_ok());
1292
1293 let snap = snapshot.unwrap();
1294 assert_eq!(snap.version, 0);
1295 assert_eq!(snap.parameters.len(), state.current_parameters.len());
1296 }
1297
1298 #[test]
1299 fn test_checkpoint_management() {
1300 let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
1301 let mut state = TransformerOptimizerState::new(&config).unwrap();
1302
1303 let checkpoint_id = state.save_checkpoint("test_checkpoint".to_string());
1304 assert!(checkpoint_id.is_ok());
1305
1306 let id = checkpoint_id.unwrap();
1307 let load_result = state.load_checkpoint(&id);
1308 assert!(load_result.is_ok());
1309 }
1310
1311 #[test]
1312 fn test_parameter_history() {
1313 let history = ParameterHistory::<f32>::new(10, 5);
1314 assert!(history.is_ok());
1315
1316 let mut h = history.unwrap();
1317 let params = Array1::<f32>::ones(5);
1318 assert!(h.record_parameters(¶ms).is_ok());
1319
1320 let recent = h.get_recent_parameters(1);
1321 assert_eq!(recent.len(), 1);
1322 }
1323
1324 #[test]
1325 fn test_convergence_tracker() {
1326 let mut tracker = ConvergenceTracker::<f32>::new();
1327
1328 tracker.record_loss(2.0);
1329 tracker.record_loss(1.5);
1330 tracker.record_loss(1.0);
1331
1332 let convergence = tracker.get_convergence_rate();
1333 assert!(convergence > 0.0);
1334
1335 let stability = tracker.get_stability_score();
1336 assert!(stability > 0.0 && stability <= 1.0);
1337 }
1338
1339 #[test]
1340 fn test_state_validation() {
1341 let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
1342 let state = TransformerOptimizerState::new(&config).unwrap();
1343
1344 let validation = state.validate_state();
1345 assert!(validation.is_ok());
1346
1347 let report = validation.unwrap();
1348 assert!(report.is_valid);
1349 }
1350}