llama_cpp_4/mtmd.rs
1//! Safe wrappers for the `libmtmd` multimodal support library.
2//!
3//! `libmtmd` extends llama.cpp with the ability to encode image and audio
4//! inputs (bitmaps) into token embeddings that can then be fed into a
5//! standard [`llama_decode`] call alongside normal text tokens.
6//!
7//! # Quick-start
8//!
9//! ```no_run
10//! # #[cfg(feature = "mtmd")]
11//! # {
12//! use std::path::Path;
13//! use llama_cpp_4::{
14//! llama_backend::LlamaBackend,
15//! model::{LlamaModel, params::LlamaModelParams, AddBos},
16//! context::params::LlamaContextParams,
17//! mtmd::{MtmdContext, MtmdContextParams, MtmdBitmap, MtmdInputChunks, MtmdInputText},
18//! };
19//!
20//! let backend = LlamaBackend::init().unwrap();
21//! let model = LlamaModel::load_from_file(&backend, Path::new("model.gguf"),
22//! &LlamaModelParams::default()).unwrap();
23//! let mut lctx = model.new_context(&backend, LlamaContextParams::default()).unwrap();
24//!
25//! // Load the multimodal projector (mmproj) model.
26//! let ctx_params = MtmdContextParams::default();
27//! let mtmd_ctx = MtmdContext::init_from_file(Path::new("mmproj.gguf"), &model, ctx_params)
28//! .unwrap();
29//!
30//! // Load an image from a file.
31//! let bitmap = MtmdBitmap::from_file(&mtmd_ctx, Path::new("image.jpg")).unwrap();
32//!
33//! // Tokenize a prompt that contains the media marker.
34//! let marker = MtmdContext::default_marker();
35//! let prompt = format!("Describe this image: {marker}");
36//! let text = MtmdInputText::new(&prompt, true, true);
37//! let bitmaps = [&bitmap];
38//!
39//! let mut chunks = MtmdInputChunks::new();
40//! mtmd_ctx.tokenize(&text, &bitmaps, &mut chunks).unwrap();
41//!
42//! // Evaluate / decode all chunks.
43//! let n_batch = lctx.n_batch() as i32;
44//! let mut n_past = 0i32;
45//! mtmd_ctx.eval_chunks(lctx.as_ptr(), &chunks, 0, 0, n_batch, true, &mut n_past).unwrap();
46//! # }
47//! ```
48//!
49//! # Feature flag
50//!
51//! This module is only compiled when the `mtmd` Cargo feature is enabled.
52
53use std::ffi::{CStr, CString};
54use std::path::Path;
55use std::ptr::NonNull;
56use std::slice;
57
58use llama_cpp_sys_4 as sys;
59
60use crate::model::LlamaModel;
61
62// ─────────────────────────────────────────────────────────────────────────────
63// Error types
64// ─────────────────────────────────────────────────────────────────────────────
65
66/// All errors that can be returned by the mtmd module.
67#[derive(Debug, thiserror::Error)]
68pub enum MtmdError {
69 /// The context could not be created (e.g. bad mmproj file).
70 #[error("failed to create mtmd context (null return from mtmd_init_from_file)")]
71 ContextCreateFailed,
72
73 /// The bitmap could not be created.
74 #[error("failed to create mtmd bitmap")]
75 BitmapCreateFailed,
76
77 /// A path could not be converted to a valid C string (embedded NUL byte or non-UTF-8).
78 #[error("invalid path: {0}")]
79 InvalidPath(#[from] std::ffi::NulError),
80
81 /// A path was not representable as UTF-8.
82 #[error("path is not valid UTF-8")]
83 PathNotUtf8,
84
85 /// `mtmd_tokenize` returned an error code.
86 #[error("tokenize error: code {0} (1 = bitmap count mismatch, 2 = preprocessing error)")]
87 TokenizeError(i32),
88
89 /// `mtmd_encode_chunk` returned a non-zero code.
90 #[error("encode error: code {0}")]
91 EncodeError(i32),
92
93 /// `mtmd_helper_eval_chunks` (or single-chunk variant) returned a non-zero code.
94 #[error("eval error: code {0}")]
95 EvalError(i32),
96}
97
98/// A convenience `Result` alias for this module.
99pub type Result<T> = std::result::Result<T, MtmdError>;
100
101// ─────────────────────────────────────────────────────────────────────────────
102// MtmdContextParams
103// ─────────────────────────────────────────────────────────────────────────────
104
105/// Parameters used when creating an [`MtmdContext`].
106///
107/// Obtain a default-initialised instance via [`MtmdContextParams::default()`].
108pub struct MtmdContextParams {
109 pub(crate) params: sys::mtmd_context_params,
110}
111
112impl std::fmt::Debug for MtmdContextParams {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("MtmdContextParams")
115 .field("use_gpu", &self.params.use_gpu)
116 .field("print_timings", &self.params.print_timings)
117 .field("n_threads", &self.params.n_threads)
118 .field("warmup", &self.params.warmup)
119 .field("image_min_tokens", &self.params.image_min_tokens)
120 .field("image_max_tokens", &self.params.image_max_tokens)
121 .finish()
122 }
123}
124
125impl Default for MtmdContextParams {
126 fn default() -> Self {
127 let params = unsafe { sys::mtmd_context_params_default() };
128 Self { params }
129 }
130}
131
132impl MtmdContextParams {
133 /// Whether to run the vision/audio encoder on the GPU (default: `true`).
134 #[must_use]
135 pub fn use_gpu(mut self, v: bool) -> Self {
136 self.params.use_gpu = v;
137 self
138 }
139
140 /// Whether to print timing info after each encode (default: `false`).
141 #[must_use]
142 pub fn print_timings(mut self, v: bool) -> Self {
143 self.params.print_timings = v;
144 self
145 }
146
147 /// Number of threads used for the vision encoder (default taken from
148 /// `mtmd_context_params_default`).
149 #[must_use]
150 pub fn n_threads(mut self, n: i32) -> Self {
151 self.params.n_threads = n;
152 self
153 }
154
155 /// Whether to run a warm-up encode pass after initialisation.
156 #[must_use]
157 pub fn warmup(mut self, v: bool) -> Self {
158 self.params.warmup = v;
159 self
160 }
161
162 /// Minimum number of image tokens (0 = use model default).
163 #[must_use]
164 pub fn image_min_tokens(mut self, n: i32) -> Self {
165 self.params.image_min_tokens = n;
166 self
167 }
168
169 /// Maximum number of image tokens (0 = use model default).
170 #[must_use]
171 pub fn image_max_tokens(mut self, n: i32) -> Self {
172 self.params.image_max_tokens = n;
173 self
174 }
175
176 /// Override the media marker string (e.g. `"<image>"`).
177 ///
178 /// The provided string must not contain interior NUL bytes. Pass `None`
179 /// to use the library default (`mtmd_default_marker()`).
180 ///
181 /// **Note:** the `CString` is stored inside the params so the pointer
182 /// remains valid as long as this `MtmdContextParams` lives.
183 /// # Errors
184 ///
185 /// Returns [`MtmdError`] if the marker string contains a NUL byte.
186 pub fn media_marker(mut self, marker: Option<&str>) -> std::result::Result<Self, MtmdError> {
187 match marker {
188 None => {
189 self.params.media_marker = std::ptr::null();
190 Ok(self)
191 }
192 Some(s) => {
193 let cs = CString::new(s)?;
194 self.params.media_marker = cs.as_ptr();
195 // Leak the CString so the raw pointer stays valid; the caller
196 // must ensure the params don't outlive the string. Since
197 // MtmdContextParams is consumed by MtmdContext::init_from_file,
198 // this is safe.
199 std::mem::forget(cs);
200 Ok(self)
201 }
202 }
203 }
204}
205
206// ─────────────────────────────────────────────────────────────────────────────
207// MtmdContext
208// ─────────────────────────────────────────────────────────────────────────────
209
210/// The main multimodal context.
211///
212/// Wraps a `mtmd_context *`. This context is tied to a specific mmproj model
213/// file and a loaded [`LlamaModel`]. It is safe to share across threads for
214/// `tokenize` calls (read-only), but `encode_chunk` / eval helpers mutate
215/// internal state and must not be called concurrently.
216pub struct MtmdContext {
217 ptr: NonNull<sys::mtmd_context>,
218}
219
220// The underlying mtmd_context is internally synchronised for tokenize().
221// encode / decode must be called from a single thread at a time (caller's
222// responsibility, enforced by the inference semaphore in the server).
223unsafe impl Send for MtmdContext {}
224unsafe impl Sync for MtmdContext {}
225
226impl std::fmt::Debug for MtmdContext {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 f.debug_struct("MtmdContext")
229 .field("ptr", &self.ptr)
230 .finish()
231 }
232}
233
234impl Drop for MtmdContext {
235 fn drop(&mut self) {
236 unsafe { sys::mtmd_free(self.ptr.as_ptr()) }
237 }
238}
239
240impl MtmdContext {
241 /// Returns the default media marker string used in prompts
242 /// (currently `"<__media__>"`).
243 #[must_use]
244 pub fn default_marker() -> &'static str {
245 let ptr = unsafe { sys::mtmd_default_marker() };
246 unsafe { CStr::from_ptr(ptr) }
247 .to_str()
248 .unwrap_or("<__media__>")
249 }
250
251 /// Initialise a multimodal context from an mmproj GGUF file.
252 ///
253 /// # Parameters
254 ///
255 /// * `mmproj_path` – path to the mmproj `.gguf` file
256 /// * `text_model` – the already-loaded text model
257 /// * `params` – context parameters (use [`MtmdContextParams::default()`])
258 ///
259 /// # Errors
260 ///
261 /// Returns [`MtmdError::ContextCreateFailed`] if the underlying C call
262 /// returns a null pointer.
263 #[allow(clippy::needless_pass_by_value)]
264 pub fn init_from_file(
265 mmproj_path: impl AsRef<Path>,
266 text_model: &LlamaModel,
267 params: MtmdContextParams,
268 ) -> Result<Self> {
269 let path = mmproj_path
270 .as_ref()
271 .to_str()
272 .ok_or(MtmdError::PathNotUtf8)?;
273 let c_path = CString::new(path)?;
274
275 let ptr = unsafe {
276 sys::mtmd_init_from_file(c_path.as_ptr(), text_model.model.as_ptr(), params.params)
277 };
278
279 let ptr = NonNull::new(ptr).ok_or(MtmdError::ContextCreateFailed)?;
280 Ok(Self { ptr })
281 }
282
283 // ── Logging ──────────────────────────────────────────────────────────
284
285 /// Silence all clip/mtmd log output by installing a no-op callback.
286 ///
287 /// Call this right after [`init_from_file`](Self::init_from_file) to
288 /// suppress the verbose `clip_model_loader: tensor[N]…` lines that
289 /// clip.cpp emits to its own private logger (separate from `llama_log_set`).
290 pub fn void_logs() {
291 unsafe extern "C" fn noop(
292 _level: sys::ggml_log_level,
293 _text: *const ::std::os::raw::c_char,
294 _ud: *mut ::std::os::raw::c_void,
295 ) {
296 }
297 unsafe { sys::mtmd_log_set(Some(noop), std::ptr::null_mut()) };
298 }
299
300 /// Like [`void_logs`](Self::void_logs), but additionally silences logs
301 /// emitted by the `mtmd_helper_*` layer (e.g. eval/decode helpers).
302 ///
303 /// Internally calls `mtmd_helper_log_set` which also routes through
304 /// `mtmd_log_set`, so this is a strict superset of `void_logs`.
305 pub fn void_helper_logs() {
306 unsafe extern "C" fn noop(
307 _level: sys::ggml_log_level,
308 _text: *const ::std::os::raw::c_char,
309 _ud: *mut ::std::os::raw::c_void,
310 ) {
311 }
312 unsafe { sys::mtmd_helper_log_set(Some(noop), std::ptr::null_mut()) };
313 }
314
315 // ── Capability queries ────────────────────────────────────────────────
316
317 /// Returns `true` if the model supports vision (image) input.
318 #[must_use]
319 pub fn supports_vision(&self) -> bool {
320 unsafe { sys::mtmd_support_vision(self.ptr.as_ptr()) }
321 }
322
323 /// Returns `true` if the model supports audio input.
324 #[must_use]
325 pub fn supports_audio(&self) -> bool {
326 unsafe { sys::mtmd_support_audio(self.ptr.as_ptr()) }
327 }
328
329 /// Returns the audio sample rate in Hz (e.g. 16 000 for Whisper), or
330 /// `-1` if audio is not supported.
331 #[must_use]
332 #[deprecated(note = "use audio_sample_rate() instead")]
333 pub fn audio_bitrate(&self) -> i32 {
334 self.audio_sample_rate()
335 }
336
337 /// Returns the audio sample rate in Hz.
338 #[must_use]
339 pub fn audio_sample_rate(&self) -> i32 {
340 unsafe { sys::mtmd_get_audio_sample_rate(self.ptr.as_ptr()) }
341 }
342
343 /// Whether `llama_decode` must use a non-causal attention mask when
344 /// decoding image embeddings for this model.
345 #[must_use]
346 pub fn decode_use_non_causal(&self, chunk: &MtmdInputChunk<'_>) -> bool {
347 unsafe { sys::mtmd_decode_use_non_causal(self.ptr.as_ptr(), chunk.as_ptr()) }
348 }
349
350 /// Whether the model uses M-RoPE for `llama_decode`.
351 #[must_use]
352 pub fn decode_use_mrope(&self) -> bool {
353 unsafe { sys::mtmd_decode_use_mrope(self.ptr.as_ptr()) }
354 }
355
356 // ── Core API ──────────────────────────────────────────────────────────
357
358 /// Tokenize a text prompt that contains one or more media markers.
359 ///
360 /// The number of `bitmaps` must equal the number of media markers in the
361 /// prompt text, otherwise [`MtmdError::TokenizeError(1)`] is returned.
362 ///
363 /// This call is **thread-safe** (shared `&self`).
364 ///
365 /// # Parameters
366 ///
367 /// * `text` – text + tokenisation options
368 /// * `bitmaps` – slice of [`MtmdBitmap`] references, one per media marker
369 /// * `output` – an [`MtmdInputChunks`] that will be populated with the result
370 ///
371 /// # Errors
372 ///
373 /// Returns [`MtmdError::TokenizeError`] if tokenization fails.
374 pub fn tokenize(
375 &self,
376 text: &MtmdInputText<'_>,
377 bitmaps: &[&MtmdBitmap],
378 output: &mut MtmdInputChunks,
379 ) -> Result<()> {
380 // The C signature is: mtmd_tokenize(..., mtmd_bitmap ** bitmaps, ...)
381 // where each element is a `const mtmd_bitmap *`. We build a Vec of
382 // `*const mtmd_bitmap` and pass a mutable pointer to its first element
383 // (i.e. `*mut *const mtmd_bitmap`) to satisfy the C API.
384 let mut bitmap_ptrs: Vec<*const sys::mtmd_bitmap> = bitmaps
385 .iter()
386 .map(|b| b.ptr.as_ptr().cast_const())
387 .collect();
388
389 let c_text = sys::mtmd_input_text {
390 text: text.c_text.as_ptr(),
391 add_special: text.add_special,
392 parse_special: text.parse_special,
393 };
394
395 let ret = unsafe {
396 sys::mtmd_tokenize(
397 self.ptr.as_ptr(),
398 output.ptr.as_ptr(),
399 &raw const c_text,
400 bitmap_ptrs.as_mut_ptr(),
401 bitmap_ptrs.len(),
402 )
403 };
404
405 if ret != 0 {
406 return Err(MtmdError::TokenizeError(ret));
407 }
408 Ok(())
409 }
410
411 /// Encode a single input chunk (image or audio) and store the resulting
412 /// embeddings inside the context.
413 ///
414 /// After a successful call, the embeddings can be retrieved with
415 /// [`MtmdContext::output_embd`].
416 ///
417 /// This call is **NOT thread-safe**.
418 ///
419 /// # Errors
420 ///
421 /// Returns [`MtmdError::EncodeError`] if encoding fails.
422 pub fn encode_chunk(&self, chunk: &MtmdInputChunk<'_>) -> Result<()> {
423 let ret = unsafe { sys::mtmd_encode_chunk(self.ptr.as_ptr(), chunk.ptr) };
424 if ret != 0 {
425 return Err(MtmdError::EncodeError(ret));
426 }
427 Ok(())
428 }
429
430 /// Return a slice over the embeddings produced by the last
431 /// [`encode_chunk`](Self::encode_chunk) call.
432 ///
433 /// The length (in `f32` elements) is:
434 /// ```text
435 /// n_embd_inp(model) * chunk.n_tokens()
436 /// ```
437 ///
438 /// # Safety
439 ///
440 /// The returned slice is valid until the next call that mutates the
441 /// context (e.g. another `encode_chunk`).
442 #[must_use]
443 pub fn output_embd(&self, n_elements: usize) -> &[f32] {
444 let ptr = unsafe { sys::mtmd_get_output_embd(self.ptr.as_ptr()) };
445 if ptr.is_null() || n_elements == 0 {
446 return &[];
447 }
448 unsafe { slice::from_raw_parts(ptr, n_elements) }
449 }
450
451 // ── Helper API ────────────────────────────────────────────────────────
452
453 /// High-level helper: evaluate (decode) all chunks in sequence.
454 ///
455 /// * Text chunks are decoded via `llama_decode`.
456 /// * Image/audio chunks are first encoded with `mtmd_encode_chunk` and
457 /// then decoded via `llama_decode`.
458 ///
459 /// On success `new_n_past` is updated with the new past position.
460 ///
461 /// This call is **NOT thread-safe**.
462 ///
463 /// # Parameters
464 ///
465 /// * `lctx` – raw pointer to the llama context (from [`LlamaContext::as_ptr`])
466 /// * `chunks` – the tokenized chunks to evaluate
467 /// * `n_past` – current KV-cache position
468 /// * `seq_id` – sequence ID
469 /// * `n_batch` – maximum batch size (must be ≥ 1)
470 /// * `logits_last` – if `true`, compute logits only for the final token
471 /// * `new_n_past` – updated KV-cache position after the call
472 ///
473 /// # Errors
474 ///
475 /// Returns [`MtmdError::EvalError`] if evaluation fails.
476 #[allow(clippy::too_many_arguments, clippy::not_unsafe_ptr_arg_deref)]
477 pub fn eval_chunks(
478 &self,
479 lctx: *mut sys::llama_context,
480 chunks: &MtmdInputChunks,
481 n_past: i32,
482 seq_id: i32,
483 n_batch: i32,
484 logits_last: bool,
485 new_n_past: &mut i32,
486 ) -> Result<()> {
487 let ret = unsafe {
488 sys::mtmd_helper_eval_chunks(
489 self.ptr.as_ptr(),
490 lctx,
491 chunks.ptr.as_ptr(),
492 n_past,
493 seq_id,
494 n_batch,
495 logits_last,
496 new_n_past,
497 )
498 };
499 if ret != 0 {
500 return Err(MtmdError::EvalError(ret));
501 }
502 Ok(())
503 }
504
505 /// High-level helper: evaluate a single chunk.
506 ///
507 /// Works identically to [`eval_chunks`](Self::eval_chunks) but operates on
508 /// one chunk at a time.
509 ///
510 /// # Errors
511 ///
512 /// Returns [`MtmdError::EvalError`] if evaluation fails.
513 #[allow(clippy::too_many_arguments, clippy::not_unsafe_ptr_arg_deref)]
514 pub fn eval_chunk_single(
515 &self,
516 lctx: *mut sys::llama_context,
517 chunk: &MtmdInputChunk<'_>,
518 n_past: i32,
519 seq_id: i32,
520 n_batch: i32,
521 logits_last: bool,
522 new_n_past: &mut i32,
523 ) -> Result<()> {
524 let ret = unsafe {
525 sys::mtmd_helper_eval_chunk_single(
526 self.ptr.as_ptr(),
527 lctx,
528 chunk.ptr,
529 n_past,
530 seq_id,
531 n_batch,
532 logits_last,
533 new_n_past,
534 )
535 };
536 if ret != 0 {
537 return Err(MtmdError::EvalError(ret));
538 }
539 Ok(())
540 }
541
542 /// Decode an image/audio chunk whose embeddings have already been
543 /// computed (e.g. via [`encode_chunk`](Self::encode_chunk) followed by
544 /// [`output_embd`](Self::output_embd)).
545 ///
546 /// Unlike [`eval_chunk_single`](Self::eval_chunk_single), this helper
547 /// handles batching plus the non-causal-attention setup required by
548 /// some models (e.g. Gemma 3, Gemma 4 audio) and the M-RoPE position
549 /// layout. Use it when the embeddings are already in hand and you want
550 /// the helper to take care of `llama_decode` plumbing.
551 ///
552 /// `encoded_embd` must contain `mtmd_image_tokens_get_n_tokens(chunk) *
553 /// llama_model_n_embd_inp(model)` `f32` elements. This call is **NOT
554 /// thread-safe**.
555 ///
556 /// # Errors
557 ///
558 /// Returns [`MtmdError::EvalError`] with code `-1` if `chunk` is not an
559 /// image/audio chunk, or `1` if `llama_decode` fails.
560 #[allow(clippy::too_many_arguments, clippy::not_unsafe_ptr_arg_deref)]
561 pub fn decode_image_chunk(
562 &self,
563 lctx: *mut sys::llama_context,
564 chunk: &MtmdInputChunk<'_>,
565 encoded_embd: &[f32],
566 n_past: i32,
567 seq_id: i32,
568 n_batch: i32,
569 new_n_past: &mut i32,
570 ) -> Result<()> {
571 let ret = unsafe {
572 sys::mtmd_helper_decode_image_chunk(
573 self.ptr.as_ptr(),
574 lctx,
575 chunk.ptr,
576 encoded_embd.as_ptr().cast_mut(),
577 n_past,
578 seq_id,
579 n_batch,
580 new_n_past,
581 )
582 };
583 if ret != 0 {
584 return Err(MtmdError::EvalError(ret));
585 }
586 Ok(())
587 }
588
589 /// Returns a raw pointer to the underlying `mtmd_context`.
590 ///
591 /// # Safety
592 ///
593 /// The returned pointer is valid for the lifetime of this `MtmdContext`.
594 /// The caller must not free it.
595 #[must_use]
596 pub fn as_ptr(&self) -> *mut sys::mtmd_context {
597 self.ptr.as_ptr()
598 }
599}
600
601// ─────────────────────────────────────────────────────────────────────────────
602// MtmdInputText
603// ─────────────────────────────────────────────────────────────────────────────
604
605/// Text input for [`MtmdContext::tokenize`].
606///
607/// The prompt string must contain the media marker (see
608/// [`MtmdContext::default_marker`]) once for every bitmap to be embedded.
609#[derive(Debug)]
610pub struct MtmdInputText<'a> {
611 c_text: CString,
612 add_special: bool,
613 parse_special: bool,
614 _marker: std::marker::PhantomData<&'a ()>,
615}
616
617impl<'a> MtmdInputText<'a> {
618 /// Create a new `MtmdInputText`.
619 ///
620 /// * `text` – the prompt (must not contain interior NUL bytes)
621 /// * `add_special` – whether to add BOS/EOS tokens
622 /// * `parse_special` – whether to parse special tokens embedded in the text
623 ///
624 /// # Panics
625 ///
626 /// Panics if `text` contains an interior NUL byte.
627 #[must_use]
628 pub fn new(text: &'a str, add_special: bool, parse_special: bool) -> Self {
629 let c_text = CString::new(text).expect("MtmdInputText: text must not contain NUL bytes");
630 Self {
631 c_text,
632 add_special,
633 parse_special,
634 _marker: std::marker::PhantomData,
635 }
636 }
637
638 /// Try to create a new `MtmdInputText`, returning an error if `text`
639 /// contains an interior NUL byte.
640 ///
641 /// # Errors
642 ///
643 /// Returns [`std::ffi::NulError`] if `text` contains a NUL byte.
644 pub fn try_new(
645 text: &'a str,
646 add_special: bool,
647 parse_special: bool,
648 ) -> std::result::Result<Self, std::ffi::NulError> {
649 let c_text = CString::new(text)?;
650 Ok(Self {
651 c_text,
652 add_special,
653 parse_special,
654 _marker: std::marker::PhantomData,
655 })
656 }
657}
658
659// ─────────────────────────────────────────────────────────────────────────────
660// MtmdBitmap
661// ─────────────────────────────────────────────────────────────────────────────
662
663/// An image or audio bitmap ready for multimodal encoding.
664///
665/// # Image bitmaps
666///
667/// The raw pixel data must be in RGBRGBRGB… (interleaved) format. The total
668/// number of bytes must be `nx * ny * 3`.
669///
670/// # Audio bitmaps
671///
672/// The raw sample data must be little-endian `f32` PCM samples. The total
673/// number of bytes must be `n_samples * 4`.
674pub struct MtmdBitmap {
675 ptr: NonNull<sys::mtmd_bitmap>,
676}
677
678unsafe impl Send for MtmdBitmap {}
679unsafe impl Sync for MtmdBitmap {}
680
681impl std::fmt::Debug for MtmdBitmap {
682 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
683 f.debug_struct("MtmdBitmap")
684 .field("nx", &self.nx())
685 .field("ny", &self.ny())
686 .field("n_bytes", &self.n_bytes())
687 .field("is_audio", &self.is_audio())
688 .finish()
689 }
690}
691
692impl Drop for MtmdBitmap {
693 fn drop(&mut self) {
694 unsafe { sys::mtmd_bitmap_free(self.ptr.as_ptr()) }
695 }
696}
697
698impl MtmdBitmap {
699 /// Create a bitmap from raw RGB pixel data.
700 ///
701 /// * `nx` – image width in pixels
702 /// * `ny` – image height in pixels
703 /// * `data` – raw pixel bytes in RGBRGB… format; must be `nx * ny * 3` bytes
704 ///
705 /// # Errors
706 ///
707 /// Returns [`MtmdError::BitmapCreateFailed`] if the underlying C call
708 /// returns null.
709 pub fn from_rgb(nx: u32, ny: u32, data: &[u8]) -> Result<Self> {
710 let ptr = unsafe { sys::mtmd_bitmap_init(nx, ny, data.as_ptr()) };
711 let ptr = NonNull::new(ptr).ok_or(MtmdError::BitmapCreateFailed)?;
712 Ok(Self { ptr })
713 }
714
715 /// Create an audio bitmap from PCM `f32` samples.
716 ///
717 /// * `samples` – slice of PCM float samples
718 ///
719 /// # Errors
720 ///
721 /// Returns [`MtmdError::BitmapCreateFailed`] if the underlying C call
722 /// returns null.
723 pub fn from_audio(samples: &[f32]) -> Result<Self> {
724 let ptr = unsafe { sys::mtmd_bitmap_init_from_audio(samples.len(), samples.as_ptr()) };
725 let ptr = NonNull::new(ptr).ok_or(MtmdError::BitmapCreateFailed)?;
726 Ok(Self { ptr })
727 }
728
729 /// Load a bitmap from a file (image or audio).
730 ///
731 /// Supported image formats: JPEG, PNG, BMP, GIF, and others handled by
732 /// `stb_image`. Supported audio formats: WAV, MP3, FLAC (via miniaudio).
733 ///
734 /// # Errors
735 ///
736 /// Returns [`MtmdError::BitmapCreateFailed`] if the file cannot be loaded.
737 pub fn from_file(ctx: &MtmdContext, path: impl AsRef<Path>) -> Result<Self> {
738 let path = path.as_ref().to_str().ok_or(MtmdError::PathNotUtf8)?;
739 let c_path = CString::new(path)?;
740
741 let ptr =
742 unsafe { sys::mtmd_helper_bitmap_init_from_file(ctx.ptr.as_ptr(), c_path.as_ptr()) };
743 let ptr = NonNull::new(ptr).ok_or(MtmdError::BitmapCreateFailed)?;
744 Ok(Self { ptr })
745 }
746
747 /// Load a bitmap from an in-memory buffer containing a file.
748 ///
749 /// The format is auto-detected (image vs audio via magic bytes).
750 ///
751 /// # Errors
752 ///
753 /// Returns [`MtmdError::BitmapCreateFailed`] if decoding fails.
754 pub fn from_buf(ctx: &MtmdContext, buf: &[u8]) -> Result<Self> {
755 let ptr = unsafe {
756 sys::mtmd_helper_bitmap_init_from_buf(ctx.ptr.as_ptr(), buf.as_ptr(), buf.len())
757 };
758 let ptr = NonNull::new(ptr).ok_or(MtmdError::BitmapCreateFailed)?;
759 Ok(Self { ptr })
760 }
761
762 // ── Getters ───────────────────────────────────────────────────────────
763
764 /// Width in pixels (for images) or 0 (for audio).
765 #[must_use]
766 pub fn nx(&self) -> u32 {
767 unsafe { sys::mtmd_bitmap_get_nx(self.ptr.as_ptr()) }
768 }
769
770 /// Height in pixels (for images) or 0 (for audio).
771 #[must_use]
772 pub fn ny(&self) -> u32 {
773 unsafe { sys::mtmd_bitmap_get_ny(self.ptr.as_ptr()) }
774 }
775
776 /// Total number of bytes in the bitmap data.
777 #[must_use]
778 pub fn n_bytes(&self) -> usize {
779 unsafe { sys::mtmd_bitmap_get_n_bytes(self.ptr.as_ptr()) }
780 }
781
782 /// Returns `true` if this bitmap contains audio (rather than image) data.
783 #[must_use]
784 pub fn is_audio(&self) -> bool {
785 unsafe { sys::mtmd_bitmap_is_audio(self.ptr.as_ptr()) }
786 }
787
788 /// Return the raw pixel / sample data.
789 #[must_use]
790 pub fn data(&self) -> &[u8] {
791 let n = self.n_bytes();
792 if n == 0 {
793 return &[];
794 }
795 let ptr = unsafe { sys::mtmd_bitmap_get_data(self.ptr.as_ptr()) };
796 unsafe { slice::from_raw_parts(ptr, n) }
797 }
798
799 /// Return the optional ID string attached to this bitmap (used for KV
800 /// cache tracking), or `None` if no ID has been set.
801 #[must_use]
802 pub fn id(&self) -> Option<&str> {
803 let ptr = unsafe { sys::mtmd_bitmap_get_id(self.ptr.as_ptr()) };
804 if ptr.is_null() {
805 return None;
806 }
807 unsafe { CStr::from_ptr(ptr) }.to_str().ok()
808 }
809
810 /// Attach an optional ID string to this bitmap (used for KV cache
811 /// tracking).
812 ///
813 /// # Errors
814 ///
815 /// Returns an error if `id` contains an interior NUL byte.
816 pub fn set_id(&mut self, id: &str) -> std::result::Result<(), std::ffi::NulError> {
817 let cs = CString::new(id)?;
818 unsafe { sys::mtmd_bitmap_set_id(self.ptr.as_ptr(), cs.as_ptr()) };
819 Ok(())
820 }
821}
822
823// ─────────────────────────────────────────────────────────────────────────────
824// MtmdInputChunks
825// ─────────────────────────────────────────────────────────────────────────────
826
827/// A list of tokenized input chunks produced by [`MtmdContext::tokenize`].
828///
829/// Each chunk is either a text token sequence or a set of image/audio tokens.
830pub struct MtmdInputChunks {
831 ptr: NonNull<sys::mtmd_input_chunks>,
832}
833
834impl std::fmt::Debug for MtmdInputChunks {
835 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
836 f.debug_struct("MtmdInputChunks")
837 .field("len", &self.len())
838 .finish()
839 }
840}
841
842impl Drop for MtmdInputChunks {
843 fn drop(&mut self) {
844 unsafe { sys::mtmd_input_chunks_free(self.ptr.as_ptr()) }
845 }
846}
847
848impl MtmdInputChunks {
849 /// Create a new, empty chunk list. Populated by
850 /// [`MtmdContext::tokenize`].
851 ///
852 /// # Panics
853 ///
854 /// Panics if the underlying C allocation fails (OOM).
855 #[must_use]
856 pub fn new() -> Self {
857 let ptr = unsafe { sys::mtmd_input_chunks_init() };
858 let ptr = NonNull::new(ptr).expect("mtmd_input_chunks_init returned null");
859 Self { ptr }
860 }
861
862 /// Number of chunks in this list.
863 #[must_use]
864 pub fn len(&self) -> usize {
865 unsafe { sys::mtmd_input_chunks_size(self.ptr.as_ptr()) }
866 }
867
868 /// Returns `true` if there are no chunks.
869 #[must_use]
870 pub fn is_empty(&self) -> bool {
871 self.len() == 0
872 }
873
874 /// Get the `idx`-th chunk. Returns `None` if `idx >= len()`.
875 #[must_use]
876 pub fn get(&self, idx: usize) -> Option<MtmdInputChunk<'_>> {
877 if idx >= self.len() {
878 return None;
879 }
880 let ptr = unsafe { sys::mtmd_input_chunks_get(self.ptr.as_ptr(), idx) };
881 if ptr.is_null() {
882 return None;
883 }
884 Some(MtmdInputChunk {
885 ptr,
886 _marker: std::marker::PhantomData,
887 })
888 }
889
890 /// Iterate over all chunks.
891 pub fn iter(&self) -> impl Iterator<Item = MtmdInputChunk<'_>> {
892 (0..self.len()).filter_map(|i| self.get(i))
893 }
894
895 /// Total number of tokens across all chunks.
896 ///
897 /// Equivalent to `mtmd_helper_get_n_tokens`.
898 #[must_use]
899 pub fn n_tokens(&self) -> usize {
900 unsafe { sys::mtmd_helper_get_n_tokens(self.ptr.as_ptr()) }
901 }
902
903 /// Total number of *positions* across all chunks (used for KV-cache
904 /// tracking with M-RoPE models where positions ≠ tokens).
905 ///
906 /// Equivalent to `mtmd_helper_get_n_pos`.
907 #[must_use]
908 pub fn n_pos(&self) -> i32 {
909 unsafe { sys::mtmd_helper_get_n_pos(self.ptr.as_ptr()) }
910 }
911}
912
913impl Default for MtmdInputChunks {
914 fn default() -> Self {
915 Self::new()
916 }
917}
918
919// ─────────────────────────────────────────────────────────────────────────────
920// MtmdInputChunkType
921// ─────────────────────────────────────────────────────────────────────────────
922
923/// The type of an [`MtmdInputChunk`].
924#[derive(Debug, Clone, Copy, PartialEq, Eq)]
925pub enum MtmdInputChunkType {
926 /// Plain text tokens.
927 Text,
928 /// Image tokens (embeddings produced by the vision encoder).
929 Image,
930 /// Audio tokens (embeddings produced by the audio encoder).
931 Audio,
932}
933
934impl From<sys::mtmd_input_chunk_type> for MtmdInputChunkType {
935 fn from(v: sys::mtmd_input_chunk_type) -> Self {
936 // mtmd_input_chunk_type is a plain C `typedef unsigned int`.
937 // The variants are exported as free-standing constants.
938 if v == sys::MTMD_INPUT_CHUNK_TYPE_IMAGE {
939 Self::Image
940 } else if v == sys::MTMD_INPUT_CHUNK_TYPE_AUDIO {
941 Self::Audio
942 } else {
943 Self::Text
944 }
945 }
946}
947
948// ─────────────────────────────────────────────────────────────────────────────
949// MtmdInputChunk
950// ─────────────────────────────────────────────────────────────────────────────
951
952/// A single tokenized input chunk (text, image, or audio).
953///
954/// Instances are borrowed from an [`MtmdInputChunks`] list and live as long
955/// as that list.
956#[derive(Debug)]
957pub struct MtmdInputChunk<'chunks> {
958 ptr: *const sys::mtmd_input_chunk,
959 _marker: std::marker::PhantomData<&'chunks MtmdInputChunks>,
960}
961
962impl<'chunks> MtmdInputChunk<'chunks> {
963 /// The type of this chunk.
964 #[must_use]
965 pub fn chunk_type(&self) -> MtmdInputChunkType {
966 let t = unsafe { sys::mtmd_input_chunk_get_type(self.ptr) };
967 MtmdInputChunkType::from(t)
968 }
969
970 /// Total number of tokens in this chunk.
971 #[must_use]
972 pub fn n_tokens(&self) -> usize {
973 unsafe { sys::mtmd_input_chunk_get_n_tokens(self.ptr) }
974 }
975
976 /// Number of temporal positions (equals `n_tokens` for non-M-RoPE models).
977 #[must_use]
978 pub fn n_pos(&self) -> i32 {
979 unsafe { sys::mtmd_input_chunk_get_n_pos(self.ptr) }
980 }
981
982 /// Return the raw llama token IDs for a **text** chunk.
983 ///
984 /// Returns `None` if this chunk is not a text chunk.
985 #[must_use]
986 pub fn text_tokens(&self) -> Option<&[i32]> {
987 if self.chunk_type() != MtmdInputChunkType::Text {
988 return None;
989 }
990 let mut n: usize = 0;
991 let ptr = unsafe { sys::mtmd_input_chunk_get_tokens_text(self.ptr, &raw mut n) };
992 if ptr.is_null() || n == 0 {
993 return Some(&[]);
994 }
995 Some(unsafe { slice::from_raw_parts(ptr, n) })
996 }
997
998 /// Return the image token metadata for an **image** or **audio** chunk.
999 ///
1000 /// Returns `None` for text chunks.
1001 #[must_use]
1002 pub fn image_tokens(&self) -> Option<MtmdImageTokens<'chunks>> {
1003 match self.chunk_type() {
1004 MtmdInputChunkType::Image | MtmdInputChunkType::Audio => {}
1005 MtmdInputChunkType::Text => return None,
1006 }
1007 let ptr = unsafe { sys::mtmd_input_chunk_get_tokens_image(self.ptr) };
1008 if ptr.is_null() {
1009 return None;
1010 }
1011 Some(MtmdImageTokens {
1012 ptr,
1013 _marker: std::marker::PhantomData,
1014 })
1015 }
1016
1017 /// Optional ID attached to this chunk (used for KV cache tracking).
1018 #[must_use]
1019 pub fn id(&self) -> Option<&str> {
1020 let ptr = unsafe { sys::mtmd_input_chunk_get_id(self.ptr) };
1021 if ptr.is_null() {
1022 return None;
1023 }
1024 unsafe { CStr::from_ptr(ptr) }.to_str().ok()
1025 }
1026
1027 /// Returns the raw `*const mtmd_input_chunk` pointer.
1028 ///
1029 /// # Safety
1030 ///
1031 /// The returned pointer is valid for the lifetime of the parent
1032 /// `MtmdInputChunks`.
1033 #[must_use]
1034 pub fn as_ptr(&self) -> *const sys::mtmd_input_chunk {
1035 self.ptr
1036 }
1037}
1038
1039// ─────────────────────────────────────────────────────────────────────────────
1040// MtmdDecoderPos
1041// ─────────────────────────────────────────────────────────────────────────────
1042
1043/// Per-token position used by M-RoPE decoder attention.
1044///
1045/// `t` is the temporal axis, `x`/`y` the spatial axes. `z` is reserved for
1046/// future use. Values are *relative* to a base `pos_0` provided when the
1047/// position is computed.
1048#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
1049#[repr(C)]
1050pub struct MtmdDecoderPos {
1051 /// Temporal index.
1052 pub t: u32,
1053 /// Spatial X.
1054 pub x: u32,
1055 /// Spatial Y.
1056 pub y: u32,
1057 /// Reserved.
1058 pub z: u32,
1059}
1060
1061// ─────────────────────────────────────────────────────────────────────────────
1062// MtmdImageTokens
1063// ─────────────────────────────────────────────────────────────────────────────
1064
1065/// Image/audio token metadata attached to a non-text [`MtmdInputChunk`].
1066#[derive(Debug)]
1067pub struct MtmdImageTokens<'chunks> {
1068 ptr: *const sys::mtmd_image_tokens,
1069 _marker: std::marker::PhantomData<&'chunks MtmdInputChunks>,
1070}
1071
1072impl MtmdImageTokens<'_> {
1073 /// Total number of embedding tokens.
1074 #[must_use]
1075 pub fn n_tokens(&self) -> usize {
1076 unsafe { sys::mtmd_image_tokens_get_n_tokens(self.ptr) }
1077 }
1078
1079 /// Width of the token grid.
1080 #[must_use]
1081 pub fn nx(&self) -> usize {
1082 unsafe { sys::mtmd_image_tokens_get_nx(self.ptr) }
1083 }
1084
1085 /// Height of the token grid.
1086 #[must_use]
1087 pub fn ny(&self) -> usize {
1088 unsafe { sys::mtmd_image_tokens_get_ny(self.ptr) }
1089 }
1090
1091 /// Number of temporal positions (M-RoPE variant; equals `n_tokens` otherwise).
1092 #[must_use]
1093 pub fn n_pos(&self) -> i32 {
1094 unsafe { sys::mtmd_image_tokens_get_n_pos(self.ptr) }
1095 }
1096
1097 /// Optional ID for KV cache tracking.
1098 #[must_use]
1099 pub fn id(&self) -> Option<&str> {
1100 let ptr = unsafe { sys::mtmd_image_tokens_get_id(self.ptr) };
1101 if ptr.is_null() {
1102 return None;
1103 }
1104 unsafe { CStr::from_ptr(ptr) }.to_str().ok()
1105 }
1106
1107 /// Compute the per-token decoder positions used by M-RoPE models.
1108 ///
1109 /// Returns a vector of length [`n_tokens`](Self::n_tokens). Each entry
1110 /// is relative to `pos_0`; for non-M-RoPE models this typically reduces
1111 /// to `(0, i, 0, 0)` for the i-th token.
1112 ///
1113 /// Wraps `mtmd_helper_image_get_decoder_pos`.
1114 #[must_use]
1115 pub fn decoder_positions(&self, pos_0: i32) -> Vec<MtmdDecoderPos> {
1116 let n = self.n_tokens();
1117 let mut out = vec![MtmdDecoderPos::default(); n];
1118 if n == 0 {
1119 return out;
1120 }
1121 unsafe {
1122 sys::mtmd_helper_image_get_decoder_pos(
1123 self.ptr,
1124 pos_0,
1125 out.as_mut_ptr().cast::<sys::mtmd_decoder_pos>(),
1126 );
1127 }
1128 out
1129 }
1130}
1131
1132// ─────────────────────────────────────────────────────────────────────────────
1133// LlamaContext extension
1134// ─────────────────────────────────────────────────────────────────────────────
1135
1136use crate::context::LlamaContext;
1137
1138impl LlamaContext<'_> {
1139 /// Expose the raw `llama_context` pointer for use with mtmd helpers.
1140 ///
1141 /// # Safety
1142 ///
1143 /// The pointer is valid for the lifetime of this `LlamaContext` and must
1144 /// not be freed by the caller.
1145 #[must_use]
1146 pub fn as_ptr(&self) -> *mut sys::llama_context {
1147 self.context.as_ptr()
1148 }
1149}
1150
1151#[cfg(test)]
1152mod tests {
1153 use super::*;
1154
1155 #[test]
1156 fn decoder_pos_layout_matches_sys() {
1157 // The Rust MtmdDecoderPos is cast to sys::mtmd_decoder_pos at the
1158 // FFI boundary in `MtmdImageTokens::decoder_positions`. Verify the
1159 // assumption.
1160 assert_eq!(
1161 std::mem::size_of::<MtmdDecoderPos>(),
1162 std::mem::size_of::<sys::mtmd_decoder_pos>(),
1163 );
1164 assert_eq!(
1165 std::mem::align_of::<MtmdDecoderPos>(),
1166 std::mem::align_of::<sys::mtmd_decoder_pos>(),
1167 );
1168 assert_eq!(std::mem::offset_of!(MtmdDecoderPos, t), 0);
1169 assert_eq!(std::mem::offset_of!(MtmdDecoderPos, x), 4);
1170 assert_eq!(std::mem::offset_of!(MtmdDecoderPos, y), 8);
1171 assert_eq!(std::mem::offset_of!(MtmdDecoderPos, z), 12);
1172 }
1173}