use std::collections::HashMap;
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use validator::Validate;
use crate::engines::ValidateRequest;
use crate::preprocessor::media::MediaDecoder;
use super::{
OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
common_ext::{CommonExt, CommonExtProvider},
nvext::NvExt,
nvext::NvExtProvider,
validate,
};
pub mod aggregator;
mod delta;
pub mod jail;
pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
use dynamo_parsers::tool_calling::{ToolCallResponse, ToolCallResponseChunk};
use dynamo_protocols::types::{
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, FunctionCall,
FunctionCallStream, FunctionType,
};
pub(crate) fn tool_call_response_to_protocol(
parsed: ToolCallResponse,
) -> ChatCompletionMessageToolCall {
ChatCompletionMessageToolCall {
id: parsed.id,
r#type: FunctionType::Function,
function: FunctionCall {
name: parsed.function.name,
arguments: parsed.function.arguments,
},
}
}
#[allow(dead_code)]
pub(crate) fn tool_call_response_chunk_to_protocol(
parsed: ToolCallResponseChunk,
) -> ChatCompletionMessageToolCallChunk {
ChatCompletionMessageToolCallChunk {
index: parsed.index,
id: parsed.id,
r#type: parsed.tp.map(|_| FunctionType::Function),
function: parsed.function.map(|f| FunctionCallStream {
name: f.name,
arguments: f.arguments,
}),
}
}
#[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateChatCompletionRequest {
#[serde(flatten)]
#[schema(value_type = Object)]
pub inner: dynamo_protocols::types::CreateChatCompletionRequest,
#[serde(flatten, default)]
pub common: CommonExt,
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
alias = "chat_template_kwargs"
)]
pub chat_template_args: Option<std::collections::HashMap<String, serde_json::Value>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub thinking: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub media_io_kwargs: Option<MediaDecoder>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub return_tokens_as_token_ids: Option<bool>,
#[serde(flatten, default, skip_serializing)]
pub unsupported_fields: std::collections::HashMap<String, serde_json::Value>,
}
impl NvCreateChatCompletionRequest {
pub fn normalize_reasoning_template_args(&mut self) -> anyhow::Result<()> {
let thinking_enabled = self
.thinking
.as_ref()
.map(openai_thinking_enabled)
.transpose()?
.flatten();
let reasoning_effort = self
.inner
.reasoning_effort
.as_ref()
.and_then(|effort| serde_json::to_value(effort).ok());
if thinking_enabled.is_none() && reasoning_effort.is_none() {
return Ok(());
}
let args = self.chat_template_args.get_or_insert_with(HashMap::new);
if let Some(enabled) = thinking_enabled {
args.entry("thinking".to_string())
.or_insert(serde_json::Value::Bool(enabled));
}
if let Some(effort) = reasoning_effort {
args.entry("reasoning_effort".to_string()).or_insert(effort);
}
self.thinking = None;
Ok(())
}
}
fn openai_thinking_enabled(value: &serde_json::Value) -> anyhow::Result<Option<bool>> {
if let Some(enabled) = value.as_bool() {
return Ok(Some(enabled));
}
let Some(thinking_object) = value.as_object() else {
anyhow::bail!(
"`thinking` must be a boolean or an object with `type` set to `enabled` or `disabled`"
);
};
let Some(thinking_type) = thinking_object.get("type").and_then(|v| v.as_str()) else {
anyhow::bail!("`thinking.type` must be `enabled` or `disabled`");
};
match thinking_type {
"enabled" => Ok(Some(true)),
"disabled" => Ok(Some(false)),
_ => anyhow::bail!("`thinking.type` must be `enabled` or `disabled`"),
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct NvCreateChatCompletionResponse {
#[serde(flatten)]
pub inner: dynamo_protocols::types::CreateChatCompletionResponse,
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<serde_json::Value>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct NvCreateChatCompletionStreamResponse {
#[serde(flatten)]
pub inner: dynamo_protocols::types::CreateChatCompletionStreamResponse,
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<serde_json::Value>,
}
impl NvExtProvider for NvCreateChatCompletionRequest {
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
fn raw_prompt(&self) -> Option<String> {
None
}
fn unsupported_fields(&self) -> Option<&std::collections::HashMap<String, serde_json::Value>> {
Some(&self.unsupported_fields)
}
}
impl AnnotationsProvider for NvCreateChatCompletionRequest {
fn annotations(&self) -> Option<Vec<String>> {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.clone())
}
fn has_annotation(&self, annotation: &str) -> bool {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.as_ref())
.map(|annotations| annotations.contains(&annotation.to_string()))
.unwrap_or(false)
}
}
impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
fn get_temperature(&self) -> Option<f32> {
self.inner.temperature
}
fn get_top_p(&self) -> Option<f32> {
self.inner.top_p
}
fn get_frequency_penalty(&self) -> Option<f32> {
self.inner.frequency_penalty
}
fn get_presence_penalty(&self) -> Option<f32> {
self.inner.presence_penalty
}
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
fn get_seed(&self) -> Option<i64> {
self.inner.seed
}
fn get_n(&self) -> Option<u8> {
self.inner.n
}
fn get_best_of(&self) -> Option<u8> {
None }
}
impl CommonExtProvider for NvCreateChatCompletionRequest {
fn common_ext(&self) -> Option<&CommonExt> {
Some(&self.common)
}
fn get_guided_json(&self) -> Option<serde_json::Value> {
if let Some(value) = self.common.guided_json.clone() {
return Some(value);
}
if let Some(response_format) = self.inner.response_format.as_ref() {
use dynamo_protocols::types::ResponseFormat;
match response_format {
ResponseFormat::Text => {}
ResponseFormat::JsonObject => {
return Some(serde_json::json!({
"type": "object"
}));
}
ResponseFormat::JsonSchema { json_schema } => {
if let Some(schema) = json_schema.schema.clone() {
return Some(schema);
}
}
}
}
None
}
fn get_guided_regex(&self) -> Option<String> {
self.common.guided_regex.clone()
}
fn get_guided_grammar(&self) -> Option<String> {
self.common.guided_grammar.clone()
}
fn get_guided_choice(&self) -> Option<Vec<String>> {
self.common.guided_choice.clone()
}
fn get_guided_decoding_backend(&self) -> Option<String> {
self.common.guided_decoding_backend.clone()
}
fn get_guided_whitespace_pattern(&self) -> Option<String> {
self.common.guided_whitespace_pattern.clone()
}
fn get_top_k(&self) -> Option<i32> {
self.common.top_k
}
fn get_min_p(&self) -> Option<f32> {
self.common.min_p
}
fn get_repetition_penalty(&self) -> Option<f32> {
self.common.repetition_penalty
}
fn get_include_stop_str_in_output(&self) -> Option<bool> {
self.common.include_stop_str_in_output
}
fn get_skip_special_tokens(&self) -> Option<bool> {
self.common.skip_special_tokens
}
fn get_prompt_logprobs_count(&self) -> Option<u32> {
self.common.prompt_logprobs
}
}
impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
#[allow(deprecated)]
fn get_max_tokens(&self) -> Option<u32> {
self.inner.max_completion_tokens.or(self.inner.max_tokens)
}
fn get_min_tokens(&self) -> Option<u32> {
self.common.min_tokens
}
fn get_stop(&self) -> Option<Vec<String>> {
self.inner.stop.as_ref().and_then(|stop| stop.strings())
}
fn get_stop_token_ids(&self) -> Option<Vec<crate::types::TokenIdType>> {
if let Some(ids) = self.inner.stop.as_ref().and_then(|stop| stop.token_ids()) {
return Some(ids);
}
self.unsupported_fields
.get("stop_token_ids")
.and_then(|v| serde_json::from_value::<Vec<crate::types::TokenIdType>>(v.clone()).ok())
}
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
fn get_common_ignore_eos(&self) -> Option<bool> {
self.common.ignore_eos
}
fn get_ignore_eos(&self) -> Option<bool> {
self.common.ignore_eos
}
}
impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest {
fn get_logprobs(&self) -> Option<u32> {
match self.inner.logprobs {
Some(true) => match self.inner.top_logprobs {
Some(top_logprobs) => Some(top_logprobs as u32),
None => Some(1_u32),
},
Some(false) => None,
None => None,
}
}
fn get_prompt_logprobs(&self) -> Option<u32> {
self.common.prompt_logprobs
}
fn get_skip_special_tokens(&self) -> Option<bool> {
CommonExtProvider::get_skip_special_tokens(self)
}
fn get_formatted_prompt(&self) -> Option<bool> {
None
}
fn get_return_tokens_as_token_ids(&self) -> Option<bool> {
self.return_tokens_as_token_ids
}
}
impl ValidateRequest for NvCreateChatCompletionRequest {
fn validate(&self) -> Result<(), anyhow::Error> {
validate::validate_no_unsupported_fields(&self.unsupported_fields)?;
validate::validate_messages(&self.inner.messages)?;
validate::validate_model(&self.inner.model)?;
validate::validate_reasoning_effort(&self.inner.reasoning_effort)?;
validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
validate::validate_logit_bias(&self.inner.logit_bias)?;
validate::validate_top_logprobs(self.inner.top_logprobs)?;
validate::validate_max_completion_tokens(self.inner.max_completion_tokens)?;
validate::validate_n(self.inner.n)?;
super::nvext::validate_completion_token_ids_single_choice(
self.inner.n.unwrap_or(1) as usize,
self.nvext.as_ref(),
)?;
validate::validate_presence_penalty(self.inner.presence_penalty)?;
validate::validate_response_format(&self.inner.response_format)?;
validate::validate_service_tier(&self.inner.service_tier)?;
validate::validate_stop(&self.inner.stop)?;
validate::validate_temperature(self.inner.temperature)?;
validate::validate_top_p(self.inner.top_p)?;
validate::validate_tools(&self.inner.tools.as_deref())?;
validate::validate_tool_choice(&self.inner.tool_choice, self.inner.tools.as_deref())?;
validate::validate_user(self.inner.user.as_deref())?;
validate::validate_repetition_penalty(self.get_repetition_penalty())?;
validate::validate_min_p(self.get_min_p())?;
validate::validate_top_k(self.get_top_k())?;
validate::validate_n_with_temperature(self.inner.n, self.inner.temperature)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engines::ValidateRequest;
use crate::protocols::common::{OutputOptionsProvider, StopConditionsProvider};
use dynamo_protocols::types::{ChatCompletionTool, ChatCompletionToolType, FunctionObject};
use serde_json::json;
#[test]
fn test_skip_special_tokens_none() {
let json_str = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
assert_eq!(request.common.skip_special_tokens, None);
let output_options = request
.extract_output_options()
.expect("Failed to extract output options");
assert_eq!(output_options.skip_special_tokens, None);
}
#[test]
fn test_skip_special_tokens_propagates() {
for skip_value in [true, false] {
let json_str = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"skip_special_tokens": skip_value
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
let output_options = request
.extract_output_options()
.expect("Failed to extract output options");
assert_eq!(output_options.skip_special_tokens, Some(skip_value));
}
}
#[test]
fn test_stop_contract() {
let one_stop = json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"stop": " The"
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(one_stop).expect("Failed to deserialize request");
assert_eq!(request.get_stop(), Some(vec![" The".to_string()]));
assert_eq!(request.get_stop_token_ids(), None);
let many_stops = json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"stop": ["A", "B"]
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(many_stops).expect("Failed to deserialize request");
assert_eq!(
request.get_stop(),
Some(vec!["A".to_string(), "B".to_string()])
);
assert_eq!(request.get_stop_token_ids(), None);
let token_id_stops = json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"stop": [32, 34]
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(token_id_stops).expect("Failed to deserialize request");
assert_eq!(request.get_stop(), None);
assert_eq!(request.get_stop_token_ids(), Some(vec![32, 34]));
let stop_conditions = request
.extract_stop_conditions()
.expect("extract stop conditions");
assert_eq!(stop_conditions.stop, None);
assert_eq!(stop_conditions.stop_token_ids, Some(vec![32, 34]));
let token_id_display_string_stop = json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"stop": "token_id:576"
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(token_id_display_string_stop)
.expect("Failed to deserialize request");
assert_eq!(request.get_stop(), Some(vec!["token_id:576".to_string()]));
assert_eq!(request.get_stop_token_ids(), None);
let token_id_display_string_array_stop = json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"stop": ["token_id:576"]
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(token_id_display_string_array_stop)
.expect("Failed to deserialize request");
assert_eq!(request.get_stop(), Some(vec!["token_id:576".to_string()]));
assert_eq!(request.get_stop_token_ids(), None);
let scalar_token_id_stop = json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"stop": 576
});
let result: Result<NvCreateChatCompletionRequest, _> =
serde_json::from_value(scalar_token_id_stop);
assert!(result.is_err());
let whitelisted_stop_token_ids = json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"stop_token_ids": [576]
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(whitelisted_stop_token_ids)
.expect("Failed to deserialize request");
assert_eq!(request.get_stop_token_ids(), Some(vec![576]));
assert!(
ValidateRequest::validate(&request).is_ok(),
"stop_token_ids must be accepted via PASSTHROUGH_EXTRA_FIELDS"
);
let invalid_stop_token_ids = json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"stop_token_ids": "bad"
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(invalid_stop_token_ids).expect("Failed to deserialize request");
let err = ValidateRequest::validate(&request).expect_err("invalid stop_token_ids");
assert!(err.to_string().contains("stop_token_ids"));
}
#[test]
fn test_passthrough_token_constraints_validate() {
let request_json = json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"allowed_token_ids": [10, 11],
"bad_words_token_ids": [[12, 13]]
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(request_json).expect("Failed to deserialize request");
assert_eq!(
request.unsupported_fields.get("allowed_token_ids"),
Some(&serde_json::json!([10, 11]))
);
assert_eq!(
request.unsupported_fields.get("bad_words_token_ids"),
Some(&serde_json::json!([[12, 13]]))
);
assert!(ValidateRequest::validate(&request).is_ok());
}
#[test]
fn test_completion_token_ids_rejected_for_multi_choice() {
let request_json = json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"n": 2,
"nvext": {
"extra_fields": ["completion_token_ids"]
}
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(request_json).expect("Failed to deserialize request");
let err = ValidateRequest::validate(&request).expect_err("multi-choice token ids");
assert!(err.to_string().contains("completion_token_ids"));
}
#[test]
fn test_validate_tool_choice_required_rejects_empty_tools() {
let request_json = json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"tool_choice": "required"
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(request_json).expect("Failed to deserialize request");
let err = ValidateRequest::validate(&request).expect_err("required needs tools");
assert!(
err.to_string()
.contains("tool_choice is \"required\" but tools is empty")
);
}
#[test]
fn test_validate_tool_choice_named_rejects_missing_tool() {
let request_json = json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {"type": "object", "properties": {}}
}
}],
"tool_choice": {
"type": "function",
"function": {"name": "search"}
}
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(request_json).expect("Failed to deserialize request");
let err = ValidateRequest::validate(&request).expect_err("named tool must exist");
assert!(
err.to_string()
.contains("tool named \"search\" in tool_choice is not present in tools")
);
}
#[test]
fn test_truncate_prompt_tokens_rejected_until_supported() {
let request_json = json!({
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"truncate_prompt_tokens": 2
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(request_json).expect("Failed to deserialize request");
assert!(ValidateRequest::validate(&request).is_err());
}
use dynamo_parsers::tool_calling::{
CalledFunction, CalledFunctionStream, ToolCallResponse, ToolCallResponseChunk, ToolCallType,
};
fn native_call(id: &str, name: &str, args: &str) -> ToolCallResponse {
ToolCallResponse {
id: id.to_string(),
tp: ToolCallType::Function,
function: CalledFunction {
name: name.to_string(),
arguments: args.to_string(),
},
}
}
fn native_chunk(index: u32, id: &str, name: &str, args: &str) -> ToolCallResponseChunk {
ToolCallResponseChunk {
index,
id: Some(id.to_string()),
tp: Some(ToolCallType::Function),
function: Some(CalledFunctionStream {
name: Some(name.to_string()),
arguments: Some(args.to_string()),
}),
}
}
fn legacy_unary(id: &str, name: &str, args: &str) -> ChatCompletionMessageToolCall {
ChatCompletionMessageToolCall {
id: id.to_string(),
r#type: FunctionType::Function,
function: FunctionCall {
name: name.to_string(),
arguments: args.to_string(),
},
}
}
fn legacy_chunk(
index: u32,
id: &str,
name: &str,
args: &str,
) -> ChatCompletionMessageToolCallChunk {
ChatCompletionMessageToolCallChunk {
index,
id: Some(id.to_string()),
r#type: Some(FunctionType::Function),
function: Some(FunctionCallStream {
name: Some(name.to_string()),
arguments: Some(args.to_string()),
}),
}
}
#[test]
fn unary_mapping_matches_legacy_struct_and_json() {
for (id, name, args) in [
(
"call_1",
"get_weather",
r#"{"location":"SF","unit":"celsius"}"#,
),
("call_2", "ping", "{}"), ] {
let mapped = tool_call_response_to_protocol(native_call(id, name, args));
let legacy = legacy_unary(id, name, args);
assert_eq!(mapped, legacy, "struct mismatch for {name}");
assert_eq!(
serde_json::to_string(&mapped).unwrap(),
serde_json::to_string(&legacy).unwrap(),
"serialized JSON mismatch for {name}"
);
}
}
#[test]
fn unary_mapping_multi_call_matches_legacy() {
let inputs = [
("a", "first", r#"{"k":"v1"}"#),
("b", "second", r#"{"k":"v2"}"#),
];
let mapped: Vec<_> = inputs
.iter()
.map(|(id, n, a)| tool_call_response_to_protocol(native_call(id, n, a)))
.collect();
let legacy: Vec<_> = inputs
.iter()
.map(|(id, n, a)| legacy_unary(id, n, a))
.collect();
assert_eq!(mapped, legacy);
assert_eq!(
serde_json::to_string(&mapped).unwrap(),
serde_json::to_string(&legacy).unwrap()
);
}
#[test]
fn stream_mapping_matches_legacy_struct_and_json() {
for (idx, id, name, args) in [
(0u32, "call_1", "get_weather", r#"{"location":"SF"}"#),
(1u32, "call_2", "ping", "{}"), ] {
let mapped = tool_call_response_chunk_to_protocol(native_chunk(idx, id, name, args));
let legacy = legacy_chunk(idx, id, name, args);
assert_eq!(mapped, legacy, "struct mismatch for {name}");
assert_eq!(
serde_json::to_string(&mapped).unwrap(),
serde_json::to_string(&legacy).unwrap(),
"serialized JSON mismatch for {name}"
);
}
}
#[test]
fn stream_mapping_multi_call_indexes_and_matches_legacy() {
let inputs = [
(0u32, "a", "first", r#"{"k":"v1"}"#),
(1u32, "b", "second", r#"{"k":"v2"}"#),
];
let mapped: Vec<_> = inputs
.iter()
.map(|(i, id, n, a)| tool_call_response_chunk_to_protocol(native_chunk(*i, id, n, a)))
.collect();
let legacy: Vec<_> = inputs
.iter()
.map(|(i, id, n, a)| legacy_chunk(*i, id, n, a))
.collect();
assert_eq!(mapped, legacy);
assert_eq!(
serde_json::to_string(&mapped).unwrap(),
serde_json::to_string(&legacy).unwrap()
);
}
#[test]
fn test_validate_tools_valid_names() {
fn make_tool(name: &str) -> ChatCompletionTool {
ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: FunctionObject {
name: name.to_string(),
description: None,
parameters: Some(json!({"type": "object", "properties": {}})),
strict: None,
},
}
}
let tools = vec![
make_tool("func_name"),
make_tool("func-name_v2"),
make_tool("FuncName"),
make_tool("Func_Name-123"),
];
assert!(validate::validate_tools(&Some(&tools)).is_ok());
}
#[test]
fn test_validate_tools_invalid_names() {
for name in ["<func_name>", "func name", "func@name", "func,name", ""] {
let tools = vec![ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: FunctionObject {
name: name.to_string(),
description: None,
parameters: Some(json!({"type": "object", "properties": {}})),
strict: None,
},
}];
assert!(
validate::validate_tools(&Some(&tools)).is_err(),
"expected error for name: {name:?}"
);
}
}
#[test]
fn test_openai_thinking_payload_normalizes_to_template_args() {
let json_str = json!({
"model": "deepseek-ai/DeepSeek-V4-Pro",
"messages": [
{"role": "user", "content": "Hello"}
],
"reasoning_effort": "max",
"thinking": {"type": "enabled"}
});
let mut request: NvCreateChatCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
request
.normalize_reasoning_template_args()
.expect("thinking payload should normalize");
let args = request
.chat_template_args
.as_ref()
.expect("chat_template_args should be populated");
assert_eq!(args.get("thinking"), Some(&json!(true)));
assert_eq!(args.get("reasoning_effort"), Some(&json!("max")));
}
#[test]
fn test_invalid_openai_thinking_payload_is_rejected() {
for invalid_thinking in [
json!("enabled"),
json!({"type": "auto"}),
json!({"type": true}),
json!({}),
] {
let json_str = json!({
"model": "deepseek-ai/DeepSeek-V4-Pro",
"messages": [
{"role": "user", "content": "Hello"}
],
"thinking": invalid_thinking
});
let mut request: NvCreateChatCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
assert!(request.normalize_reasoning_template_args().is_err());
}
}
}