Skip to main content

llama_cpp_4/
mtmd.rs

1//! Safe wrappers for the `libmtmd` multimodal support library.
2//!
3//! `libmtmd` extends llama.cpp with the ability to encode image and audio
4//! inputs (bitmaps) into token embeddings that can then be fed into a
5//! standard [`llama_decode`] call alongside normal text tokens.
6//!
7//! # Quick-start
8//!
9//! ```no_run
10//! # #[cfg(feature = "mtmd")]
11//! # {
12//! use std::path::Path;
13//! use llama_cpp_4::{
14//!     llama_backend::LlamaBackend,
15//!     model::{LlamaModel, params::LlamaModelParams, AddBos},
16//!     context::params::LlamaContextParams,
17//!     mtmd::{MtmdContext, MtmdContextParams, MtmdBitmap, MtmdInputChunks, MtmdInputText},
18//! };
19//!
20//! let backend  = LlamaBackend::init().unwrap();
21//! let model    = LlamaModel::load_from_file(&backend, Path::new("model.gguf"),
22//!                                            &LlamaModelParams::default()).unwrap();
23//! let mut lctx = model.new_context(&backend, LlamaContextParams::default()).unwrap();
24//!
25//! // Load the multimodal projector (mmproj) model.
26//! let ctx_params = MtmdContextParams::default();
27//! let mtmd_ctx   = MtmdContext::init_from_file(Path::new("mmproj.gguf"), &model, ctx_params)
28//!                               .unwrap();
29//!
30//! // Load an image from a file.
31//! let bitmap = MtmdBitmap::from_file(&mtmd_ctx, Path::new("image.jpg")).unwrap();
32//!
33//! // Tokenize a prompt that contains the media marker.
34//! let marker  = MtmdContext::default_marker();
35//! let prompt  = format!("Describe this image: {marker}");
36//! let text    = MtmdInputText::new(&prompt, true, true);
37//! let bitmaps = [&bitmap];
38//!
39//! let mut chunks = MtmdInputChunks::new();
40//! mtmd_ctx.tokenize(&text, &bitmaps, &mut chunks).unwrap();
41//!
42//! // Evaluate / decode all chunks.
43//! let n_batch = lctx.n_batch() as i32;
44//! let mut n_past = 0i32;
45//! mtmd_ctx.eval_chunks(lctx.as_ptr(), &chunks, 0, 0, n_batch, true, &mut n_past).unwrap();
46//! # }
47//! ```
48//!
49//! # Feature flag
50//!
51//! This module is only compiled when the `mtmd` Cargo feature is enabled.
52
53use std::ffi::{CStr, CString};
54use std::path::Path;
55use std::ptr::NonNull;
56use std::slice;
57
58use llama_cpp_sys_4 as sys;
59
60use crate::model::LlamaModel;
61
62// ─────────────────────────────────────────────────────────────────────────────
63// Error types
64// ─────────────────────────────────────────────────────────────────────────────
65
66/// All errors that can be returned by the mtmd module.
67#[derive(Debug, thiserror::Error)]
68pub enum MtmdError {
69    /// The context could not be created (e.g. bad mmproj file).
70    #[error("failed to create mtmd context (null return from mtmd_init_from_file)")]
71    ContextCreateFailed,
72
73    /// The bitmap could not be created.
74    #[error("failed to create mtmd bitmap")]
75    BitmapCreateFailed,
76
77    /// A path could not be converted to a valid C string (embedded NUL byte or non-UTF-8).
78    #[error("invalid path: {0}")]
79    InvalidPath(#[from] std::ffi::NulError),
80
81    /// A path was not representable as UTF-8.
82    #[error("path is not valid UTF-8")]
83    PathNotUtf8,
84
85    /// `mtmd_tokenize` returned an error code.
86    #[error("tokenize error: code {0} (1 = bitmap count mismatch, 2 = preprocessing error)")]
87    TokenizeError(i32),
88
89    /// `mtmd_encode_chunk` returned a non-zero code.
90    #[error("encode error: code {0}")]
91    EncodeError(i32),
92
93    /// `mtmd_helper_eval_chunks` (or single-chunk variant) returned a non-zero code.
94    #[error("eval error: code {0}")]
95    EvalError(i32),
96}
97
98/// A convenience `Result` alias for this module.
99pub type Result<T> = std::result::Result<T, MtmdError>;
100
101// ─────────────────────────────────────────────────────────────────────────────
102// MtmdContextParams
103// ─────────────────────────────────────────────────────────────────────────────
104
105/// Parameters used when creating an [`MtmdContext`].
106///
107/// Obtain a default-initialised instance via [`MtmdContextParams::default()`].
108pub struct MtmdContextParams {
109    pub(crate) params: sys::mtmd_context_params,
110}
111
112impl std::fmt::Debug for MtmdContextParams {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("MtmdContextParams")
115            .field("use_gpu", &self.params.use_gpu)
116            .field("print_timings", &self.params.print_timings)
117            .field("n_threads", &self.params.n_threads)
118            .field("warmup", &self.params.warmup)
119            .field("image_min_tokens", &self.params.image_min_tokens)
120            .field("image_max_tokens", &self.params.image_max_tokens)
121            .finish()
122    }
123}
124
125impl Default for MtmdContextParams {
126    fn default() -> Self {
127        let params = unsafe { sys::mtmd_context_params_default() };
128        Self { params }
129    }
130}
131
132impl MtmdContextParams {
133    /// Whether to run the vision/audio encoder on the GPU (default: `true`).
134    #[must_use]
135    pub fn use_gpu(mut self, v: bool) -> Self {
136        self.params.use_gpu = v;
137        self
138    }
139
140    /// Whether to print timing info after each encode (default: `false`).
141    #[must_use]
142    pub fn print_timings(mut self, v: bool) -> Self {
143        self.params.print_timings = v;
144        self
145    }
146
147    /// Number of threads used for the vision encoder (default taken from
148    /// `mtmd_context_params_default`).
149    #[must_use]
150    pub fn n_threads(mut self, n: i32) -> Self {
151        self.params.n_threads = n;
152        self
153    }
154
155    /// Whether to run a warm-up encode pass after initialisation.
156    #[must_use]
157    pub fn warmup(mut self, v: bool) -> Self {
158        self.params.warmup = v;
159        self
160    }
161
162    /// Minimum number of image tokens (0 = use model default).
163    #[must_use]
164    pub fn image_min_tokens(mut self, n: i32) -> Self {
165        self.params.image_min_tokens = n;
166        self
167    }
168
169    /// Maximum number of image tokens (0 = use model default).
170    #[must_use]
171    pub fn image_max_tokens(mut self, n: i32) -> Self {
172        self.params.image_max_tokens = n;
173        self
174    }
175
176    /// Override the media marker string (e.g. `"<image>"`).
177    ///
178    /// The provided string must not contain interior NUL bytes.  Pass `None`
179    /// to use the library default (`mtmd_default_marker()`).
180    ///
181    /// **Note:** the `CString` is stored inside the params so the pointer
182    /// remains valid as long as this `MtmdContextParams` lives.
183    /// # Errors
184    ///
185    /// Returns [`MtmdError`] if the marker string contains a NUL byte.
186    pub fn media_marker(mut self, marker: Option<&str>) -> std::result::Result<Self, MtmdError> {
187        match marker {
188            None => {
189                self.params.media_marker = std::ptr::null();
190                Ok(self)
191            }
192            Some(s) => {
193                let cs = CString::new(s)?;
194                self.params.media_marker = cs.as_ptr();
195                // Leak the CString so the raw pointer stays valid; the caller
196                // must ensure the params don't outlive the string.  Since
197                // MtmdContextParams is consumed by MtmdContext::init_from_file,
198                // this is safe.
199                std::mem::forget(cs);
200                Ok(self)
201            }
202        }
203    }
204}
205
206// ─────────────────────────────────────────────────────────────────────────────
207// MtmdContext
208// ─────────────────────────────────────────────────────────────────────────────
209
210/// The main multimodal context.
211///
212/// Wraps a `mtmd_context *`.  This context is tied to a specific mmproj model
213/// file and a loaded [`LlamaModel`].  It is safe to share across threads for
214/// `tokenize` calls (read-only), but `encode_chunk` / eval helpers mutate
215/// internal state and must not be called concurrently.
216pub struct MtmdContext {
217    ptr: NonNull<sys::mtmd_context>,
218}
219
220// The underlying mtmd_context is internally synchronised for tokenize().
221// encode / decode must be called from a single thread at a time (caller's
222// responsibility, enforced by the inference semaphore in the server).
223unsafe impl Send for MtmdContext {}
224unsafe impl Sync for MtmdContext {}
225
226impl std::fmt::Debug for MtmdContext {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        f.debug_struct("MtmdContext")
229            .field("ptr", &self.ptr)
230            .finish()
231    }
232}
233
234impl Drop for MtmdContext {
235    fn drop(&mut self) {
236        unsafe { sys::mtmd_free(self.ptr.as_ptr()) }
237    }
238}
239
240impl MtmdContext {
241    /// Returns the default media marker string used in prompts
242    /// (currently `"<__media__>"`).
243    #[must_use]
244    pub fn default_marker() -> &'static str {
245        let ptr = unsafe { sys::mtmd_default_marker() };
246        unsafe { CStr::from_ptr(ptr) }
247            .to_str()
248            .unwrap_or("<__media__>")
249    }
250
251    /// Initialise a multimodal context from an mmproj GGUF file.
252    ///
253    /// # Parameters
254    ///
255    /// * `mmproj_path` – path to the mmproj `.gguf` file
256    /// * `text_model`  – the already-loaded text model
257    /// * `params`      – context parameters (use [`MtmdContextParams::default()`])
258    ///
259    /// # Errors
260    ///
261    /// Returns [`MtmdError::ContextCreateFailed`] if the underlying C call
262    /// returns a null pointer.
263    #[allow(clippy::needless_pass_by_value)]
264    pub fn init_from_file(
265        mmproj_path: impl AsRef<Path>,
266        text_model: &LlamaModel,
267        params: MtmdContextParams,
268    ) -> Result<Self> {
269        let path = mmproj_path
270            .as_ref()
271            .to_str()
272            .ok_or(MtmdError::PathNotUtf8)?;
273        let c_path = CString::new(path)?;
274
275        let ptr = unsafe {
276            sys::mtmd_init_from_file(c_path.as_ptr(), text_model.model.as_ptr(), params.params)
277        };
278
279        let ptr = NonNull::new(ptr).ok_or(MtmdError::ContextCreateFailed)?;
280        Ok(Self { ptr })
281    }
282
283    // ── Logging ──────────────────────────────────────────────────────────
284
285    /// Silence all clip/mtmd log output by installing a no-op callback.
286    ///
287    /// Call this right after [`init_from_file`](Self::init_from_file) to
288    /// suppress the verbose `clip_model_loader: tensor[N]…` lines that
289    /// clip.cpp emits to its own private logger (separate from `llama_log_set`).
290    pub fn void_logs() {
291        unsafe extern "C" fn noop(
292            _level: sys::ggml_log_level,
293            _text: *const ::std::os::raw::c_char,
294            _ud: *mut ::std::os::raw::c_void,
295        ) {
296        }
297        unsafe { sys::mtmd_log_set(Some(noop), std::ptr::null_mut()) };
298    }
299
300    // ── Capability queries ────────────────────────────────────────────────
301
302    /// Returns `true` if the model supports vision (image) input.
303    #[must_use]
304    pub fn supports_vision(&self) -> bool {
305        unsafe { sys::mtmd_support_vision(self.ptr.as_ptr()) }
306    }
307
308    /// Returns `true` if the model supports audio input.
309    #[must_use]
310    pub fn supports_audio(&self) -> bool {
311        unsafe { sys::mtmd_support_audio(self.ptr.as_ptr()) }
312    }
313
314    /// Returns the audio sample rate in Hz (e.g. 16 000 for Whisper), or
315    /// `-1` if audio is not supported.
316    #[must_use]
317    #[deprecated(note = "use audio_sample_rate() instead")]
318    pub fn audio_bitrate(&self) -> i32 {
319        self.audio_sample_rate()
320    }
321
322    /// Returns the audio sample rate in Hz.
323    #[must_use]
324    pub fn audio_sample_rate(&self) -> i32 {
325        unsafe { sys::mtmd_get_audio_sample_rate(self.ptr.as_ptr()) }
326    }
327
328    /// Whether `llama_decode` must use a non-causal attention mask when
329    /// decoding image embeddings for this model.
330    #[must_use]
331    pub fn decode_use_non_causal(&self) -> bool {
332        unsafe { sys::mtmd_decode_use_non_causal(self.ptr.as_ptr()) }
333    }
334
335    /// Whether the model uses M-RoPE for `llama_decode`.
336    #[must_use]
337    pub fn decode_use_mrope(&self) -> bool {
338        unsafe { sys::mtmd_decode_use_mrope(self.ptr.as_ptr()) }
339    }
340
341    // ── Core API ──────────────────────────────────────────────────────────
342
343    /// Tokenize a text prompt that contains one or more media markers.
344    ///
345    /// The number of `bitmaps` must equal the number of media markers in the
346    /// prompt text, otherwise [`MtmdError::TokenizeError(1)`] is returned.
347    ///
348    /// This call is **thread-safe** (shared `&self`).
349    ///
350    /// # Parameters
351    ///
352    /// * `text`    – text + tokenisation options
353    /// * `bitmaps` – slice of [`MtmdBitmap`] references, one per media marker
354    /// * `output`  – an [`MtmdInputChunks`] that will be populated with the result
355    ///
356    /// # Errors
357    ///
358    /// Returns [`MtmdError::TokenizeError`] if tokenization fails.
359    pub fn tokenize(
360        &self,
361        text: &MtmdInputText<'_>,
362        bitmaps: &[&MtmdBitmap],
363        output: &mut MtmdInputChunks,
364    ) -> Result<()> {
365        // The C signature is: mtmd_tokenize(..., mtmd_bitmap ** bitmaps, ...)
366        // where each element is a `const mtmd_bitmap *`.  We build a Vec of
367        // `*const mtmd_bitmap` and pass a mutable pointer to its first element
368        // (i.e. `*mut *const mtmd_bitmap`) to satisfy the C API.
369        let mut bitmap_ptrs: Vec<*const sys::mtmd_bitmap> = bitmaps
370            .iter()
371            .map(|b| b.ptr.as_ptr().cast_const())
372            .collect();
373
374        let c_text = sys::mtmd_input_text {
375            text: text.c_text.as_ptr(),
376            add_special: text.add_special,
377            parse_special: text.parse_special,
378        };
379
380        let ret = unsafe {
381            sys::mtmd_tokenize(
382                self.ptr.as_ptr(),
383                output.ptr.as_ptr(),
384                &raw const c_text,
385                bitmap_ptrs.as_mut_ptr(),
386                bitmap_ptrs.len(),
387            )
388        };
389
390        if ret != 0 {
391            return Err(MtmdError::TokenizeError(ret));
392        }
393        Ok(())
394    }
395
396    /// Encode a single input chunk (image or audio) and store the resulting
397    /// embeddings inside the context.
398    ///
399    /// After a successful call, the embeddings can be retrieved with
400    /// [`MtmdContext::output_embd`].
401    ///
402    /// This call is **NOT thread-safe**.
403    ///
404    /// # Errors
405    ///
406    /// Returns [`MtmdError::EncodeError`] if encoding fails.
407    pub fn encode_chunk(&self, chunk: &MtmdInputChunk<'_>) -> Result<()> {
408        let ret = unsafe { sys::mtmd_encode_chunk(self.ptr.as_ptr(), chunk.ptr) };
409        if ret != 0 {
410            return Err(MtmdError::EncodeError(ret));
411        }
412        Ok(())
413    }
414
415    /// Return a slice over the embeddings produced by the last
416    /// [`encode_chunk`](Self::encode_chunk) call.
417    ///
418    /// The length (in `f32` elements) is:
419    /// ```text
420    /// n_embd_inp(model)  *  chunk.n_tokens()
421    /// ```
422    ///
423    /// # Safety
424    ///
425    /// The returned slice is valid until the next call that mutates the
426    /// context (e.g. another `encode_chunk`).
427    #[must_use]
428    pub fn output_embd(&self, n_elements: usize) -> &[f32] {
429        let ptr = unsafe { sys::mtmd_get_output_embd(self.ptr.as_ptr()) };
430        if ptr.is_null() || n_elements == 0 {
431            return &[];
432        }
433        unsafe { slice::from_raw_parts(ptr, n_elements) }
434    }
435
436    // ── Helper API ────────────────────────────────────────────────────────
437
438    /// High-level helper: evaluate (decode) all chunks in sequence.
439    ///
440    /// * Text chunks are decoded via `llama_decode`.
441    /// * Image/audio chunks are first encoded with `mtmd_encode_chunk` and
442    ///   then decoded via `llama_decode`.
443    ///
444    /// On success `new_n_past` is updated with the new past position.
445    ///
446    /// This call is **NOT thread-safe**.
447    ///
448    /// # Parameters
449    ///
450    /// * `lctx`        – raw pointer to the llama context (from [`LlamaContext::as_ptr`])
451    /// * `chunks`      – the tokenized chunks to evaluate
452    /// * `n_past`      – current KV-cache position
453    /// * `seq_id`      – sequence ID
454    /// * `n_batch`     – maximum batch size (must be ≥ 1)
455    /// * `logits_last` – if `true`, compute logits only for the final token
456    /// * `new_n_past`  – updated KV-cache position after the call
457    ///
458    /// # Errors
459    ///
460    /// Returns [`MtmdError::EvalError`] if evaluation fails.
461    #[allow(clippy::too_many_arguments, clippy::not_unsafe_ptr_arg_deref)]
462    pub fn eval_chunks(
463        &self,
464        lctx: *mut sys::llama_context,
465        chunks: &MtmdInputChunks,
466        n_past: i32,
467        seq_id: i32,
468        n_batch: i32,
469        logits_last: bool,
470        new_n_past: &mut i32,
471    ) -> Result<()> {
472        let ret = unsafe {
473            sys::mtmd_helper_eval_chunks(
474                self.ptr.as_ptr(),
475                lctx,
476                chunks.ptr.as_ptr(),
477                n_past,
478                seq_id,
479                n_batch,
480                logits_last,
481                new_n_past,
482            )
483        };
484        if ret != 0 {
485            return Err(MtmdError::EvalError(ret));
486        }
487        Ok(())
488    }
489
490    /// High-level helper: evaluate a single chunk.
491    ///
492    /// Works identically to [`eval_chunks`](Self::eval_chunks) but operates on
493    /// one chunk at a time.
494    ///
495    /// # Errors
496    ///
497    /// Returns [`MtmdError::EvalError`] if evaluation fails.
498    #[allow(clippy::too_many_arguments, clippy::not_unsafe_ptr_arg_deref)]
499    pub fn eval_chunk_single(
500        &self,
501        lctx: *mut sys::llama_context,
502        chunk: &MtmdInputChunk<'_>,
503        n_past: i32,
504        seq_id: i32,
505        n_batch: i32,
506        logits_last: bool,
507        new_n_past: &mut i32,
508    ) -> Result<()> {
509        let ret = unsafe {
510            sys::mtmd_helper_eval_chunk_single(
511                self.ptr.as_ptr(),
512                lctx,
513                chunk.ptr,
514                n_past,
515                seq_id,
516                n_batch,
517                logits_last,
518                new_n_past,
519            )
520        };
521        if ret != 0 {
522            return Err(MtmdError::EvalError(ret));
523        }
524        Ok(())
525    }
526
527    /// Returns a raw pointer to the underlying `mtmd_context`.
528    ///
529    /// # Safety
530    ///
531    /// The returned pointer is valid for the lifetime of this `MtmdContext`.
532    /// The caller must not free it.
533    #[must_use]
534    pub fn as_ptr(&self) -> *mut sys::mtmd_context {
535        self.ptr.as_ptr()
536    }
537}
538
539// ─────────────────────────────────────────────────────────────────────────────
540// MtmdInputText
541// ─────────────────────────────────────────────────────────────────────────────
542
543/// Text input for [`MtmdContext::tokenize`].
544///
545/// The prompt string must contain the media marker (see
546/// [`MtmdContext::default_marker`]) once for every bitmap to be embedded.
547#[derive(Debug)]
548pub struct MtmdInputText<'a> {
549    c_text: CString,
550    add_special: bool,
551    parse_special: bool,
552    _marker: std::marker::PhantomData<&'a ()>,
553}
554
555impl<'a> MtmdInputText<'a> {
556    /// Create a new `MtmdInputText`.
557    ///
558    /// * `text`          – the prompt (must not contain interior NUL bytes)
559    /// * `add_special`   – whether to add BOS/EOS tokens
560    /// * `parse_special` – whether to parse special tokens embedded in the text
561    ///
562    /// # Panics
563    ///
564    /// Panics if `text` contains an interior NUL byte.
565    #[must_use]
566    pub fn new(text: &'a str, add_special: bool, parse_special: bool) -> Self {
567        let c_text = CString::new(text).expect("MtmdInputText: text must not contain NUL bytes");
568        Self {
569            c_text,
570            add_special,
571            parse_special,
572            _marker: std::marker::PhantomData,
573        }
574    }
575
576    /// Try to create a new `MtmdInputText`, returning an error if `text`
577    /// contains an interior NUL byte.
578    ///
579    /// # Errors
580    ///
581    /// Returns [`std::ffi::NulError`] if `text` contains a NUL byte.
582    pub fn try_new(
583        text: &'a str,
584        add_special: bool,
585        parse_special: bool,
586    ) -> std::result::Result<Self, std::ffi::NulError> {
587        let c_text = CString::new(text)?;
588        Ok(Self {
589            c_text,
590            add_special,
591            parse_special,
592            _marker: std::marker::PhantomData,
593        })
594    }
595}
596
597// ─────────────────────────────────────────────────────────────────────────────
598// MtmdBitmap
599// ─────────────────────────────────────────────────────────────────────────────
600
601/// An image or audio bitmap ready for multimodal encoding.
602///
603/// # Image bitmaps
604///
605/// The raw pixel data must be in RGBRGBRGB… (interleaved) format.  The total
606/// number of bytes must be `nx * ny * 3`.
607///
608/// # Audio bitmaps
609///
610/// The raw sample data must be little-endian `f32` PCM samples.  The total
611/// number of bytes must be `n_samples * 4`.
612pub struct MtmdBitmap {
613    ptr: NonNull<sys::mtmd_bitmap>,
614}
615
616unsafe impl Send for MtmdBitmap {}
617unsafe impl Sync for MtmdBitmap {}
618
619impl std::fmt::Debug for MtmdBitmap {
620    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
621        f.debug_struct("MtmdBitmap")
622            .field("nx", &self.nx())
623            .field("ny", &self.ny())
624            .field("n_bytes", &self.n_bytes())
625            .field("is_audio", &self.is_audio())
626            .finish()
627    }
628}
629
630impl Drop for MtmdBitmap {
631    fn drop(&mut self) {
632        unsafe { sys::mtmd_bitmap_free(self.ptr.as_ptr()) }
633    }
634}
635
636impl MtmdBitmap {
637    /// Create a bitmap from raw RGB pixel data.
638    ///
639    /// * `nx`   – image width in pixels
640    /// * `ny`   – image height in pixels
641    /// * `data` – raw pixel bytes in RGBRGB… format; must be `nx * ny * 3` bytes
642    ///
643    /// # Errors
644    ///
645    /// Returns [`MtmdError::BitmapCreateFailed`] if the underlying C call
646    /// returns null.
647    pub fn from_rgb(nx: u32, ny: u32, data: &[u8]) -> Result<Self> {
648        let ptr = unsafe { sys::mtmd_bitmap_init(nx, ny, data.as_ptr()) };
649        let ptr = NonNull::new(ptr).ok_or(MtmdError::BitmapCreateFailed)?;
650        Ok(Self { ptr })
651    }
652
653    /// Create an audio bitmap from PCM `f32` samples.
654    ///
655    /// * `samples` – slice of PCM float samples
656    ///
657    /// # Errors
658    ///
659    /// Returns [`MtmdError::BitmapCreateFailed`] if the underlying C call
660    /// returns null.
661    pub fn from_audio(samples: &[f32]) -> Result<Self> {
662        let ptr = unsafe { sys::mtmd_bitmap_init_from_audio(samples.len(), samples.as_ptr()) };
663        let ptr = NonNull::new(ptr).ok_or(MtmdError::BitmapCreateFailed)?;
664        Ok(Self { ptr })
665    }
666
667    /// Load a bitmap from a file (image or audio).
668    ///
669    /// Supported image formats: JPEG, PNG, BMP, GIF, and others handled by
670    /// `stb_image`.  Supported audio formats: WAV, MP3, FLAC (via miniaudio).
671    ///
672    /// # Errors
673    ///
674    /// Returns [`MtmdError::BitmapCreateFailed`] if the file cannot be loaded.
675    pub fn from_file(ctx: &MtmdContext, path: impl AsRef<Path>) -> Result<Self> {
676        let path = path.as_ref().to_str().ok_or(MtmdError::PathNotUtf8)?;
677        let c_path = CString::new(path)?;
678
679        let ptr =
680            unsafe { sys::mtmd_helper_bitmap_init_from_file(ctx.ptr.as_ptr(), c_path.as_ptr()) };
681        let ptr = NonNull::new(ptr).ok_or(MtmdError::BitmapCreateFailed)?;
682        Ok(Self { ptr })
683    }
684
685    /// Load a bitmap from an in-memory buffer containing a file.
686    ///
687    /// The format is auto-detected (image vs audio via magic bytes).
688    ///
689    /// # Errors
690    ///
691    /// Returns [`MtmdError::BitmapCreateFailed`] if decoding fails.
692    pub fn from_buf(ctx: &MtmdContext, buf: &[u8]) -> Result<Self> {
693        let ptr = unsafe {
694            sys::mtmd_helper_bitmap_init_from_buf(ctx.ptr.as_ptr(), buf.as_ptr(), buf.len())
695        };
696        let ptr = NonNull::new(ptr).ok_or(MtmdError::BitmapCreateFailed)?;
697        Ok(Self { ptr })
698    }
699
700    // ── Getters ───────────────────────────────────────────────────────────
701
702    /// Width in pixels (for images) or 0 (for audio).
703    #[must_use]
704    pub fn nx(&self) -> u32 {
705        unsafe { sys::mtmd_bitmap_get_nx(self.ptr.as_ptr()) }
706    }
707
708    /// Height in pixels (for images) or 0 (for audio).
709    #[must_use]
710    pub fn ny(&self) -> u32 {
711        unsafe { sys::mtmd_bitmap_get_ny(self.ptr.as_ptr()) }
712    }
713
714    /// Total number of bytes in the bitmap data.
715    #[must_use]
716    pub fn n_bytes(&self) -> usize {
717        unsafe { sys::mtmd_bitmap_get_n_bytes(self.ptr.as_ptr()) }
718    }
719
720    /// Returns `true` if this bitmap contains audio (rather than image) data.
721    #[must_use]
722    pub fn is_audio(&self) -> bool {
723        unsafe { sys::mtmd_bitmap_is_audio(self.ptr.as_ptr()) }
724    }
725
726    /// Return the raw pixel / sample data.
727    #[must_use]
728    pub fn data(&self) -> &[u8] {
729        let n = self.n_bytes();
730        if n == 0 {
731            return &[];
732        }
733        let ptr = unsafe { sys::mtmd_bitmap_get_data(self.ptr.as_ptr()) };
734        unsafe { slice::from_raw_parts(ptr, n) }
735    }
736
737    /// Return the optional ID string attached to this bitmap (used for KV
738    /// cache tracking), or `None` if no ID has been set.
739    #[must_use]
740    pub fn id(&self) -> Option<&str> {
741        let ptr = unsafe { sys::mtmd_bitmap_get_id(self.ptr.as_ptr()) };
742        if ptr.is_null() {
743            return None;
744        }
745        unsafe { CStr::from_ptr(ptr) }.to_str().ok()
746    }
747
748    /// Attach an optional ID string to this bitmap (used for KV cache
749    /// tracking).
750    ///
751    /// # Errors
752    ///
753    /// Returns an error if `id` contains an interior NUL byte.
754    pub fn set_id(&mut self, id: &str) -> std::result::Result<(), std::ffi::NulError> {
755        let cs = CString::new(id)?;
756        unsafe { sys::mtmd_bitmap_set_id(self.ptr.as_ptr(), cs.as_ptr()) };
757        Ok(())
758    }
759}
760
761// ─────────────────────────────────────────────────────────────────────────────
762// MtmdInputChunks
763// ─────────────────────────────────────────────────────────────────────────────
764
765/// A list of tokenized input chunks produced by [`MtmdContext::tokenize`].
766///
767/// Each chunk is either a text token sequence or a set of image/audio tokens.
768pub struct MtmdInputChunks {
769    ptr: NonNull<sys::mtmd_input_chunks>,
770}
771
772impl std::fmt::Debug for MtmdInputChunks {
773    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
774        f.debug_struct("MtmdInputChunks")
775            .field("len", &self.len())
776            .finish()
777    }
778}
779
780impl Drop for MtmdInputChunks {
781    fn drop(&mut self) {
782        unsafe { sys::mtmd_input_chunks_free(self.ptr.as_ptr()) }
783    }
784}
785
786impl MtmdInputChunks {
787    /// Create a new, empty chunk list.  Populated by
788    /// [`MtmdContext::tokenize`].
789    ///
790    /// # Panics
791    ///
792    /// Panics if the underlying C allocation fails (OOM).
793    #[must_use]
794    pub fn new() -> Self {
795        let ptr = unsafe { sys::mtmd_input_chunks_init() };
796        let ptr = NonNull::new(ptr).expect("mtmd_input_chunks_init returned null");
797        Self { ptr }
798    }
799
800    /// Number of chunks in this list.
801    #[must_use]
802    pub fn len(&self) -> usize {
803        unsafe { sys::mtmd_input_chunks_size(self.ptr.as_ptr()) }
804    }
805
806    /// Returns `true` if there are no chunks.
807    #[must_use]
808    pub fn is_empty(&self) -> bool {
809        self.len() == 0
810    }
811
812    /// Get the `idx`-th chunk.  Returns `None` if `idx >= len()`.
813    #[must_use]
814    pub fn get(&self, idx: usize) -> Option<MtmdInputChunk<'_>> {
815        if idx >= self.len() {
816            return None;
817        }
818        let ptr = unsafe { sys::mtmd_input_chunks_get(self.ptr.as_ptr(), idx) };
819        if ptr.is_null() {
820            return None;
821        }
822        Some(MtmdInputChunk {
823            ptr,
824            _marker: std::marker::PhantomData,
825        })
826    }
827
828    /// Iterate over all chunks.
829    pub fn iter(&self) -> impl Iterator<Item = MtmdInputChunk<'_>> {
830        (0..self.len()).filter_map(|i| self.get(i))
831    }
832
833    /// Total number of tokens across all chunks.
834    ///
835    /// Equivalent to `mtmd_helper_get_n_tokens`.
836    #[must_use]
837    pub fn n_tokens(&self) -> usize {
838        unsafe { sys::mtmd_helper_get_n_tokens(self.ptr.as_ptr()) }
839    }
840
841    /// Total number of *positions* across all chunks (used for KV-cache
842    /// tracking with M-RoPE models where positions ≠ tokens).
843    ///
844    /// Equivalent to `mtmd_helper_get_n_pos`.
845    #[must_use]
846    pub fn n_pos(&self) -> i32 {
847        unsafe { sys::mtmd_helper_get_n_pos(self.ptr.as_ptr()) }
848    }
849}
850
851impl Default for MtmdInputChunks {
852    fn default() -> Self {
853        Self::new()
854    }
855}
856
857// ─────────────────────────────────────────────────────────────────────────────
858// MtmdInputChunkType
859// ─────────────────────────────────────────────────────────────────────────────
860
861/// The type of an [`MtmdInputChunk`].
862#[derive(Debug, Clone, Copy, PartialEq, Eq)]
863pub enum MtmdInputChunkType {
864    /// Plain text tokens.
865    Text,
866    /// Image tokens (embeddings produced by the vision encoder).
867    Image,
868    /// Audio tokens (embeddings produced by the audio encoder).
869    Audio,
870}
871
872impl From<sys::mtmd_input_chunk_type> for MtmdInputChunkType {
873    fn from(v: sys::mtmd_input_chunk_type) -> Self {
874        // mtmd_input_chunk_type is a plain C `typedef unsigned int`.
875        // The variants are exported as free-standing constants.
876        if v == sys::MTMD_INPUT_CHUNK_TYPE_IMAGE {
877            Self::Image
878        } else if v == sys::MTMD_INPUT_CHUNK_TYPE_AUDIO {
879            Self::Audio
880        } else {
881            Self::Text
882        }
883    }
884}
885
886// ─────────────────────────────────────────────────────────────────────────────
887// MtmdInputChunk
888// ─────────────────────────────────────────────────────────────────────────────
889
890/// A single tokenized input chunk (text, image, or audio).
891///
892/// Instances are borrowed from an [`MtmdInputChunks`] list and live as long
893/// as that list.
894#[derive(Debug)]
895pub struct MtmdInputChunk<'chunks> {
896    ptr: *const sys::mtmd_input_chunk,
897    _marker: std::marker::PhantomData<&'chunks MtmdInputChunks>,
898}
899
900impl<'chunks> MtmdInputChunk<'chunks> {
901    /// The type of this chunk.
902    #[must_use]
903    pub fn chunk_type(&self) -> MtmdInputChunkType {
904        let t = unsafe { sys::mtmd_input_chunk_get_type(self.ptr) };
905        MtmdInputChunkType::from(t)
906    }
907
908    /// Total number of tokens in this chunk.
909    #[must_use]
910    pub fn n_tokens(&self) -> usize {
911        unsafe { sys::mtmd_input_chunk_get_n_tokens(self.ptr) }
912    }
913
914    /// Number of temporal positions (equals `n_tokens` for non-M-RoPE models).
915    #[must_use]
916    pub fn n_pos(&self) -> i32 {
917        unsafe { sys::mtmd_input_chunk_get_n_pos(self.ptr) }
918    }
919
920    /// Return the raw llama token IDs for a **text** chunk.
921    ///
922    /// Returns `None` if this chunk is not a text chunk.
923    #[must_use]
924    pub fn text_tokens(&self) -> Option<&[i32]> {
925        if self.chunk_type() != MtmdInputChunkType::Text {
926            return None;
927        }
928        let mut n: usize = 0;
929        let ptr = unsafe { sys::mtmd_input_chunk_get_tokens_text(self.ptr, &raw mut n) };
930        if ptr.is_null() || n == 0 {
931            return Some(&[]);
932        }
933        Some(unsafe { slice::from_raw_parts(ptr, n) })
934    }
935
936    /// Return the image token metadata for an **image** or **audio** chunk.
937    ///
938    /// Returns `None` for text chunks.
939    #[must_use]
940    pub fn image_tokens(&self) -> Option<MtmdImageTokens<'chunks>> {
941        match self.chunk_type() {
942            MtmdInputChunkType::Image | MtmdInputChunkType::Audio => {}
943            MtmdInputChunkType::Text => return None,
944        }
945        let ptr = unsafe { sys::mtmd_input_chunk_get_tokens_image(self.ptr) };
946        if ptr.is_null() {
947            return None;
948        }
949        Some(MtmdImageTokens {
950            ptr,
951            _marker: std::marker::PhantomData,
952        })
953    }
954
955    /// Optional ID attached to this chunk (used for KV cache tracking).
956    #[must_use]
957    pub fn id(&self) -> Option<&str> {
958        let ptr = unsafe { sys::mtmd_input_chunk_get_id(self.ptr) };
959        if ptr.is_null() {
960            return None;
961        }
962        unsafe { CStr::from_ptr(ptr) }.to_str().ok()
963    }
964
965    /// Returns the raw `*const mtmd_input_chunk` pointer.
966    ///
967    /// # Safety
968    ///
969    /// The returned pointer is valid for the lifetime of the parent
970    /// `MtmdInputChunks`.
971    #[must_use]
972    pub fn as_ptr(&self) -> *const sys::mtmd_input_chunk {
973        self.ptr
974    }
975}
976
977// ─────────────────────────────────────────────────────────────────────────────
978// MtmdImageTokens
979// ─────────────────────────────────────────────────────────────────────────────
980
981/// Image/audio token metadata attached to a non-text [`MtmdInputChunk`].
982#[derive(Debug)]
983pub struct MtmdImageTokens<'chunks> {
984    ptr: *const sys::mtmd_image_tokens,
985    _marker: std::marker::PhantomData<&'chunks MtmdInputChunks>,
986}
987
988impl MtmdImageTokens<'_> {
989    /// Total number of embedding tokens.
990    #[must_use]
991    pub fn n_tokens(&self) -> usize {
992        unsafe { sys::mtmd_image_tokens_get_n_tokens(self.ptr) }
993    }
994
995    /// Width of the token grid.
996    #[must_use]
997    pub fn nx(&self) -> usize {
998        unsafe { sys::mtmd_image_tokens_get_nx(self.ptr) }
999    }
1000
1001    /// Height of the token grid.
1002    #[must_use]
1003    pub fn ny(&self) -> usize {
1004        unsafe { sys::mtmd_image_tokens_get_ny(self.ptr) }
1005    }
1006
1007    /// Number of temporal positions (M-RoPE variant; equals `n_tokens` otherwise).
1008    #[must_use]
1009    pub fn n_pos(&self) -> i32 {
1010        unsafe { sys::mtmd_image_tokens_get_n_pos(self.ptr) }
1011    }
1012
1013    /// Optional ID for KV cache tracking.
1014    #[must_use]
1015    pub fn id(&self) -> Option<&str> {
1016        let ptr = unsafe { sys::mtmd_image_tokens_get_id(self.ptr) };
1017        if ptr.is_null() {
1018            return None;
1019        }
1020        unsafe { CStr::from_ptr(ptr) }.to_str().ok()
1021    }
1022}
1023
1024// ─────────────────────────────────────────────────────────────────────────────
1025// LlamaContext extension
1026// ─────────────────────────────────────────────────────────────────────────────
1027
1028use crate::context::LlamaContext;
1029
1030impl LlamaContext<'_> {
1031    /// Expose the raw `llama_context` pointer for use with mtmd helpers.
1032    ///
1033    /// # Safety
1034    ///
1035    /// The pointer is valid for the lifetime of this `LlamaContext` and must
1036    /// not be freed by the caller.
1037    #[must_use]
1038    pub fn as_ptr(&self) -> *mut sys::llama_context {
1039        self.context.as_ptr()
1040    }
1041}