Skip to main content

Module sequence_parallel

Module sequence_parallel 

Source
Expand description

Sequence parallelism for transformer pretraining.

Distributes the sequence dimension across multiple GPUs. Each GPU processes a contiguous chunk of the sequence, with all-to-all communication for attention computation.

§Architecture (Ring Attention)

GPU 0: tokens[0..S/2]      GPU 1: tokens[S/2..S]
──────────────────────     ──────────────────────
Q₀ = embed(tok[0..S/2])   Q₁ = embed(tok[S/2..S])
K₀ = proj(Q₀)             K₁ = proj(Q₁)
V₀ = proj(Q₀)             V₁ = proj(Q₁)

Ring step 1: attn(Q₀, K₀, V₀) + recv K₁,V₁ from GPU 1
Ring step 2: attn(Q₀, K₁, V₁) + send K₀,V₀ to GPU 1
─── Reduce attention outputs ───

§Communication Pattern

Each GPU sends its K,V to the next GPU in the ring and receives K,V from the previous GPU. After N-1 ring steps, each GPU has computed attention against all K,V chunks.

§When to Use

Most valuable when sequence length >> hidden size (8K+ sequences). Reduces peak memory from O(S² × H) to O((S/N)² × H × N) = O(S²/N × H).

§Contract (C-SP-001)

  • Sequence chunks are contiguous and non-overlapping
  • Each GPU’s attention output is identical to the full-sequence result
  • Ring communication maintains causal mask correctness

Structs§

RingAttentionSchedule
Ring attention schedule for a single GPU.
RingStep
A single step in the ring attention protocol.
SequenceParallelConfig
Sequence parallel configuration.
SpCommCost
Communication cost estimate for sequence parallelism.

Enums§

CausalMaskType
Type of causal mask needed for a ring attention step.