use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::{DeviceSliceMut, Stream};
use baracuda_kernels_types::{
ArchSku, AttentionKind, BackendKind, Element, ElementKind, KernelSku, MathPrecision,
OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef,
};
#[cfg(feature = "ring_attention")]
use core::ffi::c_void;
#[cfg(feature = "ring_attention")]
use super::map_status;
pub const RING_ATTENTION_HEAD_DIM: i32 = 128;
#[derive(Copy, Clone, Debug)]
pub struct RingAttentionDescriptor {
pub batch_size: i32,
pub num_heads: i32,
pub query_len: i32,
pub key_len: i32,
pub head_dim: i32,
pub scale: f32,
pub is_causal: bool,
pub element: ElementKind,
}
pub struct RingAttentionArgs<'a, T: Element> {
pub q: TensorRef<'a, T, 4>,
pub y: TensorMut<'a, T, 4>,
pub lse: Option<TensorMut<'a, T, 3>>,
pub kv_scratch_a: DeviceSliceMut<'a, T>,
pub kv_scratch_b: DeviceSliceMut<'a, T>,
pub accumulator_scratch: DeviceSliceMut<'a, u8>,
}
pub struct RingAttentionPlan<T: Element> {
desc: RingAttentionDescriptor,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element> RingAttentionPlan<T> {
pub fn select(
_stream: &Stream,
desc: &RingAttentionDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::RingAttentionPlan: descriptor element != T",
));
}
if desc.batch_size < 0
|| desc.num_heads < 0
|| desc.query_len < 0
|| desc.key_len < 0
|| desc.head_dim < 0
{
return Err(Error::InvalidProblem(
"baracuda-kernels::RingAttentionPlan: extents must be non-negative",
));
}
if !desc.scale.is_finite() {
return Err(Error::InvalidProblem(
"baracuda-kernels::RingAttentionPlan: scale must be finite",
));
}
if desc.head_dim != RING_ATTENTION_HEAD_DIM {
return Err(Error::Unsupported(
"baracuda-kernels::RingAttentionPlan: Tier 1 requires head_dim == 128",
));
}
let dtype_in_scope = matches!(T::KIND, ElementKind::F16 | ElementKind::Bf16);
if !dtype_in_scope {
return Err(Error::Unsupported(
"baracuda-kernels::RingAttentionPlan: Tier 1 wired for `{f16, bf16}`",
));
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Attention,
op: AttentionKind::RingAttention as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
pub fn can_implement(&self, args: &RingAttentionArgs<'_, T>) -> Result<()> {
let shape_q = [
self.desc.batch_size,
self.desc.num_heads,
self.desc.query_len,
self.desc.head_dim,
];
let shape_y = shape_q;
let shape_lse = [
self.desc.batch_size,
self.desc.num_heads,
self.desc.query_len,
];
if args.q.shape != shape_q {
return Err(Error::InvalidProblem(
"baracuda-kernels::RingAttentionPlan: Q shape mismatch",
));
}
if args.y.shape != shape_y {
return Err(Error::InvalidProblem(
"baracuda-kernels::RingAttentionPlan: y shape mismatch",
));
}
if let Some(lse) = args.lse.as_ref() {
if lse.shape != shape_lse {
return Err(Error::InvalidProblem(
"baracuda-kernels::RingAttentionPlan: lse shape must be [B, H, Q_local]",
));
}
if !lse.is_contiguous() {
return Err(Error::Unsupported(
"baracuda-kernels::RingAttentionPlan: lse must be contiguous",
));
}
}
if !args.q.is_contiguous() || !args.y.is_contiguous() {
return Err(Error::Unsupported(
"baracuda-kernels::RingAttentionPlan: Tier 1 requires contiguous tensors",
));
}
let kv_scratch_min = self.kv_scratch_elements();
if args.kv_scratch_a.len() < kv_scratch_min {
return Err(Error::BufferTooSmall {
needed: kv_scratch_min,
got: args.kv_scratch_a.len(),
});
}
if args.kv_scratch_b.len() < kv_scratch_min {
return Err(Error::BufferTooSmall {
needed: kv_scratch_min,
got: args.kv_scratch_b.len(),
});
}
let acc_min = self.accumulator_scratch_bytes();
if args.accumulator_scratch.len() < acc_min {
return Err(Error::WorkspaceTooSmall {
needed: acc_min,
got: args.accumulator_scratch.len(),
});
}
Ok(())
}
#[inline]
pub fn kv_scratch_elements(&self) -> usize {
let n = (self.desc.batch_size as i64)
* (self.desc.num_heads as i64)
* (self.desc.key_len as i64)
* (self.desc.head_dim as i64)
* 2;
n.max(0) as usize
}
#[inline]
pub fn accumulator_scratch_bytes(&self) -> usize {
let bhq = (self.desc.batch_size as i64)
* (self.desc.num_heads as i64)
* (self.desc.query_len as i64);
let o = bhq * (self.desc.head_dim as i64);
let ml = bhq;
((o + 2 * ml) * 4).max(0) as usize
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[cfg(feature = "ring_attention")]
pub fn run(
&self,
stream: &Stream,
comm: &baracuda_nccl::Communicator,
mut args: RingAttentionArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
if args.y.numel() == 0 {
return Ok(());
}
let world_size = comm.world_size();
let rank = comm.rank();
if world_size < 1 {
return Err(Error::InvalidProblem(
"baracuda-kernels::RingAttentionPlan: world_size must be >= 1",
));
}
if rank < 0 || rank >= world_size {
return Err(Error::InvalidProblem(
"baracuda-kernels::RingAttentionPlan: rank out of range",
));
}
let stream_ptr = stream.as_raw() as *mut c_void;
let kv_scratch_elems = self.kv_scratch_elements();
let k_chunk_elems = kv_scratch_elems / 2;
let bhq = (self.desc.batch_size as usize)
* (self.desc.num_heads as usize)
* (self.desc.query_len as usize);
let o_len_f32 = bhq * (self.desc.head_dim as usize);
let ml_len_f32 = bhq;
let o_bytes = o_len_f32 * 4;
let m_bytes = ml_len_f32 * 4;
let acc_raw = args.accumulator_scratch.as_raw().0 as *mut u8;
let o_ptr = acc_raw as *mut c_void;
let m_ptr = unsafe { acc_raw.add(o_bytes) } as *mut c_void;
let l_ptr = unsafe { acc_raw.add(o_bytes + m_bytes) } as *mut c_void;
let init_status = unsafe {
baracuda_kernels_sys::baracuda_kernels_ring_attention_init_run(
o_ptr,
m_ptr,
l_ptr,
o_len_f32 as i64,
ml_len_f32 as i64,
stream_ptr,
)
};
map_status(init_status)?;
let mut current_in_a = true;
let q_global_base = rank * self.desc.query_len;
let q_ptr = args.q.data.as_raw().0 as *const c_void;
for step in 0..world_size {
let origin_rank = (rank - step + world_size) % world_size;
let k_global_base = origin_rank * self.desc.key_len;
let (cur_ptr, _other_ptr) = if current_in_a {
(
args.kv_scratch_a.as_raw().0 as *mut c_void,
args.kv_scratch_b.as_raw().0 as *mut c_void,
)
} else {
(
args.kv_scratch_b.as_raw().0 as *mut c_void,
args.kv_scratch_a.as_raw().0 as *mut c_void,
)
};
let k_cur = cur_ptr as *const c_void;
let v_cur = unsafe {
(cur_ptr as *mut u8).add(k_chunk_elems * core::mem::size_of::<T>())
as *const c_void
};
let step_status = match T::KIND {
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_ring_attention_f16_step_run(
self.desc.batch_size,
self.desc.num_heads,
self.desc.query_len,
self.desc.key_len,
self.desc.head_dim,
q_global_base,
k_global_base,
self.desc.scale,
if self.desc.is_causal { 1 } else { 0 },
q_ptr,
k_cur,
v_cur,
o_ptr,
m_ptr,
l_ptr,
stream_ptr,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_ring_attention_bf16_step_run(
self.desc.batch_size,
self.desc.num_heads,
self.desc.query_len,
self.desc.key_len,
self.desc.head_dim,
q_global_base,
k_global_base,
self.desc.scale,
if self.desc.is_causal { 1 } else { 0 },
q_ptr,
k_cur,
v_cur,
o_ptr,
m_ptr,
l_ptr,
stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::RingAttentionPlan: dtype not in {f16, bf16}",
));
}
};
map_status(step_status)?;
if step + 1 < world_size {
let next_peer = (rank + 1) % world_size;
let prev_peer = (rank - 1 + world_size) % world_size;
rotate_kv(
comm,
stream,
current_in_a,
&mut args.kv_scratch_a,
&mut args.kv_scratch_b,
k_chunk_elems,
next_peer,
prev_peer,
)?;
current_in_a = !current_in_a;
}
}
let y_ptr = args.y.data.as_raw().0 as *mut c_void;
let lse_ptr = args
.lse
.as_mut()
.map(|t| t.data.as_raw().0 as *mut c_void)
.unwrap_or(core::ptr::null_mut());
let fin_status = match T::KIND {
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_ring_attention_f16_finalize_run(
self.desc.batch_size,
self.desc.num_heads,
self.desc.query_len,
self.desc.head_dim,
o_ptr,
m_ptr,
l_ptr,
y_ptr,
lse_ptr,
stream_ptr,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_ring_attention_bf16_finalize_run(
self.desc.batch_size,
self.desc.num_heads,
self.desc.query_len,
self.desc.head_dim,
o_ptr,
m_ptr,
l_ptr,
y_ptr,
lse_ptr,
stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::RingAttentionPlan: dtype not in {f16, bf16}",
));
}
};
map_status(fin_status)?;
Ok(())
}
}
#[cfg(feature = "ring_attention")]
#[allow(clippy::too_many_arguments)]
fn rotate_kv<T: Element>(
comm: &baracuda_nccl::Communicator,
stream: &Stream,
current_in_a: bool,
scratch_a: &mut DeviceSliceMut<'_, T>,
scratch_b: &mut DeviceSliceMut<'_, T>,
k_chunk_elems: usize,
next_peer: i32,
prev_peer: i32,
) -> Result<()> {
let total_elems = 2 * k_chunk_elems;
use baracuda_nccl_sys::{nccl, ncclDataType_t};
let n = nccl().map_err(|_| {
Error::Unsupported(
"baracuda-kernels::RingAttentionPlan: NCCL library not available at runtime",
)
})?;
let dtype = match T::KIND {
ElementKind::F16 => ncclDataType_t::Float16,
ElementKind::Bf16 => ncclDataType_t::BFloat16,
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::RingAttentionPlan::rotate_kv: dtype not in {f16, bf16}",
));
}
};
let (send_ptr, recv_ptr) = if current_in_a {
(
scratch_a.as_raw().0 as *mut c_void,
scratch_b.as_raw().0 as *mut c_void,
)
} else {
(
scratch_b.as_raw().0 as *mut c_void,
scratch_a.as_raw().0 as *mut c_void,
)
};
let group_start = n.nccl_group_start().map_err(|_| {
Error::CutlassInternal(7000)
})?;
let group_end = n.nccl_group_end().map_err(|_| {
Error::CutlassInternal(7001)
})?;
let send_fn = n.nccl_send().map_err(|_| Error::CutlassInternal(7002))?;
let recv_fn = n.nccl_recv().map_err(|_| Error::CutlassInternal(7003))?;
let comm_handle = comm.as_raw();
let stream_raw = stream.as_raw() as _;
let s = unsafe { group_start() };
if !s.is_success() {
return Err(Error::CutlassInternal(7100 + s.0));
}
let s = unsafe {
send_fn(
send_ptr as *const c_void,
total_elems,
dtype,
next_peer,
comm_handle,
stream_raw,
)
};
if !s.is_success() {
let _ = unsafe { group_end() };
return Err(Error::CutlassInternal(7200 + s.0));
}
let s = unsafe {
recv_fn(
recv_ptr,
total_elems,
dtype,
prev_peer,
comm_handle,
stream_raw,
)
};
if !s.is_success() {
let _ = unsafe { group_end() };
return Err(Error::CutlassInternal(7300 + s.0));
}
let s = unsafe { group_end() };
if !s.is_success() {
return Err(Error::CutlassInternal(7400 + s.0));
}
Ok(())
}