mod build_ptx;
#[derive(Debug, Clone)]
pub struct FusedRmsNormGateUpSwigluQ4KKernel {
pub k: u32,
pub n: u32,
pub epsilon: f32,
}
impl FusedRmsNormGateUpSwigluQ4KKernel {
#[must_use]
pub fn new(k: u32, n: u32) -> Self {
Self {
k,
n,
epsilon: 1e-6, }
}
#[must_use]
pub const fn with_epsilon(mut self, epsilon: f32) -> Self {
self.epsilon = epsilon;
self
}
}
#[cfg(test)]
mod tests_3way_fusion {
use super::*;
use crate::kernels::Kernel;
#[test]
fn test_fused_rmsnorm_gate_up_swiglu_q4k_kernel_builds() {
let kernel = FusedRmsNormGateUpSwigluQ4KKernel::new(3584, 18944);
let ptx = kernel.emit_ptx();
assert!(ptx.contains("fused_rmsnorm_gate_up_swiglu_q4k"));
assert!(ptx.contains(".entry"));
}
#[test]
fn test_fused_rmsnorm_gate_up_swiglu_q4k_kernel_name() {
let kernel = FusedRmsNormGateUpSwigluQ4KKernel::new(1024, 4096);
assert_eq!(kernel.name(), "fused_rmsnorm_gate_up_swiglu_q4k");
}
#[test]
fn test_fused_rmsnorm_gate_up_swiglu_q4k_kernel_clone() {
let kernel = FusedRmsNormGateUpSwigluQ4KKernel::new(2048, 8192);
let cloned = kernel.clone();
assert_eq!(cloned.k, kernel.k);
assert_eq!(cloned.n, kernel.n);
assert_eq!(cloned.epsilon, kernel.epsilon);
}
#[test]
fn test_fused_rmsnorm_gate_up_swiglu_q4k_with_epsilon() {
let kernel = FusedRmsNormGateUpSwigluQ4KKernel::new(2048, 8192).with_epsilon(1e-5);
assert_eq!(kernel.epsilon, 1e-5);
}
#[test]
fn test_fused_rmsnorm_gate_up_swiglu_q4k_kernel_debug() {
let kernel = FusedRmsNormGateUpSwigluQ4KKernel::new(2048, 8192);
let debug = format!("{:?}", kernel);
assert!(debug.contains("FusedRmsNormGateUpSwigluQ4KKernel"));
assert!(debug.contains("2048"));
assert!(debug.contains("8192"));
}
}