use super::*;
use minijinja::{context, value::Value};
use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionRequest, completions::CompletionRequest,
};
use tracing;
impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn messages(&self) -> Value {
Value::from_serialize(&self.inner.messages)
}
fn tools(&self) -> Option<Value> {
if self.inner.tools.is_none() {
None
} else {
Some(Value::from_serialize(&self.inner.tools))
}
}
fn tool_choice(&self) -> Option<Value> {
if self.inner.tool_choice.is_none() {
None
} else {
Some(Value::from_serialize(&self.inner.tool_choice))
}
}
fn should_add_generation_prompt(&self) -> bool {
if let Some(last) = self.inner.messages.last() {
matches!(
last,
async_openai::types::ChatCompletionRequestMessage::User(_)
)
} else {
true
}
}
}
impl OAIChatLikeRequest for CompletionRequest {
fn messages(&self) -> minijinja::value::Value {
let message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
crate::protocols::openai::completions::prompt_to_string(&self.inner.prompt),
),
name: None,
},
);
minijinja::value::Value::from_serialize(vec![message])
}
fn should_add_generation_prompt(&self) -> bool {
true
}
}
impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
fn supports_add_generation_prompt(&self) -> bool {
self.supports_add_generation_prompt
}
fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String> {
let mixins = Value::from_dyn_object(self.mixins.clone());
let tools = req.tools();
let has_tools = tools.is_some();
let add_generation_prompt = req.should_add_generation_prompt();
tracing::trace!(
"Rendering prompt with tools: {:?}, add_generation_prompt: {}",
has_tools,
add_generation_prompt
);
let ctx = context! {
messages => req.messages(),
tools => tools,
bos_token => self.config.bos_tok(),
eos_token => self.config.eos_tok(),
unk_token => self.config.unk_tok(),
add_generation_prompt => add_generation_prompt,
..mixins
};
let ctx = context! { ..ctx, ..context! {
}};
let tmpl = if has_tools {
self.env.get_template("tool_use")?
} else {
self.env.get_template("default")?
};
Ok(tmpl.render(&ctx)?)
}
}