use crate::tensor::DenseTensor;
use crate::tensor::traits::{TensorOps, TensorBase};
pub trait DifferentiableOp {
fn forward(&self, inputs: &[&DenseTensor]) -> DenseTensor;
fn backward(&self, inputs: &[&DenseTensor], grad_output: &DenseTensor) -> Vec<DenseTensor>;
}
#[derive(Debug, Clone)]
pub struct LinearOp {
pub weight: DenseTensor,
pub bias: Option<DenseTensor>,
}
impl LinearOp {
pub fn new(weight: DenseTensor, bias: Option<DenseTensor>) -> Self {
Self { weight, bias }
}
pub fn forward(&self, input: &DenseTensor) -> DenseTensor {
let output = input.matmul(&self.weight);
if let Some(bias) = &self.bias {
output.add(bias)
} else {
output
}
}
}
#[derive(Debug, Clone)]
pub struct EmbeddingOp {
pub embeddings: DenseTensor,
}
impl EmbeddingOp {
pub fn new(embeddings: DenseTensor) -> Self {
Self { embeddings }
}
pub fn forward(&self, indices: &[usize]) -> DenseTensor {
let dim = self.embeddings.shape()[1];
let mut data = Vec::with_capacity(indices.len() * dim);
for &idx in indices {
let start = idx * dim;
let end = start + dim;
data.extend_from_slice(&self.embeddings.data()[start..end]);
}
DenseTensor::new(data, vec![indices.len(), dim])
}
}
#[derive(Debug, Clone)]
pub struct ScaledDotProductOp {
pub scale: f64,
}
impl ScaledDotProductOp {
pub fn new(head_dim: usize) -> Self {
Self {
scale: 1.0 / (head_dim as f64).sqrt(),
}
}
pub fn forward(&self, query: &DenseTensor, key: &DenseTensor, value: &DenseTensor) -> DenseTensor {
let key_t = key.transpose(None);
let mut scores = query.matmul(&key_t);
scores = scores.scale(self.scale);
let attn_weights = scores.softmax(-1);
attn_weights.matmul(value)
}
pub fn forward_with_mask(
&self,
query: &DenseTensor,
key: &DenseTensor,
value: &DenseTensor,
mask: Option<&DenseTensor>,
) -> DenseTensor {
let key_t = key.transpose(None);
let mut scores = query.matmul(&key_t);
scores = scores.scale(self.scale);
if let Some(mask) = mask {
scores = scores.mask_fill(mask, f64::NEG_INFINITY);
}
let attn_weights = scores.softmax(-1);
attn_weights.matmul(value)
}
}
#[derive(Debug, Clone)]
pub struct MultiHeadAttentionOp {
pub w_q: DenseTensor,
pub w_k: DenseTensor,
pub w_v: DenseTensor,
pub w_o: DenseTensor,
pub num_heads: usize,
pub head_dim: usize,
pub scale: f64,
}
impl MultiHeadAttentionOp {
pub fn new(
w_q: DenseTensor,
w_k: DenseTensor,
w_v: DenseTensor,
w_o: DenseTensor,
num_heads: usize,
) -> Self {
let head_dim = w_q.shape()[0] / num_heads;
let scale = 1.0 / (head_dim as f64).sqrt();
Self {
w_q,
w_k,
w_v,
w_o,
num_heads,
head_dim,
scale,
}
}
pub fn forward(&self, x: &DenseTensor) -> DenseTensor {
let batch_size = x.shape()[0];
let seq_len = x.shape()[1];
let q = x.matmul(&self.w_q);
let k = x.matmul(&self.w_k);
let v = x.matmul(&self.w_v);
let q = q.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim]);
let k = k.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim]);
let v = v.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim]);
let q = q.transpose_2d();
let k = k.transpose_2d();
let v = v.transpose_2d();
let k_t = k.transpose(None);
let mut scores = q.matmul(&k_t);
scores = scores.scale(self.scale);
let attn_weights = scores.softmax(-1);
let attn_output = attn_weights.matmul(&v);
let attn_output = attn_output.transpose_2d();
let hidden_dim = self.num_heads * self.head_dim;
let attn_output = attn_output.reshape(&[batch_size, seq_len, hidden_dim]);
attn_output.matmul(&self.w_o)
}
}
#[derive(Debug, Clone)]
pub struct SwiGLUOp {
pub gate_proj: DenseTensor,
pub up_proj: DenseTensor,
pub down_proj: DenseTensor,
}
impl SwiGLUOp {
pub fn new(gate_proj: DenseTensor, up_proj: DenseTensor, down_proj: DenseTensor) -> Self {
Self {
gate_proj,
up_proj,
down_proj,
}
}
pub fn forward(&self, x: &DenseTensor) -> DenseTensor {
let gate = x.matmul(&self.gate_proj);
let gate = gate.silu();
let up = x.matmul(&self.up_proj);
let intermediate = gate.mul(&up);
intermediate.matmul(&self.down_proj)
}
}
#[derive(Debug, Clone)]
pub struct LayerNormOp {
pub weight: DenseTensor,
pub bias: DenseTensor,
pub epsilon: f64,
}
impl LayerNormOp {
pub fn new(weight: DenseTensor, bias: DenseTensor, epsilon: f64) -> Self {
Self {
weight,
bias,
epsilon,
}
}
pub fn forward(&self, x: &DenseTensor) -> DenseTensor {
let mean = x.mean_dim(-1);
let var = x.var_dim(-1);
let normalized = x.sub(&mean).div(&var.add(&DenseTensor::full(var.shape(), self.epsilon)).sqrt());
normalized.mul(&self.weight).add(&self.bias)
}
}
#[derive(Debug, Clone)]
pub struct RMSNormOp {
pub weight: DenseTensor,
pub epsilon: f64,
}
impl RMSNormOp {
pub fn new(weight: DenseTensor, epsilon: f64) -> Self {
Self { weight, epsilon }
}
pub fn forward(&self, x: &DenseTensor) -> DenseTensor {
let ndim = x.ndim();
if ndim == 3 {
let batch = x.shape()[0];
let seq = x.shape()[1];
let hidden = x.shape()[2];
let x_squared = x.mul(x);
let mean_square = x_squared.mean_dim(-1);
let mean_square_expanded = mean_square.expand_last_dim(hidden);
let eps_tensor = DenseTensor::full(&[batch, seq, hidden], self.epsilon);
let rms_input = mean_square_expanded.add(&eps_tensor);
let rms = rms_input.sqrt();
let normalized = x.div(&rms);
let weight_expanded = self.weight.expand_to_3d(batch, seq);
normalized.mul(&weight_expanded)
} else if ndim == 2 {
let seq = x.shape()[0];
let hidden = x.shape()[1];
let x_squared = x.mul(x);
let mean_square = x_squared.mean_dim(-1);
let mean_square_expanded = mean_square.expand_last_dim_2d(hidden);
let eps_tensor = DenseTensor::full(&[seq, hidden], self.epsilon);
let rms_input = mean_square_expanded.add(&eps_tensor);
let rms = rms_input.sqrt();
let normalized = x.div(&rms);
let weight_expanded = self.weight.expand_to_2d(seq);
normalized.mul(&weight_expanded)
} else {
panic!("RMSNormOp only supports 2D or 3D tensors");
}
}
}
#[derive(Debug, Clone)]
pub struct RoPEOp {
pub cos_cache: DenseTensor,
pub sin_cache: DenseTensor,
}
impl RoPEOp {
pub fn new(cos_cache: DenseTensor, sin_cache: DenseTensor) -> Self {
Self { cos_cache, sin_cache }
}
pub fn forward(&self, x: &DenseTensor, positions: &[usize]) -> DenseTensor {
let mut output = x.clone();
for (i, &pos) in positions.iter().enumerate() {
let cos = self.cos_cache.get_row(pos);
let sin = self.sin_cache.get_row(pos);
let x_row = x.get_row(i);
let rotated = self.rotate_half(&x_row);
let result = x_row.mul(&cos).add(&rotated.mul(&sin));
output.set_row(i, &result);
}
output
}
fn rotate_half(&self, x: &DenseTensor) -> DenseTensor {
let dim = x.shape()[0];
let half_dim = dim / 2;
let mut data = vec![0.0; dim];
let x_data = x.data();
for i in 0..half_dim {
data[i] = -x_data[i + half_dim];
data[i + half_dim] = x_data[i];
}
DenseTensor::new(data, vec![1, dim])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_op() {
let weight = DenseTensor::new(vec![0.1, 0.2, 0.3, 0.4], vec![2, 2]);
let bias = DenseTensor::new(vec![0.1, 0.1], vec![1, 2]);
let linear = LinearOp::new(weight, Some(bias));
let input = DenseTensor::new(vec![1.0, 2.0], vec![1, 2]);
let output = linear.forward(&input);
assert_eq!(output.shape(), &[1, 2]);
}
#[test]
fn test_embedding_op() {
let embeddings = DenseTensor::new(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![3, 2],
);
let embedding = EmbeddingOp::new(embeddings);
let indices = vec![0, 2];
let output = embedding.forward(&indices);
assert_eq!(output.shape(), &[2, 2]);
}
#[test]
fn test_scaled_dot_product() {
let op = ScaledDotProductOp::new(4);
let q = DenseTensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![1, 4]);
let k = DenseTensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![1, 4]);
let v = DenseTensor::new(vec![1.0, 2.0], vec![1, 2]);
let output = op.forward(&q, &k, &v);
assert_eq!(output.shape(), &[1, 2]);
}
#[test]
fn test_rms_norm() {
let weight = DenseTensor::ones(vec![4]);
let rms_norm = RMSNormOp::new(weight, 1e-6);
let x = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]);
let output = rms_norm.forward(&x);
assert_eq!(output.shape(), &[1, 4]);
let rms = output.clone().mul(&output).mean_dim(-1);
let rms_val = rms.data()[0];
assert!((rms_val - 1.0).abs() < 0.1);
}
}