use axum::http::HeaderMap;
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use validator::{Validate, ValidationError};
pub use crate::protocols::common::timing::TimingInfo;
pub const HEADER_WORKER_INSTANCE_ID: &str = "x-worker-instance-id";
pub const HEADER_PREFILL_INSTANCE_ID: &str = "x-prefill-instance-id";
pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap) -> Option<NvExt> {
let worker_id = headers
.get(HEADER_WORKER_INSTANCE_ID)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let prefill_id = headers
.get(HEADER_PREFILL_INSTANCE_ID)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
if worker_id.is_none() && prefill_id.is_none() {
return nvext;
}
let mut ext = nvext.unwrap_or_default();
if let Some(id) = worker_id {
ext.backend_instance_id = Some(id);
ext.decode_worker_id = Some(id);
}
if let Some(id) = prefill_id {
ext.prefill_worker_id = Some(id);
}
Some(ext)
}
pub trait NvExtProvider {
fn nvext(&self) -> Option<&NvExt>;
fn raw_prompt(&self) -> Option<String>;
}
#[derive(ToSchema, Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct WorkerIdInfo {
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_worker_id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_dp_rank: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub decode_dp_rank: Option<u32>,
}
#[derive(ToSchema, Serialize, Deserialize, Debug, Clone)]
pub struct NvExtResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub worker_id: Option<WorkerIdInfo>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timing: Option<TimingInfo>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_ids: Option<Vec<u32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub routed_experts: Option<serde_json::Value>,
}
#[derive(ToSchema, Serialize, Deserialize, Builder, Validate, Debug, Clone)]
#[validate(schema(function = "validate_nv_ext"))]
pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub greed_sampling: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub use_raw_prompt: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub annotations: Option<Vec<String>>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub backend_instance_id: Option<u64>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub token_data: Option<Vec<u32>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub max_thinking_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub extra_fields: Option<Vec<String>>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_worker_id: Option<u64>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub agent_hints: Option<AgentHints>,
}
#[derive(ToSchema, Serialize, Deserialize, Builder, Debug, Clone, Default, PartialEq)]
pub struct AgentHints {
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub latency_sensitivity: Option<f64>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub osl: Option<u32>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub speculative_prefill: Option<bool>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
}
impl Default for NvExt {
fn default() -> Self {
NvExt::builder().build().unwrap()
}
}
impl NvExt {
pub fn builder() -> NvExtBuilder {
NvExtBuilder::default()
}
}
fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
Ok(())
}
impl NvExtBuilder {
pub fn add_annotation(&mut self, annotation: impl Into<String>) -> &mut Self {
self.annotations
.get_or_insert_with(|| Some(vec![]))
.as_mut()
.expect("stop should always be Some(Vec)")
.push(annotation.into());
self
}
}
#[cfg(test)]
mod tests {
use validator::Validate;
use super::*;
#[test]
fn test_nv_ext_builder_default() {
let nv_ext = NvExt::builder().build().unwrap();
assert_eq!(nv_ext.greed_sampling, None);
assert_eq!(nv_ext.use_raw_prompt, None);
assert_eq!(nv_ext.annotations, None);
assert_eq!(nv_ext.backend_instance_id, None);
assert_eq!(nv_ext.token_data, None);
assert_eq!(nv_ext.max_thinking_tokens, None);
assert_eq!(nv_ext.extra_fields, None);
assert_eq!(nv_ext.prefill_worker_id, None);
assert_eq!(nv_ext.decode_worker_id, None);
assert_eq!(nv_ext.agent_hints, None);
}
#[test]
fn test_nv_ext_builder_custom() {
let nv_ext = NvExt::builder()
.greed_sampling(true)
.use_raw_prompt(true)
.backend_instance_id(42)
.token_data(vec![1, 2, 3, 4])
.max_thinking_tokens(1024)
.extra_fields(vec!["worker_id".to_string()])
.build()
.unwrap();
assert_eq!(nv_ext.greed_sampling, Some(true));
assert_eq!(nv_ext.use_raw_prompt, Some(true));
assert_eq!(nv_ext.backend_instance_id, Some(42));
assert_eq!(nv_ext.token_data, Some(vec![1, 2, 3, 4]));
assert_eq!(nv_ext.max_thinking_tokens, Some(1024));
assert_eq!(nv_ext.extra_fields, Some(vec!["worker_id".to_string()]));
assert!(nv_ext.validate().is_ok());
}
#[test]
fn test_nv_ext_disagg_worker_ids() {
let nv_ext = NvExt::builder()
.prefill_worker_id(100)
.decode_worker_id(200)
.build()
.unwrap();
assert_eq!(nv_ext.prefill_worker_id, Some(100));
assert_eq!(nv_ext.decode_worker_id, Some(200));
assert!(nv_ext.validate().is_ok());
}
#[test]
fn test_apply_header_routing_overrides() {
use axum::http::HeaderMap;
let mut headers = HeaderMap::new();
headers.insert(HEADER_WORKER_INSTANCE_ID, "123".parse().unwrap());
let nvext = NvExt::builder()
.backend_instance_id(999)
.decode_worker_id(888)
.prefill_worker_id(777)
.build()
.unwrap();
let result = apply_header_routing_overrides(Some(nvext), &headers).unwrap();
assert_eq!(result.backend_instance_id, Some(123));
assert_eq!(result.decode_worker_id, Some(123));
assert_eq!(result.prefill_worker_id, Some(777));
}
}