1use anyhow::{Result, bail};
29use std::path::PathBuf;
30
31type TokenizerFn<'a> = dyn Fn(&str) -> Result<Vec<u32>> + 'a;
32
33#[derive(Debug, Clone)]
35pub enum MediaSource {
36 FilePath(PathBuf),
38 Bytes(Vec<u8>),
40}
41
42#[derive(Debug, Clone)]
46pub struct MtmdTurn {
47 pub role: String,
48 pub text: String,
49 pub images: Vec<MediaSource>,
50 pub audio: Vec<MediaSource>,
51}
52
53impl MtmdTurn {
54 pub fn user(text: impl Into<String>) -> Self {
55 Self {
56 role: "user".into(),
57 text: text.into(),
58 images: Vec::new(),
59 audio: Vec::new(),
60 }
61 }
62 pub fn system(text: impl Into<String>) -> Self {
63 Self {
64 role: "system".into(),
65 text: text.into(),
66 images: Vec::new(),
67 audio: Vec::new(),
68 }
69 }
70 pub fn assistant(text: impl Into<String>) -> Self {
71 Self {
72 role: "assistant".into(),
73 text: text.into(),
74 images: Vec::new(),
75 audio: Vec::new(),
76 }
77 }
78 pub fn with_image_path(mut self, path: impl Into<PathBuf>) -> Self {
79 self.images.push(MediaSource::FilePath(path.into()));
80 self
81 }
82 pub fn with_image_bytes(mut self, bytes: Vec<u8>) -> Self {
83 self.images.push(MediaSource::Bytes(bytes));
84 self
85 }
86 pub fn with_audio_path(mut self, path: impl Into<PathBuf>) -> Self {
87 self.audio.push(MediaSource::FilePath(path.into()));
88 self
89 }
90 pub fn with_audio_bytes(mut self, bytes: Vec<u8>) -> Self {
91 self.audio.push(MediaSource::Bytes(bytes));
92 self
93 }
94
95 pub fn has_media(&self) -> bool {
96 !self.images.is_empty() || !self.audio.is_empty()
97 }
98}
99
100#[derive(Debug, Clone, Default)]
105pub struct AssembledTurn {
106 pub text_tokens: Vec<u32>,
107 pub image_refs: Vec<MediaSource>,
108 pub audio_refs: Vec<MediaSource>,
109}
110
111pub struct MtmdContext {
115 template_source: String,
116 bos_token: Option<String>,
117 eos_token: Option<String>,
118}
119
120impl MtmdContext {
121 pub fn from_template_source(src: impl Into<String>) -> Self {
124 Self {
125 template_source: src.into(),
126 bos_token: None,
127 eos_token: None,
128 }
129 }
130
131 pub fn with_tokens(mut self, bos: Option<String>, eos: Option<String>) -> Self {
132 self.bos_token = bos;
133 self.eos_token = eos;
134 self
135 }
136
137 pub fn template_source(&self) -> &str {
138 &self.template_source
139 }
140 pub fn bos_token(&self) -> Option<&str> {
141 self.bos_token.as_deref()
142 }
143 pub fn eos_token(&self) -> Option<&str> {
144 self.eos_token.as_deref()
145 }
146
147 pub fn build_turn(
161 &self,
162 turns: &[MtmdTurn],
163 tokenizer_fn: Option<&TokenizerFn<'_>>,
164 ) -> Result<AssembledTurn> {
165 if turns.is_empty() {
166 bail!("MtmdContext::build_turn: empty turn list");
167 }
168 let mut text = String::new();
169 let mut image_refs = Vec::new();
170 let mut audio_refs = Vec::new();
171
172 if let Some(bos) = self.bos_token.as_deref() {
176 text.push_str(bos);
177 }
178 for t in turns {
179 text.push_str("<|im_start|>");
180 text.push_str(&t.role);
181 text.push('\n');
182 text.push_str(&t.text);
183 for img in &t.images {
186 text.push_str("<|image|>");
187 image_refs.push(img.clone());
188 }
189 for au in &t.audio {
190 text.push_str("<|audio|>");
191 audio_refs.push(au.clone());
192 }
193 text.push_str("<|im_end|>\n");
194 }
195 if let Some(eos) = self.eos_token.as_deref() {
196 text.push_str(eos);
197 }
198
199 let text_tokens = match tokenizer_fn {
200 Some(f) => f(&text)?,
201 None => Vec::new(),
202 };
203 Ok(AssembledTurn {
204 text_tokens,
205 image_refs,
206 audio_refs,
207 })
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn build_turn_records_media_order() {
217 let ctx = MtmdContext::from_template_source("").with_tokens(None, None);
218 let turn = MtmdTurn::user("describe")
219 .with_image_path("/tmp/a.png")
220 .with_audio_path("/tmp/b.wav")
221 .with_image_path("/tmp/c.png");
222 let out = ctx.build_turn(&[turn], None).unwrap();
223 assert_eq!(out.image_refs.len(), 2);
224 assert_eq!(out.audio_refs.len(), 1);
225 assert!(out.text_tokens.is_empty());
227 }
228
229 #[test]
230 fn build_turn_invokes_tokenizer_callback() {
231 let ctx = MtmdContext::from_template_source("");
232 let counter = std::cell::Cell::new(0u32);
233 let tokenize = |s: &str| -> Result<Vec<u32>> {
234 counter.set(s.len() as u32);
235 Ok(vec![1, 2, 3])
236 };
237 let turn = MtmdTurn::user("hello");
238 let out = ctx
239 .build_turn(
240 &[turn],
241 Some(&tokenize as &dyn Fn(&str) -> Result<Vec<u32>>),
242 )
243 .unwrap();
244 assert_eq!(out.text_tokens, vec![1, 2, 3]);
245 assert!(counter.get() > 0, "tokenizer must see the rendered text");
246 }
247
248 #[test]
249 fn build_turn_rejects_empty() {
250 let ctx = MtmdContext::from_template_source("");
251 let err = ctx.build_turn(&[], None).unwrap_err();
252 assert!(format!("{err}").contains("empty turn list"));
253 }
254}