use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;
pub struct Embedding {
base: ModuleBase,
num_embeddings: usize,
embedding_dim: usize,
padding_idx: Option<usize>,
max_norm: Option<f32>,
norm_type: f32,
scale_grad_by_freq: bool,
sparse: bool,
}
impl Embedding {
pub fn new(num_embeddings: usize, embedding_dim: usize) -> Self {
let mut base = ModuleBase::new();
let weight = crate::init::xavier_uniform(&[num_embeddings, embedding_dim])
.expect("Failed to initialize embedding weight");
base.register_parameter("weight".to_string(), Parameter::new(weight));
Self {
base,
num_embeddings,
embedding_dim,
padding_idx: None,
max_norm: None,
norm_type: 2.0,
scale_grad_by_freq: false,
sparse: false,
}
}
pub fn with_padding_idx(
num_embeddings: usize,
embedding_dim: usize,
padding_idx: usize,
) -> Self {
let mut embedding = Self::new(num_embeddings, embedding_dim);
embedding.padding_idx = Some(padding_idx);
embedding
}
pub fn with_config(
num_embeddings: usize,
embedding_dim: usize,
padding_idx: Option<usize>,
max_norm: Option<f32>,
norm_type: f32,
scale_grad_by_freq: bool,
sparse: bool,
) -> Self {
let mut embedding = Self::new(num_embeddings, embedding_dim);
embedding.padding_idx = padding_idx;
embedding.max_norm = max_norm;
embedding.norm_type = norm_type;
embedding.scale_grad_by_freq = scale_grad_by_freq;
embedding.sparse = sparse;
embedding
}
}
impl Module for Embedding {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let weight = self.base.parameters["weight"].tensor().read().clone();
let weight_data = weight.to_vec()?;
let input_data = input.to_vec()?;
let binding = input.shape();
let input_shape = binding.dims();
let mut output_shape = input_shape.to_vec();
output_shape.push(self.embedding_dim);
let num_indices: usize = input_shape.iter().product();
let total_output_size = num_indices * self.embedding_dim;
let mut output_data = Vec::with_capacity(total_output_size);
for &idx_f32 in input_data.iter() {
let idx = idx_f32 as usize;
if let Some(padding_idx) = self.padding_idx {
if idx == padding_idx {
output_data.extend(vec![0.0; self.embedding_dim]);
continue;
}
}
if idx >= self.num_embeddings {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Index {} out of bounds for embedding with {} embeddings",
idx, self.num_embeddings
)));
}
let start_idx = idx * self.embedding_dim;
let end_idx = start_idx + self.embedding_dim;
let mut embedding_vec = weight_data[start_idx..end_idx].to_vec();
if let Some(max_norm) = self.max_norm {
let norm = if self.norm_type == 2.0 {
embedding_vec.iter().map(|x| x * x).sum::<f32>().sqrt()
} else {
embedding_vec
.iter()
.map(|x| x.abs().powf(self.norm_type))
.sum::<f32>()
.powf(1.0 / self.norm_type)
};
if norm > max_norm {
let scale = max_norm / norm;
for val in &mut embedding_vec {
*val *= scale;
}
}
}
output_data.extend(embedding_vec);
}
Tensor::from_vec(output_data, &output_shape)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for Embedding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Embedding")
.field("num_embeddings", &self.num_embeddings)
.field("embedding_dim", &self.embedding_dim)
.field("padding_idx", &self.padding_idx)
.finish()
}
}
#[derive(Debug, Clone)]
pub enum PositionalEncodingType {
Sinusoidal,
Learnable,
Relative,
Rotary { base: f32 },
Alibi,
}
pub struct SinusoidalPositionalEncoding {
base: ModuleBase,
d_model: usize,
max_len: usize,
dropout: f32,
}
impl SinusoidalPositionalEncoding {
pub fn new(d_model: usize, max_len: usize, dropout: f32) -> Self {
let mut base = ModuleBase::new();
let pe = create_sinusoidal_encoding(max_len, d_model);
base.register_parameter("pe".to_string(), Parameter::new(pe));
Self {
base,
d_model,
max_len,
dropout,
}
}
}
impl Module for SinusoidalPositionalEncoding {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let pe = self.base.parameters["pe"].tensor().read().clone();
let seq_len = input.shape().dims()[1];
let pe_slice = pe.narrow(0, 0, seq_len.min(self.max_len))?;
let output = input.add_op(&pe_slice.unsqueeze(0)?)?;
if self.dropout > 0.0 && self.training() {
crate::functional::dropout(&output, self.dropout, self.training())
} else {
Ok(output)
}
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct LearnablePositionalEncoding {
base: ModuleBase,
d_model: usize,
max_len: usize,
dropout: f32,
}
impl LearnablePositionalEncoding {
pub fn new(d_model: usize, max_len: usize, dropout: f32) -> Self {
let mut base = ModuleBase::new();
let pe = crate::init::xavier_uniform(&[max_len, d_model])
.expect("Failed to initialize positional encoding");
base.register_parameter("pe".to_string(), Parameter::new(pe));
Self {
base,
d_model,
max_len,
dropout,
}
}
}
impl Module for LearnablePositionalEncoding {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let pe = self.base.parameters["pe"].tensor().read().clone();
let seq_len = input.shape().dims()[1];
let pe_slice = pe.narrow(0, 0, seq_len.min(self.max_len))?;
let output = input.add_op(&pe_slice.unsqueeze(0)?)?;
if self.dropout > 0.0 && self.training() {
crate::functional::dropout(&output, self.dropout, self.training())
} else {
Ok(output)
}
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct RotaryPositionalEmbedding {
d_model: usize,
base: f32,
max_seq_len: usize,
}
impl RotaryPositionalEmbedding {
pub fn new(d_model: usize, base: f32, max_seq_len: usize) -> Self {
Self {
d_model,
base,
max_seq_len,
}
}
pub fn apply_rope(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
let seq_len = q.shape().dims()[2];
let head_dim = q.shape().dims()[3];
let freqs = self.create_frequencies(seq_len, head_dim)?;
let q_rot = self.rotate_tensor(q, &freqs)?;
let k_rot = self.rotate_tensor(k, &freqs)?;
Ok((q_rot, k_rot))
}
fn create_frequencies(&self, seq_len: usize, head_dim: usize) -> Result<Tensor> {
let mut freqs = Vec::new();
for pos in 0..seq_len {
for i in (0..head_dim).step_by(2) {
let freq = 1.0 / self.base.powf(i as f32 / head_dim as f32);
let angle = pos as f32 * freq;
freqs.push(angle.cos());
freqs.push(-angle.sin());
freqs.push(angle.sin());
freqs.push(angle.cos());
}
}
Tensor::from_vec(freqs, &[seq_len, head_dim, 2, 2])
}
fn rotate_tensor(&self, x: &Tensor, _freqs: &Tensor) -> Result<Tensor> {
Ok(x.clone())
}
}
pub struct AlibiPositionalBias {
num_heads: usize,
max_seq_len: usize,
}
impl AlibiPositionalBias {
pub fn new(num_heads: usize, max_seq_len: usize) -> Self {
Self {
num_heads,
max_seq_len,
}
}
pub fn create_bias(&self, seq_len: usize) -> Result<Tensor> {
let mut bias_data = Vec::new();
let slopes = self.get_slopes();
for head in 0..self.num_heads {
let slope = slopes[head];
for i in 0..seq_len {
for j in 0..seq_len {
let distance = (i as i32 - j as i32).abs() as f32;
let bias = -slope * distance;
bias_data.push(bias);
}
}
}
Tensor::from_vec(bias_data, &[self.num_heads, seq_len, seq_len])
}
fn get_slopes(&self) -> Vec<f32> {
let ratio = 2.0_f32.powf(-8.0 / self.num_heads as f32);
let mut slopes = Vec::new();
for i in 0..self.num_heads {
let slope = ratio.powf(i as f32 + 1.0);
slopes.push(slope);
}
slopes
}
}
pub struct RelativePositionalEncoding {
base: ModuleBase,
d_model: usize,
max_relative_distance: usize,
}
impl RelativePositionalEncoding {
pub fn new(d_model: usize, max_relative_distance: usize) -> Self {
let mut base = ModuleBase::new();
let num_positions = 2 * max_relative_distance + 1;
let relative_pe = crate::init::xavier_uniform(&[num_positions, d_model])
.expect("Failed to initialize relative positional encoding");
base.register_parameter("relative_pe".to_string(), Parameter::new(relative_pe));
Self {
base,
d_model,
max_relative_distance,
}
}
pub fn get_relative_embeddings(&self, seq_len: usize) -> Result<Tensor> {
let relative_pe = self.base.parameters["relative_pe"].tensor().read().clone();
let mut relative_data = Vec::new();
for i in 0..seq_len {
for j in 0..seq_len {
let relative_distance = (i as i32 - j as i32).clamp(
-(self.max_relative_distance as i32),
self.max_relative_distance as i32,
);
let idx = (relative_distance + self.max_relative_distance as i32) as usize;
let embedding = relative_pe.narrow(0, idx as i64, 1)?.squeeze(0)?;
let embedding_data = embedding.to_vec()?;
relative_data.extend(embedding_data);
}
}
Tensor::from_vec(relative_data, &[seq_len, seq_len, self.d_model])
}
}
impl Module for RelativePositionalEncoding {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let input_shape_binding = input.shape();
let input_shape = input_shape_binding.dims();
match input_shape.len() {
3 => {
let batch_size = input_shape[0];
let seq_len = input_shape[1];
let d_model = input_shape[2];
let max_relative_position = (seq_len - 1) * 2 + 1;
let embeddings = self.base.parameters.get("weight").ok_or_else(|| {
TorshError::InvalidArgument(
"RelativePositionalEncoding missing weight parameter".to_string(),
)
})?;
let embedding_data = embeddings.tensor().read().to_vec()?;
let input_data = input.to_vec()?;
let mut output_data = vec![0.0f32; batch_size * seq_len * d_model];
for b in 0..batch_size {
for i in 0..seq_len {
for j in 0..seq_len {
let relative_pos = (j as i32 - i as i32) + (seq_len - 1) as i32;
let relative_pos_clamped =
relative_pos.max(0).min(max_relative_position as i32 - 1) as usize;
let emb_start = relative_pos_clamped * d_model;
for d in 0..d_model {
let input_idx = b * seq_len * d_model + i * d_model + d;
let output_idx = b * seq_len * d_model + i * d_model + d;
if j == 0 {
output_data[output_idx] = input_data[input_idx];
}
output_data[output_idx] +=
embedding_data[emb_start + d] / seq_len as f32;
}
}
}
}
Tensor::from_vec(output_data, input_shape)
}
4 => {
let batch_size = input_shape[0];
let num_heads = input_shape[1];
let seq_len_q = input_shape[2];
let seq_len_k = input_shape[3];
if seq_len_q != seq_len_k {
return Err(TorshError::InvalidArgument(
"RelativePositionalEncoding requires square attention matrices".to_string(),
));
}
let seq_len = seq_len_q;
let max_relative_position = (seq_len - 1) * 2 + 1;
let embeddings = self.base.parameters.get("weight").ok_or_else(|| {
TorshError::InvalidArgument(
"RelativePositionalEncoding missing weight parameter".to_string(),
)
})?;
let embedding_data = embeddings.tensor().read().to_vec()?;
let input_data = input.to_vec()?;
let mut output_data = vec![0.0f32; batch_size * num_heads * seq_len * seq_len];
for b in 0..batch_size {
for h in 0..num_heads {
for i in 0..seq_len {
for j in 0..seq_len {
let relative_pos = (j as i32 - i as i32) + (seq_len - 1) as i32;
let relative_pos_clamped =
relative_pos.max(0).min(max_relative_position as i32 - 1)
as usize;
let idx = b * num_heads * seq_len * seq_len
+ h * seq_len * seq_len
+ i * seq_len
+ j;
output_data[idx] =
input_data[idx] + embedding_data[relative_pos_clamped];
}
}
}
}
Tensor::from_vec(output_data, input_shape)
}
_ => {
Ok(input.clone())
}
}
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct PositionInterpolation;
impl PositionInterpolation {
pub fn interpolate_positions(pe: &Tensor, new_max_len: usize) -> Result<Tensor> {
let shape = pe.shape();
let pe_shape = shape.dims();
let old_max_len = pe_shape[0];
let d_model = pe_shape[1];
if new_max_len <= old_max_len {
return pe.narrow(0, 0, new_max_len);
}
let scale_factor = old_max_len as f32 / new_max_len as f32;
let mut interpolated_data = Vec::new();
for new_pos in 0..new_max_len {
let old_pos_f = new_pos as f32 * scale_factor;
let old_pos_low = old_pos_f.floor() as usize;
let old_pos_high = (old_pos_low + 1).min(old_max_len - 1);
let alpha = old_pos_f - old_pos_low as f32;
let pe_low = pe.narrow(0, old_pos_low as i64, 1)?.squeeze(0)?;
let pe_high = pe.narrow(0, old_pos_high as i64, 1)?.squeeze(0)?;
let pe_low_data = pe_low.to_vec()?;
let pe_high_data = pe_high.to_vec()?;
for i in 0..d_model {
let interpolated = pe_low_data[i] * (1.0 - alpha) + pe_high_data[i] * alpha;
interpolated_data.push(interpolated);
}
}
Tensor::from_vec(interpolated_data, &[new_max_len, d_model])
}
pub fn interpolate_rope_frequencies(base: f32, scale_factor: f32, d_model: usize) -> Vec<f32> {
let mut freqs = Vec::new();
for i in (0..d_model).step_by(2) {
let freq = 1.0 / (base * scale_factor).powf(i as f32 / d_model as f32);
freqs.push(freq);
}
freqs
}
}
pub struct SinusoidalPositionEmbedding {
base: ModuleBase,
d_model: usize,
max_len: usize,
learned_scale: bool,
}
impl SinusoidalPositionEmbedding {
pub fn new(d_model: usize, max_len: usize) -> Result<Self> {
if d_model % 2 != 0 {
return Err(TorshError::InvalidArgument(format!(
"d_model must be even, got {}",
d_model
)));
}
let mut base = ModuleBase::new();
let embeddings = Self::create_embeddings(max_len, d_model)?;
base.register_parameter("embeddings".to_string(), Parameter::new(embeddings));
Ok(Self {
base,
d_model,
max_len,
learned_scale: false,
})
}
pub fn with_learned_scale(d_model: usize, max_len: usize) -> Result<Self> {
let mut layer = Self::new(d_model, max_len)?;
layer.learned_scale = true;
let scale = Tensor::from_vec(vec![1.0], &[1])?;
layer
.base
.register_parameter("scale".to_string(), Parameter::new(scale));
Ok(layer)
}
fn create_embeddings(max_len: usize, d_model: usize) -> Result<Tensor> {
let mut embeddings = vec![0.0f32; max_len * d_model];
let mut div_term = Vec::with_capacity(d_model / 2);
for i in (0..d_model).step_by(2) {
let exponent = i as f32 / d_model as f32;
div_term.push(1.0 / 10000.0_f32.powf(exponent));
}
for pos in 0..max_len {
let pos_f = pos as f32;
for (i, &div) in div_term.iter().enumerate() {
let angle = pos_f * div;
embeddings[pos * d_model + (i * 2)] = angle.sin();
if (i * 2 + 1) < d_model {
embeddings[pos * d_model + (i * 2 + 1)] = angle.cos();
}
}
}
Tensor::from_vec(embeddings, &[max_len, d_model])
}
pub fn get_embeddings(&self, positions: &Tensor) -> Result<Tensor> {
let embeddings = self.base.parameters["embeddings"].tensor().read().clone();
let positions_data = positions.to_vec()?;
let binding = positions.shape();
let positions_shape = binding.dims();
for &pos in positions_data.iter() {
let pos_usize = pos as usize;
if pos_usize >= self.max_len {
return Err(TorshError::InvalidArgument(format!(
"Position {} exceeds max_len {}",
pos_usize, self.max_len
)));
}
}
match positions_shape.len() {
1 => {
let seq_len = positions_shape[0];
let mut output = Vec::with_capacity(seq_len * self.d_model);
let embeddings_data = embeddings.to_vec()?;
for &pos in positions_data.iter() {
let pos_idx = pos as usize;
let start = pos_idx * self.d_model;
let end = start + self.d_model;
output.extend_from_slice(&embeddings_data[start..end]);
}
let mut result = Tensor::from_vec(output, &[seq_len, self.d_model])?;
if self.learned_scale {
let scale = self.base.parameters["scale"].tensor().read().clone();
result = result.mul_op(&scale)?;
}
Ok(result)
}
2 => {
let batch_size = positions_shape[0];
let seq_len = positions_shape[1];
let mut output = Vec::with_capacity(batch_size * seq_len * self.d_model);
let embeddings_data = embeddings.to_vec()?;
for &pos in positions_data.iter() {
let pos_idx = pos as usize;
let start = pos_idx * self.d_model;
let end = start + self.d_model;
output.extend_from_slice(&embeddings_data[start..end]);
}
let mut result = Tensor::from_vec(output, &[batch_size, seq_len, self.d_model])?;
if self.learned_scale {
let scale = self.base.parameters["scale"].tensor().read().clone();
result = result.mul_op(&scale)?;
}
Ok(result)
}
_ => Err(TorshError::InvalidArgument(format!(
"Expected 1D or 2D positions tensor, got {}D",
positions_shape.len()
))),
}
}
pub fn get_embeddings_for_length(&self, seq_len: usize) -> Result<Tensor> {
if seq_len > self.max_len {
return Err(TorshError::InvalidArgument(format!(
"Sequence length {} exceeds max_len {}",
seq_len, self.max_len
)));
}
let embeddings = self.base.parameters["embeddings"].tensor().read().clone();
let mut result = embeddings.narrow(0, 0, seq_len)?;
if self.learned_scale {
let scale = self.base.parameters["scale"].tensor().read().clone();
result = result.mul_op(&scale)?;
}
Ok(result)
}
pub fn max_len(&self) -> usize {
self.max_len
}
pub fn d_model(&self) -> usize {
self.d_model
}
}
impl Module for SinusoidalPositionEmbedding {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let binding = input.shape();
let input_shape = binding.dims();
match input_shape.len() {
1 | 2 => {
self.get_embeddings(input)
}
3 => {
let seq_len = input_shape[1];
let pos_emb = self.get_embeddings_for_length(seq_len)?;
let pos_emb_broadcasted = pos_emb.unsqueeze(0)?;
input.add_op(&pos_emb_broadcasted)
}
_ => Err(TorshError::InvalidArgument(format!(
"Unexpected input shape: {:?}",
input_shape
))),
}
}
fn parameters(&self) -> HashMap<String, Parameter> {
if self.learned_scale {
let mut params = HashMap::new();
if let Some(scale) = self.base.parameters.get("scale") {
params.insert("scale".to_string(), scale.clone());
}
params
} else {
HashMap::new()
}
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
if self.learned_scale {
let mut params = HashMap::new();
if let Some(scale) = self.base.parameters.get("scale") {
params.insert("scale".to_string(), scale.clone());
}
params
} else {
HashMap::new()
}
}
}
impl std::fmt::Debug for SinusoidalPositionEmbedding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SinusoidalPositionEmbedding")
.field("d_model", &self.d_model)
.field("max_len", &self.max_len)
.field("learned_scale", &self.learned_scale)
.finish()
}
}
fn create_sinusoidal_encoding(max_len: usize, d_model: usize) -> Tensor {
let mut pos_encoding = vec![0.0f32; max_len * d_model];
for pos in 0..max_len {
for i in (0..d_model).step_by(2) {
let angle = pos as f32 / 10000.0_f32.powf(i as f32 / d_model as f32);
pos_encoding[pos * d_model + i] = angle.sin();
if i + 1 < d_model {
pos_encoding[pos * d_model + i + 1] = angle.cos();
}
}
}
Tensor::from_vec(pos_encoding, &[max_len, d_model]).expect("tensor creation should succeed")
}
impl std::fmt::Debug for SinusoidalPositionalEncoding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SinusoidalPositionalEncoding")
.field("d_model", &self.d_model)
.field("max_len", &self.max_len)
.field("dropout", &self.dropout)
.finish()
}
}
impl std::fmt::Debug for LearnablePositionalEncoding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LearnablePositionalEncoding")
.field("d_model", &self.d_model)
.field("max_len", &self.max_len)
.field("dropout", &self.dropout)
.finish()
}
}
impl std::fmt::Debug for RotaryPositionalEmbedding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RotaryPositionalEmbedding")
.field("d_model", &self.d_model)
.field("base", &self.base)
.field("max_seq_len", &self.max_seq_len)
.finish()
}
}
impl std::fmt::Debug for AlibiPositionalBias {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AlibiPositionalBias")
.field("num_heads", &self.num_heads)
.field("max_seq_len", &self.max_seq_len)
.finish()
}
}
impl std::fmt::Debug for RelativePositionalEncoding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RelativePositionalEncoding")
.field("d_model", &self.d_model)
.field("max_relative_distance", &self.max_relative_distance)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_embedding_basic_lookup() -> Result<()> {
let mut embedding = Embedding::new(5, 3);
let weight_data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, ];
let weight = Tensor::from_vec(weight_data, &[5, 3])?;
*embedding
.base
.parameters
.get_mut("weight")
.expect("operation should succeed")
.tensor()
.write() = weight;
let input = Tensor::from_vec(vec![2.0], &[1])?;
let output = embedding.forward(&input)?;
let output_data = output.to_vec()?;
assert_eq!(output.shape().dims(), &[1, 3]);
assert_relative_eq!(output_data[0], 7.0, epsilon = 1e-6);
assert_relative_eq!(output_data[1], 8.0, epsilon = 1e-6);
assert_relative_eq!(output_data[2], 9.0, epsilon = 1e-6);
Ok(())
}
#[test]
fn test_embedding_multiple_indices() -> Result<()> {
let mut embedding = Embedding::new(4, 2);
let weight_data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
let weight = Tensor::from_vec(weight_data, &[4, 2])?;
*embedding
.base
.parameters
.get_mut("weight")
.expect("operation should succeed")
.tensor()
.write() = weight;
let input = Tensor::from_vec(vec![0.0, 2.0, 1.0], &[3])?;
let output = embedding.forward(&input)?;
let output_data = output.to_vec()?;
assert_eq!(output.shape().dims(), &[3, 2]);
assert_relative_eq!(output_data[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(output_data[1], 2.0, epsilon = 1e-6);
assert_relative_eq!(output_data[2], 5.0, epsilon = 1e-6);
assert_relative_eq!(output_data[3], 6.0, epsilon = 1e-6);
assert_relative_eq!(output_data[4], 3.0, epsilon = 1e-6);
assert_relative_eq!(output_data[5], 4.0, epsilon = 1e-6);
Ok(())
}
#[test]
fn test_embedding_2d_indices() -> Result<()> {
let mut embedding = Embedding::new(3, 2);
let weight_data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
let weight = Tensor::from_vec(weight_data, &[3, 2])?;
*embedding
.base
.parameters
.get_mut("weight")
.expect("operation should succeed")
.tensor()
.write() = weight;
let input = Tensor::from_vec(vec![0.0, 1.0, 2.0, 0.0], &[2, 2])?;
let output = embedding.forward(&input)?;
assert_eq!(output.shape().dims(), &[2, 2, 2]);
let output_data = output.to_vec()?;
assert_relative_eq!(output_data[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(output_data[1], 2.0, epsilon = 1e-6);
assert_relative_eq!(output_data[2], 3.0, epsilon = 1e-6);
assert_relative_eq!(output_data[3], 4.0, epsilon = 1e-6);
assert_relative_eq!(output_data[4], 5.0, epsilon = 1e-6);
assert_relative_eq!(output_data[5], 6.0, epsilon = 1e-6);
assert_relative_eq!(output_data[6], 1.0, epsilon = 1e-6);
assert_relative_eq!(output_data[7], 2.0, epsilon = 1e-6);
Ok(())
}
#[test]
fn test_embedding_with_padding_idx() -> Result<()> {
let mut embedding = Embedding::with_padding_idx(4, 3, 0);
let weight_data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ];
let weight = Tensor::from_vec(weight_data, &[4, 3])?;
*embedding
.base
.parameters
.get_mut("weight")
.expect("operation should succeed")
.tensor()
.write() = weight;
let input = Tensor::from_vec(vec![1.0, 0.0, 2.0], &[3])?;
let output = embedding.forward(&input)?;
let output_data = output.to_vec()?;
assert_relative_eq!(output_data[0], 4.0, epsilon = 1e-6);
assert_relative_eq!(output_data[1], 5.0, epsilon = 1e-6);
assert_relative_eq!(output_data[2], 6.0, epsilon = 1e-6);
assert_relative_eq!(output_data[3], 0.0, epsilon = 1e-6);
assert_relative_eq!(output_data[4], 0.0, epsilon = 1e-6);
assert_relative_eq!(output_data[5], 0.0, epsilon = 1e-6);
assert_relative_eq!(output_data[6], 7.0, epsilon = 1e-6);
assert_relative_eq!(output_data[7], 8.0, epsilon = 1e-6);
assert_relative_eq!(output_data[8], 9.0, epsilon = 1e-6);
Ok(())
}
#[test]
fn test_embedding_out_of_bounds() {
let mut embedding = Embedding::new(3, 2);
let weight_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let weight = Tensor::from_vec(weight_data, &[3, 2]).expect("Tensor should succeed");
*embedding
.base
.parameters
.get_mut("weight")
.expect("operation should succeed")
.tensor()
.write() = weight;
let input = Tensor::from_vec(vec![5.0], &[1]).expect("Tensor should succeed");
let result = embedding.forward(&input);
assert!(result.is_err());
if let Err(torsh_core::error::TorshError::InvalidArgument(msg)) = result {
assert!(msg.contains("out of bounds"));
} else {
panic!("Expected InvalidArgument error for out-of-bounds index");
}
}
#[test]
fn test_embedding_with_max_norm() -> Result<()> {
let mut embedding = Embedding::with_config(
3,
2,
None, Some(1.0), 2.0, false, false, );
let weight_data = vec![
3.0, 4.0, 1.0, 0.0, 0.6, 0.8, ];
let weight = Tensor::from_vec(weight_data, &[3, 2])?;
*embedding
.base
.parameters
.get_mut("weight")
.expect("operation should succeed")
.tensor()
.write() = weight;
let input = Tensor::from_vec(vec![0.0], &[1])?;
let output = embedding.forward(&input)?;
let output_data = output.to_vec()?;
assert_relative_eq!(output_data[0], 0.6, epsilon = 1e-6);
assert_relative_eq!(output_data[1], 0.8, epsilon = 1e-6);
let norm = (output_data[0] * output_data[0] + output_data[1] * output_data[1]).sqrt();
assert_relative_eq!(norm, 1.0, epsilon = 1e-6);
Ok(())
}
#[test]
fn test_embedding_shape_preservation() -> Result<()> {
let embedding = Embedding::new(10, 5);
let input1d = Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3])?;
let output1d = embedding.forward(&input1d)?;
assert_eq!(output1d.shape().dims(), &[3, 5]);
let input2d = Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], &[2, 4])?;
let output2d = embedding.forward(&input2d)?;
assert_eq!(output2d.shape().dims(), &[2, 4, 5]);
let input3d = Tensor::from_vec(
vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 1.0],
&[2, 3, 2],
)?;
let output3d = embedding.forward(&input3d)?;
assert_eq!(output3d.shape().dims(), &[2, 3, 2, 5]);
Ok(())
}
#[test]
fn test_sinusoidal_position_embedding_creation() -> Result<()> {
let pos_emb = SinusoidalPositionEmbedding::new(64, 100)?;
assert_eq!(pos_emb.d_model(), 64);
assert_eq!(pos_emb.max_len(), 100);
let result = SinusoidalPositionEmbedding::new(63, 100);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_sinusoidal_position_embedding_forward() -> Result<()> {
let pos_emb = SinusoidalPositionEmbedding::new(128, 1000)?;
let positions = Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0], &[5])?;
let embeddings = pos_emb.forward(&positions)?;
assert_eq!(embeddings.shape().dims(), &[5, 128]);
let data = embeddings.to_vec()?;
let sum: f32 = data.iter().sum();
assert!(sum.abs() > 0.1);
Ok(())
}
#[test]
fn test_sinusoidal_position_embedding_sequence_length() -> Result<()> {
let pos_emb = SinusoidalPositionEmbedding::new(64, 200)?;
let embeddings = pos_emb.get_embeddings_for_length(50)?;
assert_eq!(embeddings.shape().dims(), &[50, 64]);
let emb0 = embeddings.narrow(0, 0, 1)?.squeeze(0)?;
let data0 = emb0.to_vec()?;
assert_relative_eq!(data0[0], 0.0, epsilon = 1e-6); assert_relative_eq!(data0[1], 1.0, epsilon = 1e-6);
Ok(())
}
#[test]
fn test_sinusoidal_position_embedding_mathematical_properties() -> Result<()> {
let d_model = 64;
let pos_emb = SinusoidalPositionEmbedding::new(d_model, 100)?;
let positions = Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3])?;
let embeddings = pos_emb.forward(&positions)?;
let data = embeddings.to_vec()?;
assert_relative_eq!(data[0], 0.0, epsilon = 1e-6); assert_relative_eq!(data[1], 1.0, epsilon = 1e-6);
let pos1_start = d_model;
assert_relative_eq!(data[pos1_start], (1.0_f32).sin(), epsilon = 1e-5);
assert_relative_eq!(data[pos1_start + 1], (1.0_f32).cos(), epsilon = 1e-5);
Ok(())
}
#[test]
fn test_sinusoidal_position_embedding_periodicity() -> Result<()> {
let pos_emb = SinusoidalPositionEmbedding::new(128, 10000)?;
let pos1 = Tensor::from_vec(vec![0.0], &[1])?;
let emb1 = pos_emb.forward(&pos1)?;
let pos2 = Tensor::from_vec(vec![100.0], &[1])?;
let emb2 = pos_emb.forward(&pos2)?;
let data1 = emb1.to_vec()?;
let data2 = emb2.to_vec()?;
let mut different_count = 0;
for (v1, v2) in data1.iter().zip(data2.iter()) {
if (v1 - v2).abs() > 1e-6 {
different_count += 1;
}
}
assert!(different_count > 100);
Ok(())
}
#[test]
fn test_sinusoidal_position_embedding_batch_support() -> Result<()> {
let pos_emb = SinusoidalPositionEmbedding::new(64, 100)?;
let positions = Tensor::from_vec(
vec![
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, ],
&[2, 3],
)?;
let embeddings = pos_emb.forward(&positions)?;
assert_eq!(embeddings.shape().dims(), &[2, 3, 64]);
Ok(())
}
#[test]
fn test_sinusoidal_position_embedding_bounds_checking() -> Result<()> {
let pos_emb = SinusoidalPositionEmbedding::new(64, 100)?;
let positions = Tensor::from_vec(vec![101.0], &[1])?;
let result = pos_emb.forward(&positions);
assert!(result.is_err());
let result = pos_emb.get_embeddings_for_length(101);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_sinusoidal_position_embedding_with_learned_scale() -> Result<()> {
let pos_emb = SinusoidalPositionEmbedding::with_learned_scale(64, 100)?;
let params = pos_emb.parameters();
assert_eq!(params.len(), 1);
assert!(params.contains_key("scale"));
let positions = Tensor::from_vec(vec![0.0, 1.0], &[2])?;
let embeddings = pos_emb.forward(&positions)?;
assert_eq!(embeddings.shape().dims(), &[2, 64]);
Ok(())
}
#[test]
fn test_sinusoidal_position_embedding_add_to_tokens() -> Result<()> {
let d_model = 64;
let seq_len = 10;
let batch_size = 2;
let pos_emb = SinusoidalPositionEmbedding::new(d_model, 100)?;
let token_data = vec![0.5_f32; batch_size * seq_len * d_model];
let tokens = Tensor::from_vec(token_data, &[batch_size, seq_len, d_model])?;
let output = pos_emb.forward(&tokens)?;
assert_eq!(output.shape().dims(), &[batch_size, seq_len, d_model]);
let output_data = output.to_vec()?;
let tokens_data = tokens.to_vec()?;
let mut different_count = 0;
for (out, tok) in output_data.iter().zip(tokens_data.iter()) {
if (out - tok).abs() > 1e-6 {
different_count += 1;
}
}
assert!(different_count > d_model * seq_len);
Ok(())
}
}