Skip to main content

entrenar/train/transformer_trainer/
sequence_parallel.rs

1//! Sequence parallelism for transformer pretraining.
2//!
3//! Distributes the sequence dimension across multiple GPUs. Each GPU
4//! processes a contiguous chunk of the sequence, with all-to-all
5//! communication for attention computation.
6//!
7//! # Architecture (Ring Attention)
8//!
9//! ```text
10//! GPU 0: tokens[0..S/2]      GPU 1: tokens[S/2..S]
11//! ──────────────────────     ──────────────────────
12//! Q₀ = embed(tok[0..S/2])   Q₁ = embed(tok[S/2..S])
13//! K₀ = proj(Q₀)             K₁ = proj(Q₁)
14//! V₀ = proj(Q₀)             V₁ = proj(Q₁)
15//!
16//! Ring step 1: attn(Q₀, K₀, V₀) + recv K₁,V₁ from GPU 1
17//! Ring step 2: attn(Q₀, K₁, V₁) + send K₀,V₀ to GPU 1
18//! ─── Reduce attention outputs ───
19//! ```
20//!
21//! # Communication Pattern
22//!
23//! Each GPU sends its K,V to the next GPU in the ring and receives K,V
24//! from the previous GPU. After N-1 ring steps, each GPU has computed
25//! attention against all K,V chunks.
26//!
27//! # When to Use
28//!
29//! Most valuable when sequence length >> hidden size (8K+ sequences).
30//! Reduces peak memory from O(S² × H) to O((S/N)² × H × N) = O(S²/N × H).
31//!
32//! # Contract (C-SP-001)
33//!
34//! - Sequence chunks are contiguous and non-overlapping
35//! - Each GPU's attention output is identical to the full-sequence result
36//! - Ring communication maintains causal mask correctness
37
38/// Sequence parallel configuration.
39#[derive(Debug, Clone)]
40pub struct SequenceParallelConfig {
41    /// This GPU's rank in the SP group
42    pub sp_rank: usize,
43    /// Total GPUs in the SP group
44    pub sp_size: usize,
45    /// Full sequence length (before sharding)
46    pub full_seq_len: usize,
47    /// Hidden size (not sharded in SP)
48    pub hidden_size: usize,
49    /// Number of attention heads
50    pub num_heads: usize,
51    /// Head dimension
52    pub head_dim: usize,
53}
54
55impl SequenceParallelConfig {
56    /// Create a new SP config.
57    ///
58    /// # Panics
59    /// Panics if sequence length is not divisible by sp_size.
60    pub fn new(
61        sp_rank: usize,
62        sp_size: usize,
63        full_seq_len: usize,
64        hidden_size: usize,
65        num_heads: usize,
66    ) -> Self {
67        assert!(
68            full_seq_len.is_multiple_of(sp_size),
69            "seq_len ({full_seq_len}) must be divisible by sp_size ({sp_size})"
70        );
71
72        let head_dim = hidden_size / num_heads;
73
74        Self { sp_rank, sp_size, full_seq_len, hidden_size, num_heads, head_dim }
75    }
76
77    /// Local sequence length on this GPU.
78    pub fn local_seq_len(&self) -> usize {
79        self.full_seq_len / self.sp_size
80    }
81
82    /// Start token index for this GPU's chunk.
83    pub fn seq_start(&self) -> usize {
84        self.sp_rank * self.local_seq_len()
85    }
86
87    /// End token index (exclusive) for this GPU's chunk.
88    pub fn seq_end(&self) -> usize {
89        self.seq_start() + self.local_seq_len()
90    }
91
92    /// Memory savings for attention scores.
93    ///
94    /// Attention score matrix: [num_heads, local_seq, full_seq] per GPU.
95    /// Without SP: [num_heads, full_seq, full_seq].
96    /// Savings: 1 - 1/sp_size (e.g., 50% with 2 GPUs).
97    pub fn attention_memory_savings(&self) -> f64 {
98        1.0 - (1.0 / self.sp_size as f64)
99    }
100
101    /// Number of ring communication steps needed.
102    ///
103    /// Each GPU must see all other GPUs' K,V → sp_size - 1 steps.
104    pub fn ring_steps(&self) -> usize {
105        self.sp_size - 1
106    }
107}
108
109/// Ring attention schedule for a single GPU.
110///
111/// Generates the sequence of (send_to, recv_from) pairs for each ring step.
112#[derive(Debug, Clone)]
113pub struct RingAttentionSchedule {
114    /// Steps in the ring attention protocol
115    pub steps: Vec<RingStep>,
116    /// This GPU's rank
117    pub rank: usize,
118    /// Total GPUs
119    pub world_size: usize,
120}
121
122/// A single step in the ring attention protocol.
123#[derive(Debug, Clone, Copy)]
124pub struct RingStep {
125    /// Ring step index (0-based)
126    pub step: usize,
127    /// Rank to send K,V to
128    pub send_to: usize,
129    /// Rank to receive K,V from
130    pub recv_from: usize,
131    /// The chunk index being processed (which rank's K,V we're using)
132    pub kv_chunk_source: usize,
133}
134
135impl RingAttentionSchedule {
136    /// Generate a ring attention schedule for a given rank.
137    ///
138    /// In each step, GPU sends its K,V to the right neighbor and
139    /// receives K,V from the left neighbor.
140    pub fn new(rank: usize, world_size: usize) -> Self {
141        let mut steps = Vec::with_capacity(world_size - 1);
142
143        for step in 0..world_size - 1 {
144            let send_to = (rank + 1) % world_size;
145            let recv_from = (rank + world_size - 1) % world_size;
146            // After `step` rotations, we have K,V from rank (rank - step - 1) mod N
147            let kv_chunk_source = (rank + world_size - step - 1) % world_size;
148
149            steps.push(RingStep { step, send_to, recv_from, kv_chunk_source });
150        }
151
152        Self { steps, rank, world_size }
153    }
154
155    /// Check if a ring step requires a causal mask adjustment.
156    ///
157    /// For causal (autoregressive) attention, tokens can only attend
158    /// to earlier tokens. When processing K,V from a later chunk,
159    /// the causal mask must block attention to future tokens.
160    pub fn needs_causal_mask(&self, step: usize, local_seq_len: usize) -> CausalMaskType {
161        let kv_source = self.steps[step].kv_chunk_source;
162        let q_start = self.rank * local_seq_len;
163        let kv_start = kv_source * local_seq_len;
164
165        if kv_start + local_seq_len <= q_start {
166            // KV chunk is entirely before Q chunk → full attention
167            CausalMaskType::FullAttention
168        } else if kv_start >= q_start + local_seq_len {
169            // KV chunk is entirely after Q chunk → no attention (skip)
170            CausalMaskType::NoAttention
171        } else {
172            // KV chunk overlaps with Q chunk → apply causal mask
173            CausalMaskType::CausalMask
174        }
175    }
176}
177
178/// Type of causal mask needed for a ring attention step.
179#[derive(Debug, Clone, Copy, PartialEq, Eq)]
180pub enum CausalMaskType {
181    /// All tokens can attend (KV is before Q)
182    FullAttention,
183    /// No tokens can attend (KV is after Q) — skip computation
184    NoAttention,
185    /// Standard causal mask needed (KV overlaps with Q)
186    CausalMask,
187}
188
189/// Communication cost estimate for sequence parallelism.
190#[derive(Debug, Clone)]
191pub struct SpCommCost {
192    /// Bytes per K,V send (local_seq × head_dim × num_kv_heads × sizeof(f32))
193    pub kv_bytes_per_send: usize,
194    /// Number of ring steps (sp_size - 1)
195    pub ring_steps: usize,
196    /// Number of blocks
197    pub num_blocks: usize,
198}
199
200impl SpCommCost {
201    /// Estimate SP communication cost.
202    pub fn estimate(
203        local_seq_len: usize,
204        head_dim: usize,
205        num_kv_heads: usize,
206        sp_size: usize,
207        num_blocks: usize,
208    ) -> Self {
209        // K + V = 2 × local_seq × kv_dim
210        let kv_bytes_per_send =
211            2 * local_seq_len * head_dim * num_kv_heads * std::mem::size_of::<f32>();
212
213        Self { kv_bytes_per_send, ring_steps: sp_size - 1, num_blocks }
214    }
215
216    /// Total bytes communicated per training step.
217    pub fn total_bytes_per_step(&self) -> usize {
218        self.kv_bytes_per_send * self.ring_steps * self.num_blocks
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_sp_config_basic() {
228        let sp = SequenceParallelConfig::new(0, 2, 2048, 1024, 16);
229        assert_eq!(sp.local_seq_len(), 1024);
230        assert_eq!(sp.seq_start(), 0);
231        assert_eq!(sp.seq_end(), 1024);
232        assert!((sp.attention_memory_savings() - 0.5).abs() < 1e-10);
233        assert_eq!(sp.ring_steps(), 1);
234    }
235
236    #[test]
237    fn test_sp_config_4way() {
238        let sp = SequenceParallelConfig::new(2, 4, 8192, 1024, 16);
239        assert_eq!(sp.local_seq_len(), 2048);
240        assert_eq!(sp.seq_start(), 4096);
241        assert_eq!(sp.seq_end(), 6144);
242        assert!((sp.attention_memory_savings() - 0.75).abs() < 1e-10);
243        assert_eq!(sp.ring_steps(), 3);
244    }
245
246    #[test]
247    #[should_panic(expected = "must be divisible")]
248    fn test_sp_config_indivisible() {
249        SequenceParallelConfig::new(0, 3, 1000, 1024, 16); // 1000 % 3 != 0
250    }
251
252    #[test]
253    fn test_ring_attention_schedule_2gpu() {
254        let sched = RingAttentionSchedule::new(0, 2);
255        assert_eq!(sched.steps.len(), 1);
256        assert_eq!(sched.steps[0].send_to, 1);
257        assert_eq!(sched.steps[0].recv_from, 1);
258        assert_eq!(sched.steps[0].kv_chunk_source, 1);
259    }
260
261    #[test]
262    fn test_ring_attention_schedule_4gpu() {
263        let sched = RingAttentionSchedule::new(0, 4);
264        assert_eq!(sched.steps.len(), 3);
265
266        // Step 0: send to 1, recv from 3, processing chunk from rank 3
267        assert_eq!(sched.steps[0].send_to, 1);
268        assert_eq!(sched.steps[0].recv_from, 3);
269        assert_eq!(sched.steps[0].kv_chunk_source, 3);
270
271        // Step 1: send to 1, recv from 3, processing chunk from rank 2
272        assert_eq!(sched.steps[1].kv_chunk_source, 2);
273
274        // Step 2: processing chunk from rank 1
275        assert_eq!(sched.steps[2].kv_chunk_source, 1);
276    }
277
278    #[test]
279    fn test_ring_attention_all_chunks_seen() {
280        // Each GPU must see K,V from all other GPUs
281        let world_size = 4;
282        for rank in 0..world_size {
283            let sched = RingAttentionSchedule::new(rank, world_size);
284            let mut seen: Vec<usize> = sched.steps.iter().map(|s| s.kv_chunk_source).collect();
285            seen.push(rank); // own chunk (processed locally)
286            seen.sort_unstable();
287            assert_eq!(seen, vec![0, 1, 2, 3], "rank {rank} didn't see all chunks");
288        }
289    }
290
291    #[test]
292    fn test_causal_mask_type() {
293        // 4 GPUs, seq_len=1024, local=256
294        let sched = RingAttentionSchedule::new(2, 4); // rank 2: tokens 512..768
295        let local_seq = 256;
296
297        // Step 0: kv from rank 1 (tokens 256..512) — before us → full attention
298        let mask = sched.needs_causal_mask(0, local_seq);
299        assert_eq!(mask, CausalMaskType::FullAttention);
300
301        // Step 2: kv from rank 3 (tokens 768..1024) — after us → no attention
302        let mask = sched.needs_causal_mask(2, local_seq);
303        assert_eq!(mask, CausalMaskType::NoAttention);
304    }
305
306    #[test]
307    fn test_sp_comm_cost() {
308        // 2 GPUs, seq=2048 (local=1024), head_dim=64, 4 KV heads, 24 blocks
309        let cost = SpCommCost::estimate(1024, 64, 4, 2, 24);
310        // K+V = 2 × 1024 × 64 × 4 × 4 = 2 MB per send
311        assert_eq!(cost.kv_bytes_per_send, 2 * 1024 * 64 * 4 * 4);
312        assert_eq!(cost.ring_steps, 1);
313        assert_eq!(cost.total_bytes_per_step(), cost.kv_bytes_per_send * 24);
314    }
315}