use anyhow::Result;
use nexar::{NexarClient, Rank};
use std::sync::Arc;
use super::protocol::{PrefillDone, PrefillFn, PrefillRequest, tags};
use crate::distributed::inference::transport;
pub struct PrefillWorker {
client: Arc<NexarClient>,
rank: Rank,
router_rank: Rank,
max_kv_transfer_bytes: usize,
prefill_fn: PrefillFn,
}
impl PrefillWorker {
pub fn new(
client: Arc<NexarClient>,
rank: Rank,
router_rank: Rank,
max_kv_transfer_bytes: usize,
prefill_fn: PrefillFn,
) -> Self {
Self {
client,
rank,
router_rank,
max_kv_transfer_bytes,
prefill_fn,
}
}
pub async fn run_loop(&self) -> Result<()> {
tracing::info!(rank = self.rank, "Prefill worker loop starting");
loop {
let mut req_buf = [0u8; 16];
match transport::recv_bytes(
&self.client,
&mut req_buf,
self.router_rank,
tags::PREFILL_REQUEST,
)
.await
{
Ok(()) => {}
Err(e) => {
tracing::warn!(rank = self.rank, "Prefill recv error (shutdown?): {}", e);
break;
}
}
let req = PrefillRequest::from_bytes(&req_buf);
let token_bytes = req.seq_len as usize * std::mem::size_of::<i64>();
tracing::debug!(
rank = self.rank,
request_id = req.request_id,
seq_len = req.seq_len,
decode_rank = req.decode_rank,
"Received prefill request"
);
let mut token_buf = vec![0u8; token_bytes];
transport::recv_bytes(
&self.client,
&mut token_buf,
self.router_rank,
transport::tags::ACTIVATION,
)
.await?;
let (_activation, kv_cache) = (self.prefill_fn)(&token_buf, req.seq_len as usize);
if kv_cache.len() > self.max_kv_transfer_bytes {
tracing::error!(
rank = self.rank,
request_id = req.request_id,
kv_bytes = kv_cache.len(),
limit = self.max_kv_transfer_bytes,
"KV cache exceeds transfer limit; dropping request"
);
let done = PrefillDone {
request_id: req.request_id,
kv_bytes: 0,
};
let _ = transport::send_bytes(
&self.client,
&done.to_bytes(),
self.router_rank,
tags::PREFILL_DONE,
)
.await;
continue;
}
let kv_bytes_len = kv_cache.len() as u64;
let decode_rank = req.decode_rank as Rank;
let kv_len_bytes = kv_bytes_len.to_le_bytes();
transport::send_bytes(&self.client, &kv_len_bytes, decode_rank, tags::KV_CACHE).await?;
transport::send_bytes(&self.client, &kv_cache, decode_rank, tags::KV_CACHE).await?;
let mut ack_buf = [0u8; 8];
transport::recv_bytes(&self.client, &mut ack_buf, decode_rank, tags::KV_CACHE_ACK)
.await?;
let done = PrefillDone {
request_id: req.request_id,
kv_bytes: kv_bytes_len,
};
transport::send_bytes(
&self.client,
&done.to_bytes(),
self.router_rank,
tags::PREFILL_DONE,
)
.await?;
tracing::debug!(
rank = self.rank,
request_id = req.request_id,
kv_bytes = kv_bytes_len,
"Prefill done; KV cache transferred"
);
}
tracing::info!(rank = self.rank, "Prefill worker loop ended");
Ok(())
}
pub fn rank(&self) -> Rank {
self.rank
}
}