use dynamo_llm::protocols::common;
use dynamo_llm::protocols::common::llm_backend::BackendOutput;
use dynamo_protocols::types::{
ChatCompletionMessageContent, ChatCompletionNamedToolChoice, ChatCompletionRequestMessage,
ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
ChatCompletionToolChoiceOption, ChatCompletionToolType, CreateChatCompletionRequest,
FunctionName, FunctionType,
};
fn get_text(content: &ChatCompletionMessageContent) -> &str {
match content {
ChatCompletionMessageContent::Text(text) => text.as_str(),
ChatCompletionMessageContent::Parts(_) => "",
}
}
use dynamo_llm::protocols::openai::DeltaGeneratorExt;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
fn create_test_request() -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
name: None,
},
)];
NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages,
stream: Some(false),
stream_options: None,
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
}
}
async fn apply_jail_transformation(
raw_response: dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse,
tool_choice: Option<ChatCompletionToolChoiceOption>,
) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse {
use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt;
use futures::stream;
let input_stream = stream::iter(vec![Annotated {
data: Some(raw_response),
id: None,
event: None,
comment: None,
error: None,
}]);
let mut builder = JailedStream::builder();
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(ref named)) => {
builder = builder.tool_choice_named(named.function.name.clone());
}
Some(ChatCompletionToolChoiceOption::Required) => {
builder = builder.tool_choice_required();
}
_ => {}
}
let jail = builder.build();
let output_stream = jail.apply_with_finish_reason(input_stream);
tokio::pin!(output_stream);
output_stream.next().await.unwrap().data.unwrap()
}
async fn apply_jail_transformation_streaming(
raw_responses: Vec<
dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse,
>,
tool_choice: Option<ChatCompletionToolChoiceOption>,
) -> Vec<dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse> {
use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt;
use futures::stream;
let input_stream = stream::iter(raw_responses.into_iter().map(|r| Annotated {
data: Some(r),
id: None,
event: None,
comment: None,
error: None,
}));
let mut builder = JailedStream::builder();
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(ref named)) => {
builder = builder.tool_choice_named(named.function.name.clone());
}
Some(ChatCompletionToolChoiceOption::Required) => {
builder = builder.tool_choice_required();
}
_ => {}
}
let jail = builder.build();
let output_stream = jail.apply_with_finish_reason(input_stream);
tokio::pin!(output_stream);
output_stream
.filter_map(|ann| async move { ann.data })
.collect()
.await
}
fn build_backend_output(text: &str) -> BackendOutput {
BackendOutput {
token_ids: vec![],
tokens: vec![],
text: Some(text.to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(common::FinishReason::Stop),
stop_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: None,
}
}
#[tokio::test]
async fn test_named_tool_choice_parses_json() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Named(
ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: FunctionName {
name: "get_weather".to_string(),
},
},
));
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-1".to_string());
let backend_output = build_backend_output(r#"{"location":"Paris"}"#);
let raw_response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
let response = apply_jail_transformation(raw_response, tool_choice).await;
let choice = &response.inner.choices[0];
assert_eq!(
choice.finish_reason,
Some(dynamo_protocols::types::FinishReason::Stop)
);
let delta = &choice.delta;
assert!(delta.content.is_none() || delta.content.as_ref().map(get_text) == Some(""));
let tool_calls = delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
let tool_call = &tool_calls[0];
assert_eq!(tool_call.index, 0);
assert!(tool_call.id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_call.r#type, Some(FunctionType::Function));
assert_eq!(
tool_call.function.as_ref().unwrap().name.as_deref(),
Some("get_weather")
);
assert_eq!(
tool_call.function.as_ref().unwrap().arguments.as_deref(),
Some(r#"{"location":"Paris"}"#)
);
}
#[tokio::test]
async fn test_required_tool_choice_parses_json_array() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Required);
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-2".to_string());
let backend_output = build_backend_output(
r#"[{"name":"search","parameters":{"query":"rust"}},
{"name":"summarize","parameters":{"topic":"memory"}}]"#,
);
let raw_response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
let response = apply_jail_transformation(raw_response, tool_choice).await;
let choice = &response.inner.choices[0];
assert_eq!(
choice.finish_reason,
Some(dynamo_protocols::types::FinishReason::ToolCalls)
);
let delta = &choice.delta;
assert!(delta.content.is_none() || delta.content.as_ref().map(get_text) == Some(""));
let tool_calls = delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 2);
assert_eq!(tool_calls[0].index, 0);
assert!(tool_calls[0].id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_calls[0].r#type, Some(FunctionType::Function));
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("search")
);
assert_eq!(
tool_calls[0]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"query":"rust"}"#)
);
assert_eq!(tool_calls[1].index, 1);
assert!(tool_calls[1].id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_calls[1].r#type, Some(FunctionType::Function));
assert_eq!(
tool_calls[1].function.as_ref().unwrap().name.as_deref(),
Some("summarize")
);
assert_eq!(
tool_calls[1]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"topic":"memory"}"#)
);
}
#[tokio::test]
async fn test_tool_choice_parse_failure_returns_as_content() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Required);
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-3".to_string());
let backend_output = build_backend_output("not-json");
let raw_response = generator
.choice_from_postprocessor(backend_output)
.expect("choice generation");
let response = apply_jail_transformation(raw_response, tool_choice).await;
let delta = &response.inner.choices[0].delta;
assert_eq!(delta.content.as_ref().map(get_text), Some("not-json"));
assert!(delta.tool_calls.is_none());
}
#[tokio::test]
async fn test_streaming_named_tool_buffers_until_finish() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Named(
ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: FunctionName {
name: "get_weather".to_string(),
},
},
));
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-stream-1".to_string());
let chunks = [r#"{"location":""#, r#"Paris","unit":""#, r#"celsius"}"#];
let mut raw_responses = Vec::new();
for (i, chunk) in chunks.iter().enumerate() {
let backend_output = BackendOutput {
token_ids: vec![],
tokens: vec![],
text: Some(chunk.to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: if i == chunks.len() - 1 {
Some(common::FinishReason::Stop)
} else {
None
},
stop_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: None,
};
let response = generator
.choice_from_postprocessor(backend_output)
.expect("streaming chunk");
raw_responses.push(response);
}
let all_responses = apply_jail_transformation_streaming(raw_responses, tool_choice).await;
assert_eq!(all_responses.len(), 1);
let response = &all_responses[0];
assert_eq!(
response.inner.choices[0].finish_reason,
Some(dynamo_protocols::types::FinishReason::Stop)
);
let tool_calls = response.inner.choices[0].delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("get_weather")
);
assert_eq!(
tool_calls[0]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"location":"Paris","unit":"celsius"}"#)
);
}
#[tokio::test]
async fn test_streaming_required_tool_parallel() {
let mut request = create_test_request();
let tool_choice = Some(ChatCompletionToolChoiceOption::Required);
request.inner.tool_choice = tool_choice.clone();
let mut generator = request.response_generator("req-stream-2".to_string());
let chunks = [
r#"[{"name":"search","parameters":{"query":"rust"}},"#,
r#"{"name":"summarize","parameters":{"topic":"memory"}}]"#,
];
let mut raw_responses = Vec::new();
for (i, chunk) in chunks.iter().enumerate() {
let backend_output = BackendOutput {
token_ids: vec![],
tokens: vec![],
text: Some(chunk.to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: if i == chunks.len() - 1 {
Some(common::FinishReason::Stop)
} else {
None
},
stop_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: None,
};
let response = generator
.choice_from_postprocessor(backend_output)
.expect("streaming chunk");
raw_responses.push(response);
}
let all_responses = apply_jail_transformation_streaming(raw_responses, tool_choice).await;
assert_eq!(all_responses.len(), 1);
let response = &all_responses[0];
assert_eq!(
response.inner.choices[0].finish_reason,
Some(dynamo_protocols::types::FinishReason::ToolCalls)
);
let tool_calls = response.inner.choices[0].delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 2);
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("search")
);
assert_eq!(
tool_calls[0]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"query":"rust"}"#)
);
assert_eq!(
tool_calls[1].function.as_ref().unwrap().name.as_deref(),
Some("summarize")
);
assert_eq!(
tool_calls[1]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some(r#"{"topic":"memory"}"#)
);
}
#[test]
fn test_no_tool_choice_outputs_normal_text() {
let request = create_test_request();
let mut generator = request.response_generator("req-stream-4".to_string());
let backend_output = BackendOutput {
token_ids: vec![],
tokens: vec![],
text: Some("Hello world".to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
stop_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: None,
};
let response = generator
.choice_from_postprocessor(backend_output)
.expect("normal text");
assert_eq!(
response.inner.choices[0]
.delta
.content
.as_ref()
.map(get_text),
Some("Hello world")
);
assert!(response.inner.choices[0].delta.tool_calls.is_none());
}
fn make_text_chunk(
text: &str,
finish: bool,
) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse {
use dynamo_protocols::types::{
ChatChoiceStream, ChatCompletionMessageContent, ChatCompletionStreamResponseDelta, Role,
};
#[allow(deprecated)]
dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse {
inner: dynamo_protocols::types::CreateChatCompletionStreamResponse {
id: "test-named-parser".to_string(),
choices: vec![ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some(ChatCompletionMessageContent::Text(text.to_string())),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: if finish {
Some(dynamo_protocols::types::FinishReason::Stop)
} else {
None
},
stop_reason: None,
logprobs: None,
}],
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
},
nvext: None,
}
}
async fn apply_jail_named_with_parser(
chunks: Vec<
dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse,
>,
parser: &str,
named_tool: &str,
) -> Vec<dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse> {
use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt;
use futures::stream;
let input = stream::iter(chunks.into_iter().map(|r| Annotated {
data: Some(r),
id: None,
event: None,
comment: None,
error: None,
}));
let jail = JailedStream::builder()
.tool_call_parser(parser)
.named_tool_filter(named_tool)
.build();
let out = jail.apply_with_finish_reason(input);
tokio::pin!(out);
out.filter_map(|ann| async move { ann.data })
.collect()
.await
}
#[tokio::test]
async fn test_named_tool_with_parser_correct_tool_passes() {
let hermes_payload = "<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}\n</tool_call>";
let chunks = vec![
make_text_chunk(hermes_payload, false),
make_text_chunk("", true), ];
let responses = apply_jail_named_with_parser(chunks, "hermes", "get_weather").await;
let tool_call_response = responses
.iter()
.find(|r| {
r.inner
.choices
.first()
.and_then(|c| c.delta.tool_calls.as_ref())
.is_some()
})
.expect("expected a response with tool calls for the correct named tool");
let tool_calls = tool_call_response.inner.choices[0]
.delta
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tool_calls.len(), 1, "expected exactly one tool call");
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("get_weather"),
"tool call name should be get_weather"
);
}
#[tokio::test]
async fn test_named_tool_with_parser_wrong_tool_is_filtered() {
let hermes_wrong_tool = "<tool_call>\n{\"name\": \"search\", \"arguments\": {\"query\": \"Paris weather\"}}\n</tool_call>";
let chunks = vec![
make_text_chunk(hermes_wrong_tool, false),
make_text_chunk("", true),
];
let responses = apply_jail_named_with_parser(chunks, "hermes", "get_weather").await;
for r in &responses {
if let Some(choice) = r.inner.choices.first()
&& let Some(tool_calls) = &choice.delta.tool_calls
{
for tc in tool_calls {
let name = tc
.function
.as_ref()
.and_then(|f| f.name.as_deref())
.unwrap_or("");
assert_ne!(
name, "search",
"wrong tool 'search' should have been filtered by named_tool_filter"
);
}
}
}
}