#[derive(Debug, Clone)]
pub struct SequenceParallelConfig {
pub sp_rank: usize,
pub sp_size: usize,
pub full_seq_len: usize,
pub hidden_size: usize,
pub num_heads: usize,
pub head_dim: usize,
}
impl SequenceParallelConfig {
pub fn new(
sp_rank: usize,
sp_size: usize,
full_seq_len: usize,
hidden_size: usize,
num_heads: usize,
) -> Self {
assert!(
full_seq_len.is_multiple_of(sp_size),
"seq_len ({full_seq_len}) must be divisible by sp_size ({sp_size})"
);
let head_dim = hidden_size / num_heads;
Self { sp_rank, sp_size, full_seq_len, hidden_size, num_heads, head_dim }
}
pub fn local_seq_len(&self) -> usize {
self.full_seq_len / self.sp_size
}
pub fn seq_start(&self) -> usize {
self.sp_rank * self.local_seq_len()
}
pub fn seq_end(&self) -> usize {
self.seq_start() + self.local_seq_len()
}
pub fn attention_memory_savings(&self) -> f64 {
1.0 - (1.0 / self.sp_size as f64)
}
pub fn ring_steps(&self) -> usize {
self.sp_size - 1
}
}
#[derive(Debug, Clone)]
pub struct RingAttentionSchedule {
pub steps: Vec<RingStep>,
pub rank: usize,
pub world_size: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct RingStep {
pub step: usize,
pub send_to: usize,
pub recv_from: usize,
pub kv_chunk_source: usize,
}
impl RingAttentionSchedule {
pub fn new(rank: usize, world_size: usize) -> Self {
let mut steps = Vec::with_capacity(world_size - 1);
for step in 0..world_size - 1 {
let send_to = (rank + 1) % world_size;
let recv_from = (rank + world_size - 1) % world_size;
let kv_chunk_source = (rank + world_size - step - 1) % world_size;
steps.push(RingStep { step, send_to, recv_from, kv_chunk_source });
}
Self { steps, rank, world_size }
}
pub fn needs_causal_mask(&self, step: usize, local_seq_len: usize) -> CausalMaskType {
let kv_source = self.steps[step].kv_chunk_source;
let q_start = self.rank * local_seq_len;
let kv_start = kv_source * local_seq_len;
if kv_start + local_seq_len <= q_start {
CausalMaskType::FullAttention
} else if kv_start >= q_start + local_seq_len {
CausalMaskType::NoAttention
} else {
CausalMaskType::CausalMask
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CausalMaskType {
FullAttention,
NoAttention,
CausalMask,
}
#[derive(Debug, Clone)]
pub struct SpCommCost {
pub kv_bytes_per_send: usize,
pub ring_steps: usize,
pub num_blocks: usize,
}
impl SpCommCost {
pub fn estimate(
local_seq_len: usize,
head_dim: usize,
num_kv_heads: usize,
sp_size: usize,
num_blocks: usize,
) -> Self {
let kv_bytes_per_send =
2 * local_seq_len * head_dim * num_kv_heads * std::mem::size_of::<f32>();
Self { kv_bytes_per_send, ring_steps: sp_size - 1, num_blocks }
}
pub fn total_bytes_per_step(&self) -> usize {
self.kv_bytes_per_send * self.ring_steps * self.num_blocks
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sp_config_basic() {
let sp = SequenceParallelConfig::new(0, 2, 2048, 1024, 16);
assert_eq!(sp.local_seq_len(), 1024);
assert_eq!(sp.seq_start(), 0);
assert_eq!(sp.seq_end(), 1024);
assert!((sp.attention_memory_savings() - 0.5).abs() < 1e-10);
assert_eq!(sp.ring_steps(), 1);
}
#[test]
fn test_sp_config_4way() {
let sp = SequenceParallelConfig::new(2, 4, 8192, 1024, 16);
assert_eq!(sp.local_seq_len(), 2048);
assert_eq!(sp.seq_start(), 4096);
assert_eq!(sp.seq_end(), 6144);
assert!((sp.attention_memory_savings() - 0.75).abs() < 1e-10);
assert_eq!(sp.ring_steps(), 3);
}
#[test]
#[should_panic(expected = "must be divisible")]
fn test_sp_config_indivisible() {
SequenceParallelConfig::new(0, 3, 1000, 1024, 16); }
#[test]
fn test_ring_attention_schedule_2gpu() {
let sched = RingAttentionSchedule::new(0, 2);
assert_eq!(sched.steps.len(), 1);
assert_eq!(sched.steps[0].send_to, 1);
assert_eq!(sched.steps[0].recv_from, 1);
assert_eq!(sched.steps[0].kv_chunk_source, 1);
}
#[test]
fn test_ring_attention_schedule_4gpu() {
let sched = RingAttentionSchedule::new(0, 4);
assert_eq!(sched.steps.len(), 3);
assert_eq!(sched.steps[0].send_to, 1);
assert_eq!(sched.steps[0].recv_from, 3);
assert_eq!(sched.steps[0].kv_chunk_source, 3);
assert_eq!(sched.steps[1].kv_chunk_source, 2);
assert_eq!(sched.steps[2].kv_chunk_source, 1);
}
#[test]
fn test_ring_attention_all_chunks_seen() {
let world_size = 4;
for rank in 0..world_size {
let sched = RingAttentionSchedule::new(rank, world_size);
let mut seen: Vec<usize> = sched.steps.iter().map(|s| s.kv_chunk_source).collect();
seen.push(rank); seen.sort_unstable();
assert_eq!(seen, vec![0, 1, 2, 3], "rank {rank} didn't see all chunks");
}
}
#[test]
fn test_causal_mask_type() {
let sched = RingAttentionSchedule::new(2, 4); let local_seq = 256;
let mask = sched.needs_causal_mask(0, local_seq);
assert_eq!(mask, CausalMaskType::FullAttention);
let mask = sched.needs_causal_mask(2, local_seq);
assert_eq!(mask, CausalMaskType::NoAttention);
}
#[test]
fn test_sp_comm_cost() {
let cost = SpCommCost::estimate(1024, 64, 4, 2, 24);
assert_eq!(cost.kv_bytes_per_send, 2 * 1024 * 64 * 4 * 4);
assert_eq!(cost.ring_steps, 1);
assert_eq!(cost.total_bytes_per_step(), cost.kv_bytes_per_send * 24);
}
}