use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::{s, Array, Array2, Array3, Axis, IxDyn, Zip};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::{Rng, RngExt};
use std::f64::consts::PI;
use std::fmt::Debug;
pub trait PositionalEncoding<F: Float + Debug + NumAssign> {
fn encode(&self, seq_len: usize) -> Array2<F>;
fn apply(&self, input: &Array3<F>) -> Result<Array3<F>>;
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
if input.ndim() != 3 {
return Err(NeuralError::InvalidArchitecture(format!(
"Expected 3D input, got {}D",
input.ndim()
)));
}
let shape = input.shape();
let input_3d = input
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix3>()
.map_err(|e| {
NeuralError::InvalidArchitecture(format!("Failed to convert to 3D: {}", e))
})?;
let output_3d = self.apply(&input_3d.to_owned())?;
Ok(output_3d.into_dyn())
}
fn update(&mut self, _learning_rate: F) -> Result<()> {
Ok(())
}
fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
where
F: Send + Sync + 'static;
fn d_model(&self) -> usize;
fn max_len(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct SinusoidalPositionalEncoding<F: Float + Debug + NumAssign> {
d_model: usize,
max_len: usize,
encodings: Array2<F>,
dropout: Option<F>,
}
impl<F: Float + Debug + NumAssign> SinusoidalPositionalEncoding<F> {
pub fn new(d_model: usize, max_len: usize) -> Self {
assert!(
d_model.is_multiple_of(2),
"d_model must be even for sinusoidal PE"
);
let encodings = Self::compute_encodings(d_model, max_len);
Self {
d_model,
max_len,
encodings,
dropout: None,
}
}
pub fn with_dropout(d_model: usize, max_len: usize, dropout: F) -> Self {
let mut pe = Self::new(d_model, max_len);
pe.dropout = Some(dropout);
pe
}
fn compute_encodings(d_model: usize, max_len: usize) -> Array2<F> {
let mut encodings = Array2::zeros((max_len, d_model));
for pos in 0..max_len {
for i in 0..(d_model / 2) {
let exponent = (2 * i) as f64 / d_model as f64;
let div_term = (10000.0_f64).powf(exponent);
let angle = pos as f64 / div_term;
let sin_val = F::from(angle.sin()).unwrap_or(F::zero());
let cos_val = F::from(angle.cos()).unwrap_or(F::zero());
encodings[[pos, 2 * i]] = sin_val;
encodings[[pos, 2 * i + 1]] = cos_val;
}
}
encodings
}
pub fn params(&self) -> Vec<&Array<F, IxDyn>> {
Vec::new()
}
pub fn set_training(&mut self, _training: bool) {
}
}
impl<F: Float + Debug + NumAssign> PositionalEncoding<F> for SinusoidalPositionalEncoding<F> {
fn encode(&self, seq_len: usize) -> Array2<F> {
assert!(
seq_len <= self.max_len,
"seq_len {} exceeds max_len {}",
seq_len,
self.max_len
);
self.encodings.slice(s![..seq_len, ..]).to_owned()
}
fn apply(&self, input: &Array3<F>) -> Result<Array3<F>> {
let seq_len = input.shape()[1];
if seq_len > self.max_len {
return Err(NeuralError::InvalidArchitecture(format!(
"Sequence length {} exceeds max_len {}",
seq_len, self.max_len
)));
}
let encoding = self.encode(seq_len);
let mut output = input.clone();
for mut batch in output.axis_iter_mut(Axis(0)) {
Zip::from(&mut batch)
.and(&encoding)
.for_each(|b, &e| *b += e);
}
Ok(output)
}
fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
where
F: Send + Sync + 'static,
{
Box::new(self.clone())
}
fn d_model(&self) -> usize {
self.d_model
}
fn max_len(&self) -> usize {
self.max_len
}
}
#[derive(Debug, Clone)]
pub struct LearnedPositionalEncoding<F: Float + Debug + NumAssign> {
d_model: usize,
max_len: usize,
embeddings: Array2<F>,
}
impl<F: Float + Debug + NumAssign> LearnedPositionalEncoding<F> {
pub fn new<R: Rng>(d_model: usize, max_len: usize, rng: &mut R) -> Self {
let std = (2.0 / (max_len + d_model) as f64).sqrt();
let mut embeddings = Array2::zeros((max_len, d_model));
for elem in embeddings.iter_mut() {
let u1: f64 = rng.random_range(0.0001..1.0);
let u2: f64 = rng.random_range(0.0..1.0);
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
*elem = F::from(z * std).unwrap_or(F::zero());
}
Self {
d_model,
max_len,
embeddings,
}
}
pub fn from_embeddings(embeddings: Array2<F>) -> Self {
let shape = embeddings.shape();
Self {
d_model: shape[1],
max_len: shape[0],
embeddings,
}
}
pub fn embeddings_mut(&mut self) -> &mut Array2<F> {
&mut self.embeddings
}
pub fn embeddings(&self) -> &Array2<F> {
&self.embeddings
}
}
impl<F: Float + Debug + NumAssign> PositionalEncoding<F> for LearnedPositionalEncoding<F> {
fn encode(&self, seq_len: usize) -> Array2<F> {
assert!(
seq_len <= self.max_len,
"seq_len {} exceeds max_len {}",
seq_len,
self.max_len
);
self.embeddings.slice(s![..seq_len, ..]).to_owned()
}
fn apply(&self, input: &Array3<F>) -> Result<Array3<F>> {
let seq_len = input.shape()[1];
if seq_len > self.max_len {
return Err(NeuralError::InvalidArchitecture(format!(
"Sequence length {} exceeds max_len {}",
seq_len, self.max_len
)));
}
let encoding = self.encode(seq_len);
let mut output = input.clone();
for mut batch in output.axis_iter_mut(Axis(0)) {
Zip::from(&mut batch)
.and(&encoding)
.for_each(|b, &e| *b += e);
}
Ok(output)
}
fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
where
F: Send + Sync + 'static,
{
Box::new(self.clone())
}
fn d_model(&self) -> usize {
self.d_model
}
fn max_len(&self) -> usize {
self.max_len
}
}
#[derive(Debug, Clone)]
pub struct RotaryPositionalEncoding<F: Float + Debug + NumAssign> {
d_model: usize,
max_len: usize,
base: f64,
sin_cached: Array2<F>,
cos_cached: Array2<F>,
}
impl<F: Float + Debug + NumAssign> RotaryPositionalEncoding<F> {
pub fn new(d_model: usize, max_len: usize, base: f64) -> Self {
assert!(d_model.is_multiple_of(2), "d_model must be even for RoPE");
let (sin_cached, cos_cached) = Self::compute_rope_cache(d_model, max_len, base);
Self {
d_model,
max_len,
base,
sin_cached,
cos_cached,
}
}
pub fn default_base(d_model: usize, max_len: usize) -> Self {
Self::new(d_model, max_len, 10000.0)
}
fn compute_rope_cache(d_model: usize, max_len: usize, base: f64) -> (Array2<F>, Array2<F>) {
let half_dim = d_model / 2;
let mut sin_cached = Array2::zeros((max_len, half_dim));
let mut cos_cached = Array2::zeros((max_len, half_dim));
for pos in 0..max_len {
for i in 0..half_dim {
let freq = 1.0 / base.powf((2 * i) as f64 / d_model as f64);
let angle = pos as f64 * freq;
sin_cached[[pos, i]] = F::from(angle.sin()).unwrap_or(F::zero());
cos_cached[[pos, i]] = F::from(angle.cos()).unwrap_or(F::zero());
}
}
(sin_cached, cos_cached)
}
pub fn rotate(&self, x: &Array3<F>, offset: usize) -> Result<Array3<F>> {
let seq_len = x.shape()[1];
if seq_len + offset > self.max_len {
return Err(NeuralError::InvalidArchitecture(format!(
"Position {} exceeds max_len {}",
seq_len + offset,
self.max_len
)));
}
let batch_size = x.shape()[0];
let half_dim = self.d_model / 2;
let mut output = Array3::zeros(x.raw_dim());
for b in 0..batch_size {
for pos in 0..seq_len {
let abs_pos = pos + offset;
for i in 0..half_dim {
let x1 = x[[b, pos, 2 * i]];
let x2 = x[[b, pos, 2 * i + 1]];
let cos = self.cos_cached[[abs_pos, i]];
let sin = self.sin_cached[[abs_pos, i]];
output[[b, pos, 2 * i]] = x1 * cos - x2 * sin;
output[[b, pos, 2 * i + 1]] = x1 * sin + x2 * cos;
}
}
}
Ok(output)
}
pub fn sin_cache(&self) -> &Array2<F> {
&self.sin_cached
}
pub fn cos_cache(&self) -> &Array2<F> {
&self.cos_cached
}
}
impl<F: Float + Debug + NumAssign> PositionalEncoding<F> for RotaryPositionalEncoding<F> {
fn encode(&self, seq_len: usize) -> Array2<F> {
let half_dim = self.d_model / 2;
let mut encoding = Array2::zeros((seq_len, self.d_model));
for pos in 0..seq_len {
for i in 0..half_dim {
encoding[[pos, 2 * i]] = self.sin_cached[[pos, i]];
encoding[[pos, 2 * i + 1]] = self.cos_cached[[pos, i]];
}
}
encoding
}
fn apply(&self, input: &Array3<F>) -> Result<Array3<F>> {
self.rotate(input, 0)
}
fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
where
F: Send + Sync + 'static,
{
Box::new(self.clone())
}
fn d_model(&self) -> usize {
self.d_model
}
fn max_len(&self) -> usize {
self.max_len
}
}
#[derive(Debug, Clone)]
pub struct RelativePositionalEncoding<F: Float + Debug + NumAssign> {
d_model: usize,
max_len: usize,
rel_embeddings: Array2<F>,
}
impl<F: Float + Debug + NumAssign> RelativePositionalEncoding<F> {
pub fn new<R: Rng>(d_model: usize, max_len: usize, rng: &mut R) -> Self {
let num_positions = 2 * max_len - 1;
let std = (1.0 / d_model as f64).sqrt();
let mut rel_embeddings = Array2::zeros((num_positions, d_model));
for elem in rel_embeddings.iter_mut() {
let u1: f64 = rng.random_range(0.0001..1.0);
let u2: f64 = rng.random_range(0.0..1.0);
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
*elem = F::from(z * std).unwrap_or(F::zero());
}
Self {
d_model,
max_len,
rel_embeddings,
}
}
pub fn get_relative_embedding(&self, rel_pos: i64) -> Option<Array<F, IxDyn>> {
let max_rel = self.max_len as i64 - 1;
if rel_pos < -max_rel || rel_pos > max_rel {
return None;
}
let idx = (rel_pos + max_rel) as usize;
Some(self.rel_embeddings.slice(s![idx, ..]).to_owned().into_dyn())
}
pub fn get_attention_bias(&self, query_len: usize, key_len: usize) -> Result<Array3<F>> {
if query_len > self.max_len || key_len > self.max_len {
return Err(NeuralError::InvalidArchitecture(format!(
"Sequence length exceeds max_len {}",
self.max_len
)));
}
let mut bias = Array3::zeros((query_len, key_len, self.d_model));
let max_rel = self.max_len as i64 - 1;
for q in 0..query_len {
for k in 0..key_len {
let rel_pos = k as i64 - q as i64;
let idx = (rel_pos + max_rel) as usize;
for d in 0..self.d_model {
bias[[q, k, d]] = self.rel_embeddings[[idx, d]];
}
}
}
Ok(bias)
}
pub fn embeddings_mut(&mut self) -> &mut Array2<F> {
&mut self.rel_embeddings
}
}
impl<F: Float + Debug + NumAssign> PositionalEncoding<F> for RelativePositionalEncoding<F> {
fn encode(&self, seq_len: usize) -> Array2<F> {
let start = self.max_len - 1;
self.rel_embeddings
.slice(s![start..(start + seq_len), ..])
.to_owned()
}
fn apply(&self, input: &Array3<F>) -> Result<Array3<F>> {
let seq_len = input.shape()[1];
if seq_len > self.max_len {
return Err(NeuralError::InvalidArchitecture(format!(
"Sequence length {} exceeds max_len {}",
seq_len, self.max_len
)));
}
let encoding = self.encode(seq_len);
let mut output = input.clone();
for mut batch in output.axis_iter_mut(Axis(0)) {
Zip::from(&mut batch)
.and(&encoding)
.for_each(|b, &e| *b += e);
}
Ok(output)
}
fn clone_box(&self) -> Box<dyn PositionalEncoding<F> + Send + Sync>
where
F: Send + Sync + 'static,
{
Box::new(self.clone())
}
fn d_model(&self) -> usize {
self.d_model
}
fn max_len(&self) -> usize {
self.max_len
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PositionalEncodingType {
Sinusoidal,
Learned,
Rotary,
Relative,
}
pub struct PositionalEncodingFactory;
impl PositionalEncodingFactory {
pub fn create<F, R>(
pe_type: PositionalEncodingType,
d_model: usize,
max_len: usize,
rng: &mut R,
) -> Box<dyn PositionalEncoding<F> + Send + Sync>
where
F: Float + Debug + NumAssign + Send + Sync + 'static,
R: Rng,
{
match pe_type {
PositionalEncodingType::Sinusoidal => {
Box::new(SinusoidalPositionalEncoding::new(d_model, max_len))
}
PositionalEncodingType::Learned => {
Box::new(LearnedPositionalEncoding::new(d_model, max_len, rng))
}
PositionalEncodingType::Rotary => {
Box::new(RotaryPositionalEncoding::default_base(d_model, max_len))
}
PositionalEncodingType::Relative => {
Box::new(RelativePositionalEncoding::new(d_model, max_len, rng))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array3;
use scirs2_core::random::SeedableRng;
#[test]
fn test_sinusoidal_encoding_shape() {
let pe = SinusoidalPositionalEncoding::<f32>::new(64, 100);
let encoding = pe.encode(10);
assert_eq!(encoding.shape(), &[10, 64]);
let encoding = pe.encode(50);
assert_eq!(encoding.shape(), &[50, 64]);
}
#[test]
fn test_sinusoidal_encoding_values() {
let pe = SinusoidalPositionalEncoding::<f64>::new(4, 10);
let encoding = pe.encode(3);
assert!((encoding[[0, 0]] - 0.0).abs() < 1e-6); assert!((encoding[[0, 1]] - 1.0).abs() < 1e-6);
assert!((encoding[[0, 0]] - encoding[[1, 0]]).abs() > 1e-10);
}
#[test]
fn test_sinusoidal_apply() {
let pe = SinusoidalPositionalEncoding::<f32>::new(8, 20);
let input = Array3::zeros((2, 10, 8)); let output = pe.apply(&input).expect("Operation failed");
assert_eq!(output.shape(), input.shape());
let encoding = pe.encode(10);
for b in 0..2 {
for s in 0..10 {
for d in 0..8 {
assert!((output[[b, s, d]] - encoding[[s, d]]).abs() < 1e-6);
}
}
}
}
#[test]
fn test_learned_encoding() {
let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
let pe = LearnedPositionalEncoding::<f32>::new(32, 50, &mut rng);
let encoding = pe.encode(10);
assert_eq!(encoding.shape(), &[10, 32]);
let sum: f32 = encoding.iter().map(|x| x.abs()).sum();
assert!(sum > 0.1);
}
#[test]
fn test_learned_from_embeddings() {
let embeddings = Array2::ones((20, 16));
let pe = LearnedPositionalEncoding::<f32>::from_embeddings(embeddings);
assert_eq!(pe.d_model(), 16);
assert_eq!(pe.max_len(), 20);
}
#[test]
fn test_rope_encoding() {
let pe = RotaryPositionalEncoding::<f32>::default_base(64, 100);
let encoding = pe.encode(10);
assert_eq!(encoding.shape(), &[10, 64]);
}
#[test]
fn test_rope_rotate() {
let pe = RotaryPositionalEncoding::<f64>::default_base(8, 20);
let input = Array3::ones((1, 5, 8));
let rotated = pe.rotate(&input, 0).expect("Operation failed");
assert_eq!(rotated.shape(), input.shape());
let mut different = false;
for pos in 1..5 {
for i in 0..8 {
if (rotated[[0, pos, i]] - input[[0, pos, i]]).abs() > 1e-6 {
different = true;
break;
}
}
if different {
break;
}
}
assert!(
different,
"RoPE should modify input values at non-zero positions"
);
}
#[test]
fn test_rope_with_offset() {
let pe = RotaryPositionalEncoding::<f32>::default_base(8, 100);
let input = Array3::ones((1, 10, 8));
let rotated_0 = pe.rotate(&input, 0).expect("Operation failed");
let rotated_5 = pe.rotate(&input, 5).expect("Operation failed");
let mut different = false;
for s in 0..10 {
for d in 0..8 {
if (rotated_0[[0, s, d]] - rotated_5[[0, s, d]]).abs() > 1e-6 {
different = true;
break;
}
}
}
assert!(different, "Different offsets should give different results");
}
#[test]
fn test_relative_encoding() {
let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
let pe = RelativePositionalEncoding::<f32>::new(16, 30, &mut rng);
let rel_0 = pe.get_relative_embedding(0);
assert!(rel_0.is_some());
let rel_pos = pe.get_relative_embedding(5);
assert!(rel_pos.is_some());
let rel_neg = pe.get_relative_embedding(-5);
assert!(rel_neg.is_some());
let out_of_range = pe.get_relative_embedding(100);
assert!(out_of_range.is_none());
}
#[test]
fn test_relative_attention_bias() {
let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
let pe = RelativePositionalEncoding::<f32>::new(8, 20, &mut rng);
let bias = pe.get_attention_bias(10, 10).expect("Operation failed");
assert_eq!(bias.shape(), &[10, 10, 8]);
let rel_0 = pe.get_relative_embedding(0).expect("Operation failed");
for i in 0..10 {
for d in 0..8 {
assert!((bias[[i, i, d]] - rel_0[[d]]).abs() < 1e-6);
}
}
}
#[test]
fn test_factory() {
let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
let sinusoidal = PositionalEncodingFactory::create::<f32, _>(
PositionalEncodingType::Sinusoidal,
32,
100,
&mut rng,
);
assert_eq!(sinusoidal.d_model(), 32);
let learned = PositionalEncodingFactory::create::<f32, _>(
PositionalEncodingType::Learned,
32,
100,
&mut rng,
);
assert_eq!(learned.d_model(), 32);
let rotary = PositionalEncodingFactory::create::<f32, _>(
PositionalEncodingType::Rotary,
32,
100,
&mut rng,
);
assert_eq!(rotary.d_model(), 32);
let relative = PositionalEncodingFactory::create::<f32, _>(
PositionalEncodingType::Relative,
32,
100,
&mut rng,
);
assert_eq!(relative.d_model(), 32);
}
#[test]
fn test_sinusoidal_properties() {
let pe = SinusoidalPositionalEncoding::<f64>::new(64, 1000);
let encoding = pe.encode(100);
for i in 0..99 {
let mut same = true;
for d in 0..64 {
if (encoding[[i, d]] - encoding[[i + 1, d]]).abs() > 1e-10 {
same = false;
break;
}
}
assert!(!same, "Adjacent positions should be different");
}
}
#[test]
fn test_max_len_error() {
let pe = SinusoidalPositionalEncoding::<f32>::new(16, 10);
let input = Array3::zeros((1, 20, 16)); let result = pe.apply(&input);
assert!(result.is_err());
}
}