use anyhow::{Ok, Result};
use dynamo_runtime::config::environment_names::model::huggingface as env_hf;
use dynamo_llm::model_card::{ModelDeploymentCard, PromptContextMixin};
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::preprocessor::prompt::PromptFormatter;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
use serde::{Deserialize, Serialize};
use hf_hub::{Cache, Repo, RepoType, api::tokio::ApiBuilder};
use rstest::rstest;
use std::path::PathBuf;
fn get_hf_token() -> Result<String> {
let token = std::env::var(env_hf::HF_TOKEN)
.map_err(|_| anyhow::anyhow!("HF_TOKEN environment variable is not set"))?;
if token.trim().is_empty() {
anyhow::bail!("HF_TOKEN environment variable is empty");
}
Ok(token)
}
async fn make_mdc_from_repo(
local_path: &str,
hf_repo: &str,
hf_revision: &str,
mixins: Option<Vec<PromptContextMixin>>,
) -> ModelDeploymentCard {
let downloaded_path = maybe_download_model(local_path, hf_repo, hf_revision).await;
let display_name = format!("{}--{}", hf_repo, hf_revision);
let mut mdc = ModelDeploymentCard::load_from_disk(downloaded_path, None).unwrap();
mdc.set_name(&display_name);
mdc.prompt_context = mixins;
mdc
}
async fn maybe_download_model(local_path: &str, model: &str, revision: &str) -> String {
let cache = Cache::new(PathBuf::from(local_path));
let token = get_hf_token().expect("HF_TOKEN is required to download models from Hugging Face");
let api = ApiBuilder::from_cache(cache)
.with_progress(false)
.with_token(Some(token))
.build()
.unwrap();
let repo = Repo::with_revision(String::from(model), RepoType::Model, String::from(revision));
let files_to_download = vec!["config.json", "tokenizer.json", "tokenizer_config.json"];
let optional_files = vec![
"generation_config.json",
"chat_template.jinja",
"chat_template.json",
];
let repo_builder = api.repo(repo);
let mut downloaded_path = PathBuf::new();
for file in &files_to_download {
downloaded_path = repo_builder.get(file).await.unwrap();
}
for file in &optional_files {
if let Err(e) = repo_builder.get(file).await {
println!(
"Failed to download optional file {} for model {}: {}",
file, model, e
);
}
}
downloaded_path.parent().unwrap().display().to_string()
}
async fn make_mdcs() -> Vec<ModelDeploymentCard> {
vec![
make_mdc_from_repo(
"tests/data/sample-models",
"meta-llama/Llama-3.1-70B-Instruct",
"1605565",
Some(vec![PromptContextMixin::Llama3DateTime]),
)
.await,
]
}
const SINGLE_CHAT_MESSAGE: &str = r#"
[
{
"role": "user",
"content": "What is deep learning?"
}
]
"#;
const THREE_TURN_CHAT_MESSAGE: &str = r#"
[
{
"role": "user",
"content": "How do I reverse a string in Python?"
},
{
"role": "assistant",
"content": "You can reverse a string in Python using slicing:\n\n```python\nreversed_string = your_string[::-1]\n```\n\nAlternatively, you can use `reversed()` with `join()`:\n\n```python\nreversed_string = ''.join(reversed(your_string))\n```\n"
},
{
"role": "user",
"content": "What if I want to reverse each word in a sentence but keep their order?"
}
]"#;
const THREE_TURN_CHAT_MESSAGE_WITH_SYSTEM: &str = r#"
[
{
"role": "system",
"content": "You are a very helpful assistant!"
},
{
"role": "user",
"content": "How do I reverse a string in Python?"
},
{
"role": "assistant",
"content": "You can reverse a string in Python using slicing:\n\n```python\nreversed_string = your_string[::-1]\n```\n\nAlternatively, you can use `reversed()` with `join()`:\n\n```python\nreversed_string = ''.join(reversed(your_string))\n```\n"
},
{
"role": "user",
"content": "What if I want to reverse each word in a sentence but keep their order?"
}
]"#;
const MULTI_TURN_WITH_CONTINUATION: &str = r#"
[
{
"role": "system",
"content": "You are a very helpful assistant!"
},
{
"role": "user",
"content": "How do I reverse a string in Python?"
},
{
"role": "assistant",
"content": "You can reverse a "
}
]"#;
const TOOLS: &str = r#"
[
{
"type": "function",
"function": {
"name": "get_current_temperature",
"description": "Get the current temperature for a specific location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g., San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["Celsius", "Fahrenheit"],
"description": "The temperature unit to use. Infer this from the user's location."
}
},
"required": ["location", "unit"]
}
}
},
{
"type": "function",
"function": {
"name": "get_rain_probability",
"description": "Get the probability of rain for a specific location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g., San Francisco, CA"
}
},
"required": ["location"]
}
}
}
]
"#;
#[derive(Serialize, Deserialize)]
struct Request {
messages: Vec<dynamo_protocols::types::ChatCompletionRequestMessage>,
tools: Option<Vec<dynamo_protocols::types::ChatCompletionTool>>,
tool_choice: Option<dynamo_protocols::types::ChatCompletionToolChoiceOption>,
}
impl Request {
fn from(
messages: &str,
tools: Option<&str>,
tool_choice: Option<dynamo_protocols::types::ChatCompletionToolChoiceOption>,
model: String,
) -> NvCreateChatCompletionRequest {
let messages: Vec<dynamo_protocols::types::ChatCompletionRequestMessage> =
serde_json::from_str(messages).unwrap();
let tools: Option<Vec<dynamo_protocols::types::ChatCompletionTool>> =
tools.map(|x| serde_json::from_str(x).unwrap());
let mut inner = dynamo_protocols::types::CreateChatCompletionRequestArgs::default();
inner.model(model);
inner.messages(messages);
if let Some(tools) = tools {
inner.tools(tools);
}
if let Some(tool_choice) = tool_choice {
inner.tool_choice(tool_choice);
}
let inner = inner.build().unwrap();
NvCreateChatCompletionRequest {
inner,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
}
}
}
#[tokio::test]
async fn test_single_turn() {
if let Err(e) = get_hf_token() {
println!("HF_TOKEN is not set, skipping test: {}", e);
return;
}
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc).unwrap();
let formatter = match formatter {
PromptFormatter::OAI(formatter) => Ok(formatter),
}
.unwrap();
let request = Request::from(SINGLE_CHAT_MESSAGE, None, None, mdc.slug().to_string());
let formatted_prompt = formatter.render(&request).unwrap();
insta::with_settings!({
info => &request,
snapshot_suffix => mdc.slug().to_string(),
filters => vec![
(r"Today Date: .*", "Today Date: <redacted>"),
]
}, {
insta::assert_snapshot!(formatted_prompt);
});
}
}
#[tokio::test]
async fn test_single_turn_with_tools() {
if let Err(e) = get_hf_token() {
println!("HF_TOKEN is not set, skipping test: {}", e);
return;
}
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc).unwrap();
let formatter = match formatter {
PromptFormatter::OAI(formatter) => Ok(formatter),
}
.unwrap();
let request = Request::from(
SINGLE_CHAT_MESSAGE,
Some(TOOLS),
Some(dynamo_protocols::types::ChatCompletionToolChoiceOption::Auto),
mdc.slug().to_string(),
);
let formatted_prompt = formatter.render(&request).unwrap();
insta::with_settings!({
info => &request,
snapshot_suffix => mdc.slug().to_string(),
filters => vec![
(r"Today Date: .*", "Today Date: <redacted>"),
]
}, {
insta::assert_snapshot!(formatted_prompt);
});
}
}
#[tokio::test]
async fn test_mulit_turn_without_system() {
if let Err(e) = get_hf_token() {
println!("HF_TOKEN is not set, skipping test: {}", e);
return;
}
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc).unwrap();
let formatter = match formatter {
PromptFormatter::OAI(formatter) => Ok(formatter),
}
.unwrap();
let request = Request::from(THREE_TURN_CHAT_MESSAGE, None, None, mdc.slug().to_string());
let formatted_prompt = formatter.render(&request).unwrap();
insta::with_settings!({
info => &request,
snapshot_suffix => mdc.slug().to_string(),
filters => vec![
(r"Today Date: .*", "Today Date: <redacted>"),
]
}, {
insta::assert_snapshot!(formatted_prompt);
});
}
}
#[tokio::test]
async fn test_mulit_turn_with_system() {
if let Err(e) = get_hf_token() {
println!("HF_TOKEN is not set, skipping test: {}", e);
return;
}
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc).unwrap();
let formatter = match formatter {
PromptFormatter::OAI(formatter) => Ok(formatter),
}
.unwrap();
let request = Request::from(
THREE_TURN_CHAT_MESSAGE_WITH_SYSTEM,
None,
None,
mdc.slug().to_string(),
);
let formatted_prompt = formatter.render(&request).unwrap();
insta::with_settings!({
info => &request,
snapshot_suffix => mdc.slug().to_string(),
filters => vec![
(r"Today Date: .*", "Today Date: <redacted>"),
]
}, {
insta::assert_snapshot!(formatted_prompt);
});
}
}
#[tokio::test]
async fn test_multi_turn_with_system_with_tools() {
if let Err(e) = get_hf_token() {
println!("HF_TOKEN is not set, skipping test: {}", e);
return;
}
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc).unwrap();
let formatter = match formatter {
PromptFormatter::OAI(formatter) => Ok(formatter),
}
.unwrap();
let request = Request::from(
THREE_TURN_CHAT_MESSAGE_WITH_SYSTEM,
Some(TOOLS),
Some(dynamo_protocols::types::ChatCompletionToolChoiceOption::Auto),
mdc.slug().to_string(),
);
let formatted_prompt = formatter.render(&request).unwrap();
insta::with_settings!({
info => &request,
snapshot_suffix => mdc.slug().to_string(),
filters => vec![
(r"Today Date: .*", "Today Date: <redacted>"),
]
}, {
insta::assert_snapshot!(formatted_prompt);
});
}
}
#[tokio::test]
async fn test_multi_turn_with_continuation() {
if let Err(e) = get_hf_token() {
println!("HF_TOKEN is not set, skipping test: {}", e);
return;
}
let mdc = make_mdc_from_repo(
"tests/data/sample-models",
"meta-llama/Llama-3.1-70B-Instruct",
"1605565",
Some(vec![PromptContextMixin::Llama3DateTime]),
)
.await;
let formatter = PromptFormatter::from_mdc(&mdc).unwrap();
let formatter = match formatter {
PromptFormatter::OAI(formatter) => Ok(formatter),
}
.unwrap();
let request = Request::from(
MULTI_TURN_WITH_CONTINUATION,
None,
None,
mdc.slug().to_string(),
);
let formatted_prompt = formatter.render(&request).unwrap();
insta::with_settings!({
info => &request,
snapshot_suffix => mdc.slug().to_string(),
filters => vec![
(r"Today Date: .*", "Today Date: <redacted>"),
]
}, {
insta::assert_snapshot!(formatted_prompt);
});
}
pub mod openai_preprocessor_tests {
pub use super::*;
use std::collections::HashSet;
#[tokio::test]
async fn test_stop_condition() {
if let Err(e) = get_hf_token() {
println!("HF_TOKEN is not set, skipping test: {}", e);
return;
}
let mdc = make_mdc_from_repo(
"tests/data/sample-models",
"openai/gpt-oss-120b",
"b5c939de8f754692c1647ca79fbf85e8c1e70f8a",
Some(vec![PromptContextMixin::OaiChat]),
)
.await;
let oai_preprocessor = OpenAIPreprocessor::new(mdc.clone()).unwrap();
let request = Request::from(SINGLE_CHAT_MESSAGE, None, None, mdc.slug().to_string());
let preprocessed_request = oai_preprocessor
.preprocess_request(&request, None)
.await
.unwrap()
.0;
assert!(
preprocessed_request
.stop_conditions
.stop_token_ids_hidden
.is_some()
);
let eos_token_id_set: HashSet<_> = preprocessed_request
.stop_conditions
.stop_token_ids_hidden
.unwrap()
.iter()
.cloned()
.collect();
assert_eq!(
eos_token_id_set,
vec![200002, 199999, 200012].into_iter().collect()
);
}
}
fn build_message(text: &str, chunks: &[(&str, usize)]) -> String {
let mut content_parts = vec![format!(r#"{{"type": "text", "text": "{}"}}"#, text)];
for (chunk_type, count) in chunks {
for i in 1..=*count {
let chunk = match *chunk_type {
"image_url" => format!(
r#"{{"type": "image_url", "image_url": {{"url": "https://example.com/img{}.jpg"}}}}"#,
i
),
"video_url" => format!(
r#"{{"type": "video_url", "video_url": {{"url": "https://example.com/vid{}.mp4"}}}}"#,
i
),
"audio_url" => format!(
r#"{{"type": "audio_url", "audio_url": {{"url": "https://example.com/audio{}.mp3"}}}}"#,
i
),
_ => panic!("Unknown chunk type: {}", chunk_type),
};
content_parts.push(chunk);
}
}
format!(
r#"[{{"role": "user", "content": [{}]}}]"#,
content_parts.join(", ")
)
}
#[rstest]
#[case::no_media(&[])]
#[case::single_video(&[("video_url", 1)])]
#[case::three_images(&[("image_url", 3)])]
#[case::mixed_multiple(&[("image_url", 2), ("video_url", 1), ("audio_url", 2)])]
#[tokio::test]
async fn test_media_url_passthrough(#[case] media_chunks: &[(&str, usize)]) {
if let Err(e) = get_hf_token() {
println!("HF_TOKEN is not set, skipping test: {}", e);
return;
}
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let preprocessor = dynamo_llm::preprocessor::OpenAIPreprocessor::new(mdc.clone()).unwrap();
let message = build_message("Test multimodal content", media_chunks);
let request = Request::from(&message, None, None, mdc.slug().to_string());
let (preprocessed, _annotations, _) = preprocessor
.preprocess_request(&request, None)
.await
.unwrap();
if media_chunks.is_empty() {
assert!(
preprocessed.multi_modal_data.is_none()
|| preprocessed.multi_modal_data.as_ref().unwrap().is_empty(),
"Multimodal data should be None or empty when no media is present"
);
} else {
assert!(
preprocessed.multi_modal_data.is_some(),
"Multimodal data should be present"
);
let media_map = preprocessed.multi_modal_data.as_ref().unwrap();
for (media_type, expected_count) in media_chunks {
assert!(
media_map.contains_key(*media_type),
"Should contain {} key",
media_type
);
assert_eq!(
media_map.get(*media_type).unwrap().len(),
*expected_count,
"Should have {} {} item(s)",
expected_count,
media_type
);
}
}
}
}
mod context_length_validation {
use dynamo_llm::model_card::ModelDeploymentCard;
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
use dynamo_runtime::error::{DynamoError, ErrorType};
const MODEL_PATH: &str = "tests/data/sample-models/mock-llama-3.1-8b-instruct";
fn make_chat_request(message: &str, model: &str) -> NvCreateChatCompletionRequest {
let messages: Vec<dynamo_protocols::types::ChatCompletionRequestMessage> =
serde_json::from_str(message).unwrap();
let inner = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
.model(model)
.messages(messages)
.build()
.unwrap();
NvCreateChatCompletionRequest {
inner,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
}
}
#[tokio::test]
async fn test_prompt_exceeding_context_length_returns_400() {
let mut mdc = ModelDeploymentCard::load_from_disk(MODEL_PATH, None).unwrap();
mdc.context_length = 5;
let preprocessor = OpenAIPreprocessor::new(mdc).unwrap();
let request = make_chat_request(
r#"[{"role": "user", "content": "What is deep learning?"}]"#,
"test-model",
);
let result = preprocessor.preprocess_request(&request, None).await;
let err = result.expect_err("should reject prompt exceeding context_length");
let dynamo_err = err
.downcast_ref::<DynamoError>()
.expect("error should be DynamoError");
assert_eq!(dynamo_err.error_type(), ErrorType::InvalidArgument);
assert!(
dynamo_err
.message()
.contains("maximum context length is 5 tokens"),
"error message should state the context limit, got: {}",
dynamo_err.message()
);
assert!(
dynamo_err.message().contains("Please reduce the length"),
"error message should tell user what to do, got: {}",
dynamo_err.message()
);
}
#[tokio::test]
async fn test_prompt_exactly_at_context_length_returns_400() {
let mut mdc = ModelDeploymentCard::load_from_disk(MODEL_PATH, None).unwrap();
mdc.context_length = 131072;
let preprocessor = OpenAIPreprocessor::new(mdc.clone()).unwrap();
let request = make_chat_request(
r#"[{"role": "user", "content": "What is deep learning?"}]"#,
"test-model",
);
let (preprocessed, _, _) = preprocessor
.preprocess_request(&request, None)
.await
.unwrap();
let token_count = preprocessed.token_ids.len() as u32;
mdc.context_length = token_count;
let preprocessor = OpenAIPreprocessor::new(mdc).unwrap();
let request = make_chat_request(
r#"[{"role": "user", "content": "What is deep learning?"}]"#,
"test-model",
);
let result = preprocessor.preprocess_request(&request, None).await;
let err = result.expect_err("should reject prompt that fills entire context_length");
let dynamo_err = err
.downcast_ref::<DynamoError>()
.expect("error should be DynamoError");
assert_eq!(dynamo_err.error_type(), ErrorType::InvalidArgument);
}
#[tokio::test]
async fn test_context_length_zero_skips_validation() {
let mut mdc = ModelDeploymentCard::load_from_disk(MODEL_PATH, None).unwrap();
mdc.context_length = 0;
let preprocessor = OpenAIPreprocessor::new(mdc).unwrap();
let request = make_chat_request(
r#"[{"role": "user", "content": "What is deep learning?"}]"#,
"test-model",
);
let result = preprocessor.preprocess_request(&request, None).await;
assert!(result.is_ok(), "context_length=0 should skip validation");
}
}