1use super::config::TransformerBasedOptimizerConfig;
4use crate::error::Result;
5use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
6use scirs2_core::numeric::Float;
7use std::collections::{HashMap, VecDeque};
8use std::fmt::Debug;
9
10type SequenceDataTuple<T> = (Vec<Array2<T>>, Vec<Array2<T>>, Vec<Array1<T>>);
12
13#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum SequenceProcessingStrategy {
16 SlidingWindow,
18 Hierarchical,
20 AttentionBased,
22 Adaptive,
24 TruncatedBPTT,
26}
27
28pub struct OptimizationSequenceProcessor<T: Float + Debug + Send + Sync + 'static> {
30 strategy: SequenceProcessingStrategy,
32
33 max_sequence_length: usize,
35
36 window_size: usize,
38
39 window_overlap: usize,
41
42 model_dimension: usize,
44
45 sequence_buffer: SequenceBuffer<T>,
47
48 statistics: SequenceStatistics<T>,
50
51 preprocessor: SequencePreprocessor<T>,
53
54 chunking: ChunkingStrategy<T>,
56}
57
58impl<T: Float + Debug + scirs2_core::numeric::FromPrimitive + Send + Sync>
59 OptimizationSequenceProcessor<T>
60{
61 pub fn new(config: &TransformerBasedOptimizerConfig<T>) -> Result<Self> {
63 let strategy = SequenceProcessingStrategy::SlidingWindow;
64 let max_sequence_length = config.sequence_length;
65 let window_size = max_sequence_length / 2;
66 let window_overlap = window_size / 4;
67 let model_dimension = config.model_dimension;
68
69 let sequence_buffer = SequenceBuffer::new(1000, model_dimension)?;
70 let statistics = SequenceStatistics::new();
71 let preprocessor = SequencePreprocessor::new(model_dimension)?;
72 let chunking = ChunkingStrategy::new(max_sequence_length, window_size)?;
73
74 Ok(Self {
75 strategy,
76 max_sequence_length,
77 window_size,
78 window_overlap,
79 model_dimension,
80 sequence_buffer,
81 statistics,
82 preprocessor,
83 chunking,
84 })
85 }
86
87 pub fn process_optimization_sequence(
89 &mut self,
90 gradient_history: &Array2<T>,
91 parameter_history: &Array2<T>,
92 loss_history: &Array1<T>,
93 ) -> Result<Array2<T>> {
94 self.statistics
96 .update(gradient_history, parameter_history, loss_history)?;
97
98 self.sequence_buffer
100 .add_sequence(gradient_history, parameter_history, loss_history)?;
101
102 match self.strategy {
103 SequenceProcessingStrategy::SlidingWindow => {
104 self.process_sliding_window(gradient_history, parameter_history, loss_history)
105 }
106 SequenceProcessingStrategy::Hierarchical => {
107 self.process_hierarchical(gradient_history, parameter_history, loss_history)
108 }
109 SequenceProcessingStrategy::AttentionBased => {
110 self.process_attention_based(gradient_history, parameter_history, loss_history)
111 }
112 SequenceProcessingStrategy::Adaptive => {
113 self.process_adaptive(gradient_history, parameter_history, loss_history)
114 }
115 SequenceProcessingStrategy::TruncatedBPTT => {
116 self.process_truncated_bptt(gradient_history, parameter_history, loss_history)
117 }
118 }
119 }
120
121 fn process_sliding_window(
123 &mut self,
124 gradient_history: &Array2<T>,
125 parameter_history: &Array2<T>,
126 loss_history: &Array1<T>,
127 ) -> Result<Array2<T>> {
128 let sequence_length = gradient_history.shape()[0];
129
130 if sequence_length <= self.max_sequence_length {
131 return self.combine_sequences(gradient_history, parameter_history, loss_history);
133 }
134
135 let mut processed_chunks = Vec::new();
137 let step_size = self.window_size - self.window_overlap;
138
139 for start in (0..sequence_length).step_by(step_size) {
140 let end = (start + self.window_size).min(sequence_length);
141
142 let grad_chunk = gradient_history.slice(s![start..end, ..]);
143 let param_chunk = parameter_history.slice(s![start..end, ..]);
144 let loss_chunk = loss_history.slice(s![start..end]);
145
146 let processed_chunk = self.combine_sequences(
147 &grad_chunk.to_owned(),
148 ¶m_chunk.to_owned(),
149 &loss_chunk.to_owned(),
150 )?;
151
152 processed_chunks.push(processed_chunk);
153 }
154
155 self.combine_chunks(&processed_chunks)
157 }
158
159 fn process_hierarchical(
161 &mut self,
162 gradient_history: &Array2<T>,
163 parameter_history: &Array2<T>,
164 loss_history: &Array1<T>,
165 ) -> Result<Array2<T>> {
166 let sequence_length = gradient_history.shape()[0];
167
168 let mut levels = Vec::new();
170 let mut current_level =
171 self.combine_sequences(gradient_history, parameter_history, loss_history)?;
172 levels.push(current_level.clone());
173
174 let mut current_length = sequence_length;
176 while current_length > self.max_sequence_length {
177 current_length /= 2;
178 current_level = self.downsample_sequence(¤t_level, current_length)?;
179 levels.push(current_level.clone());
180 }
181
182 Ok(levels.last().expect("unwrap failed").clone())
184 }
185
186 fn process_attention_based(
188 &mut self,
189 gradient_history: &Array2<T>,
190 parameter_history: &Array2<T>,
191 loss_history: &Array1<T>,
192 ) -> Result<Array2<T>> {
193 let sequence_length = gradient_history.shape()[0];
194
195 if sequence_length <= self.max_sequence_length {
196 return self.combine_sequences(gradient_history, parameter_history, loss_history);
197 }
198
199 let importance_scores = self.compute_importance_scores(gradient_history, loss_history)?;
201
202 let selected_indices =
204 self.select_top_k_indices(&importance_scores, self.max_sequence_length)?;
205
206 let mut selected_gradients =
208 Array2::zeros((self.max_sequence_length, gradient_history.shape()[1]));
209 let mut selected_parameters =
210 Array2::zeros((self.max_sequence_length, parameter_history.shape()[1]));
211 let mut selected_losses = Array1::zeros(self.max_sequence_length);
212
213 for (i, &idx) in selected_indices.iter().enumerate() {
214 selected_gradients
215 .row_mut(i)
216 .assign(&gradient_history.row(idx));
217 selected_parameters
218 .row_mut(i)
219 .assign(¶meter_history.row(idx));
220 selected_losses[i] = loss_history[idx];
221 }
222
223 self.combine_sequences(&selected_gradients, &selected_parameters, &selected_losses)
224 }
225
226 fn process_adaptive(
228 &mut self,
229 gradient_history: &Array2<T>,
230 parameter_history: &Array2<T>,
231 loss_history: &Array1<T>,
232 ) -> Result<Array2<T>> {
233 let change_points = self.detect_change_points(loss_history)?;
235
236 let segments = self.segment_by_change_points(
238 gradient_history,
239 parameter_history,
240 loss_history,
241 &change_points,
242 )?;
243
244 let mut processed_segments = Vec::new();
246 for segment in segments {
247 let processed =
248 self.combine_sequences(&segment.gradients, &segment.parameters, &segment.losses)?;
249 processed_segments.push(processed);
250 }
251
252 self.combine_chunks(&processed_segments)
253 }
254
255 fn process_truncated_bptt(
257 &mut self,
258 gradient_history: &Array2<T>,
259 parameter_history: &Array2<T>,
260 loss_history: &Array1<T>,
261 ) -> Result<Array2<T>> {
262 let sequence_length = gradient_history.shape()[0];
263
264 if sequence_length <= self.max_sequence_length {
265 return self.combine_sequences(gradient_history, parameter_history, loss_history);
266 }
267
268 let start_idx = sequence_length - self.max_sequence_length;
270 let grad_chunk = gradient_history.slice(s![start_idx.., ..]);
271 let param_chunk = parameter_history.slice(s![start_idx.., ..]);
272 let loss_chunk = loss_history.slice(s![start_idx..]);
273
274 self.combine_sequences(
275 &grad_chunk.to_owned(),
276 ¶m_chunk.to_owned(),
277 &loss_chunk.to_owned(),
278 )
279 }
280
281 pub fn trajectory_to_sequences(
283 &mut self,
284 trajectory: &super::OptimizationTrajectory<T>,
285 ) -> Result<Vec<super::TrainingSequence<T>>> {
286 let sequence_length = trajectory.gradient_sequence.shape()[0];
287 let mut sequences = Vec::new();
288
289 for start in (0..sequence_length).step_by(self.window_size / 2) {
291 let end = (start + self.window_size).min(sequence_length);
292
293 if end - start < self.window_size / 2 {
294 break; }
296
297 let input_end = end - 1;
298 let target_start = start + 1;
299
300 let input_gradients = trajectory.gradient_sequence.slice(s![start..input_end, ..]);
301 let input_parameters = trajectory
302 .parameter_sequence
303 .slice(s![start..input_end, ..]);
304 let input_losses = trajectory.loss_sequence.slice(s![start..input_end]);
305
306 let target_gradients = trajectory
307 .gradient_sequence
308 .slice(s![target_start..end, ..]);
309 let target_parameters = trajectory
310 .parameter_sequence
311 .slice(s![target_start..end, ..]);
312 let target_losses = trajectory.loss_sequence.slice(s![target_start..end]);
313
314 let input = self.combine_sequences(
315 &input_gradients.to_owned(),
316 &input_parameters.to_owned(),
317 &input_losses.to_owned(),
318 )?;
319
320 let target = self.combine_sequences(
321 &target_gradients.to_owned(),
322 &target_parameters.to_owned(),
323 &target_losses.to_owned(),
324 )?;
325
326 sequences.push(super::TrainingSequence {
327 input,
328 target,
329 sequence_length: input_end - start,
330 });
331 }
332
333 Ok(sequences)
334 }
335
336 fn combine_sequences(
338 &self,
339 gradients: &Array2<T>,
340 parameters: &Array2<T>,
341 losses: &Array1<T>,
342 ) -> Result<Array2<T>> {
343 self.preprocessor
344 .combine_sequences(gradients, parameters, losses)
345 }
346
347 fn combine_chunks(&self, chunks: &[Array2<T>]) -> Result<Array2<T>> {
349 if chunks.is_empty() {
350 return Err(crate::error::OptimError::Other(
351 "No chunks to combine".to_string(),
352 ));
353 }
354
355 if chunks.len() == 1 {
356 return Ok(chunks[0].clone());
357 }
358
359 let total_length: usize = chunks.iter().map(|chunk| chunk.shape()[0]).sum();
361 let feature_dim = chunks[0].shape()[1];
362
363 let mut combined = Array2::zeros((total_length.min(self.max_sequence_length), feature_dim));
364 let mut current_pos = 0;
365
366 for chunk in chunks {
367 let chunk_len = chunk.shape()[0];
368 let copy_len = (chunk_len).min(self.max_sequence_length - current_pos);
369
370 if copy_len == 0 {
371 break;
372 }
373
374 combined
375 .slice_mut(s![current_pos..current_pos + copy_len, ..])
376 .assign(&chunk.slice(s![..copy_len, ..]));
377
378 current_pos += copy_len;
379
380 if current_pos >= self.max_sequence_length {
381 break;
382 }
383 }
384
385 Ok(combined)
386 }
387
388 fn downsample_sequence(&self, sequence: &Array2<T>, target_length: usize) -> Result<Array2<T>> {
390 let current_length = sequence.shape()[0];
391 let feature_dim = sequence.shape()[1];
392
393 if current_length <= target_length {
394 return Ok(sequence.clone());
395 }
396
397 let mut downsampled = Array2::zeros((target_length, feature_dim));
398 let step = current_length as f64 / target_length as f64;
399
400 for i in 0..target_length {
401 let source_idx = (i as f64 * step) as usize;
402 downsampled
403 .row_mut(i)
404 .assign(&sequence.row(source_idx.min(current_length - 1)));
405 }
406
407 Ok(downsampled)
408 }
409
410 fn compute_importance_scores(
412 &self,
413 gradients: &Array2<T>,
414 losses: &Array1<T>,
415 ) -> Result<Array1<T>> {
416 let sequence_length = gradients.shape()[0];
417 let mut scores = Array1::zeros(sequence_length);
418
419 for i in 0..sequence_length {
420 let grad_norm = gradients
422 .row(i)
423 .iter()
424 .map(|&x| x * x)
425 .fold(T::zero(), |acc, x| acc + x)
426 .sqrt();
427
428 let loss_change = if i > 0 {
429 (losses[i] - losses[i - 1]).abs()
430 } else {
431 T::zero()
432 };
433
434 scores[i] = grad_norm + loss_change;
435 }
436
437 Ok(scores)
438 }
439
440 fn select_top_k_indices(&self, scores: &Array1<T>, k: usize) -> Result<Vec<usize>> {
442 let mut indexed_scores: Vec<(usize, T)> = scores
443 .iter()
444 .enumerate()
445 .map(|(i, &score)| (i, score))
446 .collect();
447
448 indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("unwrap failed"));
449
450 let mut selected_indices: Vec<usize> = indexed_scores
451 .into_iter()
452 .take(k)
453 .map(|(idx, _)| idx)
454 .collect();
455
456 selected_indices.sort();
457 Ok(selected_indices)
458 }
459
460 fn detect_change_points(&self, losses: &Array1<T>) -> Result<Vec<usize>> {
462 let mut change_points = vec![0]; let window_size = 5;
464 let threshold = scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero());
465
466 for i in window_size..losses.len() - window_size {
467 let before_mean = losses
468 .slice(s![i - window_size..i])
469 .mean()
470 .expect("unwrap failed");
471 let after_mean = losses
472 .slice(s![i..i + window_size])
473 .mean()
474 .expect("unwrap failed");
475
476 if (before_mean - after_mean).abs() > threshold {
477 change_points.push(i);
478 }
479 }
480
481 change_points.push(losses.len() - 1); Ok(change_points)
483 }
484
485 fn segment_by_change_points(
487 &self,
488 gradients: &Array2<T>,
489 parameters: &Array2<T>,
490 losses: &Array1<T>,
491 change_points: &[usize],
492 ) -> Result<Vec<SequenceSegment<T>>> {
493 let mut segments = Vec::new();
494
495 for i in 0..change_points.len() - 1 {
496 let start = change_points[i];
497 let end = change_points[i + 1];
498
499 let segment = SequenceSegment {
500 gradients: gradients.slice(s![start..end, ..]).to_owned(),
501 parameters: parameters.slice(s![start..end, ..]).to_owned(),
502 losses: losses.slice(s![start..end]).to_owned(),
503 start_index: start,
504 end_index: end,
505 };
506
507 segments.push(segment);
508 }
509
510 Ok(segments)
511 }
512
513 pub fn set_strategy(&mut self, strategy: SequenceProcessingStrategy) {
515 self.strategy = strategy;
516 }
517
518 pub fn get_strategy(&self) -> SequenceProcessingStrategy {
520 self.strategy
521 }
522
523 pub fn get_statistics(&self) -> &SequenceStatistics<T> {
525 &self.statistics
526 }
527
528 pub fn reset(&mut self) -> Result<()> {
530 self.sequence_buffer.clear();
531 self.statistics.reset();
532 Ok(())
533 }
534}
535
536pub struct SequenceBuffer<T: Float + Debug + Send + Sync + 'static> {
538 gradient_buffer: VecDeque<Array2<T>>,
540
541 parameter_buffer: VecDeque<Array2<T>>,
543
544 loss_buffer: VecDeque<Array1<T>>,
546
547 max_size: usize,
549
550 model_dimension: usize,
552}
553
554impl<T: Float + Debug + Send + Sync + 'static> SequenceBuffer<T> {
555 pub fn new(max_size: usize, model_dimension: usize) -> Result<Self> {
556 Ok(Self {
557 gradient_buffer: VecDeque::new(),
558 parameter_buffer: VecDeque::new(),
559 loss_buffer: VecDeque::new(),
560 max_size,
561 model_dimension,
562 })
563 }
564
565 pub fn add_sequence(
566 &mut self,
567 gradients: &Array2<T>,
568 parameters: &Array2<T>,
569 losses: &Array1<T>,
570 ) -> Result<()> {
571 self.gradient_buffer.push_back(gradients.clone());
572 self.parameter_buffer.push_back(parameters.clone());
573 self.loss_buffer.push_back(losses.clone());
574
575 while self.gradient_buffer.len() > self.max_size {
576 self.gradient_buffer.pop_front();
577 self.parameter_buffer.pop_front();
578 self.loss_buffer.pop_front();
579 }
580
581 Ok(())
582 }
583
584 pub fn clear(&mut self) {
585 self.gradient_buffer.clear();
586 self.parameter_buffer.clear();
587 self.loss_buffer.clear();
588 }
589
590 pub fn get_recent_sequences(&self, count: usize) -> SequenceDataTuple<T> {
591 let actual_count = count.min(self.gradient_buffer.len());
592
593 let gradients = self
594 .gradient_buffer
595 .iter()
596 .rev()
597 .take(actual_count)
598 .cloned()
599 .collect();
600 let parameters = self
601 .parameter_buffer
602 .iter()
603 .rev()
604 .take(actual_count)
605 .cloned()
606 .collect();
607 let losses = self
608 .loss_buffer
609 .iter()
610 .rev()
611 .take(actual_count)
612 .cloned()
613 .collect();
614
615 (gradients, parameters, losses)
616 }
617}
618
619pub struct SequenceStatistics<T: Float + Debug + Send + Sync + 'static> {
621 gradient_stats: StatisticsAccumulator<T>,
623
624 parameter_stats: StatisticsAccumulator<T>,
626
627 loss_stats: StatisticsAccumulator<T>,
629
630 length_stats: StatisticsAccumulator<T>,
632}
633
634impl<T: Float + Debug + Send + Sync + 'static> Default for SequenceStatistics<T> {
635 fn default() -> Self {
636 Self::new()
637 }
638}
639
640impl<T: Float + Debug + Send + Sync + 'static> SequenceStatistics<T> {
641 pub fn new() -> Self {
642 Self {
643 gradient_stats: StatisticsAccumulator::new(),
644 parameter_stats: StatisticsAccumulator::new(),
645 loss_stats: StatisticsAccumulator::new(),
646 length_stats: StatisticsAccumulator::new(),
647 }
648 }
649
650 pub fn update(
651 &mut self,
652 gradients: &Array2<T>,
653 parameters: &Array2<T>,
654 losses: &Array1<T>,
655 ) -> Result<()> {
656 self.gradient_stats.update_from_array2(gradients);
657 self.parameter_stats.update_from_array2(parameters);
658 self.loss_stats.update_from_array1(losses);
659 self.length_stats
660 .update(T::from(gradients.shape()[0]).expect("unwrap failed"));
661
662 Ok(())
663 }
664
665 pub fn reset(&mut self) {
666 self.gradient_stats.reset();
667 self.parameter_stats.reset();
668 self.loss_stats.reset();
669 self.length_stats.reset();
670 }
671
672 pub fn get_gradient_stats(&self) -> &StatisticsAccumulator<T> {
673 &self.gradient_stats
674 }
675
676 pub fn get_parameter_stats(&self) -> &StatisticsAccumulator<T> {
677 &self.parameter_stats
678 }
679
680 pub fn get_loss_stats(&self) -> &StatisticsAccumulator<T> {
681 &self.loss_stats
682 }
683}
684
685pub struct SequencePreprocessor<T: Float + Debug + Send + Sync + 'static> {
687 model_dimension: usize,
689
690 normalization_stats: HashMap<String, (T, T)>, }
693
694impl<T: Float + Debug + Send + Sync + 'static> SequencePreprocessor<T> {
695 pub fn new(model_dimension: usize) -> Result<Self> {
696 Ok(Self {
697 model_dimension,
698 normalization_stats: HashMap::new(),
699 })
700 }
701
702 pub fn combine_sequences(
703 &self,
704 gradients: &Array2<T>,
705 parameters: &Array2<T>,
706 losses: &Array1<T>,
707 ) -> Result<Array2<T>> {
708 let sequence_length = gradients.shape()[0];
709 let grad_dim = gradients.shape()[1];
710 let param_dim = parameters.shape()[1];
711
712 let feature_dim = grad_dim + param_dim + 2;
714 let mut combined = Array2::zeros((sequence_length, feature_dim.min(self.model_dimension)));
715
716 for i in 0..sequence_length {
717 let mut feature_idx = 0;
718
719 for j in 0..grad_dim.min(self.model_dimension / 3) {
721 if feature_idx < combined.shape()[1] {
722 combined[[i, feature_idx]] = gradients[[i, j]];
723 feature_idx += 1;
724 }
725 }
726
727 for j in 0..param_dim.min(self.model_dimension / 3) {
729 if feature_idx < combined.shape()[1] {
730 combined[[i, feature_idx]] = parameters[[i, j]];
731 feature_idx += 1;
732 }
733 }
734
735 if feature_idx < combined.shape()[1] {
737 combined[[i, feature_idx]] = losses[i];
738 feature_idx += 1;
739 }
740
741 if feature_idx < combined.shape()[1] && i > 0 {
743 let normalized_loss = if losses[0] != T::zero() {
744 losses[i] / losses[0]
745 } else {
746 T::one()
747 };
748 combined[[i, feature_idx]] = normalized_loss;
749 }
750 }
751
752 Ok(combined)
753 }
754}
755
756pub struct ChunkingStrategy<T: Float + Debug + Send + Sync + 'static> {
758 max_chunk_size: usize,
760
761 overlap_size: usize,
763
764 chunk_stats: StatisticsAccumulator<T>,
766}
767
768impl<T: Float + Debug + Send + Sync + 'static> ChunkingStrategy<T> {
769 pub fn new(max_chunk_size: usize, overlap_size: usize) -> Result<Self> {
770 Ok(Self {
771 max_chunk_size,
772 overlap_size,
773 chunk_stats: StatisticsAccumulator::new(),
774 })
775 }
776
777 pub fn create_chunks(&mut self, sequence: &Array2<T>) -> Result<Vec<Array2<T>>> {
778 let sequence_length = sequence.shape()[0];
779 let mut chunks = Vec::new();
780
781 if sequence_length <= self.max_chunk_size {
782 chunks.push(sequence.clone());
783 return Ok(chunks);
784 }
785
786 let step_size = self.max_chunk_size - self.overlap_size;
787
788 for start in (0..sequence_length).step_by(step_size) {
789 let end = (start + self.max_chunk_size).min(sequence_length);
790 let chunk = sequence.slice(s![start..end, ..]).to_owned();
791 chunks.push(chunk);
792
793 if end >= sequence_length {
794 break;
795 }
796 }
797
798 Ok(chunks)
799 }
800}
801
802#[derive(Debug, Clone)]
804pub struct SequenceSegment<T: Float + Debug + Send + Sync + 'static> {
805 pub gradients: Array2<T>,
806 pub parameters: Array2<T>,
807 pub losses: Array1<T>,
808 pub start_index: usize,
809 pub end_index: usize,
810}
811
812pub struct StatisticsAccumulator<T: Float + Debug + Send + Sync + 'static> {
813 count: usize,
814 sum: T,
815 sum_sq: T,
816 min: T,
817 max: T,
818}
819
820impl<T: Float + Debug + Send + Sync + 'static> Default for StatisticsAccumulator<T> {
821 fn default() -> Self {
822 Self::new()
823 }
824}
825
826impl<T: Float + Debug + Send + Sync + 'static> StatisticsAccumulator<T> {
827 pub fn new() -> Self {
828 Self {
829 count: 0,
830 sum: T::zero(),
831 sum_sq: T::zero(),
832 min: T::infinity(),
833 max: T::neg_infinity(),
834 }
835 }
836
837 pub fn update(&mut self, value: T) {
838 self.count += 1;
839 self.sum = self.sum + value;
840 self.sum_sq = self.sum_sq + value * value;
841 self.min = self.min.min(value);
842 self.max = self.max.max(value);
843 }
844
845 pub fn update_from_array1(&mut self, array: &Array1<T>) {
846 for &value in array.iter() {
847 self.update(value);
848 }
849 }
850
851 pub fn update_from_array2(&mut self, array: &Array2<T>) {
852 for &value in array.iter() {
853 self.update(value);
854 }
855 }
856
857 pub fn mean(&self) -> T {
858 if self.count > 0 {
859 self.sum / scirs2_core::numeric::NumCast::from(self.count).unwrap_or_else(|| T::zero())
860 } else {
861 T::zero()
862 }
863 }
864
865 pub fn variance(&self) -> T {
866 if self.count > 1 {
867 let mean = self.mean();
868 (self.sum_sq
869 / scirs2_core::numeric::NumCast::from(self.count).unwrap_or_else(|| T::zero()))
870 - (mean * mean)
871 } else {
872 T::zero()
873 }
874 }
875
876 pub fn std_dev(&self) -> T {
877 self.variance().sqrt()
878 }
879
880 pub fn min(&self) -> T {
881 self.min
882 }
883
884 pub fn max(&self) -> T {
885 self.max
886 }
887
888 pub fn reset(&mut self) {
889 self.count = 0;
890 self.sum = T::zero();
891 self.sum_sq = T::zero();
892 self.min = T::infinity();
893 self.max = T::neg_infinity();
894 }
895}
896
897use scirs2_core::ndarray::s;
899
900#[cfg(test)]
901mod tests {
902 use super::*;
903
904 #[test]
905 fn test_sequence_processor_creation() {
906 let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
907 let processor = OptimizationSequenceProcessor::new(&config);
908 assert!(processor.is_ok());
909 }
910
911 #[test]
912 fn test_sequence_buffer() {
913 let buffer = SequenceBuffer::<f32>::new(10, 64);
914 assert!(buffer.is_ok());
915
916 let mut buf = buffer.expect("unwrap failed");
917 let gradients = Array2::<f32>::ones((5, 64));
918 let parameters = Array2::<f32>::ones((5, 64));
919 let losses = Array1::<f32>::ones(5);
920
921 assert!(buf.add_sequence(&gradients, ¶meters, &losses).is_ok());
922 }
923
924 #[test]
925 fn test_sequence_statistics() {
926 let mut stats = SequenceStatistics::<f32>::new();
927
928 let gradients = Array2::<f32>::ones((10, 5));
929 let parameters = Array2::<f32>::ones((10, 5));
930 let losses = Array1::<f32>::ones(10);
931
932 assert!(stats.update(&gradients, ¶meters, &losses).is_ok());
933 assert!(stats.get_gradient_stats().mean() > 0.0);
934 }
935
936 #[test]
937 fn test_statistics_accumulator() {
938 let mut acc = StatisticsAccumulator::<f32>::new();
939
940 acc.update(1.0);
941 acc.update(2.0);
942 acc.update(3.0);
943
944 assert_eq!(acc.mean(), 2.0);
945 assert!(acc.std_dev() > 0.0);
946 assert_eq!(acc.min(), 1.0);
947 assert_eq!(acc.max(), 3.0);
948 }
949
950 #[test]
951 fn test_chunking_strategy() {
952 let chunking = ChunkingStrategy::<f32>::new(10, 2);
953 assert!(chunking.is_ok());
954
955 let mut strategy = chunking.expect("unwrap failed");
956 let sequence = Array2::<f32>::ones((25, 5));
957
958 let chunks = strategy.create_chunks(&sequence);
959 assert!(chunks.is_ok());
960
961 let chunk_vec = chunks.expect("unwrap failed");
962 assert!(chunk_vec.len() > 1);
963 }
964}