use crate::CpuTensor;
use crate::DType;
use tl_backend::error::BackendError;
use tl_backend::fused_ops::GpuFusedOps;
type Result<T> = std::result::Result<T, BackendError>;
fn silu(x: f32) -> f32 {
x / (1.0 + (-x).exp())
}
fn gelu(x: f32) -> f32 {
0.5 * x * (1.0 + (std::f32::consts::FRAC_2_SQRT_PI * (x + 0.044715 * x.powi(3))).tanh())
}
impl GpuFusedOps for CpuTensor<f32> {
fn fused_silu_mul(&self, up: &Self) -> Result<Self> {
let count = self.elem_count();
assert_eq!(count, up.elem_count(), "fused_silu_mul: shape mismatch");
let data: Vec<f32> = self.data.iter()
.zip(up.data.iter())
.map(|(&g, &u)| silu(g) * u)
.collect();
Ok(CpuTensor::<f32>::from_slice(&data, self.shape(), DType::F32))
}
fn fused_rms_norm(&self, weight: &Self, eps: f32) -> Result<Self> {
let shape = self.shape();
let d = *shape.last().unwrap();
let n = self.elem_count() / d;
let mut data = vec![0.0f32; self.elem_count()];
for row in 0..n {
let offset = row * d;
let row_data = &self.data[offset..offset + d];
let sum_sq: f32 = row_data.iter().map(|x| x * x).sum();
let rms = (sum_sq / d as f32 + eps).sqrt().recip();
for col in 0..d {
data[offset + col] = row_data[col] * rms * weight.data[col];
}
}
Ok(CpuTensor::<f32>::from_slice(&data, shape, DType::F32))
}
fn fused_add_rms_norm(&self, residual: &Self, weight: &Self, eps: f32) -> Result<Self> {
let shape = self.shape();
assert_eq!(self.elem_count(), residual.elem_count(), "fused_add_rms_norm: shape mismatch");
let d = *shape.last().unwrap();
let n = self.elem_count() / d;
let mut data = vec![0.0f32; self.elem_count()];
for row in 0..n {
let offset = row * d;
let sum_sq: f32 = (0..d)
.map(|i| {
let v = self.data[offset + i] + residual.data[offset + i];
v * v
})
.sum();
let rms = (sum_sq / d as f32 + eps).sqrt().recip();
for col in 0..d {
let added = self.data[offset + col] + residual.data[offset + col];
data[offset + col] = added * rms * weight.data[col];
}
}
Ok(CpuTensor::<f32>::from_slice(&data, shape, DType::F32))
}
fn fused_rotary_emb(&self, freqs: &Self, head_dim: usize) -> Result<Self> {
let count = self.elem_count();
let half_dim = head_dim / 2;
let mut data = vec![0.0f32; count];
for id in 0..(count / 2) {
let pair_idx = id % half_dim;
let base = id - pair_idx;
let freq = freqs.data[pair_idx];
let cos_val = freq.cos();
let sin_val = freq.sin();
let x0 = self.data[base + pair_idx];
let x1 = self.data[base + pair_idx + half_dim];
data[base + pair_idx] = x0 * cos_val - x1 * sin_val;
data[base + pair_idx + half_dim] = x0 * sin_val + x1 * cos_val;
}
Ok(CpuTensor::<f32>::from_slice(&data, self.shape(), DType::F32))
}
fn fused_add_relu(&self, other: &Self) -> Result<Self> {
let count = self.elem_count();
assert_eq!(count, other.elem_count(), "fused_add_relu: shape mismatch");
let data: Vec<f32> = self.data.iter()
.zip(other.data.iter())
.map(|(&a, &b)| (a + b).max(0.0))
.collect();
Ok(CpuTensor::<f32>::from_slice(&data, self.shape(), DType::F32))
}
fn fused_bias_gelu(&self, bias: &Self) -> Result<Self> {
let bias_len = bias.elem_count();
let data: Vec<f32> = self.data.iter()
.enumerate()
.map(|(i, &x)| gelu(x + bias.data[i % bias_len]))
.collect();
Ok(CpuTensor::<f32>::from_slice(&data, self.shape(), DType::F32))
}
}