#![allow(unused_variables)]
use crate::tensor::Tensor;
use anyhow::Result;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
use std::arch::x86_64::*;
pub struct VectorizedRoPE {
dim: usize,
#[allow(dead_code)]
max_seq_len: usize,
base: f32,
#[allow(dead_code)]
inv_freq: Tensor,
use_simd: bool,
}
impl VectorizedRoPE {
pub fn new(dim: usize, max_seq_len: usize, base: f32) -> Result<Self> {
let half_dim = dim / 2;
let inv_freq_vec: Vec<f32> =
(0..half_dim).map(|i| 1.0 / base.powf(2.0 * i as f32 / dim as f32)).collect();
let inv_freq = Tensor::new(inv_freq_vec)?;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
let use_simd = is_x86_feature_detected!("avx2") && dim >= 128 && dim.is_multiple_of(8);
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
let use_simd = false;
Ok(Self {
dim,
max_seq_len,
base,
inv_freq,
use_simd,
})
}
pub fn forward(&self, x: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
if self.use_simd {
self.forward_simd(x, position_ids)
} else {
self.forward_standard(x, position_ids)
}
}
fn forward_standard(&self, x: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
let batch_size = x.shape()[0];
let seq_len = x.shape()[1];
let hidden_size = x.shape()[2];
let cos_sin = self.get_cos_sin_embeddings(position_ids)?;
let half_dim = hidden_size / 2;
let x1 = x.slice(2, 0, half_dim)?;
let x2 = x.slice(2, half_dim, hidden_size)?;
let cos_vals = cos_sin.slice(3, 0, 1)?.squeeze(3)?;
let sin_vals = cos_sin.slice(3, 1, 2)?.squeeze(3)?;
let out1 = x1.mul(&cos_vals)?.sub(&x2.mul(&sin_vals)?)?;
let out2 = x1.mul(&sin_vals)?.add(&x2.mul(&cos_vals)?)?;
let result = Tensor::concat(&[out1, out2], 2)?;
Ok(result)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
#[allow(dead_code)]
unsafe fn forward_simd_inner(
&self,
x_data: &[f32],
cos_data: &[f32],
sin_data: &[f32],
output_data: &mut [f32],
batch_size: usize,
seq_len: usize,
num_heads: usize,
head_dim: usize,
) {
let half_dim = head_dim / 2;
for b in 0..batch_size {
for s in 0..seq_len {
for h in 0..num_heads {
let x_base = b * seq_len * num_heads * head_dim
+ s * num_heads * head_dim
+ h * head_dim;
let cos_sin_base = b * seq_len * half_dim + s * half_dim;
let x1_start = x_base;
let x2_start = x_base + half_dim;
let out1_start = x_base;
let out2_start = x_base + half_dim;
let mut i = 0;
while i + 8 <= half_dim {
let x1_vals = _mm256_loadu_ps(&x_data[x1_start + i]);
let x2_vals = _mm256_loadu_ps(&x_data[x2_start + i]);
let cos_vals = _mm256_loadu_ps(&cos_data[cos_sin_base + i]);
let sin_vals = _mm256_loadu_ps(&sin_data[cos_sin_base + i]);
let x1_cos = _mm256_mul_ps(x1_vals, cos_vals);
let x2_sin = _mm256_mul_ps(x2_vals, sin_vals);
let rotated_x1 = _mm256_sub_ps(x1_cos, x2_sin);
let x2_cos = _mm256_mul_ps(x2_vals, cos_vals);
let x1_sin = _mm256_mul_ps(x1_vals, sin_vals);
let rotated_x2 = _mm256_add_ps(x2_cos, x1_sin);
_mm256_storeu_ps(&mut output_data[out1_start + i], rotated_x1);
_mm256_storeu_ps(&mut output_data[out2_start + i], rotated_x2);
i += 8;
}
while i < half_dim {
let x1 = x_data[x1_start + i];
let x2 = x_data[x2_start + i];
let cos_val = cos_data[cos_sin_base + i];
let sin_val = sin_data[cos_sin_base + i];
output_data[out1_start + i] = x1 * cos_val - x2 * sin_val;
output_data[out2_start + i] = x2 * cos_val + x1 * sin_val;
i += 1;
}
}
}
}
}
fn forward_simd(&self, x: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
let batch_size = x.shape()[0];
let seq_len = x.shape()[1];
let hidden_size = x.shape()[2];
let cos_sin = self.get_cos_sin_embeddings(position_ids)?;
let cos_embed = cos_sin.slice(3, 0, 1)?; let sin_embed = cos_sin.slice(3, 1, 2)?;
let half_dim = hidden_size / 2;
let x_even = x.slice(2, 0, half_dim)?; let x_odd = x.slice(2, half_dim, hidden_size)?;
let cos_part = x_even.mul(&cos_embed.squeeze(3)?)?;
let sin_part = x_odd.mul(&sin_embed.squeeze(3)?)?;
let rotated_even = cos_part.sub(&sin_part)?;
let rotated_odd =
x_even.mul(&sin_embed.squeeze(3)?)?.add(&x_odd.mul(&cos_embed.squeeze(3)?)?)?;
Ok(Tensor::concat(&[rotated_even, rotated_odd], 2)?)
}
fn get_cos_sin_embeddings(&self, position_ids: &Tensor) -> Result<Tensor> {
let batch_size = position_ids.shape()[0];
let seq_len = position_ids.shape()[1];
let half_dim = self.dim / 2;
let mut freqs = Vec::with_capacity(half_dim);
for i in 0..half_dim {
let freq = 1.0 / self.base.powf(2.0 * (i as f32) / (self.dim as f32));
freqs.push(freq);
}
let freq_tensor = Tensor::from_vec(freqs, &[half_dim])?;
let position_data = position_ids.data()?;
let position_f32: Vec<f32> = position_data.into_iter().collect();
let position_ids_f32 = Tensor::from_vec(position_f32, &position_ids.shape())?;
let freq_expanded = freq_tensor.unsqueeze(0)?.unsqueeze(0)?;
let pos_expanded = position_ids_f32.unsqueeze(position_ids_f32.shape().len())?;
let pos_freqs = pos_expanded.mul(&freq_expanded)?;
let cos_embed = pos_freqs.cos()?;
let sin_embed = pos_freqs.sin()?;
let cos_expanded = cos_embed.unsqueeze(cos_embed.shape().len())?;
let sin_expanded = sin_embed.unsqueeze(sin_embed.shape().len())?;
Ok(Tensor::concat(&[cos_expanded, sin_expanded], 3)?)
}
}
pub struct OptimizedRoPE {
config: RoPEConfig,
precomputed_freqs: Option<Tensor>,
}
#[derive(Debug, Clone)]
pub struct RoPEConfig {
pub dim: usize,
pub max_seq_len: usize,
pub base: f32,
pub scaling_factor: f32,
pub scaling_type: RoPEScalingType,
}
#[derive(Debug, Clone, Copy)]
pub enum RoPEScalingType {
None,
Linear,
Dynamic,
}
impl OptimizedRoPE {
pub fn new(config: RoPEConfig) -> Result<Self> {
Ok(Self {
config,
precomputed_freqs: None,
})
}
pub fn with_precomputed_freqs(config: RoPEConfig) -> Result<Self> {
let precomputed = Self::precompute_frequencies(&config)?;
Ok(Self {
config,
precomputed_freqs: Some(precomputed),
})
}
pub fn enable_precomputation(&mut self) -> Result<()> {
if self.precomputed_freqs.is_none() {
self.precomputed_freqs = Some(Self::precompute_frequencies(&self.config)?);
}
Ok(())
}
fn precompute_frequencies(config: &RoPEConfig) -> Result<Tensor> {
let half_dim = config.dim / 2;
let mut freq_values = Vec::new();
for i in 0..half_dim {
let freq = 1.0 / config.base.powf(2.0 * i as f32 / config.dim as f32);
freq_values.push(freq);
}
let scaled_freqs: Vec<f32> = match config.scaling_type {
RoPEScalingType::None => freq_values,
RoPEScalingType::Linear => {
freq_values.into_iter().map(|f| f / config.scaling_factor).collect()
},
RoPEScalingType::Dynamic => {
freq_values
.into_iter()
.enumerate()
.map(|(i, f)| {
let scale = if i < half_dim / 2 { config.scaling_factor } else { 1.0 };
f / scale
})
.collect()
},
};
let mut cos_vals = Vec::new();
let mut sin_vals = Vec::new();
for pos in 0..config.max_seq_len {
for &freq in scaled_freqs.iter() {
let angle = pos as f32 * freq;
cos_vals.push(angle.cos());
sin_vals.push(angle.sin());
}
}
let all_vals = [cos_vals, sin_vals].concat();
Ok(Tensor::from_vec(
all_vals,
&[config.max_seq_len, half_dim * 2],
)?)
}
pub fn forward(&self, x: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
if let Some(ref precomputed) = self.precomputed_freqs {
self.forward_with_precomputed(x, position_ids, precomputed)
} else {
self.forward_dynamic(x, position_ids)
}
}
fn forward_with_precomputed(
&self,
x: &Tensor,
position_ids: &Tensor,
precomputed: &Tensor,
) -> Result<Tensor> {
let batch_size = x.shape()[0];
let seq_len = x.shape()[1];
let hidden_size = x.shape()[2];
let half_dim = hidden_size / 2;
let cos_precomputed = precomputed.slice(1, 0, half_dim)?;
let sin_precomputed = precomputed.slice(1, half_dim, half_dim * 2)?;
let position_indices =
position_ids.reshape(&[position_ids.shape().iter().product::<usize>()])?;
let cos_embed = cos_precomputed.clone() .reshape(&[batch_size, seq_len, half_dim])?;
let sin_embed = sin_precomputed.clone() .reshape(&[batch_size, seq_len, half_dim])?;
let x_first_half = x.slice(2, 0, half_dim)?;
let x_second_half = x.slice(2, half_dim, hidden_size)?;
let cos_result = x_first_half.mul(&cos_embed)?.sub(&x_second_half.mul(&sin_embed)?)?;
let sin_result = x_first_half.mul(&sin_embed)?.add(&x_second_half.mul(&cos_embed)?)?;
Ok(Tensor::concat(&[cos_result, sin_result], 2)?)
}
fn forward_dynamic(&self, x: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
let vectorized_rope =
VectorizedRoPE::new(self.config.dim, self.config.max_seq_len, self.config.base)?;
vectorized_rope.forward(x, position_ids)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vectorized_rope_creation() -> Result<()> {
let rope = VectorizedRoPE::new(128, 2048, 10000.0)?;
assert_eq!(rope.dim, 128);
assert_eq!(rope.max_seq_len, 2048);
Ok(())
}
#[test]
fn test_optimized_rope_creation() -> Result<()> {
let config = RoPEConfig {
dim: 128,
max_seq_len: 1024,
base: 10000.0,
scaling_factor: 1.0,
scaling_type: RoPEScalingType::None,
};
let rope = OptimizedRoPE::new(config)?;
assert_eq!(rope.config.dim, 128);
Ok(())
}
#[test]
fn test_rope_config() {
let config = RoPEConfig {
dim: 64,
max_seq_len: 512,
base: 10000.0,
scaling_factor: 1.0,
scaling_type: RoPEScalingType::Linear,
};
assert_eq!(config.dim, 64);
assert_eq!(config.max_seq_len, 512);
}
}