pub struct PipelineGenerationLoop {
pub stage: PipelineStageRuntime,
pub max_tokens: usize,
pub stop_tokens: Vec<u32>,
}Expand description
End-to-end autoregressive generation loop for a pipeline-parallel cluster.
§Roles
-
First shard (rank 0) — calls
generate_first_shard. It sends the already-embedded+locally-forwarded hidden state to the next shard, then waits for a 4-byte token reply from the last shard. The sampled token is fed back as the next input and the cycle repeats. -
Middle / last shards — call
run_shard_loopin a background task, supplying aforward_fnclosure that accepts raw activation bytes and returns the result. The last shard’sforward_fnmust return logits ([batch, seq_len, vocab_size]) in fp32 little-endian; the loop applies greedy argmax and sends the winning token back to rank 0.
§Greedy sampler
The built-in sampler applies argmax over the last position of the logit
tensor (shape = [batch, seq_len, vocab_size]):
token = argmax(logits[0, -1, :])This is the standard greedy decode step identical to what dnet’s generate_stream
loop does when temperature=0.
Fields§
§stage: PipelineStageRuntimeThe underlying stage runtime used for sending/receiving.
max_tokens: usizeMaximum number of new tokens to generate.
stop_tokens: Vec<u32>Token IDs that terminate generation (e.g. EOS). Any token whose u32 value appears in this list stops the loop immediately.
Implementations§
Source§impl PipelineGenerationLoop
impl PipelineGenerationLoop
Sourcepub fn new(
stage: PipelineStageRuntime,
max_tokens: usize,
stop_tokens: Vec<u32>,
) -> Self
pub fn new( stage: PipelineStageRuntime, max_tokens: usize, stop_tokens: Vec<u32>, ) -> Self
Create a new generation loop wrapping the given stage runtime.
Sourcepub async fn generate_first_shard(
&mut self,
input_hidden: &[u8],
input_shape: &[u32],
vocab_size: u32,
) -> DistributedResult<Vec<u32>>
pub async fn generate_first_shard( &mut self, input_hidden: &[u8], input_shape: &[u32], vocab_size: u32, ) -> DistributedResult<Vec<u32>>
Drive autoregressive generation from the first shard (rank 0).
The caller is responsible for embedding + running the local layers to
produce input_hidden before the first call. On each subsequent step
the single-token hidden state from the local forward pass is passed in
again.
§Arguments
input_hidden— raw bytes of the hidden state produced by this shard’s local forward pass (dtype matchesstage.config.wire_dtype).input_shape— shape ofinput_hidden, e.g.[1, seq_len, hidden_dim].vocab_size— vocabulary size; used to validate the logit payload returned by the last shard.
§Returns
The ordered list of generated token IDs (not including the prompt).
Sourcepub async fn run_shard_loop<F>(
&mut self,
forward_fn: F,
) -> DistributedResult<()>
pub async fn run_shard_loop<F>( &mut self, forward_fn: F, ) -> DistributedResult<()>
Run a middle or last shard’s receive → compute → send loop.
This method blocks until the first shard signals termination (i.e. the pipeline transport is closed) or an error occurs.
§forward_fn contract
fn forward_fn(data: &[u8], shape: &[u32]) -> DistributedResult<(Vec<u8>, Vec<u32>)>- Input: raw activation bytes + shape from the previous shard.
- Output: either the next hidden state (middle shards) or fp32
logits
[batch, seq_len, vocab_size](last shard).
For the last shard the returned bytes are treated as fp32 logits.
run_shard_loop applies greedy argmax on the final position, packs the
winning index as a 4-byte LE u32, and sends it back to rank 0 via
send_result. Middle shards forward the returned bytes to the next
shard via send_to_next.
Auto Trait Implementations§
impl Freeze for PipelineGenerationLoop
impl RefUnwindSafe for PipelineGenerationLoop
impl Send for PipelineGenerationLoop
impl Sync for PipelineGenerationLoop
impl Unpin for PipelineGenerationLoop
impl UnsafeUnpin for PipelineGenerationLoop
impl UnwindSafe for PipelineGenerationLoop
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more