use super::{Backend, ComputeOp};
use crate::error::TruenoError;
#[derive(Debug, Clone)]
pub struct FusedQKVWeights {
pub q_weight: Vec<f32>,
pub k_weight: Vec<f32>,
pub v_weight: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct FusedQKVOp {
pub hidden_size: usize,
pub kv_dim: usize,
pub num_heads: usize,
pub head_dim: usize,
}
impl FusedQKVOp {
pub fn new(hidden_size: usize, num_heads: usize, num_kv_heads: usize) -> Self {
let head_dim = hidden_size / num_heads;
let kv_dim = num_kv_heads * head_dim;
Self { hidden_size, kv_dim, num_heads, head_dim }
}
}
#[allow(clippy::needless_range_loop)] impl ComputeOp for FusedQKVOp {
type Input = (Vec<f32>, FusedQKVWeights);
type Output = (Vec<f32>, Vec<f32>, Vec<f32>);
fn name(&self) -> &'static str {
"fused_qkv"
}
fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
let (x, weights) = input;
if x.len() != self.hidden_size {
return Err(TruenoError::SizeMismatch { expected: self.hidden_size, actual: x.len() });
}
let mut q = vec![0.0f32; self.hidden_size];
for i in 0..self.hidden_size {
let mut sum = 0.0f32;
for j in 0..self.hidden_size {
sum += x[j] * weights.q_weight[i * self.hidden_size + j];
}
q[i] = sum;
}
let mut k = vec![0.0f32; self.kv_dim];
for i in 0..self.kv_dim {
let mut sum = 0.0f32;
for j in 0..self.hidden_size {
sum += x[j] * weights.k_weight[i * self.hidden_size + j];
}
k[i] = sum;
}
let mut v = vec![0.0f32; self.kv_dim];
for i in 0..self.kv_dim {
let mut sum = 0.0f32;
for j in 0..self.hidden_size {
sum += x[j] * weights.v_weight[i * self.hidden_size + j];
}
v[i] = sum;
}
Ok((q, k, v))
}
fn tokens(&self, _input: &Self::Input) -> usize {
self.hidden_size + 2 * self.kv_dim
}
}
#[derive(Debug, Clone)]
pub struct FusedGateUpWeights {
pub gate_weight: Vec<f32>,
pub up_weight: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct FusedGateUpOp {
pub hidden_size: usize,
pub intermediate_size: usize,
}
impl FusedGateUpOp {
pub fn new(hidden_size: usize, intermediate_size: usize) -> Self {
Self { hidden_size, intermediate_size }
}
#[inline]
pub fn silu(x: f32) -> f32 {
crate::activations::silu_scalar(x)
}
}
impl ComputeOp for FusedGateUpOp {
type Input = (Vec<f32>, FusedGateUpWeights);
type Output = Vec<f32>;
fn name(&self) -> &'static str {
"fused_gate_up"
}
fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
let (x, weights) = input;
if x.len() != self.hidden_size {
return Err(TruenoError::SizeMismatch { expected: self.hidden_size, actual: x.len() });
}
let mut output = vec![0.0f32; self.intermediate_size];
let simd_backend = crate::Backend::select_best();
let x_vec = crate::Vector::from_slice_with_backend(&x, simd_backend);
for i in 0..self.intermediate_size {
let row_start = i * self.hidden_size;
let row_end = row_start + self.hidden_size;
let gate_row = crate::Vector::from_slice_with_backend(
&weights.gate_weight[row_start..row_end],
simd_backend,
);
let gate_sum = x_vec.dot(&gate_row).unwrap_or(0.0);
let up_row = crate::Vector::from_slice_with_backend(
&weights.up_weight[row_start..row_end],
simd_backend,
);
let up_sum = x_vec.dot(&up_row).unwrap_or(0.0);
output[i] = Self::silu(gate_sum) * up_sum;
}
Ok(output)
}
fn tokens(&self, _input: &Self::Input) -> usize {
self.intermediate_size
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fused_qkv_basic() {
let op = FusedQKVOp::new(4, 2, 1);
let x = vec![1.0, 2.0, 3.0, 4.0];
let weights = FusedQKVWeights {
q_weight: vec![1.0; 16], k_weight: vec![1.0; 8], v_weight: vec![1.0; 8], };
let (q, k, v) = op.execute((x, weights), Backend::Scalar).unwrap();
assert_eq!(q.len(), 4);
assert_eq!(k.len(), 2);
assert_eq!(v.len(), 2);
}
#[test]
fn test_fused_qkv_dimension_mismatch() {
let op = FusedQKVOp::new(4, 2, 2);
let x = vec![1.0, 2.0]; let weights = FusedQKVWeights {
q_weight: vec![1.0; 16],
k_weight: vec![1.0; 8],
v_weight: vec![1.0; 8],
};
let result = op.execute((x, weights), Backend::Scalar);
assert!(result.is_err());
}
#[test]
fn test_fused_gate_up_basic() {
let op = FusedGateUpOp::new(4, 2);
let x = vec![1.0, 2.0, 3.0, 4.0];
let weights = FusedGateUpWeights {
gate_weight: vec![1.0; 8], up_weight: vec![1.0; 8], };
let output = op.execute((x, weights), Backend::Scalar).unwrap();
assert_eq!(output.len(), 2);
assert!(output[0] > 90.0 && output[0] < 110.0);
}
#[test]
fn test_fused_gate_up_dimension_mismatch() {
let op = FusedGateUpOp::new(4, 2);
let x = vec![1.0, 2.0]; let weights = FusedGateUpWeights { gate_weight: vec![1.0; 8], up_weight: vec![1.0; 8] };
let result = op.execute((x, weights), Backend::Scalar);
assert!(result.is_err());
}
#[test]
fn test_silu_values() {
assert!((FusedGateUpOp::silu(0.0) - 0.0).abs() < 1e-6);
assert!((FusedGateUpOp::silu(10.0) - 10.0).abs() < 0.01);
assert!(FusedGateUpOp::silu(-10.0).abs() < 0.01);
}
#[test]
fn test_fused_qkv_tokens() {
let op = FusedQKVOp::new(128, 8, 4);
let weights = FusedQKVWeights { q_weight: vec![], k_weight: vec![], v_weight: vec![] };
assert_eq!(op.tokens(&(vec![], weights)), 256);
}
#[test]
fn test_fused_gate_up_tokens() {
let op = FusedGateUpOp::new(128, 256);
let weights = FusedGateUpWeights { gate_weight: vec![], up_weight: vec![] };
assert_eq!(op.tokens(&(vec![], weights)), 256);
}
}