use super::module::Module;
use crate::tensor::{GraphContext, Tensor};
use ndarray::{ArrayD, IxDyn};
use std::cell::RefCell;
use std::rc::Rc;
#[derive(Debug, Clone)]
pub struct SinusoidalPositionalEncoding {
pub d_model: usize,
pub max_len: usize,
pub encoding: Tensor,
}
impl SinusoidalPositionalEncoding {
pub fn new(
context: &Rc<RefCell<GraphContext>>,
d_model: usize,
max_len: usize,
name: &str,
) -> Self {
let encoding_data = Self::compute_encoding(d_model, max_len);
let encoding = Tensor::new_literal(context, encoding_data, name);
Self {
d_model,
max_len,
encoding,
}
}
fn compute_encoding(d_model: usize, max_len: usize) -> ArrayD<f32> {
let mut encoding = ArrayD::zeros(IxDyn(&[max_len, d_model]));
for pos in 0..max_len {
for i in 0..(d_model / 2) {
let div_term = (10000.0_f32).powf((2 * i) as f32 / d_model as f32);
let angle = pos as f32 / div_term;
encoding[[pos, 2 * i]] = angle.sin();
encoding[[pos, 2 * i + 1]] = angle.cos();
}
if d_model % 2 == 1 {
let div_term = (10000.0_f32).powf((d_model - 1) as f32 / d_model as f32);
let angle = pos as f32 / div_term;
encoding[[pos, d_model - 1]] = angle.sin();
}
}
encoding
}
pub fn get_encoding(&self, _seq_len: usize) -> &Tensor {
&self.encoding
}
}
impl Module for SinusoidalPositionalEncoding {
fn forward(&self, input: &Tensor) -> Tensor {
input + &self.encoding
}
fn parameters(&self) -> Vec<Tensor> {
vec![]
}
}
#[derive(Debug, Clone)]
pub struct LearnedPositionalEmbedding {
pub max_len: usize,
pub embedding_dim: usize,
pub weight: Tensor,
}
impl LearnedPositionalEmbedding {
pub fn new(
context: &Rc<RefCell<GraphContext>>,
max_len: usize,
embedding_dim: usize,
name: &str,
) -> Self {
let weight = Tensor::new_parameter_with_shape(
context,
&format!("{}_weight", name),
vec![max_len, embedding_dim],
crate::nn::init::Initializer::Normal {
mean: 0.0,
std: 0.02,
},
);
Self {
max_len,
embedding_dim,
weight,
}
}
pub fn create_position_ids(context: &Rc<RefCell<GraphContext>>, seq_len: usize) -> Tensor {
let positions: Vec<f32> = (0..seq_len).map(|i| i as f32).collect();
let data = ArrayD::from_shape_vec(IxDyn(&[seq_len]), positions).unwrap();
Tensor::new_literal(context, data, "position_ids")
}
}
impl Module for LearnedPositionalEmbedding {
fn forward(&self, position_ids: &Tensor) -> Tensor {
position_ids.embedding(&self.weight)
}
fn parameters(&self) -> Vec<Tensor> {
vec![self.weight.clone()]
}
}
pub fn create_position_ids(
context: &Rc<RefCell<GraphContext>>,
batch_size: usize,
seq_len: usize,
) -> Tensor {
let mut positions = Vec::with_capacity(batch_size * seq_len);
for _ in 0..batch_size {
for pos in 0..seq_len {
positions.push(pos as f32);
}
}
let data = ArrayD::from_shape_vec(IxDyn(&[batch_size, seq_len]), positions).unwrap();
Tensor::new_literal(context, data, "position_ids")
}
#[derive(Debug, Clone)]
pub struct RotaryPositionEmbedding {
pub head_dim: usize,
pub max_len: usize,
pub base: f32,
cos_cached: ArrayD<f32>,
sin_cached: ArrayD<f32>,
context: Rc<RefCell<GraphContext>>,
}
impl RotaryPositionEmbedding {
pub fn new(
context: &Rc<RefCell<GraphContext>>,
head_dim: usize,
max_len: usize,
_name: &str,
) -> Self {
Self::with_base(context, head_dim, max_len, 10000.0, _name)
}
pub fn with_base(
context: &Rc<RefCell<GraphContext>>,
head_dim: usize,
max_len: usize,
base: f32,
_name: &str,
) -> Self {
assert!(head_dim % 2 == 0, "head_dim must be even for RoPE");
let (cos_cached, sin_cached) = Self::precompute_freqs(head_dim, max_len, base);
Self {
head_dim,
max_len,
base,
cos_cached,
sin_cached,
context: Rc::clone(context),
}
}
fn precompute_freqs(head_dim: usize, max_len: usize, base: f32) -> (ArrayD<f32>, ArrayD<f32>) {
let half_dim = head_dim / 2;
let inv_freq: Vec<f32> = (0..half_dim)
.map(|i| 1.0 / base.powf(2.0 * i as f32 / head_dim as f32))
.collect();
let mut cos_data = vec![0.0f32; max_len * half_dim];
let mut sin_data = vec![0.0f32; max_len * half_dim];
for pos in 0..max_len {
for i in 0..half_dim {
let angle = pos as f32 * inv_freq[i];
cos_data[pos * half_dim + i] = angle.cos();
sin_data[pos * half_dim + i] = angle.sin();
}
}
let cos_arr = ArrayD::from_shape_vec(IxDyn(&[max_len, half_dim]), cos_data).unwrap();
let sin_arr = ArrayD::from_shape_vec(IxDyn(&[max_len, half_dim]), sin_data).unwrap();
(cos_arr, sin_arr)
}
pub fn apply(&self, query: &Tensor, key: &Tensor, seq_offset: usize) -> (Tensor, Tensor) {
let q_rot = self.rotate_half(query, seq_offset);
let k_rot = self.rotate_half(key, seq_offset);
(q_rot, k_rot)
}
fn rotate_half(&self, x: &Tensor, _seq_offset: usize) -> Tensor {
let half_dim = self.head_dim / 2;
let cos_tensor = Tensor::new_literal(
&self.context,
self.cos_cached
.clone()
.into_shape_with_order(IxDyn(&[1, 1, self.max_len, half_dim]))
.unwrap(),
"rope_cos",
);
let sin_tensor = Tensor::new_literal(
&self.context,
self.sin_cached
.clone()
.into_shape_with_order(IxDyn(&[1, 1, self.max_len, half_dim]))
.unwrap(),
"rope_sin",
);
let x1 = x.slice(3, 0, half_dim);
let x2 = x.slice(3, half_dim, self.head_dim);
let rot1 = &(&x1 * &cos_tensor) - &(&x2 * &sin_tensor);
let rot2 = &(&x1 * &sin_tensor) + &(&x2 * &cos_tensor);
rot1.concat(&[&rot2], 3)
}
pub fn get_cos(&self) -> &ArrayD<f32> {
&self.cos_cached
}
pub fn get_sin(&self) -> &ArrayD<f32> {
&self.sin_cached
}
}
#[derive(Debug, Clone)]
pub struct ALiBi {
pub num_heads: usize,
slopes: Vec<f32>,
context: Rc<RefCell<GraphContext>>,
}
impl ALiBi {
pub fn new(context: &Rc<RefCell<GraphContext>>, num_heads: usize) -> Self {
let slopes = Self::compute_slopes(num_heads);
Self {
num_heads,
slopes,
context: Rc::clone(context),
}
}
fn compute_slopes(num_heads: usize) -> Vec<f32> {
let ratio = 8.0 / num_heads as f32;
(1..=num_heads)
.map(|i| 2.0_f32.powf(-ratio * i as f32))
.collect()
}
pub fn get_bias(&self, seq_len: usize) -> Tensor {
let mut bias_data = vec![0.0f32; self.num_heads * seq_len * seq_len];
for h in 0..self.num_heads {
let slope = self.slopes[h];
for i in 0..seq_len {
for j in 0..seq_len {
let distance = (i as i64 - j as i64).abs() as f32;
let idx = h * seq_len * seq_len + i * seq_len + j;
bias_data[idx] = -slope * distance;
}
}
}
let bias_arr =
ArrayD::from_shape_vec(IxDyn(&[1, self.num_heads, seq_len, seq_len]), bias_data)
.unwrap();
Tensor::new_literal(&self.context, bias_arr, "alibi_bias")
}
pub fn get_causal_bias(&self, seq_len: usize) -> Tensor {
let mut bias_data = vec![0.0f32; self.num_heads * seq_len * seq_len];
for h in 0..self.num_heads {
let slope = self.slopes[h];
for i in 0..seq_len {
for j in 0..seq_len {
let idx = h * seq_len * seq_len + i * seq_len + j;
if j > i {
bias_data[idx] = -1e9;
} else {
let distance = (i - j) as f32;
bias_data[idx] = -slope * distance;
}
}
}
}
let bias_arr =
ArrayD::from_shape_vec(IxDyn(&[1, self.num_heads, seq_len, seq_len]), bias_data)
.unwrap();
Tensor::new_literal(&self.context, bias_arr, "alibi_causal_bias")
}
pub fn get_slopes(&self) -> &[f32] {
&self.slopes
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::asg::{NodeType, Value};
use crate::runtime::{backend::Backend, cpu_backend::CpuBackend};
use std::collections::HashMap;
#[test]
fn test_rope_creation() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let rope = RotaryPositionEmbedding::new(&context, 64, 2048, "rope");
assert_eq!(rope.head_dim, 64);
assert_eq!(rope.max_len, 2048);
assert_eq!(rope.base, 10000.0);
}
#[test]
fn test_rope_precompute_freqs() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let rope = RotaryPositionEmbedding::new(&context, 4, 10, "rope");
let cos = rope.get_cos();
let sin = rope.get_sin();
assert_eq!(cos.shape(), &[10, 2]);
assert_eq!(sin.shape(), &[10, 2]);
assert!((cos[[0, 0]] - 1.0).abs() < 1e-5);
assert!((cos[[0, 1]] - 1.0).abs() < 1e-5);
assert!((sin[[0, 0]] - 0.0).abs() < 1e-5);
assert!((sin[[0, 1]] - 0.0).abs() < 1e-5);
}
#[test]
fn test_rope_with_custom_base() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let rope = RotaryPositionEmbedding::with_base(&context, 64, 1024, 500000.0, "rope");
assert_eq!(rope.base, 500000.0);
}
#[test]
fn test_alibi_creation() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let alibi = ALiBi::new(&context, 8);
assert_eq!(alibi.num_heads, 8);
assert_eq!(alibi.slopes.len(), 8);
}
#[test]
fn test_alibi_slopes() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let alibi = ALiBi::new(&context, 8);
let slopes = alibi.get_slopes();
assert!((slopes[0] - 0.5).abs() < 1e-5); assert!((slopes[1] - 0.25).abs() < 1e-5); assert!((slopes[7] - 2.0_f32.powf(-8.0)).abs() < 1e-6); }
#[test]
fn test_alibi_bias_shape() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let alibi = ALiBi::new(&context, 4);
let bias = alibi.get_bias(16);
let graph = context.borrow();
let main_graph = graph.main_graph();
let node = main_graph.get_node(bias.node_id).unwrap();
if let NodeType::Literal(Value::Tensor(arr)) = &node.node_type {
assert_eq!(arr.shape(), &[1, 4, 16, 16]);
assert_eq!(arr[[0, 0, 0, 0]], 0.0);
assert_eq!(arr[[0, 0, 5, 5]], 0.0);
assert!((arr[[0, 0, 0, 1]] + 0.25).abs() < 1e-5);
} else {
panic!("Expected Literal tensor");
}
}
#[test]
fn test_alibi_causal_bias() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let alibi = ALiBi::new(&context, 2);
let bias = alibi.get_causal_bias(4);
let graph = context.borrow();
let main_graph = graph.main_graph();
let node = main_graph.get_node(bias.node_id).unwrap();
if let NodeType::Literal(Value::Tensor(arr)) = &node.node_type {
assert!(arr[[0, 0, 0, 1]] < -1e8);
assert!(arr[[0, 0, 0, 2]] < -1e8);
assert!(arr[[0, 0, 1, 2]] < -1e8);
assert_eq!(arr[[0, 0, 0, 0]], 0.0); assert!(arr[[0, 0, 1, 0]] < 0.0); } else {
panic!("Expected Literal tensor");
}
}
#[test]
fn test_sinusoidal_encoding_shape() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let pos_enc = SinusoidalPositionalEncoding::new(&context, 64, 100, "pos_enc");
assert_eq!(pos_enc.d_model, 64);
assert_eq!(pos_enc.max_len, 100);
assert!(pos_enc.parameters().is_empty());
}
#[test]
fn test_sinusoidal_encoding_values() {
let encoding = SinusoidalPositionalEncoding::compute_encoding(4, 3);
assert!((encoding[[0, 0]] - 0.0).abs() < 1e-5); assert!((encoding[[0, 1]] - 1.0).abs() < 1e-5); assert!((encoding[[0, 2]] - 0.0).abs() < 1e-5); assert!((encoding[[0, 3]] - 1.0).abs() < 1e-5);
assert!((encoding[[1, 0]] - 1.0_f32.sin()).abs() < 1e-5);
assert!((encoding[[1, 1]] - 1.0_f32.cos()).abs() < 1e-5);
}
#[test]
fn test_learned_positional_embedding_creation() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let pos_emb = LearnedPositionalEmbedding::new(&context, 512, 256, "pos_emb");
assert_eq!(pos_emb.max_len, 512);
assert_eq!(pos_emb.embedding_dim, 256);
assert_eq!(pos_emb.parameters().len(), 1);
}
#[test]
fn test_create_position_ids() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let pos_ids = create_position_ids(&context, 2, 4);
context
.borrow_mut()
.main_graph_mut()
.set_output(pos_ids.node_id);
let backend = CpuBackend::new();
let graph = context.borrow().main_graph().clone();
let (results, _) = backend.run(&graph, HashMap::new()).unwrap();
if let Value::Tensor(arr) = &results[0] {
assert_eq!(arr.shape(), &[2, 4]);
assert!((arr[[0, 0]] - 0.0).abs() < 1e-5);
assert!((arr[[0, 1]] - 1.0).abs() < 1e-5);
assert!((arr[[0, 2]] - 2.0).abs() < 1e-5);
assert!((arr[[0, 3]] - 3.0).abs() < 1e-5);
assert!((arr[[1, 0]] - 0.0).abs() < 1e-5);
assert!((arr[[1, 1]] - 1.0).abs() < 1e-5);
} else {
panic!("Expected tensor");
}
}
#[test]
fn test_learned_positional_embedding_forward() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let pos_emb = LearnedPositionalEmbedding::new(&context, 5, 3, "pos_emb");
let pos_ids = Tensor::new_input(&context, "pos_ids");
let output = pos_emb.forward(&pos_ids);
context
.borrow_mut()
.main_graph_mut()
.set_output(output.node_id);
let weight_data =
ArrayD::from_shape_vec(IxDyn(&[5, 3]), (0..15).map(|x| x as f32).collect()).unwrap();
let pos_ids_data = ArrayD::from_shape_vec(IxDyn(&[4]), vec![0.0, 2.0, 4.0, 1.0]).unwrap();
let mut inputs = HashMap::new();
inputs.insert("pos_ids".to_string(), Value::Tensor(pos_ids_data));
inputs.insert("pos_emb_weight".to_string(), Value::Tensor(weight_data));
let backend = CpuBackend::new();
let device_data = backend.load_data(&inputs).unwrap();
let mut memo = HashMap::new();
for (name, value) in device_data {
let node_id = context
.borrow()
.main_graph()
.nodes
.iter()
.find(|(_, node)| {
matches!(&node.node_type,
NodeType::Input { name: n } |
NodeType::Parameter { name: n } if n == &name
)
})
.map(|(id, _)| *id);
if let Some(id) = node_id {
memo.insert((0, id), value);
}
}
let graph = context.borrow().main_graph().clone();
let (results, _) = backend.run(&graph, memo).unwrap();
if let Value::Tensor(arr) = &results[0] {
assert_eq!(arr.shape(), &[4, 3]);
assert!((arr[[0, 0]] - 0.0).abs() < 1e-5);
assert!((arr[[1, 0]] - 6.0).abs() < 1e-5);
assert!((arr[[2, 0]] - 12.0).abs() < 1e-5);
assert!((arr[[3, 0]] - 3.0).abs() < 1e-5);
} else {
panic!("Expected tensor");
}
}
}