#![cfg(test)]
use super::rope::{
BatchedRopeKernel, PreciseRopeIndirectKernel, PreciseRopeKernel, RopeIndirectKernel,
RopeKernel, RopeNeoxIndirectKernel, RopeNeoxKernel,
};
use crate::kernels::Kernel;
#[test]
fn test_rope_kernel_basic() {
let kernel = RopeKernel::new(8, 64, 10000.0);
assert_eq!(kernel.name(), "rope");
assert_eq!(kernel.num_heads, 8);
assert_eq!(kernel.head_dim, 64);
assert_eq!(kernel.theta, 10000.0);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".version"));
assert!(ptx.contains(".entry rope"));
assert!(ptx.contains("x_ptr"));
}
#[test]
fn test_rope_kernel_llama_config() {
let kernel = RopeKernel::new(32, 128, 10000.0);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".version"));
}
#[test]
fn test_rope_kernel_small() {
let kernel = RopeKernel::new(4, 32, 10000.0);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".version"));
}
#[test]
fn test_rope_indirect_kernel() {
let kernel = RopeIndirectKernel::new(8, 64, 10000.0);
assert!(kernel.name().contains("indirect"));
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".version"));
assert!(ptx.contains("pos_ptr"));
}
#[test]
fn test_rope_indirect_kernel_large() {
let kernel = RopeIndirectKernel::new(32, 128, 10000.0);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".version"));
}
#[test]
fn test_rope_neox_kernel() {
let kernel = RopeNeoxKernel::new(8, 64, 10000.0);
assert!(kernel.name().contains("neox"));
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".version"));
}
#[test]
fn test_rope_neox_kernel_gpt_config() {
let kernel = RopeNeoxKernel::new(16, 96, 10000.0);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".version"));
}
#[test]
fn test_rope_neox_indirect_kernel() {
let kernel = RopeNeoxIndirectKernel::new(8, 64, 10000.0);
assert!(kernel.name().contains("indirect"));
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".version"));
assert!(ptx.contains("pos_ptr"));
}
#[test]
fn test_batched_rope_kernel() {
let kernel = BatchedRopeKernel::new(8, 64, 4, 10000.0);
assert!(kernel.name().contains("batched"));
assert_eq!(kernel.batch_size, 4);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".version"));
}
#[test]
fn test_batched_rope_kernel_large_batch() {
let kernel = BatchedRopeKernel::new(16, 128, 32, 10000.0);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".version"));
}
#[test]
fn test_precise_rope_kernel() {
let kernel = PreciseRopeKernel::new(8, 64, 1_000_000.0);
assert!(kernel.name().contains("precise"));
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".version"));
}
#[test]
fn test_precise_rope_kernel_qwen_config() {
let kernel = PreciseRopeKernel::new(28, 128, 1_000_000.0);
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".version"));
}
#[test]
fn test_precise_rope_indirect_kernel() {
let kernel = PreciseRopeIndirectKernel::new(8, 64, 1_000_000.0);
assert!(kernel.name().contains("indirect"));
let ptx = kernel.emit_ptx();
assert!(ptx.contains(".version"));
assert!(ptx.contains("pos_ptr"));
}
#[test]
fn test_all_rope_kernel_variants() {
let configs = vec![(8, 64, 10000.0), (16, 128, 10000.0), (32, 64, 1_000_000.0)];
for (num_heads, head_dim, theta) in configs {
let k1 = RopeKernel::new(num_heads, head_dim, theta);
assert!(k1.emit_ptx().contains(".version"));
let k2 = RopeIndirectKernel::new(num_heads, head_dim, theta);
assert!(k2.emit_ptx().contains(".version"));
let k3 = RopeNeoxKernel::new(num_heads, head_dim, theta);
assert!(k3.emit_ptx().contains(".version"));
let k4 = RopeNeoxIndirectKernel::new(num_heads, head_dim, theta);
assert!(k4.emit_ptx().contains(".version"));
let k5 = BatchedRopeKernel::new(num_heads, head_dim, 4, theta);
assert!(k5.emit_ptx().contains(".version"));
let k6 = PreciseRopeKernel::new(num_heads, head_dim, theta);
assert!(k6.emit_ptx().contains(".version"));
let k7 = PreciseRopeIndirectKernel::new(num_heads, head_dim, theta);
assert!(k7.emit_ptx().contains(".version"));
}
}
#[test]
fn test_rope_theta_values() {
let k1 = RopeKernel::new(8, 64, 10000.0);
assert_eq!(k1.theta, 10000.0);
let k2 = RopeKernel::new(8, 64, 500000.0);
assert_eq!(k2.theta, 500000.0);
let k3 = RopeKernel::new(8, 64, 1_000_000.0);
assert_eq!(k3.theta, 1_000_000.0);
}