Skip to main content

optirs_learned/transformer/architecture/
positional_encoding.rs

1use std::fmt::Debug;
2// Positional encoding mechanisms for transformer optimization
3//
4// This module implements various positional encoding strategies used in the
5// transformer optimizer to provide position information to the attention mechanisms.
6
7#[allow(dead_code)]
8use scirs2_core::ndarray::{s, Array1, Array2};
9use scirs2_core::numeric::Float;
10use scirs2_core::random::{Random, Rng as SCRRng};
11
12use super::super::TransformerOptimizerConfig;
13use crate::error::{OptimError, Result};
14
15/// Types of positional encoding
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum PositionalEncodingType {
18    /// Sinusoidal position encoding
19    Sinusoidal,
20    /// Learned position embedding
21    Learned,
22    /// Rotary position embedding (RoPE)
23    Rotary,
24    /// Relative position encoding
25    Relative,
26    /// ALiBi (Attention with Linear Biases)
27    ALiBi,
28}
29
30/// Positional encoder for transformer inputs
31#[derive(Debug, Clone)]
32pub struct PositionalEncoder<T: Float + Debug + Send + Sync + 'static> {
33    /// Encoding type
34    encoding_type: PositionalEncodingType,
35
36    /// Cached encodings
37    cached_encodings: Option<Array2<T>>,
38
39    /// Maximum sequence length
40    max_seqlen: usize,
41
42    /// Model dimension
43    modeldim: usize,
44
45    /// Learned position embeddings (if applicable)
46    position_embeddings: Option<Array2<T>>,
47
48    /// ALiBi slopes (if applicable)
49    alibi_slopes: Option<Array1<T>>,
50}
51
52impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> PositionalEncoder<T> {
53    /// Create new positional encoder
54    pub fn new(config: &TransformerOptimizerConfig) -> Result<Self> {
55        let max_seqlen = config.max_sequence_length;
56        let modeldim = config.modeldim;
57
58        let mut cached_encodings = None;
59        let mut position_embeddings = None;
60        let mut alibi_slopes = None;
61
62        match config.pos_encoding_type {
63            PositionalEncodingType::Sinusoidal => {
64                // Precompute sinusoidal encodings
65                let mut encodings = Array2::zeros((max_seqlen, modeldim));
66
67                for pos in 0..max_seqlen {
68                    for i in 0..modeldim {
69                        let angle = scirs2_core::numeric::NumCast::from(pos)
70                            .unwrap_or_else(|| T::zero())
71                            / T::from(10000.0_f64.powf(2.0 * (i as f64) / modeldim as f64))
72                                .expect("unwrap failed");
73
74                        if i % 2 == 0 {
75                            encodings[[pos, i]] = angle.sin();
76                        } else {
77                            encodings[[pos, i]] = angle.cos();
78                        }
79                    }
80                }
81                cached_encodings = Some(encodings);
82            }
83            PositionalEncodingType::Learned => {
84                // Initialize learnable position embeddings
85                let mut rng = scirs2_core::random::thread_rng();
86                let mut embeddings = Array2::zeros((max_seqlen, modeldim));
87
88                // Xavier initialization
89                let bound = (6.0 / (max_seqlen + modeldim) as f64).sqrt();
90                for elem in embeddings.iter_mut() {
91                    *elem =
92                        T::from((rng.random::<f64>() - 0.5) * 2.0 * bound).expect("unwrap failed");
93                }
94                position_embeddings = Some(embeddings);
95            }
96            PositionalEncodingType::ALiBi => {
97                // Initialize ALiBi slopes
98                let numheads = config.numheads;
99                let mut slopes = Array1::zeros(numheads);
100
101                for h in 0..numheads {
102                    let slope = T::from(2.0_f64.powf(-8.0 * (h + 1) as f64 / numheads as f64))
103                        .expect("unwrap failed");
104                    slopes[h] = slope;
105                }
106                alibi_slopes = Some(slopes);
107            }
108            _ => {
109                // Default to sinusoidal for other types
110                let mut encodings = Array2::zeros((max_seqlen, modeldim));
111
112                for pos in 0..max_seqlen {
113                    for i in 0..modeldim {
114                        let angle = scirs2_core::numeric::NumCast::from(pos)
115                            .unwrap_or_else(|| T::zero())
116                            / T::from(10000.0_f64.powf(2.0 * (i as f64) / modeldim as f64))
117                                .expect("unwrap failed");
118
119                        if i % 2 == 0 {
120                            encodings[[pos, i]] = angle.sin();
121                        } else {
122                            encodings[[pos, i]] = angle.cos();
123                        }
124                    }
125                }
126                cached_encodings = Some(encodings);
127            }
128        }
129
130        Ok(Self {
131            encoding_type: config.pos_encoding_type,
132            cached_encodings,
133            max_seqlen,
134            modeldim,
135            position_embeddings,
136            alibi_slopes,
137        })
138    }
139
140    /// Encode input with positional information
141    pub fn encode(&self, input: &Array2<T>) -> Result<Array2<T>> {
142        let (seq_len, modeldim) = input.dim();
143
144        if seq_len > self.max_seqlen {
145            return Err(OptimError::InvalidConfig(format!(
146                "Sequence length {} exceeds maximum {}",
147                seq_len, self.max_seqlen
148            )));
149        }
150
151        if modeldim != self.modeldim {
152            return Err(OptimError::InvalidConfig(format!(
153                "Model dimension {} doesn't match expected {}",
154                modeldim, self.modeldim
155            )));
156        }
157
158        let mut output = input.clone();
159
160        match self.encoding_type {
161            PositionalEncodingType::Sinusoidal => {
162                if let Some(ref encodings) = self.cached_encodings {
163                    let pos_enc = encodings.slice(s![..seq_len, ..]);
164                    output = output + pos_enc;
165                }
166            }
167            PositionalEncodingType::Learned => {
168                if let Some(ref embeddings) = self.position_embeddings {
169                    let pos_emb = embeddings.slice(s![..seq_len, ..]);
170                    output = output + pos_emb;
171                }
172            }
173            PositionalEncodingType::Rotary => {
174                // Rotary position embedding (RoPE) doesn't add to input,
175                // it modifies attention computation
176                // For now, just return input unchanged
177            }
178            PositionalEncodingType::Relative => {
179                // Relative position encoding doesn't add to input,
180                // it modifies attention computation
181                // For now, just return input unchanged
182            }
183            PositionalEncodingType::ALiBi => {
184                // ALiBi doesn't add to input, it modifies attention scores
185                // For now, just return input unchanged
186            }
187        }
188
189        Ok(output)
190    }
191
192    /// Get ALiBi slopes for attention bias calculation
193    pub fn get_alibi_slopes(&self) -> Option<&Array1<T>> {
194        self.alibi_slopes.as_ref()
195    }
196
197    /// Get encoding type
198    pub fn encoding_type(&self) -> PositionalEncodingType {
199        self.encoding_type
200    }
201
202    /// Get maximum sequence length
203    pub fn max_sequence_length(&self) -> usize {
204        self.max_seqlen
205    }
206
207    /// Get model dimension
208    pub fn model_dimension(&self) -> usize {
209        self.modeldim
210    }
211
212    /// Update position embeddings (for learned encoding)
213    pub fn update_embeddings(&mut self, new_embeddings: Array2<T>) -> Result<()> {
214        match self.encoding_type {
215            PositionalEncodingType::Learned => {
216                let (pos_len, model_dim) = new_embeddings.dim();
217                if pos_len != self.max_seqlen || model_dim != self.modeldim {
218                    return Err(OptimError::InvalidConfig(
219                        "New embeddings dimensions don't match encoder configuration".to_string(),
220                    ));
221                }
222                self.position_embeddings = Some(new_embeddings);
223                Ok(())
224            }
225            _ => Err(OptimError::InvalidConfig(
226                "Position embeddings can only be updated for learned encoding type".to_string(),
227            )),
228        }
229    }
230
231    /// Compute sinusoidal encoding for a specific position
232    pub fn compute_sinusoidal_position(&self, position: usize) -> Result<Array1<T>> {
233        if position >= self.max_seqlen {
234            return Err(OptimError::InvalidConfig(
235                "Position exceeds maximum sequence length".to_string(),
236            ));
237        }
238
239        let mut encoding = Array1::zeros(self.modeldim);
240        for i in 0..self.modeldim {
241            let angle = scirs2_core::numeric::NumCast::from(position).unwrap_or_else(|| T::zero())
242                / T::from(10000.0_f64.powf(2.0 * (i as f64) / self.modeldim as f64))
243                    .expect("unwrap failed");
244
245            if i % 2 == 0 {
246                encoding[i] = angle.sin();
247            } else {
248                encoding[i] = angle.cos();
249            }
250        }
251
252        Ok(encoding)
253    }
254
255    /// Apply ALiBi bias to attention scores
256    pub fn apply_alibi_bias(
257        &self,
258        attention_scores: &mut Array2<T>,
259        head_idx: usize,
260    ) -> Result<()> {
261        if self.encoding_type != PositionalEncodingType::ALiBi {
262            return Ok(()); // No-op for non-ALiBi encoding
263        }
264
265        if let Some(ref slopes) = self.alibi_slopes {
266            if head_idx >= slopes.len() {
267                return Err(OptimError::InvalidConfig(
268                    "Head index exceeds number of ALiBi slopes".to_string(),
269                ));
270            }
271
272            let slope = slopes[head_idx];
273            let (seq_len, _) = attention_scores.dim();
274
275            for i in 0..seq_len {
276                for j in 0..seq_len {
277                    let distance = T::from((i as i32 - j as i32).abs()).expect("unwrap failed");
278                    attention_scores[[i, j]] = attention_scores[[i, j]] - slope * distance;
279                }
280            }
281        }
282
283        Ok(())
284    }
285}