use crate::{
array::Array,
error::{
ArithmeticOverflowPayload, CapExceededPayload, EmptyInputPayload, Error,
InvariantViolationPayload, LengthMismatchPayload, MissingFieldPayload, OutOfRangePayload,
Result, try_to_vec, try_with_capacity,
},
};
pub const MAX_MESSAGE_FORMAT_ITEMS: usize = 1024;
fn check_format_count(count: usize, label: &'static str, _model_name: &str) -> Result<()> {
if count > MAX_MESSAGE_FORMAT_ITEMS {
return Err(Error::CapExceeded(CapExceededPayload::new(
label,
"MAX_MESSAGE_FORMAT_ITEMS",
MAX_MESSAGE_FORMAT_ITEMS as u64,
count as u64,
)));
}
Ok(())
}
pub type ImageTokenSpans = Vec<(usize, usize)>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, derive_more::Display, derive_more::IsVariant)]
#[display("{}", self.as_str())]
pub enum MarkerPolicy {
Required,
PrependIfAbsent,
}
impl MarkerPolicy {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Required => "required",
Self::PrependIfAbsent => "prepend_if_absent",
}
}
}
pub fn locate_image_tokens(tokens: &[u32], image_token_id: u32) -> ImageTokenSpans {
let mut spans = ImageTokenSpans::new();
let mut i = 0;
while i < tokens.len() {
if tokens[i] == image_token_id {
let start = i;
while i < tokens.len() && tokens[i] == image_token_id {
i += 1;
}
spans.push((start, i));
} else {
i += 1;
}
}
spans
}
pub fn insert_image_tokens(
text_tokens: &[u32],
image_count: usize,
image_marker_id: u32,
image_token_id: u32,
num_tokens_per_image: usize,
policy: MarkerPolicy,
) -> Result<Vec<u32>> {
if image_count == 0 {
return try_to_vec(text_tokens);
}
if num_tokens_per_image == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"insert_image_tokens: num_tokens_per_image (with image_count > 0)",
"must be > 0 — otherwise images would silently drop, config/model state is degenerate",
)));
}
let placeholder_total = image_count
.checked_mul(num_tokens_per_image)
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"insert_image_tokens: placeholder_total (image_count * num_tokens_per_image)",
"usize",
[
("image_count", image_count as u64),
("num_tokens_per_image", num_tokens_per_image as u64),
],
))
})?;
if let Some(run_start) = text_tokens.iter().position(|&t| t == image_marker_id) {
let run_end = text_tokens[run_start..]
.iter()
.position(|&t| t != image_marker_id)
.map_or(text_tokens.len(), |off| run_start + off);
let run_len = run_end - run_start;
if text_tokens[run_end..].contains(&image_marker_id) {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"insert_image_tokens: image_marker_id occurrences (after the first contiguous run)",
"must be 0 — the splice supports at most one contiguous marker run \
(mirrors python prompt_utils' `prompt.split(\"<image>\")` 2-chunk contract)",
)));
}
if run_len != image_count {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"insert_image_tokens: contiguous marker run length vs image_count (the chat-template \
producer should emit exactly `marker * image_count` adjacent markers; mismatch \
suggests caller/template skew)",
image_count,
run_len,
)));
}
let cap = text_tokens
.len()
.checked_add(placeholder_total)
.and_then(|n| n.checked_sub(run_len))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"insert_image_tokens: cap (text_len + placeholder_total - run_len)",
"usize",
[
("text_len", text_tokens.len() as u64),
("placeholder_total", placeholder_total as u64),
("run_len", run_len as u64),
],
))
})?;
let mut out: Vec<u32> = try_with_capacity(cap)?;
out.extend_from_slice(&text_tokens[..run_start]);
out.extend(std::iter::repeat_n(image_token_id, placeholder_total));
out.extend_from_slice(&text_tokens[run_end..]);
Ok(out)
} else {
if policy == MarkerPolicy::Required {
return Err(Error::MissingField(MissingFieldPayload::new(
"insert_image_tokens (MarkerPolicy::Required, image_count > 0; chat-template / tokenizer \
drift detected — pass MarkerPolicy::PrependIfAbsent if the model uses the \
PROMPT_WITH_IMAGE_TOKEN-family formatter)",
"image_marker_id token in text_tokens",
)));
}
let cap = text_tokens
.len()
.checked_add(placeholder_total)
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"insert_image_tokens: cap (text_len + placeholder_total)",
"usize",
[
("text_len", text_tokens.len() as u64),
("placeholder_total", placeholder_total as u64),
],
))
})?;
let mut out: Vec<u32> = try_with_capacity(cap)?;
out.extend(std::iter::repeat_n(image_token_id, placeholder_total));
out.extend_from_slice(text_tokens);
Ok(out)
}
}
pub fn build_multimodal_mask(seq_len: usize, image_spans: &[(usize, usize)]) -> Result<Array> {
build_multimodal_mask_with_past(seq_len, 0, image_spans)
}
pub fn build_multimodal_mask_with_past(
seq_len: usize,
past_len: usize,
image_spans: &[(usize, usize)],
) -> Result<Array> {
let total_keys = past_len.checked_add(seq_len).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"build_multimodal_mask_with_past: total_keys (past_len + seq_len)",
"usize",
[("past_len", past_len as u64), ("seq_len", seq_len as u64)],
))
})?;
if seq_len == 0 {
if !image_spans.is_empty() {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"build_multimodal_mask_with_past: image_spans (with seq_len=0)",
"must be empty — an empty chunk cannot contain any image span",
)));
}
return Array::from_slice::<bool>(&[], &(1_usize, 1_usize, 0_usize, total_keys));
}
if total_keys > i32::MAX as usize {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"build_multimodal_mask_with_past: total_keys (past_len + seq_len)",
"must be <= i32::MAX (mlx dimension limit)",
format!("{total_keys}"),
)));
}
let mut sorted: Vec<(usize, usize)> = try_to_vec(image_spans)?;
sorted.sort_unstable_by_key(|&(s, _)| s);
let mut prev_end = 0usize;
for &(s, e) in &sorted {
if s >= e {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"build_multimodal_mask_with_past: image span (start, end)",
"start must be strictly less than end (empty spans not allowed)",
)));
}
if e > seq_len {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"build_multimodal_mask_with_past: chunk-local image span end vs seq_len",
"must be <= seq_len",
format!("end={e}, seq_len={seq_len}"),
)));
}
if s < prev_end {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"build_multimodal_mask_with_past: image span order (s vs prev_end)",
"spans must be monotone non-overlapping",
)));
}
prev_end = e;
}
let total = seq_len.checked_mul(total_keys).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"build_multimodal_mask_with_past: total (seq_len * total_keys)",
"usize",
[
("seq_len", seq_len as u64),
("total_keys", total_keys as u64),
],
))
})?;
let mut block_id: Vec<u32> = try_with_capacity(seq_len)?;
block_id.resize(seq_len, 0);
for (idx, &(s, e)) in sorted.iter().enumerate() {
let block = (idx + 1) as u32;
for slot in block_id.iter_mut().take(e).skip(s) {
*slot = block;
}
}
let mut buf: Vec<bool> = Vec::new();
buf
.try_reserve_exact(total)
.map_err(|_| Error::OutOfMemory)?;
for (q, &q_blk) in block_id.iter().enumerate() {
for k in 0..total_keys {
let attend = if k < past_len {
true
} else {
let k_local = k - past_len;
let causal = k_local <= q;
let same_image_span = q_blk != 0 && q_blk == block_id[k_local];
causal || same_image_span
};
buf.push(attend);
}
}
Array::from_slice::<bool>(&buf, &(1_usize, 1_usize, seq_len, total_keys))
}
#[derive(Debug)]
pub struct MultimodalPrompt {
pub tokens: Vec<u32>,
pub image_spans: ImageTokenSpans,
pub attention_mask: Array,
}
pub fn assemble_multimodal_prompt(
text_tokens: &[u32],
image_count: usize,
image_marker_id: u32,
image_token_id: u32,
num_tokens_per_image: usize,
policy: MarkerPolicy,
) -> Result<MultimodalPrompt> {
if image_count > 0 && num_tokens_per_image == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"assemble_multimodal_prompt: num_tokens_per_image (with image_count > 0)",
"must be > 0 — otherwise images would silently drop, config/model state is degenerate",
)));
}
let (base, marker_run_len) = if image_count == 0 {
(0_usize, 0_usize)
} else if let Some(run_start) = text_tokens.iter().position(|&t| t == image_marker_id) {
let run_end = text_tokens[run_start..]
.iter()
.position(|&t| t != image_marker_id)
.map_or(text_tokens.len(), |off| run_start + off);
(run_start, run_end - run_start)
} else {
(0_usize, 0_usize)
};
let placeholder_total = image_count
.checked_mul(num_tokens_per_image)
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"assemble_multimodal_prompt: placeholder_total (image_count * num_tokens_per_image)",
"usize",
[
("image_count", image_count as u64),
("num_tokens_per_image", num_tokens_per_image as u64),
],
))
})?;
let final_len = if placeholder_total == 0 {
text_tokens.len()
} else {
text_tokens
.len()
.checked_add(placeholder_total)
.and_then(|n| n.checked_sub(marker_run_len))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"assemble_multimodal_prompt: final_len (text_len + placeholder_total - marker_run_len)",
"usize",
[
("text_len", text_tokens.len() as u64),
("placeholder_total", placeholder_total as u64),
("marker_run_len", marker_run_len as u64),
],
))
})?
};
if final_len > i32::MAX as usize {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"assemble_multimodal_prompt: final assembled length",
"must be <= i32::MAX (mlx dimension limit; reject before allocating splice buffer)",
format!("{final_len}"),
)));
}
let tokens = insert_image_tokens(
text_tokens,
image_count,
image_marker_id,
image_token_id,
num_tokens_per_image,
policy,
)?;
let mut image_spans = ImageTokenSpans::new();
if image_count > 0 && num_tokens_per_image > 0 {
image_spans = try_with_capacity(image_count)?;
for i in 0..image_count {
let i_times_n = i.checked_mul(num_tokens_per_image).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"assemble_multimodal_prompt: i * num_tokens_per_image",
"usize",
[
("i", i as u64),
("num_tokens_per_image", num_tokens_per_image as u64),
],
))
})?;
let start = base.checked_add(i_times_n).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"assemble_multimodal_prompt: start (base + i * num_tokens_per_image)",
"usize",
[("base", base as u64), ("i_times_n", i_times_n as u64)],
))
})?;
let end = start.checked_add(num_tokens_per_image).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"assemble_multimodal_prompt: end (start + num_tokens_per_image)",
"usize",
[
("start", start as u64),
("num_tokens_per_image", num_tokens_per_image as u64),
],
))
})?;
image_spans.push((start, end));
}
}
let attention_mask = build_multimodal_mask(tokens.len(), &image_spans)?;
Ok(MultimodalPrompt {
tokens,
image_spans,
attention_mask,
})
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, derive_more::Display, derive_more::IsVariant)]
#[display("{}", self.as_str())]
#[non_exhaustive]
pub enum MessageFormat {
ListWithImage,
ListWithImageFirst,
ListWithImageUrlFirst,
ListWithImageType,
ListWithImageTypeText,
ListWithImageTypeTextImageLast,
ImageToken,
ImageTokenPipe,
StartImageToken,
ImageTokenNewline,
NumberedImageTokens,
PromptOnly,
PromptWithImageToken,
PromptWithStartImageToken,
VideoWithText,
}
impl MessageFormat {
pub const fn as_str(&self) -> &'static str {
match self {
Self::ListWithImage => "list_with_image",
Self::ListWithImageFirst => "list_with_image_first",
Self::ListWithImageUrlFirst => "list_with_image_url_first",
Self::ListWithImageType => "list_with_image_type",
Self::ListWithImageTypeText => "list_with_image_type_text",
Self::ListWithImageTypeTextImageLast => "list_with_image_type_text_image_last",
Self::ImageToken => "image_token",
Self::ImageTokenPipe => "image_token_pipe",
Self::StartImageToken => "start_image_token",
Self::ImageTokenNewline => "image_token_newline",
Self::NumberedImageTokens => "numbered_image_tokens",
Self::PromptOnly => "prompt_only",
Self::PromptWithImageToken => "prompt_with_image_token",
Self::PromptWithStartImageToken => "prompt_with_start_image_token",
Self::VideoWithText => "video_with_text",
}
}
}
pub const MESSAGE_FORMAT_VARIANTS: &[MessageFormat] = &[
MessageFormat::ListWithImage,
MessageFormat::ListWithImageFirst,
MessageFormat::ListWithImageUrlFirst,
MessageFormat::ListWithImageType,
MessageFormat::ListWithImageTypeText,
MessageFormat::ListWithImageTypeTextImageLast,
MessageFormat::ImageToken,
MessageFormat::ImageTokenPipe,
MessageFormat::StartImageToken,
MessageFormat::ImageTokenNewline,
MessageFormat::NumberedImageTokens,
MessageFormat::PromptOnly,
MessageFormat::PromptWithImageToken,
MessageFormat::PromptWithStartImageToken,
MessageFormat::VideoWithText,
];
pub const MODEL_CONFIG: &[(&str, MessageFormat)] = &[
("aya_vision", MessageFormat::ListWithImage),
("bunny-llama", MessageFormat::ImageTokenNewline),
("cohere2_vision", MessageFormat::ListWithImage),
("deepseek_vl_v2", MessageFormat::ImageTokenNewline),
("deepseekocr", MessageFormat::ImageTokenNewline),
("deepseekocr_2", MessageFormat::ImageTokenNewline),
("dots_ocr", MessageFormat::ListWithImageFirst),
("ernie4_5_moe_vl", MessageFormat::ListWithImageUrlFirst),
("falcon_ocr", MessageFormat::PromptOnly),
("florence2", MessageFormat::PromptOnly),
("gemma3", MessageFormat::StartImageToken),
("gemma3n", MessageFormat::ListWithImageTypeText),
("gemma4", MessageFormat::ListWithImageTypeText),
("glm4v", MessageFormat::ListWithImageFirst),
("glm4v_moe", MessageFormat::ListWithImageFirst),
("glm_ocr", MessageFormat::ListWithImageFirst),
("granite4_vision", MessageFormat::ListWithImage),
("granite_vision", MessageFormat::ListWithImage),
("hunyuan_vl", MessageFormat::ListWithImageFirst),
("idefics2", MessageFormat::ListWithImage),
("idefics3", MessageFormat::ListWithImageFirst),
("internvl_chat", MessageFormat::ListWithImageType),
("jina_vlm", MessageFormat::ImageTokenPipe),
("jvlm", MessageFormat::ImageTokenPipe),
("kimi_k25", MessageFormat::ListWithImage),
("kimi_vl", MessageFormat::ListWithImage),
("lfm2-vl", MessageFormat::ListWithImageFirst),
("lfm2_vl", MessageFormat::ListWithImageFirst),
("llama4", MessageFormat::ListWithImage),
("llava", MessageFormat::ListWithImage),
("llava-qwen2", MessageFormat::ImageTokenNewline),
("llava_next", MessageFormat::ListWithImage),
("llava_qwen2", MessageFormat::ImageTokenNewline),
("minicpmo", MessageFormat::ImageToken),
("mistral3", MessageFormat::ListWithImageFirst),
("mllama", MessageFormat::ListWithImage),
("molmo", MessageFormat::PromptOnly),
("molmo2", MessageFormat::ListWithImageFirst),
("molmo_point", MessageFormat::ListWithImageFirst),
("moondream3", MessageFormat::PromptOnly),
("multi_modality", MessageFormat::ImageToken),
("nemotron_h_nano_omni", MessageFormat::ListWithImageType),
(
"nemotronh_nano_omni_reasoning_v3",
MessageFormat::ListWithImageType,
),
("paddleocr_vl", MessageFormat::ListWithImageFirst),
("paligemma", MessageFormat::PromptWithImageToken),
("phi3_v", MessageFormat::NumberedImageTokens),
("phi4-siglip", MessageFormat::ImageTokenNewline),
("phi4mm", MessageFormat::NumberedImageTokens),
("pixtral", MessageFormat::ListWithImageTypeText),
("qwen2_5_vl", MessageFormat::ListWithImageFirst),
("qwen2_vl", MessageFormat::ListWithImage),
("qwen3_5", MessageFormat::ListWithImageFirst),
("qwen3_5_moe", MessageFormat::ListWithImageFirst),
("qwen3_omni_moe", MessageFormat::ListWithImageFirst),
("qwen3_vl", MessageFormat::ListWithImageFirst),
("qwen3_vl_moe", MessageFormat::ListWithImageFirst),
("smolvlm", MessageFormat::ListWithImageFirst),
("youtu_vl", MessageFormat::ListWithImageFirst),
];
pub const SINGLE_IMAGE_ONLY_MODELS: &[&str] = &[
"bunny-llama",
"falcon_ocr",
"llava-qwen2",
"llava_next",
"mllama",
"multi_modality",
"paligemma",
];
const VIDEO_FORMAT_MODELS: &[&str] = &[
"gemma4",
"qwen2_5_vl",
"qwen2_vl",
"qwen3_5",
"qwen3_5_moe",
"qwen3_omni_moe",
"qwen3_vl",
"qwen3_vl_moe",
];
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ContentItem {
Text {
text: String,
},
ContentText {
text: String,
},
Image,
ImageUrl,
Audio,
Video {
video: String,
max_pixels: u32,
fps: u32,
},
}
#[derive(Debug, Clone, Copy)]
pub struct MessageBuilder;
impl MessageBuilder {
pub fn text_message(text: impl Into<String>) -> ContentItem {
ContentItem::Text { text: text.into() }
}
pub fn content_message(content: impl Into<String>) -> ContentItem {
ContentItem::ContentText {
text: content.into(),
}
}
pub fn image_message() -> ContentItem {
ContentItem::Image
}
pub fn image_url_message() -> ContentItem {
ContentItem::ImageUrl
}
pub fn audio_message() -> ContentItem {
ContentItem::Audio
}
pub fn video_message(video_path: impl Into<String>, max_pixels: u32, fps: u32) -> ContentItem {
ContentItem::Video {
video: video_path.into(),
max_pixels,
fps,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Message {
pub role: String,
pub content: MessageContent,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MessageContent {
Items(Vec<ContentItem>),
Text(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FormattedMessage {
Message(Message),
String(String),
}
#[derive(Debug, Clone)]
pub struct FormatOpts {
pub role: String,
pub skip_image_token: bool,
pub skip_audio_token: bool,
pub num_images: usize,
pub num_audios: usize,
pub video: Vec<String>,
pub max_pixels: u32,
pub fps: Vec<u32>,
}
impl FormatOpts {
pub fn formatter_default() -> Self {
Self {
role: "user".to_string(),
skip_image_token: false,
skip_audio_token: false,
num_images: 1,
num_audios: 1,
video: Vec::new(),
max_pixels: 224 * 224,
fps: Vec::new(),
}
}
pub fn get_message_default() -> Self {
Self {
role: "user".to_string(),
skip_image_token: false,
skip_audio_token: false,
num_images: 0,
num_audios: 0,
video: Vec::new(),
max_pixels: 224 * 224,
fps: Vec::new(),
}
}
}
impl Default for FormatOpts {
fn default() -> Self {
Self::formatter_default()
}
}
#[derive(Debug, Clone)]
pub struct MessageFormatter {
pub model_name: String,
pub format_type: MessageFormat,
}
impl MessageFormatter {
pub fn for_model(model_type: &str) -> Result<Self> {
let lower = model_type.to_lowercase();
let idx = MODEL_CONFIG
.binary_search_by(|(k, _)| (*k).cmp(lower.as_str()))
.map_err(|_| {
Error::MissingKey(crate::error::MissingKeyPayload::new(
"MessageFormatter::for_model: model_type not in MODEL_CONFIG",
model_type.to_owned(),
))
})?;
Ok(Self {
model_name: lower,
format_type: MODEL_CONFIG[idx].1,
})
}
pub fn format_message(&self, prompt: &str, opts: &FormatOpts) -> Result<FormattedMessage> {
if opts.num_images > 1
&& SINGLE_IMAGE_ONLY_MODELS
.binary_search(&self.model_name.as_str())
.is_ok()
{
return Err(Error::OutOfRange(OutOfRangePayload::new(
"MessageFormatter::format_message: opts.num_images (this model is in \
SINGLE_IMAGE_ONLY_MODELS — please use only 1 image)",
"must be <= 1",
format!("{}", opts.num_images),
)));
}
if !opts.video.is_empty()
&& VIDEO_FORMAT_MODELS
.binary_search(&self.model_name.as_str())
.is_ok()
{
return self.format_video_message(prompt, opts);
}
match self.format_type {
MessageFormat::ListWithImage => self.format_list_with_image(prompt, opts, false, false),
MessageFormat::ListWithImageFirst => self.format_list_with_image(prompt, opts, true, false),
MessageFormat::ListWithImageUrlFirst => self.format_list_with_image(prompt, opts, true, true),
MessageFormat::ListWithImageType => {
self.format_list_with_image_type(prompt, opts, ContentMessageKind::Content, true)
}
MessageFormat::ListWithImageTypeText => {
self.format_list_with_image_type(prompt, opts, ContentMessageKind::Text, true)
}
MessageFormat::ListWithImageTypeTextImageLast => {
self.format_list_with_image_type(prompt, opts, ContentMessageKind::Text, false)
}
MessageFormat::ImageToken => self.format_with_token(prompt, opts, "<image>", true),
MessageFormat::ImageTokenPipe => self.format_with_token(prompt, opts, "<|image|>", true),
MessageFormat::StartImageToken => {
self.format_with_token(prompt, opts, "<start_of_image>", false)
}
MessageFormat::ImageTokenNewline => self.format_with_token(prompt, opts, "<image>\n", true),
MessageFormat::NumberedImageTokens => self.format_numbered_tokens(prompt, opts),
MessageFormat::PromptOnly => Ok(FormattedMessage::String(prompt.to_string())),
MessageFormat::PromptWithImageToken => self.format_prompt_with_image_token(prompt, opts),
MessageFormat::PromptWithStartImageToken => {
self.format_prompt_with_start_image_token(prompt, opts)
}
MessageFormat::VideoWithText => self.format_video_message(prompt, opts),
}
}
fn format_prompt_with_image_token(
&self,
prompt: &str,
opts: &FormatOpts,
) -> Result<FormattedMessage> {
check_format_count(opts.num_images, "num_images", &self.model_name)?;
let effective_n = opts.num_images;
let cap = effective_n
.checked_mul(7) .and_then(|n| n.checked_add(prompt.len()))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"MessageFormatter::format_prompt_with_image_token: cap (7 * num_images + prompt.len())",
"usize",
[
("num_images", effective_n as u64),
("prompt_len", prompt.len() as u64),
],
))
})?;
let mut s = String::new();
s.try_reserve_exact(cap).map_err(|_| Error::OutOfMemory)?;
for _ in 0..effective_n {
s.push_str("<image>");
}
s.push_str(prompt);
Ok(FormattedMessage::String(s))
}
fn format_prompt_with_start_image_token(
&self,
prompt: &str,
opts: &FormatOpts,
) -> Result<FormattedMessage> {
check_format_count(opts.num_images, "num_images", &self.model_name)?;
let effective_n = opts.num_images;
let cap = effective_n
.checked_mul(16) .and_then(|n| n.checked_add(prompt.len()))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"MessageFormatter::format_prompt_with_start_image_token: cap (16 * num_images + \
prompt.len())",
"usize",
[
("num_images", effective_n as u64),
("prompt_len", prompt.len() as u64),
],
))
})?;
let mut s = String::new();
s.try_reserve_exact(cap).map_err(|_| Error::OutOfMemory)?;
s.push_str(prompt);
for _ in 0..effective_n {
s.push_str("<start_of_image>");
}
Ok(FormattedMessage::String(s))
}
fn format_list_with_image(
&self,
prompt: &str,
opts: &FormatOpts,
image_first: bool,
use_image_url: bool,
) -> Result<FormattedMessage> {
let effective_n = if opts.role == "user" && !opts.skip_image_token {
check_format_count(opts.num_images, "num_images", &self.model_name)?;
opts.num_images
} else {
0
};
let cap = 1usize.checked_add(effective_n).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"MessageFormatter::format_list_with_image: cap (1 + num_images)",
"usize",
[("num_images", effective_n as u64)],
))
})?;
let mut content: Vec<ContentItem> = try_with_capacity(cap)?;
let text_first = !image_first || effective_n == 0;
if text_first {
content.push(MessageBuilder::text_message(prompt));
}
if effective_n > 0 {
let image_builder = if use_image_url {
MessageBuilder::image_url_message
} else {
MessageBuilder::image_message
};
for _ in 0..effective_n {
content.push(image_builder());
}
}
if !text_first {
content.push(MessageBuilder::text_message(prompt));
}
Ok(FormattedMessage::Message(Message {
role: opts.role.clone(),
content: MessageContent::Items(content),
}))
}
fn format_list_with_image_type(
&self,
prompt: &str,
opts: &FormatOpts,
msg_kind: ContentMessageKind,
image_first: bool,
) -> Result<FormattedMessage> {
if opts.role == "assistant" {
let s = match msg_kind {
ContentMessageKind::Content | ContentMessageKind::Text => prompt.to_string(),
};
return Ok(FormattedMessage::Message(Message {
role: opts.role.clone(),
content: MessageContent::Text(s),
}));
}
let n_img = if opts.role == "user" && !opts.skip_image_token {
check_format_count(opts.num_images, "num_images", &self.model_name)?;
opts.num_images
} else {
0
};
let n_aud = if opts.role == "user" && !opts.skip_audio_token {
check_format_count(opts.num_audios, "num_audios", &self.model_name)?;
opts.num_audios
} else {
0
};
let cap = 1usize
.checked_add(n_img)
.and_then(|n| n.checked_add(n_aud))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"MessageFormatter::format_list_with_image_type: cap (1 + num_images + num_audios)",
"usize",
[("num_images", n_img as u64), ("num_audios", n_aud as u64)],
))
})?;
let msg = match msg_kind {
ContentMessageKind::Content => MessageBuilder::content_message(prompt),
ContentMessageKind::Text => MessageBuilder::text_message(prompt),
};
let mut content: Vec<ContentItem> = try_with_capacity(cap)?;
let text_first = !image_first || n_img == 0;
if text_first {
content.push(msg);
for _ in 0..n_img {
content.push(MessageBuilder::image_message());
}
} else {
for _ in 0..n_img {
content.push(MessageBuilder::image_message());
}
content.push(msg);
}
for _ in 0..n_aud {
content.push(MessageBuilder::audio_message());
}
Ok(FormattedMessage::Message(Message {
role: opts.role.clone(),
content: MessageContent::Items(content),
}))
}
fn format_with_token(
&self,
prompt: &str,
opts: &FormatOpts,
token: &str,
image_first: bool,
) -> Result<FormattedMessage> {
let n_img = if opts.role == "user" && !opts.skip_image_token {
check_format_count(opts.num_images, "num_images", &self.model_name)?;
opts.num_images
} else {
0
};
let n_aud = if opts.role == "user" && !opts.skip_audio_token {
check_format_count(opts.num_audios, "num_audios", &self.model_name)?;
opts.num_audios
} else {
0
};
const AUDIO_BYTES_PER_AUDIO: usize = 32;
let token_bytes = token.len().checked_mul(n_img).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"MessageFormatter::format_with_token: token_bytes (token.len() * num_images)",
"usize",
[
("token_len", token.len() as u64),
("num_images", n_img as u64),
],
))
})?;
let audio_bytes = AUDIO_BYTES_PER_AUDIO.checked_mul(n_aud).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"MessageFormatter::format_with_token: audio_bytes (AUDIO_BYTES_PER_AUDIO * num_audios)",
"usize",
[
("AUDIO_BYTES_PER_AUDIO", AUDIO_BYTES_PER_AUDIO as u64),
("num_audios", n_aud as u64),
],
))
})?;
let cap = token_bytes
.checked_add(audio_bytes)
.and_then(|n| n.checked_add(prompt.len()))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"MessageFormatter::format_with_token: cap (token_bytes + audio_bytes + prompt.len())",
"usize",
[
("token_bytes", token_bytes as u64),
("audio_bytes", audio_bytes as u64),
("prompt_len", prompt.len() as u64),
],
))
})?;
let mut content = String::new();
content
.try_reserve_exact(cap)
.map_err(|_| Error::OutOfMemory)?;
for i in 0..n_aud {
content.push_str(&format!("<|audio_{}|>", i + 1));
}
if image_first {
for _ in 0..n_img {
content.push_str(token);
}
content.push_str(prompt);
} else {
content.push_str(prompt);
for _ in 0..n_img {
content.push_str(token);
}
}
Ok(FormattedMessage::Message(Message {
role: opts.role.clone(),
content: MessageContent::Text(content),
}))
}
fn format_numbered_tokens(&self, prompt: &str, opts: &FormatOpts) -> Result<FormattedMessage> {
let n_img = if opts.role == "user" && !opts.skip_image_token {
check_format_count(opts.num_images, "num_images", &self.model_name)?;
opts.num_images
} else {
0
};
let n_aud = if opts.role == "user" && !opts.skip_audio_token {
check_format_count(opts.num_audios, "num_audios", &self.model_name)?;
opts.num_audios
} else {
0
};
const BYTES_PER_TOKEN: usize = 16;
let img_bytes = BYTES_PER_TOKEN.checked_mul(n_img).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"MessageFormatter::format_numbered_tokens: img_bytes (BYTES_PER_TOKEN * num_images)",
"usize",
[
("BYTES_PER_TOKEN", BYTES_PER_TOKEN as u64),
("num_images", n_img as u64),
],
))
})?;
let aud_bytes = BYTES_PER_TOKEN.checked_mul(n_aud).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"MessageFormatter::format_numbered_tokens: aud_bytes (BYTES_PER_TOKEN * num_audios)",
"usize",
[
("BYTES_PER_TOKEN", BYTES_PER_TOKEN as u64),
("num_audios", n_aud as u64),
],
))
})?;
let cap = img_bytes
.checked_add(aud_bytes)
.and_then(|n| n.checked_add(prompt.len()))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"MessageFormatter::format_numbered_tokens: cap (img_bytes + aud_bytes + prompt.len())",
"usize",
[
("img_bytes", img_bytes as u64),
("aud_bytes", aud_bytes as u64),
("prompt_len", prompt.len() as u64),
],
))
})?;
let mut content = String::new();
content
.try_reserve_exact(cap)
.map_err(|_| Error::OutOfMemory)?;
for i in 0..n_img {
content.push_str(&format!("<|image_{}|>", i + 1));
}
for i in 0..n_aud {
content.push_str(&format!("<|audio_{}|>", i + 1));
}
content.push_str(prompt);
Ok(FormattedMessage::Message(Message {
role: opts.role.clone(),
content: MessageContent::Text(content),
}))
}
fn format_video_message(&self, prompt: &str, opts: &FormatOpts) -> Result<FormattedMessage> {
if opts.video.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"MessageFormatter::format_video_message: opts.video (the python branch unconditionally \
dereferences kwargs['video'])",
)));
}
check_format_count(opts.video.len(), "video.len()", &self.model_name)?;
let n_vid = opts.video.len();
let fps_list: Vec<u32> = if opts.fps.is_empty() {
let mut v: Vec<u32> = try_with_capacity(n_vid)?;
v.resize(n_vid, 1u32);
v
} else if opts.fps.len() == 1 {
let mut v: Vec<u32> = try_with_capacity(n_vid)?;
v.resize(n_vid, opts.fps[0]);
v
} else if opts.fps.len() == n_vid {
let mut v: Vec<u32> = try_with_capacity(n_vid)?;
v.extend_from_slice(&opts.fps);
v
} else {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"MessageFormatter::format_video_message: opts.fps vs opts.video length (fps must be empty, \
a scalar, or match video.len() exactly)",
n_vid,
opts.fps.len(),
)));
};
let cap = n_vid.checked_add(1).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"MessageFormatter::format_video_message: cap (video.len() + 1)",
"usize",
[("video_len", n_vid as u64)],
))
})?;
let mut content: Vec<ContentItem> = try_with_capacity(cap)?;
for (v, f) in opts.video.iter().zip(fps_list.iter()) {
content.push(MessageBuilder::video_message(
v.clone(),
opts.max_pixels,
*f,
));
}
content.push(MessageBuilder::text_message(prompt));
Ok(FormattedMessage::Message(Message {
role: opts.role.clone(),
content: MessageContent::Items(content),
}))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ContentMessageKind {
Content,
Text,
}
pub fn get_message_json(
model_name: &str,
prompt: &str,
opts: Option<&FormatOpts>,
) -> Result<FormattedMessage> {
let formatter = MessageFormatter::for_model(model_name)?;
let defaults;
let resolved = match opts {
Some(o) => o,
None => {
defaults = FormatOpts::get_message_default();
&defaults
}
};
formatter.format_message(prompt, resolved)
}
#[cfg(test)]
mod tests {
use super::*;
fn opts_with(role: &str, num_images: usize, num_audios: usize) -> FormatOpts {
FormatOpts {
role: role.to_string(),
num_images,
num_audios,
..FormatOpts::formatter_default()
}
}
fn items(fm: &FormattedMessage) -> Vec<ContentItem> {
match fm {
FormattedMessage::Message(Message {
content: MessageContent::Items(v),
..
}) => v.clone(),
other => panic!("expected Message(Items), got {other:?}"),
}
}
fn text_content(fm: &FormattedMessage) -> String {
match fm {
FormattedMessage::Message(Message {
content: MessageContent::Text(s),
..
}) => s.clone(),
other => panic!("expected Message(Text), got {other:?}"),
}
}
#[test]
fn marker_policy_as_str_and_display() {
assert_eq!(MarkerPolicy::Required.as_str(), "required");
assert_eq!(MarkerPolicy::PrependIfAbsent.as_str(), "prepend_if_absent");
assert_eq!(format!("{}", MarkerPolicy::Required), "required");
assert_eq!(
format!("{}", MarkerPolicy::PrependIfAbsent),
"prepend_if_absent"
);
assert!(MarkerPolicy::Required.is_required());
assert!(MarkerPolicy::PrependIfAbsent.is_prepend_if_absent());
assert!(!MarkerPolicy::Required.is_prepend_if_absent());
}
#[test]
fn locate_image_tokens_empty_and_runs() {
assert_eq!(locate_image_tokens(&[], 99), ImageTokenSpans::new());
assert_eq!(locate_image_tokens(&[1, 2, 3], 99), ImageTokenSpans::new());
assert_eq!(locate_image_tokens(&[1, 99, 99], 99), vec![(1, 3)]);
assert_eq!(locate_image_tokens(&[99, 1], 99), vec![(0, 1)]);
}
#[test]
fn insert_image_tokens_marker_cap_overflow() {
let text = [7_u32, 7];
let err =
insert_image_tokens(&text, 2, 7, 99, usize::MAX / 2, MarkerPolicy::Required).unwrap_err();
assert!(
matches!(err, Error::ArithmeticOverflow(_)),
"expected ArithmeticOverflow, got {err:?}"
);
if let Error::ArithmeticOverflow(p) = &err {
assert_eq!(
p.context(),
"insert_image_tokens: cap (text_len + placeholder_total - run_len)"
);
}
}
#[test]
fn insert_image_tokens_prepend_cap_overflow() {
let text = [1_u32, 2]; let err = insert_image_tokens(
&text,
2,
7,
99,
usize::MAX / 2,
MarkerPolicy::PrependIfAbsent,
)
.unwrap_err();
assert!(
matches!(err, Error::ArithmeticOverflow(_)),
"expected ArithmeticOverflow, got {err:?}"
);
if let Error::ArithmeticOverflow(p) = &err {
assert_eq!(
p.context(),
"insert_image_tokens: cap (text_len + placeholder_total)"
);
}
}
#[test]
fn insert_image_tokens_degenerate_and_overflow_and_invariant() {
let err = insert_image_tokens(&[1, 7, 2], 1, 7, 99, 0, MarkerPolicy::Required).unwrap_err();
assert!(matches!(err, Error::InvariantViolation(_)));
let err =
insert_image_tokens(&[7, 7], 2, 7, 99, usize::MAX, MarkerPolicy::Required).unwrap_err();
assert!(matches!(err, Error::ArithmeticOverflow(_)));
if let Error::ArithmeticOverflow(p) = &err {
assert_eq!(
p.context(),
"insert_image_tokens: placeholder_total (image_count * num_tokens_per_image)"
);
}
let err = insert_image_tokens(&[7, 1, 7], 1, 7, 99, 3, MarkerPolicy::Required).unwrap_err();
assert!(matches!(err, Error::InvariantViolation(_)));
let err = insert_image_tokens(&[1, 7, 7, 2], 1, 7, 99, 3, MarkerPolicy::Required).unwrap_err();
assert!(matches!(err, Error::LengthMismatch(_)));
if let Error::LengthMismatch(p) = &err {
assert_eq!(p.expected(), 1);
assert_eq!(p.actual(), 2);
}
let err = insert_image_tokens(&[1, 2], 1, 7, 99, 3, MarkerPolicy::Required).unwrap_err();
assert!(matches!(err, Error::MissingField(_)));
if let Error::MissingField(p) = &err {
assert_eq!(p.field(), "image_marker_id token in text_tokens");
}
}
#[test]
fn mask_with_past_total_keys_overflow() {
let err = build_multimodal_mask_with_past(1, usize::MAX, &[]).unwrap_err();
assert!(matches!(err, Error::ArithmeticOverflow(_)));
if let Error::ArithmeticOverflow(p) = &err {
assert_eq!(
p.context(),
"build_multimodal_mask_with_past: total_keys (past_len + seq_len)"
);
}
}
#[test]
fn mask_with_past_total_keys_exceeds_i32_max() {
let err = build_multimodal_mask_with_past(1, i32::MAX as usize, &[]).unwrap_err();
assert!(matches!(err, Error::OutOfRange(_)));
if let Error::OutOfRange(p) = &err {
assert_eq!(
p.context(),
"build_multimodal_mask_with_past: total_keys (past_len + seq_len)"
);
}
}
#[test]
fn mask_with_past_empty_chunk() {
let mask = build_multimodal_mask_with_past(0, 3, &[]).unwrap();
assert_eq!(mask.shape(), vec![1, 1, 0, 3]);
let err = build_multimodal_mask_with_past(0, 3, &[(0, 1)]).unwrap_err();
assert!(matches!(err, Error::InvariantViolation(_)));
}
#[test]
fn mask_span_validation_errors() {
let err = build_multimodal_mask(4, &[(2, 2)]).unwrap_err();
assert!(matches!(err, Error::InvariantViolation(_)));
let err = build_multimodal_mask(4, &[(2, 5)]).unwrap_err();
assert!(matches!(err, Error::OutOfRange(_)));
let err = build_multimodal_mask(8, &[(1, 4), (3, 6)]).unwrap_err();
assert!(matches!(err, Error::InvariantViolation(_)));
}
#[test]
fn mask_image_span_bidirectional_oracle() {
let mut mask = build_multimodal_mask(4, &[(1, 3)]).unwrap();
assert_eq!(mask.shape(), vec![1, 1, 4, 4]);
let v = mask.to_vec::<bool>().unwrap();
let block = [0u32, 1, 1, 0];
for q in 0..4usize {
for k in 0..4usize {
let causal = k <= q;
let same_img = block[q] != 0 && block[q] == block[k];
let expected = causal || same_img;
assert_eq!(v[q * 4 + k], expected, "mask[{q}][{k}] expected {expected}");
}
}
assert!(v[6]);
}
#[test]
fn mask_with_past_attends_all_past_keys() {
let mut mask = build_multimodal_mask_with_past(2, 2, &[(0, 2)]).unwrap();
assert_eq!(mask.shape(), vec![1, 1, 2, 4]);
let v = mask.to_vec::<bool>().unwrap();
let block = [1u32, 1]; for q in 0..2usize {
for k in 0..4usize {
let expected = if k < 2 {
true
} else {
let kl = k - 2;
(kl <= q) || (block[q] != 0 && block[q] == block[kl])
};
assert_eq!(v[q * 4 + k], expected, "mask[{q}][{k}]");
}
}
}
#[test]
fn assemble_degenerate_zero_width() {
let err =
assemble_multimodal_prompt(&[1, 7, 2], 1, 7, 99, 0, MarkerPolicy::Required).unwrap_err();
assert!(matches!(err, Error::InvariantViolation(_)));
if let Error::InvariantViolation(p) = &err {
assert_eq!(
p.context(),
"assemble_multimodal_prompt: num_tokens_per_image (with image_count > 0)"
);
}
}
#[test]
fn assemble_placeholder_total_overflow() {
let err = assemble_multimodal_prompt(&[7, 7], 2, 7, 99, usize::MAX, MarkerPolicy::Required)
.unwrap_err();
assert!(matches!(err, Error::ArithmeticOverflow(_)));
if let Error::ArithmeticOverflow(p) = &err {
assert_eq!(
p.context(),
"assemble_multimodal_prompt: placeholder_total (image_count * num_tokens_per_image)"
);
}
}
#[test]
fn assemble_final_len_overflow() {
let err = assemble_multimodal_prompt(
&[1, 2],
2,
7,
99,
usize::MAX / 2,
MarkerPolicy::PrependIfAbsent,
)
.unwrap_err();
assert!(matches!(err, Error::ArithmeticOverflow(_)));
if let Error::ArithmeticOverflow(p) = &err {
assert_eq!(
p.context(),
"assemble_multimodal_prompt: final_len (text_len + placeholder_total - marker_run_len)"
);
}
}
#[test]
fn assemble_final_len_exceeds_i32_max() {
let n = i32::MAX as usize + 1;
let err =
assemble_multimodal_prompt(&[1], 1, 7, 99, n, MarkerPolicy::PrependIfAbsent).unwrap_err();
assert!(matches!(err, Error::OutOfRange(_)));
if let Error::OutOfRange(p) = &err {
assert_eq!(
p.context(),
"assemble_multimodal_prompt: final assembled length"
);
}
}
#[test]
fn assemble_multi_image_spans() {
let text = [1_u32, 2, 7, 7, 3];
let p = assemble_multimodal_prompt(&text, 2, 7, 99, 3, MarkerPolicy::Required).unwrap();
assert_eq!(p.tokens, vec![1, 2, 99, 99, 99, 99, 99, 99, 3]);
assert_eq!(p.image_spans, vec![(2, 5), (5, 8)]);
assert_eq!(p.attention_mask.shape(), vec![1, 1, 9, 9]);
}
#[test]
fn assemble_zero_images() {
let text = [1_u32, 2, 3];
let p = assemble_multimodal_prompt(&text, 0, 7, 99, 3, MarkerPolicy::Required).unwrap();
assert_eq!(p.tokens, vec![1, 2, 3]);
assert!(p.image_spans.is_empty());
assert_eq!(p.attention_mask.shape(), vec![1, 1, 3, 3]);
}
#[test]
fn message_format_as_str_all_variants() {
let pairs: &[(MessageFormat, &str)] = &[
(MessageFormat::ListWithImage, "list_with_image"),
(MessageFormat::ListWithImageFirst, "list_with_image_first"),
(
MessageFormat::ListWithImageUrlFirst,
"list_with_image_url_first",
),
(MessageFormat::ListWithImageType, "list_with_image_type"),
(
MessageFormat::ListWithImageTypeText,
"list_with_image_type_text",
),
(
MessageFormat::ListWithImageTypeTextImageLast,
"list_with_image_type_text_image_last",
),
(MessageFormat::ImageToken, "image_token"),
(MessageFormat::ImageTokenPipe, "image_token_pipe"),
(MessageFormat::StartImageToken, "start_image_token"),
(MessageFormat::ImageTokenNewline, "image_token_newline"),
(MessageFormat::NumberedImageTokens, "numbered_image_tokens"),
(MessageFormat::PromptOnly, "prompt_only"),
(
MessageFormat::PromptWithImageToken,
"prompt_with_image_token",
),
(
MessageFormat::PromptWithStartImageToken,
"prompt_with_start_image_token",
),
(MessageFormat::VideoWithText, "video_with_text"),
];
for (fmt, s) in pairs {
assert_eq!(fmt.as_str(), *s);
assert_eq!(format!("{fmt}"), *s);
}
assert_eq!(MESSAGE_FORMAT_VARIANTS.len(), 15);
for (i, (fmt, _)) in pairs.iter().enumerate() {
assert_eq!(MESSAGE_FORMAT_VARIANTS[i], *fmt);
}
}
#[test]
fn message_builder_constructors() {
assert_eq!(
MessageBuilder::text_message("hi"),
ContentItem::Text {
text: "hi".to_string()
}
);
assert_eq!(
MessageBuilder::content_message("hi"),
ContentItem::ContentText {
text: "hi".to_string()
}
);
assert_eq!(MessageBuilder::image_message(), ContentItem::Image);
assert_eq!(MessageBuilder::image_url_message(), ContentItem::ImageUrl);
assert_eq!(MessageBuilder::audio_message(), ContentItem::Audio);
assert_eq!(
MessageBuilder::video_message("v.mp4", 100, 2),
ContentItem::Video {
video: "v.mp4".to_string(),
max_pixels: 100,
fps: 2,
}
);
}
#[test]
fn format_opts_defaults() {
let f = FormatOpts::formatter_default();
assert_eq!(f.num_images, 1);
assert_eq!(f.num_audios, 1);
assert_eq!(f.role, "user");
assert_eq!(f.max_pixels, 224 * 224);
let d = FormatOpts::default();
assert_eq!(d.num_images, 1);
assert_eq!(d.num_audios, 1);
let g = FormatOpts::get_message_default();
assert_eq!(g.num_images, 0);
assert_eq!(g.num_audios, 0);
assert_eq!(g.role, "user");
}
#[test]
fn for_model_lookup_and_errors() {
let f = MessageFormatter::for_model("Qwen2_VL").unwrap();
assert_eq!(f.model_name, "qwen2_vl");
assert_eq!(f.format_type, MessageFormat::ListWithImage);
let err = MessageFormatter::for_model("nonexistent_model").unwrap_err();
assert!(matches!(err, Error::MissingKey(_)));
if let Error::MissingKey(p) = &err {
assert_eq!(p.key(), "nonexistent_model");
}
}
#[test]
fn dispatch_single_image_guard() {
let f = MessageFormatter::for_model("paligemma").unwrap();
let opts = opts_with("user", 2, 0);
let err = f.format_message("hi", &opts).unwrap_err();
assert!(matches!(err, Error::OutOfRange(_)));
}
#[test]
fn dispatch_list_with_image() {
let f = MessageFormatter::for_model("qwen2_vl").unwrap();
let it = items(&f.format_message("hi", &opts_with("user", 1, 0)).unwrap());
assert_eq!(
it,
vec![
ContentItem::Text {
text: "hi".to_string()
},
ContentItem::Image
]
);
}
#[test]
fn dispatch_list_with_image_first() {
let f = MessageFormatter::for_model("qwen2_5_vl").unwrap();
let it = items(&f.format_message("hi", &opts_with("user", 2, 0)).unwrap());
assert_eq!(
it,
vec![
ContentItem::Image,
ContentItem::Image,
ContentItem::Text {
text: "hi".to_string()
},
]
);
}
#[test]
fn dispatch_list_with_image_url_first() {
let f = MessageFormatter::for_model("ernie4_5_moe_vl").unwrap();
let it = items(&f.format_message("hi", &opts_with("user", 1, 0)).unwrap());
assert_eq!(
it,
vec![
ContentItem::ImageUrl,
ContentItem::Text {
text: "hi".to_string()
},
]
);
}
#[test]
fn dispatch_list_with_image_type_content() {
let f = MessageFormatter::for_model("internvl_chat").unwrap();
let it = items(&f.format_message("hi", &opts_with("user", 1, 1)).unwrap());
assert_eq!(
it,
vec![
ContentItem::Image,
ContentItem::ContentText {
text: "hi".to_string()
},
ContentItem::Audio,
]
);
}
#[test]
fn dispatch_list_with_image_type_text() {
let f = MessageFormatter::for_model("pixtral").unwrap();
let it = items(&f.format_message("hi", &opts_with("user", 1, 0)).unwrap());
assert_eq!(
it,
vec![
ContentItem::Image,
ContentItem::Text {
text: "hi".to_string()
},
]
);
}
#[test]
fn dispatch_image_token() {
let f = MessageFormatter::for_model("minicpmo").unwrap();
let s = text_content(&f.format_message("hi", &opts_with("user", 2, 0)).unwrap());
assert_eq!(s, "<image><image>hi");
}
#[test]
fn dispatch_image_token_pipe() {
let f = MessageFormatter::for_model("jina_vlm").unwrap();
let s = text_content(&f.format_message("hi", &opts_with("user", 1, 0)).unwrap());
assert_eq!(s, "<|image|>hi");
}
#[test]
fn dispatch_start_image_token() {
let f = MessageFormatter::for_model("gemma3").unwrap();
let s = text_content(&f.format_message("hi", &opts_with("user", 1, 0)).unwrap());
assert_eq!(s, "hi<start_of_image>");
}
#[test]
fn dispatch_image_token_newline() {
let f = MessageFormatter::for_model("deepseek_vl_v2").unwrap();
let s = text_content(&f.format_message("hi", &opts_with("user", 1, 0)).unwrap());
assert_eq!(s, "<image>\nhi");
}
#[test]
fn dispatch_numbered_image_tokens() {
let f = MessageFormatter::for_model("phi3_v").unwrap();
let s = text_content(&f.format_message("hi", &opts_with("user", 2, 1)).unwrap());
assert_eq!(s, "<|image_1|><|image_2|><|audio_1|>hi");
}
#[test]
fn dispatch_prompt_only() {
let f = MessageFormatter::for_model("florence2").unwrap();
let out = f.format_message("hi", &opts_with("user", 1, 0)).unwrap();
assert_eq!(out, FormattedMessage::String("hi".to_string()));
}
#[test]
fn dispatch_prompt_with_image_token() {
let f = MessageFormatter::for_model("paligemma").unwrap();
let out = f.format_message("hi", &opts_with("user", 1, 0)).unwrap();
assert_eq!(out, FormattedMessage::String("<image>hi".to_string()));
let out = f
.format_message("hi", &opts_with("assistant", 1, 0))
.unwrap();
assert_eq!(out, FormattedMessage::String("<image>hi".to_string()));
let out = f.format_message("hi", &opts_with("user", 0, 0)).unwrap();
assert_eq!(out, FormattedMessage::String("hi".to_string()));
}
#[test]
fn dispatch_video_with_text_default_fps() {
let f = MessageFormatter::for_model("qwen2_vl").unwrap();
let opts = FormatOpts {
role: "user".to_string(),
video: vec!["a.mp4".to_string(), "b.mp4".to_string()],
fps: Vec::new(), max_pixels: 256,
..FormatOpts::formatter_default()
};
let it = items(&f.format_message("cap", &opts).unwrap());
assert_eq!(
it,
vec![
ContentItem::Video {
video: "a.mp4".to_string(),
max_pixels: 256,
fps: 1,
},
ContentItem::Video {
video: "b.mp4".to_string(),
max_pixels: 256,
fps: 1,
},
ContentItem::Text {
text: "cap".to_string()
},
]
);
}
#[test]
fn video_message_scalar_fps_broadcast() {
let f = MessageFormatter::for_model("qwen2_5_vl").unwrap();
let opts = FormatOpts {
role: "user".to_string(),
video: vec!["a.mp4".to_string(), "b.mp4".to_string()],
fps: vec![5],
max_pixels: 224 * 224,
..FormatOpts::formatter_default()
};
let it = items(&f.format_message("cap", &opts).unwrap());
assert_eq!(
it,
vec![
ContentItem::Video {
video: "a.mp4".to_string(),
max_pixels: 224 * 224,
fps: 5,
},
ContentItem::Video {
video: "b.mp4".to_string(),
max_pixels: 224 * 224,
fps: 5,
},
ContentItem::Text {
text: "cap".to_string()
},
]
);
}
#[test]
fn video_message_per_video_fps() {
let f = MessageFormatter::for_model("qwen3_vl").unwrap();
let opts = FormatOpts {
role: "user".to_string(),
video: vec!["a.mp4".to_string(), "b.mp4".to_string()],
fps: vec![2, 3],
..FormatOpts::formatter_default()
};
let it = items(&f.format_message("cap", &opts).unwrap());
match (&it[0], &it[1]) {
(ContentItem::Video { fps: f0, .. }, ContentItem::Video { fps: f1, .. }) => {
assert_eq!(*f0, 2);
assert_eq!(*f1, 3);
}
_ => panic!("expected two Video items"),
}
}
#[test]
fn video_message_fps_length_mismatch() {
let f = MessageFormatter::for_model("qwen2_vl").unwrap();
let opts = FormatOpts {
role: "user".to_string(),
video: vec![
"a.mp4".to_string(),
"b.mp4".to_string(),
"c.mp4".to_string(),
],
fps: vec![1, 2],
..FormatOpts::formatter_default()
};
let err = f.format_message("cap", &opts).unwrap_err();
assert!(matches!(err, Error::LengthMismatch(_)));
if let Error::LengthMismatch(p) = &err {
assert_eq!(p.expected(), 3);
assert_eq!(p.actual(), 2);
}
}
#[test]
fn list_with_image_assistant_role_no_images() {
let f = MessageFormatter::for_model("qwen2_vl").unwrap();
let it = items(
&f.format_message("hi", &opts_with("assistant", 3, 0))
.unwrap(),
);
assert_eq!(
it,
vec![ContentItem::Text {
text: "hi".to_string()
}]
);
}
#[test]
fn list_with_image_skip_image_token() {
let f = MessageFormatter::for_model("qwen2_vl").unwrap();
let opts = FormatOpts {
role: "user".to_string(),
num_images: 3,
skip_image_token: true,
..FormatOpts::formatter_default()
};
let it = items(&f.format_message("hi", &opts).unwrap());
assert_eq!(
it,
vec![ContentItem::Text {
text: "hi".to_string()
}]
);
}
#[test]
fn list_with_image_type_assistant_collapse() {
let f = MessageFormatter::for_model("internvl_chat").unwrap();
let out = f
.format_message("hi", &opts_with("assistant", 2, 2))
.unwrap();
assert_eq!(out, text_content_msg("assistant", "hi"));
}
#[test]
fn list_with_image_type_system_role_no_media() {
let f = MessageFormatter::for_model("internvl_chat").unwrap();
let it = items(&f.format_message("sys", &opts_with("system", 5, 5)).unwrap());
assert_eq!(
it,
vec![ContentItem::ContentText {
text: "sys".to_string()
}]
);
}
#[test]
fn list_with_image_type_text_first_with_audio() {
let f = MessageFormatter::for_model("internvl_chat").unwrap();
let it = items(&f.format_message("hi", &opts_with("user", 0, 2)).unwrap());
assert_eq!(
it,
vec![
ContentItem::ContentText {
text: "hi".to_string()
},
ContentItem::Audio,
ContentItem::Audio,
]
);
}
#[test]
fn with_token_audio_prefix_then_images() {
let f = MessageFormatter::for_model("minicpmo").unwrap();
let s = text_content(&f.format_message("hi", &opts_with("user", 1, 2)).unwrap());
assert_eq!(s, "<|audio_1|><|audio_2|><image>hi");
}
#[test]
fn with_token_assistant_no_media() {
let f = MessageFormatter::for_model("jina_vlm").unwrap();
let s = text_content(
&f.format_message("hi", &opts_with("assistant", 3, 3))
.unwrap(),
);
assert_eq!(s, "hi");
}
#[test]
fn with_token_skip_audio_only() {
let f = MessageFormatter::for_model("minicpmo").unwrap();
let opts = FormatOpts {
role: "user".to_string(),
num_images: 1,
num_audios: 3,
skip_audio_token: true,
..FormatOpts::formatter_default()
};
let s = text_content(&f.format_message("hi", &opts).unwrap());
assert_eq!(s, "<image>hi");
}
#[test]
fn numbered_tokens_assistant_no_media() {
let f = MessageFormatter::for_model("phi3_v").unwrap();
let s = text_content(
&f.format_message("hi", &opts_with("assistant", 3, 3))
.unwrap(),
);
assert_eq!(s, "hi");
}
#[test]
fn numbered_tokens_skip_image_keeps_audio() {
let f = MessageFormatter::for_model("phi4mm").unwrap();
let opts = FormatOpts {
role: "user".to_string(),
num_images: 3,
num_audios: 2,
skip_image_token: true,
..FormatOpts::formatter_default()
};
let s = text_content(&f.format_message("hi", &opts).unwrap());
assert_eq!(s, "<|audio_1|><|audio_2|>hi");
}
#[test]
fn format_count_cap_exceeded() {
let f = MessageFormatter::for_model("qwen2_vl").unwrap();
let opts = opts_with("user", MAX_MESSAGE_FORMAT_ITEMS + 1, 0);
let err = f.format_message("hi", &opts).unwrap_err();
assert!(matches!(err, Error::CapExceeded(_)));
if let Error::CapExceeded(p) = &err {
assert_eq!(p.cap(), MAX_MESSAGE_FORMAT_ITEMS as u64);
assert_eq!(p.observed(), (MAX_MESSAGE_FORMAT_ITEMS + 1) as u64);
assert_eq!(p.cap_name(), "MAX_MESSAGE_FORMAT_ITEMS");
}
}
#[test]
fn format_count_at_cap_ok() {
assert!(check_format_count(MAX_MESSAGE_FORMAT_ITEMS, "num_images", "m").is_ok());
assert!(check_format_count(MAX_MESSAGE_FORMAT_ITEMS + 1, "num_images", "m").is_err());
}
#[test]
fn prompt_with_start_image_token_direct() {
let f = MessageFormatter {
model_name: "synthetic".to_string(),
format_type: MessageFormat::PromptWithStartImageToken,
};
let out = f.format_message("hi", &opts_with("user", 2, 0)).unwrap();
assert_eq!(
out,
FormattedMessage::String("hi<start_of_image><start_of_image>".to_string())
);
let out = f
.format_message("hi", &opts_with("assistant", 1, 0))
.unwrap();
assert_eq!(
out,
FormattedMessage::String("hi<start_of_image>".to_string())
);
}
#[test]
fn list_with_image_type_text_image_last_direct() {
let f = MessageFormatter {
model_name: "synthetic".to_string(),
format_type: MessageFormat::ListWithImageTypeTextImageLast,
};
let it = items(&f.format_message("hi", &opts_with("user", 2, 1)).unwrap());
assert_eq!(
it,
vec![
ContentItem::Text {
text: "hi".to_string()
},
ContentItem::Image,
ContentItem::Image,
ContentItem::Audio,
]
);
}
#[test]
fn video_with_text_base_format_arm() {
let f = MessageFormatter {
model_name: "synthetic".to_string(),
format_type: MessageFormat::VideoWithText,
};
let opts = FormatOpts {
role: "user".to_string(),
video: vec!["v.mp4".to_string()],
fps: vec![4],
max_pixels: 64,
..FormatOpts::formatter_default()
};
let it = items(&f.format_message("cap", &opts).unwrap());
assert_eq!(
it,
vec![
ContentItem::Video {
video: "v.mp4".to_string(),
max_pixels: 64,
fps: 4,
},
ContentItem::Text {
text: "cap".to_string()
},
]
);
let empty = FormatOpts {
role: "user".to_string(),
video: Vec::new(),
..FormatOpts::formatter_default()
};
let err = f.format_message("cap", &empty).unwrap_err();
assert!(matches!(err, Error::EmptyInput(_)));
if let Error::EmptyInput(p) = &err {
assert!(p.context().contains("format_video_message"));
}
}
#[test]
fn get_message_json_default_text_only() {
let out = get_message_json("qwen2_vl", "hi", None).unwrap();
let it = items(&out);
assert_eq!(
it,
vec![ContentItem::Text {
text: "hi".to_string()
}]
);
}
#[test]
fn get_message_json_explicit_opts_and_unknown_model() {
let opts = opts_with("user", 1, 0);
let out = get_message_json("qwen2_vl", "hi", Some(&opts)).unwrap();
assert_eq!(
items(&out),
vec![
ContentItem::Text {
text: "hi".to_string()
},
ContentItem::Image,
]
);
let err = get_message_json("does_not_exist", "hi", None).unwrap_err();
assert!(matches!(err, Error::MissingKey(_)));
}
fn text_content_msg(role: &str, text: &str) -> FormattedMessage {
FormattedMessage::Message(Message {
role: role.to_string(),
content: MessageContent::Text(text.to_string()),
})
}
#[test]
fn model_config_keys_strictly_sorted_and_unique() {
for pair in MODEL_CONFIG.windows(2) {
let (a, _) = pair[0];
let (b, _) = pair[1];
assert!(
a < b,
"MODEL_CONFIG must be strictly ascending for binary_search_by: {a:?} !< {b:?}"
);
}
}
#[test]
fn single_image_only_models_strictly_sorted_and_unique() {
for pair in SINGLE_IMAGE_ONLY_MODELS.windows(2) {
assert!(
pair[0] < pair[1],
"SINGLE_IMAGE_ONLY_MODELS must be strictly ascending for binary_search: \
{:?} !< {:?}",
pair[0],
pair[1]
);
}
}
#[test]
fn video_format_models_strictly_sorted_and_unique() {
for pair in VIDEO_FORMAT_MODELS.windows(2) {
assert!(
pair[0] < pair[1],
"VIDEO_FORMAT_MODELS must be strictly ascending for binary_search: {:?} !< {:?}",
pair[0],
pair[1]
);
}
}
#[test]
fn model_config_every_key_resolves() {
for &(key, fmt) in MODEL_CONFIG {
let f = MessageFormatter::for_model(key).unwrap();
assert_eq!(f.format_type, fmt, "for_model({key:?}) format mismatch");
assert_eq!(
f.model_name, key,
"for_model({key:?}) should lowercase to itself"
);
}
}
#[test]
fn single_image_only_models_reject_multi_image() {
for &model in SINGLE_IMAGE_ONLY_MODELS {
if let Ok(f) = MessageFormatter::for_model(model) {
let err = f
.format_message("hi", &opts_with("user", 2, 0))
.unwrap_err();
assert!(
matches!(err, Error::OutOfRange(_)),
"model {model:?} should reject num_images=2, got {err:?}"
);
}
}
}
#[test]
fn message_format_tags_are_pairwise_distinct() {
let tags: Vec<&'static str> = MESSAGE_FORMAT_VARIANTS.iter().map(|f| f.as_str()).collect();
for i in 0..tags.len() {
for j in (i + 1)..tags.len() {
assert_ne!(
tags[i], tags[j],
"duplicate MessageFormat tag {:?}",
tags[i]
);
}
}
}
#[test]
fn locate_image_tokens_collapse_and_split() {
assert_eq!(
locate_image_tokens(&[5, 9, 9, 5, 9, 5], 9),
vec![(1, 3), (4, 5)]
);
assert_eq!(locate_image_tokens(&[9, 9, 9, 9], 9), vec![(0, 4)]);
}
#[test]
fn mask_span_end_equals_seq_len_is_valid() {
let mut mask = build_multimodal_mask(3, &[(1, 3)]).unwrap();
assert_eq!(mask.shape(), vec![1, 1, 3, 3]);
let v = mask.to_vec::<bool>().unwrap();
let block = [0u32, 1, 1];
for q in 0..3usize {
for k in 0..3usize {
let causal = k <= q;
let same_img = block[q] != 0 && block[q] == block[k];
assert_eq!(v[q * 3 + k], causal || same_img, "mask[{q}][{k}]");
}
}
}
#[test]
fn list_with_image_first_assistant_role_text_only() {
let f = MessageFormatter::for_model("qwen2_5_vl").unwrap();
let it = items(
&f.format_message("hi", &opts_with("assistant", 4, 0))
.unwrap(),
);
assert_eq!(
it,
vec![ContentItem::Text {
text: "hi".to_string()
}]
);
}
#[test]
fn get_message_json_prompt_only_string_branch() {
let opts = opts_with("user", 1, 0);
let out = get_message_json("florence2", "describe", Some(&opts)).unwrap();
assert_eq!(out, FormattedMessage::String("describe".to_string()));
}
}