Skip to main content

optirs_learned/transformer_based_optimizer/
sequence_processor.rs

1// Sequence processing for optimization trajectories
2
3use super::config::TransformerBasedOptimizerConfig;
4use crate::error::Result;
5use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
6use scirs2_core::numeric::Float;
7use std::collections::{HashMap, VecDeque};
8use std::fmt::Debug;
9
10/// Type alias for sequence data tuple (gradients, parameters, losses)
11type SequenceDataTuple<T> = (Vec<Array2<T>>, Vec<Array2<T>>, Vec<Array1<T>>);
12
13/// Sequence processing strategy types
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum SequenceProcessingStrategy {
16    /// Sliding window approach
17    SlidingWindow,
18    /// Hierarchical chunking
19    Hierarchical,
20    /// Attention-based selection
21    AttentionBased,
22    /// Adaptive segmentation
23    Adaptive,
24    /// Truncated backpropagation
25    TruncatedBPTT,
26}
27
28/// Optimization sequence processor
29pub struct OptimizationSequenceProcessor<T: Float + Debug + Send + Sync + 'static> {
30    /// Processing strategy
31    strategy: SequenceProcessingStrategy,
32
33    /// Maximum sequence length
34    max_sequence_length: usize,
35
36    /// Window size for sliding window
37    window_size: usize,
38
39    /// Overlap between windows
40    window_overlap: usize,
41
42    /// Model dimension
43    model_dimension: usize,
44
45    /// Sequence history buffer
46    sequence_buffer: SequenceBuffer<T>,
47
48    /// Sequence statistics
49    statistics: SequenceStatistics<T>,
50
51    /// Preprocessing pipeline
52    preprocessor: SequencePreprocessor<T>,
53
54    /// Chunking strategy
55    chunking: ChunkingStrategy<T>,
56}
57
58impl<T: Float + Debug + scirs2_core::numeric::FromPrimitive + Send + Sync>
59    OptimizationSequenceProcessor<T>
60{
61    /// Create new sequence processor
62    pub fn new(config: &TransformerBasedOptimizerConfig<T>) -> Result<Self> {
63        let strategy = SequenceProcessingStrategy::SlidingWindow;
64        let max_sequence_length = config.sequence_length;
65        let window_size = max_sequence_length / 2;
66        let window_overlap = window_size / 4;
67        let model_dimension = config.model_dimension;
68
69        let sequence_buffer = SequenceBuffer::new(1000, model_dimension)?;
70        let statistics = SequenceStatistics::new();
71        let preprocessor = SequencePreprocessor::new(model_dimension)?;
72        let chunking = ChunkingStrategy::new(max_sequence_length, window_size)?;
73
74        Ok(Self {
75            strategy,
76            max_sequence_length,
77            window_size,
78            window_overlap,
79            model_dimension,
80            sequence_buffer,
81            statistics,
82            preprocessor,
83            chunking,
84        })
85    }
86
87    /// Process optimization sequence
88    pub fn process_optimization_sequence(
89        &mut self,
90        gradient_history: &Array2<T>,
91        parameter_history: &Array2<T>,
92        loss_history: &Array1<T>,
93    ) -> Result<Array2<T>> {
94        // Update statistics
95        self.statistics
96            .update(gradient_history, parameter_history, loss_history)?;
97
98        // Store in buffer
99        self.sequence_buffer
100            .add_sequence(gradient_history, parameter_history, loss_history)?;
101
102        match self.strategy {
103            SequenceProcessingStrategy::SlidingWindow => {
104                self.process_sliding_window(gradient_history, parameter_history, loss_history)
105            }
106            SequenceProcessingStrategy::Hierarchical => {
107                self.process_hierarchical(gradient_history, parameter_history, loss_history)
108            }
109            SequenceProcessingStrategy::AttentionBased => {
110                self.process_attention_based(gradient_history, parameter_history, loss_history)
111            }
112            SequenceProcessingStrategy::Adaptive => {
113                self.process_adaptive(gradient_history, parameter_history, loss_history)
114            }
115            SequenceProcessingStrategy::TruncatedBPTT => {
116                self.process_truncated_bptt(gradient_history, parameter_history, loss_history)
117            }
118        }
119    }
120
121    /// Process using sliding window strategy
122    fn process_sliding_window(
123        &mut self,
124        gradient_history: &Array2<T>,
125        parameter_history: &Array2<T>,
126        loss_history: &Array1<T>,
127    ) -> Result<Array2<T>> {
128        let sequence_length = gradient_history.shape()[0];
129
130        if sequence_length <= self.max_sequence_length {
131            // Sequence fits in one window
132            return self.combine_sequences(gradient_history, parameter_history, loss_history);
133        }
134
135        // Process in overlapping windows
136        let mut processed_chunks = Vec::new();
137        let step_size = self.window_size - self.window_overlap;
138
139        for start in (0..sequence_length).step_by(step_size) {
140            let end = (start + self.window_size).min(sequence_length);
141
142            let grad_chunk = gradient_history.slice(s![start..end, ..]);
143            let param_chunk = parameter_history.slice(s![start..end, ..]);
144            let loss_chunk = loss_history.slice(s![start..end]);
145
146            let processed_chunk = self.combine_sequences(
147                &grad_chunk.to_owned(),
148                &param_chunk.to_owned(),
149                &loss_chunk.to_owned(),
150            )?;
151
152            processed_chunks.push(processed_chunk);
153        }
154
155        // Combine processed chunks
156        self.combine_chunks(&processed_chunks)
157    }
158
159    /// Process using hierarchical strategy
160    fn process_hierarchical(
161        &mut self,
162        gradient_history: &Array2<T>,
163        parameter_history: &Array2<T>,
164        loss_history: &Array1<T>,
165    ) -> Result<Array2<T>> {
166        let sequence_length = gradient_history.shape()[0];
167
168        // Create hierarchical representation
169        let mut levels = Vec::new();
170        let mut current_level =
171            self.combine_sequences(gradient_history, parameter_history, loss_history)?;
172        levels.push(current_level.clone());
173
174        // Build hierarchy by downsampling
175        let mut current_length = sequence_length;
176        while current_length > self.max_sequence_length {
177            current_length /= 2;
178            current_level = self.downsample_sequence(&current_level, current_length)?;
179            levels.push(current_level.clone());
180        }
181
182        // Return the highest level that fits
183        Ok(levels.last().expect("unwrap failed").clone())
184    }
185
186    /// Process using attention-based selection
187    fn process_attention_based(
188        &mut self,
189        gradient_history: &Array2<T>,
190        parameter_history: &Array2<T>,
191        loss_history: &Array1<T>,
192    ) -> Result<Array2<T>> {
193        let sequence_length = gradient_history.shape()[0];
194
195        if sequence_length <= self.max_sequence_length {
196            return self.combine_sequences(gradient_history, parameter_history, loss_history);
197        }
198
199        // Compute importance scores for each time step
200        let importance_scores = self.compute_importance_scores(gradient_history, loss_history)?;
201
202        // Select most important steps
203        let selected_indices =
204            self.select_top_k_indices(&importance_scores, self.max_sequence_length)?;
205
206        // Extract selected sequences
207        let mut selected_gradients =
208            Array2::zeros((self.max_sequence_length, gradient_history.shape()[1]));
209        let mut selected_parameters =
210            Array2::zeros((self.max_sequence_length, parameter_history.shape()[1]));
211        let mut selected_losses = Array1::zeros(self.max_sequence_length);
212
213        for (i, &idx) in selected_indices.iter().enumerate() {
214            selected_gradients
215                .row_mut(i)
216                .assign(&gradient_history.row(idx));
217            selected_parameters
218                .row_mut(i)
219                .assign(&parameter_history.row(idx));
220            selected_losses[i] = loss_history[idx];
221        }
222
223        self.combine_sequences(&selected_gradients, &selected_parameters, &selected_losses)
224    }
225
226    /// Process using adaptive segmentation
227    fn process_adaptive(
228        &mut self,
229        gradient_history: &Array2<T>,
230        parameter_history: &Array2<T>,
231        loss_history: &Array1<T>,
232    ) -> Result<Array2<T>> {
233        // Detect change points in the optimization trajectory
234        let change_points = self.detect_change_points(loss_history)?;
235
236        // Segment sequence based on change points
237        let segments = self.segment_by_change_points(
238            gradient_history,
239            parameter_history,
240            loss_history,
241            &change_points,
242        )?;
243
244        // Process each segment and combine
245        let mut processed_segments = Vec::new();
246        for segment in segments {
247            let processed =
248                self.combine_sequences(&segment.gradients, &segment.parameters, &segment.losses)?;
249            processed_segments.push(processed);
250        }
251
252        self.combine_chunks(&processed_segments)
253    }
254
255    /// Process using truncated BPTT
256    fn process_truncated_bptt(
257        &mut self,
258        gradient_history: &Array2<T>,
259        parameter_history: &Array2<T>,
260        loss_history: &Array1<T>,
261    ) -> Result<Array2<T>> {
262        let sequence_length = gradient_history.shape()[0];
263
264        if sequence_length <= self.max_sequence_length {
265            return self.combine_sequences(gradient_history, parameter_history, loss_history);
266        }
267
268        // Take the most recent subsequence
269        let start_idx = sequence_length - self.max_sequence_length;
270        let grad_chunk = gradient_history.slice(s![start_idx.., ..]);
271        let param_chunk = parameter_history.slice(s![start_idx.., ..]);
272        let loss_chunk = loss_history.slice(s![start_idx..]);
273
274        self.combine_sequences(
275            &grad_chunk.to_owned(),
276            &param_chunk.to_owned(),
277            &loss_chunk.to_owned(),
278        )
279    }
280
281    /// Convert optimization trajectory to training sequences
282    pub fn trajectory_to_sequences(
283        &mut self,
284        trajectory: &super::OptimizationTrajectory<T>,
285    ) -> Result<Vec<super::TrainingSequence<T>>> {
286        let sequence_length = trajectory.gradient_sequence.shape()[0];
287        let mut sequences = Vec::new();
288
289        // Create overlapping sequences for training
290        for start in (0..sequence_length).step_by(self.window_size / 2) {
291            let end = (start + self.window_size).min(sequence_length);
292
293            if end - start < self.window_size / 2 {
294                break; // Skip sequences that are too short
295            }
296
297            let input_end = end - 1;
298            let target_start = start + 1;
299
300            let input_gradients = trajectory.gradient_sequence.slice(s![start..input_end, ..]);
301            let input_parameters = trajectory
302                .parameter_sequence
303                .slice(s![start..input_end, ..]);
304            let input_losses = trajectory.loss_sequence.slice(s![start..input_end]);
305
306            let target_gradients = trajectory
307                .gradient_sequence
308                .slice(s![target_start..end, ..]);
309            let target_parameters = trajectory
310                .parameter_sequence
311                .slice(s![target_start..end, ..]);
312            let target_losses = trajectory.loss_sequence.slice(s![target_start..end]);
313
314            let input = self.combine_sequences(
315                &input_gradients.to_owned(),
316                &input_parameters.to_owned(),
317                &input_losses.to_owned(),
318            )?;
319
320            let target = self.combine_sequences(
321                &target_gradients.to_owned(),
322                &target_parameters.to_owned(),
323                &target_losses.to_owned(),
324            )?;
325
326            sequences.push(super::TrainingSequence {
327                input,
328                target,
329                sequence_length: input_end - start,
330            });
331        }
332
333        Ok(sequences)
334    }
335
336    /// Combine gradients, parameters, and losses into unified sequence
337    fn combine_sequences(
338        &self,
339        gradients: &Array2<T>,
340        parameters: &Array2<T>,
341        losses: &Array1<T>,
342    ) -> Result<Array2<T>> {
343        self.preprocessor
344            .combine_sequences(gradients, parameters, losses)
345    }
346
347    /// Combine multiple processed chunks
348    fn combine_chunks(&self, chunks: &[Array2<T>]) -> Result<Array2<T>> {
349        if chunks.is_empty() {
350            return Err(crate::error::OptimError::Other(
351                "No chunks to combine".to_string(),
352            ));
353        }
354
355        if chunks.len() == 1 {
356            return Ok(chunks[0].clone());
357        }
358
359        // Simple concatenation strategy
360        let total_length: usize = chunks.iter().map(|chunk| chunk.shape()[0]).sum();
361        let feature_dim = chunks[0].shape()[1];
362
363        let mut combined = Array2::zeros((total_length.min(self.max_sequence_length), feature_dim));
364        let mut current_pos = 0;
365
366        for chunk in chunks {
367            let chunk_len = chunk.shape()[0];
368            let copy_len = (chunk_len).min(self.max_sequence_length - current_pos);
369
370            if copy_len == 0 {
371                break;
372            }
373
374            combined
375                .slice_mut(s![current_pos..current_pos + copy_len, ..])
376                .assign(&chunk.slice(s![..copy_len, ..]));
377
378            current_pos += copy_len;
379
380            if current_pos >= self.max_sequence_length {
381                break;
382            }
383        }
384
385        Ok(combined)
386    }
387
388    /// Downsample sequence to target length
389    fn downsample_sequence(&self, sequence: &Array2<T>, target_length: usize) -> Result<Array2<T>> {
390        let current_length = sequence.shape()[0];
391        let feature_dim = sequence.shape()[1];
392
393        if current_length <= target_length {
394            return Ok(sequence.clone());
395        }
396
397        let mut downsampled = Array2::zeros((target_length, feature_dim));
398        let step = current_length as f64 / target_length as f64;
399
400        for i in 0..target_length {
401            let source_idx = (i as f64 * step) as usize;
402            downsampled
403                .row_mut(i)
404                .assign(&sequence.row(source_idx.min(current_length - 1)));
405        }
406
407        Ok(downsampled)
408    }
409
410    /// Compute importance scores for sequence steps
411    fn compute_importance_scores(
412        &self,
413        gradients: &Array2<T>,
414        losses: &Array1<T>,
415    ) -> Result<Array1<T>> {
416        let sequence_length = gradients.shape()[0];
417        let mut scores = Array1::zeros(sequence_length);
418
419        for i in 0..sequence_length {
420            // Combine gradient magnitude and loss change
421            let grad_norm = gradients
422                .row(i)
423                .iter()
424                .map(|&x| x * x)
425                .fold(T::zero(), |acc, x| acc + x)
426                .sqrt();
427
428            let loss_change = if i > 0 {
429                (losses[i] - losses[i - 1]).abs()
430            } else {
431                T::zero()
432            };
433
434            scores[i] = grad_norm + loss_change;
435        }
436
437        Ok(scores)
438    }
439
440    /// Select top-k indices based on importance scores
441    fn select_top_k_indices(&self, scores: &Array1<T>, k: usize) -> Result<Vec<usize>> {
442        let mut indexed_scores: Vec<(usize, T)> = scores
443            .iter()
444            .enumerate()
445            .map(|(i, &score)| (i, score))
446            .collect();
447
448        indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("unwrap failed"));
449
450        let mut selected_indices: Vec<usize> = indexed_scores
451            .into_iter()
452            .take(k)
453            .map(|(idx, _)| idx)
454            .collect();
455
456        selected_indices.sort();
457        Ok(selected_indices)
458    }
459
460    /// Detect change points in loss trajectory
461    fn detect_change_points(&self, losses: &Array1<T>) -> Result<Vec<usize>> {
462        let mut change_points = vec![0]; // Always include start
463        let window_size = 5;
464        let threshold = scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero());
465
466        for i in window_size..losses.len() - window_size {
467            let before_mean = losses
468                .slice(s![i - window_size..i])
469                .mean()
470                .expect("unwrap failed");
471            let after_mean = losses
472                .slice(s![i..i + window_size])
473                .mean()
474                .expect("unwrap failed");
475
476            if (before_mean - after_mean).abs() > threshold {
477                change_points.push(i);
478            }
479        }
480
481        change_points.push(losses.len() - 1); // Always include end
482        Ok(change_points)
483    }
484
485    /// Segment sequences by change points
486    fn segment_by_change_points(
487        &self,
488        gradients: &Array2<T>,
489        parameters: &Array2<T>,
490        losses: &Array1<T>,
491        change_points: &[usize],
492    ) -> Result<Vec<SequenceSegment<T>>> {
493        let mut segments = Vec::new();
494
495        for i in 0..change_points.len() - 1 {
496            let start = change_points[i];
497            let end = change_points[i + 1];
498
499            let segment = SequenceSegment {
500                gradients: gradients.slice(s![start..end, ..]).to_owned(),
501                parameters: parameters.slice(s![start..end, ..]).to_owned(),
502                losses: losses.slice(s![start..end]).to_owned(),
503                start_index: start,
504                end_index: end,
505            };
506
507            segments.push(segment);
508        }
509
510        Ok(segments)
511    }
512
513    /// Set processing strategy
514    pub fn set_strategy(&mut self, strategy: SequenceProcessingStrategy) {
515        self.strategy = strategy;
516    }
517
518    /// Get current strategy
519    pub fn get_strategy(&self) -> SequenceProcessingStrategy {
520        self.strategy
521    }
522
523    /// Get sequence statistics
524    pub fn get_statistics(&self) -> &SequenceStatistics<T> {
525        &self.statistics
526    }
527
528    /// Reset processor state
529    pub fn reset(&mut self) -> Result<()> {
530        self.sequence_buffer.clear();
531        self.statistics.reset();
532        Ok(())
533    }
534}
535
536/// Sequence buffer for storing optimization history
537pub struct SequenceBuffer<T: Float + Debug + Send + Sync + 'static> {
538    /// Gradient history
539    gradient_buffer: VecDeque<Array2<T>>,
540
541    /// Parameter history
542    parameter_buffer: VecDeque<Array2<T>>,
543
544    /// Loss history
545    loss_buffer: VecDeque<Array1<T>>,
546
547    /// Maximum buffer size
548    max_size: usize,
549
550    /// Model dimension
551    model_dimension: usize,
552}
553
554impl<T: Float + Debug + Send + Sync + 'static> SequenceBuffer<T> {
555    pub fn new(max_size: usize, model_dimension: usize) -> Result<Self> {
556        Ok(Self {
557            gradient_buffer: VecDeque::new(),
558            parameter_buffer: VecDeque::new(),
559            loss_buffer: VecDeque::new(),
560            max_size,
561            model_dimension,
562        })
563    }
564
565    pub fn add_sequence(
566        &mut self,
567        gradients: &Array2<T>,
568        parameters: &Array2<T>,
569        losses: &Array1<T>,
570    ) -> Result<()> {
571        self.gradient_buffer.push_back(gradients.clone());
572        self.parameter_buffer.push_back(parameters.clone());
573        self.loss_buffer.push_back(losses.clone());
574
575        while self.gradient_buffer.len() > self.max_size {
576            self.gradient_buffer.pop_front();
577            self.parameter_buffer.pop_front();
578            self.loss_buffer.pop_front();
579        }
580
581        Ok(())
582    }
583
584    pub fn clear(&mut self) {
585        self.gradient_buffer.clear();
586        self.parameter_buffer.clear();
587        self.loss_buffer.clear();
588    }
589
590    pub fn get_recent_sequences(&self, count: usize) -> SequenceDataTuple<T> {
591        let actual_count = count.min(self.gradient_buffer.len());
592
593        let gradients = self
594            .gradient_buffer
595            .iter()
596            .rev()
597            .take(actual_count)
598            .cloned()
599            .collect();
600        let parameters = self
601            .parameter_buffer
602            .iter()
603            .rev()
604            .take(actual_count)
605            .cloned()
606            .collect();
607        let losses = self
608            .loss_buffer
609            .iter()
610            .rev()
611            .take(actual_count)
612            .cloned()
613            .collect();
614
615        (gradients, parameters, losses)
616    }
617}
618
619/// Sequence statistics tracker
620pub struct SequenceStatistics<T: Float + Debug + Send + Sync + 'static> {
621    /// Gradient statistics
622    gradient_stats: StatisticsAccumulator<T>,
623
624    /// Parameter statistics
625    parameter_stats: StatisticsAccumulator<T>,
626
627    /// Loss statistics
628    loss_stats: StatisticsAccumulator<T>,
629
630    /// Sequence length statistics
631    length_stats: StatisticsAccumulator<T>,
632}
633
634impl<T: Float + Debug + Send + Sync + 'static> Default for SequenceStatistics<T> {
635    fn default() -> Self {
636        Self::new()
637    }
638}
639
640impl<T: Float + Debug + Send + Sync + 'static> SequenceStatistics<T> {
641    pub fn new() -> Self {
642        Self {
643            gradient_stats: StatisticsAccumulator::new(),
644            parameter_stats: StatisticsAccumulator::new(),
645            loss_stats: StatisticsAccumulator::new(),
646            length_stats: StatisticsAccumulator::new(),
647        }
648    }
649
650    pub fn update(
651        &mut self,
652        gradients: &Array2<T>,
653        parameters: &Array2<T>,
654        losses: &Array1<T>,
655    ) -> Result<()> {
656        self.gradient_stats.update_from_array2(gradients);
657        self.parameter_stats.update_from_array2(parameters);
658        self.loss_stats.update_from_array1(losses);
659        self.length_stats
660            .update(T::from(gradients.shape()[0]).expect("unwrap failed"));
661
662        Ok(())
663    }
664
665    pub fn reset(&mut self) {
666        self.gradient_stats.reset();
667        self.parameter_stats.reset();
668        self.loss_stats.reset();
669        self.length_stats.reset();
670    }
671
672    pub fn get_gradient_stats(&self) -> &StatisticsAccumulator<T> {
673        &self.gradient_stats
674    }
675
676    pub fn get_parameter_stats(&self) -> &StatisticsAccumulator<T> {
677        &self.parameter_stats
678    }
679
680    pub fn get_loss_stats(&self) -> &StatisticsAccumulator<T> {
681        &self.loss_stats
682    }
683}
684
685/// Sequence preprocessor
686pub struct SequencePreprocessor<T: Float + Debug + Send + Sync + 'static> {
687    /// Model dimension
688    model_dimension: usize,
689
690    /// Normalization statistics
691    normalization_stats: HashMap<String, (T, T)>, // (mean, std)
692}
693
694impl<T: Float + Debug + Send + Sync + 'static> SequencePreprocessor<T> {
695    pub fn new(model_dimension: usize) -> Result<Self> {
696        Ok(Self {
697            model_dimension,
698            normalization_stats: HashMap::new(),
699        })
700    }
701
702    pub fn combine_sequences(
703        &self,
704        gradients: &Array2<T>,
705        parameters: &Array2<T>,
706        losses: &Array1<T>,
707    ) -> Result<Array2<T>> {
708        let sequence_length = gradients.shape()[0];
709        let grad_dim = gradients.shape()[1];
710        let param_dim = parameters.shape()[1];
711
712        // Combined features: [gradients, parameters, loss, normalized_loss]
713        let feature_dim = grad_dim + param_dim + 2;
714        let mut combined = Array2::zeros((sequence_length, feature_dim.min(self.model_dimension)));
715
716        for i in 0..sequence_length {
717            let mut feature_idx = 0;
718
719            // Add gradient features
720            for j in 0..grad_dim.min(self.model_dimension / 3) {
721                if feature_idx < combined.shape()[1] {
722                    combined[[i, feature_idx]] = gradients[[i, j]];
723                    feature_idx += 1;
724                }
725            }
726
727            // Add parameter features
728            for j in 0..param_dim.min(self.model_dimension / 3) {
729                if feature_idx < combined.shape()[1] {
730                    combined[[i, feature_idx]] = parameters[[i, j]];
731                    feature_idx += 1;
732                }
733            }
734
735            // Add loss features
736            if feature_idx < combined.shape()[1] {
737                combined[[i, feature_idx]] = losses[i];
738                feature_idx += 1;
739            }
740
741            // Add normalized loss (loss relative to first loss)
742            if feature_idx < combined.shape()[1] && i > 0 {
743                let normalized_loss = if losses[0] != T::zero() {
744                    losses[i] / losses[0]
745                } else {
746                    T::one()
747                };
748                combined[[i, feature_idx]] = normalized_loss;
749            }
750        }
751
752        Ok(combined)
753    }
754}
755
756/// Chunking strategy
757pub struct ChunkingStrategy<T: Float + Debug + Send + Sync + 'static> {
758    /// Maximum chunk size
759    max_chunk_size: usize,
760
761    /// Overlap between chunks
762    overlap_size: usize,
763
764    /// Chunk statistics
765    chunk_stats: StatisticsAccumulator<T>,
766}
767
768impl<T: Float + Debug + Send + Sync + 'static> ChunkingStrategy<T> {
769    pub fn new(max_chunk_size: usize, overlap_size: usize) -> Result<Self> {
770        Ok(Self {
771            max_chunk_size,
772            overlap_size,
773            chunk_stats: StatisticsAccumulator::new(),
774        })
775    }
776
777    pub fn create_chunks(&mut self, sequence: &Array2<T>) -> Result<Vec<Array2<T>>> {
778        let sequence_length = sequence.shape()[0];
779        let mut chunks = Vec::new();
780
781        if sequence_length <= self.max_chunk_size {
782            chunks.push(sequence.clone());
783            return Ok(chunks);
784        }
785
786        let step_size = self.max_chunk_size - self.overlap_size;
787
788        for start in (0..sequence_length).step_by(step_size) {
789            let end = (start + self.max_chunk_size).min(sequence_length);
790            let chunk = sequence.slice(s![start..end, ..]).to_owned();
791            chunks.push(chunk);
792
793            if end >= sequence_length {
794                break;
795            }
796        }
797
798        Ok(chunks)
799    }
800}
801
802/// Supporting data structures
803#[derive(Debug, Clone)]
804pub struct SequenceSegment<T: Float + Debug + Send + Sync + 'static> {
805    pub gradients: Array2<T>,
806    pub parameters: Array2<T>,
807    pub losses: Array1<T>,
808    pub start_index: usize,
809    pub end_index: usize,
810}
811
812pub struct StatisticsAccumulator<T: Float + Debug + Send + Sync + 'static> {
813    count: usize,
814    sum: T,
815    sum_sq: T,
816    min: T,
817    max: T,
818}
819
820impl<T: Float + Debug + Send + Sync + 'static> Default for StatisticsAccumulator<T> {
821    fn default() -> Self {
822        Self::new()
823    }
824}
825
826impl<T: Float + Debug + Send + Sync + 'static> StatisticsAccumulator<T> {
827    pub fn new() -> Self {
828        Self {
829            count: 0,
830            sum: T::zero(),
831            sum_sq: T::zero(),
832            min: T::infinity(),
833            max: T::neg_infinity(),
834        }
835    }
836
837    pub fn update(&mut self, value: T) {
838        self.count += 1;
839        self.sum = self.sum + value;
840        self.sum_sq = self.sum_sq + value * value;
841        self.min = self.min.min(value);
842        self.max = self.max.max(value);
843    }
844
845    pub fn update_from_array1(&mut self, array: &Array1<T>) {
846        for &value in array.iter() {
847            self.update(value);
848        }
849    }
850
851    pub fn update_from_array2(&mut self, array: &Array2<T>) {
852        for &value in array.iter() {
853            self.update(value);
854        }
855    }
856
857    pub fn mean(&self) -> T {
858        if self.count > 0 {
859            self.sum / scirs2_core::numeric::NumCast::from(self.count).unwrap_or_else(|| T::zero())
860        } else {
861            T::zero()
862        }
863    }
864
865    pub fn variance(&self) -> T {
866        if self.count > 1 {
867            let mean = self.mean();
868            (self.sum_sq
869                / scirs2_core::numeric::NumCast::from(self.count).unwrap_or_else(|| T::zero()))
870                - (mean * mean)
871        } else {
872            T::zero()
873        }
874    }
875
876    pub fn std_dev(&self) -> T {
877        self.variance().sqrt()
878    }
879
880    pub fn min(&self) -> T {
881        self.min
882    }
883
884    pub fn max(&self) -> T {
885        self.max
886    }
887
888    pub fn reset(&mut self) {
889        self.count = 0;
890        self.sum = T::zero();
891        self.sum_sq = T::zero();
892        self.min = T::infinity();
893        self.max = T::neg_infinity();
894    }
895}
896
897// Import for slice macro
898use scirs2_core::ndarray::s;
899
900#[cfg(test)]
901mod tests {
902    use super::*;
903
904    #[test]
905    fn test_sequence_processor_creation() {
906        let config = super::super::config::TransformerBasedOptimizerConfig::<f32>::default();
907        let processor = OptimizationSequenceProcessor::new(&config);
908        assert!(processor.is_ok());
909    }
910
911    #[test]
912    fn test_sequence_buffer() {
913        let buffer = SequenceBuffer::<f32>::new(10, 64);
914        assert!(buffer.is_ok());
915
916        let mut buf = buffer.expect("unwrap failed");
917        let gradients = Array2::<f32>::ones((5, 64));
918        let parameters = Array2::<f32>::ones((5, 64));
919        let losses = Array1::<f32>::ones(5);
920
921        assert!(buf.add_sequence(&gradients, &parameters, &losses).is_ok());
922    }
923
924    #[test]
925    fn test_sequence_statistics() {
926        let mut stats = SequenceStatistics::<f32>::new();
927
928        let gradients = Array2::<f32>::ones((10, 5));
929        let parameters = Array2::<f32>::ones((10, 5));
930        let losses = Array1::<f32>::ones(10);
931
932        assert!(stats.update(&gradients, &parameters, &losses).is_ok());
933        assert!(stats.get_gradient_stats().mean() > 0.0);
934    }
935
936    #[test]
937    fn test_statistics_accumulator() {
938        let mut acc = StatisticsAccumulator::<f32>::new();
939
940        acc.update(1.0);
941        acc.update(2.0);
942        acc.update(3.0);
943
944        assert_eq!(acc.mean(), 2.0);
945        assert!(acc.std_dev() > 0.0);
946        assert_eq!(acc.min(), 1.0);
947        assert_eq!(acc.max(), 3.0);
948    }
949
950    #[test]
951    fn test_chunking_strategy() {
952        let chunking = ChunkingStrategy::<f32>::new(10, 2);
953        assert!(chunking.is_ok());
954
955        let mut strategy = chunking.expect("unwrap failed");
956        let sequence = Array2::<f32>::ones((25, 5));
957
958        let chunks = strategy.create_chunks(&sequence);
959        assert!(chunks.is_ok());
960
961        let chunk_vec = chunks.expect("unwrap failed");
962        assert!(chunk_vec.len() > 1);
963    }
964}