use super::{Backend, ComputeOp};
use crate::error::TruenoError;
#[derive(Debug, Clone)]
pub struct FusedQKVWeights {
pub q_weight: Vec<f32>,
pub k_weight: Vec<f32>,
pub v_weight: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct FusedQKVOp {
pub hidden_size: usize,
pub kv_dim: usize,
pub num_heads: usize,
pub head_dim: usize,
}
impl FusedQKVOp {
pub fn new(hidden_size: usize, num_heads: usize, num_kv_heads: usize) -> Self {
let head_dim = hidden_size / num_heads;
let kv_dim = num_kv_heads * head_dim;
Self { hidden_size, kv_dim, num_heads, head_dim }
}
}
#[allow(clippy::needless_range_loop)] impl ComputeOp for FusedQKVOp {
type Input = (Vec<f32>, FusedQKVWeights);
type Output = (Vec<f32>, Vec<f32>, Vec<f32>);
fn name(&self) -> &'static str {
"fused_qkv"
}
fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
let (x, weights) = input;
if x.len() != self.hidden_size {
return Err(TruenoError::SizeMismatch { expected: self.hidden_size, actual: x.len() });
}
let h = self.hidden_size;
let mut q: Vec<f32> = Vec::with_capacity(h);
unsafe {
q.set_len(h);
}
for i in 0..h {
q[i] =
super::attention::AttentionOp::simd_dot(&x, &weights.q_weight[i * h..(i + 1) * h]);
}
let mut k: Vec<f32> = Vec::with_capacity(self.kv_dim);
unsafe {
k.set_len(self.kv_dim);
}
for i in 0..self.kv_dim {
k[i] =
super::attention::AttentionOp::simd_dot(&x, &weights.k_weight[i * h..(i + 1) * h]);
}
let mut v: Vec<f32> = Vec::with_capacity(self.kv_dim);
unsafe {
v.set_len(self.kv_dim);
}
for i in 0..self.kv_dim {
v[i] =
super::attention::AttentionOp::simd_dot(&x, &weights.v_weight[i * h..(i + 1) * h]);
}
Ok((q, k, v))
}
fn tokens(&self, _input: &Self::Input) -> usize {
self.hidden_size + 2 * self.kv_dim
}
}
#[derive(Debug, Clone)]
pub struct FusedGateUpWeights {
pub gate_weight: Vec<f32>,
pub up_weight: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct FusedGateUpOp {
pub hidden_size: usize,
pub intermediate_size: usize,
}
impl FusedGateUpOp {
pub fn new(hidden_size: usize, intermediate_size: usize) -> Self {
Self { hidden_size, intermediate_size }
}
#[inline]
pub fn silu(x: f32) -> f32 {
contract_pre_silu!();
let result = crate::activations::silu_scalar(x);
contract_post_silu!(&[result]);
result
}
}
impl ComputeOp for FusedGateUpOp {
type Input = (Vec<f32>, FusedGateUpWeights);
type Output = Vec<f32>;
fn name(&self) -> &'static str {
"fused_gate_up"
}
fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
let (x, weights) = input;
if x.len() != self.hidden_size {
return Err(TruenoError::SizeMismatch { expected: self.hidden_size, actual: x.len() });
}
let mut output: Vec<f32> = Vec::with_capacity(self.intermediate_size);
unsafe {
output.set_len(self.intermediate_size);
}
let h = self.hidden_size;
for i in 0..self.intermediate_size {
let row_start = i * h;
let row_end = row_start + h;
let gate_sum = super::attention::AttentionOp::simd_dot(
&x,
&weights.gate_weight[row_start..row_end],
);
let up_sum =
super::attention::AttentionOp::simd_dot(&x, &weights.up_weight[row_start..row_end]);
output[i] = Self::silu(gate_sum) * up_sum;
}
Ok(output)
}
fn tokens(&self, _input: &Self::Input) -> usize {
self.intermediate_size
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fused_qkv_basic() {
let op = FusedQKVOp::new(4, 2, 1);
let x = vec![1.0, 2.0, 3.0, 4.0];
let weights = FusedQKVWeights {
q_weight: vec![1.0; 16], k_weight: vec![1.0; 8], v_weight: vec![1.0; 8], };
let (q, k, v) = op.execute((x, weights), Backend::Scalar).unwrap();
assert_eq!(q.len(), 4);
assert_eq!(k.len(), 2);
assert_eq!(v.len(), 2);
}
#[test]
fn test_fused_qkv_dimension_mismatch() {
let op = FusedQKVOp::new(4, 2, 2);
let x = vec![1.0, 2.0]; let weights = FusedQKVWeights {
q_weight: vec![1.0; 16],
k_weight: vec![1.0; 8],
v_weight: vec![1.0; 8],
};
let result = op.execute((x, weights), Backend::Scalar);
assert!(result.is_err());
}
#[test]
fn test_fused_gate_up_basic() {
let op = FusedGateUpOp::new(4, 2);
let x = vec![1.0, 2.0, 3.0, 4.0];
let weights = FusedGateUpWeights {
gate_weight: vec![1.0; 8], up_weight: vec![1.0; 8], };
let output = op.execute((x, weights), Backend::Scalar).unwrap();
assert_eq!(output.len(), 2);
assert!(output[0] > 90.0 && output[0] < 110.0);
}
#[test]
fn test_fused_gate_up_dimension_mismatch() {
let op = FusedGateUpOp::new(4, 2);
let x = vec![1.0, 2.0]; let weights = FusedGateUpWeights { gate_weight: vec![1.0; 8], up_weight: vec![1.0; 8] };
let result = op.execute((x, weights), Backend::Scalar);
assert!(result.is_err());
}
#[test]
fn test_silu_values() {
assert!((FusedGateUpOp::silu(0.0) - 0.0).abs() < 1e-6);
assert!((FusedGateUpOp::silu(10.0) - 10.0).abs() < 0.01);
assert!(FusedGateUpOp::silu(-10.0).abs() < 0.01);
}
#[test]
fn test_fused_qkv_tokens() {
let op = FusedQKVOp::new(128, 8, 4);
let weights = FusedQKVWeights { q_weight: vec![], k_weight: vec![], v_weight: vec![] };
assert_eq!(op.tokens(&(vec![], weights)), 256);
}
#[test]
fn test_fused_gate_up_tokens() {
let op = FusedGateUpOp::new(128, 256);
let weights = FusedGateUpWeights { gate_weight: vec![], up_weight: vec![] };
assert_eq!(op.tokens(&(vec![], weights)), 256);
}
}