use axum::http::HeaderMap;
use derive_builder::Builder;
use dynamo_protocols::types::StopReason;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use validator::{Validate, ValidationError};
pub use crate::agents::context::AgentContext;
pub use crate::protocols::common::timing::TimingInfo;
pub const HEADER_WORKER_INSTANCE_ID: &str = "x-worker-instance-id";
pub const HEADER_PREFILL_INSTANCE_ID: &str = "x-prefill-instance-id";
pub const HEADER_DP_RANK: &str = "x-dp-rank";
pub const HEADER_PREFILL_DP_RANK: &str = "x-prefill-dp-rank";
const UNSET_DP_RANK_SENTINEL: u32 = u32::MAX;
pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap) -> Option<NvExt> {
let worker_id = headers
.get(HEADER_WORKER_INSTANCE_ID)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let prefill_id = headers
.get(HEADER_PREFILL_INSTANCE_ID)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let dp_rank = headers
.get(HEADER_DP_RANK)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u32>().ok());
let prefill_dp_rank = headers
.get(HEADER_PREFILL_DP_RANK)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u32>().ok());
let prefill_dp_rank = prefill_dp_rank.filter(|rank| *rank != UNSET_DP_RANK_SENTINEL);
if worker_id.is_none() && prefill_id.is_none() && dp_rank.is_none() && prefill_dp_rank.is_none()
{
return nvext;
}
let mut ext = nvext.unwrap_or_default();
if let Some(id) = worker_id {
ext.backend_instance_id = Some(id);
ext.decode_worker_id = Some(id);
}
if let Some(id) = prefill_id {
ext.prefill_worker_id = Some(id);
}
if let Some(rank) = dp_rank {
ext.dp_rank = Some(rank);
}
if let Some(rank) = prefill_dp_rank {
ext.prefill_dp_rank = Some(rank);
}
Some(ext)
}
pub trait NvExtProvider {
fn nvext(&self) -> Option<&NvExt>;
fn raw_prompt(&self) -> Option<String>;
}
#[derive(ToSchema, Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct WorkerIdInfo {
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_worker_id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill_dp_rank: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub decode_dp_rank: Option<u32>,
}
#[derive(ToSchema, Serialize, Deserialize, Debug, Clone)]
pub struct NvExtResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub worker_id: Option<WorkerIdInfo>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timing: Option<TimingInfo>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_ids: Option<Vec<u32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub routed_experts: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub engine_data: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<serde_json::Value>,
}
pub(crate) fn merge_response_nvext(
target: &mut Option<serde_json::Value>,
incoming: Option<serde_json::Value>,
) {
let Some(incoming) = incoming else {
return;
};
match (target.as_mut(), incoming) {
(Some(serde_json::Value::Object(target_obj)), serde_json::Value::Object(incoming_obj)) => {
target_obj.extend(incoming_obj);
}
(_, incoming) => {
*target = Some(incoming);
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct NvExtResponseFieldSelection {
pub worker_id: bool,
pub timing: bool,
pub token_ids: bool,
pub routed_experts: bool,
pub engine_data: bool,
pub stop_reason: bool,
}
impl NvExtResponseFieldSelection {
pub fn from_nvext(nvext: Option<&NvExt>) -> Self {
let Some(ext) = nvext else {
return Self::default();
};
let mut selection = Self::default();
if let Some(fields) = ext.extra_fields.as_ref() {
for field in fields {
match field.as_str() {
"worker_id" => selection.worker_id = true,
"timing" => selection.timing = true,
"routed_experts" => selection.routed_experts = true,
"engine_data" => selection.engine_data = true,
"stop_reason" => selection.stop_reason = true,
_ => {}
}
}
}
if ext.has_query_instance_id_annotation() {
selection.worker_id = true;
selection.token_ids = true;
}
selection
}
pub fn build_response_nvext(
&self,
tracker: Option<&std::sync::Arc<crate::protocols::common::timing::RequestTracker>>,
disaggregated_params: Option<&serde_json::Value>,
finish_reason_present: bool,
engine_data_from_backend: Option<serde_json::Value>,
stop_reason_from_backend: Option<StopReason>,
) -> Option<NvExtResponse> {
let worker_id = if self.worker_id {
tracker.and_then(|t| t.get_worker_info())
} else {
None
};
let token_ids = if self.token_ids {
disaggregated_params
.and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok())
} else {
None
};
let routed_experts = if self.routed_experts {
disaggregated_params
.and_then(|params| params.get("routed_experts"))
.cloned()
} else {
None
};
let timing = if finish_reason_present && self.timing {
tracker.map(|t| t.get_timing_info())
} else {
None
};
let engine_data = if self.engine_data {
engine_data_from_backend
} else {
None
};
let stop_reason = if self.stop_reason {
stop_reason_from_backend.and_then(|reason| serde_json::to_value(reason).ok())
} else {
None
};
if worker_id.is_none()
&& token_ids.is_none()
&& routed_experts.is_none()
&& timing.is_none()
&& engine_data.is_none()
&& stop_reason.is_none()
{
return None;
}
Some(NvExtResponse {
worker_id,
timing,
token_ids,
routed_experts,
engine_data,
stop_reason,
})
}
}
#[derive(ToSchema, Serialize, Deserialize, Builder, Validate, Debug, Clone)]
#[validate(schema(function = "validate_nv_ext"))]
pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub greed_sampling: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub use_raw_prompt: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub annotations: Option<Vec<String>>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub backend_instance_id: Option<u64>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub token_data: Option<Vec<u32>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub max_thinking_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub extra_fields: Option<Vec<String>>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_worker_id: Option<u64>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_rank: Option<u32>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_dp_rank: Option<u32>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub agent_hints: Option<AgentHints>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub agent_context: Option<AgentContext>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub request_timestamp_ms: Option<f64>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub session_control: Option<SessionControl>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub router: Option<RouterParams>,
}
pub use crate::protocols::common::preprocessor::RouterParams;
#[derive(ToSchema, Serialize, Deserialize, Builder, Debug, Clone, Default, PartialEq)]
pub struct AgentHints {
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub osl: Option<u32>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub speculative_prefill: Option<bool>,
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(ignore)]
pub latency_sensitivity: Option<f64>,
}
fn default_session_timeout() -> u64 {
300
}
#[derive(ToSchema, Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct SessionControl {
pub session_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub action: Option<SessionAction>,
#[serde(default = "default_session_timeout")]
pub timeout: u64,
}
#[derive(ToSchema, Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum SessionAction {
Open,
Close,
}
impl Default for NvExt {
fn default() -> Self {
NvExt::builder().build().unwrap()
}
}
impl NvExt {
pub fn builder() -> NvExtBuilder {
NvExtBuilder::default()
}
pub fn has_query_instance_id_annotation(&self) -> bool {
self.annotations.as_ref().is_some_and(|annotations| {
annotations
.iter()
.any(|annotation| annotation.starts_with("query_instance_id:"))
})
}
}
fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
Ok(())
}
impl NvExtBuilder {
pub fn add_annotation(&mut self, annotation: impl Into<String>) -> &mut Self {
self.annotations
.get_or_insert_with(|| Some(vec![]))
.as_mut()
.expect("stop should always be Some(Vec)")
.push(annotation.into());
self
}
}
#[cfg(test)]
mod tests {
use validator::Validate;
use super::*;
#[test]
fn test_nv_ext_builder_default() {
let nv_ext = NvExt::builder().build().unwrap();
assert_eq!(nv_ext.greed_sampling, None);
assert_eq!(nv_ext.use_raw_prompt, None);
assert_eq!(nv_ext.annotations, None);
assert_eq!(nv_ext.backend_instance_id, None);
assert_eq!(nv_ext.token_data, None);
assert_eq!(nv_ext.max_thinking_tokens, None);
assert_eq!(nv_ext.extra_fields, None);
assert_eq!(nv_ext.prefill_worker_id, None);
assert_eq!(nv_ext.decode_worker_id, None);
assert_eq!(nv_ext.agent_hints, None);
assert_eq!(nv_ext.agent_context, None);
assert_eq!(nv_ext.request_timestamp_ms, None);
assert_eq!(nv_ext.session_control, None);
}
#[test]
fn test_nv_ext_builder_custom() {
let nv_ext = NvExt::builder()
.greed_sampling(true)
.use_raw_prompt(true)
.backend_instance_id(42)
.token_data(vec![1, 2, 3, 4])
.max_thinking_tokens(1024)
.extra_fields(vec!["worker_id".to_string()])
.build()
.unwrap();
assert_eq!(nv_ext.greed_sampling, Some(true));
assert_eq!(nv_ext.use_raw_prompt, Some(true));
assert_eq!(nv_ext.backend_instance_id, Some(42));
assert_eq!(nv_ext.token_data, Some(vec![1, 2, 3, 4]));
assert_eq!(nv_ext.max_thinking_tokens, Some(1024));
assert_eq!(nv_ext.extra_fields, Some(vec!["worker_id".to_string()]));
assert!(nv_ext.validate().is_ok());
}
#[test]
fn test_nv_ext_disagg_worker_ids() {
let nv_ext = NvExt::builder()
.prefill_worker_id(100)
.decode_worker_id(200)
.build()
.unwrap();
assert_eq!(nv_ext.prefill_worker_id, Some(100));
assert_eq!(nv_ext.decode_worker_id, Some(200));
assert!(nv_ext.validate().is_ok());
}
#[test]
fn test_session_control_serde() {
let sc_json = r#"{"session_id": "sub-1", "action": "open", "timeout": 60}"#;
let sc: SessionControl = serde_json::from_str(sc_json).unwrap();
assert_eq!(sc.action, Some(SessionAction::Open));
assert_eq!(sc.session_id, "sub-1");
assert_eq!(sc.timeout, 60);
let sc_close = r#"{"session_id": "sub-1", "action": "close"}"#;
let sc: SessionControl = serde_json::from_str(sc_close).unwrap();
assert_eq!(sc.action, Some(SessionAction::Close));
assert_eq!(sc.timeout, 300);
let sc_continue = r#"{"session_id": "sub-1"}"#;
let sc: SessionControl = serde_json::from_str(sc_continue).unwrap();
assert_eq!(sc.action, None);
assert_eq!(sc.session_id, "sub-1");
let nvext_json =
r#"{"session_control": {"session_id": "sub-2", "action": "open", "timeout": 300}}"#;
let nvext: NvExt = serde_json::from_str(nvext_json).unwrap();
assert!(nvext.session_control.is_some());
let sc = nvext.session_control.unwrap();
assert_eq!(sc.action, Some(SessionAction::Open));
assert_eq!(sc.session_id, "sub-2");
let original = SessionControl {
session_id: "test-session".to_string(),
action: Some(SessionAction::Close),
timeout: 90,
};
let json = serde_json::to_string(&original).unwrap();
let deser: SessionControl = serde_json::from_str(&json).unwrap();
assert_eq!(deser, original);
}
#[test]
fn test_agent_context_serde() {
let json = r#"{
"agent_context": {
"session_type_id": "deep_research:v1",
"session_id": "run-123",
"trajectory_id": "run-123:researcher-0",
"parent_trajectory_id": "run-123:orchestrator"
}
}"#;
let nvext: NvExt = serde_json::from_str(json).unwrap();
let agent_context = nvext.agent_context.expect("agent_context should parse");
assert_eq!(agent_context.session_type_id, "deep_research:v1");
assert_eq!(agent_context.session_id, "run-123");
assert_eq!(agent_context.trajectory_id, "run-123:researcher-0");
assert_eq!(
agent_context.parent_trajectory_id.as_deref(),
Some("run-123:orchestrator")
);
}
#[test]
fn test_agent_context_missing_required_field_fails() {
let json = r#"{
"agent_context": {
"session_type_id": "deep_research:v1",
"trajectory_id": "run-123:researcher-0"
}
}"#;
assert!(serde_json::from_str::<NvExt>(json).is_err());
}
#[test]
fn test_apply_header_routing_overrides() {
use axum::http::HeaderMap;
let mut headers = HeaderMap::new();
headers.insert(HEADER_WORKER_INSTANCE_ID, "123".parse().unwrap());
headers.insert(HEADER_PREFILL_INSTANCE_ID, "456".parse().unwrap());
headers.insert(HEADER_DP_RANK, "3".parse().unwrap());
headers.insert(HEADER_PREFILL_DP_RANK, "5".parse().unwrap());
let result = apply_header_routing_overrides(None, &headers).unwrap();
assert_eq!(result.backend_instance_id, Some(123));
assert_eq!(result.decode_worker_id, Some(123));
assert_eq!(result.prefill_worker_id, Some(456));
assert_eq!(result.dp_rank, Some(3));
assert_eq!(result.prefill_dp_rank, Some(5));
}
#[test]
fn test_nvext_response_field_selection_defaults_to_none() {
let selection = NvExtResponseFieldSelection::from_nvext(None);
assert_eq!(selection, NvExtResponseFieldSelection::default());
}
#[test]
fn test_nvext_response_field_selection_respects_extra_fields() {
let nvext = NvExt::builder()
.extra_fields(vec!["worker_id".to_string(), "routed_experts".to_string()])
.build()
.unwrap();
let selection = NvExtResponseFieldSelection::from_nvext(Some(&nvext));
assert!(selection.worker_id);
assert!(!selection.timing);
assert!(!selection.token_ids);
assert!(selection.routed_experts);
}
#[test]
fn test_nvext_response_field_selection_query_instance_id_exception() {
let nvext = NvExt::builder()
.annotations(vec!["query_instance_id:".to_string()])
.build()
.unwrap();
let selection = NvExtResponseFieldSelection::from_nvext(Some(&nvext));
assert!(selection.worker_id);
assert!(!selection.timing); assert!(selection.token_ids);
assert!(!selection.routed_experts);
}
#[test]
fn test_nvext_response_field_selection_rejects_stray_annotation() {
let nvext = NvExt::builder()
.annotations(vec!["query_instance_id_extra:foo".to_string()])
.build()
.unwrap();
assert_eq!(
NvExtResponseFieldSelection::from_nvext(Some(&nvext)),
NvExtResponseFieldSelection::default(),
);
}
#[test]
fn test_nvext_response_field_selection_worker_id_only() {
let nvext = NvExt::builder()
.extra_fields(vec!["worker_id".to_string()])
.build()
.unwrap();
assert_eq!(
NvExtResponseFieldSelection::from_nvext(Some(&nvext)),
NvExtResponseFieldSelection {
worker_id: true,
..Default::default()
}
);
}
#[test]
fn test_nvext_response_field_selection_timing_only() {
let nvext = NvExt::builder()
.extra_fields(vec!["timing".to_string()])
.build()
.unwrap();
assert_eq!(
NvExtResponseFieldSelection::from_nvext(Some(&nvext)),
NvExtResponseFieldSelection {
timing: true,
..Default::default()
}
);
}
#[test]
fn test_nvext_response_field_selection_routed_experts_only() {
let nvext = NvExt::builder()
.extra_fields(vec!["routed_experts".to_string()])
.build()
.unwrap();
assert_eq!(
NvExtResponseFieldSelection::from_nvext(Some(&nvext)),
NvExtResponseFieldSelection {
routed_experts: true,
..Default::default()
}
);
}
#[test]
fn test_nvext_response_field_selection_stop_reason_only() {
let nvext = NvExt::builder()
.extra_fields(vec!["stop_reason".to_string()])
.build()
.unwrap();
assert_eq!(
NvExtResponseFieldSelection::from_nvext(Some(&nvext)),
NvExtResponseFieldSelection {
stop_reason: true,
..Default::default()
}
);
}
fn sel_all_false() -> NvExtResponseFieldSelection {
NvExtResponseFieldSelection::default()
}
fn tracker_with_prefill_worker()
-> std::sync::Arc<crate::protocols::common::timing::RequestTracker> {
use crate::protocols::common::timing::{RequestTracker, WORKER_TYPE_PREFILL};
let tracker = std::sync::Arc::new(RequestTracker::new());
tracker.record_worker(42, Some(0), WORKER_TYPE_PREFILL);
tracker
}
fn disagg_params_full() -> serde_json::Value {
serde_json::json!({
"token_ids": [11u32, 22u32, 33u32],
"routed_experts": {"layer_0": [1, 3]},
})
}
#[test]
fn test_build_response_nvext_all_false_returns_none() {
let sel = sel_all_false();
assert!(
sel.build_response_nvext(None, None, false, None, None)
.is_none(),
"no fields selected → None"
);
assert!(
sel.build_response_nvext(None, None, true, None, None)
.is_none(),
"finish_reason alone does not force emission"
);
}
#[test]
fn test_build_response_nvext_worker_id_only_without_finish() {
let sel = NvExtResponseFieldSelection {
worker_id: true,
..Default::default()
};
let tracker = tracker_with_prefill_worker();
let out = sel
.build_response_nvext(Some(&tracker), None, false, None, None)
.expect("worker_id should emit regardless of finish_reason");
assert!(out.worker_id.is_some());
assert!(out.timing.is_none());
assert!(out.token_ids.is_none());
assert!(out.routed_experts.is_none());
}
#[test]
fn test_build_response_nvext_timing_suppressed_without_finish() {
let sel = NvExtResponseFieldSelection {
timing: true,
..Default::default()
};
let tracker = tracker_with_prefill_worker();
assert!(
sel.build_response_nvext(Some(&tracker), None, false, None, None)
.is_none(),
"timing is gated on finish_reason_present"
);
}
#[test]
fn test_build_response_nvext_timing_emitted_on_finish() {
let sel = NvExtResponseFieldSelection {
timing: true,
..Default::default()
};
let tracker = tracker_with_prefill_worker();
let out = sel
.build_response_nvext(Some(&tracker), None, true, None, None)
.expect("timing should emit on finish");
assert!(out.timing.is_some());
assert!(out.worker_id.is_none());
assert!(out.token_ids.is_none());
assert!(out.routed_experts.is_none());
}
#[test]
fn test_build_response_nvext_timing_requires_tracker() {
let sel = NvExtResponseFieldSelection {
timing: true,
..Default::default()
};
assert!(
sel.build_response_nvext(None, None, true, None, None)
.is_none()
);
}
#[test]
fn test_build_response_nvext_token_ids_from_disagg_params() {
let sel = NvExtResponseFieldSelection {
token_ids: true,
..Default::default()
};
let params = disagg_params_full();
let out = sel
.build_response_nvext(None, Some(¶ms), false, None, None)
.expect("token_ids should emit when present");
assert_eq!(out.token_ids, Some(vec![11u32, 22, 33]));
assert!(out.worker_id.is_none());
assert!(out.timing.is_none());
assert!(out.routed_experts.is_none());
}
#[test]
fn test_build_response_nvext_token_ids_malformed_falls_back_to_none() {
let sel = NvExtResponseFieldSelection {
token_ids: true,
..Default::default()
};
let params = serde_json::json!({ "token_ids": "not-an-array" });
assert!(
sel.build_response_nvext(None, Some(¶ms), false, None, None)
.is_none(),
"malformed token_ids silently suppressed; nothing else selected → None"
);
}
#[test]
fn test_build_response_nvext_routed_experts_cloned_as_is() {
let sel = NvExtResponseFieldSelection {
routed_experts: true,
..Default::default()
};
let params = disagg_params_full();
let out = sel
.build_response_nvext(None, Some(¶ms), false, None, None)
.expect("routed_experts should emit when present");
assert_eq!(
out.routed_experts,
Some(serde_json::json!({"layer_0": [1, 3]}))
);
}
#[test]
fn test_build_response_nvext_stop_reason_when_requested() {
let sel = NvExtResponseFieldSelection {
stop_reason: true,
..Default::default()
};
let out = sel
.build_response_nvext(
None,
None,
true,
None,
Some(StopReason::String("END".to_string())),
)
.expect("stop_reason should emit when requested and present");
assert_eq!(out.stop_reason, Some(serde_json::json!("END")));
assert!(out.worker_id.is_none());
assert!(out.timing.is_none());
assert!(out.token_ids.is_none());
assert!(out.routed_experts.is_none());
}
#[test]
fn test_build_response_nvext_stop_reason_suppressed_when_absent() {
let sel = NvExtResponseFieldSelection {
stop_reason: true,
..Default::default()
};
assert!(
sel.build_response_nvext(None, None, true, None, None)
.is_none()
);
}
#[test]
fn test_build_response_nvext_combined_emission() {
let sel = NvExtResponseFieldSelection {
worker_id: true,
timing: true,
token_ids: true,
routed_experts: true,
engine_data: false,
stop_reason: false,
};
let tracker = tracker_with_prefill_worker();
let params = disagg_params_full();
let out = sel
.build_response_nvext(Some(&tracker), Some(¶ms), true, None, None)
.expect("all fields selected and available → Some");
assert!(out.worker_id.is_some());
assert!(out.timing.is_some());
assert_eq!(out.token_ids, Some(vec![11u32, 22, 33]));
assert_eq!(
out.routed_experts,
Some(serde_json::json!({"layer_0": [1, 3]}))
);
}
#[test]
fn test_nvext_response_field_selection_multiple_extra_fields() {
let nvext = NvExt::builder()
.extra_fields(vec![
"worker_id".to_string(),
"timing".to_string(),
"routed_experts".to_string(),
])
.build()
.unwrap();
assert_eq!(
NvExtResponseFieldSelection::from_nvext(Some(&nvext)),
NvExtResponseFieldSelection {
worker_id: true,
timing: true,
token_ids: false, routed_experts: true,
engine_data: false,
stop_reason: false,
}
);
}
}