1use super::config::TransformerBasedOptimizerConfig;
4use crate::error::Result;
5use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
6use scirs2_core::numeric::Float;
7use std::collections::{HashMap, VecDeque};
8use std::fmt::Debug;
9
10type SequenceDataTuple<T> = (Vec<Array2<T>>, Vec<Array2<T>>, Vec<Array1<T>>);
12
13#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum SequenceProcessingStrategy {
16 SlidingWindow,
18 Hierarchical,
20 AttentionBased,
22 Adaptive,
24 TruncatedBPTT,
26}
27
28pub struct OptimizationSequenceProcessor<T: Float + Debug + Send + Sync + 'static> {
30 strategy: SequenceProcessingStrategy,
32
33 max_sequence_length: usize,
35
36 window_size: usize,
38
39 window_overlap: usize,
41
42 model_dimension: usize,
44
45 sequence_buffer: SequenceBuffer<T>,
47
48 statistics: SequenceStatistics<T>,
50
51 preprocessor: SequencePreprocessor<T>,
53
54 chunking: ChunkingStrategy<T>,
56}
57
58impl<T: Float + Debug + scirs2_core::numeric::FromPrimitive + Send + Sync>
59 OptimizationSequenceProcessor<T>
60{
61 pub fn new(config: &TransformerBasedOptimizerConfig<T>) -> Result<Self> {
63 let strategy = SequenceProcessingStrategy::SlidingWindow;
64 let max_sequence_length = config.sequence_length;
65 let window_size = max_sequence_length / 2;
66 let window_overlap = window_size / 4;
67 let model_dimension = config.model_dimension;
68
69 let sequence_buffer = SequenceBuffer::new(1000, model_dimension)?;
70 let statistics = SequenceStatistics::new();
71 let preprocessor = SequencePreprocessor::new(model_dimension)?;
72 let chunking = ChunkingStrategy::new(max_sequence_length, window_size)?;
73
74 Ok(Self {
75 strategy,
76 max_sequence_length,
77 window_size,
78 window_overlap,
79 model_dimension,
80 sequence_buffer,
81 statistics,
82 preprocessor,
83 chunking,
84 })
85 }
86
87 pub fn process_optimization_sequence(
89 &mut self,
90 gradient_history: &Array2<T>,
91 parameter_history: &Array2<T>,
92 loss_history: &Array1<T>,
93 ) -> Result<Array2<T>> {
94 self.statistics
96 .update(gradient_history, parameter_history, loss_history)?;
97
98 self.sequence_buffer
100 .add_sequence(gradient_history, parameter_history, loss_history)?;
101
102 match self.strategy {
103 SequenceProcessingStrategy::SlidingWindow => {
104 self.process_sliding_window(gradient_history, parameter_history, loss_history)
105 }
106 SequenceProcessingStrategy::Hierarchical => {
107 self.process_hierarchical(gradient_history, parameter_history, loss_history)
108 }
109 SequenceProcessingStrategy::AttentionBased => {
110 self.process_attention_based(gradient_history, parameter_history, loss_history)
111 }
112 SequenceProcessingStrategy::Adaptive => {
113 self.process_adaptive(gradient_history, parameter_history, loss_history)
114 }
115 SequenceProcessingStrategy::TruncatedBPTT => {
116 self.process_truncated_bptt(gradient_history, parameter_history, loss_history)
117 }
118 }
119 }
120
121 fn process_sliding_window(
123 &mut self,
124 gradient_history: &Array2<T>,
125 parameter_history: &Array2<T>,
126 loss_history: &Array1<T>,
127 ) -> Result<Array2<T>> {
128 let sequence_length = gradient_history.shape()[0];
129
130 if sequence_length <= self.max_sequence_length {
131 return self.combine_sequences(gradient_history, parameter_history, loss_history);
133 }
134
135 let mut processed_chunks = Vec::new();
137 let step_size = self.window_size - self.window_overlap;
138
139 for start in (0..sequence_length).step_by(step_size) {
140 let end = (start + self.window_size).min(sequence_length);
141
142 let grad_chunk = gradient_history.slice(s![start..end, ..]);
143 let param_chunk = parameter_history.slice(s![start..end, ..]);
144 let loss_chunk = loss_history.slice(s![start..end]);
145
146 let processed_chunk = self.combine_sequences(
147 &grad_chunk.to_owned(),
148 ¶m_chunk.to_owned(),
149 &loss_chunk.to_owned(),
150 )?;
151
152 processed_chunks.push(processed_chunk);
153 }
154
155 self.combine_chunks(&processed_chunks)
157 }
158
159 fn process_hierarchical(
161 &mut self,
162 gradient_history: &Array2<T>,
163 parameter_history: &Array2<T>,
164 loss_history: &Array1<T>,
165 ) -> Result<Array2<T>> {
166 let sequence_length = gradient_history.shape()[0];
167
168 let mut levels = Vec::new();
170 let mut current_level =
171 self.combine_sequences(gradient_history, parameter_history, loss_history)?;
172 levels.push(current_level.clone());
173
174 let mut current_length = sequence_length;
176 while current_length > self.max_sequence_length {
177 current_length /= 2;
178 current_level = self.downsample_sequence(¤t_level, current_length)?;
179 levels.push(current_level.clone());
180 }
181
182 Ok(levels.last().unwrap().clone())
184 }
185
186 fn process_attention_based(
188 &mut self,
189 gradient_history: &Array2<T>,
190 parameter_history: &Array2<T>,
191 loss_history: &Array1<T>,
192 ) -> Result<Array2<T>> {
193 let sequence_length = gradient_history.shape()[0];
194
195 if sequence_length <= self.max_sequence_length {
196 return self.combine_sequences(gradient_history, parameter_history, loss_history);
197 }
198
199 let importance_scores = self.compute_importance_scores(gradient_history, loss_history)?;
201
202 let selected_indices =
204 self.select_top_k_indices(&importance_scores, self.max_sequence_length)?;
205
206 let mut selected_gradients =
208 Array2::zeros((self.max_sequence_length, gradient_history.shape()[1]));
209 let mut selected_parameters =
210 Array2::zeros((self.max_sequence_length, parameter_history.shape()[1]));
211 let mut selected_losses = Array1::zeros(self.max_sequence_length);
212
213 for (i, &idx) in selected_indices.iter().enumerate() {
214 selected_gradients
215 .row_mut(i)
216 .assign(&gradient_history.row(idx));
217 selected_parameters
218 .row_mut(i)
219 .assign(¶meter_history.row(idx));
220 selected_losses[i] = loss_history[idx];
221 }
222
223 self.combine_sequences(&selected_gradients, &selected_parameters, &selected_losses)
224 }
225
226 fn process_adaptive(
228 &mut self,
229 gradient_history: &Array2<T>,
230 parameter_history: &Array2<T>,
231 loss_history: &Array1<T>,
232 ) -> Result<Array2<T>> {
233 let change_points = self.detect_change_points(loss_history)?;
235
236 let segments = self.segment_by_change_points(
238 gradient_history,
239 parameter_history,
240 loss_history,
241 &change_points,
242 )?;
243
244 let mut processed_segments = Vec::new();
246 for segment in segments {
247 let processed =
248 self.combine_sequences(&segment.gradients, &segment.parameters, &segment.losses)?;
249 processed_segments.push(processed);
250 }
251
252 self.combine_chunks(&processed_segments)
253 }
254
255 fn process_truncated_bptt(
257 &mut self,
258 gradient_history: &Array2<T>,
259 parameter_history: &Array2<T>,
260 loss_history: &Array1<T>,
261 ) -> Result<Array2<T>> {
262 let sequence_length = gradient_history.shape()[0];
263
264 if sequence_length <= self.max_sequence_length {
265 return self.combine_sequences(gradient_history, parameter_history, loss_history);
266 }
267
268 let start_idx = sequence_length - self.max_sequence_length;
270 let grad_chunk = gradient_history.slice(s![start_idx.., ..]);
271 let param_chunk = parameter_history.slice(s![start_idx.., ..]);
272 let loss_chunk = loss_history.slice(s![start_idx..]);
273
274 self.combine_sequences(
275 &grad_chunk.to_owned(),
276 ¶m_chunk.to_owned(),
277 &loss_chunk.to_owned(),
278 )
279 }
280
281 pub fn trajectory_to_sequences(
283 &mut self,
284 trajectory: &super::OptimizationTrajectory<T>,
285 ) -> Result<Vec<super::TrainingSequence<T>>> {
286 let sequence_length = trajectory.gradient_sequence.shape()[0];
287 let mut sequences = Vec::new();
288
289 for start in (0..sequence_length).step_by(self.window_size / 2) {
291 let end = (start + self.window_size).min(sequence_length);
292
293 if end - start < self.window_size / 2 {
294 break; }
296
297 let input_end = end - 1;
298 let target_start = start + 1;
299
300 let input_gradients = trajectory.gradient_sequence.slice(s![start..input_end, ..]);
301 let input_parameters = trajectory
302 .parameter_sequence
303 .slice(s![start..input_end, ..]);
304 let input_losses = trajectory.loss_sequence.slice(s![start..input_end]);
305
306 let target_gradients = trajectory
307 .gradient_sequence
308 .slice(s![target_start..end, ..]);
309 let target_parameters = trajectory
310 .parameter_sequence
311 .slice(s![target_start..end, ..]);
312 let target_losses = trajectory.loss_sequence.slice(s![target_start..end]);
313
314 let input = self.combine_sequences(
315 &input_gradients.to_owned(),
316 &input_parameters.to_owned(),
317 &input_losses.to_owned(),
318 )?;
319
320 let target = self.combine_sequences(
321 &target_gradients.to_owned(),
322 &target_parameters.to_owned(),
323 &target_losses.to_owned(),
324 )?;
325
326 sequences.push(super::TrainingSequence {
327 input,
328 target,
329 sequence_length: input_end - start,
330 });
331 }
332
333 Ok(sequences)
334 }
335
336 fn combine_sequences(
338 &self,
339 gradients: &Array2<T>,
340 parameters: &Array2<T>,
341 losses: &Array1<T>,
342 ) -> Result<Array2<T>> {
343 self.preprocessor
344 .combine_sequences(gradients, parameters, losses)
345 }
346
347 fn combine_chunks(&self, chunks: &[Array2<T>]) -> Result<Array2<T>> {
349 if chunks.is_empty() {
350 return Err(crate::error::OptimError::Other(
351 "No chunks to combine".to_string(),
352 ));
353 }
354
355 if chunks.len() == 1 {
356 return Ok(chunks[0].clone());
357 }
358
359 let total_length: usize = chunks.iter().map(|chunk| chunk.shape()[0]).sum();
361 let feature_dim = chunks[0].shape()[1];
362
363 let mut combined = Array2::zeros((total_length.min(self.max_sequence_length), feature_dim));
364 let mut current_pos = 0;
365
366 for chunk in chunks {
367 let chunk_len = chunk.shape()[0];
368 let copy_len = (chunk_len).min(self.max_sequence_length - current_pos);
369
370 if copy_len == 0 {
371 break;
372 }
373
374 combined
375 .slice_mut(s![current_pos..current_pos + copy_len, ..])
376 .assign(&chunk.slice(s![..copy_len, ..]));
377
378 current_pos += copy_len;
379
380 if current_pos >= self.max_sequence_length {
381 break;
382 }
383 }
384
385 Ok(combined)
386 }
387
388 fn downsample_sequence(&self, sequence: &Array2<T>, target_length: usize) -> Result<Array2<T>> {
390 let current_length = sequence.shape()[0];
391 let feature_dim = sequence.shape()[1];
392
393 if current_length <= target_length {
394 return Ok(sequence.clone());
395 }
396
397 let mut downsampled = Array2::zeros((target_length, feature_dim));
398 let step = current_length as f64 / target_length as f64;
399
400 for i in 0..target_length {
401 let source_idx = (i as f64 * step) as usize;
402 downsampled
403 .row_mut(i)
404 .assign(&sequence.row(source_idx.min(current_length - 1)));
405 }
406
407 Ok(downsampled)
408 }
409
410 fn compute_importance_scores(
412 &self,
413 gradients: &Array2<T>,
414 losses: &Array1<T>,
415 ) -> Result<Array1<T>> {
416 let sequence_length = gradients.shape()[0];
417 let mut scores = Array1::zeros(sequence_length);
418
419 for i in 0..sequence_length {
420 let grad_norm = gradients
422 .row(i)
423 .iter()
424 .map(|&x| x * x)
425 .fold(T::zero(), |acc, x| acc + x)
426 .sqrt();
427
428 let loss_change = if i > 0 {
429 (losses[i] - losses[i - 1]).abs()
430 } else {
431 T::zero()
432 };
433
434 scores[i] = grad_norm + loss_change;
435 }
436
437 Ok(scores)
438 }
439
440 fn select_top_k_indices(&self, scores: &Array1<T>, k: usize) -> Result<Vec<usize>> {
442 let mut indexed_scores: Vec<(usize, T)> = scores
443 .iter()
444 .enumerate()
445 .map(|(i, &score)| (i, score))
446 .collect();
447
448 indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
449
450 let mut selected_indices: Vec<usize> = indexed_scores
451 .into_iter()
452 .take(k)
453 .map(|(idx, _)| idx)
454 .collect();
455
456 selected_indices.sort();
457 Ok(selected_indices)
458 }
459
460 fn detect_change_points(&self, losses: &Array1<T>) -> Result<Vec<usize>> {
462 let mut change_points = vec![0]; let window_size = 5;
464 let threshold = scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero());
465
466 for i in window_size..losses.len() - window_size {
467 let before_mean = losses.slice(s![i - window_size..i]).mean().unwrap();
468 let after_mean = losses.slice(s![i..i + window_size]).mean().unwrap();
469
470 if (before_mean - after_mean).abs() > threshold {
471 change_points.push(i);
472 }
473 }
474
475 change_points.push(losses.len() - 1); Ok(change_points)
477 }
478
479 fn segment_by_change_points(
481 &self,
482 gradients: &Array2<T>,
483 parameters: &Array2<T>,
484 losses: &Array1<T>,
485 change_points: &[usize],
486 ) -> Result<Vec<SequenceSegment<T>>> {
487 let mut segments = Vec::new();
488
489 for i in 0..change_points.len() - 1 {
490 let start = change_points[i];
491 let end = change_points[i + 1];
492
493 let segment = SequenceSegment {
494 gradients: gradients.slice(s![start..end, ..]).to_owned(),
495 parameters: parameters.slice(s![start..end, ..]).to_owned(),
496 losses: losses.slice(s![start..end]).to_owned(),
497 start_index: start,
498 end_index: end,
499 };
500
501 segments.push(segment);
502 }
503
504 Ok(segments)
505 }
506
507 pub fn set_strategy(&mut self, strategy: SequenceProcessingStrategy) {
509 self.strategy = strategy;
510 }
511
512 pub fn get_strategy(&self) -> SequenceProcessingStrategy {
514 self.strategy
515 }
516
517 pub fn get_statistics(&self) -> &SequenceStatistics<T> {
519 &self.statistics
520 }
521
522 pub fn reset(&mut self) -> Result<()> {
524 self.sequence_buffer.clear();
525 self.statistics.reset();
526 Ok(())
527 }
528}
529
530pub struct SequenceBuffer<T: Float + Debug + Send + Sync + 'static> {
532 gradient_buffer: VecDeque<Array2<T>>,
534
535 parameter_buffer: VecDeque<Array2<T>>,
537
538 loss_buffer: VecDeque<Array1<T>>,
540
541 max_size: usize,
543
544 model_dimension: usize,
546}
547
548impl<T: Float + Debug + Send + Sync + 'static> SequenceBuffer<T> {
549 pub fn new(max_size: usize, model_dimension: usize) -> Result<Self> {
550 Ok(Self {
551 gradient_buffer: VecDeque::new(),
552 parameter_buffer: VecDeque::new(),
553 loss_buffer: VecDeque::new(),
554 max_size,
555 model_dimension,
556 })
557 }
558
559 pub fn add_sequence(
560 &mut self,
561 gradients: &Array2<T>,
562 parameters: &Array2<T>,
563 losses: &Array1<T>,
564 ) -> Result<()> {
565 self.gradient_buffer.push_back(gradients.clone());
566 self.parameter_buffer.push_back(parameters.clone());
567 self.loss_buffer.push_back(losses.clone());
568
569 while self.gradient_buffer.len() > self.max_size {
570 self.gradient_buffer.pop_front();
571 self.parameter_buffer.pop_front();
572 self.loss_buffer.pop_front();
573 }
574
575 Ok(())
576 }
577
578 pub fn clear(&mut self) {
579 self.gradient_buffer.clear();
580 self.parameter_buffer.clear();
581 self.loss_buffer.clear();
582 }
583
584 pub fn get_recent_sequences(&self, count: usize) -> SequenceDataTuple<T> {
585 let actual_count = count.min(self.gradient_buffer.len());
586
587 let gradients = self
588 .gradient_buffer
589 .iter()
590 .rev()
591 .take(actual_count)
592 .cloned()
593 .collect();
594 let parameters = self
595 .parameter_buffer
596 .iter()
597 .rev()
598 .take(actual_count)
599 .cloned()
600 .collect();
601 let losses = self
602 .loss_buffer
603 .iter()
604 .rev()
605 .take(actual_count)
606 .cloned()
607 .collect();
608
609 (gradients, parameters, losses)
610 }
611}
612
613pub struct SequenceStatistics<T: Float + Debug + Send + Sync + 'static> {
615 gradient_stats: StatisticsAccumulator<T>,
617
618 parameter_stats: StatisticsAccumulator<T>,
620
621 loss_stats: StatisticsAccumulator<T>,
623
624 length_stats: StatisticsAccumulator<T>,
626}
627
628impl<T: Float + Debug + Send + Sync + 'static> Default for SequenceStatistics<T> {
629 fn default() -> Self {
630 Self::new()
631 }
632}
633
634impl<T: Float + Debug + Send + Sync + 'static> SequenceStatistics<T> {
635 pub fn new() -> Self {
636 Self {
637 gradient_stats: StatisticsAccumulator::new(),
638 parameter_stats: StatisticsAccumulator::new(),
639 loss_stats: StatisticsAccumulator::new(),
640 length_stats: StatisticsAccumulator::new(),
641 }
642 }
643
644 pub fn update(
645 &mut self,
646 gradients: &Array2<T>,
647 parameters: &Array2<T>,
648 losses: &Array1<T>,
649 ) -> Result<()> {
650 self.gradient_stats.update_from_array2(gradients);
651 self.parameter_stats.update_from_array2(parameters);
652 self.loss_stats.update_from_array1(losses);
653 self.length_stats
654 .update(T::from(gradients.shape()[0]).unwrap());
655
656 Ok(())
657 }
658
659 pub fn reset(&mut self) {
660 self.gradient_stats.reset();
661 self.parameter_stats.reset();
662 self.loss_stats.reset();
663 self.length_stats.reset();
664 }
665
666 pub fn get_gradient_stats(&self) -> &StatisticsAccumulator<T> {
667 &self.gradient_stats
668 }
669
670 pub fn get_parameter_stats(&self) -> &StatisticsAccumulator<T> {
671 &self.parameter_stats
672 }
673
674 pub fn get_loss_stats(&self) -> &StatisticsAccumulator<T> {
675 &self.loss_stats
676 }
677}
678
679pub struct SequencePreprocessor<T: Float + Debug + Send + Sync + 'static> {
681 model_dimension: usize,
683
684 normalization_stats: HashMap<String, (T, T)>, }
687
688impl<T: Float + Debug + Send + Sync + 'static> SequencePreprocessor<T> {
689 pub fn new(model_dimension: usize) -> Result<Self> {
690 Ok(Self {
691 model_dimension,
692 normalization_stats: HashMap::new(),
693 })
694 }
695
696 pub fn combine_sequences(
697 &self,
698 gradients: &Array2<T>,
699 parameters: &Array2<T>,
700 losses: &Array1<T>,
701 ) -> Result<Array2<T>> {
702 let sequence_length = gradients.shape()[0];
703 let grad_dim = gradients.shape()[1];
704 let param_dim = parameters.shape()[1];
705
706 let feature_dim = grad_dim + param_dim + 2;
708 let mut combined = Array2::zeros((sequence_length, feature_dim.min(self.model_dimension)));
709
710 for i in 0..sequence_length {
711 let mut feature_idx = 0;
712
713 for j in 0..grad_dim.min(self.model_dimension / 3) {
715 if feature_idx < combined.shape()[1] {
716 combined[[i, feature_idx]] = gradients[[i, j]];
717 feature_idx += 1;
718 }
719 }
720
721 for j in 0..param_dim.min(self.model_dimension / 3) {
723 if feature_idx < combined.shape()[1] {
724 combined[[i, feature_idx]] = parameters[[i, j]];
725 feature_idx += 1;
726 }
727 }
728
729 if feature_idx < combined.shape()[1] {
731 combined[[i, feature_idx]] = losses[i];
732 feature_idx += 1;
733 }
734
735 if feature_idx < combined.shape()[1] && i > 0 {
737 let normalized_loss = if losses[0] != T::zero() {
738 losses[i] / losses[0]
739 } else {
740 T::one()
741 };
742 combined[[i, feature_idx]] = normalized_loss;
743 }
744 }
745
746 Ok(combined)
747 }
748}
749
750pub struct ChunkingStrategy<T: Float + Debug + Send + Sync + 'static> {
752 max_chunk_size: usize,
754
755 overlap_size: usize,
757
758 chunk_stats: StatisticsAccumulator<T>,
760}
761
762impl<T: Float + Debug + Send + Sync + 'static> ChunkingStrategy<T> {
763 pub fn new(max_chunk_size: usize, overlap_size: usize) -> Result<Self> {
764 Ok(Self {
765 max_chunk_size,
766 overlap_size,
767 chunk_stats: StatisticsAccumulator::new(),
768 })
769 }
770
771 pub fn create_chunks(&mut self, sequence: &Array2<T>) -> Result<Vec<Array2<T>>> {
772 let sequence_length = sequence.shape()[0];
773 let mut chunks = Vec::new();
774
775 if sequence_length <= self.max_chunk_size {
776 chunks.push(sequence.clone());
777 return Ok(chunks);
778 }
779
780 let step_size = self.max_chunk_size - self.overlap_size;
781
782 for start in (0..sequence_length).step_by(step_size) {
783 let end = (start + self.max_chunk_size).min(sequence_length);
784 let chunk = sequence.slice(s![start..end, ..]).to_owned();
785 chunks.push(chunk);
786
787 if end >= sequence_length {
788 break;
789 }
790 }
791
792 Ok(chunks)
793 }
794}
795
796#[derive(Debug, Clone)]
798pub struct SequenceSegment<T: Float + Debug + Send + Sync + 'static> {
799 pub gradients: Array2<T>,
800 pub parameters: Array2<T>,
801 pub losses: Array1<T>,
802 pub start_index: usize,
803 pub end_index: usize,
804}
805
806pub struct StatisticsAccumulator<T: Float + Debug + Send + Sync + 'static> {
807 count: usize,
808 sum: T,
809 sum_sq: T,
810 min: T,
811 max: T,
812}
813
814impl<T: Float + Debug + Send + Sync + 'static> Default for StatisticsAccumulator<T> {
815 fn default() -> Self {
816 Self::new()
817 }
818}
819
820impl<T: Float + Debug + Send + Sync + 'static> StatisticsAccumulator<T> {
821 pub fn new() -> Self {
822 Self {
823 count: 0,
824 sum: T::zero(),
825 sum_sq: T::zero(),
826 min: T::infinity(),
827 max: T::neg_infinity(),
828 }
829 }
830
831 pub fn update(&mut self, value: T) {
832 self.count += 1;
833 self.sum = self.sum + value;
834 self.sum_sq = self.sum_sq + value * value;
835 self.min = self.min.min(value);
836 self.max = self.max.max(value);
837 }
838
839 pub fn update_from_array1(&mut self, array: &Array1<T>) {
840 for &value in array.iter() {
841 self.update(value);
842 }
843 }
844
845 pub fn update_from_array2(&mut self, array: &Array2<T>) {
846 for &value in array.iter() {
847 self.update(value);
848 }
849 }
850
851 pub fn mean(&self) -> T {
852 if self.count > 0 {
853 self.sum / scirs2_core::numeric::NumCast::from(self.count).unwrap_or_else(|| T::zero())
854 } else {
855 T::zero()
856 }
857 }
858
859 pub fn variance(&self) -> T {
860 if self.count > 1 {
861 let mean = self.mean();
862 (self.sum_sq
863 / scirs2_core::numeric::NumCast::from(self.count).unwrap_or_else(|| T::zero()))
864 - (mean * mean)
865 } else {
866 T::zero()
867 }
868 }
869
870 pub fn std_dev(&self) -> T {
871 self.variance().sqrt()
872 }
873
874 pub fn min(&self) -> T {
875 self.min
876 }
877
878 pub fn max(&self) -> T {
879 self.max
880 }
881
882 pub fn reset(&mut self) {
883 self.count = 0;
884 self.sum = T::zero();
885 self.sum_sq = T::zero();
886 self.min = T::infinity();
887 self.max = T::neg_infinity();
888 }
889}
890
891use scirs2_core::ndarray::s;
893
894#[cfg(test)]
895mod tests {
896 use super::*;
897
898 #[test]
899 fn test_sequence_processor_creation() {
900 let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
901 let processor = OptimizationSequenceProcessor::new(&config);
902 assert!(processor.is_ok());
903 }
904
905 #[test]
906 fn test_sequence_buffer() {
907 let buffer = SequenceBuffer::<f32>::new(10, 64);
908 assert!(buffer.is_ok());
909
910 let mut buf = buffer.unwrap();
911 let gradients = Array2::<f32>::ones((5, 64));
912 let parameters = Array2::<f32>::ones((5, 64));
913 let losses = Array1::<f32>::ones(5);
914
915 assert!(buf.add_sequence(&gradients, ¶meters, &losses).is_ok());
916 }
917
918 #[test]
919 fn test_sequence_statistics() {
920 let mut stats = SequenceStatistics::<f32>::new();
921
922 let gradients = Array2::<f32>::ones((10, 5));
923 let parameters = Array2::<f32>::ones((10, 5));
924 let losses = Array1::<f32>::ones(10);
925
926 assert!(stats.update(&gradients, ¶meters, &losses).is_ok());
927 assert!(stats.get_gradient_stats().mean() > 0.0);
928 }
929
930 #[test]
931 fn test_statistics_accumulator() {
932 let mut acc = StatisticsAccumulator::<f32>::new();
933
934 acc.update(1.0);
935 acc.update(2.0);
936 acc.update(3.0);
937
938 assert_eq!(acc.mean(), 2.0);
939 assert!(acc.std_dev() > 0.0);
940 assert_eq!(acc.min(), 1.0);
941 assert_eq!(acc.max(), 3.0);
942 }
943
944 #[test]
945 fn test_chunking_strategy() {
946 let chunking = ChunkingStrategy::<f32>::new(10, 2);
947 assert!(chunking.is_ok());
948
949 let mut strategy = chunking.unwrap();
950 let sequence = Array2::<f32>::ones((25, 5));
951
952 let chunks = strategy.create_chunks(&sequence);
953 assert!(chunks.is_ok());
954
955 let chunk_vec = chunks.unwrap();
956 assert!(chunk_vec.len() > 1);
957 }
958}