use anyhow::{Context, Result, anyhow};
use nexar::{NexarClient, Rank};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use super::protocol::{
DecodeRequest, DecodedToken, DisaggConfig, PrefillDone, PrefillRequest, tags,
};
use crate::distributed::inference::transport;
struct PrefillLoad {
rank: Rank,
in_flight: AtomicU64,
}
pub struct DisaggRouter {
client: Arc<NexarClient>,
config: DisaggConfig,
next_request_id: AtomicU64,
prefill_loads: Vec<Arc<PrefillLoad>>,
decode_cursor: AtomicU64,
kv_affinity: Mutex<HashMap<String, Rank>>,
}
impl DisaggRouter {
pub fn new(client: Arc<NexarClient>, config: DisaggConfig) -> Self {
let prefill_loads = config
.prefill_workers
.iter()
.map(|&rank| {
Arc::new(PrefillLoad {
rank,
in_flight: AtomicU64::new(0),
})
})
.collect();
Self {
client,
config,
next_request_id: AtomicU64::new(1),
prefill_loads,
decode_cursor: AtomicU64::new(0),
kv_affinity: Mutex::new(HashMap::new()),
}
}
fn choose_prefill_worker(&self) -> (Rank, u64) {
let load = self
.prefill_loads
.iter()
.min_by_key(|pl| pl.in_flight.load(Ordering::Relaxed))
.expect("at least one prefill worker must be configured");
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
load.in_flight.fetch_add(1, Ordering::Relaxed);
(load.rank, request_id)
}
fn release_prefill_worker(&self, rank: Rank) {
if let Some(pl) = self.prefill_loads.iter().find(|pl| pl.rank == rank) {
pl.in_flight.fetch_sub(1, Ordering::Relaxed);
}
}
fn choose_decode_worker(&self, session_key: Option<&str>) -> Rank {
if let Some(key) = session_key {
let affinity = self.kv_affinity.lock().expect("kv_affinity mutex poisoned");
if let Some(&rank) = affinity.get(key) {
return rank;
}
}
let n = self.config.decode_workers.len() as u64;
let idx = self.decode_cursor.fetch_add(1, Ordering::Relaxed) % n;
self.config.decode_workers[idx as usize]
}
pub fn record_kv_affinity(&self, session_key: String, decode_rank: Rank) {
self.kv_affinity
.lock()
.expect("kv_affinity mutex poisoned")
.insert(session_key, decode_rank);
}
pub fn evict_kv_affinity(&self, session_key: &str) {
self.kv_affinity
.lock()
.expect("kv_affinity mutex poisoned")
.remove(session_key);
}
pub async fn route_request(
&self,
token_ids_bytes: &[u8],
seq_len: u32,
max_new_tokens: u32,
session_key: Option<&str>,
) -> Result<Vec<i64>> {
if self.config.prefill_workers.is_empty() {
return Err(anyhow!("No prefill workers configured"));
}
if self.config.decode_workers.is_empty() {
return Err(anyhow!("No decode workers configured"));
}
let decode_rank = self.choose_decode_worker(session_key);
let (prefill_rank, request_id) = self.choose_prefill_worker();
tracing::debug!(
request_id,
prefill_rank,
decode_rank,
seq_len,
"Routing prefill request"
);
let prefill_result: Result<PrefillDone> = async {
transport::send_bytes(
&self.client,
token_ids_bytes,
prefill_rank,
transport::tags::ACTIVATION,
)
.await?;
let prefill_req = PrefillRequest {
request_id,
seq_len,
decode_rank,
};
transport::send_bytes(
&self.client,
&prefill_req.to_bytes(),
prefill_rank,
tags::PREFILL_REQUEST,
)
.await?;
let mut done_buf = [0u8; 16];
transport::recv_bytes(
&self.client,
&mut done_buf,
prefill_rank,
tags::PREFILL_DONE,
)
.await?;
Ok(PrefillDone::from_bytes(&done_buf))
}
.await;
self.release_prefill_worker(prefill_rank);
let prefill_done = prefill_result?;
tracing::debug!(
request_id = prefill_done.request_id,
kv_bytes = prefill_done.kv_bytes,
"Prefill complete; starting decode"
);
let decode_req = DecodeRequest {
request_id,
max_new_tokens,
};
transport::send_bytes(
&self.client,
&decode_req.to_bytes(),
decode_rank,
tags::DECODE_REQUEST,
)
.await?;
let mut tokens = Vec::new();
loop {
let mut token_buf = [0u8; 16];
match transport::recv_bytes(
&self.client,
&mut token_buf,
decode_rank,
tags::DECODE_TOKEN,
)
.await
{
Ok(()) => {
let decoded = DecodedToken::from_bytes(&token_buf);
if decoded.is_done() {
break;
}
tokens.push(decoded.token_id);
}
Err(recv_err) => {
let mut done_buf2 = [0u8; 16];
match transport::recv_bytes(
&self.client,
&mut done_buf2,
decode_rank,
tags::DECODE_DONE,
)
.await
{
Ok(()) => break, Err(_) => {
return Err(recv_err.context(format!(
"decode transport error after {} tokens from worker {}",
tokens.len(),
decode_rank
)));
}
}
}
}
}
if let Some(key) = session_key {
self.record_kv_affinity(key.to_string(), decode_rank);
}
Ok(tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_router_choose_prefill_least_loaded() {
let loads: Vec<Arc<PrefillLoad>> = vec![
Arc::new(PrefillLoad {
rank: 1,
in_flight: AtomicU64::new(5),
}),
Arc::new(PrefillLoad {
rank: 2,
in_flight: AtomicU64::new(1),
}),
Arc::new(PrefillLoad {
rank: 3,
in_flight: AtomicU64::new(3),
}),
];
let least = loads
.iter()
.min_by_key(|pl| pl.in_flight.load(Ordering::Relaxed))
.unwrap();
assert_eq!(least.rank, 2);
}
#[test]
fn test_router_kv_affinity() {
let affinity: Mutex<HashMap<String, Rank>> = Mutex::new(HashMap::new());
affinity
.lock()
.unwrap()
.insert("session-abc".to_string(), 4);
let rank = *affinity.lock().unwrap().get("session-abc").unwrap();
assert_eq!(rank, 4);
affinity.lock().unwrap().remove("session-abc");
assert!(affinity.lock().unwrap().get("session-abc").is_none());
}
#[test]
fn test_decode_cursor_round_robin() {
let cursor = AtomicU64::new(0);
let workers = [10u32, 20u32, 30u32];
let n = workers.len() as u64;
let picks: Vec<u32> = (0..6)
.map(|_| workers[(cursor.fetch_add(1, Ordering::Relaxed) % n) as usize])
.collect();
assert_eq!(picks, vec![10, 20, 30, 10, 20, 30]);
}
}