use serde::{Deserialize, Serialize};
use tensorlogic_ir::{EinsumGraph, EinsumNode};
use crate::error::{Result, TrustformerError};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct PositionEncodingConfig {
pub d_model: usize,
pub max_seq_len: usize,
pub encoding_type: PositionEncodingType,
pub dropout: f64,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum PositionEncodingType {
Sinusoidal {
base: f64,
},
Learned,
Relative {
num_buckets: usize,
max_distance: usize,
},
Rotary {
base: f64,
scaling_factor: f64,
},
Alibi {
n_heads: usize,
max_seq_len: usize,
},
}
impl PositionEncodingConfig {
pub fn sinusoidal(d_model: usize, max_seq_len: usize) -> Self {
Self {
d_model,
max_seq_len,
encoding_type: PositionEncodingType::Sinusoidal { base: 10000.0 },
dropout: 0.0,
}
}
pub fn learned(d_model: usize, max_seq_len: usize) -> Self {
Self {
d_model,
max_seq_len,
encoding_type: PositionEncodingType::Learned,
dropout: 0.0,
}
}
pub fn relative(d_model: usize, num_buckets: usize, max_distance: usize) -> Self {
Self {
d_model,
max_seq_len: 0, encoding_type: PositionEncodingType::Relative {
num_buckets,
max_distance,
},
dropout: 0.0,
}
}
pub fn rotary(d_model: usize, max_seq_len: usize) -> Self {
Self {
d_model,
max_seq_len,
encoding_type: PositionEncodingType::Rotary {
base: 10000.0,
scaling_factor: 1.0,
},
dropout: 0.0,
}
}
pub fn rotary_scaled(
d_model: usize,
max_seq_len: usize,
base: f64,
scaling_factor: f64,
) -> Self {
Self {
d_model,
max_seq_len,
encoding_type: PositionEncodingType::Rotary {
base,
scaling_factor,
},
dropout: 0.0,
}
}
pub fn alibi(d_model: usize, n_heads: usize, max_seq_len: usize) -> Self {
Self {
d_model,
max_seq_len,
encoding_type: PositionEncodingType::Alibi {
n_heads,
max_seq_len,
},
dropout: 0.0,
}
}
pub fn with_dropout(mut self, dropout: f64) -> Self {
self.dropout = dropout;
self
}
pub fn validate(&self) -> Result<()> {
if self.d_model == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "d_model must be positive".to_string(),
});
}
if !(0.0..=1.0).contains(&self.dropout) {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: format!("dropout must be in [0,1], got {}", self.dropout),
});
}
match &self.encoding_type {
PositionEncodingType::Sinusoidal { base } => {
if *base <= 0.0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "base must be positive".to_string(),
});
}
}
PositionEncodingType::Relative {
num_buckets,
max_distance,
} => {
if *num_buckets == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "num_buckets must be positive".to_string(),
});
}
if *max_distance == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "max_distance must be positive".to_string(),
});
}
}
PositionEncodingType::Learned => {
if self.max_seq_len == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "max_seq_len must be positive for learned encoding".to_string(),
});
}
}
PositionEncodingType::Rotary {
base,
scaling_factor,
} => {
if *base <= 0.0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "RoPE base must be positive".to_string(),
});
}
if *scaling_factor <= 0.0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "RoPE scaling_factor must be positive".to_string(),
});
}
if self.max_seq_len == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "max_seq_len must be positive for RoPE".to_string(),
});
}
if !self.d_model.is_multiple_of(2) {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "d_model must be even for RoPE".to_string(),
});
}
}
PositionEncodingType::Alibi {
n_heads,
max_seq_len,
} => {
if *n_heads == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "n_heads must be positive for ALiBi".to_string(),
});
}
if *max_seq_len == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "max_seq_len must be positive for ALiBi".to_string(),
});
}
}
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct SinusoidalPositionEncoding {
pub config: PositionEncodingConfig,
}
impl SinusoidalPositionEncoding {
pub fn new(config: PositionEncodingConfig) -> Result<Self> {
config.validate()?;
match config.encoding_type {
PositionEncodingType::Sinusoidal { .. } => Ok(Self { config }),
_ => Err(TrustformerError::InvalidDimension {
expected: 0,
got: 1,
context: "Expected Sinusoidal encoding type".to_string(),
}),
}
}
pub fn build_encoding_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let pe_tensor = graph.add_tensor("sinusoidal_pe");
let output_tensor = graph.add_tensor("x_with_pe");
let add_node = EinsumNode::elem_binary("add", 0, pe_tensor, output_tensor);
graph.add_node(add_node)?;
if self.config.dropout > 0.0 {
let dropout_tensor = graph.add_tensor("pe_dropout_output");
let dropout_node = EinsumNode::elem_unary(
format!("dropout_{}", self.config.dropout),
output_tensor,
dropout_tensor,
);
graph.add_node(dropout_node)?;
Ok(vec![dropout_tensor])
} else {
Ok(vec![output_tensor])
}
}
pub fn base(&self) -> f64 {
match self.config.encoding_type {
PositionEncodingType::Sinusoidal { base } => base,
_ => 10000.0,
}
}
}
#[derive(Clone, Debug)]
pub struct LearnedPositionEncoding {
pub config: PositionEncodingConfig,
}
impl LearnedPositionEncoding {
pub fn new(config: PositionEncodingConfig) -> Result<Self> {
config.validate()?;
match config.encoding_type {
PositionEncodingType::Learned => Ok(Self { config }),
_ => Err(TrustformerError::InvalidDimension {
expected: 0,
got: 1,
context: "Expected Learned encoding type".to_string(),
}),
}
}
pub fn build_encoding_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let pe_lookup = graph.add_tensor("pe_lookup");
let lookup_node = EinsumNode::elem_unary("gather_pos_emb", 1, pe_lookup);
graph.add_node(lookup_node)?;
let output_tensor = graph.add_tensor("x_with_learned_pe");
let add_node = EinsumNode::elem_binary("add", 0, pe_lookup, output_tensor);
graph.add_node(add_node)?;
if self.config.dropout > 0.0 {
let dropout_tensor = graph.add_tensor("learned_pe_dropout_output");
let dropout_node = EinsumNode::elem_unary(
format!("dropout_{}", self.config.dropout),
output_tensor,
dropout_tensor,
);
graph.add_node(dropout_node)?;
Ok(vec![dropout_tensor])
} else {
Ok(vec![output_tensor])
}
}
pub fn max_seq_len(&self) -> usize {
self.config.max_seq_len
}
}
#[derive(Clone, Debug)]
pub struct RelativePositionEncoding {
pub config: PositionEncodingConfig,
}
impl RelativePositionEncoding {
pub fn new(config: PositionEncodingConfig) -> Result<Self> {
config.validate()?;
match config.encoding_type {
PositionEncodingType::Relative { .. } => Ok(Self { config }),
_ => Err(TrustformerError::InvalidDimension {
expected: 0,
got: 1,
context: "Expected Relative encoding type".to_string(),
}),
}
}
pub fn build_bias_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let bias_lookup = graph.add_tensor("rel_pos_bias_lookup");
let lookup_node = EinsumNode::elem_unary("gather_rel_bias", 1, bias_lookup);
graph.add_node(lookup_node)?;
let output_tensor = graph.add_tensor("scores_with_rel_bias");
let add_node = EinsumNode::elem_binary("add", 0, bias_lookup, output_tensor);
graph.add_node(add_node)?;
Ok(vec![output_tensor])
}
pub fn num_buckets(&self) -> usize {
match self.config.encoding_type {
PositionEncodingType::Relative { num_buckets, .. } => num_buckets,
_ => 0,
}
}
pub fn max_distance(&self) -> usize {
match self.config.encoding_type {
PositionEncodingType::Relative { max_distance, .. } => max_distance,
_ => 0,
}
}
}
#[derive(Clone, Debug)]
pub struct RotaryPositionEncoding {
pub config: PositionEncodingConfig,
}
impl RotaryPositionEncoding {
pub fn new(config: PositionEncodingConfig) -> Result<Self> {
config.validate()?;
match config.encoding_type {
PositionEncodingType::Rotary { .. } => Ok(Self { config }),
_ => Err(TrustformerError::InvalidDimension {
expected: 0,
got: 1,
context: "Expected Rotary encoding type".to_string(),
}),
}
}
pub fn build_encoding_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let x_even = graph.add_tensor("rope_x_even");
let x_odd = graph.add_tensor("rope_x_odd");
let split_node = EinsumNode::elem_unary("split_even_odd", 0, x_even);
graph.add_node(split_node)?;
let even_cos = graph.add_tensor("rope_even_cos");
let even_cos_node = EinsumNode::elem_binary("mul", x_even, 1, even_cos);
graph.add_node(even_cos_node)?;
let odd_sin = graph.add_tensor("rope_odd_sin");
let odd_sin_node = EinsumNode::elem_binary("mul", x_odd, 2, odd_sin);
graph.add_node(odd_sin_node)?;
let rotated_0 = graph.add_tensor("rope_rotated_0");
let sub_node = EinsumNode::elem_binary("sub", even_cos, odd_sin, rotated_0);
graph.add_node(sub_node)?;
let even_sin = graph.add_tensor("rope_even_sin");
let even_sin_node = EinsumNode::elem_binary("mul", x_even, 2, even_sin);
graph.add_node(even_sin_node)?;
let odd_cos = graph.add_tensor("rope_odd_cos");
let odd_cos_node = EinsumNode::elem_binary("mul", x_odd, 1, odd_cos);
graph.add_node(odd_cos_node)?;
let rotated_1 = graph.add_tensor("rope_rotated_1");
let add_node = EinsumNode::elem_binary("add", even_sin, odd_cos, rotated_1);
graph.add_node(add_node)?;
let output_tensor = graph.add_tensor("rope_output");
let concat_node = EinsumNode::elem_binary("concat", rotated_0, rotated_1, output_tensor);
graph.add_node(concat_node)?;
Ok(vec![output_tensor])
}
pub fn base(&self) -> f64 {
match self.config.encoding_type {
PositionEncodingType::Rotary { base, .. } => base,
_ => 10000.0,
}
}
pub fn scaling_factor(&self) -> f64 {
match self.config.encoding_type {
PositionEncodingType::Rotary { scaling_factor, .. } => scaling_factor,
_ => 1.0,
}
}
}
#[derive(Clone, Debug)]
pub struct AlibiPositionEncoding {
pub config: PositionEncodingConfig,
}
impl AlibiPositionEncoding {
pub fn new(config: PositionEncodingConfig) -> Result<Self> {
config.validate()?;
match config.encoding_type {
PositionEncodingType::Alibi { .. } => Ok(Self { config }),
_ => Err(TrustformerError::InvalidDimension {
expected: 0,
got: 1,
context: "Expected Alibi encoding type".to_string(),
}),
}
}
pub fn build_bias_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let slopes_expanded = graph.add_tensor("alibi_slopes_expanded");
let expand_node = EinsumNode::elem_unary("expand_dims", 1, slopes_expanded);
graph.add_node(expand_node)?;
let bias = graph.add_tensor("alibi_bias");
let bias_node = EinsumNode::elem_binary("mul", slopes_expanded, 2, bias);
graph.add_node(bias_node)?;
let neg_bias = graph.add_tensor("alibi_neg_bias");
let neg_node = EinsumNode::elem_unary("neg", bias, neg_bias);
graph.add_node(neg_node)?;
let output_tensor = graph.add_tensor("scores_with_alibi");
let add_node = EinsumNode::elem_binary("add", 0, neg_bias, output_tensor);
graph.add_node(add_node)?;
Ok(vec![output_tensor])
}
pub fn n_heads(&self) -> usize {
match self.config.encoding_type {
PositionEncodingType::Alibi { n_heads, .. } => n_heads,
_ => 0,
}
}
pub fn compute_slopes(&self) -> Vec<f64> {
let n = self.n_heads();
(1..=n)
.map(|i| 2_f64.powf(-8.0 * (i as f64) / (n as f64)))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sinusoidal_config_creation() {
let config = PositionEncodingConfig::sinusoidal(512, 2048);
assert_eq!(config.d_model, 512);
assert_eq!(config.max_seq_len, 2048);
assert!(matches!(
config.encoding_type,
PositionEncodingType::Sinusoidal { base: 10000.0 }
));
assert!(config.validate().is_ok());
}
#[test]
fn test_learned_config_creation() {
let config = PositionEncodingConfig::learned(512, 2048);
assert_eq!(config.d_model, 512);
assert_eq!(config.max_seq_len, 2048);
assert!(matches!(
config.encoding_type,
PositionEncodingType::Learned
));
assert!(config.validate().is_ok());
}
#[test]
fn test_relative_config_creation() {
let config = PositionEncodingConfig::relative(512, 32, 128);
assert_eq!(config.d_model, 512);
assert!(matches!(
config.encoding_type,
PositionEncodingType::Relative {
num_buckets: 32,
max_distance: 128
}
));
assert!(config.validate().is_ok());
}
#[test]
fn test_config_with_dropout() {
let config = PositionEncodingConfig::sinusoidal(512, 2048).with_dropout(0.1);
assert!((config.dropout - 0.1).abs() < 1e-10);
assert!(config.validate().is_ok());
}
#[test]
fn test_sinusoidal_encoding_creation() {
let config = PositionEncodingConfig::sinusoidal(512, 2048);
let encoding = SinusoidalPositionEncoding::new(config).expect("unwrap");
assert_eq!(encoding.config.d_model, 512);
assert_eq!(encoding.base(), 10000.0);
}
#[test]
fn test_learned_encoding_creation() {
let config = PositionEncodingConfig::learned(512, 2048);
let encoding = LearnedPositionEncoding::new(config).expect("unwrap");
assert_eq!(encoding.max_seq_len(), 2048);
}
#[test]
fn test_relative_encoding_creation() {
let config = PositionEncodingConfig::relative(512, 32, 128);
let encoding = RelativePositionEncoding::new(config).expect("unwrap");
assert_eq!(encoding.num_buckets(), 32);
assert_eq!(encoding.max_distance(), 128);
}
#[test]
fn test_sinusoidal_graph_building() {
let config = PositionEncodingConfig::sinusoidal(512, 2048);
let encoding = SinusoidalPositionEncoding::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("x");
let outputs = encoding.build_encoding_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_learned_graph_building() {
let config = PositionEncodingConfig::learned(512, 2048);
let encoding = LearnedPositionEncoding::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("x");
graph.add_tensor("position_embeddings");
let outputs = encoding.build_encoding_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_relative_bias_graph_building() {
let config = PositionEncodingConfig::relative(512, 32, 128);
let encoding = RelativePositionEncoding::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("attention_scores");
graph.add_tensor("relative_position_bias");
graph.add_tensor("relative_position_indices");
let outputs = encoding.build_bias_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_invalid_config_zero_dimension() {
let mut config = PositionEncodingConfig::sinusoidal(0, 2048);
assert!(config.validate().is_err());
config = PositionEncodingConfig::learned(512, 0);
assert!(config.validate().is_err());
}
#[test]
fn test_invalid_dropout() {
let config = PositionEncodingConfig::sinusoidal(512, 2048).with_dropout(1.5);
assert!(config.validate().is_err());
}
#[test]
fn test_wrong_encoding_type() {
let config = PositionEncodingConfig::learned(512, 2048);
let result = SinusoidalPositionEncoding::new(config);
assert!(result.is_err());
}
#[test]
fn test_rotary_config_creation() {
let config = PositionEncodingConfig::rotary(512, 2048);
assert_eq!(config.d_model, 512);
assert_eq!(config.max_seq_len, 2048);
assert!(matches!(
config.encoding_type,
PositionEncodingType::Rotary {
base: 10000.0,
scaling_factor: 1.0
}
));
assert!(config.validate().is_ok());
}
#[test]
fn test_rotary_scaled_config() {
let config = PositionEncodingConfig::rotary_scaled(512, 4096, 10000.0, 2.0);
assert_eq!(config.max_seq_len, 4096);
match config.encoding_type {
PositionEncodingType::Rotary {
base,
scaling_factor,
} => {
assert!((base - 10000.0).abs() < 1e-10);
assert!((scaling_factor - 2.0).abs() < 1e-10);
}
_ => panic!("Expected Rotary encoding type"),
}
}
#[test]
fn test_rotary_encoding_creation() {
let config = PositionEncodingConfig::rotary(512, 2048);
let encoding = RotaryPositionEncoding::new(config).expect("unwrap");
assert_eq!(encoding.config.d_model, 512);
assert_eq!(encoding.base(), 10000.0);
assert_eq!(encoding.scaling_factor(), 1.0);
}
#[test]
fn test_rotary_graph_building() {
let config = PositionEncodingConfig::rotary(512, 2048);
let encoding = RotaryPositionEncoding::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("x");
graph.add_tensor("cos_cached");
graph.add_tensor("sin_cached");
let outputs = encoding.build_encoding_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_rotary_requires_even_d_model() {
let config = PositionEncodingConfig::rotary(513, 2048); assert!(config.validate().is_err());
}
#[test]
fn test_alibi_config_creation() {
let config = PositionEncodingConfig::alibi(512, 8, 2048);
assert_eq!(config.d_model, 512);
assert_eq!(config.max_seq_len, 2048);
assert!(matches!(
config.encoding_type,
PositionEncodingType::Alibi {
n_heads: 8,
max_seq_len: 2048
}
));
assert!(config.validate().is_ok());
}
#[test]
fn test_alibi_encoding_creation() {
let config = PositionEncodingConfig::alibi(512, 8, 2048);
let encoding = AlibiPositionEncoding::new(config).expect("unwrap");
assert_eq!(encoding.n_heads(), 8);
}
#[test]
fn test_alibi_slopes_computation() {
let config = PositionEncodingConfig::alibi(512, 8, 2048);
let encoding = AlibiPositionEncoding::new(config).expect("unwrap");
let slopes = encoding.compute_slopes();
assert_eq!(slopes.len(), 8);
for i in 1..slopes.len() {
assert!(slopes[i] < slopes[i - 1]);
}
assert!(slopes[0] < 1.0);
assert!(slopes[0] > 0.0);
}
#[test]
fn test_alibi_graph_building() {
let config = PositionEncodingConfig::alibi(512, 8, 2048);
let encoding = AlibiPositionEncoding::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("attention_scores");
graph.add_tensor("alibi_slopes");
graph.add_tensor("distance_matrix");
let outputs = encoding.build_bias_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_alibi_invalid_zero_heads() {
let config = PositionEncodingConfig::alibi(512, 0, 2048);
assert!(config.validate().is_err());
}
}