use anyhow::Result;
use nexar::{NexarClient, Rank};
use std::sync::Arc;
use super::protocol::{DecodeRequest, DecodeStepFn, DecodedToken, tags};
use crate::distributed::inference::transport;
pub struct DecodeWorker {
client: Arc<NexarClient>,
rank: Rank,
router_rank: Rank,
prefill_workers: Vec<Rank>,
max_kv_transfer_bytes: usize,
decode_step_fn: DecodeStepFn,
}
impl DecodeWorker {
pub fn new(
client: Arc<NexarClient>,
rank: Rank,
router_rank: Rank,
prefill_workers: Vec<Rank>,
max_kv_transfer_bytes: usize,
decode_step_fn: DecodeStepFn,
) -> Self {
Self {
client,
rank,
router_rank,
prefill_workers,
max_kv_transfer_bytes,
decode_step_fn,
}
}
pub async fn run_loop(&self) -> Result<()> {
tracing::info!(rank = self.rank, "Decode worker loop starting");
loop {
let mut kv_len_buf = [0u8; 8];
let mut kv_source_rank: Rank = 0;
let mut received = false;
for &prefill_rank in &self.prefill_workers {
match transport::recv_bytes(
&self.client,
&mut kv_len_buf,
prefill_rank,
tags::KV_CACHE,
)
.await
{
Ok(()) => {
kv_source_rank = prefill_rank;
received = true;
break;
}
Err(_) => continue,
}
}
if !received {
tracing::warn!(
rank = self.rank,
"No KV cache received from any prefill worker (shutdown?)"
);
break;
}
let kv_len = u64::from_le_bytes(kv_len_buf) as usize;
if kv_len > self.max_kv_transfer_bytes {
tracing::error!(
rank = self.rank,
kv_bytes = kv_len,
limit = self.max_kv_transfer_bytes,
"KV cache exceeds limit; skipping request"
);
continue;
}
let mut kv_cache = vec![0u8; kv_len];
transport::recv_bytes(&self.client, &mut kv_cache, kv_source_rank, tags::KV_CACHE)
.await?;
let ack = [0u8; 8];
transport::send_bytes(&self.client, &ack, kv_source_rank, tags::KV_CACHE_ACK).await?;
let mut decode_req_buf = [0u8; 16];
transport::recv_bytes(
&self.client,
&mut decode_req_buf,
self.router_rank,
tags::DECODE_REQUEST,
)
.await?;
let decode_req = DecodeRequest::from_bytes(&decode_req_buf);
tracing::debug!(
rank = self.rank,
request_id = decode_req.request_id,
max_new_tokens = decode_req.max_new_tokens,
"Starting decode loop"
);
let mut last_token: i64 = 0;
let mut position: u32 = 0;
let mut tokens_generated: u32 = 0;
if kv_cache.len() >= 4 {
position = u32::from_le_bytes(kv_cache[0..4].try_into().unwrap());
}
loop {
if tokens_generated >= decode_req.max_new_tokens {
break;
}
let (next_token, updated_kv) =
(self.decode_step_fn)(&kv_cache, last_token, position);
kv_cache = updated_kv;
last_token = next_token;
position += 1;
tokens_generated += 1;
let tok = DecodedToken {
request_id: decode_req.request_id,
token_id: next_token,
};
transport::send_bytes(
&self.client,
&tok.to_bytes(),
self.router_rank,
tags::DECODE_TOKEN,
)
.await?;
if next_token == i64::MIN {
tracing::debug!(
rank = self.rank,
request_id = decode_req.request_id,
tokens_generated,
"EOS reached"
);
break;
}
}
let done_payload = decode_req.request_id.to_le_bytes();
transport::send_bytes(
&self.client,
&done_payload,
self.router_rank,
tags::DECODE_DONE,
)
.await?;
tracing::debug!(
rank = self.rank,
request_id = decode_req.request_id,
tokens_generated,
"Decode complete"
);
}
tracing::info!(rank = self.rank, "Decode worker loop ended");
Ok(())
}
pub fn rank(&self) -> Rank {
self.rank
}
}