use nexar::Rank;
pub type PrefillFn = Box<dyn Fn(&[u8], usize) -> (Vec<u8>, Vec<u8>) + Send + Sync>;
pub type DecodeStepFn = Box<dyn Fn(&[u8], i64, u32) -> (i64, Vec<u8>) + Send + Sync>;
pub mod tags {
pub const PREFILL_REQUEST: u32 = 30;
pub const PREFILL_DONE: u32 = 31;
pub const KV_CACHE: u32 = 32;
pub const KV_CACHE_ACK: u32 = 33;
pub const DECODE_REQUEST: u32 = 34;
pub const DECODE_TOKEN: u32 = 35;
pub const DECODE_DONE: u32 = 36;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DisaggRole {
Prefill,
Decode,
Router,
}
#[derive(Debug, Clone)]
pub struct DisaggConfig {
pub prefill_workers: Vec<Rank>,
pub decode_workers: Vec<Rank>,
pub router_rank: Rank,
pub max_kv_transfer_bytes: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct PrefillRequest {
pub request_id: u64,
pub seq_len: u32,
pub decode_rank: u32,
}
impl PrefillRequest {
pub fn to_bytes(&self) -> [u8; 16] {
let mut buf = [0u8; 16];
buf[0..8].copy_from_slice(&self.request_id.to_le_bytes());
buf[8..12].copy_from_slice(&self.seq_len.to_le_bytes());
buf[12..16].copy_from_slice(&self.decode_rank.to_le_bytes());
buf
}
pub fn from_bytes(buf: &[u8; 16]) -> Self {
Self {
request_id: u64::from_le_bytes(buf[0..8].try_into().unwrap()),
seq_len: u32::from_le_bytes(buf[8..12].try_into().unwrap()),
decode_rank: u32::from_le_bytes(buf[12..16].try_into().unwrap()),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct PrefillDone {
pub request_id: u64,
pub kv_bytes: u64,
}
impl PrefillDone {
pub fn to_bytes(&self) -> [u8; 16] {
let mut buf = [0u8; 16];
buf[0..8].copy_from_slice(&self.request_id.to_le_bytes());
buf[8..16].copy_from_slice(&self.kv_bytes.to_le_bytes());
buf
}
pub fn from_bytes(buf: &[u8; 16]) -> Self {
Self {
request_id: u64::from_le_bytes(buf[0..8].try_into().unwrap()),
kv_bytes: u64::from_le_bytes(buf[8..16].try_into().unwrap()),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct DecodeRequest {
pub request_id: u64,
pub max_new_tokens: u32,
}
impl DecodeRequest {
pub fn to_bytes(&self) -> [u8; 16] {
let mut buf = [0u8; 16];
buf[0..8].copy_from_slice(&self.request_id.to_le_bytes());
buf[8..12].copy_from_slice(&self.max_new_tokens.to_le_bytes());
buf
}
pub fn from_bytes(buf: &[u8; 16]) -> Self {
Self {
request_id: u64::from_le_bytes(buf[0..8].try_into().unwrap()),
max_new_tokens: u32::from_le_bytes(buf[8..12].try_into().unwrap()),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct DecodedToken {
pub request_id: u64,
pub token_id: i64,
}
impl DecodedToken {
pub fn to_bytes(&self) -> [u8; 16] {
let mut buf = [0u8; 16];
buf[0..8].copy_from_slice(&self.request_id.to_le_bytes());
buf[8..16].copy_from_slice(&self.token_id.to_le_bytes());
buf
}
pub fn from_bytes(buf: &[u8; 16]) -> Self {
Self {
request_id: u64::from_le_bytes(buf[0..8].try_into().unwrap()),
token_id: i64::from_le_bytes(buf[8..16].try_into().unwrap()),
}
}
pub fn is_done(&self) -> bool {
self.token_id == -1
}
pub fn done(request_id: u64) -> Self {
Self {
request_id,
token_id: -1,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prefill_request_roundtrip() {
let orig = PrefillRequest {
request_id: 42,
seq_len: 128,
decode_rank: 3,
};
let bytes = orig.to_bytes();
let decoded = PrefillRequest::from_bytes(&bytes);
assert_eq!(decoded.request_id, 42);
assert_eq!(decoded.seq_len, 128);
assert_eq!(decoded.decode_rank, 3);
}
#[test]
fn test_prefill_done_roundtrip() {
let orig = PrefillDone {
request_id: 99,
kv_bytes: 1_048_576,
};
let bytes = orig.to_bytes();
let decoded = PrefillDone::from_bytes(&bytes);
assert_eq!(decoded.request_id, 99);
assert_eq!(decoded.kv_bytes, 1_048_576);
}
#[test]
fn test_decode_request_roundtrip() {
let orig = DecodeRequest {
request_id: 7,
max_new_tokens: 512,
};
let bytes = orig.to_bytes();
let decoded = DecodeRequest::from_bytes(&bytes);
assert_eq!(decoded.request_id, 7);
assert_eq!(decoded.max_new_tokens, 512);
}
#[test]
fn test_decoded_token_roundtrip() {
let orig = DecodedToken {
request_id: 1,
token_id: 12345,
};
let bytes = orig.to_bytes();
let decoded = DecodedToken::from_bytes(&bytes);
assert_eq!(decoded.request_id, 1);
assert_eq!(decoded.token_id, 12345);
}
}