use std::{collections::HashSet, sync::Arc};
use anyhow::{Context, Ok, Result};
use minijinja::Environment;
use crate::model_card::{ModelDeploymentCard, PromptContextMixin, PromptFormatterArtifact};
mod context;
mod formatters;
mod oai;
mod tokcfg;
use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
pub use tokcfg::ChatTemplate;
use tokcfg::ChatTemplateValue;
impl PromptFormatter {
pub fn from_mdc(mdc: &ModelDeploymentCard) -> Result<PromptFormatter> {
let name_lower = mdc.display_name.to_lowercase();
if name_lower.contains("deepseek")
&& name_lower.contains("v3.2")
&& !name_lower.contains("exp")
{
tracing::info!("Detected DeepSeek V3.2 model (non-Exp), using native Rust formatter");
return Ok(Self::OAI(Arc::new(
super::deepseek_v32::DeepSeekV32Formatter::new_thinking(),
)));
}
match mdc
.prompt_formatter
.as_ref()
.ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))?
{
PromptFormatterArtifact::HfTokenizerConfigJson(checked_file) => {
let Some(file) = checked_file.path() else {
anyhow::bail!(
"HfTokenizerConfigJson for {} is a URL, cannot load",
mdc.display_name
);
};
let contents = std::fs::read_to_string(file).with_context(|| {
format!(
"PromptFormatter.from_mdc fs:read_to_string '{}'",
file.display()
)
})?;
let mut config: ChatTemplate =
serde_json::from_str(&contents).inspect_err(|err| {
crate::log_json_err(&file.display().to_string(), &contents, err)
})?;
match mdc.chat_template_file.as_ref() {
Some(PromptFormatterArtifact::HfChatTemplateJinja {
file: checked_file,
..
}) => {
let Some(path) = checked_file.path() else {
anyhow::bail!(
"HfChatTemplateJinja for {} is a URL, cannot load",
mdc.display_name
);
};
let chat_template = std::fs::read_to_string(path)
.with_context(|| format!("fs:read_to_string '{}'", path.display()))?;
config.chat_template = Some(ChatTemplateValue(either::Left(chat_template)));
}
Some(PromptFormatterArtifact::HfChatTemplateJson {
file: checked_file,
..
}) => {
let Some(path) = checked_file.path() else {
anyhow::bail!(
"HfChatTemplateJson for {} is a URL, cannot load",
mdc.display_name
);
};
let raw = std::fs::read_to_string(path)
.with_context(|| format!("fs:read_to_string '{}'", path.display()))?;
let wrapper: serde_json::Value =
serde_json::from_str(&raw).with_context(|| {
format!("Failed to parse '{}' as JSON", path.display())
})?;
let field = wrapper.get("chat_template").ok_or_else(|| {
anyhow::anyhow!(
"'{}' does not contain a 'chat_template' field",
path.display()
)
})?;
let value = serde_json::from_value::<ChatTemplateValue>(field.clone())
.with_context(|| {
format!(
"Failed to deserialize 'chat_template' in '{}'",
path.display()
)
})?;
config.chat_template = Some(value);
}
_ => {}
}
Self::from_parts(
config,
mdc.prompt_context
.clone()
.map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
mdc.runtime_config.exclude_tools_when_tool_choice_none,
)
}
PromptFormatterArtifact::HfChatTemplateJinja { .. }
| PromptFormatterArtifact::HfChatTemplateJson { .. } => Err(anyhow::anyhow!(
"prompt_formatter should not have type HfChatTemplate*"
)),
}
}
pub fn from_parts(
config: ChatTemplate,
context: ContextMixins,
exclude_tools_when_tool_choice_none: bool,
) -> Result<PromptFormatter> {
let formatter = HfTokenizerConfigJsonFormatter::with_options(
config,
context,
exclude_tools_when_tool_choice_none,
)?;
Ok(Self::OAI(Arc::new(formatter)))
}
}
struct JinjaEnvironment {
env: Environment<'static>,
}
#[derive(Debug)]
struct HfTokenizerConfigJsonFormatter {
env: Environment<'static>,
config: ChatTemplate,
mixins: Arc<ContextMixins>,
supports_add_generation_prompt: bool,
requires_content_arrays: bool,
exclude_tools_when_tool_choice_none: bool,
template_handles_reasoning: bool,
}
#[derive(Debug, Clone, Default)]
pub struct ContextMixins {
context_mixins: HashSet<PromptContextMixin>,
}