optirs_learned/transformer/architecture/
positional_encoding.rs

1use std::fmt::Debug;
2// Positional encoding mechanisms for transformer optimization
3//
4// This module implements various positional encoding strategies used in the
5// transformer optimizer to provide position information to the attention mechanisms.
6
7#[allow(dead_code)]
8use scirs2_core::ndarray::{s, Array1, Array2};
9use scirs2_core::numeric::Float;
10use scirs2_core::random::{Random, Rng as SCRRng};
11
12use super::super::TransformerOptimizerConfig;
13use crate::error::{OptimError, Result};
14
15/// Types of positional encoding
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum PositionalEncodingType {
18    /// Sinusoidal position encoding
19    Sinusoidal,
20    /// Learned position embedding
21    Learned,
22    /// Rotary position embedding (RoPE)
23    Rotary,
24    /// Relative position encoding
25    Relative,
26    /// ALiBi (Attention with Linear Biases)
27    ALiBi,
28}
29
30/// Positional encoder for transformer inputs
31#[derive(Debug, Clone)]
32pub struct PositionalEncoder<T: Float + Debug + Send + Sync + 'static> {
33    /// Encoding type
34    encoding_type: PositionalEncodingType,
35
36    /// Cached encodings
37    cached_encodings: Option<Array2<T>>,
38
39    /// Maximum sequence length
40    max_seqlen: usize,
41
42    /// Model dimension
43    modeldim: usize,
44
45    /// Learned position embeddings (if applicable)
46    position_embeddings: Option<Array2<T>>,
47
48    /// ALiBi slopes (if applicable)
49    alibi_slopes: Option<Array1<T>>,
50}
51
52impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> PositionalEncoder<T> {
53    /// Create new positional encoder
54    pub fn new(config: &TransformerOptimizerConfig) -> Result<Self> {
55        let max_seqlen = config.max_sequence_length;
56        let modeldim = config.modeldim;
57
58        let mut cached_encodings = None;
59        let mut position_embeddings = None;
60        let mut alibi_slopes = None;
61
62        match config.pos_encoding_type {
63            PositionalEncodingType::Sinusoidal => {
64                // Precompute sinusoidal encodings
65                let mut encodings = Array2::zeros((max_seqlen, modeldim));
66
67                for pos in 0..max_seqlen {
68                    for i in 0..modeldim {
69                        let angle = scirs2_core::numeric::NumCast::from(pos)
70                            .unwrap_or_else(|| T::zero())
71                            / T::from(10000.0_f64.powf(2.0 * (i as f64) / modeldim as f64))
72                                .unwrap();
73
74                        if i % 2 == 0 {
75                            encodings[[pos, i]] = angle.sin();
76                        } else {
77                            encodings[[pos, i]] = angle.cos();
78                        }
79                    }
80                }
81                cached_encodings = Some(encodings);
82            }
83            PositionalEncodingType::Learned => {
84                // Initialize learnable position embeddings
85                let mut rng = scirs2_core::random::thread_rng();
86                let mut embeddings = Array2::zeros((max_seqlen, modeldim));
87
88                // Xavier initialization
89                let bound = (6.0 / (max_seqlen + modeldim) as f64).sqrt();
90                for elem in embeddings.iter_mut() {
91                    *elem = T::from((rng.random::<f64>() - 0.5) * 2.0 * bound).unwrap();
92                }
93                position_embeddings = Some(embeddings);
94            }
95            PositionalEncodingType::ALiBi => {
96                // Initialize ALiBi slopes
97                let numheads = config.numheads;
98                let mut slopes = Array1::zeros(numheads);
99
100                for h in 0..numheads {
101                    let slope =
102                        T::from(2.0_f64.powf(-8.0 * (h + 1) as f64 / numheads as f64)).unwrap();
103                    slopes[h] = slope;
104                }
105                alibi_slopes = Some(slopes);
106            }
107            _ => {
108                // Default to sinusoidal for other types
109                let mut encodings = Array2::zeros((max_seqlen, modeldim));
110
111                for pos in 0..max_seqlen {
112                    for i in 0..modeldim {
113                        let angle = scirs2_core::numeric::NumCast::from(pos)
114                            .unwrap_or_else(|| T::zero())
115                            / T::from(10000.0_f64.powf(2.0 * (i as f64) / modeldim as f64))
116                                .unwrap();
117
118                        if i % 2 == 0 {
119                            encodings[[pos, i]] = angle.sin();
120                        } else {
121                            encodings[[pos, i]] = angle.cos();
122                        }
123                    }
124                }
125                cached_encodings = Some(encodings);
126            }
127        }
128
129        Ok(Self {
130            encoding_type: config.pos_encoding_type,
131            cached_encodings,
132            max_seqlen,
133            modeldim,
134            position_embeddings,
135            alibi_slopes,
136        })
137    }
138
139    /// Encode input with positional information
140    pub fn encode(&self, input: &Array2<T>) -> Result<Array2<T>> {
141        let (seq_len, modeldim) = input.dim();
142
143        if seq_len > self.max_seqlen {
144            return Err(OptimError::InvalidConfig(format!(
145                "Sequence length {} exceeds maximum {}",
146                seq_len, self.max_seqlen
147            )));
148        }
149
150        if modeldim != self.modeldim {
151            return Err(OptimError::InvalidConfig(format!(
152                "Model dimension {} doesn't match expected {}",
153                modeldim, self.modeldim
154            )));
155        }
156
157        let mut output = input.clone();
158
159        match self.encoding_type {
160            PositionalEncodingType::Sinusoidal => {
161                if let Some(ref encodings) = self.cached_encodings {
162                    let pos_enc = encodings.slice(s![..seq_len, ..]);
163                    output = output + pos_enc;
164                }
165            }
166            PositionalEncodingType::Learned => {
167                if let Some(ref embeddings) = self.position_embeddings {
168                    let pos_emb = embeddings.slice(s![..seq_len, ..]);
169                    output = output + pos_emb;
170                }
171            }
172            PositionalEncodingType::Rotary => {
173                // Rotary position embedding (RoPE) doesn't add to input,
174                // it modifies attention computation
175                // For now, just return input unchanged
176            }
177            PositionalEncodingType::Relative => {
178                // Relative position encoding doesn't add to input,
179                // it modifies attention computation
180                // For now, just return input unchanged
181            }
182            PositionalEncodingType::ALiBi => {
183                // ALiBi doesn't add to input, it modifies attention scores
184                // For now, just return input unchanged
185            }
186        }
187
188        Ok(output)
189    }
190
191    /// Get ALiBi slopes for attention bias calculation
192    pub fn get_alibi_slopes(&self) -> Option<&Array1<T>> {
193        self.alibi_slopes.as_ref()
194    }
195
196    /// Get encoding type
197    pub fn encoding_type(&self) -> PositionalEncodingType {
198        self.encoding_type
199    }
200
201    /// Get maximum sequence length
202    pub fn max_sequence_length(&self) -> usize {
203        self.max_seqlen
204    }
205
206    /// Get model dimension
207    pub fn model_dimension(&self) -> usize {
208        self.modeldim
209    }
210
211    /// Update position embeddings (for learned encoding)
212    pub fn update_embeddings(&mut self, new_embeddings: Array2<T>) -> Result<()> {
213        match self.encoding_type {
214            PositionalEncodingType::Learned => {
215                let (pos_len, model_dim) = new_embeddings.dim();
216                if pos_len != self.max_seqlen || model_dim != self.modeldim {
217                    return Err(OptimError::InvalidConfig(
218                        "New embeddings dimensions don't match encoder configuration".to_string(),
219                    ));
220                }
221                self.position_embeddings = Some(new_embeddings);
222                Ok(())
223            }
224            _ => Err(OptimError::InvalidConfig(
225                "Position embeddings can only be updated for learned encoding type".to_string(),
226            )),
227        }
228    }
229
230    /// Compute sinusoidal encoding for a specific position
231    pub fn compute_sinusoidal_position(&self, position: usize) -> Result<Array1<T>> {
232        if position >= self.max_seqlen {
233            return Err(OptimError::InvalidConfig(
234                "Position exceeds maximum sequence length".to_string(),
235            ));
236        }
237
238        let mut encoding = Array1::zeros(self.modeldim);
239        for i in 0..self.modeldim {
240            let angle = scirs2_core::numeric::NumCast::from(position).unwrap_or_else(|| T::zero())
241                / T::from(10000.0_f64.powf(2.0 * (i as f64) / self.modeldim as f64)).unwrap();
242
243            if i % 2 == 0 {
244                encoding[i] = angle.sin();
245            } else {
246                encoding[i] = angle.cos();
247            }
248        }
249
250        Ok(encoding)
251    }
252
253    /// Apply ALiBi bias to attention scores
254    pub fn apply_alibi_bias(
255        &self,
256        attention_scores: &mut Array2<T>,
257        head_idx: usize,
258    ) -> Result<()> {
259        if self.encoding_type != PositionalEncodingType::ALiBi {
260            return Ok(()); // No-op for non-ALiBi encoding
261        }
262
263        if let Some(ref slopes) = self.alibi_slopes {
264            if head_idx >= slopes.len() {
265                return Err(OptimError::InvalidConfig(
266                    "Head index exceeds number of ALiBi slopes".to_string(),
267                ));
268            }
269
270            let slope = slopes[head_idx];
271            let (seq_len, _) = attention_scores.dim();
272
273            for i in 0..seq_len {
274                for j in 0..seq_len {
275                    let distance = T::from((i as i32 - j as i32).abs()).unwrap();
276                    attention_scores[[i, j]] = attention_scores[[i, j]] - slope * distance;
277                }
278            }
279        }
280
281        Ok(())
282    }
283}