use axum::http::HeaderMap;
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use validator::{Validate, ValidationError};
pub use crate::protocols::common::timing::TimingInfo;
pub const HEADER_WORKER_INSTANCE_ID: &str = "x-worker-instance-id";
pub const HEADER_PREFILL_INSTANCE_ID: &str = "x-prefill-instance-id";
pub const HEADER_DP_RANK: &str = "x-dp-rank";
pub const HEADER_PREFILL_DP_RANK: &str = "x-prefill-dp-rank";
const UNSET_DP_RANK_SENTINEL: u32 = u32::MAX;
pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap) -> Option<NvExt> {
let worker_id = headers
.get(HEADER_WORKER_INSTANCE_ID)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let prefill_id = headers
.get(HEADER_PREFILL_INSTANCE_ID)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let dp_rank = headers
.get(HEADER_DP_RANK)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u32>().ok());
let prefill_dp_rank = headers
.get(HEADER_PREFILL_DP_RANK)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u32>().ok());
let prefill_dp_rank = prefill_dp_rank.filter(|rank| *rank != UNSET_DP_RANK_SENTINEL);
if worker_id.is_none() && prefill_id.is_none() && dp_rank.is_none() && prefill_dp_rank.is_none()
{
return nvext;
}
let mut ext = nvext.unwrap_or_default();
if let Some(id) = worker_id {
ext.backend_instance_id = Some(id);
ext.decode_worker_id = Some(id);
}
if let Some(id) = prefill_id {
ext.prefill_worker_id = Some(id);
}
if let Some(rank) = dp_rank {
ext.dp_rank = Some(rank);
}
if let Some(rank) = prefill_dp_rank {
ext.prefill_dp_rank = Some(rank);
}
Some(ext)
}
pub trait NvExtProvider {
fn nvext(&self) -> Option<&NvExt>;
fn raw_prompt(&self) -> Option<String>;
}
#[derive(ToSchema, Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct WorkerIdInfo {
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_worker_id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_dp_rank: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub decode_dp_rank: Option<u32>,
}
#[derive(ToSchema, Serialize, Deserialize, Debug, Clone)]
pub struct NvExtResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub worker_id: Option<WorkerIdInfo>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timing: Option<TimingInfo>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_ids: Option<Vec<u32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub routed_experts: Option<serde_json::Value>,
}
#[derive(ToSchema, Serialize, Deserialize, Builder, Validate, Debug, Clone)]
#[validate(schema(function = "validate_nv_ext"))]
pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub greed_sampling: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub use_raw_prompt: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub annotations: Option<Vec<String>>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub backend_instance_id: Option<u64>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub token_data: Option<Vec<u32>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub max_thinking_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub extra_fields: Option<Vec<String>>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_worker_id: Option<u64>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_rank: Option<u32>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_dp_rank: Option<u32>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub agent_hints: Option<AgentHints>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub request_timestamp_ms: Option<f64>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub session_control: Option<SessionControl>,
}
#[derive(ToSchema, Serialize, Deserialize, Builder, Debug, Clone, Default, PartialEq)]
pub struct AgentHints {
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub osl: Option<u32>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub speculative_prefill: Option<bool>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(ignore)]
pub latency_sensitivity: Option<f64>,
}
fn default_session_timeout() -> u64 {
300
}
#[derive(ToSchema, Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct SessionControl {
pub session_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub action: Option<SessionAction>,
#[serde(default = "default_session_timeout")]
pub timeout: u64,
}
#[derive(ToSchema, Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum SessionAction {
Open,
Close,
}
impl Default for NvExt {
fn default() -> Self {
NvExt::builder().build().unwrap()
}
}
impl NvExt {
pub fn builder() -> NvExtBuilder {
NvExtBuilder::default()
}
}
fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
Ok(())
}
impl NvExtBuilder {
pub fn add_annotation(&mut self, annotation: impl Into<String>) -> &mut Self {
self.annotations
.get_or_insert_with(|| Some(vec![]))
.as_mut()
.expect("stop should always be Some(Vec)")
.push(annotation.into());
self
}
}
#[cfg(test)]
mod tests {
use validator::Validate;
use super::*;
#[test]
fn test_nv_ext_builder_default() {
let nv_ext = NvExt::builder().build().unwrap();
assert_eq!(nv_ext.greed_sampling, None);
assert_eq!(nv_ext.use_raw_prompt, None);
assert_eq!(nv_ext.annotations, None);
assert_eq!(nv_ext.backend_instance_id, None);
assert_eq!(nv_ext.token_data, None);
assert_eq!(nv_ext.max_thinking_tokens, None);
assert_eq!(nv_ext.extra_fields, None);
assert_eq!(nv_ext.prefill_worker_id, None);
assert_eq!(nv_ext.decode_worker_id, None);
assert_eq!(nv_ext.agent_hints, None);
assert_eq!(nv_ext.request_timestamp_ms, None);
assert_eq!(nv_ext.session_control, None);
}
#[test]
fn test_nv_ext_builder_custom() {
let nv_ext = NvExt::builder()
.greed_sampling(true)
.use_raw_prompt(true)
.backend_instance_id(42)
.token_data(vec![1, 2, 3, 4])
.max_thinking_tokens(1024)
.extra_fields(vec!["worker_id".to_string()])
.build()
.unwrap();
assert_eq!(nv_ext.greed_sampling, Some(true));
assert_eq!(nv_ext.use_raw_prompt, Some(true));
assert_eq!(nv_ext.backend_instance_id, Some(42));
assert_eq!(nv_ext.token_data, Some(vec![1, 2, 3, 4]));
assert_eq!(nv_ext.max_thinking_tokens, Some(1024));
assert_eq!(nv_ext.extra_fields, Some(vec!["worker_id".to_string()]));
assert!(nv_ext.validate().is_ok());
}
#[test]
fn test_nv_ext_disagg_worker_ids() {
let nv_ext = NvExt::builder()
.prefill_worker_id(100)
.decode_worker_id(200)
.build()
.unwrap();
assert_eq!(nv_ext.prefill_worker_id, Some(100));
assert_eq!(nv_ext.decode_worker_id, Some(200));
assert!(nv_ext.validate().is_ok());
}
#[test]
fn test_session_control_serde() {
let sc_json = r#"{"session_id": "sub-1", "action": "open", "timeout": 60}"#;
let sc: SessionControl = serde_json::from_str(sc_json).unwrap();
assert_eq!(sc.action, Some(SessionAction::Open));
assert_eq!(sc.session_id, "sub-1");
assert_eq!(sc.timeout, 60);
let sc_close = r#"{"session_id": "sub-1", "action": "close"}"#;
let sc: SessionControl = serde_json::from_str(sc_close).unwrap();
assert_eq!(sc.action, Some(SessionAction::Close));
assert_eq!(sc.timeout, 300);
let sc_continue = r#"{"session_id": "sub-1"}"#;
let sc: SessionControl = serde_json::from_str(sc_continue).unwrap();
assert_eq!(sc.action, None);
assert_eq!(sc.session_id, "sub-1");
let nvext_json =
r#"{"session_control": {"session_id": "sub-2", "action": "open", "timeout": 300}}"#;
let nvext: NvExt = serde_json::from_str(nvext_json).unwrap();
assert!(nvext.session_control.is_some());
let sc = nvext.session_control.unwrap();
assert_eq!(sc.action, Some(SessionAction::Open));
assert_eq!(sc.session_id, "sub-2");
let original = SessionControl {
session_id: "test-session".to_string(),
action: Some(SessionAction::Close),
timeout: 90,
};
let json = serde_json::to_string(&original).unwrap();
let deser: SessionControl = serde_json::from_str(&json).unwrap();
assert_eq!(deser, original);
}
#[test]
fn test_apply_header_routing_overrides() {
use axum::http::HeaderMap;
let mut headers = HeaderMap::new();
headers.insert(HEADER_WORKER_INSTANCE_ID, "123".parse().unwrap());
headers.insert(HEADER_PREFILL_INSTANCE_ID, "456".parse().unwrap());
headers.insert(HEADER_DP_RANK, "3".parse().unwrap());
headers.insert(HEADER_PREFILL_DP_RANK, "5".parse().unwrap());
let result = apply_header_routing_overrides(None, &headers).unwrap();
assert_eq!(result.backend_instance_id, Some(123));
assert_eq!(result.decode_worker_id, Some(123));
assert_eq!(result.prefill_worker_id, Some(456));
assert_eq!(result.dp_rank, Some(3));
assert_eq!(result.prefill_dp_rank, Some(5));
}
}