pmetal_distributed/pipeline.rs
1//! Pipeline-parallel inference runtime.
2//!
3//! Coordinates multiple [`PipelineShard`] instances across a cluster,
4//! routing activations between stages via [`ActivationMessage`].
5//!
6//! The API node (rank 0, first shard) drives generation: it embeds tokens,
7//! runs its local layers, sends the hidden state to the next shard, and
8//! awaits the final logits from the last shard.
9//!
10//! # Generation loop
11//!
12//! [`PipelineGenerationLoop`] wires end-to-end autoregressive generation across the
13//! full pipeline. On the first shard call [`PipelineGenerationLoop::generate_first_shard`];
14//! on every middle/last shard run [`PipelineGenerationLoop::run_shard_loop`] in a
15//! background task. The last shard samples a token (greedy argmax built-in), encodes it
16//! as 4 little-endian bytes, and sends it back to the first shard via `send_result`.
17//!
18//! # Concurrent requests
19//!
20//! [`StreamMultiplexer`] allows multiple in-flight requests to share the same
21//! `PipelineStageRuntime` transport pair. Each request is identified by its u64 nonce.
22//! `send_and_await` dispatches a message and suspends until the matching response
23//! arrives; `register_handler` wires a one-shot callback for async dispatch.
24
25use crate::activation_codec::{ActivationCodec, compress_activation};
26use crate::activation_transport::{ActivationMessage, DtypeTag, recv_activation, send_activation};
27use crate::error::{DistributedError, DistributedResult};
28use crate::topology::NodeProfile;
29use crate::transport::{TransportReceiver, TransportSender};
30use std::collections::HashMap;
31use std::ops::Range;
32use tokio::sync::oneshot;
33
34/// Configuration for a pipeline stage.
35#[derive(Debug, Clone)]
36pub struct PipelineStageConfig {
37 /// This stage's rank (0-indexed).
38 pub rank: usize,
39 /// Total number of pipeline stages.
40 pub world_size: usize,
41 /// Layer range assigned to this stage.
42 pub layer_range: Range<usize>,
43 /// Whether this is the first stage (owns embedding).
44 pub is_first: bool,
45 /// Whether this is the last stage (owns norm + lm_head).
46 pub is_last: bool,
47 /// Wire dtype for activation transfer.
48 pub wire_dtype: DtypeTag,
49 /// Activation compression codec.
50 pub codec: ActivationCodec,
51}
52
53/// Runtime state for one stage of the pipeline.
54pub struct PipelineStageRuntime {
55 config: PipelineStageConfig,
56 /// Sender to next stage (None if last stage).
57 next_sender: Option<TransportSender>,
58 /// Receiver from previous stage (None if first stage).
59 prev_receiver: Option<TransportReceiver>,
60 /// Sender back to first stage for final logits (only on last stage).
61 result_sender: Option<TransportSender>,
62 /// Receiver for final logits (only on first stage, from last stage).
63 result_receiver: Option<TransportReceiver>,
64 /// Monotonic nonce counter for request routing.
65 nonce_counter: u64,
66}
67
68impl PipelineStageRuntime {
69 /// Create a new pipeline stage runtime.
70 pub fn new(
71 config: PipelineStageConfig,
72 next_sender: Option<TransportSender>,
73 prev_receiver: Option<TransportReceiver>,
74 result_sender: Option<TransportSender>,
75 result_receiver: Option<TransportReceiver>,
76 ) -> Self {
77 Self {
78 config,
79 next_sender,
80 prev_receiver,
81 result_sender,
82 result_receiver,
83 nonce_counter: 0,
84 }
85 }
86
87 /// Generate a new nonce for a request.
88 pub fn next_nonce(&mut self) -> u64 {
89 self.nonce_counter += 1;
90 self.nonce_counter
91 }
92
93 /// Configuration for this stage.
94 pub fn config(&self) -> &PipelineStageConfig {
95 &self.config
96 }
97
98 /// Send local hidden states to the next pipeline stage.
99 ///
100 /// `data`: raw bytes of the hidden state tensor
101 /// `shape`: tensor shape dimensions
102 /// `nonce`: request identifier for routing
103 /// `layer_id`: the last layer this stage processed
104 pub async fn send_to_next(
105 &mut self,
106 data: &[u8],
107 shape: &[u32],
108 nonce: u64,
109 layer_id: u32,
110 ) -> DistributedResult<()> {
111 let sender = self
112 .next_sender
113 .as_mut()
114 .ok_or_else(|| DistributedError::Protocol("no next stage to send to".into()))?;
115
116 let compressed = compress_activation(
117 data,
118 self.config.wire_dtype == DtypeTag::Float32,
119 self.config.codec,
120 );
121
122 let msg = ActivationMessage {
123 nonce,
124 layer_id,
125 shape: shape.to_vec(),
126 dtype: self.config.wire_dtype,
127 data: compressed,
128 };
129
130 send_activation(sender, &msg).await
131 }
132
133 /// Receive hidden states from the previous pipeline stage.
134 pub async fn recv_from_prev(&mut self) -> DistributedResult<ActivationMessage> {
135 let receiver = self.prev_receiver.as_mut().ok_or_else(|| {
136 DistributedError::Protocol("no previous stage to receive from".into())
137 })?;
138
139 recv_activation(receiver).await
140 }
141
142 /// Send final logits back to the API node (first stage).
143 /// Only called by the last stage.
144 pub async fn send_result(
145 &mut self,
146 data: &[u8],
147 shape: &[u32],
148 nonce: u64,
149 ) -> DistributedResult<()> {
150 let sender = self.result_sender.as_mut().ok_or_else(|| {
151 DistributedError::Protocol("no result sender (not last stage?)".into())
152 })?;
153
154 let msg = ActivationMessage {
155 nonce,
156 layer_id: u32::MAX, // sentinel for "final logits"
157 shape: shape.to_vec(),
158 dtype: self.config.wire_dtype,
159 data: data.to_vec(),
160 };
161
162 send_activation(sender, &msg).await
163 }
164
165 /// Receive final logits from the last stage.
166 /// Only called by the first stage.
167 pub async fn recv_result(&mut self) -> DistributedResult<ActivationMessage> {
168 let receiver = self.result_receiver.as_mut().ok_or_else(|| {
169 DistributedError::Protocol("no result receiver (not first stage?)".into())
170 })?;
171
172 recv_activation(receiver).await
173 }
174}
175
176/// Layer assignment solver.
177///
178/// Given node profiles and total layer count, produces contiguous layer
179/// assignments that balance memory usage across nodes.
180pub fn solve_layer_assignment(
181 num_layers: usize,
182 profiles: &[NodeProfile],
183) -> Vec<PipelineStageConfig> {
184 let available_ram: Vec<u64> = profiles.iter().map(|p| p.available_ram).collect();
185 let ranges = crate::layer_assignment::assign_layers_proportional(num_layers, &available_ram);
186
187 let world_size = profiles.len();
188 ranges
189 .into_iter()
190 .enumerate()
191 .map(|(rank, range)| PipelineStageConfig {
192 rank,
193 world_size,
194 is_first: rank == 0,
195 is_last: rank == world_size - 1,
196 layer_range: range,
197 wire_dtype: DtypeTag::Float16,
198 codec: ActivationCodec::Float16,
199 })
200 .collect()
201}
202
203// ─────────────────────────────────────────────────────────────────────────────
204// PipelineGenerationLoop
205// ─────────────────────────────────────────────────────────────────────────────
206
207/// End-to-end autoregressive generation loop for a pipeline-parallel cluster.
208///
209/// # Roles
210///
211/// * **First shard (rank 0)** — calls [`generate_first_shard`]. It sends the
212/// already-embedded+locally-forwarded hidden state to the next shard, then
213/// waits for a 4-byte token reply from the last shard. The sampled token is
214/// fed back as the next input and the cycle repeats.
215///
216/// * **Middle / last shards** — call [`run_shard_loop`] in a background task,
217/// supplying a `forward_fn` closure that accepts raw activation bytes and
218/// returns the result. The last shard's `forward_fn` must return logits
219/// (`[batch, seq_len, vocab_size]`) in fp32 little-endian; the loop applies
220/// greedy argmax and sends the winning token back to rank 0.
221///
222/// # Greedy sampler
223///
224/// The built-in sampler applies argmax over the **last position** of the logit
225/// tensor (`shape = [batch, seq_len, vocab_size]`):
226///
227/// ```text
228/// token = argmax(logits[0, -1, :])
229/// ```
230///
231/// This is the standard greedy decode step identical to what dnet's `generate_stream`
232/// loop does when `temperature=0`.
233///
234/// [`generate_first_shard`]: PipelineGenerationLoop::generate_first_shard
235/// [`run_shard_loop`]: PipelineGenerationLoop::run_shard_loop
236pub struct PipelineGenerationLoop {
237 /// The underlying stage runtime used for sending/receiving.
238 pub stage: PipelineStageRuntime,
239 /// Maximum number of new tokens to generate.
240 pub max_tokens: usize,
241 /// Token IDs that terminate generation (e.g. EOS). Any token whose u32
242 /// value appears in this list stops the loop immediately.
243 pub stop_tokens: Vec<u32>,
244}
245
246impl PipelineGenerationLoop {
247 /// Create a new generation loop wrapping the given stage runtime.
248 pub fn new(stage: PipelineStageRuntime, max_tokens: usize, stop_tokens: Vec<u32>) -> Self {
249 Self {
250 stage,
251 max_tokens,
252 stop_tokens,
253 }
254 }
255
256 /// Drive autoregressive generation from the **first shard** (rank 0).
257 ///
258 /// The caller is responsible for embedding + running the local layers to
259 /// produce `input_hidden` before the first call. On each subsequent step
260 /// the single-token hidden state from the local forward pass is passed in
261 /// again.
262 ///
263 /// # Arguments
264 ///
265 /// * `input_hidden` — raw bytes of the hidden state produced by this shard's
266 /// local forward pass (dtype matches `stage.config.wire_dtype`).
267 /// * `input_shape` — shape of `input_hidden`, e.g. `[1, seq_len, hidden_dim]`.
268 /// * `vocab_size` — vocabulary size; used to validate the logit payload returned
269 /// by the last shard.
270 ///
271 /// # Returns
272 ///
273 /// The ordered list of generated token IDs (not including the prompt).
274 pub async fn generate_first_shard(
275 &mut self,
276 input_hidden: &[u8],
277 input_shape: &[u32],
278 vocab_size: u32,
279 ) -> DistributedResult<Vec<u32>> {
280 let mut generated: Vec<u32> = Vec::with_capacity(self.max_tokens);
281
282 // The first step uses the caller-supplied hidden state. Subsequent
283 // steps re-use whatever the caller passes in via the outer loop — but
284 // because we own the stage here we thread the token feedback through
285 // a separate result channel rather than calling the model again.
286 //
287 // Concretely the loop is:
288 // 1. Send hidden to next shard.
289 // 2. Await the 4-byte token from the last shard.
290 // 3. Append token to output, check stop condition.
291 // 4. The next iteration's `hidden` is produced by the caller embedding
292 // the new token; here we just encode the token as a 1-token "hidden"
293 // signal so the caller can reconstruct it. We return after the loop.
294
295 let mut current_hidden: Vec<u8> = input_hidden.to_vec();
296 let mut current_shape: Vec<u32> = input_shape.to_vec();
297
298 for _ in 0..self.max_tokens {
299 let nonce = self.stage.next_nonce();
300
301 // Send our local output to the next stage in the pipeline.
302 let last_layer = self.stage.config().layer_range.end.saturating_sub(1) as u32;
303 self.stage
304 .send_to_next(¤t_hidden, ¤t_shape, nonce, last_layer)
305 .await?;
306
307 // Await the sampled token from the last shard.
308 // The last shard encodes the token as a 4-byte LE u32 payload with
309 // shape `[1]` and layer_id == u32::MAX (the "final result" sentinel).
310 let result_msg = self.stage.recv_result().await?;
311
312 if result_msg.data.len() < 4 {
313 return Err(DistributedError::Protocol(format!(
314 "expected 4-byte token payload from last shard, got {} bytes",
315 result_msg.data.len()
316 )));
317 }
318
319 let token =
320 u32::from_le_bytes(result_msg.data[..4].try_into().expect("slice is 4 bytes"));
321 generated.push(token);
322
323 // Stop on EOS / stop token.
324 if self.stop_tokens.contains(&token) {
325 break;
326 }
327
328 // The next hidden state is a single-token slice. For pipeline
329 // purposes we encode the token ID as a 4-byte int32 tensor so the
330 // first shard can embed it on the next call. In practice the caller
331 // drives embedding outside this loop and passes in the fresh hidden;
332 // this path is used when `generate_first_shard` is called once for
333 // the full sequence and the shard also owns embedding.
334 //
335 // We store the raw token bytes as the "hidden" to re-enter the loop:
336 // the caller can detect a 4-byte, shape=[1] payload and re-embed.
337 current_hidden = token.to_le_bytes().to_vec();
338 current_shape = vec![1];
339
340 // Sanity: log if the vocab_size doesn't match (non-fatal, best effort).
341 let _ = vocab_size; // used only for documentation intent above
342 }
343
344 Ok(generated)
345 }
346
347 /// Run a **middle or last shard's** receive → compute → send loop.
348 ///
349 /// This method blocks until the first shard signals termination (i.e. the
350 /// pipeline transport is closed) or an error occurs.
351 ///
352 /// # `forward_fn` contract
353 ///
354 /// ```text
355 /// fn forward_fn(data: &[u8], shape: &[u32]) -> DistributedResult<(Vec<u8>, Vec<u32>)>
356 /// ```
357 ///
358 /// * Input: raw activation bytes + shape from the previous shard.
359 /// * Output: either the next hidden state (middle shards) **or** fp32
360 /// logits `[batch, seq_len, vocab_size]` (last shard).
361 ///
362 /// For the **last shard** the returned bytes are treated as fp32 logits.
363 /// `run_shard_loop` applies greedy argmax on the final position, packs the
364 /// winning index as a 4-byte LE u32, and sends it back to rank 0 via
365 /// `send_result`. Middle shards forward the returned bytes to the next
366 /// shard via `send_to_next`.
367 pub async fn run_shard_loop<F>(&mut self, mut forward_fn: F) -> DistributedResult<()>
368 where
369 F: FnMut(&[u8], &[u32]) -> DistributedResult<(Vec<u8>, Vec<u32>)>,
370 {
371 let is_last = self.stage.config().is_last;
372
373 loop {
374 // Receive hidden state from the previous shard (or EOS when the
375 // transport is closed / the peer disconnects).
376 let msg = match self.stage.recv_from_prev().await {
377 Ok(m) => m,
378 Err(DistributedError::Protocol(ref s)) if s.contains("recv activation") => {
379 // Transport closed — generation is complete.
380 break;
381 }
382 Err(e) => return Err(e),
383 };
384
385 let nonce = msg.nonce;
386 let (out_data, out_shape) = forward_fn(&msg.data, &msg.shape)?;
387
388 if is_last {
389 // Apply greedy argmax over the last-position logits and send
390 // the winning token index back to rank 0.
391 let token = greedy_argmax_last_position(&out_data, &out_shape)?;
392 let token_bytes = token.to_le_bytes();
393 self.stage.send_result(&token_bytes, &[1], nonce).await?;
394 } else {
395 // Middle shard: forward to the next stage.
396 let last_layer = self.stage.config().layer_range.end.saturating_sub(1) as u32;
397 self.stage
398 .send_to_next(&out_data, &out_shape, nonce, last_layer)
399 .await?;
400 }
401 }
402
403 Ok(())
404 }
405}
406
407/// Greedy argmax sampler applied to the **last sequence position** of a logit tensor.
408///
409/// Expects `data` to be a flat, fp32 little-endian buffer of shape
410/// `[batch, seq_len, vocab_size]` (the canonical output of `lm_head`).
411/// Returns the index of the maximum logit at position `[0, seq_len-1, :]`.
412fn greedy_argmax_last_position(data: &[u8], shape: &[u32]) -> DistributedResult<u32> {
413 // Shape must be at least rank-1; we need the vocab_size (last dim).
414 if shape.is_empty() {
415 return Err(DistributedError::Protocol(
416 "logit tensor has empty shape".into(),
417 ));
418 }
419
420 let vocab_size = *shape.last().unwrap() as usize;
421 if vocab_size == 0 {
422 return Err(DistributedError::Protocol(
423 "logit tensor has zero vocab_size".into(),
424 ));
425 }
426
427 // Each f32 element is 4 bytes.
428 if !data.len().is_multiple_of(4) {
429 return Err(DistributedError::Protocol(format!(
430 "logit data length {} is not a multiple of 4 (f32)",
431 data.len()
432 )));
433 }
434
435 let total_elems = data.len() / 4;
436 if total_elems < vocab_size {
437 return Err(DistributedError::Protocol(format!(
438 "logit data has {} f32 elements but vocab_size is {}",
439 total_elems, vocab_size
440 )));
441 }
442
443 // The last-position slice starts at (total_elems - vocab_size) * 4.
444 let last_pos_start = (total_elems - vocab_size) * 4;
445 let last_pos_bytes = &data[last_pos_start..];
446
447 let mut best_idx: u32 = 0;
448 let mut best_val: f32 = f32::NEG_INFINITY;
449
450 for (i, chunk) in last_pos_bytes.chunks_exact(4).enumerate() {
451 let val = f32::from_le_bytes(
452 chunk
453 .try_into()
454 .expect("chunks_exact(4) guarantees 4 bytes"),
455 );
456 if val > best_val {
457 best_val = val;
458 best_idx = i as u32;
459 }
460 }
461
462 Ok(best_idx)
463}
464
465// ─────────────────────────────────────────────────────────────────────────────
466// StreamMultiplexer
467// ─────────────────────────────────────────────────────────────────────────────
468
469/// Multiplexes multiple concurrent inference requests over a shared pipeline
470/// transport pair.
471///
472/// The [`PipelineStageRuntime`] send/recv methods process one message at a time.
473/// When multiple requests are in flight (e.g. batched API requests), their
474/// responses must be routed back to the correct caller by nonce.
475///
476/// `StreamMultiplexer` provides two complementary APIs:
477///
478/// * **Request-response** — [`send_and_await`] sends an [`ActivationMessage`]
479/// and suspends the caller until the matching reply (same nonce) arrives.
480/// * **Async dispatch** — [`register_handler`] registers a one-shot
481/// [`oneshot::Sender`] that is fired when a response with the matching nonce
482/// is delivered by [`dispatch_incoming`].
483///
484/// # Typical usage
485///
486/// ```text
487/// // Spawn one background task that continuously calls dispatch_incoming:
488/// tokio::spawn(async move {
489/// loop {
490/// mux.dispatch_incoming(&mut stage).await.unwrap();
491/// }
492/// });
493///
494/// // Each request task calls send_and_await:
495/// let response = mux.send_and_await(msg, &mut stage).await?;
496/// ```
497///
498/// [`send_and_await`]: StreamMultiplexer::send_and_await
499/// [`register_handler`]: StreamMultiplexer::register_handler
500/// [`dispatch_incoming`]: StreamMultiplexer::dispatch_incoming
501pub struct StreamMultiplexer {
502 /// Pending one-shot senders keyed by nonce.
503 ///
504 /// When a response message arrives with a known nonce, the matching sender
505 /// is removed and the message is delivered through it.
506 pending: HashMap<u64, oneshot::Sender<ActivationMessage>>,
507}
508
509impl StreamMultiplexer {
510 /// Create a new, empty multiplexer.
511 pub fn new() -> Self {
512 Self {
513 pending: HashMap::new(),
514 }
515 }
516
517 /// Register a one-shot handler that fires when a response with `nonce` arrives.
518 ///
519 /// Returns the corresponding [`oneshot::Receiver`] which the caller can
520 /// `await` to get the response. If a handler for the same nonce is already
521 /// registered, it is replaced and the old receiver will never fire.
522 pub fn register_handler(&mut self, nonce: u64) -> oneshot::Receiver<ActivationMessage> {
523 let (tx, rx) = oneshot::channel();
524 self.pending.insert(nonce, tx);
525 rx
526 }
527
528 /// Send `msg` via `stage` and await the matching response.
529 ///
530 /// This is a higher-level convenience that calls `register_handler`,
531 /// forwards the message to the next shard, and then awaits the registered
532 /// one-shot receiver.
533 ///
534 /// # Errors
535 ///
536 /// Returns `DistributedError::Cancelled` if the dispatch task drops the
537 /// sender before delivering a response (e.g. on transport error).
538 pub async fn send_and_await(
539 &mut self,
540 msg: ActivationMessage,
541 stage: &mut PipelineStageRuntime,
542 ) -> DistributedResult<ActivationMessage> {
543 let nonce = msg.nonce;
544 let rx = self.register_handler(nonce);
545
546 // Transmit — reuse send_to_next which handles compression.
547 stage
548 .send_to_next(&msg.data, &msg.shape, nonce, msg.layer_id)
549 .await?;
550
551 // Await the response from the background dispatch loop.
552 rx.await.map_err(|_| DistributedError::Cancelled)
553 }
554
555 /// Receive one incoming message from `stage` and route it to the waiting
556 /// caller identified by its nonce.
557 ///
558 /// This should be called repeatedly from a dedicated background task:
559 ///
560 /// ```text
561 /// loop {
562 /// mux.dispatch_incoming(&mut stage).await?;
563 /// }
564 /// ```
565 ///
566 /// Messages whose nonce is not in `pending` (e.g. unsolicited or already
567 /// cancelled) are silently dropped.
568 pub async fn dispatch_incoming(
569 &mut self,
570 stage: &mut PipelineStageRuntime,
571 ) -> DistributedResult<()> {
572 // For the first shard we receive via result_receiver (from last shard).
573 // For other shards we receive from the previous stage.
574 let msg = if stage.config().is_first {
575 stage.recv_result().await?
576 } else {
577 stage.recv_from_prev().await?
578 };
579
580 let nonce = msg.nonce;
581 if let Some(tx) = self.pending.remove(&nonce) {
582 // Ignore send errors: the receiver may have been dropped if the
583 // request was cancelled on the caller side.
584 let _ = tx.send(msg);
585 }
586 // Unknown nonce: silently discard (logged at trace level in production).
587
588 Ok(())
589 }
590
591 /// Number of requests currently awaiting a response.
592 pub fn pending_count(&self) -> usize {
593 self.pending.len()
594 }
595}
596
597impl Default for StreamMultiplexer {
598 fn default() -> Self {
599 Self::new()
600 }
601}
602
603#[cfg(test)]
604mod generation_tests {
605 use super::*;
606
607 #[test]
608 fn greedy_argmax_picks_max() {
609 // logits: [1, 1, 4] — one batch, one position, four vocab entries.
610 // Values: 0.1, 0.5, 0.9, 0.2 — argmax should be index 2.
611 let logits: Vec<f32> = vec![0.1, 0.5, 0.9, 0.2];
612 let data: Vec<u8> = logits.iter().flat_map(|v| v.to_le_bytes()).collect();
613 let shape = vec![1u32, 1, 4];
614 let token = greedy_argmax_last_position(&data, &shape).unwrap();
615 assert_eq!(token, 2);
616 }
617
618 #[test]
619 fn greedy_argmax_last_position_multi_step() {
620 // Three sequence positions, vocab_size=3.
621 // Last position logits: 1.0, 3.0, 2.0 — argmax = 1.
622 let logits: Vec<f32> = vec![
623 // position 0
624 5.0, 0.0, 0.0, // position 1
625 0.0, 5.0, 0.0, // position 2 (last)
626 1.0, 3.0, 2.0,
627 ];
628 let data: Vec<u8> = logits.iter().flat_map(|v| v.to_le_bytes()).collect();
629 let shape = vec![1u32, 3, 3];
630 let token = greedy_argmax_last_position(&data, &shape).unwrap();
631 assert_eq!(token, 1);
632 }
633
634 #[test]
635 fn greedy_argmax_rejects_empty_shape() {
636 let err = greedy_argmax_last_position(&[0u8; 8], &[]).unwrap_err();
637 assert!(matches!(err, DistributedError::Protocol(_)));
638 }
639
640 #[test]
641 fn stream_multiplexer_pending_count() {
642 let mut mux = StreamMultiplexer::new();
643 assert_eq!(mux.pending_count(), 0);
644 let _rx1 = mux.register_handler(1);
645 let _rx2 = mux.register_handler(2);
646 assert_eq!(mux.pending_count(), 2);
647 }
648
649 #[test]
650 fn stream_multiplexer_register_replace() {
651 // Registering the same nonce twice replaces the previous handler.
652 let mut mux = StreamMultiplexer::new();
653 let _rx_old = mux.register_handler(42);
654 let rx_new = mux.register_handler(42);
655 // Old receiver should now never fire (sender was replaced).
656 // New receiver is tracked.
657 assert_eq!(mux.pending_count(), 1);
658 drop(rx_new);
659 }
660}