use anyhow::{Result, bail};
use std::path::PathBuf;
type TokenizerFn<'a> = dyn Fn(&str) -> Result<Vec<u32>> + 'a;
#[derive(Debug, Clone)]
pub enum MediaSource {
FilePath(PathBuf),
Bytes(Vec<u8>),
}
#[derive(Debug, Clone)]
pub struct MtmdTurn {
pub role: String,
pub text: String,
pub images: Vec<MediaSource>,
pub audio: Vec<MediaSource>,
}
impl MtmdTurn {
pub fn user(text: impl Into<String>) -> Self {
Self {
role: "user".into(),
text: text.into(),
images: Vec::new(),
audio: Vec::new(),
}
}
pub fn system(text: impl Into<String>) -> Self {
Self {
role: "system".into(),
text: text.into(),
images: Vec::new(),
audio: Vec::new(),
}
}
pub fn assistant(text: impl Into<String>) -> Self {
Self {
role: "assistant".into(),
text: text.into(),
images: Vec::new(),
audio: Vec::new(),
}
}
pub fn with_image_path(mut self, path: impl Into<PathBuf>) -> Self {
self.images.push(MediaSource::FilePath(path.into()));
self
}
pub fn with_image_bytes(mut self, bytes: Vec<u8>) -> Self {
self.images.push(MediaSource::Bytes(bytes));
self
}
pub fn with_audio_path(mut self, path: impl Into<PathBuf>) -> Self {
self.audio.push(MediaSource::FilePath(path.into()));
self
}
pub fn with_audio_bytes(mut self, bytes: Vec<u8>) -> Self {
self.audio.push(MediaSource::Bytes(bytes));
self
}
pub fn has_media(&self) -> bool {
!self.images.is_empty() || !self.audio.is_empty()
}
}
#[derive(Debug, Clone, Default)]
pub struct AssembledTurn {
pub text_tokens: Vec<u32>,
pub image_refs: Vec<MediaSource>,
pub audio_refs: Vec<MediaSource>,
}
pub struct MtmdContext {
template_source: String,
bos_token: Option<String>,
eos_token: Option<String>,
}
impl MtmdContext {
pub fn from_template_source(src: impl Into<String>) -> Self {
Self {
template_source: src.into(),
bos_token: None,
eos_token: None,
}
}
pub fn with_tokens(mut self, bos: Option<String>, eos: Option<String>) -> Self {
self.bos_token = bos;
self.eos_token = eos;
self
}
pub fn template_source(&self) -> &str {
&self.template_source
}
pub fn bos_token(&self) -> Option<&str> {
self.bos_token.as_deref()
}
pub fn eos_token(&self) -> Option<&str> {
self.eos_token.as_deref()
}
pub fn build_turn(
&self,
turns: &[MtmdTurn],
tokenizer_fn: Option<&TokenizerFn<'_>>,
) -> Result<AssembledTurn> {
if turns.is_empty() {
bail!("MtmdContext::build_turn: empty turn list");
}
let mut text = String::new();
let mut image_refs = Vec::new();
let mut audio_refs = Vec::new();
if let Some(bos) = self.bos_token.as_deref() {
text.push_str(bos);
}
for t in turns {
text.push_str("<|im_start|>");
text.push_str(&t.role);
text.push('\n');
text.push_str(&t.text);
for img in &t.images {
text.push_str("<|image|>");
image_refs.push(img.clone());
}
for au in &t.audio {
text.push_str("<|audio|>");
audio_refs.push(au.clone());
}
text.push_str("<|im_end|>\n");
}
if let Some(eos) = self.eos_token.as_deref() {
text.push_str(eos);
}
let text_tokens = match tokenizer_fn {
Some(f) => f(&text)?,
None => Vec::new(),
};
Ok(AssembledTurn {
text_tokens,
image_refs,
audio_refs,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_turn_records_media_order() {
let ctx = MtmdContext::from_template_source("").with_tokens(None, None);
let turn = MtmdTurn::user("describe")
.with_image_path("/tmp/a.png")
.with_audio_path("/tmp/b.wav")
.with_image_path("/tmp/c.png");
let out = ctx.build_turn(&[turn], None).unwrap();
assert_eq!(out.image_refs.len(), 2);
assert_eq!(out.audio_refs.len(), 1);
assert!(out.text_tokens.is_empty());
}
#[test]
fn build_turn_invokes_tokenizer_callback() {
let ctx = MtmdContext::from_template_source("");
let counter = std::cell::Cell::new(0u32);
let tokenize = |s: &str| -> Result<Vec<u32>> {
counter.set(s.len() as u32);
Ok(vec![1, 2, 3])
};
let turn = MtmdTurn::user("hello");
let out = ctx
.build_turn(
&[turn],
Some(&tokenize as &dyn Fn(&str) -> Result<Vec<u32>>),
)
.unwrap();
assert_eq!(out.text_tokens, vec![1, 2, 3]);
assert!(counter.get() > 0, "tokenizer must see the rendered text");
}
#[test]
fn build_turn_rejects_empty() {
let ctx = MtmdContext::from_template_source("");
let err = ctx.build_turn(&[], None).unwrap_err();
assert!(format!("{err}").contains("empty turn list"));
}
}