optirs_learned/transformer_based_optimizer/
positional_encoding.rs1use crate::error::Result;
4use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
5use scirs2_core::numeric::Float;
6use serde::{Deserialize, Serialize};
7use std::f64::consts::PI;
8use std::fmt::Debug;
9
10#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
12pub enum PositionalEncodingType {
13 Sinusoidal,
15 Learned,
17 Rotary,
19 None,
21}
22
23pub struct PositionalEncoding<T: Float + Debug + Send + Sync + 'static> {
25 encoding_type: PositionalEncodingType,
27
28 max_sequence_length: usize,
30
31 model_dimension: usize,
33
34 encoding_matrix: Array2<T>,
36
37 learned_embeddings: Option<Array2<T>>,
39
40 rope_base: T,
42}
43
44impl<T: Float + Debug + Send + Sync + 'static> PositionalEncoding<T> {
45 pub fn new(
47 max_sequence_length: usize,
48 model_dimension: usize,
49 encoding_type: PositionalEncodingType,
50 ) -> Result<Self> {
51 let rope_base = scirs2_core::numeric::NumCast::from(10000.0).unwrap_or_else(|| T::zero());
52 let mut encoding = Self {
53 encoding_type,
54 max_sequence_length,
55 model_dimension,
56 encoding_matrix: Array2::zeros((max_sequence_length, model_dimension)),
57 learned_embeddings: None,
58 rope_base,
59 };
60
61 encoding.initialize_encoding()?;
62 Ok(encoding)
63 }
64
65 fn initialize_encoding(&mut self) -> Result<()> {
67 match self.encoding_type {
68 PositionalEncodingType::Sinusoidal => self.initialize_sinusoidal(),
69 PositionalEncodingType::Learned => self.initialize_learned(),
70 PositionalEncodingType::Rotary => self.initialize_rotary(),
71 PositionalEncodingType::None => Ok(()),
72 }
73 }
74
75 fn initialize_sinusoidal(&mut self) -> Result<()> {
77 for pos in 0..self.max_sequence_length {
78 for i in 0..self.model_dimension {
79 let position =
80 scirs2_core::numeric::NumCast::from(pos).unwrap_or_else(|| T::zero());
81 let dimension = scirs2_core::numeric::NumCast::from(i).unwrap_or_else(|| T::zero());
82 let model_dim = scirs2_core::numeric::NumCast::from(self.model_dimension)
83 .unwrap_or_else(|| T::zero());
84
85 let angle = position
86 / scirs2_core::numeric::NumCast::from(10000.0)
87 .unwrap_or_else(|| T::zero())
88 .powf(
89 scirs2_core::numeric::NumCast::from(2.0).unwrap_or_else(|| T::zero())
90 * dimension
91 / model_dim,
92 );
93
94 if i % 2 == 0 {
95 self.encoding_matrix[[pos, i]] = angle.sin();
97 } else {
98 self.encoding_matrix[[pos, i]] = angle.cos();
100 }
101 }
102 }
103 Ok(())
104 }
105
106 fn initialize_learned(&mut self) -> Result<()> {
108 let learned_embeddings = Array2::zeros((self.max_sequence_length, self.model_dimension));
109 self.learned_embeddings = Some(learned_embeddings);
110 Ok(())
111 }
112
113 fn initialize_rotary(&mut self) -> Result<()> {
115 for i in (0..self.model_dimension).step_by(2) {
118 let dim_pair = scirs2_core::numeric::NumCast::from(i).unwrap_or_else(|| T::zero())
119 / scirs2_core::numeric::NumCast::from(self.model_dimension)
120 .unwrap_or_else(|| T::zero());
121 let freq = T::one() / self.rope_base.powf(dim_pair);
122
123 if i < self.model_dimension {
124 self.encoding_matrix[[0, i]] = freq;
125 }
126 if i + 1 < self.model_dimension {
127 self.encoding_matrix[[0, i + 1]] = freq;
128 }
129 }
130 Ok(())
131 }
132
133 pub fn encode(&self, input: &Array2<T>) -> Result<Array2<T>> {
135 match self.encoding_type {
136 PositionalEncodingType::None => Ok(input.clone()),
137 PositionalEncodingType::Sinusoidal => self.apply_sinusoidal(input),
138 PositionalEncodingType::Learned => self.apply_learned(input),
139 PositionalEncodingType::Rotary => self.apply_rotary(input),
140 }
141 }
142
143 fn apply_sinusoidal(&self, input: &Array2<T>) -> Result<Array2<T>> {
145 let batch_size = input.shape()[0];
146 let sequence_length = input.shape()[1];
147
148 if sequence_length > self.max_sequence_length {
149 return Err(crate::error::OptimError::Other(
150 "Sequence length exceeds maximum".to_string(),
151 ));
152 }
153
154 let mut output = input.clone();
155
156 for batch in 0..batch_size {
157 for pos in 0..sequence_length {
158 for dim in 0..self.model_dimension {
159 output[[batch, pos]] = output[[batch, pos]] + self.encoding_matrix[[pos, dim]];
160 }
161 }
162 }
163
164 Ok(output)
165 }
166
167 fn apply_learned(&self, input: &Array2<T>) -> Result<Array2<T>> {
169 if let Some(ref learned) = self.learned_embeddings {
170 let batch_size = input.shape()[0];
171 let sequence_length = input.shape()[1];
172
173 if sequence_length > self.max_sequence_length {
174 return Err(crate::error::OptimError::Other(
175 "Sequence length exceeds maximum".to_string(),
176 ));
177 }
178
179 let mut output = input.clone();
180
181 for batch in 0..batch_size {
182 for pos in 0..sequence_length {
183 for dim in 0..self.model_dimension {
184 output[[batch, pos]] = output[[batch, pos]] + learned[[pos, dim]];
185 }
186 }
187 }
188
189 Ok(output)
190 } else {
191 Err(crate::error::OptimError::Other(
192 "Learned embeddings not initialized".to_string(),
193 ))
194 }
195 }
196
197 fn apply_rotary(&self, input: &Array2<T>) -> Result<Array2<T>> {
199 let batch_size = input.shape()[0];
200 let sequence_length = input.shape()[1];
201 let mut output = input.clone();
202
203 for batch in 0..batch_size {
204 for pos in 0..sequence_length {
205 let position =
206 scirs2_core::numeric::NumCast::from(pos).unwrap_or_else(|| T::zero());
207
208 for i in (0..self.model_dimension).step_by(2) {
209 if i + 1 < self.model_dimension {
210 let freq = self.encoding_matrix[[0, i]];
211 let angle = position * freq;
212
213 let cos_val = angle.cos();
214 let sin_val = angle.sin();
215
216 let x = input[[batch, pos]]; let y = if i + 1 < input.shape()[1] {
218 input[[batch, pos]]
219 } else {
220 T::zero()
221 };
222
223 output[[batch, pos]] = x * cos_val - y * sin_val;
224 }
226 }
227 }
228 }
229
230 Ok(output)
231 }
232
233 pub fn get_position_encoding(&self, position: usize) -> Result<Array1<T>> {
235 if position >= self.max_sequence_length {
236 return Err(crate::error::OptimError::Other(
237 "Position exceeds maximum sequence length".to_string(),
238 ));
239 }
240
241 match self.encoding_type {
242 PositionalEncodingType::Sinusoidal => Ok(self.encoding_matrix.row(position).to_owned()),
243 PositionalEncodingType::Learned => {
244 if let Some(ref learned) = self.learned_embeddings {
245 Ok(learned.row(position).to_owned())
246 } else {
247 Err(crate::error::OptimError::Other(
248 "Learned embeddings not available".to_string(),
249 ))
250 }
251 }
252 PositionalEncodingType::Rotary => {
253 Ok(self.encoding_matrix.row(0).to_owned())
255 }
256 PositionalEncodingType::None => Ok(Array1::zeros(self.model_dimension)),
257 }
258 }
259
260 pub fn update_learned_embeddings(&mut self, gradients: &Array2<T>) -> Result<()> {
262 if let Some(ref mut learned) = self.learned_embeddings {
263 *learned = &*learned - gradients;
264 Ok(())
265 } else {
266 Err(crate::error::OptimError::Other(
267 "No learned embeddings to update".to_string(),
268 ))
269 }
270 }
271
272 pub fn parameter_count(&self) -> usize {
274 match self.encoding_type {
275 PositionalEncodingType::Learned => self.max_sequence_length * self.model_dimension,
276 _ => 0, }
278 }
279
280 pub fn reset(&mut self) -> Result<()> {
282 self.initialize_encoding()
283 }
284
285 pub fn get_encoding_type(&self) -> PositionalEncodingType {
287 self.encoding_type
288 }
289
290 pub fn get_max_sequence_length(&self) -> usize {
292 self.max_sequence_length
293 }
294
295 pub fn get_model_dimension(&self) -> usize {
297 self.model_dimension
298 }
299
300 pub fn sinusoidal_with_base(
302 max_sequence_length: usize,
303 model_dimension: usize,
304 base: T,
305 ) -> Result<Self> {
306 let mut encoding = Self::new(
307 max_sequence_length,
308 model_dimension,
309 PositionalEncodingType::Sinusoidal,
310 )?;
311
312 for pos in 0..max_sequence_length {
314 for i in 0..model_dimension {
315 let position =
316 scirs2_core::numeric::NumCast::from(pos).unwrap_or_else(|| T::zero());
317 let dimension = scirs2_core::numeric::NumCast::from(i).unwrap_or_else(|| T::zero());
318 let model_dim = scirs2_core::numeric::NumCast::from(model_dimension)
319 .unwrap_or_else(|| T::zero());
320
321 let angle = position
322 / base.powf(
323 scirs2_core::numeric::NumCast::from(2.0).unwrap_or_else(|| T::zero())
324 * dimension
325 / model_dim,
326 );
327
328 if i % 2 == 0 {
329 encoding.encoding_matrix[[pos, i]] = angle.sin();
330 } else {
331 encoding.encoding_matrix[[pos, i]] = angle.cos();
332 }
333 }
334 }
335
336 Ok(encoding)
337 }
338
339 pub fn rotary_with_base(
341 max_sequence_length: usize,
342 model_dimension: usize,
343 base: T,
344 ) -> Result<Self> {
345 let mut encoding = Self::new(
346 max_sequence_length,
347 model_dimension,
348 PositionalEncodingType::Rotary,
349 )?;
350
351 encoding.rope_base = base;
352 encoding.initialize_rotary()?;
353
354 Ok(encoding)
355 }
356}
357
358pub struct RelativePositionalEncoding<T: Float + Debug + Send + Sync + 'static> {
360 max_relative_distance: usize,
362
363 model_dimension: usize,
365
366 relative_encoding: Array2<T>,
368}
369
370impl<T: Float + Debug + Send + Sync + 'static> RelativePositionalEncoding<T> {
371 pub fn new(max_relative_distance: usize, model_dimension: usize) -> Result<Self> {
373 let table_size = 2 * max_relative_distance + 1;
374 let relative_encoding = Array2::zeros((table_size, model_dimension));
375
376 Ok(Self {
377 max_relative_distance,
378 model_dimension,
379 relative_encoding,
380 })
381 }
382
383 pub fn get_relative_encoding(&self, from_pos: usize, to_pos: usize) -> Array1<T> {
385 let relative_distance = (to_pos as i32 - from_pos as i32)
386 .max(-(self.max_relative_distance as i32))
387 .min(self.max_relative_distance as i32);
388
389 let index = (relative_distance + self.max_relative_distance as i32) as usize;
390 self.relative_encoding.row(index).to_owned()
391 }
392
393 pub fn initialize_sinusoidal(&mut self) -> Result<()> {
395 let table_size = 2 * self.max_relative_distance + 1;
396
397 for i in 0..table_size {
398 let relative_pos = i as i32 - self.max_relative_distance as i32;
399 let position =
400 scirs2_core::numeric::NumCast::from(relative_pos).unwrap_or_else(|| T::zero());
401
402 for j in 0..self.model_dimension {
403 let dimension = scirs2_core::numeric::NumCast::from(j).unwrap_or_else(|| T::zero());
404 let model_dim = scirs2_core::numeric::NumCast::from(self.model_dimension)
405 .unwrap_or_else(|| T::zero());
406
407 let angle = position
408 / scirs2_core::numeric::NumCast::from(10000.0)
409 .unwrap_or_else(|| T::zero())
410 .powf(
411 scirs2_core::numeric::NumCast::from(2.0).unwrap_or_else(|| T::zero())
412 * dimension
413 / model_dim,
414 );
415
416 if j % 2 == 0 {
417 self.relative_encoding[[i, j]] = angle.sin();
418 } else {
419 self.relative_encoding[[i, j]] = angle.cos();
420 }
421 }
422 }
423
424 Ok(())
425 }
426
427 pub fn parameter_count(&self) -> usize {
429 (2 * self.max_relative_distance + 1) * self.model_dimension
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_sinusoidal_encoding() {
439 let encoding = PositionalEncoding::<f32>::new(100, 64, PositionalEncodingType::Sinusoidal);
440 assert!(encoding.is_ok());
441
442 let pe = encoding.unwrap();
443 assert_eq!(pe.parameter_count(), 0);
444
445 let input = Array2::<f32>::zeros((2, 50));
446 let result = pe.encode(&input);
447 assert!(result.is_ok());
448 }
449
450 #[test]
451 fn test_learned_encoding() {
452 let encoding = PositionalEncoding::<f32>::new(100, 64, PositionalEncodingType::Learned);
453 assert!(encoding.is_ok());
454
455 let pe = encoding.unwrap();
456 assert_eq!(pe.parameter_count(), 100 * 64);
457 }
458
459 #[test]
460 fn test_rotary_encoding() {
461 let encoding = PositionalEncoding::<f32>::new(100, 64, PositionalEncodingType::Rotary);
462 assert!(encoding.is_ok());
463
464 let pe = encoding.unwrap();
465 let input = Array2::<f32>::ones((2, 50));
466 let result = pe.encode(&input);
467 assert!(result.is_ok());
468 }
469
470 #[test]
471 fn test_position_encoding_retrieval() {
472 let pe =
473 PositionalEncoding::<f32>::new(100, 64, PositionalEncodingType::Sinusoidal).unwrap();
474
475 let pos_encoding = pe.get_position_encoding(10);
476 assert!(pos_encoding.is_ok());
477
478 let encoding = pos_encoding.unwrap();
479 assert_eq!(encoding.len(), 64);
480 }
481
482 #[test]
483 fn test_relative_positional_encoding() {
484 let rel_pe = RelativePositionalEncoding::<f32>::new(10, 64);
485 assert!(rel_pe.is_ok());
486
487 let mut rpe = rel_pe.unwrap();
488 assert!(rpe.initialize_sinusoidal().is_ok());
489
490 let encoding = rpe.get_relative_encoding(5, 8);
491 assert_eq!(encoding.len(), 64);
492 }
493
494 #[test]
495 fn test_encoding_types() {
496 let types = [
497 PositionalEncodingType::Sinusoidal,
498 PositionalEncodingType::Learned,
499 PositionalEncodingType::Rotary,
500 PositionalEncodingType::None,
501 ];
502
503 for encoding_type in types.iter() {
504 let pe = PositionalEncoding::<f32>::new(50, 32, *encoding_type);
505 assert!(pe.is_ok());
506 }
507 }
508}