use std::sync::Arc;
use super::tokcfg::{ChatTemplate, raise_exception, strftime_now, tojson};
use super::{ContextMixins, HfTokenizerConfigJsonFormatter, JinjaEnvironment};
use either::Either;
use minijinja::{Environment, Value, context};
use serde_json::json;
use tracing;
fn detect_content_array_usage(env: &Environment) -> bool {
let array_msg = context! {
messages => json!([{"role": "user", "content": [{"type": "text", "text": "template_test"}]}]),
add_generation_prompt => false,
};
let string_msg = context! {
messages => json!([{"role": "user", "content": "template_test"}]),
add_generation_prompt => false,
};
let out_array = env
.get_template("default")
.and_then(|t| t.render(&array_msg))
.unwrap_or_default();
let out_string = env
.get_template("default")
.and_then(|t| t.render(&string_msg))
.unwrap_or_default();
out_array.contains("template_test") && !out_string.contains("template_test")
}
fn remove_known_non_jinja2_tags(template: &str) -> String {
template
.replace("{% generation %}", "")
.replace("{% endgeneration %}", "")
}
impl JinjaEnvironment {
fn env(self) -> Environment<'static> {
self.env
}
}
impl Default for JinjaEnvironment {
fn default() -> Self {
let mut env = Environment::new();
env.set_lstrip_blocks(true);
env.set_trim_blocks(true);
JinjaEnvironment { env }
}
}
impl HfTokenizerConfigJsonFormatter {
pub fn new(config: ChatTemplate, mixins: ContextMixins) -> anyhow::Result<Self> {
let mut env = JinjaEnvironment::default().env();
let chat_template = config.chat_template.as_ref().ok_or(anyhow::anyhow!(
"chat_template field is required in the tokenizer_config.json file"
))?;
env.add_filter("length", |value: Value| -> usize {
use minijinja::value::ValueKind;
match value.kind() {
ValueKind::Undefined | ValueKind::None => 0,
_ => value.len().unwrap_or(0),
}
});
env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
env.add_filter("tojson", tojson);
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
let mut supports_add_generation_prompt = None;
match &chat_template.0 {
Either::Left(x) => {
if x.contains("add_generation_prompt") {
tracing::debug!(
"Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt."
);
supports_add_generation_prompt = Some(true);
}
let template_cleaned = remove_known_non_jinja2_tags(x);
env.add_template_owned("default", template_cleaned.clone())?;
env.add_template_owned("tool_use", template_cleaned)?;
}
Either::Right(map) => {
for t in map {
for (k, v) in t.iter() {
if v.contains("add_generation_prompt") {
match supports_add_generation_prompt {
Some(true) | None => {
tracing::debug!(
"Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt."
);
supports_add_generation_prompt = Some(true);
}
Some(false) => {
tracing::warn!(
"Not all templates contain `add_generation_prompt` key. This model does not support add_generation_prompt."
);
}
}
} else {
supports_add_generation_prompt = Some(false);
}
let template_cleaned = remove_known_non_jinja2_tags(v);
env.add_template_owned(k.to_string(), template_cleaned)?;
}
}
if env.templates().count() == 0 {
anyhow::bail!(
"Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools."
);
}
}
}
let requires_content_arrays = detect_content_array_usage(&env);
Ok(HfTokenizerConfigJsonFormatter {
env,
config,
mixins: Arc::new(mixins),
supports_add_generation_prompt: supports_add_generation_prompt.unwrap_or(false),
requires_content_arrays,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_remove_known_non_jinja2_tags() {
let template =
"USER: {{ message }} ASSISTANT: {% generation %}Reply here{% endgeneration %}";
let result = remove_known_non_jinja2_tags(template);
assert_eq!(result, "USER: {{ message }} ASSISTANT: Reply here");
}
#[test]
fn test_remove_known_non_jinja2_tags_preserves_standard_tags() {
let template = "{% for item in items %}{{ item }}{% endfor %}";
let result = remove_known_non_jinja2_tags(template);
assert_eq!(result, template);
}
#[test]
fn test_remove_known_non_jinja2_tags_multiple() {
let template = "Start {% generation %}Part 1{% endgeneration %} middle {% generation %}Part 2{% endgeneration %}";
let result = remove_known_non_jinja2_tags(template);
assert_eq!(result, "Start Part 1 middle Part 2");
}
}