use super::device::ComputeDevice;
use super::tensor::ComputeTensor;
use crate::Scalar;
use crate::expr::node::ExprId;
pub fn add_tensors<D: ComputeDevice>(
dev: &D,
a: &ComputeTensor<D::Buffer>,
b: &ComputeTensor<D::Buffer>,
) -> ComputeTensor<D::Buffer> {
assert_eq!(a.numel(), b.numel(), "add_tensors: numel mismatch");
let buf = dev.elementwise(
&[&a.buffer, &b.buffer],
a.numel(),
&|ids: &[ExprId]| ids[0] + ids[1],
);
ComputeTensor::from_buffer(buf, a.shape().to_vec())
}
pub fn bias_add<D: ComputeDevice>(
dev: &D,
matrix: &ComputeTensor<D::Buffer>,
bias: &ComputeTensor<D::Buffer>,
) -> ComputeTensor<D::Buffer> {
let numel = matrix.numel();
let dim = bias.numel();
assert_eq!(numel % dim, 0, "bias_add: matrix numel not divisible by bias dim");
let buf = dev.bias_add(&matrix.buffer, &bias.buffer, numel, dim);
ComputeTensor::from_buffer(buf, matrix.shape().to_vec())
}
pub fn swiglu_fused<D: ComputeDevice>(
dev: &D,
gate: &ComputeTensor<D::Buffer>,
up: &ComputeTensor<D::Buffer>,
) -> ComputeTensor<D::Buffer> {
assert_eq!(gate.numel(), up.numel(), "swiglu_fused: numel mismatch");
let buf = dev.elementwise(
&[&gate.buffer, &up.buffer],
gate.numel(),
&|ids: &[ExprId]| {
let one = ExprId::from_f64(1.0);
let neg_gate = -ids[0];
let exp_neg = Scalar::exp(neg_gate);
let denom = one + exp_neg;
let sigmoid = one / denom;
let silu = ids[0] * sigmoid;
silu * ids[1]
},
);
ComputeTensor::from_buffer(buf, gate.shape().to_vec())
}
pub fn swiglu_backward<D: ComputeDevice>(
dev: &D,
grad_output: &ComputeTensor<D::Buffer>,
gate: &ComputeTensor<D::Buffer>,
up: &ComputeTensor<D::Buffer>,
) -> (ComputeTensor<D::Buffer>, ComputeTensor<D::Buffer>) {
let numel = grad_output.numel();
let grad_up_buf = dev.elementwise(
&[&grad_output.buffer, &gate.buffer],
numel,
&|ids: &[ExprId]| {
let one = ExprId::from_f64(1.0);
let neg_gate = -ids[1];
let exp_neg = Scalar::exp(neg_gate);
let sigmoid = one / (one + exp_neg);
let silu = ids[1] * sigmoid;
ids[0] * silu
},
);
let grad_gate_buf = dev.elementwise(
&[&grad_output.buffer, &gate.buffer, &up.buffer],
numel,
&|ids: &[ExprId]| {
let one = ExprId::from_f64(1.0);
let neg_gate = -ids[1];
let exp_neg = Scalar::exp(neg_gate);
let sigmoid = one / (one + exp_neg);
let dsilu = sigmoid * (one + ids[1] * (one - sigmoid));
ids[0] * ids[2] * dsilu
},
);
(
ComputeTensor::from_buffer(grad_gate_buf, grad_output.shape().to_vec()),
ComputeTensor::from_buffer(grad_up_buf, grad_output.shape().to_vec()),
)
}
pub fn causal_attention_backward<D: ComputeDevice>(
dev: &D,
grad_output: &ComputeTensor<D::Buffer>,
q: &ComputeTensor<D::Buffer>,
k: &ComputeTensor<D::Buffer>,
v: &ComputeTensor<D::Buffer>,
seq_len: usize,
n_heads: usize,
n_kv_heads: usize,
head_dim: usize,
) -> (ComputeTensor<D::Buffer>, ComputeTensor<D::Buffer>, ComputeTensor<D::Buffer>) {
let total_dim = n_heads * head_dim;
let kv_dim = n_kv_heads * head_dim;
let (gq, gk, gv) = dev.causal_attention_backward(
&grad_output.buffer,
&q.buffer,
&k.buffer,
&v.buffer,
seq_len,
n_heads,
n_kv_heads,
head_dim,
);
(
ComputeTensor::from_buffer(gq, vec![seq_len, total_dim]),
ComputeTensor::from_buffer(gk, vec![seq_len, kv_dim]),
ComputeTensor::from_buffer(gv, vec![seq_len, kv_dim]),
)
}
#[cfg(test)]
mod tests {
use super::*;
use super::CpuDevice;
#[test]
fn test_add_tensors() {
let dev = CpuDevice::new();
let a = ComputeTensor::from_data(&dev, &[1.0, 2.0, 3.0], &[3]);
let b = ComputeTensor::from_data(&dev, &[4.0, 5.0, 6.0], &[3]);
let c = add_tensors(&dev, &a, &b);
let out = c.to_vec();
assert!((out[0] - 5.0).abs() < 1e-5);
assert!((out[1] - 7.0).abs() < 1e-5);
assert!((out[2] - 9.0).abs() < 1e-5);
}
#[test]
fn test_bias_add() {
let dev = CpuDevice::new();
let mat = ComputeTensor::from_data(&dev, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let bias = ComputeTensor::from_data(&dev, &[0.1, 0.2, 0.3], &[3]);
let out = bias_add(&dev, &mat, &bias);
let v = out.to_vec();
assert!((v[0] - 1.1).abs() < 1e-5);
assert!((v[1] - 2.2).abs() < 1e-5);
assert!((v[2] - 3.3).abs() < 1e-5);
assert!((v[3] - 4.1).abs() < 1e-5);
assert!((v[4] - 5.2).abs() < 1e-5);
assert!((v[5] - 6.3).abs() < 1e-5);
}
#[test]
fn test_swiglu_backward() {
let dev = CpuDevice::new();
let gate = ComputeTensor::from_data(&dev, &[0.0, 1.0, -1.0], &[3]);
let up = ComputeTensor::from_data(&dev, &[1.0, 1.0, 1.0], &[3]);
let grad = ComputeTensor::from_data(&dev, &[1.0, 1.0, 1.0], &[3]);
let (grad_gate, grad_up) = swiglu_backward(&dev, &grad, &gate, &up);
let gg = grad_gate.to_vec();
let gu = grad_up.to_vec();
assert!(gu[0].abs() < 1e-5);
assert!((gu[1] - 0.7311).abs() < 1e-3);
for v in &gg {
assert!(v.is_finite());
}
}
#[test]
fn test_causal_attention_backward() {
let dev = CpuDevice::new();
let q = ComputeTensor::from_data(&dev, &[1.0, 0.0, 0.0, 1.0], &[2, 2]);
let k = ComputeTensor::from_data(&dev, &[1.0, 0.0, 0.0, 1.0], &[2, 2]);
let v = ComputeTensor::from_data(&dev, &[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let grad_out = ComputeTensor::from_data(&dev, &[1.0, 0.0, 0.0, 1.0], &[2, 2]);
let (gq, gk, gv) = causal_attention_backward(
&dev, &grad_out, &q, &k, &v, 2, 1, 1, 2,
);
assert_eq!(gq.shape(), &[2, 2]);
assert_eq!(gk.shape(), &[2, 2]);
assert_eq!(gv.shape(), &[2, 2]);
for v in gq.to_vec().iter().chain(gk.to_vec().iter()).chain(gv.to_vec().iter()) {
assert!(v.is_finite(), "gradient should be finite");
}
}
#[test]
fn test_swiglu_fused() {
let dev = CpuDevice::new();
let gate = ComputeTensor::from_data(&dev, &[0.0, 1.0, -1.0], &[3]);
let up = ComputeTensor::from_data(&dev, &[1.0, 1.0, 1.0], &[3]);
let out = swiglu_fused(&dev, &gate, &up);
let v = out.to_vec();
assert!(v[0].abs() < 1e-5);
assert!((v[1] - 0.7311).abs() < 1e-3);
assert!((v[2] - (-0.2689)).abs() < 1e-3);
}
}