optirs_learned/transformer/architecture/
positional_encoding.rs1use std::fmt::Debug;
2#[allow(dead_code)]
8use scirs2_core::ndarray::{s, Array1, Array2};
9use scirs2_core::numeric::Float;
10use scirs2_core::random::{Random, Rng as SCRRng};
11
12use super::super::TransformerOptimizerConfig;
13use crate::error::{OptimError, Result};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum PositionalEncodingType {
18 Sinusoidal,
20 Learned,
22 Rotary,
24 Relative,
26 ALiBi,
28}
29
30#[derive(Debug, Clone)]
32pub struct PositionalEncoder<T: Float + Debug + Send + Sync + 'static> {
33 encoding_type: PositionalEncodingType,
35
36 cached_encodings: Option<Array2<T>>,
38
39 max_seqlen: usize,
41
42 modeldim: usize,
44
45 position_embeddings: Option<Array2<T>>,
47
48 alibi_slopes: Option<Array1<T>>,
50}
51
52impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> PositionalEncoder<T> {
53 pub fn new(config: &TransformerOptimizerConfig) -> Result<Self> {
55 let max_seqlen = config.max_sequence_length;
56 let modeldim = config.modeldim;
57
58 let mut cached_encodings = None;
59 let mut position_embeddings = None;
60 let mut alibi_slopes = None;
61
62 match config.pos_encoding_type {
63 PositionalEncodingType::Sinusoidal => {
64 let mut encodings = Array2::zeros((max_seqlen, modeldim));
66
67 for pos in 0..max_seqlen {
68 for i in 0..modeldim {
69 let angle = scirs2_core::numeric::NumCast::from(pos)
70 .unwrap_or_else(|| T::zero())
71 / T::from(10000.0_f64.powf(2.0 * (i as f64) / modeldim as f64))
72 .expect("unwrap failed");
73
74 if i % 2 == 0 {
75 encodings[[pos, i]] = angle.sin();
76 } else {
77 encodings[[pos, i]] = angle.cos();
78 }
79 }
80 }
81 cached_encodings = Some(encodings);
82 }
83 PositionalEncodingType::Learned => {
84 let mut rng = scirs2_core::random::thread_rng();
86 let mut embeddings = Array2::zeros((max_seqlen, modeldim));
87
88 let bound = (6.0 / (max_seqlen + modeldim) as f64).sqrt();
90 for elem in embeddings.iter_mut() {
91 *elem =
92 T::from((rng.random::<f64>() - 0.5) * 2.0 * bound).expect("unwrap failed");
93 }
94 position_embeddings = Some(embeddings);
95 }
96 PositionalEncodingType::ALiBi => {
97 let numheads = config.numheads;
99 let mut slopes = Array1::zeros(numheads);
100
101 for h in 0..numheads {
102 let slope = T::from(2.0_f64.powf(-8.0 * (h + 1) as f64 / numheads as f64))
103 .expect("unwrap failed");
104 slopes[h] = slope;
105 }
106 alibi_slopes = Some(slopes);
107 }
108 _ => {
109 let mut encodings = Array2::zeros((max_seqlen, modeldim));
111
112 for pos in 0..max_seqlen {
113 for i in 0..modeldim {
114 let angle = scirs2_core::numeric::NumCast::from(pos)
115 .unwrap_or_else(|| T::zero())
116 / T::from(10000.0_f64.powf(2.0 * (i as f64) / modeldim as f64))
117 .expect("unwrap failed");
118
119 if i % 2 == 0 {
120 encodings[[pos, i]] = angle.sin();
121 } else {
122 encodings[[pos, i]] = angle.cos();
123 }
124 }
125 }
126 cached_encodings = Some(encodings);
127 }
128 }
129
130 Ok(Self {
131 encoding_type: config.pos_encoding_type,
132 cached_encodings,
133 max_seqlen,
134 modeldim,
135 position_embeddings,
136 alibi_slopes,
137 })
138 }
139
140 pub fn encode(&self, input: &Array2<T>) -> Result<Array2<T>> {
142 let (seq_len, modeldim) = input.dim();
143
144 if seq_len > self.max_seqlen {
145 return Err(OptimError::InvalidConfig(format!(
146 "Sequence length {} exceeds maximum {}",
147 seq_len, self.max_seqlen
148 )));
149 }
150
151 if modeldim != self.modeldim {
152 return Err(OptimError::InvalidConfig(format!(
153 "Model dimension {} doesn't match expected {}",
154 modeldim, self.modeldim
155 )));
156 }
157
158 let mut output = input.clone();
159
160 match self.encoding_type {
161 PositionalEncodingType::Sinusoidal => {
162 if let Some(ref encodings) = self.cached_encodings {
163 let pos_enc = encodings.slice(s![..seq_len, ..]);
164 output = output + pos_enc;
165 }
166 }
167 PositionalEncodingType::Learned => {
168 if let Some(ref embeddings) = self.position_embeddings {
169 let pos_emb = embeddings.slice(s![..seq_len, ..]);
170 output = output + pos_emb;
171 }
172 }
173 PositionalEncodingType::Rotary => {
174 }
178 PositionalEncodingType::Relative => {
179 }
183 PositionalEncodingType::ALiBi => {
184 }
187 }
188
189 Ok(output)
190 }
191
192 pub fn get_alibi_slopes(&self) -> Option<&Array1<T>> {
194 self.alibi_slopes.as_ref()
195 }
196
197 pub fn encoding_type(&self) -> PositionalEncodingType {
199 self.encoding_type
200 }
201
202 pub fn max_sequence_length(&self) -> usize {
204 self.max_seqlen
205 }
206
207 pub fn model_dimension(&self) -> usize {
209 self.modeldim
210 }
211
212 pub fn update_embeddings(&mut self, new_embeddings: Array2<T>) -> Result<()> {
214 match self.encoding_type {
215 PositionalEncodingType::Learned => {
216 let (pos_len, model_dim) = new_embeddings.dim();
217 if pos_len != self.max_seqlen || model_dim != self.modeldim {
218 return Err(OptimError::InvalidConfig(
219 "New embeddings dimensions don't match encoder configuration".to_string(),
220 ));
221 }
222 self.position_embeddings = Some(new_embeddings);
223 Ok(())
224 }
225 _ => Err(OptimError::InvalidConfig(
226 "Position embeddings can only be updated for learned encoding type".to_string(),
227 )),
228 }
229 }
230
231 pub fn compute_sinusoidal_position(&self, position: usize) -> Result<Array1<T>> {
233 if position >= self.max_seqlen {
234 return Err(OptimError::InvalidConfig(
235 "Position exceeds maximum sequence length".to_string(),
236 ));
237 }
238
239 let mut encoding = Array1::zeros(self.modeldim);
240 for i in 0..self.modeldim {
241 let angle = scirs2_core::numeric::NumCast::from(position).unwrap_or_else(|| T::zero())
242 / T::from(10000.0_f64.powf(2.0 * (i as f64) / self.modeldim as f64))
243 .expect("unwrap failed");
244
245 if i % 2 == 0 {
246 encoding[i] = angle.sin();
247 } else {
248 encoding[i] = angle.cos();
249 }
250 }
251
252 Ok(encoding)
253 }
254
255 pub fn apply_alibi_bias(
257 &self,
258 attention_scores: &mut Array2<T>,
259 head_idx: usize,
260 ) -> Result<()> {
261 if self.encoding_type != PositionalEncodingType::ALiBi {
262 return Ok(()); }
264
265 if let Some(ref slopes) = self.alibi_slopes {
266 if head_idx >= slopes.len() {
267 return Err(OptimError::InvalidConfig(
268 "Head index exceeds number of ALiBi slopes".to_string(),
269 ));
270 }
271
272 let slope = slopes[head_idx];
273 let (seq_len, _) = attention_scores.dim();
274
275 for i in 0..seq_len {
276 for j in 0..seq_len {
277 let distance = T::from((i as i32 - j as i32).abs()).expect("unwrap failed");
278 attention_scores[[i, j]] = attention_scores[[i, j]] - slope * distance;
279 }
280 }
281 }
282
283 Ok(())
284 }
285}