use crate::preprocessor::{OpenAIPreprocessor, PreprocessedRequest};
use crate::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
use crate::protocols::openai::tools::get_json_schema_from_tools;
use dynamo_parsers::tool_calling::{ToolChoice, ToolDefinition};
use dynamo_protocols::types::{ChatCompletionTool, ChatCompletionToolChoiceOption, ResponseFormat};
use dynamo_runtime::error::{DynamoError, ErrorType};
fn invalid_argument(message: impl Into<String>) -> DynamoError {
DynamoError::builder()
.error_type(ErrorType::InvalidArgument)
.message(message)
.build()
}
impl OpenAIPreprocessor {
pub(super) fn apply_tool_choice_guided_decoding(
&self,
request: &NvCreateChatCompletionRequest,
common_request: &mut PreprocessedRequest,
prompt_injected_reasoning: bool,
) -> Result<bool, DynamoError> {
let tool_choice = request
.inner
.tool_choice
.as_ref()
.unwrap_or(&ChatCompletionToolChoiceOption::Auto);
let tools = request.inner.tools.as_deref().unwrap_or(&[]);
let is_forced_tool_choice = matches!(
tool_choice,
ChatCompletionToolChoiceOption::Required | ChatCompletionToolChoiceOption::Named(_)
);
let has_explicit_guided_decoding = has_explicit_guided_decoding(request);
let has_response_format_constraint = has_response_format_constraint(request);
if is_forced_tool_choice && has_explicit_guided_decoding {
return Err(invalid_argument(concat!(
"guided decoding cannot be used in the same request as ",
"tool_choice=\"required\" or a named tool_choice.",
)));
}
let has_assistant_constraint =
has_explicit_guided_decoding || has_response_format_constraint;
if !is_forced_tool_choice && has_assistant_constraint {
return Ok(false);
}
if is_forced_tool_choice
&& has_response_format_constraint
&& let Some(gd) = common_request.sampling_options.guided_decoding.as_mut()
{
gd.json = None;
}
if self.apply_tool_choice_structural_tag(
&convert_tool_choice(tool_choice),
&convert_tools(tools),
request.inner.parallel_tool_calls,
prompt_injected_reasoning,
common_request,
)? {
return Ok(true);
}
match get_json_schema_from_tools(Some(tool_choice), Some(tools)) {
Ok(Some(schema)) => {
let gd = common_request
.sampling_options
.guided_decoding
.get_or_insert_default();
gd.json = Some(schema);
}
Ok(None) => {}
Err(err) => {
return Err(invalid_argument(err.to_string()));
}
}
Ok(false)
}
}
fn has_explicit_guided_decoding(request: &NvCreateChatCompletionRequest) -> bool {
request.common.guided_json.is_some()
|| request.common.guided_regex.is_some()
|| request
.common
.guided_choice
.as_ref()
.is_some_and(|v| !v.is_empty())
|| request.common.guided_grammar.is_some()
}
fn has_response_format_constraint(request: &NvCreateChatCompletionRequest) -> bool {
request
.inner
.response_format
.as_ref()
.is_some_and(|format| !matches!(format, ResponseFormat::Text))
}
fn convert_tool_choice(tool_choice: &ChatCompletionToolChoiceOption) -> ToolChoice {
match tool_choice {
ChatCompletionToolChoiceOption::None => ToolChoice::None,
ChatCompletionToolChoiceOption::Auto => ToolChoice::Auto,
ChatCompletionToolChoiceOption::Required => ToolChoice::Required,
ChatCompletionToolChoiceOption::Named(named) => {
ToolChoice::Named(named.function.name.clone())
}
}
}
fn convert_tools(tools: &[ChatCompletionTool]) -> Vec<ToolDefinition> {
tools
.iter()
.map(|tool| ToolDefinition {
name: tool.function.name.clone(),
parameters: tool.function.parameters.clone(),
strict: tool.function.strict,
})
.collect()
}