entrenar/train/transformer_trainer/
sequence_parallel.rs1#[derive(Debug, Clone)]
40pub struct SequenceParallelConfig {
41 pub sp_rank: usize,
43 pub sp_size: usize,
45 pub full_seq_len: usize,
47 pub hidden_size: usize,
49 pub num_heads: usize,
51 pub head_dim: usize,
53}
54
55impl SequenceParallelConfig {
56 pub fn new(
61 sp_rank: usize,
62 sp_size: usize,
63 full_seq_len: usize,
64 hidden_size: usize,
65 num_heads: usize,
66 ) -> Self {
67 assert!(
68 full_seq_len.is_multiple_of(sp_size),
69 "seq_len ({full_seq_len}) must be divisible by sp_size ({sp_size})"
70 );
71
72 let head_dim = hidden_size / num_heads;
73
74 Self { sp_rank, sp_size, full_seq_len, hidden_size, num_heads, head_dim }
75 }
76
77 pub fn local_seq_len(&self) -> usize {
79 self.full_seq_len / self.sp_size
80 }
81
82 pub fn seq_start(&self) -> usize {
84 self.sp_rank * self.local_seq_len()
85 }
86
87 pub fn seq_end(&self) -> usize {
89 self.seq_start() + self.local_seq_len()
90 }
91
92 pub fn attention_memory_savings(&self) -> f64 {
98 1.0 - (1.0 / self.sp_size as f64)
99 }
100
101 pub fn ring_steps(&self) -> usize {
105 self.sp_size - 1
106 }
107}
108
109#[derive(Debug, Clone)]
113pub struct RingAttentionSchedule {
114 pub steps: Vec<RingStep>,
116 pub rank: usize,
118 pub world_size: usize,
120}
121
122#[derive(Debug, Clone, Copy)]
124pub struct RingStep {
125 pub step: usize,
127 pub send_to: usize,
129 pub recv_from: usize,
131 pub kv_chunk_source: usize,
133}
134
135impl RingAttentionSchedule {
136 pub fn new(rank: usize, world_size: usize) -> Self {
141 let mut steps = Vec::with_capacity(world_size - 1);
142
143 for step in 0..world_size - 1 {
144 let send_to = (rank + 1) % world_size;
145 let recv_from = (rank + world_size - 1) % world_size;
146 let kv_chunk_source = (rank + world_size - step - 1) % world_size;
148
149 steps.push(RingStep { step, send_to, recv_from, kv_chunk_source });
150 }
151
152 Self { steps, rank, world_size }
153 }
154
155 pub fn needs_causal_mask(&self, step: usize, local_seq_len: usize) -> CausalMaskType {
161 let kv_source = self.steps[step].kv_chunk_source;
162 let q_start = self.rank * local_seq_len;
163 let kv_start = kv_source * local_seq_len;
164
165 if kv_start + local_seq_len <= q_start {
166 CausalMaskType::FullAttention
168 } else if kv_start >= q_start + local_seq_len {
169 CausalMaskType::NoAttention
171 } else {
172 CausalMaskType::CausalMask
174 }
175 }
176}
177
178#[derive(Debug, Clone, Copy, PartialEq, Eq)]
180pub enum CausalMaskType {
181 FullAttention,
183 NoAttention,
185 CausalMask,
187}
188
189#[derive(Debug, Clone)]
191pub struct SpCommCost {
192 pub kv_bytes_per_send: usize,
194 pub ring_steps: usize,
196 pub num_blocks: usize,
198}
199
200impl SpCommCost {
201 pub fn estimate(
203 local_seq_len: usize,
204 head_dim: usize,
205 num_kv_heads: usize,
206 sp_size: usize,
207 num_blocks: usize,
208 ) -> Self {
209 let kv_bytes_per_send =
211 2 * local_seq_len * head_dim * num_kv_heads * std::mem::size_of::<f32>();
212
213 Self { kv_bytes_per_send, ring_steps: sp_size - 1, num_blocks }
214 }
215
216 pub fn total_bytes_per_step(&self) -> usize {
218 self.kv_bytes_per_send * self.ring_steps * self.num_blocks
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn test_sp_config_basic() {
228 let sp = SequenceParallelConfig::new(0, 2, 2048, 1024, 16);
229 assert_eq!(sp.local_seq_len(), 1024);
230 assert_eq!(sp.seq_start(), 0);
231 assert_eq!(sp.seq_end(), 1024);
232 assert!((sp.attention_memory_savings() - 0.5).abs() < 1e-10);
233 assert_eq!(sp.ring_steps(), 1);
234 }
235
236 #[test]
237 fn test_sp_config_4way() {
238 let sp = SequenceParallelConfig::new(2, 4, 8192, 1024, 16);
239 assert_eq!(sp.local_seq_len(), 2048);
240 assert_eq!(sp.seq_start(), 4096);
241 assert_eq!(sp.seq_end(), 6144);
242 assert!((sp.attention_memory_savings() - 0.75).abs() < 1e-10);
243 assert_eq!(sp.ring_steps(), 3);
244 }
245
246 #[test]
247 #[should_panic(expected = "must be divisible")]
248 fn test_sp_config_indivisible() {
249 SequenceParallelConfig::new(0, 3, 1000, 1024, 16); }
251
252 #[test]
253 fn test_ring_attention_schedule_2gpu() {
254 let sched = RingAttentionSchedule::new(0, 2);
255 assert_eq!(sched.steps.len(), 1);
256 assert_eq!(sched.steps[0].send_to, 1);
257 assert_eq!(sched.steps[0].recv_from, 1);
258 assert_eq!(sched.steps[0].kv_chunk_source, 1);
259 }
260
261 #[test]
262 fn test_ring_attention_schedule_4gpu() {
263 let sched = RingAttentionSchedule::new(0, 4);
264 assert_eq!(sched.steps.len(), 3);
265
266 assert_eq!(sched.steps[0].send_to, 1);
268 assert_eq!(sched.steps[0].recv_from, 3);
269 assert_eq!(sched.steps[0].kv_chunk_source, 3);
270
271 assert_eq!(sched.steps[1].kv_chunk_source, 2);
273
274 assert_eq!(sched.steps[2].kv_chunk_source, 1);
276 }
277
278 #[test]
279 fn test_ring_attention_all_chunks_seen() {
280 let world_size = 4;
282 for rank in 0..world_size {
283 let sched = RingAttentionSchedule::new(rank, world_size);
284 let mut seen: Vec<usize> = sched.steps.iter().map(|s| s.kv_chunk_source).collect();
285 seen.push(rank); seen.sort_unstable();
287 assert_eq!(seen, vec![0, 1, 2, 3], "rank {rank} didn't see all chunks");
288 }
289 }
290
291 #[test]
292 fn test_causal_mask_type() {
293 let sched = RingAttentionSchedule::new(2, 4); let local_seq = 256;
296
297 let mask = sched.needs_causal_mask(0, local_seq);
299 assert_eq!(mask, CausalMaskType::FullAttention);
300
301 let mask = sched.needs_causal_mask(2, local_seq);
303 assert_eq!(mask, CausalMaskType::NoAttention);
304 }
305
306 #[test]
307 fn test_sp_comm_cost() {
308 let cost = SpCommCost::estimate(1024, 64, 4, 2, 24);
310 assert_eq!(cost.kv_bytes_per_send, 2 * 1024 * 64 * 4 * 4);
312 assert_eq!(cost.ring_steps, 1);
313 assert_eq!(cost.total_bytes_per_step(), cost.kv_bytes_per_send * 24);
314 }
315}