pub const BOS: &str = "<|startoftext|>";
pub const IM_START: &str = "<|im_start|>";
pub const IM_END: &str = "<|im_end|>";
pub const PAD: &str = "<|pad|>";
pub const IMAGE_TOKEN: &str = "<image>";
pub const IMAGE_START: &str = "<|image_start|>";
pub const IMAGE_END: &str = "<|image_end|>";
pub const IMAGE_THUMBNAIL: &str = "<|img_thumbnail|>";
pub const TOOL_CALL_START: &str = "<|tool_call_start|>";
pub const TOOL_CALL_END: &str = "<|tool_call_end|>";
pub const BOS_TOKEN_ID: u32 = 1;
pub const IM_START_TOKEN_ID: u32 = 6;
pub const EOS_TOKEN_ID: u32 = 7;
pub const PAD_TOKEN_ID: u32 = 0;
pub const TOOL_CALL_START_TOKEN_ID: u32 = 10;
pub const TOOL_CALL_END_TOKEN_ID: u32 = 11;
pub const IMAGE_TOKEN_ID: u32 = 396;
pub const IMAGE_THUMBNAIL_TOKEN_ID: u32 = 497;
pub const IMAGE_START_TOKEN_ID: u32 = 498;
pub const IMAGE_END_TOKEN_ID: u32 = 499;
pub const IMG_ROW_COL_BASE_ID: u32 = 397;
#[cfg(feature = "inference")]
pub const BUNDLED_CHAT_TEMPLATE_JINJA: &str = include_str!("../models/chat_template.jinja");
pub fn expand_image_placeholders(
prompt: &str,
images: &[ImagePlaceholderInfo],
) -> crate::error::Result<String> {
let pieces: Vec<&str> = prompt.split(IMAGE_TOKEN).collect();
let placeholder_count = pieces.len() - 1;
if placeholder_count != images.len() {
return Err(crate::error::Error::ImageTokenCountMismatch {
expected: images.len(),
got: placeholder_count,
});
}
let mut out = String::with_capacity(prompt.len() + 8 * 1024 * images.len());
for (i, piece) in pieces.iter().enumerate() {
out.push_str(piece);
if i < images.len() {
build_image_block(&mut out, &images[i]);
}
}
Ok(out)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ImagePlaceholderInfo {
rows: usize,
cols: usize,
tokens_per_main_tile: usize,
thumbnail_tokens: Option<usize>,
}
impl ImagePlaceholderInfo {
pub const fn new(
rows: usize,
cols: usize,
tokens_per_main_tile: usize,
thumbnail_tokens: Option<usize>,
) -> Self {
Self {
rows,
cols,
tokens_per_main_tile,
thumbnail_tokens,
}
}
pub const fn rows(&self) -> usize {
self.rows
}
pub const fn cols(&self) -> usize {
self.cols
}
pub const fn tokens_per_main_tile(&self) -> usize {
self.tokens_per_main_tile
}
pub const fn thumbnail_tokens(&self) -> Option<usize> {
self.thumbnail_tokens
}
pub fn set_rows(&mut self, rows: usize) {
self.rows = rows;
}
pub fn set_cols(&mut self, cols: usize) {
self.cols = cols;
}
pub fn set_tokens_per_main_tile(&mut self, tokens_per_main_tile: usize) {
self.tokens_per_main_tile = tokens_per_main_tile;
}
pub fn set_thumbnail_tokens(&mut self, thumbnail_tokens: Option<usize>) {
self.thumbnail_tokens = thumbnail_tokens;
}
pub const fn with_rows(mut self, rows: usize) -> Self {
self.rows = rows;
self
}
pub const fn with_cols(mut self, cols: usize) -> Self {
self.cols = cols;
self
}
pub const fn with_tokens_per_main_tile(mut self, tokens_per_main_tile: usize) -> Self {
self.tokens_per_main_tile = tokens_per_main_tile;
self
}
pub const fn with_thumbnail_tokens(mut self, thumbnail_tokens: Option<usize>) -> Self {
self.thumbnail_tokens = thumbnail_tokens;
self
}
pub const fn num_image_tokens(&self) -> usize {
self.rows * self.cols * self.tokens_per_main_tile
+ match self.thumbnail_tokens {
Some(n) => n,
None => 0,
}
}
}
fn build_image_block(out: &mut String, img: &ImagePlaceholderInfo) {
out.push_str(IMAGE_START);
if img.rows() > 1 || img.cols() > 1 {
for outer in 0..img.cols() {
for inner in 0..img.rows() {
out.push_str("<|img_row_");
push_usize(out, outer + 1);
out.push_str("_col_");
push_usize(out, inner + 1);
out.push_str("|>");
for _ in 0..img.tokens_per_main_tile() {
out.push_str(IMAGE_TOKEN);
}
}
}
if let Some(thumb) = img.thumbnail_tokens() {
out.push_str(IMAGE_THUMBNAIL);
for _ in 0..thumb {
out.push_str(IMAGE_TOKEN);
}
}
} else {
let total = img.num_image_tokens();
for _ in 0..total {
out.push_str(IMAGE_TOKEN);
}
}
out.push_str(IMAGE_END);
}
fn push_usize(out: &mut String, n: usize) {
use std::fmt::Write as _;
let _ = write!(out, "{n}");
}
#[cfg(feature = "inference")]
mod render {
use super::*;
use serde::Serialize;
use std::sync::OnceLock;
fn stripped_template() -> &'static str {
static CELL: OnceLock<String> = OnceLock::new();
CELL.get_or_init(|| {
BUNDLED_CHAT_TEMPLATE_JINJA
.replace("{%- generation -%}", "")
.replace("{%- endgeneration -%}", "")
})
}
pub fn apply_chat_template(
messages: &[Message<'_>],
tools: Option<&serde_json::Value>,
add_generation_prompt: bool,
) -> crate::error::Result<String> {
use minijinja::{Environment, Value};
let mut env = Environment::new();
env.add_function(
"strftime_now",
|_fmt: String| -> std::result::Result<String, minijinja::Error> { Ok(today_yyyymmdd()) },
);
let tmpl = env
.template_from_str(stripped_template())
.map_err(crate::error::Error::tokenizer)?;
let ctx = Value::from_serialize(&RenderContext {
bos_token: BOS,
messages,
tools,
add_generation_prompt,
});
tmpl.render(ctx).map_err(crate::error::Error::tokenizer)
}
#[derive(Serialize)]
struct RenderContext<'a> {
bos_token: &'a str,
messages: &'a [Message<'a>],
tools: Option<&'a serde_json::Value>,
add_generation_prompt: bool,
}
#[cfg(test)]
pub(super) fn today_yyyymmdd_for_test() -> String {
today_yyyymmdd()
}
fn today_yyyymmdd() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let days = (secs / 86400) as i64;
let z = days + 719_468;
let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
let doe = (z - era * 146_097) as u64;
let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146_096) / 365;
let y_base = yoe as i64 + era * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
let mp = (5 * doy + 2) / 153;
let d = (doy - (153 * mp + 2) / 5 + 1) as u32;
let m = (if mp < 10 { mp + 3 } else { mp - 9 }) as u32;
let y = if m <= 2 { y_base + 1 } else { y_base };
format!("{y:04}-{m:02}-{d:02}")
}
}
#[cfg(feature = "inference")]
#[cfg_attr(docsrs, doc(cfg(feature = "inference")))]
pub use render::apply_chat_template;
#[cfg(feature = "inference")]
#[cfg_attr(docsrs, doc(cfg(feature = "inference")))]
#[derive(Debug, Clone, serde::Serialize)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum Message<'a> {
System {
content: &'a str,
},
User {
content: UserContent<'a>,
},
Assistant {
content: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<&'a str>,
},
}
#[cfg(feature = "inference")]
#[cfg_attr(docsrs, doc(cfg(feature = "inference")))]
#[derive(Debug, Clone, serde::Serialize)]
#[serde(untagged)]
pub enum UserContent<'a> {
Text(&'a str),
Multimodal(Vec<ContentItem<'a>>),
}
#[cfg(feature = "inference")]
#[cfg_attr(docsrs, doc(cfg(feature = "inference")))]
#[derive(Debug, Clone, serde::Serialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ContentItem<'a> {
Image,
Text {
text: &'a str,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn special_token_ids_match_tokenizer_json() {
use serde_json::Value;
let tok_raw = include_str!("../models/tokenizer.json");
let tok: Value = serde_json::from_str(tok_raw).expect("tokenizer.json must be valid JSON");
let added: &Vec<Value> = tok["added_tokens"].as_array().expect("added_tokens array");
let mut found: std::collections::HashMap<String, u64> = std::collections::HashMap::new();
for entry in added {
if let (Some(id), Some(content)) = (entry["id"].as_u64(), entry["content"].as_str()) {
found.insert(content.to_string(), id);
}
}
assert_eq!(
found.get(BOS).copied(),
Some(BOS_TOKEN_ID as u64),
"BOS_TOKEN_ID"
);
assert_eq!(
found.get(IM_END).copied(),
Some(EOS_TOKEN_ID as u64),
"EOS_TOKEN_ID = {}",
IM_END
);
assert_eq!(
found.get(PAD).copied(),
Some(PAD_TOKEN_ID as u64),
"PAD_TOKEN_ID"
);
assert_eq!(
found.get(IMAGE_TOKEN).copied(),
Some(IMAGE_TOKEN_ID as u64),
"IMAGE_TOKEN_ID"
);
}
#[test]
fn expand_count_mismatch() {
let r = expand_image_placeholders("Hello <image>", &[]);
assert!(matches!(
r,
Err(crate::error::Error::ImageTokenCountMismatch {
expected: 0,
got: 1
})
));
}
#[test]
fn expand_single_tile() {
let info = ImagePlaceholderInfo::new(1, 1, 64, None);
let out = expand_image_placeholders("X<image>Y", &[info]).unwrap();
assert!(out.starts_with("X<|image_start|>"));
assert!(out.ends_with("<|image_end|>Y"));
assert_eq!(out.matches("<image>").count(), 64);
}
#[test]
fn expand_multi_tile_with_thumbnail() {
let info = ImagePlaceholderInfo::new(2, 2, 256, Some(64));
let out = expand_image_placeholders("<image>", &[info]).unwrap();
assert_eq!(out.matches("<image>").count(), 4 * 256 + 64);
assert!(out.contains("<|img_row_1_col_1|>"));
assert!(out.contains("<|img_row_1_col_2|>"));
assert!(out.contains("<|img_row_2_col_1|>"));
assert!(out.contains("<|img_row_2_col_2|>"));
assert!(out.contains("<|img_thumbnail|>"));
}
#[test]
fn expand_multi_image_preserves_order() {
let a = ImagePlaceholderInfo::new(1, 1, 1, None);
let b = ImagePlaceholderInfo::new(1, 1, 2, None);
let out = expand_image_placeholders("A<image>B<image>C", &[a, b]).unwrap();
assert_eq!(
out,
"A<|image_start|><image><|image_end|>B<|image_start|><image><image><|image_end|>C"
);
}
#[test]
fn expand_image_placeholders_matches_fixtures() {
use serde_json::Value;
let raw = include_str!("../tests/fixtures/image_expansion_cases.json");
let cases: Value = serde_json::from_str(raw).expect("fixture must be valid JSON");
let cases_arr = cases.as_array().expect("fixture must be an array");
let mut failures: Vec<String> = Vec::new();
for case in cases_arr {
let name = case["name"].as_str().expect("name");
let prompt = case["prompt"].as_str().expect("prompt");
let expected = case["expected"].as_str().expect("expected");
let info_v = &case["info"];
let info = ImagePlaceholderInfo::new(
info_v["rows"].as_u64().unwrap() as usize,
info_v["cols"].as_u64().unwrap() as usize,
info_v["tokens_per_main_tile"].as_u64().unwrap() as usize,
info_v["thumbnail_tokens"].as_u64().map(|n| n as usize),
);
match expand_image_placeholders(prompt, &[info]) {
Ok(actual) if actual == expected => {}
Ok(actual) => failures.push(format!(
"case {name}:\n expected={expected}\n actual ={actual}"
)),
Err(e) => failures.push(format!("case {name}: error: {e}")),
}
}
assert!(
failures.is_empty(),
"{} of {} expansion cases failed:\n{}",
failures.len(),
cases_arr.len(),
failures.join("\n")
);
}
#[cfg(feature = "inference")]
#[test]
fn apply_chat_template_matches_upstream_fixtures() {
use serde_json::Value;
let raw = include_str!("../tests/fixtures/chat_template_cases.json");
let cases: Value = serde_json::from_str(raw).expect("fixture must be valid JSON");
let cases_arr = cases.as_array().expect("fixture must be a JSON array");
let mut failures: Vec<String> = Vec::new();
for case in cases_arr {
let name = case["name"].as_str().expect("each case has a name");
let expected_raw = case["expected"]
.as_str()
.expect("each case has expected output");
let messages = case["messages"].as_array().expect("messages is an array");
let add_gen = case
.get("add_generation_prompt")
.and_then(|v| v.as_bool())
.unwrap_or(true);
let tools: Option<Value> = case.get("tools").cloned();
let owned_msgs: Vec<OwnedMsg> = messages.iter().map(OwnedMsg::from_value).collect();
let msg_refs: Vec<Message<'_>> = owned_msgs.iter().map(OwnedMsg::as_ref).collect();
let tools_ref = tools.as_ref();
let today = super::render::today_yyyymmdd_for_test();
let expected = expected_raw.replace("__DATE__", &today);
match apply_chat_template(&msg_refs, tools_ref, add_gen) {
Ok(actual) if actual == expected => {}
Ok(actual) => failures.push(format!(
"case {name}: actual differs from expected\n--- actual ---\n{actual}\n--- expected ---\n{expected}",
)),
Err(e) => failures.push(format!("case {name}: render failed: {e}")),
}
}
assert!(
failures.is_empty(),
"{} of {} cases failed:\n{}",
failures.len(),
cases_arr.len(),
failures.join("\n\n")
);
}
#[cfg(feature = "inference")]
enum OwnedMsg {
System(String),
UserText(String),
UserMulti(Vec<OwnedItem>),
Assistant {
content: String,
thinking: Option<String>,
},
}
#[cfg(feature = "inference")]
#[allow(dead_code)]
enum OwnedItem {
Image,
Text(String),
}
#[cfg(feature = "inference")]
impl OwnedMsg {
fn from_value(v: &serde_json::Value) -> Self {
let role = v["role"].as_str().unwrap_or("");
match role {
"system" => Self::System(v["content"].as_str().unwrap_or("").to_string()),
"user" => match &v["content"] {
serde_json::Value::String(s) => Self::UserText(s.clone()),
serde_json::Value::Array(items) => Self::UserMulti(
items
.iter()
.map(|i| match i["type"].as_str() {
Some("image") => OwnedItem::Image,
Some("text") => OwnedItem::Text(i["text"].as_str().unwrap_or("").to_string()),
_ => OwnedItem::Text(String::new()),
})
.collect(),
),
_ => Self::UserText(String::new()),
},
"assistant" => Self::Assistant {
content: v["content"].as_str().unwrap_or("").to_string(),
thinking: v.get("thinking").and_then(|t| t.as_str()).map(String::from),
},
_ => Self::UserText(String::new()),
}
}
fn as_ref(&self) -> Message<'_> {
match self {
Self::System(c) => Message::System { content: c },
Self::UserText(c) => Message::User {
content: UserContent::Text(c),
},
Self::UserMulti(items) => Message::User {
content: UserContent::Multimodal(
items
.iter()
.map(|i| match i {
OwnedItem::Image => ContentItem::Image,
OwnedItem::Text(t) => ContentItem::Text { text: t },
})
.collect(),
),
},
Self::Assistant { content, thinking } => Message::Assistant {
content,
thinking: thinking.as_deref(),
},
}
}
}
}