use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use validator::Validate;
use crate::engines::ValidateRequest;
use crate::preprocessor::media::MediaDecoder;
use super::{
OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
common_ext::{CommonExt, CommonExtProvider},
nvext::NvExt,
nvext::NvExtProvider,
tools, validate,
};
pub mod aggregator;
mod delta;
pub mod jail;
pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
#[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateChatCompletionRequest {
#[serde(flatten)]
pub inner: dynamo_async_openai::types::CreateChatCompletionRequest,
#[serde(flatten, default)]
pub common: CommonExt,
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
alias = "chat_template_kwargs"
)]
pub chat_template_args: Option<std::collections::HashMap<String, serde_json::Value>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub media_io_kwargs: Option<MediaDecoder>,
#[serde(flatten, default, skip_serializing)]
pub unsupported_fields: std::collections::HashMap<String, serde_json::Value>,
}
pub type NvCreateChatCompletionResponse = dynamo_async_openai::types::CreateChatCompletionResponse;
pub type NvCreateChatCompletionStreamResponse =
dynamo_async_openai::types::CreateChatCompletionStreamResponse;
impl NvExtProvider for NvCreateChatCompletionRequest {
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
fn raw_prompt(&self) -> Option<String> {
None
}
}
impl AnnotationsProvider for NvCreateChatCompletionRequest {
fn annotations(&self) -> Option<Vec<String>> {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.clone())
}
fn has_annotation(&self, annotation: &str) -> bool {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.as_ref())
.map(|annotations| annotations.contains(&annotation.to_string()))
.unwrap_or(false)
}
}
impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
fn get_temperature(&self) -> Option<f32> {
self.inner.temperature
}
fn get_top_p(&self) -> Option<f32> {
self.inner.top_p
}
fn get_frequency_penalty(&self) -> Option<f32> {
self.inner.frequency_penalty
}
fn get_presence_penalty(&self) -> Option<f32> {
self.inner.presence_penalty
}
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
fn get_seed(&self) -> Option<i64> {
self.inner.seed
}
fn get_n(&self) -> Option<u8> {
self.inner.n
}
fn get_best_of(&self) -> Option<u8> {
None }
}
impl CommonExtProvider for NvCreateChatCompletionRequest {
fn common_ext(&self) -> Option<&CommonExt> {
Some(&self.common)
}
fn get_guided_json(&self) -> Option<serde_json::Value> {
if let Some(value) = self.common.guided_json.clone() {
return Some(value);
}
if let (Some(tool_choice), Some(tools)) =
(self.inner.tool_choice.as_ref(), self.inner.tools.as_deref())
{
match tools::get_json_schema_from_tools(Some(tool_choice), Some(tools)) {
Ok(Some(schema)) => return Some(schema),
Ok(None) => {}
Err(err) => {
tracing::warn!(
error = %err,
"failed to derive guided_json from tool_choice"
);
}
}
}
if let Some(response_format) = self.inner.response_format.as_ref() {
use dynamo_async_openai::types::ResponseFormat;
match response_format {
ResponseFormat::Text => {}
ResponseFormat::JsonObject => {
return Some(serde_json::json!({
"type": "object"
}));
}
ResponseFormat::JsonSchema { json_schema } => {
if let Some(schema) = json_schema.schema.clone() {
return Some(schema);
}
}
}
}
None
}
fn get_guided_regex(&self) -> Option<String> {
self.common.guided_regex.clone()
}
fn get_guided_grammar(&self) -> Option<String> {
self.common.guided_grammar.clone()
}
fn get_guided_choice(&self) -> Option<Vec<String>> {
self.common.guided_choice.clone()
}
fn get_guided_decoding_backend(&self) -> Option<String> {
self.common.guided_decoding_backend.clone()
}
fn get_guided_whitespace_pattern(&self) -> Option<String> {
self.common.guided_whitespace_pattern.clone()
}
fn get_top_k(&self) -> Option<i32> {
self.common.top_k
}
fn get_min_p(&self) -> Option<f32> {
self.common.min_p
}
fn get_repetition_penalty(&self) -> Option<f32> {
self.common.repetition_penalty
}
fn get_include_stop_str_in_output(&self) -> Option<bool> {
self.common.include_stop_str_in_output
}
fn get_skip_special_tokens(&self) -> Option<bool> {
self.common.skip_special_tokens
}
}
impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
#[allow(deprecated)]
fn get_max_tokens(&self) -> Option<u32> {
self.inner.max_completion_tokens.or(self.inner.max_tokens)
}
fn get_min_tokens(&self) -> Option<u32> {
self.common.min_tokens
}
fn get_stop(&self) -> Option<Vec<String>> {
self.inner.stop.as_ref().map(|stop| match stop {
dynamo_async_openai::types::Stop::String(s) => vec![s.clone()],
dynamo_async_openai::types::Stop::StringArray(arr) => arr.clone(),
})
}
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
fn get_common_ignore_eos(&self) -> Option<bool> {
self.common.ignore_eos
}
fn get_ignore_eos(&self) -> Option<bool> {
self.common.ignore_eos
}
}
impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest {
fn get_logprobs(&self) -> Option<u32> {
match self.inner.logprobs {
Some(true) => match self.inner.top_logprobs {
Some(top_logprobs) => Some(top_logprobs as u32),
None => Some(1_u32),
},
Some(false) => None,
None => None,
}
}
fn get_prompt_logprobs(&self) -> Option<u32> {
None
}
fn get_skip_special_tokens(&self) -> Option<bool> {
CommonExtProvider::get_skip_special_tokens(self)
}
fn get_formatted_prompt(&self) -> Option<bool> {
None
}
}
impl ValidateRequest for NvCreateChatCompletionRequest {
fn validate(&self) -> Result<(), anyhow::Error> {
validate::validate_no_unsupported_fields(&self.unsupported_fields)?;
validate::validate_messages(&self.inner.messages)?;
validate::validate_model(&self.inner.model)?;
validate::validate_reasoning_effort(&self.inner.reasoning_effort)?;
validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
validate::validate_logit_bias(&self.inner.logit_bias)?;
validate::validate_top_logprobs(self.inner.top_logprobs)?;
validate::validate_max_completion_tokens(self.inner.max_completion_tokens)?;
validate::validate_n(self.inner.n)?;
validate::validate_presence_penalty(self.inner.presence_penalty)?;
validate::validate_response_format(&self.inner.response_format)?;
validate::validate_service_tier(&self.inner.service_tier)?;
validate::validate_stop(&self.inner.stop)?;
validate::validate_temperature(self.inner.temperature)?;
validate::validate_top_p(self.inner.top_p)?;
validate::validate_tools(&self.inner.tools.as_deref())?;
validate::validate_user(self.inner.user.as_deref())?;
validate::validate_repetition_penalty(self.get_repetition_penalty())?;
validate::validate_min_p(self.get_min_p())?;
validate::validate_top_k(self.get_top_k())?;
validate::validate_n_with_temperature(self.inner.n, self.inner.temperature)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::common::OutputOptionsProvider;
use serde_json::json;
#[test]
fn test_skip_special_tokens_none() {
let json_str = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
assert_eq!(request.common.skip_special_tokens, None);
let output_options = request
.extract_output_options()
.expect("Failed to extract output options");
assert_eq!(output_options.skip_special_tokens, None);
}
#[test]
fn test_skip_special_tokens_propagates() {
for skip_value in [true, false] {
let json_str = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"skip_special_tokens": skip_value
});
let request: NvCreateChatCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
let output_options = request
.extract_output_options()
.expect("Failed to extract output options");
assert_eq!(output_options.skip_special_tokens, Some(skip_value));
}
}
}