optirs_learned/transformer/architecture/
positional_encoding.rs1use std::fmt::Debug;
2#[allow(dead_code)]
8use scirs2_core::ndarray::{s, Array1, Array2};
9use scirs2_core::numeric::Float;
10use scirs2_core::random::{Random, Rng as SCRRng};
11
12use super::super::TransformerOptimizerConfig;
13use crate::error::{OptimError, Result};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum PositionalEncodingType {
18 Sinusoidal,
20 Learned,
22 Rotary,
24 Relative,
26 ALiBi,
28}
29
30#[derive(Debug, Clone)]
32pub struct PositionalEncoder<T: Float + Debug + Send + Sync + 'static> {
33 encoding_type: PositionalEncodingType,
35
36 cached_encodings: Option<Array2<T>>,
38
39 max_seqlen: usize,
41
42 modeldim: usize,
44
45 position_embeddings: Option<Array2<T>>,
47
48 alibi_slopes: Option<Array1<T>>,
50}
51
52impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> PositionalEncoder<T> {
53 pub fn new(config: &TransformerOptimizerConfig) -> Result<Self> {
55 let max_seqlen = config.max_sequence_length;
56 let modeldim = config.modeldim;
57
58 let mut cached_encodings = None;
59 let mut position_embeddings = None;
60 let mut alibi_slopes = None;
61
62 match config.pos_encoding_type {
63 PositionalEncodingType::Sinusoidal => {
64 let mut encodings = Array2::zeros((max_seqlen, modeldim));
66
67 for pos in 0..max_seqlen {
68 for i in 0..modeldim {
69 let angle = scirs2_core::numeric::NumCast::from(pos)
70 .unwrap_or_else(|| T::zero())
71 / T::from(10000.0_f64.powf(2.0 * (i as f64) / modeldim as f64))
72 .unwrap();
73
74 if i % 2 == 0 {
75 encodings[[pos, i]] = angle.sin();
76 } else {
77 encodings[[pos, i]] = angle.cos();
78 }
79 }
80 }
81 cached_encodings = Some(encodings);
82 }
83 PositionalEncodingType::Learned => {
84 let mut rng = scirs2_core::random::thread_rng();
86 let mut embeddings = Array2::zeros((max_seqlen, modeldim));
87
88 let bound = (6.0 / (max_seqlen + modeldim) as f64).sqrt();
90 for elem in embeddings.iter_mut() {
91 *elem = T::from((rng.random::<f64>() - 0.5) * 2.0 * bound).unwrap();
92 }
93 position_embeddings = Some(embeddings);
94 }
95 PositionalEncodingType::ALiBi => {
96 let numheads = config.numheads;
98 let mut slopes = Array1::zeros(numheads);
99
100 for h in 0..numheads {
101 let slope =
102 T::from(2.0_f64.powf(-8.0 * (h + 1) as f64 / numheads as f64)).unwrap();
103 slopes[h] = slope;
104 }
105 alibi_slopes = Some(slopes);
106 }
107 _ => {
108 let mut encodings = Array2::zeros((max_seqlen, modeldim));
110
111 for pos in 0..max_seqlen {
112 for i in 0..modeldim {
113 let angle = scirs2_core::numeric::NumCast::from(pos)
114 .unwrap_or_else(|| T::zero())
115 / T::from(10000.0_f64.powf(2.0 * (i as f64) / modeldim as f64))
116 .unwrap();
117
118 if i % 2 == 0 {
119 encodings[[pos, i]] = angle.sin();
120 } else {
121 encodings[[pos, i]] = angle.cos();
122 }
123 }
124 }
125 cached_encodings = Some(encodings);
126 }
127 }
128
129 Ok(Self {
130 encoding_type: config.pos_encoding_type,
131 cached_encodings,
132 max_seqlen,
133 modeldim,
134 position_embeddings,
135 alibi_slopes,
136 })
137 }
138
139 pub fn encode(&self, input: &Array2<T>) -> Result<Array2<T>> {
141 let (seq_len, modeldim) = input.dim();
142
143 if seq_len > self.max_seqlen {
144 return Err(OptimError::InvalidConfig(format!(
145 "Sequence length {} exceeds maximum {}",
146 seq_len, self.max_seqlen
147 )));
148 }
149
150 if modeldim != self.modeldim {
151 return Err(OptimError::InvalidConfig(format!(
152 "Model dimension {} doesn't match expected {}",
153 modeldim, self.modeldim
154 )));
155 }
156
157 let mut output = input.clone();
158
159 match self.encoding_type {
160 PositionalEncodingType::Sinusoidal => {
161 if let Some(ref encodings) = self.cached_encodings {
162 let pos_enc = encodings.slice(s![..seq_len, ..]);
163 output = output + pos_enc;
164 }
165 }
166 PositionalEncodingType::Learned => {
167 if let Some(ref embeddings) = self.position_embeddings {
168 let pos_emb = embeddings.slice(s![..seq_len, ..]);
169 output = output + pos_emb;
170 }
171 }
172 PositionalEncodingType::Rotary => {
173 }
177 PositionalEncodingType::Relative => {
178 }
182 PositionalEncodingType::ALiBi => {
183 }
186 }
187
188 Ok(output)
189 }
190
191 pub fn get_alibi_slopes(&self) -> Option<&Array1<T>> {
193 self.alibi_slopes.as_ref()
194 }
195
196 pub fn encoding_type(&self) -> PositionalEncodingType {
198 self.encoding_type
199 }
200
201 pub fn max_sequence_length(&self) -> usize {
203 self.max_seqlen
204 }
205
206 pub fn model_dimension(&self) -> usize {
208 self.modeldim
209 }
210
211 pub fn update_embeddings(&mut self, new_embeddings: Array2<T>) -> Result<()> {
213 match self.encoding_type {
214 PositionalEncodingType::Learned => {
215 let (pos_len, model_dim) = new_embeddings.dim();
216 if pos_len != self.max_seqlen || model_dim != self.modeldim {
217 return Err(OptimError::InvalidConfig(
218 "New embeddings dimensions don't match encoder configuration".to_string(),
219 ));
220 }
221 self.position_embeddings = Some(new_embeddings);
222 Ok(())
223 }
224 _ => Err(OptimError::InvalidConfig(
225 "Position embeddings can only be updated for learned encoding type".to_string(),
226 )),
227 }
228 }
229
230 pub fn compute_sinusoidal_position(&self, position: usize) -> Result<Array1<T>> {
232 if position >= self.max_seqlen {
233 return Err(OptimError::InvalidConfig(
234 "Position exceeds maximum sequence length".to_string(),
235 ));
236 }
237
238 let mut encoding = Array1::zeros(self.modeldim);
239 for i in 0..self.modeldim {
240 let angle = scirs2_core::numeric::NumCast::from(position).unwrap_or_else(|| T::zero())
241 / T::from(10000.0_f64.powf(2.0 * (i as f64) / self.modeldim as f64)).unwrap();
242
243 if i % 2 == 0 {
244 encoding[i] = angle.sin();
245 } else {
246 encoding[i] = angle.cos();
247 }
248 }
249
250 Ok(encoding)
251 }
252
253 pub fn apply_alibi_bias(
255 &self,
256 attention_scores: &mut Array2<T>,
257 head_idx: usize,
258 ) -> Result<()> {
259 if self.encoding_type != PositionalEncodingType::ALiBi {
260 return Ok(()); }
262
263 if let Some(ref slopes) = self.alibi_slopes {
264 if head_idx >= slopes.len() {
265 return Err(OptimError::InvalidConfig(
266 "Head index exceeds number of ALiBi slopes".to_string(),
267 ));
268 }
269
270 let slope = slopes[head_idx];
271 let (seq_len, _) = attention_scores.dim();
272
273 for i in 0..seq_len {
274 for j in 0..seq_len {
275 let distance = T::from((i as i32 - j as i32).abs()).unwrap();
276 attention_scores[[i, j]] = attention_scores[[i, j]] - slope * distance;
277 }
278 }
279 }
280
281 Ok(())
282 }
283}