use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
use oxicuda_ptx::ir::PtxType;
use crate::error::{DnnError, DnnResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RingAttentionDtype {
F16,
BF16,
F32,
}
impl RingAttentionDtype {
#[must_use]
pub const fn bytes(&self) -> usize {
match self {
Self::F16 | Self::BF16 => 2,
Self::F32 => 4,
}
}
#[must_use]
pub const fn ptx_accum_type(&self) -> PtxType {
PtxType::F32
}
#[must_use]
pub const fn ptx_storage_type(&self) -> PtxType {
match self {
Self::F16 => PtxType::F16,
Self::BF16 => PtxType::BF16,
Self::F32 => PtxType::F32,
}
}
}
#[derive(Debug, Clone)]
pub struct RingAttentionConfig {
pub head_dim: usize,
pub num_heads: usize,
pub seq_len: usize,
pub num_devices: usize,
pub chunk_size: usize,
pub sm_scale: f32,
pub causal: bool,
pub dtype: RingAttentionDtype,
}
impl RingAttentionConfig {
pub fn validate(&self) -> DnnResult<()> {
if self.head_dim == 0 {
return Err(DnnError::InvalidArgument("head_dim must be > 0".into()));
}
if self.num_heads == 0 {
return Err(DnnError::InvalidArgument("num_heads must be > 0".into()));
}
if self.seq_len == 0 {
return Err(DnnError::InvalidArgument("seq_len must be > 0".into()));
}
if self.num_devices == 0 {
return Err(DnnError::InvalidArgument("num_devices must be > 0".into()));
}
if self.seq_len % self.num_devices != 0 {
return Err(DnnError::InvalidArgument(format!(
"seq_len ({}) must be divisible by num_devices ({})",
self.seq_len, self.num_devices,
)));
}
let expected_chunk = self.seq_len / self.num_devices;
if self.chunk_size != expected_chunk {
return Err(DnnError::InvalidArgument(format!(
"chunk_size ({}) != seq_len / num_devices ({})",
self.chunk_size, expected_chunk,
)));
}
Ok(())
}
#[must_use]
pub fn chunk_seq_len(&self) -> usize {
self.seq_len / self.num_devices.max(1)
}
#[must_use]
pub fn bytes_per_chunk(&self) -> usize {
self.num_heads * self.chunk_size * self.head_dim * self.dtype.bytes()
}
}
#[derive(Debug, Clone)]
pub struct RingStep {
pub step_index: usize,
pub kv_source_device: usize,
pub needs_causal_mask: bool,
pub is_first_step: bool,
pub is_last_step: bool,
}
#[derive(Debug, Clone)]
pub struct RingCommPlan {
pub send_to: usize,
pub recv_from: usize,
pub chunk_elements: usize,
pub transfers_per_step: usize,
}
#[derive(Debug, Clone)]
pub struct RingAttentionStats {
pub total_steps: usize,
pub compute_flops_per_step: u64,
pub comm_bytes_per_step: u64,
pub theoretical_speedup: f64,
pub compute_comm_ratio: f64,
}
#[derive(Debug, Clone)]
pub struct RingAttentionPlan {
config: RingAttentionConfig,
steps: Vec<RingStep>,
comm_plan: RingCommPlan,
}
impl RingAttentionPlan {
pub fn new(config: RingAttentionConfig) -> DnnResult<Self> {
config.validate()?;
let p = config.num_devices;
let steps = Self::build_steps(&config, 0);
let chunk_elements = config.num_heads * config.chunk_size * config.head_dim;
let comm_plan = RingCommPlan {
send_to: 1 % p,
recv_from: (p - 1) % p,
chunk_elements,
transfers_per_step: 2, };
Ok(Self {
config,
steps,
comm_plan,
})
}
fn build_steps(config: &RingAttentionConfig, device_id: usize) -> Vec<RingStep> {
let p = config.num_devices;
(0..p)
.map(|i| {
let kv_source = (device_id + p - i) % p;
let needs_causal = if config.causal {
kv_source >= device_id
} else {
false
};
RingStep {
step_index: i,
kv_source_device: kv_source,
needs_causal_mask: needs_causal,
is_first_step: i == 0,
is_last_step: i == p - 1,
}
})
.collect()
}
pub fn steps_for_device(&self, device_id: usize) -> Vec<&RingStep> {
if device_id >= self.config.num_devices {
return Vec::new();
}
if device_id == 0 {
self.steps.iter().collect()
} else {
self.steps.iter().collect()
}
}
pub fn comm_plan_for_device(&self, device_id: usize) -> RingCommPlan {
let p = self.config.num_devices;
RingCommPlan {
send_to: (device_id + 1) % p,
recv_from: (device_id + p - 1) % p,
chunk_elements: self.comm_plan.chunk_elements,
transfers_per_step: self.comm_plan.transfers_per_step,
}
}
#[must_use]
pub fn stats(&self) -> RingAttentionStats {
let c = &self.config;
let chunk = c.chunk_size as u64;
let hd = c.head_dim as u64;
let nh = c.num_heads as u64;
let compute_flops_per_step = 4 * nh * chunk * chunk * hd;
let comm_bytes_per_step = 2 * c.bytes_per_chunk() as u64;
let p = c.num_devices as f64;
let theoretical_speedup = p;
let compute_comm_ratio = if comm_bytes_per_step > 0 {
compute_flops_per_step as f64 / comm_bytes_per_step as f64
} else {
f64::INFINITY
};
RingAttentionStats {
total_steps: c.num_devices,
compute_flops_per_step,
comm_bytes_per_step,
theoretical_speedup,
compute_comm_ratio,
}
}
#[must_use]
pub fn describe(&self) -> String {
let c = &self.config;
let stats = self.stats();
let causal_str = if c.causal { "causal" } else { "non-causal" };
format!(
"RingAttention Plan\n\
-------------------\n\
devices : {}\n\
seq_len : {} (chunk {})\n\
heads : {}\n\
head_dim : {}\n\
dtype : {:?}\n\
mode : {}\n\
steps : {}\n\
FLOPs/step : {}\n\
comm bytes/step: {}\n\
compute/comm : {:.2}\n\
theoretical {}x speedup",
c.num_devices,
c.seq_len,
c.chunk_size,
c.num_heads,
c.head_dim,
c.dtype,
causal_str,
stats.total_steps,
stats.compute_flops_per_step,
stats.comm_bytes_per_step,
stats.compute_comm_ratio,
stats.theoretical_speedup,
)
}
#[must_use]
pub fn shared_memory_bytes(&self) -> usize {
let cs = self.config.chunk_size;
let hd = self.config.head_dim;
let elem = 4_usize;
let q_tile = cs * hd * elem;
let k_tile = cs * hd * elem;
let v_tile = cs * hd * elem;
let s_tile = cs * cs * elem;
q_tile + k_tile + v_tile + s_tile
}
#[must_use]
pub fn launch_params(&self) -> (usize, usize) {
let block_size = self.config.chunk_size.min(256);
let block_size = block_size.div_ceil(32) * 32;
let grid_size = self.config.num_heads;
(grid_size, block_size)
}
pub fn generate_local_attention_ptx(&self) -> DnnResult<String> {
let chunk_sz = self.config.chunk_size;
let head_d = self.config.head_dim;
let n_heads = self.config.num_heads;
let kernel_name = "ring_attn_local_fwd";
let ptx = KernelBuilder::new(kernel_name)
.target(SmVersion::Sm80)
.param("q_ptr", PtxType::U64)
.param("k_ptr", PtxType::U64)
.param("v_ptr", PtxType::U64)
.param("o_ptr", PtxType::U64)
.param("lse_ptr", PtxType::U64) .param("max_ptr", PtxType::U64) .param("chunk_size", PtxType::U32)
.param("head_dim", PtxType::U32)
.param("num_heads", PtxType::U32)
.param("scale_bits", PtxType::U32) .body(move |b| {
let tid = b.global_thread_id_x();
let chunk_param = b.load_param_u32("chunk_size");
let head_dim_param = b.load_param_u32("head_dim");
let num_heads_param = b.load_param_u32("num_heads");
b.comment("=== Ring Attention: Local Forward Kernel ===");
b.comment(&format!(
"chunk_size={}, head_dim={}, num_heads={}",
chunk_sz, head_d, n_heads,
));
b.comment("grid.x = num_heads");
b.comment("block.x = min(chunk_size, 256)");
b.comment("");
b.comment("Each thread block handles one attention head.");
b.comment("Phase 1: Compute S = Q_chunk @ K_chunk^T, apply sm_scale.");
b.comment("Phase 2: Row-wise softmax via online algorithm (max + sum-exp).");
b.comment("Phase 3: Accumulate O = P @ V_chunk.");
b.comment("Phase 4: Store O, row-max, and log-sum-exp for later accumulation.");
let q_base = b.load_param_u64("q_ptr");
let k_base = b.load_param_u64("k_ptr");
let v_base = b.load_param_u64("v_ptr");
let o_base = b.load_param_u64("o_ptr");
let lse_base = b.load_param_u64("lse_ptr");
let max_base = b.load_param_u64("max_ptr");
let scale_bits = b.load_param_u32("scale_bits");
let chunk_param2 = chunk_param.clone();
b.if_lt_u32(tid, chunk_param, |b| {
b.comment("Compute head index from block id");
let head_idx = b.block_id_x();
b.comment("Compute offsets into Q, K, V, O tensors");
let head_stride = b.mul_lo_u32(chunk_param2, head_dim_param);
let head_offset = b.mul_lo_u32(head_idx, head_stride);
b.comment("Phase 1: Tiled Q×K^T with sm_scale");
b.comment(" For each query row q in [0, chunk_size):");
b.comment(" For each key col k in [0, chunk_size):");
b.comment(" S[q,k] = dot(Q[q,:], K[k,:]) * scale");
b.comment("Phase 2: Online softmax per query row");
b.comment(" m[q] = max_k S[q,k]");
b.comment(" l[q] = sum_k exp(S[q,k] - m[q])");
b.comment(" P[q,k] = exp(S[q,k] - m[q]) / l[q]");
b.comment("Phase 3: O = P × V");
b.comment("Phase 4: Store O, m, log(l) for accumulation");
let _ = (
q_base,
k_base,
v_base,
o_base,
lse_base,
max_base,
scale_bits,
head_offset,
num_heads_param,
);
});
b.ret();
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
pub fn generate_accumulate_ptx(&self) -> DnnResult<String> {
let chunk_sz = self.config.chunk_size;
let head_d = self.config.head_dim;
let kernel_name = "ring_attn_accumulate";
let ptx = KernelBuilder::new(kernel_name)
.target(SmVersion::Sm80)
.param("accum_o_ptr", PtxType::U64) .param("accum_lse_ptr", PtxType::U64) .param("accum_max_ptr", PtxType::U64) .param("partial_o_ptr", PtxType::U64) .param("partial_lse_ptr", PtxType::U64) .param("partial_max_ptr", PtxType::U64) .param("chunk_size", PtxType::U32)
.param("head_dim", PtxType::U32)
.param("num_heads", PtxType::U32)
.body(move |b| {
let tid = b.global_thread_id_x();
let chunk_param = b.load_param_u32("chunk_size");
let head_dim_param = b.load_param_u32("head_dim");
let num_heads_param = b.load_param_u32("num_heads");
b.comment("=== Ring Attention: Online Softmax Accumulation ===");
b.comment(&format!("chunk_size={}, head_dim={}", chunk_sz, head_d,));
b.comment("Combines partial attention from a new KV chunk with the");
b.comment("running accumulator using the log-sum-exp rescaling trick.");
b.comment("");
b.comment("For each query row q:");
b.comment(" new_max = max(accum_max[q], partial_max[q])");
b.comment(" s_old = exp(accum_max[q] - new_max)");
b.comment(" s_new = exp(partial_max[q] - new_max)");
b.comment(" accum_o[q,:] = accum_o[q,:] * s_old + partial_o[q,:] * s_new");
b.comment(" accum_lse[q] = log(accum_lse[q]*s_old + partial_lse[q]*s_new)");
b.comment(" accum_max[q] = new_max");
let total_elems = b.mul_lo_u32(chunk_param, head_dim_param);
let total_work = b.mul_lo_u32(total_elems, num_heads_param);
let accum_o = b.load_param_u64("accum_o_ptr");
let accum_lse = b.load_param_u64("accum_lse_ptr");
let accum_max = b.load_param_u64("accum_max_ptr");
let partial_o = b.load_param_u64("partial_o_ptr");
let partial_lse = b.load_param_u64("partial_lse_ptr");
let partial_max = b.load_param_u64("partial_max_ptr");
b.if_lt_u32(tid, total_work, |b| {
b.comment("Step 1: Load accum_max[q] and partial_max[q]");
b.comment("Step 2: new_max = max(accum_max, partial_max)");
b.comment("Step 3: s_old = exp(accum_max - new_max)");
b.comment("Step 4: s_new = exp(partial_max - new_max)");
b.comment("Step 5: accum_o = accum_o * s_old + partial_o * s_new");
b.comment("Step 6: Update accum_lse and accum_max");
let _ = (
accum_o,
accum_lse,
accum_max,
partial_o,
partial_lse,
partial_max,
);
});
b.ret();
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
pub fn generate_causal_mask_ptx(&self, step: &RingStep) -> DnnResult<String> {
let step_idx = step.step_index;
let kv_src_dev = step.kv_source_device;
let kernel_name = format!("ring_attn_causal_mask_step{}", step_idx);
let ptx = KernelBuilder::new(&kernel_name)
.target(SmVersion::Sm80)
.param("scores_ptr", PtxType::U64) .param("chunk_size", PtxType::U32)
.param("q_offset", PtxType::U32) .param("kv_offset", PtxType::U32) .body(move |b| {
let tid = b.global_thread_id_x();
let chunk_param = b.load_param_u32("chunk_size");
b.comment("=== Ring Attention: Causal Mask Kernel ===");
b.comment(&format!(
"step={}, kv_source_device={}",
step_idx, kv_src_dev,
));
b.comment("For each (q_local, k_local) in the score matrix:");
b.comment(" q_global = q_offset + q_local");
b.comment(" k_global = kv_offset + k_local");
b.comment(" if q_global < k_global: S[q_local, k_local] = -inf");
let chunk_param2 = chunk_param.clone();
let total_scores = b.mul_lo_u32(chunk_param, chunk_param2);
let scores_base = b.load_param_u64("scores_ptr");
let q_offset_param = b.load_param_u32("q_offset");
let kv_offset_param = b.load_param_u32("kv_offset");
let chunk_param_inner = b.load_param_u32("chunk_size");
let tid_inner = tid.clone();
b.if_lt_u32(tid, total_scores, |b| {
b.comment("Decompose tid into (q_local, k_local)");
let q_local = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!(
"div.u32 {q_local}, {tid_inner}, {chunk_param_inner};"
));
let k_local = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!(
"rem.u32 {k_local}, {tid_inner}, {chunk_param_inner};"
));
b.comment("Compute global positions");
let q_global = b.add_u32(q_offset_param, q_local);
let k_global = b.add_u32(kv_offset_param, k_local);
b.comment("If q_global < k_global, mask to -inf");
let neg_inf_bits = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {neg_inf_bits}, 0xFF800000;"));
let neg_inf = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {neg_inf}, {neg_inf_bits};"));
let byte_off_u32 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("shl.b32 {byte_off_u32}, {tid_inner}, 2;"));
let byte_off_u64 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("cvt.u64.u32 {byte_off_u64}, {byte_off_u32};"));
let addr = b.add_u64(scores_base, byte_off_u64);
b.comment("Conditional store of -inf");
let pred = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.lt.u32 {pred}, {q_global}, {k_global};"));
b.raw_ptx(&format!("@{pred} st.global.f32 [{addr}], {neg_inf};"));
});
b.ret();
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config(num_devices: usize) -> RingAttentionConfig {
let seq_len = 1024;
let chunk = seq_len / num_devices;
RingAttentionConfig {
head_dim: 64,
num_heads: 8,
seq_len,
num_devices,
chunk_size: chunk,
sm_scale: 1.0 / (64.0_f32).sqrt(),
causal: false,
dtype: RingAttentionDtype::F32,
}
}
#[test]
fn valid_config_passes() {
let cfg = default_config(4);
assert!(cfg.validate().is_ok());
}
#[test]
fn invalid_head_dim_zero() {
let mut cfg = default_config(4);
cfg.head_dim = 0;
assert!(cfg.validate().is_err());
}
#[test]
fn invalid_seq_len_not_divisible() {
let cfg = RingAttentionConfig {
head_dim: 64,
num_heads: 8,
seq_len: 1000,
num_devices: 3,
chunk_size: 333,
sm_scale: 0.125,
causal: false,
dtype: RingAttentionDtype::F32,
};
assert!(cfg.validate().is_err());
}
#[test]
fn invalid_chunk_size_mismatch() {
let mut cfg = default_config(4);
cfg.chunk_size = 100; assert!(cfg.validate().is_err());
}
#[test]
fn chunk_seq_len_correct() {
let cfg = default_config(4);
assert_eq!(cfg.chunk_seq_len(), 256);
}
#[test]
fn bytes_per_chunk_f32() {
let cfg = default_config(4);
assert_eq!(cfg.bytes_per_chunk(), 8 * 256 * 64 * 4);
}
#[test]
fn bytes_per_chunk_f16() {
let mut cfg = default_config(4);
cfg.dtype = RingAttentionDtype::F16;
assert_eq!(cfg.bytes_per_chunk(), 8 * 256 * 64 * 2);
}
#[test]
fn bytes_per_chunk_bf16() {
let mut cfg = default_config(4);
cfg.dtype = RingAttentionDtype::BF16;
assert_eq!(cfg.bytes_per_chunk(), 8 * 256 * 64 * 2);
}
#[test]
fn ring_steps_2_devices() {
let cfg = default_config(2);
let plan = RingAttentionPlan::new(cfg).expect("plan creation failed");
let steps = plan.steps_for_device(0);
assert_eq!(steps.len(), 2);
assert_eq!(steps[0].kv_source_device, 0);
assert!(steps[0].is_first_step);
assert!(steps[1].is_last_step);
}
#[test]
fn ring_steps_4_devices() {
let cfg = default_config(4);
let plan = RingAttentionPlan::new(cfg).expect("plan creation failed");
let steps = plan.steps_for_device(0);
assert_eq!(steps.len(), 4);
assert_eq!(steps[0].kv_source_device, 0);
assert_eq!(steps[1].kv_source_device, 3);
assert_eq!(steps[2].kv_source_device, 2);
assert_eq!(steps[3].kv_source_device, 1);
}
#[test]
fn ring_steps_8_devices() {
let mut cfg = default_config(8);
cfg.seq_len = 2048;
cfg.chunk_size = 256;
let plan = RingAttentionPlan::new(cfg).expect("plan creation failed");
let steps = plan.steps_for_device(0);
assert_eq!(steps.len(), 8);
assert!(steps[0].is_first_step);
assert!(steps[7].is_last_step);
}
#[test]
fn causal_mask_detection() {
let mut cfg = default_config(4);
cfg.causal = true;
let plan = RingAttentionPlan::new(cfg).expect("plan creation failed");
let steps = plan.steps_for_device(0);
assert!(steps[0].needs_causal_mask);
assert!(steps[1].needs_causal_mask);
}
#[test]
fn comm_plan_ring_topology() {
let cfg = default_config(4);
let plan = RingAttentionPlan::new(cfg).expect("plan creation failed");
let cp0 = plan.comm_plan_for_device(0);
assert_eq!(cp0.send_to, 1);
assert_eq!(cp0.recv_from, 3);
assert_eq!(cp0.transfers_per_step, 2);
let cp2 = plan.comm_plan_for_device(2);
assert_eq!(cp2.send_to, 3);
assert_eq!(cp2.recv_from, 1);
}
#[test]
fn local_attention_ptx_has_entry() {
let cfg = default_config(2);
let plan = RingAttentionPlan::new(cfg).expect("plan creation failed");
let ptx = plan.generate_local_attention_ptx().expect("ptx gen failed");
assert!(ptx.contains(".entry"));
assert!(ptx.contains("ring_attn_local_fwd"));
}
#[test]
fn accumulate_ptx_generation() {
let cfg = default_config(2);
let plan = RingAttentionPlan::new(cfg).expect("plan creation failed");
let ptx = plan.generate_accumulate_ptx().expect("ptx gen failed");
assert!(ptx.contains(".entry"));
assert!(ptx.contains("ring_attn_accumulate"));
}
#[test]
fn causal_mask_ptx_generation() {
let mut cfg = default_config(2);
cfg.causal = true;
let plan = RingAttentionPlan::new(cfg).expect("plan creation failed");
let steps = plan.steps_for_device(0);
let ptx = plan
.generate_causal_mask_ptx(steps[0])
.expect("ptx gen failed");
assert!(ptx.contains(".entry"));
assert!(ptx.contains("ring_attn_causal_mask"));
}
#[test]
fn stats_computation() {
let cfg = default_config(4);
let plan = RingAttentionPlan::new(cfg).expect("plan creation failed");
let stats = plan.stats();
assert_eq!(stats.total_steps, 4);
assert!(stats.compute_flops_per_step > 0);
assert!(stats.comm_bytes_per_step > 0);
assert!((stats.theoretical_speedup - 4.0).abs() < f64::EPSILON);
assert!(stats.compute_comm_ratio > 0.0);
}
#[test]
fn describe_output_format() {
let cfg = default_config(4);
let plan = RingAttentionPlan::new(cfg).expect("plan creation failed");
let desc = plan.describe();
assert!(desc.contains("RingAttention Plan"));
assert!(desc.contains("devices"));
assert!(desc.contains("seq_len"));
assert!(desc.contains("non-causal"));
}
#[test]
fn shared_memory_bytes_calculation() {
let cfg = default_config(4);
let plan = RingAttentionPlan::new(cfg).expect("plan creation failed");
let smem = plan.shared_memory_bytes();
assert_eq!(smem, 3 * 256 * 64 * 4 + 256 * 256 * 4);
}
#[test]
fn launch_params_values() {
let cfg = default_config(4);
let plan = RingAttentionPlan::new(cfg).expect("plan creation failed");
let (grid, block) = plan.launch_params();
assert_eq!(grid, 8); assert_eq!(block, 256); }
#[test]
fn single_device_degenerates() {
let cfg = default_config(1);
let plan = RingAttentionPlan::new(cfg).expect("plan creation failed");
let steps = plan.steps_for_device(0);
assert_eq!(steps.len(), 1);
assert!(steps[0].is_first_step);
assert!(steps[0].is_last_step);
assert_eq!(steps[0].kv_source_device, 0);
let cp = plan.comm_plan_for_device(0);
assert_eq!(cp.send_to, 0);
assert_eq!(cp.recv_from, 0);
}
}