Expand description
Sequence parallelism for transformer pretraining.
Distributes the sequence dimension across multiple GPUs. Each GPU processes a contiguous chunk of the sequence, with all-to-all communication for attention computation.
§Architecture (Ring Attention)
GPU 0: tokens[0..S/2] GPU 1: tokens[S/2..S]
────────────────────── ──────────────────────
Q₀ = embed(tok[0..S/2]) Q₁ = embed(tok[S/2..S])
K₀ = proj(Q₀) K₁ = proj(Q₁)
V₀ = proj(Q₀) V₁ = proj(Q₁)
Ring step 1: attn(Q₀, K₀, V₀) + recv K₁,V₁ from GPU 1
Ring step 2: attn(Q₀, K₁, V₁) + send K₀,V₀ to GPU 1
─── Reduce attention outputs ───§Communication Pattern
Each GPU sends its K,V to the next GPU in the ring and receives K,V from the previous GPU. After N-1 ring steps, each GPU has computed attention against all K,V chunks.
§When to Use
Most valuable when sequence length >> hidden size (8K+ sequences). Reduces peak memory from O(S² × H) to O((S/N)² × H × N) = O(S²/N × H).
§Contract (C-SP-001)
- Sequence chunks are contiguous and non-overlapping
- Each GPU’s attention output is identical to the full-sequence result
- Ring communication maintains causal mask correctness
Structs§
- Ring
Attention Schedule - Ring attention schedule for a single GPU.
- Ring
Step - A single step in the ring attention protocol.
- Sequence
Parallel Config - Sequence parallel configuration.
- SpComm
Cost - Communication cost estimate for sequence parallelism.
Enums§
- Causal
Mask Type - Type of causal mask needed for a ring attention step.