Skip to main content

optirs_learned/transformer_based_optimizer/
positional_encoding.rs

1// Positional encoding implementations for transformer sequences
2
3use crate::error::Result;
4use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
5use scirs2_core::numeric::Float;
6use serde::{Deserialize, Serialize};
7use std::f64::consts::PI;
8use std::fmt::Debug;
9
10/// Positional encoding types
11#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
12pub enum PositionalEncodingType {
13    /// Sinusoidal positional encoding (original Transformer)
14    Sinusoidal,
15    /// Learned positional encoding
16    Learned,
17    /// Rotary Position Embedding (RoPE)
18    Rotary,
19    /// No positional encoding
20    None,
21}
22
23/// Positional encoding implementation
24pub struct PositionalEncoding<T: Float + Debug + Send + Sync + 'static> {
25    /// Type of positional encoding
26    encoding_type: PositionalEncodingType,
27
28    /// Maximum sequence length
29    max_sequence_length: usize,
30
31    /// Model dimension
32    model_dimension: usize,
33
34    /// Precomputed encoding matrix
35    encoding_matrix: Array2<T>,
36
37    /// Learned parameters (for learned encoding)
38    learned_embeddings: Option<Array2<T>>,
39
40    /// RoPE frequency base
41    rope_base: T,
42}
43
44impl<T: Float + Debug + Send + Sync + 'static> PositionalEncoding<T> {
45    /// Create new positional encoding
46    pub fn new(
47        max_sequence_length: usize,
48        model_dimension: usize,
49        encoding_type: PositionalEncodingType,
50    ) -> Result<Self> {
51        let rope_base = scirs2_core::numeric::NumCast::from(10000.0).unwrap_or_else(|| T::zero());
52        let mut encoding = Self {
53            encoding_type,
54            max_sequence_length,
55            model_dimension,
56            encoding_matrix: Array2::zeros((max_sequence_length, model_dimension)),
57            learned_embeddings: None,
58            rope_base,
59        };
60
61        encoding.initialize_encoding()?;
62        Ok(encoding)
63    }
64
65    /// Initialize the encoding based on type
66    fn initialize_encoding(&mut self) -> Result<()> {
67        match self.encoding_type {
68            PositionalEncodingType::Sinusoidal => self.initialize_sinusoidal(),
69            PositionalEncodingType::Learned => self.initialize_learned(),
70            PositionalEncodingType::Rotary => self.initialize_rotary(),
71            PositionalEncodingType::None => Ok(()),
72        }
73    }
74
75    /// Initialize sinusoidal positional encoding
76    fn initialize_sinusoidal(&mut self) -> Result<()> {
77        for pos in 0..self.max_sequence_length {
78            for i in 0..self.model_dimension {
79                let position =
80                    scirs2_core::numeric::NumCast::from(pos).unwrap_or_else(|| T::zero());
81                let dimension = scirs2_core::numeric::NumCast::from(i).unwrap_or_else(|| T::zero());
82                let model_dim = scirs2_core::numeric::NumCast::from(self.model_dimension)
83                    .unwrap_or_else(|| T::zero());
84
85                let angle = position
86                    / scirs2_core::numeric::NumCast::from(10000.0)
87                        .unwrap_or_else(|| T::zero())
88                        .powf(
89                            scirs2_core::numeric::NumCast::from(2.0).unwrap_or_else(|| T::zero())
90                                * dimension
91                                / model_dim,
92                        );
93
94                if i % 2 == 0 {
95                    // Even dimensions: sin
96                    self.encoding_matrix[[pos, i]] = angle.sin();
97                } else {
98                    // Odd dimensions: cos
99                    self.encoding_matrix[[pos, i]] = angle.cos();
100                }
101            }
102        }
103        Ok(())
104    }
105
106    /// Initialize learned positional encoding
107    fn initialize_learned(&mut self) -> Result<()> {
108        let learned_embeddings = Array2::zeros((self.max_sequence_length, self.model_dimension));
109        self.learned_embeddings = Some(learned_embeddings);
110        Ok(())
111    }
112
113    /// Initialize rotary positional encoding
114    fn initialize_rotary(&mut self) -> Result<()> {
115        // RoPE doesn't use a precomputed matrix, it's applied during attention
116        // We'll store frequency inverse for each dimension pair
117        for i in (0..self.model_dimension).step_by(2) {
118            let dim_pair = scirs2_core::numeric::NumCast::from(i).unwrap_or_else(|| T::zero())
119                / scirs2_core::numeric::NumCast::from(self.model_dimension)
120                    .unwrap_or_else(|| T::zero());
121            let freq = T::one() / self.rope_base.powf(dim_pair);
122
123            if i < self.model_dimension {
124                self.encoding_matrix[[0, i]] = freq;
125            }
126            if i + 1 < self.model_dimension {
127                self.encoding_matrix[[0, i + 1]] = freq;
128            }
129        }
130        Ok(())
131    }
132
133    /// Apply positional encoding to input
134    pub fn encode(&self, input: &Array2<T>) -> Result<Array2<T>> {
135        match self.encoding_type {
136            PositionalEncodingType::None => Ok(input.clone()),
137            PositionalEncodingType::Sinusoidal => self.apply_sinusoidal(input),
138            PositionalEncodingType::Learned => self.apply_learned(input),
139            PositionalEncodingType::Rotary => self.apply_rotary(input),
140        }
141    }
142
143    /// Apply sinusoidal encoding
144    fn apply_sinusoidal(&self, input: &Array2<T>) -> Result<Array2<T>> {
145        let batch_size = input.shape()[0];
146        let sequence_length = input.shape()[1];
147
148        if sequence_length > self.max_sequence_length {
149            return Err(crate::error::OptimError::Other(
150                "Sequence length exceeds maximum".to_string(),
151            ));
152        }
153
154        let mut output = input.clone();
155
156        for batch in 0..batch_size {
157            for pos in 0..sequence_length {
158                for dim in 0..self.model_dimension {
159                    output[[batch, pos]] = output[[batch, pos]] + self.encoding_matrix[[pos, dim]];
160                }
161            }
162        }
163
164        Ok(output)
165    }
166
167    /// Apply learned encoding
168    fn apply_learned(&self, input: &Array2<T>) -> Result<Array2<T>> {
169        if let Some(ref learned) = self.learned_embeddings {
170            let batch_size = input.shape()[0];
171            let sequence_length = input.shape()[1];
172
173            if sequence_length > self.max_sequence_length {
174                return Err(crate::error::OptimError::Other(
175                    "Sequence length exceeds maximum".to_string(),
176                ));
177            }
178
179            let mut output = input.clone();
180
181            for batch in 0..batch_size {
182                for pos in 0..sequence_length {
183                    for dim in 0..self.model_dimension {
184                        output[[batch, pos]] = output[[batch, pos]] + learned[[pos, dim]];
185                    }
186                }
187            }
188
189            Ok(output)
190        } else {
191            Err(crate::error::OptimError::Other(
192                "Learned embeddings not initialized".to_string(),
193            ))
194        }
195    }
196
197    /// Apply rotary encoding (simplified implementation)
198    fn apply_rotary(&self, input: &Array2<T>) -> Result<Array2<T>> {
199        let batch_size = input.shape()[0];
200        let sequence_length = input.shape()[1];
201        let mut output = input.clone();
202
203        for batch in 0..batch_size {
204            for pos in 0..sequence_length {
205                let position =
206                    scirs2_core::numeric::NumCast::from(pos).unwrap_or_else(|| T::zero());
207
208                for i in (0..self.model_dimension).step_by(2) {
209                    if i + 1 < self.model_dimension {
210                        let freq = self.encoding_matrix[[0, i]];
211                        let angle = position * freq;
212
213                        let cos_val = angle.cos();
214                        let sin_val = angle.sin();
215
216                        let x = input[[batch, pos]]; // Simplified: treating as single value
217                        let y = if i + 1 < input.shape()[1] {
218                            input[[batch, pos]]
219                        } else {
220                            T::zero()
221                        };
222
223                        output[[batch, pos]] = x * cos_val - y * sin_val;
224                        // For complete RoPE, we'd need to handle the paired dimension
225                    }
226                }
227            }
228        }
229
230        Ok(output)
231    }
232
233    /// Get encoding for specific position
234    pub fn get_position_encoding(&self, position: usize) -> Result<Array1<T>> {
235        if position >= self.max_sequence_length {
236            return Err(crate::error::OptimError::Other(
237                "Position exceeds maximum sequence length".to_string(),
238            ));
239        }
240
241        match self.encoding_type {
242            PositionalEncodingType::Sinusoidal => Ok(self.encoding_matrix.row(position).to_owned()),
243            PositionalEncodingType::Learned => {
244                if let Some(ref learned) = self.learned_embeddings {
245                    Ok(learned.row(position).to_owned())
246                } else {
247                    Err(crate::error::OptimError::Other(
248                        "Learned embeddings not available".to_string(),
249                    ))
250                }
251            }
252            PositionalEncodingType::Rotary => {
253                // Return frequency information for RoPE
254                Ok(self.encoding_matrix.row(0).to_owned())
255            }
256            PositionalEncodingType::None => Ok(Array1::zeros(self.model_dimension)),
257        }
258    }
259
260    /// Update learned embeddings (for training)
261    pub fn update_learned_embeddings(&mut self, gradients: &Array2<T>) -> Result<()> {
262        if let Some(ref mut learned) = self.learned_embeddings {
263            *learned = &*learned - gradients;
264            Ok(())
265        } else {
266            Err(crate::error::OptimError::Other(
267                "No learned embeddings to update".to_string(),
268            ))
269        }
270    }
271
272    /// Get parameter count
273    pub fn parameter_count(&self) -> usize {
274        match self.encoding_type {
275            PositionalEncodingType::Learned => self.max_sequence_length * self.model_dimension,
276            _ => 0, // Sinusoidal, RoPE, and None have no learnable parameters
277        }
278    }
279
280    /// Reset parameters
281    pub fn reset(&mut self) -> Result<()> {
282        self.initialize_encoding()
283    }
284
285    /// Get encoding type
286    pub fn get_encoding_type(&self) -> PositionalEncodingType {
287        self.encoding_type
288    }
289
290    /// Get maximum sequence length
291    pub fn get_max_sequence_length(&self) -> usize {
292        self.max_sequence_length
293    }
294
295    /// Get model dimension
296    pub fn get_model_dimension(&self) -> usize {
297        self.model_dimension
298    }
299
300    /// Create sinusoidal encoding with custom base
301    pub fn sinusoidal_with_base(
302        max_sequence_length: usize,
303        model_dimension: usize,
304        base: T,
305    ) -> Result<Self> {
306        let mut encoding = Self::new(
307            max_sequence_length,
308            model_dimension,
309            PositionalEncodingType::Sinusoidal,
310        )?;
311
312        // Reinitialize with custom base
313        for pos in 0..max_sequence_length {
314            for i in 0..model_dimension {
315                let position =
316                    scirs2_core::numeric::NumCast::from(pos).unwrap_or_else(|| T::zero());
317                let dimension = scirs2_core::numeric::NumCast::from(i).unwrap_or_else(|| T::zero());
318                let model_dim = scirs2_core::numeric::NumCast::from(model_dimension)
319                    .unwrap_or_else(|| T::zero());
320
321                let angle = position
322                    / base.powf(
323                        scirs2_core::numeric::NumCast::from(2.0).unwrap_or_else(|| T::zero())
324                            * dimension
325                            / model_dim,
326                    );
327
328                if i % 2 == 0 {
329                    encoding.encoding_matrix[[pos, i]] = angle.sin();
330                } else {
331                    encoding.encoding_matrix[[pos, i]] = angle.cos();
332                }
333            }
334        }
335
336        Ok(encoding)
337    }
338
339    /// Create rotary encoding with custom base
340    pub fn rotary_with_base(
341        max_sequence_length: usize,
342        model_dimension: usize,
343        base: T,
344    ) -> Result<Self> {
345        let mut encoding = Self::new(
346            max_sequence_length,
347            model_dimension,
348            PositionalEncodingType::Rotary,
349        )?;
350
351        encoding.rope_base = base;
352        encoding.initialize_rotary()?;
353
354        Ok(encoding)
355    }
356}
357
358/// Relative positional encoding for local attention patterns
359pub struct RelativePositionalEncoding<T: Float + Debug + Send + Sync + 'static> {
360    /// Maximum relative distance
361    max_relative_distance: usize,
362
363    /// Model dimension
364    model_dimension: usize,
365
366    /// Relative encoding table
367    relative_encoding: Array2<T>,
368}
369
370impl<T: Float + Debug + Send + Sync + 'static> RelativePositionalEncoding<T> {
371    /// Create new relative positional encoding
372    pub fn new(max_relative_distance: usize, model_dimension: usize) -> Result<Self> {
373        let table_size = 2 * max_relative_distance + 1;
374        let relative_encoding = Array2::zeros((table_size, model_dimension));
375
376        Ok(Self {
377            max_relative_distance,
378            model_dimension,
379            relative_encoding,
380        })
381    }
382
383    /// Get relative encoding between two positions
384    pub fn get_relative_encoding(&self, from_pos: usize, to_pos: usize) -> Array1<T> {
385        let relative_distance = (to_pos as i32 - from_pos as i32)
386            .max(-(self.max_relative_distance as i32))
387            .min(self.max_relative_distance as i32);
388
389        let index = (relative_distance + self.max_relative_distance as i32) as usize;
390        self.relative_encoding.row(index).to_owned()
391    }
392
393    /// Initialize with sinusoidal patterns
394    pub fn initialize_sinusoidal(&mut self) -> Result<()> {
395        let table_size = 2 * self.max_relative_distance + 1;
396
397        for i in 0..table_size {
398            let relative_pos = i as i32 - self.max_relative_distance as i32;
399            let position =
400                scirs2_core::numeric::NumCast::from(relative_pos).unwrap_or_else(|| T::zero());
401
402            for j in 0..self.model_dimension {
403                let dimension = scirs2_core::numeric::NumCast::from(j).unwrap_or_else(|| T::zero());
404                let model_dim = scirs2_core::numeric::NumCast::from(self.model_dimension)
405                    .unwrap_or_else(|| T::zero());
406
407                let angle = position
408                    / scirs2_core::numeric::NumCast::from(10000.0)
409                        .unwrap_or_else(|| T::zero())
410                        .powf(
411                            scirs2_core::numeric::NumCast::from(2.0).unwrap_or_else(|| T::zero())
412                                * dimension
413                                / model_dim,
414                        );
415
416                if j % 2 == 0 {
417                    self.relative_encoding[[i, j]] = angle.sin();
418                } else {
419                    self.relative_encoding[[i, j]] = angle.cos();
420                }
421            }
422        }
423
424        Ok(())
425    }
426
427    /// Get parameter count
428    pub fn parameter_count(&self) -> usize {
429        (2 * self.max_relative_distance + 1) * self.model_dimension
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_sinusoidal_encoding() {
439        let encoding = PositionalEncoding::<f32>::new(100, 64, PositionalEncodingType::Sinusoidal);
440        assert!(encoding.is_ok());
441
442        let pe = encoding.expect("unwrap failed");
443        assert_eq!(pe.parameter_count(), 0);
444
445        let input = Array2::<f32>::zeros((2, 50));
446        let result = pe.encode(&input);
447        assert!(result.is_ok());
448    }
449
450    #[test]
451    fn test_learned_encoding() {
452        let encoding = PositionalEncoding::<f32>::new(100, 64, PositionalEncodingType::Learned);
453        assert!(encoding.is_ok());
454
455        let pe = encoding.expect("unwrap failed");
456        assert_eq!(pe.parameter_count(), 100 * 64);
457    }
458
459    #[test]
460    fn test_rotary_encoding() {
461        let encoding = PositionalEncoding::<f32>::new(100, 64, PositionalEncodingType::Rotary);
462        assert!(encoding.is_ok());
463
464        let pe = encoding.expect("unwrap failed");
465        let input = Array2::<f32>::ones((2, 50));
466        let result = pe.encode(&input);
467        assert!(result.is_ok());
468    }
469
470    #[test]
471    fn test_position_encoding_retrieval() {
472        let pe = PositionalEncoding::<f32>::new(100, 64, PositionalEncodingType::Sinusoidal)
473            .expect("unwrap failed");
474
475        let pos_encoding = pe.get_position_encoding(10);
476        assert!(pos_encoding.is_ok());
477
478        let encoding = pos_encoding.expect("unwrap failed");
479        assert_eq!(encoding.len(), 64);
480    }
481
482    #[test]
483    fn test_relative_positional_encoding() {
484        let rel_pe = RelativePositionalEncoding::<f32>::new(10, 64);
485        assert!(rel_pe.is_ok());
486
487        let mut rpe = rel_pe.expect("unwrap failed");
488        assert!(rpe.initialize_sinusoidal().is_ok());
489
490        let encoding = rpe.get_relative_encoding(5, 8);
491        assert_eq!(encoding.len(), 64);
492    }
493
494    #[test]
495    fn test_encoding_types() {
496        let types = [
497            PositionalEncodingType::Sinusoidal,
498            PositionalEncodingType::Learned,
499            PositionalEncodingType::Rotary,
500            PositionalEncodingType::None,
501        ];
502
503        for encoding_type in types.iter() {
504            let pe = PositionalEncoding::<f32>::new(50, 32, *encoding_type);
505            assert!(pe.is_ok());
506        }
507    }
508}