optirs_learned/transformer_based_optimizer/
sequence_processor.rs

1// Sequence processing for optimization trajectories
2
3use super::config::TransformerBasedOptimizerConfig;
4use crate::error::Result;
5use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
6use scirs2_core::numeric::Float;
7use std::collections::{HashMap, VecDeque};
8use std::fmt::Debug;
9
10/// Type alias for sequence data tuple (gradients, parameters, losses)
11type SequenceDataTuple<T> = (Vec<Array2<T>>, Vec<Array2<T>>, Vec<Array1<T>>);
12
13/// Sequence processing strategy types
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum SequenceProcessingStrategy {
16    /// Sliding window approach
17    SlidingWindow,
18    /// Hierarchical chunking
19    Hierarchical,
20    /// Attention-based selection
21    AttentionBased,
22    /// Adaptive segmentation
23    Adaptive,
24    /// Truncated backpropagation
25    TruncatedBPTT,
26}
27
28/// Optimization sequence processor
29pub struct OptimizationSequenceProcessor<T: Float + Debug + Send + Sync + 'static> {
30    /// Processing strategy
31    strategy: SequenceProcessingStrategy,
32
33    /// Maximum sequence length
34    max_sequence_length: usize,
35
36    /// Window size for sliding window
37    window_size: usize,
38
39    /// Overlap between windows
40    window_overlap: usize,
41
42    /// Model dimension
43    model_dimension: usize,
44
45    /// Sequence history buffer
46    sequence_buffer: SequenceBuffer<T>,
47
48    /// Sequence statistics
49    statistics: SequenceStatistics<T>,
50
51    /// Preprocessing pipeline
52    preprocessor: SequencePreprocessor<T>,
53
54    /// Chunking strategy
55    chunking: ChunkingStrategy<T>,
56}
57
58impl<T: Float + Debug + scirs2_core::numeric::FromPrimitive + Send + Sync>
59    OptimizationSequenceProcessor<T>
60{
61    /// Create new sequence processor
62    pub fn new(config: &TransformerBasedOptimizerConfig<T>) -> Result<Self> {
63        let strategy = SequenceProcessingStrategy::SlidingWindow;
64        let max_sequence_length = config.sequence_length;
65        let window_size = max_sequence_length / 2;
66        let window_overlap = window_size / 4;
67        let model_dimension = config.model_dimension;
68
69        let sequence_buffer = SequenceBuffer::new(1000, model_dimension)?;
70        let statistics = SequenceStatistics::new();
71        let preprocessor = SequencePreprocessor::new(model_dimension)?;
72        let chunking = ChunkingStrategy::new(max_sequence_length, window_size)?;
73
74        Ok(Self {
75            strategy,
76            max_sequence_length,
77            window_size,
78            window_overlap,
79            model_dimension,
80            sequence_buffer,
81            statistics,
82            preprocessor,
83            chunking,
84        })
85    }
86
87    /// Process optimization sequence
88    pub fn process_optimization_sequence(
89        &mut self,
90        gradient_history: &Array2<T>,
91        parameter_history: &Array2<T>,
92        loss_history: &Array1<T>,
93    ) -> Result<Array2<T>> {
94        // Update statistics
95        self.statistics
96            .update(gradient_history, parameter_history, loss_history)?;
97
98        // Store in buffer
99        self.sequence_buffer
100            .add_sequence(gradient_history, parameter_history, loss_history)?;
101
102        match self.strategy {
103            SequenceProcessingStrategy::SlidingWindow => {
104                self.process_sliding_window(gradient_history, parameter_history, loss_history)
105            }
106            SequenceProcessingStrategy::Hierarchical => {
107                self.process_hierarchical(gradient_history, parameter_history, loss_history)
108            }
109            SequenceProcessingStrategy::AttentionBased => {
110                self.process_attention_based(gradient_history, parameter_history, loss_history)
111            }
112            SequenceProcessingStrategy::Adaptive => {
113                self.process_adaptive(gradient_history, parameter_history, loss_history)
114            }
115            SequenceProcessingStrategy::TruncatedBPTT => {
116                self.process_truncated_bptt(gradient_history, parameter_history, loss_history)
117            }
118        }
119    }
120
121    /// Process using sliding window strategy
122    fn process_sliding_window(
123        &mut self,
124        gradient_history: &Array2<T>,
125        parameter_history: &Array2<T>,
126        loss_history: &Array1<T>,
127    ) -> Result<Array2<T>> {
128        let sequence_length = gradient_history.shape()[0];
129
130        if sequence_length <= self.max_sequence_length {
131            // Sequence fits in one window
132            return self.combine_sequences(gradient_history, parameter_history, loss_history);
133        }
134
135        // Process in overlapping windows
136        let mut processed_chunks = Vec::new();
137        let step_size = self.window_size - self.window_overlap;
138
139        for start in (0..sequence_length).step_by(step_size) {
140            let end = (start + self.window_size).min(sequence_length);
141
142            let grad_chunk = gradient_history.slice(s![start..end, ..]);
143            let param_chunk = parameter_history.slice(s![start..end, ..]);
144            let loss_chunk = loss_history.slice(s![start..end]);
145
146            let processed_chunk = self.combine_sequences(
147                &grad_chunk.to_owned(),
148                &param_chunk.to_owned(),
149                &loss_chunk.to_owned(),
150            )?;
151
152            processed_chunks.push(processed_chunk);
153        }
154
155        // Combine processed chunks
156        self.combine_chunks(&processed_chunks)
157    }
158
159    /// Process using hierarchical strategy
160    fn process_hierarchical(
161        &mut self,
162        gradient_history: &Array2<T>,
163        parameter_history: &Array2<T>,
164        loss_history: &Array1<T>,
165    ) -> Result<Array2<T>> {
166        let sequence_length = gradient_history.shape()[0];
167
168        // Create hierarchical representation
169        let mut levels = Vec::new();
170        let mut current_level =
171            self.combine_sequences(gradient_history, parameter_history, loss_history)?;
172        levels.push(current_level.clone());
173
174        // Build hierarchy by downsampling
175        let mut current_length = sequence_length;
176        while current_length > self.max_sequence_length {
177            current_length /= 2;
178            current_level = self.downsample_sequence(&current_level, current_length)?;
179            levels.push(current_level.clone());
180        }
181
182        // Return the highest level that fits
183        Ok(levels.last().unwrap().clone())
184    }
185
186    /// Process using attention-based selection
187    fn process_attention_based(
188        &mut self,
189        gradient_history: &Array2<T>,
190        parameter_history: &Array2<T>,
191        loss_history: &Array1<T>,
192    ) -> Result<Array2<T>> {
193        let sequence_length = gradient_history.shape()[0];
194
195        if sequence_length <= self.max_sequence_length {
196            return self.combine_sequences(gradient_history, parameter_history, loss_history);
197        }
198
199        // Compute importance scores for each time step
200        let importance_scores = self.compute_importance_scores(gradient_history, loss_history)?;
201
202        // Select most important steps
203        let selected_indices =
204            self.select_top_k_indices(&importance_scores, self.max_sequence_length)?;
205
206        // Extract selected sequences
207        let mut selected_gradients =
208            Array2::zeros((self.max_sequence_length, gradient_history.shape()[1]));
209        let mut selected_parameters =
210            Array2::zeros((self.max_sequence_length, parameter_history.shape()[1]));
211        let mut selected_losses = Array1::zeros(self.max_sequence_length);
212
213        for (i, &idx) in selected_indices.iter().enumerate() {
214            selected_gradients
215                .row_mut(i)
216                .assign(&gradient_history.row(idx));
217            selected_parameters
218                .row_mut(i)
219                .assign(&parameter_history.row(idx));
220            selected_losses[i] = loss_history[idx];
221        }
222
223        self.combine_sequences(&selected_gradients, &selected_parameters, &selected_losses)
224    }
225
226    /// Process using adaptive segmentation
227    fn process_adaptive(
228        &mut self,
229        gradient_history: &Array2<T>,
230        parameter_history: &Array2<T>,
231        loss_history: &Array1<T>,
232    ) -> Result<Array2<T>> {
233        // Detect change points in the optimization trajectory
234        let change_points = self.detect_change_points(loss_history)?;
235
236        // Segment sequence based on change points
237        let segments = self.segment_by_change_points(
238            gradient_history,
239            parameter_history,
240            loss_history,
241            &change_points,
242        )?;
243
244        // Process each segment and combine
245        let mut processed_segments = Vec::new();
246        for segment in segments {
247            let processed =
248                self.combine_sequences(&segment.gradients, &segment.parameters, &segment.losses)?;
249            processed_segments.push(processed);
250        }
251
252        self.combine_chunks(&processed_segments)
253    }
254
255    /// Process using truncated BPTT
256    fn process_truncated_bptt(
257        &mut self,
258        gradient_history: &Array2<T>,
259        parameter_history: &Array2<T>,
260        loss_history: &Array1<T>,
261    ) -> Result<Array2<T>> {
262        let sequence_length = gradient_history.shape()[0];
263
264        if sequence_length <= self.max_sequence_length {
265            return self.combine_sequences(gradient_history, parameter_history, loss_history);
266        }
267
268        // Take the most recent subsequence
269        let start_idx = sequence_length - self.max_sequence_length;
270        let grad_chunk = gradient_history.slice(s![start_idx.., ..]);
271        let param_chunk = parameter_history.slice(s![start_idx.., ..]);
272        let loss_chunk = loss_history.slice(s![start_idx..]);
273
274        self.combine_sequences(
275            &grad_chunk.to_owned(),
276            &param_chunk.to_owned(),
277            &loss_chunk.to_owned(),
278        )
279    }
280
281    /// Convert optimization trajectory to training sequences
282    pub fn trajectory_to_sequences(
283        &mut self,
284        trajectory: &super::OptimizationTrajectory<T>,
285    ) -> Result<Vec<super::TrainingSequence<T>>> {
286        let sequence_length = trajectory.gradient_sequence.shape()[0];
287        let mut sequences = Vec::new();
288
289        // Create overlapping sequences for training
290        for start in (0..sequence_length).step_by(self.window_size / 2) {
291            let end = (start + self.window_size).min(sequence_length);
292
293            if end - start < self.window_size / 2 {
294                break; // Skip sequences that are too short
295            }
296
297            let input_end = end - 1;
298            let target_start = start + 1;
299
300            let input_gradients = trajectory.gradient_sequence.slice(s![start..input_end, ..]);
301            let input_parameters = trajectory
302                .parameter_sequence
303                .slice(s![start..input_end, ..]);
304            let input_losses = trajectory.loss_sequence.slice(s![start..input_end]);
305
306            let target_gradients = trajectory
307                .gradient_sequence
308                .slice(s![target_start..end, ..]);
309            let target_parameters = trajectory
310                .parameter_sequence
311                .slice(s![target_start..end, ..]);
312            let target_losses = trajectory.loss_sequence.slice(s![target_start..end]);
313
314            let input = self.combine_sequences(
315                &input_gradients.to_owned(),
316                &input_parameters.to_owned(),
317                &input_losses.to_owned(),
318            )?;
319
320            let target = self.combine_sequences(
321                &target_gradients.to_owned(),
322                &target_parameters.to_owned(),
323                &target_losses.to_owned(),
324            )?;
325
326            sequences.push(super::TrainingSequence {
327                input,
328                target,
329                sequence_length: input_end - start,
330            });
331        }
332
333        Ok(sequences)
334    }
335
336    /// Combine gradients, parameters, and losses into unified sequence
337    fn combine_sequences(
338        &self,
339        gradients: &Array2<T>,
340        parameters: &Array2<T>,
341        losses: &Array1<T>,
342    ) -> Result<Array2<T>> {
343        self.preprocessor
344            .combine_sequences(gradients, parameters, losses)
345    }
346
347    /// Combine multiple processed chunks
348    fn combine_chunks(&self, chunks: &[Array2<T>]) -> Result<Array2<T>> {
349        if chunks.is_empty() {
350            return Err(crate::error::OptimError::Other(
351                "No chunks to combine".to_string(),
352            ));
353        }
354
355        if chunks.len() == 1 {
356            return Ok(chunks[0].clone());
357        }
358
359        // Simple concatenation strategy
360        let total_length: usize = chunks.iter().map(|chunk| chunk.shape()[0]).sum();
361        let feature_dim = chunks[0].shape()[1];
362
363        let mut combined = Array2::zeros((total_length.min(self.max_sequence_length), feature_dim));
364        let mut current_pos = 0;
365
366        for chunk in chunks {
367            let chunk_len = chunk.shape()[0];
368            let copy_len = (chunk_len).min(self.max_sequence_length - current_pos);
369
370            if copy_len == 0 {
371                break;
372            }
373
374            combined
375                .slice_mut(s![current_pos..current_pos + copy_len, ..])
376                .assign(&chunk.slice(s![..copy_len, ..]));
377
378            current_pos += copy_len;
379
380            if current_pos >= self.max_sequence_length {
381                break;
382            }
383        }
384
385        Ok(combined)
386    }
387
388    /// Downsample sequence to target length
389    fn downsample_sequence(&self, sequence: &Array2<T>, target_length: usize) -> Result<Array2<T>> {
390        let current_length = sequence.shape()[0];
391        let feature_dim = sequence.shape()[1];
392
393        if current_length <= target_length {
394            return Ok(sequence.clone());
395        }
396
397        let mut downsampled = Array2::zeros((target_length, feature_dim));
398        let step = current_length as f64 / target_length as f64;
399
400        for i in 0..target_length {
401            let source_idx = (i as f64 * step) as usize;
402            downsampled
403                .row_mut(i)
404                .assign(&sequence.row(source_idx.min(current_length - 1)));
405        }
406
407        Ok(downsampled)
408    }
409
410    /// Compute importance scores for sequence steps
411    fn compute_importance_scores(
412        &self,
413        gradients: &Array2<T>,
414        losses: &Array1<T>,
415    ) -> Result<Array1<T>> {
416        let sequence_length = gradients.shape()[0];
417        let mut scores = Array1::zeros(sequence_length);
418
419        for i in 0..sequence_length {
420            // Combine gradient magnitude and loss change
421            let grad_norm = gradients
422                .row(i)
423                .iter()
424                .map(|&x| x * x)
425                .fold(T::zero(), |acc, x| acc + x)
426                .sqrt();
427
428            let loss_change = if i > 0 {
429                (losses[i] - losses[i - 1]).abs()
430            } else {
431                T::zero()
432            };
433
434            scores[i] = grad_norm + loss_change;
435        }
436
437        Ok(scores)
438    }
439
440    /// Select top-k indices based on importance scores
441    fn select_top_k_indices(&self, scores: &Array1<T>, k: usize) -> Result<Vec<usize>> {
442        let mut indexed_scores: Vec<(usize, T)> = scores
443            .iter()
444            .enumerate()
445            .map(|(i, &score)| (i, score))
446            .collect();
447
448        indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
449
450        let mut selected_indices: Vec<usize> = indexed_scores
451            .into_iter()
452            .take(k)
453            .map(|(idx, _)| idx)
454            .collect();
455
456        selected_indices.sort();
457        Ok(selected_indices)
458    }
459
460    /// Detect change points in loss trajectory
461    fn detect_change_points(&self, losses: &Array1<T>) -> Result<Vec<usize>> {
462        let mut change_points = vec![0]; // Always include start
463        let window_size = 5;
464        let threshold = scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero());
465
466        for i in window_size..losses.len() - window_size {
467            let before_mean = losses.slice(s![i - window_size..i]).mean().unwrap();
468            let after_mean = losses.slice(s![i..i + window_size]).mean().unwrap();
469
470            if (before_mean - after_mean).abs() > threshold {
471                change_points.push(i);
472            }
473        }
474
475        change_points.push(losses.len() - 1); // Always include end
476        Ok(change_points)
477    }
478
479    /// Segment sequences by change points
480    fn segment_by_change_points(
481        &self,
482        gradients: &Array2<T>,
483        parameters: &Array2<T>,
484        losses: &Array1<T>,
485        change_points: &[usize],
486    ) -> Result<Vec<SequenceSegment<T>>> {
487        let mut segments = Vec::new();
488
489        for i in 0..change_points.len() - 1 {
490            let start = change_points[i];
491            let end = change_points[i + 1];
492
493            let segment = SequenceSegment {
494                gradients: gradients.slice(s![start..end, ..]).to_owned(),
495                parameters: parameters.slice(s![start..end, ..]).to_owned(),
496                losses: losses.slice(s![start..end]).to_owned(),
497                start_index: start,
498                end_index: end,
499            };
500
501            segments.push(segment);
502        }
503
504        Ok(segments)
505    }
506
507    /// Set processing strategy
508    pub fn set_strategy(&mut self, strategy: SequenceProcessingStrategy) {
509        self.strategy = strategy;
510    }
511
512    /// Get current strategy
513    pub fn get_strategy(&self) -> SequenceProcessingStrategy {
514        self.strategy
515    }
516
517    /// Get sequence statistics
518    pub fn get_statistics(&self) -> &SequenceStatistics<T> {
519        &self.statistics
520    }
521
522    /// Reset processor state
523    pub fn reset(&mut self) -> Result<()> {
524        self.sequence_buffer.clear();
525        self.statistics.reset();
526        Ok(())
527    }
528}
529
530/// Sequence buffer for storing optimization history
531pub struct SequenceBuffer<T: Float + Debug + Send + Sync + 'static> {
532    /// Gradient history
533    gradient_buffer: VecDeque<Array2<T>>,
534
535    /// Parameter history
536    parameter_buffer: VecDeque<Array2<T>>,
537
538    /// Loss history
539    loss_buffer: VecDeque<Array1<T>>,
540
541    /// Maximum buffer size
542    max_size: usize,
543
544    /// Model dimension
545    model_dimension: usize,
546}
547
548impl<T: Float + Debug + Send + Sync + 'static> SequenceBuffer<T> {
549    pub fn new(max_size: usize, model_dimension: usize) -> Result<Self> {
550        Ok(Self {
551            gradient_buffer: VecDeque::new(),
552            parameter_buffer: VecDeque::new(),
553            loss_buffer: VecDeque::new(),
554            max_size,
555            model_dimension,
556        })
557    }
558
559    pub fn add_sequence(
560        &mut self,
561        gradients: &Array2<T>,
562        parameters: &Array2<T>,
563        losses: &Array1<T>,
564    ) -> Result<()> {
565        self.gradient_buffer.push_back(gradients.clone());
566        self.parameter_buffer.push_back(parameters.clone());
567        self.loss_buffer.push_back(losses.clone());
568
569        while self.gradient_buffer.len() > self.max_size {
570            self.gradient_buffer.pop_front();
571            self.parameter_buffer.pop_front();
572            self.loss_buffer.pop_front();
573        }
574
575        Ok(())
576    }
577
578    pub fn clear(&mut self) {
579        self.gradient_buffer.clear();
580        self.parameter_buffer.clear();
581        self.loss_buffer.clear();
582    }
583
584    pub fn get_recent_sequences(&self, count: usize) -> SequenceDataTuple<T> {
585        let actual_count = count.min(self.gradient_buffer.len());
586
587        let gradients = self
588            .gradient_buffer
589            .iter()
590            .rev()
591            .take(actual_count)
592            .cloned()
593            .collect();
594        let parameters = self
595            .parameter_buffer
596            .iter()
597            .rev()
598            .take(actual_count)
599            .cloned()
600            .collect();
601        let losses = self
602            .loss_buffer
603            .iter()
604            .rev()
605            .take(actual_count)
606            .cloned()
607            .collect();
608
609        (gradients, parameters, losses)
610    }
611}
612
613/// Sequence statistics tracker
614pub struct SequenceStatistics<T: Float + Debug + Send + Sync + 'static> {
615    /// Gradient statistics
616    gradient_stats: StatisticsAccumulator<T>,
617
618    /// Parameter statistics
619    parameter_stats: StatisticsAccumulator<T>,
620
621    /// Loss statistics
622    loss_stats: StatisticsAccumulator<T>,
623
624    /// Sequence length statistics
625    length_stats: StatisticsAccumulator<T>,
626}
627
628impl<T: Float + Debug + Send + Sync + 'static> Default for SequenceStatistics<T> {
629    fn default() -> Self {
630        Self::new()
631    }
632}
633
634impl<T: Float + Debug + Send + Sync + 'static> SequenceStatistics<T> {
635    pub fn new() -> Self {
636        Self {
637            gradient_stats: StatisticsAccumulator::new(),
638            parameter_stats: StatisticsAccumulator::new(),
639            loss_stats: StatisticsAccumulator::new(),
640            length_stats: StatisticsAccumulator::new(),
641        }
642    }
643
644    pub fn update(
645        &mut self,
646        gradients: &Array2<T>,
647        parameters: &Array2<T>,
648        losses: &Array1<T>,
649    ) -> Result<()> {
650        self.gradient_stats.update_from_array2(gradients);
651        self.parameter_stats.update_from_array2(parameters);
652        self.loss_stats.update_from_array1(losses);
653        self.length_stats
654            .update(T::from(gradients.shape()[0]).unwrap());
655
656        Ok(())
657    }
658
659    pub fn reset(&mut self) {
660        self.gradient_stats.reset();
661        self.parameter_stats.reset();
662        self.loss_stats.reset();
663        self.length_stats.reset();
664    }
665
666    pub fn get_gradient_stats(&self) -> &StatisticsAccumulator<T> {
667        &self.gradient_stats
668    }
669
670    pub fn get_parameter_stats(&self) -> &StatisticsAccumulator<T> {
671        &self.parameter_stats
672    }
673
674    pub fn get_loss_stats(&self) -> &StatisticsAccumulator<T> {
675        &self.loss_stats
676    }
677}
678
679/// Sequence preprocessor
680pub struct SequencePreprocessor<T: Float + Debug + Send + Sync + 'static> {
681    /// Model dimension
682    model_dimension: usize,
683
684    /// Normalization statistics
685    normalization_stats: HashMap<String, (T, T)>, // (mean, std)
686}
687
688impl<T: Float + Debug + Send + Sync + 'static> SequencePreprocessor<T> {
689    pub fn new(model_dimension: usize) -> Result<Self> {
690        Ok(Self {
691            model_dimension,
692            normalization_stats: HashMap::new(),
693        })
694    }
695
696    pub fn combine_sequences(
697        &self,
698        gradients: &Array2<T>,
699        parameters: &Array2<T>,
700        losses: &Array1<T>,
701    ) -> Result<Array2<T>> {
702        let sequence_length = gradients.shape()[0];
703        let grad_dim = gradients.shape()[1];
704        let param_dim = parameters.shape()[1];
705
706        // Combined features: [gradients, parameters, loss, normalized_loss]
707        let feature_dim = grad_dim + param_dim + 2;
708        let mut combined = Array2::zeros((sequence_length, feature_dim.min(self.model_dimension)));
709
710        for i in 0..sequence_length {
711            let mut feature_idx = 0;
712
713            // Add gradient features
714            for j in 0..grad_dim.min(self.model_dimension / 3) {
715                if feature_idx < combined.shape()[1] {
716                    combined[[i, feature_idx]] = gradients[[i, j]];
717                    feature_idx += 1;
718                }
719            }
720
721            // Add parameter features
722            for j in 0..param_dim.min(self.model_dimension / 3) {
723                if feature_idx < combined.shape()[1] {
724                    combined[[i, feature_idx]] = parameters[[i, j]];
725                    feature_idx += 1;
726                }
727            }
728
729            // Add loss features
730            if feature_idx < combined.shape()[1] {
731                combined[[i, feature_idx]] = losses[i];
732                feature_idx += 1;
733            }
734
735            // Add normalized loss (loss relative to first loss)
736            if feature_idx < combined.shape()[1] && i > 0 {
737                let normalized_loss = if losses[0] != T::zero() {
738                    losses[i] / losses[0]
739                } else {
740                    T::one()
741                };
742                combined[[i, feature_idx]] = normalized_loss;
743            }
744        }
745
746        Ok(combined)
747    }
748}
749
750/// Chunking strategy
751pub struct ChunkingStrategy<T: Float + Debug + Send + Sync + 'static> {
752    /// Maximum chunk size
753    max_chunk_size: usize,
754
755    /// Overlap between chunks
756    overlap_size: usize,
757
758    /// Chunk statistics
759    chunk_stats: StatisticsAccumulator<T>,
760}
761
762impl<T: Float + Debug + Send + Sync + 'static> ChunkingStrategy<T> {
763    pub fn new(max_chunk_size: usize, overlap_size: usize) -> Result<Self> {
764        Ok(Self {
765            max_chunk_size,
766            overlap_size,
767            chunk_stats: StatisticsAccumulator::new(),
768        })
769    }
770
771    pub fn create_chunks(&mut self, sequence: &Array2<T>) -> Result<Vec<Array2<T>>> {
772        let sequence_length = sequence.shape()[0];
773        let mut chunks = Vec::new();
774
775        if sequence_length <= self.max_chunk_size {
776            chunks.push(sequence.clone());
777            return Ok(chunks);
778        }
779
780        let step_size = self.max_chunk_size - self.overlap_size;
781
782        for start in (0..sequence_length).step_by(step_size) {
783            let end = (start + self.max_chunk_size).min(sequence_length);
784            let chunk = sequence.slice(s![start..end, ..]).to_owned();
785            chunks.push(chunk);
786
787            if end >= sequence_length {
788                break;
789            }
790        }
791
792        Ok(chunks)
793    }
794}
795
796/// Supporting data structures
797#[derive(Debug, Clone)]
798pub struct SequenceSegment<T: Float + Debug + Send + Sync + 'static> {
799    pub gradients: Array2<T>,
800    pub parameters: Array2<T>,
801    pub losses: Array1<T>,
802    pub start_index: usize,
803    pub end_index: usize,
804}
805
806pub struct StatisticsAccumulator<T: Float + Debug + Send + Sync + 'static> {
807    count: usize,
808    sum: T,
809    sum_sq: T,
810    min: T,
811    max: T,
812}
813
814impl<T: Float + Debug + Send + Sync + 'static> Default for StatisticsAccumulator<T> {
815    fn default() -> Self {
816        Self::new()
817    }
818}
819
820impl<T: Float + Debug + Send + Sync + 'static> StatisticsAccumulator<T> {
821    pub fn new() -> Self {
822        Self {
823            count: 0,
824            sum: T::zero(),
825            sum_sq: T::zero(),
826            min: T::infinity(),
827            max: T::neg_infinity(),
828        }
829    }
830
831    pub fn update(&mut self, value: T) {
832        self.count += 1;
833        self.sum = self.sum + value;
834        self.sum_sq = self.sum_sq + value * value;
835        self.min = self.min.min(value);
836        self.max = self.max.max(value);
837    }
838
839    pub fn update_from_array1(&mut self, array: &Array1<T>) {
840        for &value in array.iter() {
841            self.update(value);
842        }
843    }
844
845    pub fn update_from_array2(&mut self, array: &Array2<T>) {
846        for &value in array.iter() {
847            self.update(value);
848        }
849    }
850
851    pub fn mean(&self) -> T {
852        if self.count > 0 {
853            self.sum / scirs2_core::numeric::NumCast::from(self.count).unwrap_or_else(|| T::zero())
854        } else {
855            T::zero()
856        }
857    }
858
859    pub fn variance(&self) -> T {
860        if self.count > 1 {
861            let mean = self.mean();
862            (self.sum_sq
863                / scirs2_core::numeric::NumCast::from(self.count).unwrap_or_else(|| T::zero()))
864                - (mean * mean)
865        } else {
866            T::zero()
867        }
868    }
869
870    pub fn std_dev(&self) -> T {
871        self.variance().sqrt()
872    }
873
874    pub fn min(&self) -> T {
875        self.min
876    }
877
878    pub fn max(&self) -> T {
879        self.max
880    }
881
882    pub fn reset(&mut self) {
883        self.count = 0;
884        self.sum = T::zero();
885        self.sum_sq = T::zero();
886        self.min = T::infinity();
887        self.max = T::neg_infinity();
888    }
889}
890
891// Import for slice macro
892use scirs2_core::ndarray::s;
893
894#[cfg(test)]
895mod tests {
896    use super::*;
897
898    #[test]
899    fn test_sequence_processor_creation() {
900        let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
901        let processor = OptimizationSequenceProcessor::new(&config);
902        assert!(processor.is_ok());
903    }
904
905    #[test]
906    fn test_sequence_buffer() {
907        let buffer = SequenceBuffer::<f32>::new(10, 64);
908        assert!(buffer.is_ok());
909
910        let mut buf = buffer.unwrap();
911        let gradients = Array2::<f32>::ones((5, 64));
912        let parameters = Array2::<f32>::ones((5, 64));
913        let losses = Array1::<f32>::ones(5);
914
915        assert!(buf.add_sequence(&gradients, &parameters, &losses).is_ok());
916    }
917
918    #[test]
919    fn test_sequence_statistics() {
920        let mut stats = SequenceStatistics::<f32>::new();
921
922        let gradients = Array2::<f32>::ones((10, 5));
923        let parameters = Array2::<f32>::ones((10, 5));
924        let losses = Array1::<f32>::ones(10);
925
926        assert!(stats.update(&gradients, &parameters, &losses).is_ok());
927        assert!(stats.get_gradient_stats().mean() > 0.0);
928    }
929
930    #[test]
931    fn test_statistics_accumulator() {
932        let mut acc = StatisticsAccumulator::<f32>::new();
933
934        acc.update(1.0);
935        acc.update(2.0);
936        acc.update(3.0);
937
938        assert_eq!(acc.mean(), 2.0);
939        assert!(acc.std_dev() > 0.0);
940        assert_eq!(acc.min(), 1.0);
941        assert_eq!(acc.max(), 3.0);
942    }
943
944    #[test]
945    fn test_chunking_strategy() {
946        let chunking = ChunkingStrategy::<f32>::new(10, 2);
947        assert!(chunking.is_ok());
948
949        let mut strategy = chunking.unwrap();
950        let sequence = Array2::<f32>::ones((25, 5));
951
952        let chunks = strategy.create_chunks(&sequence);
953        assert!(chunks.is_ok());
954
955        let chunk_vec = chunks.unwrap();
956        assert!(chunk_vec.len() > 1);
957    }
958}