use anyhow::Result;
use nexar::{NexarClient, Rank};
use std::sync::Arc;
use super::transport::{self, GenRequestHeader, LayerAssignment, tags};
pub type ForwardFn = Box<dyn Fn(&[u8], usize) -> Vec<u8> + Send + Sync>;
pub struct SwarmWorker {
client: Arc<NexarClient>,
rank: Rank,
assignment: LayerAssignment,
prev_rank: Option<Rank>,
next_rank: Option<Rank>,
leader_rank: Rank,
forward_fn: ForwardFn,
}
impl SwarmWorker {
pub fn new(
client: Arc<NexarClient>,
rank: Rank,
assignment: LayerAssignment,
prev_rank: Option<Rank>,
next_rank: Option<Rank>,
leader_rank: Rank,
forward_fn: ForwardFn,
) -> Self {
Self {
client,
rank,
assignment,
prev_rank,
next_rank,
leader_rank,
forward_fn,
}
}
pub async fn receive_assignment(
client: &NexarClient,
leader_rank: Rank,
) -> Result<LayerAssignment> {
let mut buf = [0u8; 10];
transport::recv_bytes(client, &mut buf, leader_rank, tags::LAYER_ASSIGNMENT).await?;
Ok(LayerAssignment::from_bytes(&buf))
}
pub async fn send_ready(client: &NexarClient, leader_rank: Rank) -> Result<()> {
let ack = [1u8];
transport::send_bytes(client, &ack, leader_rank, tags::WORKER_READY).await
}
pub async fn run_compute_loop(&self) -> Result<()> {
tracing::info!(
rank = self.rank,
layers = format!(
"{}..{}",
self.assignment.start_layer, self.assignment.end_layer
),
"Worker compute loop starting"
);
loop {
let mut header_buf = [0u8; 12];
let control_src = self.prev_rank.unwrap_or(self.leader_rank);
match transport::recv_bytes(
&self.client,
&mut header_buf,
control_src,
tags::GEN_REQUEST,
)
.await
{
Ok(()) => {}
Err(e) => {
tracing::warn!("Worker recv error (may be shutdown): {}", e);
break;
}
}
let header = GenRequestHeader::from_bytes(&header_buf);
let hidden_size = header.seq_len as usize;
tracing::debug!(
rank = self.rank,
seq_len = header.seq_len,
max_tokens = header.max_tokens,
"Received generation request"
);
if let Some(prev) = self.prev_rank {
let tensor_bytes = hidden_size * std::mem::size_of::<f32>();
let mut activation_buf = vec![0u8; tensor_bytes];
transport::recv_bytes(&self.client, &mut activation_buf, prev, tags::ACTIVATION)
.await?;
let output_buf = (self.forward_fn)(&activation_buf, hidden_size);
if let Some(next) = self.next_rank {
transport::send_bytes(&self.client, &output_buf, next, tags::ACTIVATION)
.await?;
} else {
transport::send_bytes(
&self.client,
&output_buf,
self.leader_rank,
tags::LOGITS,
)
.await?;
}
} else {
let token_bytes = header.seq_len as usize * std::mem::size_of::<i64>();
let mut token_buf = vec![0u8; token_bytes];
transport::recv_bytes(
&self.client,
&mut token_buf,
self.leader_rank,
tags::ACTIVATION,
)
.await?;
let output_buf = (self.forward_fn)(&token_buf, header.seq_len as usize);
if let Some(next) = self.next_rank {
transport::send_bytes(&self.client, &output_buf, next, tags::ACTIVATION)
.await?;
} else {
transport::send_bytes(
&self.client,
&output_buf,
self.leader_rank,
tags::LOGITS,
)
.await?;
}
}
tracing::debug!(rank = self.rank, "Forward pass complete for request");
}
tracing::info!(rank = self.rank, "Worker compute loop ended");
Ok(())
}
pub fn rank(&self) -> Rank {
self.rank
}
pub fn assignment(&self) -> &LayerAssignment {
&self.assignment
}
}