use minijinja::{Environment, Error as JErr, ErrorKind, Value as JValue, value::Kwargs};
use serde_json::Value;
use crate::Error;
struct PyCompactFormatter;
impl serde_json::ser::Formatter for PyCompactFormatter {
fn begin_array_value<W>(&mut self, writer: &mut W, first: bool) -> std::io::Result<()>
where
W: ?Sized + std::io::Write,
{
if first {
Ok(())
} else {
writer.write_all(b", ")
}
}
fn begin_object_key<W>(&mut self, writer: &mut W, first: bool) -> std::io::Result<()>
where
W: ?Sized + std::io::Write,
{
if first {
Ok(())
} else {
writer.write_all(b", ")
}
}
fn begin_object_value<W>(&mut self, writer: &mut W) -> std::io::Result<()>
where
W: ?Sized + std::io::Write,
{
writer.write_all(b": ")
}
}
fn py_json_dumps<S: serde::Serialize>(
v: &S,
indent: Option<&[u8]>,
) -> Result<String, serde_json::Error> {
match indent {
None => {
let mut buf = Vec::new();
let mut ser = serde_json::Serializer::with_formatter(&mut buf, PyCompactFormatter);
v.serialize(&mut ser)?;
Ok(String::from_utf8(buf).expect("serde_json emits UTF-8"))
}
Some(pad) => {
let mut buf = Vec::new();
let fmt = serde_json::ser::PrettyFormatter::with_indent(pad);
let mut ser = serde_json::Serializer::with_formatter(&mut buf, fmt);
v.serialize(&mut ser)?;
Ok(String::from_utf8(buf).expect("serde_json emits UTF-8"))
}
}
}
fn coerce_indent(val: &JValue) -> Result<Option<Vec<u8>>, JErr> {
if val.is_none() || val.is_undefined() {
return Ok(None);
}
if let Ok(b) = bool::try_from(val.clone()) {
return Ok(Some(if b { b" ".to_vec() } else { Vec::new() }));
}
if let Some(s) = val.as_str() {
return Ok(Some(s.as_bytes().to_vec()));
}
let n = i64::try_from(val.clone()).map_err(|_| {
JErr::new(
ErrorKind::InvalidOperation,
"tojson: `indent` must be an integer, boolean, or string",
)
})?;
const MAX_INDENT: i64 = 1024;
if n > MAX_INDENT {
return Err(JErr::new(
ErrorKind::InvalidOperation,
"tojson: `indent` too large (max 1024)",
));
}
Ok(Some(vec![b' '; n.max(0) as usize]))
}
#[cfg(feature = "tokenizer-deepseek-v32")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-deepseek-v32")))]
pub trait ChatTemplateOverride: Send + Sync {
fn apply(
&self,
messages: &[Value],
tools: Option<&Value>,
add_generation_prompt: bool,
continue_final_message: bool,
enable_thinking: bool,
) -> Result<String, Error>;
}
#[cfg(feature = "tokenizer-deepseek-v32")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-deepseek-v32")))]
pub fn override_by_name(name: &str) -> Option<Box<dyn ChatTemplateOverride>> {
match name {
"deepseek_v32" => Some(Box::new(DeepseekV32)),
_ => None,
}
}
fn strip_generation_tags(template: &str) -> std::borrow::Cow<'_, str> {
if !template.contains("generation") {
return std::borrow::Cow::Borrowed(template);
}
let bytes = template.as_bytes();
let mut out = String::with_capacity(template.len());
let mut i = 0;
let mut rewrote = false;
while i < bytes.len() {
if bytes[i] == b'{' && i + 1 < bytes.len() {
match bytes[i + 1] {
b'#' => {
let end = find_comment_close(template, i + 2);
out.push_str(&template[i..end]);
i = end;
continue;
}
b'{' => {
let end = find_expr_close(template, i + 2);
out.push_str(&template[i..end]);
i = end;
continue;
}
b'%' => {
if let Some(close) = find_tag_close(template, i + 2) {
let inner = &template[i + 2..close];
let (kw, lead, trail) = classify_tag(inner);
if kw == TagKw::Raw {
let raw_end = find_endraw(template, close + 2);
out.push_str(&template[i..raw_end]);
i = raw_end;
continue;
}
let repl = match kw {
TagKw::Generation => Some("if true"),
TagKw::EndGeneration => Some("endif"),
_ => None,
};
if let Some(repl) = repl {
rewrote = true;
out.push_str("{%");
out.push_str(lead.open_str());
out.push_str(repl);
out.push_str(trail.close_str());
i = close + 2;
continue;
}
out.push_str(&template[i..close + 2]);
i = close + 2;
continue;
}
out.push_str(&template[i..]);
break;
}
_ => {}
}
}
let ch_len = utf8_char_len(bytes[i]);
out.push_str(&template[i..i + ch_len]);
i += ch_len;
}
if rewrote {
std::borrow::Cow::Owned(out)
} else {
std::borrow::Cow::Borrowed(template)
}
}
fn find_comment_close(template: &str, start: usize) -> usize {
let b = template.as_bytes();
let mut i = start;
while i + 1 < b.len() {
if b[i] == b'#' && b[i + 1] == b'}' {
return i + 2;
}
i += 1;
}
b.len()
}
fn find_expr_close(template: &str, start: usize) -> usize {
let b = template.as_bytes();
let mut i = start;
let mut quote: Option<u8> = None;
while i < b.len() {
match quote {
Some(q) => {
if b[i] == b'\\' && i + 1 < b.len() {
i += 2;
continue;
}
if b[i] == q {
quote = None;
}
}
None => {
if b[i] == b'\'' || b[i] == b'"' {
quote = Some(b[i]);
} else if b[i] == b'}' && i + 1 < b.len() && b[i + 1] == b'}' {
return i + 2;
}
}
}
i += 1;
}
b.len()
}
#[derive(PartialEq, Eq, Clone, Copy)]
enum WsCtrl {
None,
Strip,
Keep,
}
impl WsCtrl {
fn open_str(self) -> &'static str {
match self {
WsCtrl::None => " ",
WsCtrl::Strip => "- ",
WsCtrl::Keep => "+ ",
}
}
fn close_str(self) -> &'static str {
match self {
WsCtrl::None => " %}",
WsCtrl::Strip => " -%}",
WsCtrl::Keep => " +%}",
}
}
}
#[derive(PartialEq, Eq, Clone, Copy)]
enum TagKw {
Generation,
EndGeneration,
Raw,
Other,
}
fn split_lead_ws_ctrl(inner: &str) -> (WsCtrl, &str) {
let t = inner.trim_start();
if let Some(rest) = t.strip_prefix('-') {
(WsCtrl::Strip, rest.trim_start())
} else if let Some(rest) = t.strip_prefix('+') {
(WsCtrl::Keep, rest.trim_start())
} else {
(WsCtrl::None, t)
}
}
fn split_trail_ws_ctrl(inner: &str) -> (&str, WsCtrl) {
let t = inner.trim_end();
if let Some(rest) = t.strip_suffix('-') {
(rest.trim_end(), WsCtrl::Strip)
} else if let Some(rest) = t.strip_suffix('+') {
(rest.trim_end(), WsCtrl::Keep)
} else {
(t, WsCtrl::None)
}
}
fn classify_tag(inner: &str) -> (TagKw, WsCtrl, WsCtrl) {
let (lead, body) = split_lead_ws_ctrl(inner);
let (kw, trail) = split_trail_ws_ctrl(body);
let tag = match kw.trim() {
"generation" => TagKw::Generation,
"endgeneration" => TagKw::EndGeneration,
"raw" if trail != WsCtrl::Keep => TagKw::Raw,
_ => TagKw::Other,
};
(tag, lead, trail)
}
fn find_tag_close(template: &str, start: usize) -> Option<usize> {
let b = template.as_bytes();
let mut i = start;
let mut quote: Option<u8> = None;
while i < b.len() {
match quote {
Some(q) => {
if b[i] == b'\\' && i + 1 < b.len() {
i += 2;
continue;
}
if b[i] == q {
quote = None;
}
}
None => {
if b[i] == b'\'' || b[i] == b'"' {
quote = Some(b[i]);
} else if b[i] == b'%' && i + 1 < b.len() && b[i + 1] == b'}' {
return Some(i);
}
}
}
i += 1;
}
None
}
fn match_endraw_delim(template: &str, i: usize) -> Option<usize> {
let b = template.as_bytes();
if i + 1 >= b.len() || b[i] != b'{' || b[i + 1] != b'%' {
return None;
}
let mut p = i + 2;
if p < b.len() && (b[p] == b'-' || b[p] == b'+') {
p += 1;
}
while p < b.len() && b[p].is_ascii_whitespace() {
p += 1;
}
if !template[p..].starts_with("endraw") {
return None;
}
p += "endraw".len();
while p < b.len() && b[p].is_ascii_whitespace() {
p += 1;
}
if p + 1 < b.len() && (b[p] == b'-' || b[p] == b'+') && b[p + 1] == b'%' {
if p + 2 < b.len() && b[p + 2] == b'}' {
return Some(p + 3);
}
return None;
}
if p + 1 < b.len() && b[p] == b'%' && b[p + 1] == b'}' {
return Some(p + 2);
}
None
}
fn find_endraw(template: &str, start: usize) -> usize {
let b = template.as_bytes();
let mut i = start;
while i < b.len() {
if b[i] == b'{' && i + 1 < b.len() && b[i + 1] == b'%' {
if let Some(end) = match_endraw_delim(template, i) {
return end;
}
i += 1;
continue;
}
i += utf8_char_len(b[i]);
}
b.len()
}
fn utf8_char_len(b: u8) -> usize {
if b < 0x80 {
1
} else if b >> 5 == 0b110 {
2
} else if b >> 4 == 0b1110 {
3
} else if b >> 3 == 0b11110 {
4
} else {
1
}
}
const CONTINUE_FINAL_MESSAGE_TAG: &str = "CONTINUE_FINAL_MESSAGE_TAG ";
fn continue_final_message_mutate(messages: &Value) -> Result<(Value, String), Error> {
let arr = messages
.as_array()
.ok_or_else(|| Error::tokenizer("messages must be a list"))?;
let last = arr.last().ok_or_else(|| {
Error::tokenizer(
"continue_final_message is set but the conversation has no final message to continue",
)
})?;
let content = last.get("content").and_then(Value::as_str).ok_or_else(|| {
Error::tokenizer(
"continue_final_message is set but the final message has no string \"content\" to continue",
)
})?;
let original_content = content.to_string();
let mut mutated = arr.clone();
let new_content = format!("{content}{CONTINUE_FINAL_MESSAGE_TAG}");
if let Some(obj) = mutated.last_mut().and_then(Value::as_object_mut) {
obj.insert("content".to_string(), Value::String(new_content));
}
Ok((Value::Array(mutated), original_content))
}
fn continue_final_message_trim(rendered: &str, original_content: &str) -> Result<String, Error> {
let needle = CONTINUE_FINAL_MESSAGE_TAG.trim_end();
if !rendered.contains(original_content.trim()) || !rendered.contains(needle) {
return Err(Error::tokenizer(
"continue_final_message is set but the final message does not appear in the \
rendered chat (the template dropped the final message's content or the \
continue sentinel) — refusing to continue a prompt the template did not render",
));
}
let tag_loc = rendered.rfind(needle).ok_or_else(|| {
Error::tokenizer(
"continue_final_message: the rendered template does not contain the continue \
sentinel",
)
})?;
let full_tag_present = rendered.get(tag_loc..tag_loc + CONTINUE_FINAL_MESSAGE_TAG.len())
== Some(CONTINUE_FINAL_MESSAGE_TAG);
let head = &rendered[..tag_loc];
Ok(if full_tag_present {
head.to_string()
} else {
head.trim_end().to_string()
})
}
#[allow(clippy::too_many_arguments)]
pub fn render_jinja(
template: &str,
messages: &Value,
tools: Option<&Value>,
add_generation_prompt: bool,
continue_final_message: bool,
bos_token: Option<&str>,
eos_token: Option<&str>,
enable_thinking: bool,
extra: &Value,
) -> Result<String, Error> {
let mut env = Environment::new();
env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
install_hf_extensions(&mut env);
env.set_trim_blocks(true);
env.set_lstrip_blocks(true);
let template = strip_generation_tags(template);
env
.add_template("chat", &template)
.map_err(|e| Error::tokenizer(format!("chat template parse: {e}")))?;
let tmpl = env
.get_template("chat")
.map_err(|e| Error::tokenizer(format!("chat template: {e}")))?;
let continued_messages;
let mut original_final_content = String::new();
let messages: &Value = if continue_final_message {
let (mutated, original) = continue_final_message_mutate(messages)?;
continued_messages = mutated;
original_final_content = original;
&continued_messages
} else {
messages
};
let mut ctx = serde_json::Map::new();
ctx.insert("messages".into(), messages.clone());
ctx.insert(
"add_generation_prompt".into(),
Value::Bool(add_generation_prompt),
);
ctx.insert("enable_thinking".into(), Value::Bool(enable_thinking));
if let Some(t) = tools {
ctx.insert("tools".into(), t.clone());
} else {
ctx.insert("tools".into(), Value::Null);
}
ctx.insert("documents".into(), Value::Null);
if let Some(b) = bos_token {
ctx.insert("bos_token".into(), Value::String(b.to_owned()));
}
if let Some(e) = eos_token {
ctx.insert("eos_token".into(), Value::String(e.to_owned()));
}
if let Some(obj) = extra.as_object() {
for (k, v) in obj {
ctx.insert(k.clone(), v.clone());
}
}
let rendered = tmpl
.render(JValue::from_serialize(Value::Object(ctx)))
.map_err(|e| Error::tokenizer(format!("chat template render: {e}")))?;
if continue_final_message {
continue_final_message_trim(&rendered, &original_final_content)
} else {
Ok(rendered)
}
}
fn install_hf_extensions(env: &mut Environment<'_>) {
env.add_function("raise_exception", |msg: String| -> Result<JValue, JErr> {
Err(JErr::new(ErrorKind::InvalidOperation, msg))
});
env.add_filter(
"tojson",
|v: JValue, indent: Option<JValue>, kwargs: Kwargs| -> Result<JValue, JErr> {
let indent_arg = match indent {
Some(i) => Some(i),
None => kwargs.get::<Option<JValue>>("indent")?,
};
kwargs.assert_all_used()?;
let indent = match indent_arg {
Some(ref i) => coerce_indent(i)?,
None => None,
};
let out = py_json_dumps(&v, indent.as_deref())
.map_err(|e| JErr::new(ErrorKind::InvalidOperation, e.to_string()))?;
Ok(JValue::from_safe_string(out))
},
);
env.add_function("strftime_now", |fmt: String| -> Result<JValue, JErr> {
strftime_now(&fmt)
.map(JValue::from)
.map_err(|e| JErr::new(ErrorKind::InvalidOperation, format!("strftime_now: {e}")))
});
}
fn strip_naive_unsupported(format: &str) -> String {
let mut out = String::with_capacity(format.len());
let mut chars = format.chars();
while let Some(c) = chars.next() {
if c == '%' {
match chars.next() {
Some('%') => out.push_str("%%"),
Some('z') | Some('Z') => {} Some(d) => {
out.push('%');
out.push(d);
}
None => out.push('%'),
}
} else {
out.push(c);
}
}
out
}
pub fn strftime_at(dt: jiff::civil::DateTime, format: &str) -> Result<String, jiff::Error> {
jiff::fmt::strtime::format(strip_naive_unsupported(format), dt)
}
fn strftime_now(format: &str) -> Result<String, jiff::Error> {
strftime_at(jiff::Zoned::now().datetime(), format)
}
#[cfg(feature = "tokenizer-deepseek-v32")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-deepseek-v32")))]
pub use deepseek_v32::DeepseekV32;
#[cfg(feature = "tokenizer-deepseek-v32")]
mod deepseek_v32 {
use serde_json::Value;
use super::ChatTemplateOverride;
use crate::Error;
const BOS_TOKEN: &str = "<|begin▁of▁sentence|>";
const EOS_TOKEN: &str = "<|end▁of▁sentence|>";
const THINK_START: &str = "<think>";
const THINK_END: &str = "</think>";
const DSML: &str = "|DSML|";
const TOOLS_SYSTEM_TEMPLATE: &str = crate::tokenizer::generated::DEEPSEEK_V32_TEMPLATE;
pub struct DeepseekV32;
fn to_json(v: &Value) -> String {
serde_json::to_string(v).unwrap_or_else(|_| "null".into())
}
fn render_tools(tools: &Value) -> String {
let schemas = tools
.as_array()
.map(|a| a.iter().map(to_json).collect::<Vec<_>>().join("\n"))
.unwrap_or_default();
TOOLS_SYSTEM_TEMPLATE
.replace("{dsml}", DSML)
.replace("{think_start}", THINK_START)
.replace("{think_end}", THINK_END)
.replace("{tool_schemas}", &schemas)
}
fn tools_from_openai(tools: &Value) -> Value {
match tools.as_array() {
Some(arr) => Value::Array(
arr
.iter()
.map(|t| t.get("function").cloned().unwrap_or_else(|| t.clone()))
.collect(),
),
None => Value::Array(vec![]),
}
}
fn find_last_user_index(messages: &[Value]) -> i64 {
for idx in (0..messages.len()).rev() {
if let Some(r) = messages[idx].get("role").and_then(Value::as_str)
&& (r == "user" || r == "developer")
{
return idx as i64;
}
}
-1
}
fn encode_arguments_to_dsml(tc: &Value) -> Result<String, Error> {
let args_raw = tc.get("arguments");
let arguments: Value = match args_raw {
Some(Value::String(s)) => {
serde_json::from_str(s).map_err(|e| Error::tokenizer(format!("deepseek_v32 args: {e}")))?
}
Some(other) => other.clone(),
None => Value::Object(Default::default()),
};
let obj = arguments
.as_object()
.ok_or_else(|| Error::tokenizer("deepseek_v32: arguments not object"))?;
let mut parts = Vec::new();
for (k, v) in obj {
let is_str = v.is_string();
let value = if let Value::String(s) = v {
s.clone()
} else {
to_json(v)
};
parts.push(format!(
"<{DSML}parameter name=\"{k}\" string=\"{}\">{value}</{DSML}parameter>",
if is_str { "true" } else { "false" }
));
}
Ok(parts.join("\n"))
}
fn render_message(
index: usize,
messages: &[Value],
thinking_mode: &str,
tools: Option<&Value>,
) -> Result<String, Error> {
let mut prompt = String::new();
let msg = messages.get(index).ok_or_else(|| {
Error::tokenizer(format!(
"deepseek_v32: message index {index} out of range (len {})",
messages.len()
))
})?;
let last_user_idx = find_last_user_index(messages);
let role = msg.get("role").and_then(Value::as_str).unwrap_or("");
let content = msg.get("content").and_then(Value::as_str).unwrap_or("");
let msg_tools = tools.cloned().or_else(|| msg.get("tools").cloned());
let response_format = msg.get("response_format");
let mut tool_calls = msg.get("tool_calls").cloned();
if let Some(Value::Array(tcs)) = &tool_calls {
tool_calls = Some(Value::Array(
tcs
.iter()
.map(|tc| {
let f = tc.get("function");
serde_json::json!({
"name": f.and_then(|x| x.get("name")).cloned().unwrap_or(Value::Null),
"arguments": f.and_then(|x| x.get("arguments")).cloned().unwrap_or(Value::Null),
})
})
.collect(),
));
}
let reasoning_content = msg
.get("reasoning_content")
.and_then(Value::as_str)
.unwrap_or("");
match role {
"system" => {
prompt.push_str(content);
if let Some(t) = &msg_tools {
prompt.push_str("\n\n");
prompt.push_str(&render_tools(&tools_from_openai(t)));
}
if let Some(rf) = response_format {
prompt.push_str(&format!(
"\n\n## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{}",
to_json(rf)
));
}
}
"developer" => {
let mut cd = String::new();
if let Some(t) = &msg_tools {
cd.push_str("\n\n");
cd.push_str(&render_tools(&tools_from_openai(t)));
}
if let Some(rf) = response_format {
cd.push_str(&format!(
"\n\n## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{}",
to_json(rf)
));
}
cd.push_str(&format!("\n\n# The user's message is: {content}"));
prompt.push_str(&format!("<|User|>{cd}<|Assistant|>"));
if index as i64 == last_user_idx && thinking_mode == "thinking" {
prompt.push_str(THINK_START);
} else {
prompt.push_str(THINK_END);
}
}
"user" => {
prompt.push_str(&format!("<|User|>{content}<|Assistant|>"));
if index as i64 == last_user_idx && thinking_mode == "thinking" {
prompt.push_str(THINK_START);
} else {
prompt.push_str(THINK_END);
}
}
"tool" => {
let mut prev = index as i64 - 1;
while prev >= 0
&& messages[prev as usize].get("role").and_then(Value::as_str) == Some("tool")
{
prev -= 1;
}
let assistant = &messages[prev.max(0) as usize];
let order = index as i64 - prev;
let assistant_tcs = assistant
.get("tool_calls")
.and_then(Value::as_array)
.map(|a| a.len())
.unwrap_or(0);
if order == 1 {
prompt.push_str("\n\n<function_results>");
}
prompt.push_str(&format!("\n<result>{content}</result>"));
if order as usize == assistant_tcs {
prompt.push_str("\n</function_results>");
if index as i64 >= last_user_idx && thinking_mode == "thinking" {
prompt.push_str(&format!("\n\n{THINK_START}"));
} else {
prompt.push_str(&format!("\n\n{THINK_END}"));
}
}
}
"assistant" => {
let mut thinking_part = String::new();
let mut tool_calls_content = String::new();
if let Some(Value::Array(tcs)) = &tool_calls {
let mut rendered = Vec::new();
for tc in tcs {
let name = tc.get("name").and_then(Value::as_str).unwrap_or("");
rendered.push(format!(
"<{DSML}invoke name=\"{name}\">\n{}\n</{DSML}invoke>",
encode_arguments_to_dsml(tc)?
));
}
tool_calls_content.push_str(&format!(
"\n\n<{DSML}function_calls>\n{}\n</{DSML}function_calls>",
rendered.join("\n")
));
}
if thinking_mode == "thinking" && index as i64 > last_user_idx {
thinking_part = format!("{reasoning_content}{THINK_END}");
}
prompt.push_str(&format!(
"{thinking_part}{content}{tool_calls_content}{EOS_TOKEN}"
));
}
other => {
return Err(Error::tokenizer(format!(
"deepseek_v32: unknown role: {other}"
)));
}
}
Ok(prompt)
}
fn drop_thinking_messages(messages: &[Value]) -> Vec<Value> {
let last_user_idx = find_last_user_index(messages);
let mut out = Vec::new();
for (idx, msg) in messages.iter().enumerate() {
let role = msg.get("role").and_then(Value::as_str).unwrap_or("");
if ["user", "developer", "system", "tool"].contains(&role) || idx as i64 >= last_user_idx {
out.push(msg.clone());
} else if role == "assistant" {
let mut m = msg.clone();
if let Some(o) = m.as_object_mut() {
o.remove("reasoning_content");
}
out.push(m);
}
}
out
}
impl ChatTemplateOverride for DeepseekV32 {
fn apply(
&self,
messages: &[Value],
tools: Option<&Value>,
add_generation_prompt: bool,
continue_final_message: bool,
enable_thinking: bool,
) -> Result<String, Error> {
let thinking_mode = if enable_thinking { "thinking" } else { "chat" };
let mut full = messages.to_vec();
if thinking_mode == "thinking" {
full = drop_thinking_messages(&full);
}
let mut out = String::from(BOS_TOKEN);
for idx in 0..messages.len() {
out.push_str(&render_message(idx, &full, thinking_mode, tools)?);
}
if continue_final_message && add_generation_prompt {
return Err(Error::tokenizer(
"Only one of continue_final_message or add_generation_prompt can be True",
));
}
let last_role = messages
.last()
.and_then(|m| m.get("role"))
.and_then(Value::as_str);
if !add_generation_prompt
&& last_role == Some("user")
&& let Some(stripped) = out.strip_suffix("<|Assistant|><think>")
{
out = stripped.to_owned();
}
if continue_final_message
&& last_role == Some("assistant")
&& let Some(stripped) = out.strip_suffix(EOS_TOKEN)
{
out = stripped.to_owned();
}
Ok(out)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn coerce_indent_none_and_undefined_yield_none() {
assert_eq!(coerce_indent(&JValue::from(())).unwrap(), None);
assert_eq!(coerce_indent(&JValue::UNDEFINED).unwrap(), None);
}
#[test]
fn coerce_indent_bool_maps_to_one_or_zero_spaces() {
assert_eq!(
coerce_indent(&JValue::from(true)).unwrap(),
Some(b" ".to_vec())
);
assert_eq!(
coerce_indent(&JValue::from(false)).unwrap(),
Some(Vec::new())
);
}
#[test]
fn coerce_indent_str_used_verbatim() {
assert_eq!(
coerce_indent(&JValue::from("\t")).unwrap(),
Some(b"\t".to_vec())
);
assert_eq!(coerce_indent(&JValue::from("")).unwrap(), Some(Vec::new()));
}
#[test]
fn coerce_indent_int_yields_that_many_spaces() {
assert_eq!(
coerce_indent(&JValue::from(4i64)).unwrap(),
Some(b" ".to_vec())
);
assert_eq!(
coerce_indent(&JValue::from(-3i64)).unwrap(),
Some(Vec::new())
);
assert_eq!(
coerce_indent(&JValue::from(1024i64)).unwrap(),
Some(vec![b' '; 1024])
);
}
#[test]
fn coerce_indent_over_cap_is_recoverable_error() {
let err = coerce_indent(&JValue::from(2000i64)).unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidOperation);
}
#[test]
fn coerce_indent_non_int_string_bool_is_error() {
let arr = JValue::from_serialize(serde_json::json!([1, 2]));
let err = coerce_indent(&arr).unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidOperation);
let frac = JValue::from(3.5f64);
assert!(coerce_indent(&frac).is_err());
}
#[test]
fn py_json_dumps_compact_uses_python_separators() {
let v = serde_json::json!({"a": 1, "b": [2, 3]});
let out = py_json_dumps(&v, None).unwrap();
assert_eq!(out, r#"{"a": 1, "b": [2, 3]}"#);
}
#[test]
fn py_json_dumps_pretty_uses_indent_bytes() {
let v = serde_json::json!({"a": 1});
let out = py_json_dumps(&v, Some(b" ")).unwrap();
assert_eq!(out, "{\n \"a\": 1\n}");
}
#[test]
fn strip_generation_tags_fast_path_no_keyword() {
let t = "{% for x in y %}{{ x }}{% endfor %}";
assert!(matches!(
strip_generation_tags(t),
std::borrow::Cow::Borrowed(_)
));
}
#[test]
fn strip_generation_tags_rewrites_block_and_preserves_raw() {
let t = "{% generation %}X{% raw %}{% generation %}{% endraw %}Y{% endgeneration %}";
let out = strip_generation_tags(t);
let s = out.as_ref();
assert!(s.starts_with("{% if true %}X{% raw %}{% generation %}{% endraw %}Y"));
assert!(s.ends_with("{% endif %}"));
assert!(s.contains("{% raw %}{% generation %}{% endraw %}"));
}
#[test]
fn strip_generation_tags_preserves_ws_control_markers() {
let out = strip_generation_tags("{%- generation +%}body{%+ endgeneration -%}");
assert_eq!(out.as_ref(), "{%- if true +%}body{%+ endif -%}");
}
#[test]
fn strip_generation_tags_unterminated_tag_copied_to_eof() {
let out = strip_generation_tags("before {% generation");
assert_eq!(out.as_ref(), "before {% generation");
}
#[test]
fn strip_generation_tags_lone_brace_is_literal_text() {
let out = strip_generation_tags("{x has generation word");
assert_eq!(out.as_ref(), "{x has generation word");
}
#[test]
fn strip_generation_tags_non_target_tag_passed_through() {
let out = strip_generation_tags("{% if generation %}hi{% endif %}");
assert_eq!(out.as_ref(), "{% if generation %}hi{% endif %}");
}
#[test]
fn find_comment_close_terminated_and_unterminated() {
assert_eq!(find_comment_close("{# abc #} z", 2), "{# abc #}".len());
let t = "{# no close";
assert_eq!(find_comment_close(t, 2), t.len());
}
#[test]
fn find_expr_close_string_escape_and_unterminated() {
let t = "{{ \"a\\\"}}b\" }}";
assert_eq!(find_expr_close(t, 2), t.len());
let u = "{{ x + y";
assert_eq!(find_expr_close(u, 2), u.len());
}
#[test]
fn find_tag_close_string_escape_and_unterminated() {
let t = "{% x \"a\\\"b\" %}";
let close = find_tag_close(t, 2).unwrap();
assert_eq!(&t[close..close + 2], "%}");
let s = "{% set q = \"%}\" %}";
let close2 = find_tag_close(s, 2).unwrap();
assert_eq!(&s[close2..close2 + 2], "%}");
assert!(close2 > s.find("\"%}\"").unwrap());
assert_eq!(find_tag_close("{% foo", 2), None);
}
#[test]
fn match_endraw_delim_branches() {
assert_eq!(match_endraw_delim("xy", 0), None);
assert_eq!(
match_endraw_delim("{% endraw %}", 0),
Some("{% endraw %}".len())
);
assert_eq!(
match_endraw_delim("{%- endraw -%}", 0),
Some("{%- endraw -%}".len())
);
assert_eq!(
match_endraw_delim("{% endraw +%}", 0),
Some("{% endraw +%}".len())
);
assert_eq!(match_endraw_delim("{%endraw-%", 0), None);
assert_eq!(match_endraw_delim("{% endraw xx", 0), None);
assert_eq!(match_endraw_delim("{% endfor %}", 0), None);
}
#[test]
fn find_endraw_skips_interior_and_handles_unterminated() {
let t = "raw {% foo %} body {% endraw %} after";
let end = find_endraw(t, 0);
assert_eq!(&t[..end], "raw {% foo %} body {% endraw %}");
let u = "raw body {% foo %} no end";
assert_eq!(find_endraw(u, 0), u.len());
}
#[test]
fn utf8_char_len_all_widths() {
assert_eq!(utf8_char_len(0x41), 1); assert_eq!(utf8_char_len(0xC3), 2); assert_eq!(utf8_char_len(0xE2), 3); assert_eq!(utf8_char_len(0xF0), 4); assert_eq!(utf8_char_len(0x80), 1); }
#[test]
fn continue_final_message_mutate_appends_sentinel() {
let msgs = serde_json::json!([{"role": "assistant", "content": "Hello"}]);
let (mutated, original) = continue_final_message_mutate(&msgs).unwrap();
assert_eq!(original, "Hello");
let last = mutated.as_array().unwrap().last().unwrap();
assert_eq!(
last.get("content").unwrap().as_str().unwrap(),
"HelloCONTINUE_FINAL_MESSAGE_TAG "
);
}
#[test]
fn continue_final_message_mutate_errors() {
let not_list = serde_json::json!({"role": "user"});
assert!(matches!(
continue_final_message_mutate(¬_list).unwrap_err(),
Error::Tokenizer(_)
));
let empty = serde_json::json!([]);
assert!(matches!(
continue_final_message_mutate(&empty).unwrap_err(),
Error::Tokenizer(_)
));
let no_str = serde_json::json!([{"role": "assistant", "content": 42}]);
assert!(matches!(
continue_final_message_mutate(&no_str).unwrap_err(),
Error::Tokenizer(_)
));
}
#[test]
fn continue_final_message_trim_full_tag_plain_truncates() {
let rendered = "Hi there CONTINUE_FINAL_MESSAGE_TAG <eot>";
let out = continue_final_message_trim(rendered, "Hi there").unwrap();
assert_eq!(out, "Hi there ");
}
#[test]
fn continue_final_message_trim_transformed_tag_rstrips() {
let rendered = "Hi there CONTINUE_FINAL_MESSAGE_TAG";
let out = continue_final_message_trim(rendered, "Hi there").unwrap();
assert_eq!(out, "Hi there");
}
#[test]
fn continue_final_message_trim_guard_errors() {
assert!(matches!(
continue_final_message_trim("no sentinel", "x").unwrap_err(),
Error::Tokenizer(_)
));
let rendered = "unrelated CONTINUE_FINAL_MESSAGE_TAG tail";
assert!(matches!(
continue_final_message_trim(rendered, "missing-content").unwrap_err(),
Error::Tokenizer(_)
));
}
#[test]
fn continue_final_message_trim_empty_content_is_valid() {
let rendered = "prefixCONTINUE_FINAL_MESSAGE_TAG suffix";
let out = continue_final_message_trim(rendered, "").unwrap();
assert_eq!(out, "prefix");
}
#[test]
fn render_jinja_basic_with_generation_prompt() {
let messages = serde_json::json!([{"role": "user", "content": "hi"}]);
let out = render_jinja(
"{% for m in messages %}{{ m.role }}: {{ m.content }}|{% endfor %}{% if add_generation_prompt %}assistant:{% endif %}",
&messages,
None,
true,
false,
None,
None,
false,
&Value::Null,
)
.unwrap();
assert_eq!(out, "user: hi|assistant:");
}
#[test]
fn render_jinja_no_generation_prompt_omits_suffix() {
let messages = serde_json::json!([{"role": "user", "content": "hi"}]);
let out = render_jinja(
"{% for m in messages %}{{ m.content }}{% endfor %}{% if add_generation_prompt %}!{% endif %}",
&messages,
None,
false,
false,
None,
None,
false,
&Value::Null,
)
.unwrap();
assert_eq!(out, "hi");
}
#[test]
fn render_jinja_bos_eos_tools_extra_in_context() {
let messages = serde_json::json!([{"role": "user", "content": "x"}]);
let tools = serde_json::json!([{"name": "f"}]);
let extra = serde_json::json!({"custom": "C"});
let out = render_jinja(
"{{ bos_token }}{{ eos_token }}{{ tools[0].name }}{{ custom }}{% if documents is none %}D{% endif %}",
&messages,
Some(&tools),
false,
false,
Some("<s>"),
Some("</s>"),
false,
&extra,
)
.unwrap();
assert_eq!(out, "<s></s>fCD");
}
#[test]
fn render_jinja_tojson_compact_and_indented() {
let messages = serde_json::json!([{"role": "user", "content": {"a": 1, "b": 2}}]);
let out = render_jinja(
"{{ messages[0].content | tojson }}",
&messages,
None,
false,
false,
None,
None,
false,
&Value::Null,
)
.unwrap();
assert_eq!(out, r#"{"a": 1, "b": 2}"#);
let out2 = render_jinja(
"{{ messages[0].content | tojson(indent=2) }}",
&messages,
None,
false,
false,
None,
None,
false,
&Value::Null,
)
.unwrap();
assert_eq!(out2, "{\n \"a\": 1,\n \"b\": 2\n}");
}
#[test]
fn render_jinja_tojson_indent_none_kwarg_compact() {
let messages = serde_json::json!([{"role": "user", "content": [1, 2]}]);
let out = render_jinja(
"{{ messages[0].content | tojson(indent=none) }}",
&messages,
None,
false,
false,
None,
None,
false,
&Value::Null,
)
.unwrap();
assert_eq!(out, "[1, 2]");
}
#[test]
fn render_jinja_tojson_bad_indent_is_render_error() {
let messages = serde_json::json!([{"role": "user", "content": 1}]);
let err = render_jinja(
"{{ messages[0].content | tojson(indent=[1, 2]) }}",
&messages,
None,
false,
false,
None,
None,
false,
&Value::Null,
)
.unwrap_err();
assert!(matches!(err, Error::Tokenizer(_)));
}
#[test]
fn render_jinja_raise_exception_is_error() {
let messages = serde_json::json!([{"role": "user", "content": "x"}]);
let err = render_jinja(
"{{ raise_exception('boom') }}",
&messages,
None,
false,
false,
None,
None,
false,
&Value::Null,
)
.unwrap_err();
assert!(matches!(err, Error::Tokenizer(_)));
assert!(err.to_string().contains("boom"));
}
#[test]
fn render_jinja_strftime_now_renders_some_digits() {
let messages = serde_json::json!([{"role": "user", "content": "x"}]);
let out = render_jinja(
"{{ strftime_now('%Y') }}",
&messages,
None,
false,
false,
None,
None,
false,
&Value::Null,
)
.unwrap();
assert_eq!(out.len(), 4);
assert!(out.chars().all(|c| c.is_ascii_digit()));
}
#[test]
fn render_jinja_strftime_now_bad_directive_is_error() {
let messages = serde_json::json!([{"role": "user", "content": "x"}]);
let err = render_jinja(
"{{ strftime_now('%Q') }}",
&messages,
None,
false,
false,
None,
None,
false,
&Value::Null,
)
.unwrap_err();
assert!(matches!(err, Error::Tokenizer(_)));
}
#[test]
fn render_jinja_parse_error_is_typed() {
let messages = serde_json::json!([{"role": "user", "content": "x"}]);
let err = render_jinja(
"{% for x in %}",
&messages,
None,
false,
false,
None,
None,
false,
&Value::Null,
)
.unwrap_err();
assert!(matches!(err, Error::Tokenizer(_)));
}
#[test]
fn render_jinja_continue_final_message_trims_at_content() {
let messages = serde_json::json!([
{"role": "user", "content": "Q"},
{"role": "assistant", "content": "partial"}
]);
let out = render_jinja(
"{% for m in messages %}{{ m.content }}<eot>{% endfor %}",
&messages,
None,
false,
true,
None,
None,
false,
&Value::Null,
)
.unwrap();
assert_eq!(out, "Q<eot>partial");
}
#[test]
fn strftime_at_fixed_datetime_is_deterministic() {
let dt = jiff::civil::date(2024, 3, 5).at(13, 7, 9, 0);
assert_eq!(strftime_at(dt, "%Y-%m-%d").unwrap(), "2024-03-05");
assert_eq!(strftime_at(dt, "%H:%M:%S").unwrap(), "13:07:09");
assert_eq!(strftime_at(dt, "[%z%Z]").unwrap(), "[]");
assert_eq!(strftime_at(dt, "100%%").unwrap(), "100%");
}
#[test]
fn strftime_at_bad_directive_errors_not_panics() {
let dt = jiff::civil::date(2024, 1, 1).at(0, 0, 0, 0);
assert!(strftime_at(dt, "%Q").is_err());
}
#[test]
fn strip_naive_unsupported_branches() {
assert_eq!(strip_naive_unsupported("%z%Z%%zx"), "%%zx");
assert_eq!(strip_naive_unsupported("%Y-%d"), "%Y-%d");
assert_eq!(strip_naive_unsupported("end%"), "end%");
assert_eq!(strip_naive_unsupported("plain"), "plain");
}
}
#[cfg(all(test, feature = "tokenizer-deepseek-v32"))]
mod deepseek_v32_tests {
use serde_json::Value;
use super::{ChatTemplateOverride, DeepseekV32, override_by_name};
use crate::Error;
#[allow(clippy::needless_pass_by_value)]
fn render(
messages: Value,
tools: Option<Value>,
add_generation_prompt: bool,
continue_final_message: bool,
enable_thinking: bool,
) -> Result<String, Error> {
let arr = messages.as_array().unwrap().clone();
DeepseekV32.apply(
&arr,
tools.as_ref(),
add_generation_prompt,
continue_final_message,
enable_thinking,
)
}
#[test]
fn override_by_name_known_and_unknown() {
assert!(override_by_name("deepseek_v32").is_some());
assert!(override_by_name("does_not_exist").is_none());
}
#[test]
fn apply_user_chat_mode_appends_think_end() {
let out = render(
serde_json::json!([{"role": "user", "content": "hello"}]),
None,
true,
false,
false,
)
.unwrap();
assert_eq!(
out,
"<|begin▁of▁sentence|><|User|>hello<|Assistant|></think>"
);
}
#[test]
fn apply_user_no_generation_prompt_thinking_strips_think_suffix() {
let out = render(
serde_json::json!([{"role": "user", "content": "hi"}]),
None,
false,
false,
true,
)
.unwrap();
assert_eq!(out, "<|begin▁of▁sentence|><|User|>hi");
}
#[test]
fn apply_assistant_thinking_includes_reasoning_and_eos() {
let out = render(
serde_json::json!([
{"role": "user", "content": "q"},
{"role": "assistant", "content": "answer", "reasoning_content": "because"}
]),
None,
true,
false,
true,
)
.unwrap();
assert!(out.contains("because</think>answer"));
assert!(out.ends_with("<|end▁of▁sentence|>"));
}
#[test]
fn apply_assistant_tool_calls_render_dsml() {
let out = render(
serde_json::json!([
{"role": "user", "content": "go"},
{"role": "assistant", "content": "", "tool_calls": [
{"function": {"name": "search", "arguments": {"q": "cats", "n": 5}}}
]}
]),
None,
true,
false,
false,
)
.unwrap();
assert!(out.contains("<|DSML|invoke name=\"search\">"));
assert!(out.contains("<|DSML|function_calls>"));
assert!(out.contains("name=\"q\" string=\"true\">cats"));
assert!(out.contains("name=\"n\" string=\"false\">5"));
}
#[test]
fn apply_assistant_tool_call_arguments_as_json_string() {
let out = render(
serde_json::json!([
{"role": "user", "content": "go"},
{"role": "assistant", "content": "", "tool_calls": [
{"function": {"name": "f", "arguments": "{\"k\": \"v\"}"}}
]}
]),
None,
true,
false,
false,
)
.unwrap();
assert!(out.contains("name=\"k\" string=\"true\">v"));
}
#[test]
fn apply_assistant_tool_call_bad_arguments_json_errors() {
let err = render(
serde_json::json!([
{"role": "user", "content": "go"},
{"role": "assistant", "content": "", "tool_calls": [
{"function": {"name": "f", "arguments": "not json"}}
]}
]),
None,
true,
false,
false,
)
.unwrap_err();
assert!(matches!(err, Error::Tokenizer(_)));
}
#[test]
fn apply_assistant_tool_call_arguments_not_object_errors() {
let err = render(
serde_json::json!([
{"role": "user", "content": "go"},
{"role": "assistant", "content": "", "tool_calls": [
{"function": {"name": "f", "arguments": [1, 2]}}
]}
]),
None,
true,
false,
false,
)
.unwrap_err();
assert!(matches!(err, Error::Tokenizer(_)));
}
#[test]
fn apply_system_with_tools_and_response_format() {
let out = render(
serde_json::json!([
{
"role": "system",
"content": "SYS",
"tools": [{"type": "function", "function": {"name": "g"}}],
"response_format": {"type": "json"}
},
{"role": "user", "content": "u"}
]),
None,
true,
false,
false,
)
.unwrap();
assert!(out.contains("SYS"));
assert!(out.contains("<functions>"));
assert!(out.contains("<|DSML|function_calls>"));
assert!(out.contains("\"name\":\"g\""));
assert!(out.contains("## Response Format:"));
}
#[test]
fn apply_developer_role_renders_like_user_with_tools() {
let out = render(
serde_json::json!([
{
"role": "developer",
"content": "DEV",
"tools": [{"name": "h"}],
"response_format": {"type": "x"}
}
]),
None,
false,
false,
false,
)
.unwrap();
assert!(out.contains("# The user's message is: DEV"));
assert!(out.contains("## Response Format:"));
assert!(out.contains("<|User|>"));
assert!(out.contains("<|Assistant|>"));
}
#[test]
fn apply_tools_arg_non_array_yields_empty_schemas() {
let out = render(
serde_json::json!([
{"role": "system", "content": "S"},
{"role": "user", "content": "u"}
]),
Some(serde_json::json!({"not": "an array"})),
true,
false,
false,
)
.unwrap();
assert!(out.contains("<functions>"));
}
#[test]
fn apply_tool_role_function_results_block() {
let out = render(
serde_json::json!([
{"role": "user", "content": "q"},
{"role": "assistant", "content": "", "tool_calls": [
{"function": {"name": "f", "arguments": {}}}
]},
{"role": "tool", "content": "RESULT"}
]),
None,
true,
false,
false,
)
.unwrap();
assert!(out.contains("<function_results>"));
assert!(out.contains("<result>RESULT</result>"));
assert!(out.contains("</function_results>"));
}
#[test]
fn apply_unknown_role_errors() {
let err = render(
serde_json::json!([{"role": "wizard", "content": "x"}]),
None,
false,
false,
false,
)
.unwrap_err();
assert!(matches!(err, Error::Tokenizer(_)));
assert!(err.to_string().contains("unknown role"));
}
#[test]
fn apply_continue_and_generation_prompt_rejected() {
let err = render(
serde_json::json!([{"role": "user", "content": "x"}]),
None,
true,
true,
false,
)
.unwrap_err();
assert!(matches!(err, Error::Tokenizer(_)));
}
#[test]
fn apply_continue_final_message_strips_eos_for_assistant() {
let out = render(
serde_json::json!([
{"role": "user", "content": "q"},
{"role": "assistant", "content": "draft"}
]),
None,
false,
true,
false,
)
.unwrap();
assert!(out.ends_with("draft"));
assert!(!out.ends_with("<|end▁of▁sentence|>"));
}
#[test]
fn apply_thinking_drops_prior_assistant_reasoning() {
let out = render(
serde_json::json!([
{"role": "assistant", "content": "earlier", "reasoning_content": "SECRET_THOUGHT"},
{"role": "user", "content": "now"}
]),
None,
true,
false,
true,
)
.unwrap();
assert!(!out.contains("SECRET_THOUGHT"));
assert!(out.contains("earlier"));
assert!(out.contains("<|User|>now"));
}
#[test]
fn apply_no_user_message_find_last_user_index_negative() {
let out = render(
serde_json::json!([{"role": "assistant", "content": "solo"}]),
None,
false,
false,
false,
)
.unwrap();
assert!(out.contains("solo"));
assert!(out.ends_with("<|end▁of▁sentence|>"));
}
}