Skip to main content

llama_cpp_4/
mtmd.rs

1//! Safe wrappers for the `libmtmd` multimodal support library.
2//!
3//! `libmtmd` extends llama.cpp with the ability to encode image and audio
4//! inputs (bitmaps) into token embeddings that can then be fed into a
5//! standard [`crate::context::LlamaContext::decode`] call alongside normal text tokens.
6//!
7//! # Quick-start
8//!
9//! ```no_run
10//! # #[cfg(feature = "mtmd")]
11//! # {
12//! use std::path::Path;
13//! use llama_cpp_4::{
14//!     llama_backend::LlamaBackend,
15//!     model::{LlamaModel, params::LlamaModelParams, AddBos},
16//!     context::params::LlamaContextParams,
17//!     mtmd::{MtmdContext, MtmdContextParams, MtmdBitmap, MtmdInputChunks, MtmdInputText},
18//! };
19//!
20//! let backend  = LlamaBackend::init().unwrap();
21//! let model    = LlamaModel::load_from_file(&backend, Path::new("model.gguf"),
22//!                                            &LlamaModelParams::default()).unwrap();
23//! let mut lctx = model.new_context(&backend, LlamaContextParams::default()).unwrap();
24//!
25//! // Load the multimodal projector (mmproj) model.
26//! let ctx_params = MtmdContextParams::default();
27//! let mtmd_ctx   = MtmdContext::init_from_file(Path::new("mmproj.gguf"), &model, ctx_params)
28//!                               .unwrap();
29//!
30//! // Load an image from a file.
31//! let bitmap = MtmdBitmap::from_file(&mtmd_ctx, Path::new("image.jpg")).unwrap();
32//!
33//! // Tokenize a prompt that contains the media marker.
34//! let marker  = MtmdContext::default_marker();
35//! let prompt  = format!("Describe this image: {marker}");
36//! let text    = MtmdInputText::new(&prompt, true, true);
37//! let bitmaps = [&bitmap];
38//!
39//! let mut chunks = MtmdInputChunks::new();
40//! mtmd_ctx.tokenize(&text, &bitmaps, &mut chunks).unwrap();
41//!
42//! // Evaluate / decode all chunks.
43//! let n_batch = lctx.n_batch() as i32;
44//! let mut n_past = 0i32;
45//! mtmd_ctx.eval_chunks(lctx.as_ptr(), &chunks, 0, 0, n_batch, true, &mut n_past).unwrap();
46//! # }
47//! ```
48//!
49//! # Feature flag
50//!
51//! This module is only compiled when the `mtmd` Cargo feature is enabled.
52
53use std::ffi::{CStr, CString};
54use std::os::raw::c_void;
55use std::path::Path;
56use std::ptr::NonNull;
57use std::slice;
58
59use llama_cpp_sys_4 as sys;
60
61use crate::model::LlamaModel;
62
63// ─────────────────────────────────────────────────────────────────────────────
64// Error types
65// ─────────────────────────────────────────────────────────────────────────────
66
67/// All errors that can be returned by the mtmd module.
68#[derive(Debug, thiserror::Error)]
69pub enum MtmdError {
70    /// The context could not be created (e.g. bad mmproj file).
71    #[error("failed to create mtmd context (null return from mtmd_init_from_file)")]
72    ContextCreateFailed,
73
74    /// The bitmap could not be created.
75    #[error("failed to create mtmd bitmap")]
76    BitmapCreateFailed,
77
78    /// A path could not be converted to a valid C string (embedded NUL byte or non-UTF-8).
79    #[error("invalid path: {0}")]
80    InvalidPath(#[from] std::ffi::NulError),
81
82    /// A path was not representable as UTF-8.
83    #[error("path is not valid UTF-8")]
84    PathNotUtf8,
85
86    /// `mtmd_tokenize` returned an error code.
87    #[error("tokenize error: code {0} (1 = bitmap count mismatch, 2 = preprocessing error)")]
88    TokenizeError(i32),
89
90    /// `mtmd_encode_chunk` returned a non-zero code.
91    #[error("encode error: code {0}")]
92    EncodeError(i32),
93
94    /// `mtmd_helper_eval_chunks` (or single-chunk variant) returned a non-zero code.
95    #[error("eval error: code {0}")]
96    EvalError(i32),
97
98    /// A video stream could not be opened. Common causes: the build lacks
99    /// video support (`MTMD_VIDEO` was OFF), `ffmpeg`/`ffprobe` is not on
100    /// `PATH`, or the file is unreadable.
101    #[error("failed to open video stream (null return from mtmd_helper_video_init)")]
102    VideoInitFailed,
103
104    /// `mtmd_helper_video_read_next` returned an error code (`-2`).
105    #[error("video read error: code {0}")]
106    VideoReadError(i32),
107}
108
109/// A convenience `Result` alias for this module.
110pub type Result<T> = std::result::Result<T, MtmdError>;
111
112/// Progress callback invoked while the CLIP/mmproj weights are loading.
113///
114/// Receives a value in `[0.0, 1.0]`. Return `true` to continue loading or
115/// `false` to abort immediately.
116pub type MtmdProgressCallback = unsafe extern "C" fn(progress: f32, user_data: *mut c_void) -> bool;
117
118// ─────────────────────────────────────────────────────────────────────────────
119// MtmdContextParams
120// ─────────────────────────────────────────────────────────────────────────────
121
122/// Parameters used when creating an [`MtmdContext`].
123///
124/// Obtain a default-initialised instance via [`MtmdContextParams::default()`].
125pub struct MtmdContextParams {
126    pub(crate) params: sys::mtmd_context_params,
127}
128
129impl std::fmt::Debug for MtmdContextParams {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        f.debug_struct("MtmdContextParams")
132            .field("use_gpu", &self.params.use_gpu)
133            .field("print_timings", &self.params.print_timings)
134            .field("n_threads", &self.params.n_threads)
135            .field("warmup", &self.params.warmup)
136            .field("image_min_tokens", &self.params.image_min_tokens)
137            .field("image_max_tokens", &self.params.image_max_tokens)
138            .finish()
139    }
140}
141
142impl Default for MtmdContextParams {
143    fn default() -> Self {
144        let params = unsafe { sys::mtmd_context_params_default() };
145        Self { params }
146    }
147}
148
149impl MtmdContextParams {
150    /// Whether to run the vision/audio encoder on the GPU (default: `true`).
151    #[must_use]
152    pub fn use_gpu(mut self, v: bool) -> Self {
153        self.params.use_gpu = v;
154        self
155    }
156
157    /// Whether to print timing info after each encode (default: `false`).
158    #[must_use]
159    pub fn print_timings(mut self, v: bool) -> Self {
160        self.params.print_timings = v;
161        self
162    }
163
164    /// Number of threads used for the vision encoder (default taken from
165    /// `mtmd_context_params_default`).
166    #[must_use]
167    pub fn n_threads(mut self, n: i32) -> Self {
168        self.params.n_threads = n;
169        self
170    }
171
172    /// Whether to run a warm-up encode pass after initialisation.
173    #[must_use]
174    pub fn warmup(mut self, v: bool) -> Self {
175        self.params.warmup = v;
176        self
177    }
178
179    /// Minimum number of image tokens (0 = use model default).
180    #[must_use]
181    pub fn image_min_tokens(mut self, n: i32) -> Self {
182        self.params.image_min_tokens = n;
183        self
184    }
185
186    /// Maximum number of image tokens (0 = use model default).
187    #[must_use]
188    pub fn image_max_tokens(mut self, n: i32) -> Self {
189        self.params.image_max_tokens = n;
190        self
191    }
192
193    /// Maximum number of multimodal output tokens per batch.
194    ///
195    /// Maps to `mtmd_context_params.batch_max_tokens`. The upstream default
196    /// is `1024`. Increase for large images or long audio segments.
197    ///
198    /// # Examples
199    ///
200    /// ```rust
201    /// # #[cfg(feature = "mtmd")]
202    /// # {
203    /// use llama_cpp_4::mtmd::MtmdContextParams;
204    /// let params = MtmdContextParams::default().with_batch_max_tokens(2048);
205    /// assert_eq!(params.batch_max_tokens(), 2048);
206    /// # }
207    /// ```
208    #[must_use]
209    pub fn with_batch_max_tokens(mut self, n: i32) -> Self {
210        self.params.batch_max_tokens = n;
211        self
212    }
213
214    /// Get the configured batch token cap (`batch_max_tokens`).
215    #[must_use]
216    pub fn batch_max_tokens(&self) -> i32 {
217        self.params.batch_max_tokens
218    }
219
220    /// Set flash-attention mode for the vision encoder.
221    ///
222    /// Maps to `mtmd_context_params.flash_attn_type`. Uses the same
223    /// [`crate::context::params::LlamaFlashAttnType`] enum as text contexts.
224    ///
225    /// # Examples
226    ///
227    /// ```rust
228    /// # #[cfg(feature = "mtmd")]
229    /// # {
230    /// use llama_cpp_4::context::params::LlamaFlashAttnType;
231    /// use llama_cpp_4::mtmd::MtmdContextParams;
232    /// let params = MtmdContextParams::default()
233    ///     .with_flash_attn_type(LlamaFlashAttnType::Auto);
234    /// assert_eq!(params.flash_attn_type(), LlamaFlashAttnType::Auto);
235    /// # }
236    /// ```
237    #[must_use]
238    pub fn with_flash_attn_type(
239        mut self,
240        flash_attn_type: crate::context::params::LlamaFlashAttnType,
241    ) -> Self {
242        self.params.flash_attn_type = flash_attn_type.into();
243        self
244    }
245
246    /// Get flash-attention mode for the vision encoder.
247    #[must_use]
248    pub fn flash_attn_type(&self) -> crate::context::params::LlamaFlashAttnType {
249        crate::context::params::LlamaFlashAttnType::from(self.params.flash_attn_type)
250    }
251
252    /// Register a callback invoked while mmproj weights load.
253    ///
254    /// Maps to `mtmd_context_params.progress_callback`. Pass `None` to disable
255    /// progress reporting. The callback may return `false` to abort loading
256    /// early; see [`MtmdProgressCallback`].
257    ///
258    /// `user_data` is forwarded to each invocation and must remain valid until
259    /// [`MtmdContext::init_from_file`] returns.
260    #[must_use]
261    pub fn with_progress_callback(
262        mut self,
263        callback: Option<MtmdProgressCallback>,
264        user_data: *mut c_void,
265    ) -> Self {
266        self.params.progress_callback = callback;
267        self.params.progress_callback_user_data = user_data;
268        self
269    }
270
271    /// Override the media marker string (e.g. `"<image>"`).
272    ///
273    /// The provided string must not contain interior NUL bytes.  Pass `None`
274    /// to use the library default (`mtmd_default_marker()`).
275    ///
276    /// **Note:** the `CString` is stored inside the params so the pointer
277    /// remains valid as long as this `MtmdContextParams` lives.
278    /// # Errors
279    ///
280    /// Returns [`MtmdError`] if the marker string contains a NUL byte.
281    pub fn media_marker(mut self, marker: Option<&str>) -> std::result::Result<Self, MtmdError> {
282        match marker {
283            None => {
284                self.params.media_marker = std::ptr::null();
285                Ok(self)
286            }
287            Some(s) => {
288                let cs = CString::new(s)?;
289                self.params.media_marker = cs.as_ptr();
290                // Leak the CString so the raw pointer stays valid; the caller
291                // must ensure the params don't outlive the string.  Since
292                // MtmdContextParams is consumed by MtmdContext::init_from_file,
293                // this is safe.
294                std::mem::forget(cs);
295                Ok(self)
296            }
297        }
298    }
299}
300
301// ─────────────────────────────────────────────────────────────────────────────
302// MtmdContext
303// ─────────────────────────────────────────────────────────────────────────────
304
305/// The main multimodal context.
306///
307/// Wraps a `mtmd_context *`.  This context is tied to a specific mmproj model
308/// file and a loaded [`LlamaModel`].  It is safe to share across threads for
309/// `tokenize` calls (read-only), but `encode_chunk` / eval helpers mutate
310/// internal state and must not be called concurrently.
311pub struct MtmdContext {
312    ptr: NonNull<sys::mtmd_context>,
313}
314
315// The underlying mtmd_context is internally synchronised for tokenize().
316// encode / decode must be called from a single thread at a time (caller's
317// responsibility, enforced by the inference semaphore in the server).
318unsafe impl Send for MtmdContext {}
319unsafe impl Sync for MtmdContext {}
320
321impl std::fmt::Debug for MtmdContext {
322    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323        f.debug_struct("MtmdContext")
324            .field("ptr", &self.ptr)
325            .finish()
326    }
327}
328
329impl Drop for MtmdContext {
330    fn drop(&mut self) {
331        unsafe { sys::mtmd_free(self.ptr.as_ptr()) }
332    }
333}
334
335impl MtmdContext {
336    /// Returns the default media marker string used in prompts
337    /// (currently `"<__media__>"`).
338    #[must_use]
339    pub fn default_marker() -> &'static str {
340        let ptr = unsafe { sys::mtmd_default_marker() };
341        unsafe { CStr::from_ptr(ptr) }
342            .to_str()
343            .unwrap_or("<__media__>")
344    }
345
346    /// Initialise a multimodal context from an mmproj GGUF file.
347    ///
348    /// # Parameters
349    ///
350    /// * `mmproj_path` – path to the mmproj `.gguf` file
351    /// * `text_model`  – the already-loaded text model
352    /// * `params`      – context parameters (use [`MtmdContextParams::default()`])
353    ///
354    /// # Errors
355    ///
356    /// Returns [`MtmdError::ContextCreateFailed`] if the underlying C call
357    /// returns a null pointer.
358    #[allow(clippy::needless_pass_by_value)]
359    pub fn init_from_file(
360        mmproj_path: impl AsRef<Path>,
361        text_model: &LlamaModel,
362        params: MtmdContextParams,
363    ) -> Result<Self> {
364        let path = mmproj_path
365            .as_ref()
366            .to_str()
367            .ok_or(MtmdError::PathNotUtf8)?;
368        let c_path = CString::new(path)?;
369
370        let ptr = unsafe {
371            sys::mtmd_init_from_file(c_path.as_ptr(), text_model.model.as_ptr(), params.params)
372        };
373
374        let ptr = NonNull::new(ptr).ok_or(MtmdError::ContextCreateFailed)?;
375        Ok(Self { ptr })
376    }
377
378    // ── Logging ──────────────────────────────────────────────────────────
379
380    /// Silence all clip/mtmd log output by installing a no-op callback.
381    ///
382    /// Call this right after [`init_from_file`](Self::init_from_file) to
383    /// suppress the verbose `clip_model_loader: tensor[N]…` lines that
384    /// clip.cpp emits to its own private logger (separate from `llama_log_set`).
385    pub fn void_logs() {
386        unsafe extern "C" fn noop(
387            _level: sys::ggml_log_level,
388            _text: *const ::std::os::raw::c_char,
389            _ud: *mut ::std::os::raw::c_void,
390        ) {
391        }
392        unsafe { sys::mtmd_log_set(Some(noop), std::ptr::null_mut()) };
393    }
394
395    /// Like [`void_logs`](Self::void_logs), but additionally silences logs
396    /// emitted by the `mtmd_helper_*` layer (e.g. eval/decode helpers).
397    ///
398    /// Internally calls `mtmd_helper_log_set` which also routes through
399    /// `mtmd_log_set`, so this is a strict superset of `void_logs`.
400    pub fn void_helper_logs() {
401        unsafe extern "C" fn noop(
402            _level: sys::ggml_log_level,
403            _text: *const ::std::os::raw::c_char,
404            _ud: *mut ::std::os::raw::c_void,
405        ) {
406        }
407        unsafe { sys::mtmd_helper_log_set(Some(noop), std::ptr::null_mut()) };
408    }
409
410    // ── Capability queries ────────────────────────────────────────────────
411
412    /// Returns `true` if the model supports vision (image) input.
413    #[must_use]
414    pub fn supports_vision(&self) -> bool {
415        unsafe { sys::mtmd_support_vision(self.ptr.as_ptr()) }
416    }
417
418    /// Returns `true` if the model supports audio input.
419    #[must_use]
420    pub fn supports_audio(&self) -> bool {
421        unsafe { sys::mtmd_support_audio(self.ptr.as_ptr()) }
422    }
423
424    /// Returns `true` if this build and model support video input.
425    ///
426    /// Video support additionally requires `ffmpeg`/`ffprobe` to be available
427    /// at runtime (see [`MtmdVideo`]). Wraps `mtmd_helper_support_video`.
428    #[must_use]
429    pub fn supports_video(&self) -> bool {
430        unsafe { sys::mtmd_helper_support_video(self.ptr.as_ptr()) }
431    }
432
433    /// Returns the media marker string configured for *this* context.
434    ///
435    /// Unlike [`default_marker`](Self::default_marker) (the library-wide
436    /// default), this reflects any override passed via
437    /// [`MtmdContextParams::media_marker`]. Wraps `mtmd_get_marker`.
438    #[must_use]
439    pub fn marker(&self) -> &str {
440        let ptr = unsafe { sys::mtmd_get_marker(self.ptr.as_ptr()) };
441        if ptr.is_null() {
442            return Self::default_marker();
443        }
444        unsafe { CStr::from_ptr(ptr) }
445            .to_str()
446            .unwrap_or_else(|_| Self::default_marker())
447    }
448
449    /// Returns the audio sample rate in Hz (e.g. `16_000` for Whisper), or `-1` if
450    /// audio is not supported.
451    #[must_use]
452    pub fn audio_sample_rate(&self) -> i32 {
453        unsafe { sys::mtmd_get_audio_sample_rate(self.ptr.as_ptr()) }
454    }
455
456    /// Whether `llama_decode` must use a non-causal attention mask when
457    /// decoding image embeddings for this model.
458    #[must_use]
459    pub fn decode_use_non_causal(&self, chunk: &MtmdInputChunk<'_>) -> bool {
460        unsafe { sys::mtmd_decode_use_non_causal(self.ptr.as_ptr(), chunk.as_ptr()) }
461    }
462
463    /// Whether the model uses M-RoPE for `llama_decode`.
464    #[must_use]
465    pub fn decode_use_mrope(&self) -> bool {
466        unsafe { sys::mtmd_decode_use_mrope(self.ptr.as_ptr()) }
467    }
468
469    // ── Core API ──────────────────────────────────────────────────────────
470
471    /// Tokenize a text prompt that contains one or more media markers.
472    ///
473    /// The number of `bitmaps` must equal the number of media markers in the
474    /// prompt text, otherwise [`MtmdError::TokenizeError`] with code `1` is returned.
475    ///
476    /// This call is **thread-safe** (shared `&self`).
477    ///
478    /// # Parameters
479    ///
480    /// * `text`    – text + tokenisation options
481    /// * `bitmaps` – slice of [`MtmdBitmap`] references, one per media marker
482    /// * `output`  – an [`MtmdInputChunks`] that will be populated with the result
483    ///
484    /// # Errors
485    ///
486    /// Returns [`MtmdError::TokenizeError`] if tokenization fails.
487    pub fn tokenize(
488        &self,
489        text: &MtmdInputText<'_>,
490        bitmaps: &[&MtmdBitmap],
491        output: &mut MtmdInputChunks,
492    ) -> Result<()> {
493        // The C signature is: mtmd_tokenize(..., mtmd_bitmap ** bitmaps, ...)
494        // where each element is a `const mtmd_bitmap *`.  We build a Vec of
495        // `*const mtmd_bitmap` and pass a mutable pointer to its first element
496        // (i.e. `*mut *const mtmd_bitmap`) to satisfy the C API.
497        let mut bitmap_ptrs: Vec<*const sys::mtmd_bitmap> = bitmaps
498            .iter()
499            .map(|b| b.ptr.as_ptr().cast_const())
500            .collect();
501
502        let c_text = sys::mtmd_input_text {
503            text: text.c_text.as_ptr(),
504            add_special: text.add_special,
505            parse_special: text.parse_special,
506        };
507
508        let ret = unsafe {
509            sys::mtmd_tokenize(
510                self.ptr.as_ptr(),
511                output.ptr.as_ptr(),
512                &raw const c_text,
513                bitmap_ptrs.as_mut_ptr(),
514                bitmap_ptrs.len(),
515            )
516        };
517
518        if ret != 0 {
519            return Err(MtmdError::TokenizeError(ret));
520        }
521        Ok(())
522    }
523
524    /// Encode a single input chunk (image or audio) and store the resulting
525    /// embeddings inside the context.
526    ///
527    /// After a successful call, the embeddings can be retrieved with
528    /// [`MtmdContext::output_embd`].
529    ///
530    /// This call is **NOT thread-safe**.
531    ///
532    /// # Errors
533    ///
534    /// Returns [`MtmdError::EncodeError`] if encoding fails.
535    pub fn encode_chunk(&self, chunk: &MtmdInputChunk<'_>) -> Result<()> {
536        let ret = unsafe { sys::mtmd_encode_chunk(self.ptr.as_ptr(), chunk.ptr) };
537        if ret != 0 {
538            return Err(MtmdError::EncodeError(ret));
539        }
540        Ok(())
541    }
542
543    /// Return a slice over the embeddings produced by the last
544    /// [`encode_chunk`](Self::encode_chunk) call.
545    ///
546    /// The length (in `f32` elements) is:
547    /// ```text
548    /// n_embd_inp(model)  *  chunk.n_tokens()
549    /// ```
550    ///
551    /// # Safety
552    ///
553    /// The returned slice is valid until the next call that mutates the
554    /// context (e.g. another `encode_chunk`).
555    #[must_use]
556    pub fn output_embd(&self, n_elements: usize) -> &[f32] {
557        let ptr = unsafe { sys::mtmd_get_output_embd(self.ptr.as_ptr()) };
558        if ptr.is_null() || n_elements == 0 {
559            return &[];
560        }
561        unsafe { slice::from_raw_parts(ptr, n_elements) }
562    }
563
564    // ── Helper API ────────────────────────────────────────────────────────
565
566    /// High-level helper: evaluate (decode) all chunks in sequence.
567    ///
568    /// * Text chunks are decoded via `llama_decode`.
569    /// * Image/audio chunks are first encoded with `mtmd_encode_chunk` and
570    ///   then decoded via `llama_decode`.
571    ///
572    /// On success `new_n_past` is updated with the new past position.
573    ///
574    /// This call is **NOT thread-safe**.
575    ///
576    /// # Parameters
577    ///
578    /// * `lctx`        – raw pointer to the llama context (from [`LlamaContext::as_ptr`])
579    /// * `chunks`      – the tokenized chunks to evaluate
580    /// * `n_past`      – current KV-cache position
581    /// * `seq_id`      – sequence ID
582    /// * `n_batch`     – maximum batch size (must be ≥ 1)
583    /// * `logits_last` – if `true`, compute logits only for the final token
584    /// * `new_n_past`  – updated KV-cache position after the call
585    ///
586    /// # Errors
587    ///
588    /// Returns [`MtmdError::EvalError`] if evaluation fails.
589    #[allow(clippy::too_many_arguments, clippy::not_unsafe_ptr_arg_deref)]
590    pub fn eval_chunks(
591        &self,
592        lctx: *mut sys::llama_context,
593        chunks: &MtmdInputChunks,
594        n_past: i32,
595        seq_id: i32,
596        n_batch: i32,
597        logits_last: bool,
598        new_n_past: &mut i32,
599    ) -> Result<()> {
600        let ret = unsafe {
601            sys::mtmd_helper_eval_chunks(
602                self.ptr.as_ptr(),
603                lctx,
604                chunks.ptr.as_ptr(),
605                n_past,
606                seq_id,
607                n_batch,
608                logits_last,
609                new_n_past,
610            )
611        };
612        if ret != 0 {
613            return Err(MtmdError::EvalError(ret));
614        }
615        Ok(())
616    }
617
618    /// High-level helper: evaluate a single chunk.
619    ///
620    /// Works identically to [`eval_chunks`](Self::eval_chunks) but operates on
621    /// one chunk at a time.
622    ///
623    /// # Errors
624    ///
625    /// Returns [`MtmdError::EvalError`] if evaluation fails.
626    #[allow(clippy::too_many_arguments, clippy::not_unsafe_ptr_arg_deref)]
627    pub fn eval_chunk_single(
628        &self,
629        lctx: *mut sys::llama_context,
630        chunk: &MtmdInputChunk<'_>,
631        n_past: i32,
632        seq_id: i32,
633        n_batch: i32,
634        logits_last: bool,
635        new_n_past: &mut i32,
636    ) -> Result<()> {
637        let ret = unsafe {
638            sys::mtmd_helper_eval_chunk_single(
639                self.ptr.as_ptr(),
640                lctx,
641                chunk.ptr,
642                n_past,
643                seq_id,
644                n_batch,
645                logits_last,
646                new_n_past,
647            )
648        };
649        if ret != 0 {
650            return Err(MtmdError::EvalError(ret));
651        }
652        Ok(())
653    }
654
655    /// Decode an image/audio chunk whose embeddings have already been
656    /// computed (e.g. via [`encode_chunk`](Self::encode_chunk) followed by
657    /// [`output_embd`](Self::output_embd)).
658    ///
659    /// Unlike [`eval_chunk_single`](Self::eval_chunk_single), this helper
660    /// handles batching plus the non-causal-attention setup required by
661    /// some models (e.g. Gemma 3, Gemma 4 audio) and the M-RoPE position
662    /// layout. Use it when the embeddings are already in hand and you want
663    /// the helper to take care of `llama_decode` plumbing.
664    ///
665    /// `encoded_embd` must contain `mtmd_image_tokens_get_n_tokens(chunk) *
666    /// llama_model_n_embd_inp(model)` `f32` elements. This call is **NOT
667    /// thread-safe**.
668    ///
669    /// # Errors
670    ///
671    /// Returns [`MtmdError::EvalError`] with code `-1` if `chunk` is not an
672    /// image/audio chunk, or `1` if `llama_decode` fails.
673    #[allow(clippy::too_many_arguments, clippy::not_unsafe_ptr_arg_deref)]
674    pub fn decode_image_chunk(
675        &self,
676        lctx: *mut sys::llama_context,
677        chunk: &MtmdInputChunk<'_>,
678        encoded_embd: &[f32],
679        n_past: i32,
680        seq_id: i32,
681        n_batch: i32,
682        new_n_past: &mut i32,
683    ) -> Result<()> {
684        let ret = unsafe {
685            sys::mtmd_helper_decode_image_chunk(
686                self.ptr.as_ptr(),
687                lctx,
688                chunk.ptr,
689                encoded_embd.as_ptr().cast_mut(),
690                n_past,
691                seq_id,
692                n_batch,
693                new_n_past,
694                // No post-decode callback; preserves prior single-shot behavior.
695                None,
696                std::ptr::null_mut(),
697            )
698        };
699        if ret != 0 {
700            return Err(MtmdError::EvalError(ret));
701        }
702        Ok(())
703    }
704
705    /// Returns a raw pointer to the underlying `mtmd_context`.
706    ///
707    /// # Safety
708    ///
709    /// The returned pointer is valid for the lifetime of this `MtmdContext`.
710    /// The caller must not free it.
711    #[must_use]
712    pub fn as_ptr(&self) -> *mut sys::mtmd_context {
713        self.ptr.as_ptr()
714    }
715}
716
717// ─────────────────────────────────────────────────────────────────────────────
718// MtmdInputText
719// ─────────────────────────────────────────────────────────────────────────────
720
721/// Text input for [`MtmdContext::tokenize`].
722///
723/// The prompt string must contain the media marker (see
724/// [`MtmdContext::default_marker`]) once for every bitmap to be embedded.
725#[derive(Debug)]
726pub struct MtmdInputText<'a> {
727    c_text: CString,
728    add_special: bool,
729    parse_special: bool,
730    _marker: std::marker::PhantomData<&'a ()>,
731}
732
733impl<'a> MtmdInputText<'a> {
734    /// Create a new `MtmdInputText`.
735    ///
736    /// * `text`          – the prompt (must not contain interior NUL bytes)
737    /// * `add_special`   – whether to add BOS/EOS tokens
738    /// * `parse_special` – whether to parse special tokens embedded in the text
739    ///
740    /// # Panics
741    ///
742    /// Panics if `text` contains an interior NUL byte.
743    #[must_use]
744    pub fn new(text: &'a str, add_special: bool, parse_special: bool) -> Self {
745        let c_text = CString::new(text).expect("MtmdInputText: text must not contain NUL bytes");
746        Self {
747            c_text,
748            add_special,
749            parse_special,
750            _marker: std::marker::PhantomData,
751        }
752    }
753
754    /// Try to create a new `MtmdInputText`, returning an error if `text`
755    /// contains an interior NUL byte.
756    ///
757    /// # Errors
758    ///
759    /// Returns [`std::ffi::NulError`] if `text` contains a NUL byte.
760    pub fn try_new(
761        text: &'a str,
762        add_special: bool,
763        parse_special: bool,
764    ) -> std::result::Result<Self, std::ffi::NulError> {
765        let c_text = CString::new(text)?;
766        Ok(Self {
767            c_text,
768            add_special,
769            parse_special,
770            _marker: std::marker::PhantomData,
771        })
772    }
773}
774
775// ─────────────────────────────────────────────────────────────────────────────
776// MtmdBitmap
777// ─────────────────────────────────────────────────────────────────────────────
778
779/// An image or audio bitmap ready for multimodal encoding.
780///
781/// # Image bitmaps
782///
783/// The raw pixel data must be in RGBRGBRGB… (interleaved) format.  The total
784/// number of bytes must be `nx * ny * 3`.
785///
786/// # Audio bitmaps
787///
788/// The raw sample data must be little-endian `f32` PCM samples.  The total
789/// number of bytes must be `n_samples * 4`.
790pub struct MtmdBitmap {
791    ptr: NonNull<sys::mtmd_bitmap>,
792}
793
794unsafe impl Send for MtmdBitmap {}
795unsafe impl Sync for MtmdBitmap {}
796
797impl std::fmt::Debug for MtmdBitmap {
798    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
799        f.debug_struct("MtmdBitmap")
800            .field("nx", &self.nx())
801            .field("ny", &self.ny())
802            .field("n_bytes", &self.n_bytes())
803            .field("is_audio", &self.is_audio())
804            .finish()
805    }
806}
807
808impl Drop for MtmdBitmap {
809    fn drop(&mut self) {
810        unsafe { sys::mtmd_bitmap_free(self.ptr.as_ptr()) }
811    }
812}
813
814impl MtmdBitmap {
815    /// Create a bitmap from raw RGB pixel data.
816    ///
817    /// * `nx`   – image width in pixels
818    /// * `ny`   – image height in pixels
819    /// * `data` – raw pixel bytes in RGBRGB… format; must be `nx * ny * 3` bytes
820    ///
821    /// # Errors
822    ///
823    /// Returns [`MtmdError::BitmapCreateFailed`] if the underlying C call
824    /// returns null.
825    pub fn from_rgb(nx: u32, ny: u32, data: &[u8]) -> Result<Self> {
826        let ptr = unsafe { sys::mtmd_bitmap_init(nx, ny, data.as_ptr()) };
827        let ptr = NonNull::new(ptr).ok_or(MtmdError::BitmapCreateFailed)?;
828        Ok(Self { ptr })
829    }
830
831    /// Create an audio bitmap from PCM `f32` samples.
832    ///
833    /// * `samples` – slice of PCM float samples
834    ///
835    /// # Errors
836    ///
837    /// Returns [`MtmdError::BitmapCreateFailed`] if the underlying C call
838    /// returns null.
839    pub fn from_audio(samples: &[f32]) -> Result<Self> {
840        let ptr = unsafe { sys::mtmd_bitmap_init_from_audio(samples.len(), samples.as_ptr()) };
841        let ptr = NonNull::new(ptr).ok_or(MtmdError::BitmapCreateFailed)?;
842        Ok(Self { ptr })
843    }
844
845    /// Build an `MtmdBitmap` from a `mtmd_helper_bitmap_wrapper`, taking
846    /// ownership of the `bitmap` and freeing any `video_ctx`.
847    ///
848    /// The `from_file`/`from_buf` constructors only support image/audio input.
849    /// When the input is a video the helper returns a non-null `video_ctx`
850    /// (an open ffmpeg stream) which is not representable as an `MtmdBitmap`;
851    /// we free it here to avoid leaking it. Use [`MtmdVideo`] for video input.
852    fn from_wrapper(wrapper: sys::mtmd_helper_bitmap_wrapper) -> Result<Self> {
853        if !wrapper.video_ctx.is_null() {
854            unsafe { sys::mtmd_helper_video_free(wrapper.video_ctx) };
855        }
856        let ptr = NonNull::new(wrapper.bitmap).ok_or(MtmdError::BitmapCreateFailed)?;
857        Ok(Self { ptr })
858    }
859
860    /// Load a bitmap from a file (image or audio).
861    ///
862    /// Supported image formats: JPEG, PNG, BMP, GIF, and others handled by
863    /// `stb_image`.  Supported audio formats: WAV, MP3, FLAC (via miniaudio).
864    ///
865    /// # Errors
866    ///
867    /// Returns [`MtmdError::BitmapCreateFailed`] if the file cannot be loaded.
868    pub fn from_file(ctx: &MtmdContext, path: impl AsRef<Path>) -> Result<Self> {
869        let path = path.as_ref().to_str().ok_or(MtmdError::PathNotUtf8)?;
870        let c_path = CString::new(path)?;
871
872        // `placeholder = false`: load the real bitmap data (not a token-count
873        // placeholder). For image/audio the returned `video_ctx` is always null.
874        let wrapper = unsafe {
875            sys::mtmd_helper_bitmap_init_from_file(ctx.ptr.as_ptr(), c_path.as_ptr(), false)
876        };
877        Self::from_wrapper(wrapper)
878    }
879
880    /// Load a bitmap from an in-memory buffer containing a file.
881    ///
882    /// The format is auto-detected (image vs audio via magic bytes).
883    ///
884    /// # Errors
885    ///
886    /// Returns [`MtmdError::BitmapCreateFailed`] if decoding fails.
887    pub fn from_buf(ctx: &MtmdContext, buf: &[u8]) -> Result<Self> {
888        // `placeholder = false`: load the real bitmap data (not a token-count
889        // placeholder). For image/audio the returned `video_ctx` is always null.
890        let wrapper = unsafe {
891            sys::mtmd_helper_bitmap_init_from_buf(ctx.ptr.as_ptr(), buf.as_ptr(), buf.len(), false)
892        };
893        Self::from_wrapper(wrapper)
894    }
895
896    // ── Getters ───────────────────────────────────────────────────────────
897
898    /// Width in pixels (for images) or 0 (for audio).
899    #[must_use]
900    pub fn nx(&self) -> u32 {
901        unsafe { sys::mtmd_bitmap_get_nx(self.ptr.as_ptr()) }
902    }
903
904    /// Height in pixels (for images) or 0 (for audio).
905    #[must_use]
906    pub fn ny(&self) -> u32 {
907        unsafe { sys::mtmd_bitmap_get_ny(self.ptr.as_ptr()) }
908    }
909
910    /// Total number of bytes in the bitmap data.
911    #[must_use]
912    pub fn n_bytes(&self) -> usize {
913        unsafe { sys::mtmd_bitmap_get_n_bytes(self.ptr.as_ptr()) }
914    }
915
916    /// Returns `true` if this bitmap contains audio (rather than image) data.
917    #[must_use]
918    pub fn is_audio(&self) -> bool {
919        unsafe { sys::mtmd_bitmap_is_audio(self.ptr.as_ptr()) }
920    }
921
922    /// Return the raw pixel / sample data.
923    #[must_use]
924    pub fn data(&self) -> &[u8] {
925        let n = self.n_bytes();
926        if n == 0 {
927            return &[];
928        }
929        let ptr = unsafe { sys::mtmd_bitmap_get_data(self.ptr.as_ptr()) };
930        unsafe { slice::from_raw_parts(ptr, n) }
931    }
932
933    /// Return the optional ID string attached to this bitmap (used for KV
934    /// cache tracking), or `None` if no ID has been set.
935    #[must_use]
936    pub fn id(&self) -> Option<&str> {
937        let ptr = unsafe { sys::mtmd_bitmap_get_id(self.ptr.as_ptr()) };
938        if ptr.is_null() {
939            return None;
940        }
941        unsafe { CStr::from_ptr(ptr) }.to_str().ok()
942    }
943
944    /// Attach an optional ID string to this bitmap (used for KV cache
945    /// tracking).
946    ///
947    /// # Errors
948    ///
949    /// Returns an error if `id` contains an interior NUL byte.
950    pub fn set_id(&mut self, id: &str) -> std::result::Result<(), std::ffi::NulError> {
951        let cs = CString::new(id)?;
952        unsafe { sys::mtmd_bitmap_set_id(self.ptr.as_ptr(), cs.as_ptr()) };
953        Ok(())
954    }
955}
956
957// ─────────────────────────────────────────────────────────────────────────────
958// Video input
959// ─────────────────────────────────────────────────────────────────────────────
960
961// `free()` from libc — used to release the heap-allocated text returned by
962// `mtmd_helper_video_read_next` (the C side allocates it with strdup/malloc and
963// documents that the caller must release it with `free()`).
964extern "C" {
965    fn free(ptr: *mut std::os::raw::c_void);
966}
967
968/// Parameters controlling how a [`MtmdVideo`] stream is opened and sampled.
969///
970/// Obtain a default-initialised instance via [`MtmdVideoParams::default()`]
971/// (which mirrors `mtmd_helper_video_init_params_default`: ~4 fps, native
972/// `ffmpeg`/`ffprobe` from `PATH`, and a 5 s timestamp interval) and tweak it
973/// with the builder methods.
974pub struct MtmdVideoParams {
975    params: sys::mtmd_helper_video_init_params,
976    // Keeps the `ffmpeg_bin_dir` C string alive for as long as `params`
977    // borrows it via a raw pointer.
978    ffmpeg_bin_dir: Option<CString>,
979}
980
981impl std::fmt::Debug for MtmdVideoParams {
982    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
983        f.debug_struct("MtmdVideoParams")
984            .field("fps_target", &self.params.fps_target)
985            .field("timestamp_interval_ms", &self.params.timestamp_interval_ms)
986            .field("ffmpeg_bin_dir", &self.ffmpeg_bin_dir)
987            .finish()
988    }
989}
990
991impl Default for MtmdVideoParams {
992    fn default() -> Self {
993        let params = unsafe { sys::mtmd_helper_video_init_params_default() };
994        Self {
995            params,
996            ffmpeg_bin_dir: None,
997        }
998    }
999}
1000
1001impl MtmdVideoParams {
1002    /// Desired output frame rate. Values `<= 0` mean "use the video's native
1003    /// fps" (the default is ~4 fps).
1004    #[must_use]
1005    pub fn fps_target(mut self, fps: f32) -> Self {
1006        self.params.fps_target = fps;
1007        self
1008    }
1009
1010    /// Interval, in milliseconds, between inserted timestamp text chunks (e.g.
1011    /// `"[10m50.5s]"`). Values `<= 0` disable timestamps (default 5000 ms).
1012    #[must_use]
1013    pub fn timestamp_interval_ms(mut self, ms: i64) -> Self {
1014        self.params.timestamp_interval_ms = ms;
1015        self
1016    }
1017
1018    /// Directory containing the `ffmpeg`/`ffprobe` binaries. Pass `None` to
1019    /// search `PATH` (the default).
1020    ///
1021    /// # Errors
1022    ///
1023    /// Returns an error if `dir` contains an interior NUL byte.
1024    pub fn ffmpeg_bin_dir(mut self, dir: Option<&str>) -> Result<Self> {
1025        match dir {
1026            None => {
1027                self.params.ffmpeg_bin_dir = std::ptr::null();
1028                self.ffmpeg_bin_dir = None;
1029            }
1030            Some(d) => {
1031                let cs = CString::new(d)?;
1032                self.params.ffmpeg_bin_dir = cs.as_ptr();
1033                // Store the owner so the pointer above stays valid.
1034                self.ffmpeg_bin_dir = Some(cs);
1035            }
1036        }
1037        Ok(self)
1038    }
1039}
1040
1041/// Metadata describing an open [`MtmdVideo`] stream.
1042#[derive(Debug, Clone, Copy, PartialEq)]
1043pub struct MtmdVideoInfo {
1044    /// Frame width in pixels.
1045    pub width: u32,
1046    /// Frame height in pixels.
1047    pub height: u32,
1048    /// Effective frames-per-second (the `fps_target` if set, else native fps).
1049    pub fps: f32,
1050    /// Estimated total frame count at the effective fps (`-1` if unknown).
1051    pub n_frames: i32,
1052}
1053
1054/// One item read from a [`MtmdVideo`] stream by [`MtmdVideo::read_next`].
1055#[derive(Debug)]
1056pub enum MtmdVideoItem {
1057    /// A decoded video frame, ready to be tokenized like any other image
1058    /// [`MtmdBitmap`].
1059    Frame(MtmdBitmap),
1060    /// A timestamp text marker (e.g. `"[10m50.5s]"`) to be inserted into the
1061    /// prompt between frames.
1062    Text(String),
1063}
1064
1065/// An open video stream, decoded frame-by-frame via `ffmpeg`.
1066///
1067/// The notion of "video" exists only at the helper level — it is decoded into
1068/// a sequence of image [frames](MtmdVideoItem::Frame) and timestamp
1069/// [text markers](MtmdVideoItem::Text) which are then fed through the normal
1070/// multimodal pipeline.
1071///
1072/// Requires a build with video support (see [`MtmdContext::supports_video`])
1073/// and `ffmpeg`/`ffprobe` available at runtime.
1074///
1075/// # Example
1076///
1077/// ```no_run
1078/// # #[cfg(feature = "mtmd")]
1079/// # fn run(mtmd_ctx: &llama_cpp_4::mtmd::MtmdContext) -> Result<(), llama_cpp_4::mtmd::MtmdError> {
1080/// use std::path::Path;
1081/// use llama_cpp_4::mtmd::{MtmdVideo, MtmdVideoParams, MtmdVideoItem};
1082///
1083/// let mut video = MtmdVideo::from_file(mtmd_ctx, Path::new("clip.mp4"),
1084///                                      &MtmdVideoParams::default())?;
1085/// while let Some(item) = video.read_next()? {
1086///     match item {
1087///         MtmdVideoItem::Frame(bitmap) => { /* tokenize the frame */ }
1088///         MtmdVideoItem::Text(ts)      => { /* insert the timestamp marker */ }
1089///     }
1090/// }
1091/// # Ok(())
1092/// # }
1093/// ```
1094pub struct MtmdVideo {
1095    ptr: NonNull<sys::mtmd_helper_video>,
1096}
1097
1098impl std::fmt::Debug for MtmdVideo {
1099    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1100        f.debug_struct("MtmdVideo")
1101            .field("info", &self.info())
1102            .finish()
1103    }
1104}
1105
1106impl Drop for MtmdVideo {
1107    fn drop(&mut self) {
1108        unsafe { sys::mtmd_helper_video_free(self.ptr.as_ptr()) }
1109    }
1110}
1111
1112impl MtmdVideo {
1113    /// Open a video file for frame-by-frame decoding.
1114    ///
1115    /// # Errors
1116    ///
1117    /// Returns [`MtmdError::VideoInitFailed`] if the stream cannot be opened
1118    /// (no video support compiled in, `ffprobe` not found, file unreadable,
1119    /// …), or [`MtmdError::InvalidPath`] / [`MtmdError::PathNotUtf8`] for a bad
1120    /// path.
1121    pub fn from_file(
1122        ctx: &MtmdContext,
1123        path: impl AsRef<Path>,
1124        params: &MtmdVideoParams,
1125    ) -> Result<Self> {
1126        let path = path.as_ref().to_str().ok_or(MtmdError::PathNotUtf8)?;
1127        let c_path = CString::new(path)?;
1128        let ptr = unsafe {
1129            sys::mtmd_helper_video_init(ctx.ptr.as_ptr(), c_path.as_ptr(), params.params)
1130        };
1131        let ptr = NonNull::new(ptr).ok_or(MtmdError::VideoInitFailed)?;
1132        Ok(Self { ptr })
1133    }
1134
1135    /// Open a video from an in-memory buffer. The buffer is copied internally,
1136    /// so it need not outlive this call.
1137    ///
1138    /// # Errors
1139    ///
1140    /// Returns [`MtmdError::VideoInitFailed`] if the stream cannot be opened.
1141    pub fn from_buf(ctx: &MtmdContext, buf: &[u8], params: &MtmdVideoParams) -> Result<Self> {
1142        let ptr = unsafe {
1143            sys::mtmd_helper_video_init_from_buf(
1144                ctx.ptr.as_ptr(),
1145                buf.as_ptr(),
1146                buf.len(),
1147                params.params,
1148            )
1149        };
1150        let ptr = NonNull::new(ptr).ok_or(MtmdError::VideoInitFailed)?;
1151        Ok(Self { ptr })
1152    }
1153
1154    /// Return metadata (resolution, effective fps, estimated frame count) for
1155    /// this stream.
1156    #[must_use]
1157    pub fn info(&self) -> MtmdVideoInfo {
1158        let info = unsafe { sys::mtmd_helper_video_get_info(self.ptr.as_ptr()) };
1159        MtmdVideoInfo {
1160            width: info.width,
1161            height: info.height,
1162            fps: info.fps,
1163            n_frames: info.n_frames,
1164        }
1165    }
1166
1167    /// Read the next item from the stream.
1168    ///
1169    /// Returns `Ok(Some(item))` for each frame or timestamp marker, and
1170    /// `Ok(None)` once the end of the stream is reached.
1171    ///
1172    /// # Errors
1173    ///
1174    /// Returns [`MtmdError::VideoReadError`] on a decode error.
1175    pub fn read_next(&mut self) -> Result<Option<MtmdVideoItem>> {
1176        let mut out_bitmap: *mut sys::mtmd_bitmap = std::ptr::null_mut();
1177        let mut out_text: *mut std::os::raw::c_char = std::ptr::null_mut();
1178        let ret = unsafe {
1179            sys::mtmd_helper_video_read_next(
1180                self.ptr.as_ptr(),
1181                &raw mut out_bitmap,
1182                &raw mut out_text,
1183            )
1184        };
1185        match ret {
1186            0 => {
1187                if let Some(ptr) = NonNull::new(out_bitmap) {
1188                    Ok(Some(MtmdVideoItem::Frame(MtmdBitmap { ptr })))
1189                } else if !out_text.is_null() {
1190                    let text = unsafe { CStr::from_ptr(out_text) }
1191                        .to_string_lossy()
1192                        .into_owned();
1193                    // The C side allocated this with strdup/malloc; release it.
1194                    unsafe { free(out_text.cast()) };
1195                    Ok(Some(MtmdVideoItem::Text(text)))
1196                } else {
1197                    // Success but nothing produced — treat as end of stream.
1198                    Ok(None)
1199                }
1200            }
1201            -1 => Ok(None), // EOF
1202            other => Err(MtmdError::VideoReadError(other)),
1203        }
1204    }
1205}
1206
1207// ─────────────────────────────────────────────────────────────────────────────
1208// MtmdInputChunks
1209// ─────────────────────────────────────────────────────────────────────────────
1210
1211/// A list of tokenized input chunks produced by [`MtmdContext::tokenize`].
1212///
1213/// Each chunk is either a text token sequence or a set of image/audio tokens.
1214pub struct MtmdInputChunks {
1215    ptr: NonNull<sys::mtmd_input_chunks>,
1216}
1217
1218impl std::fmt::Debug for MtmdInputChunks {
1219    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1220        f.debug_struct("MtmdInputChunks")
1221            .field("len", &self.len())
1222            .finish()
1223    }
1224}
1225
1226impl Drop for MtmdInputChunks {
1227    fn drop(&mut self) {
1228        unsafe { sys::mtmd_input_chunks_free(self.ptr.as_ptr()) }
1229    }
1230}
1231
1232impl MtmdInputChunks {
1233    /// Create a new, empty chunk list.  Populated by
1234    /// [`MtmdContext::tokenize`].
1235    ///
1236    /// # Panics
1237    ///
1238    /// Panics if the underlying C allocation fails (OOM).
1239    #[must_use]
1240    pub fn new() -> Self {
1241        let ptr = unsafe { sys::mtmd_input_chunks_init() };
1242        let ptr = NonNull::new(ptr).expect("mtmd_input_chunks_init returned null");
1243        Self { ptr }
1244    }
1245
1246    /// Number of chunks in this list.
1247    #[must_use]
1248    pub fn len(&self) -> usize {
1249        unsafe { sys::mtmd_input_chunks_size(self.ptr.as_ptr()) }
1250    }
1251
1252    /// Returns `true` if there are no chunks.
1253    #[must_use]
1254    pub fn is_empty(&self) -> bool {
1255        self.len() == 0
1256    }
1257
1258    /// Get the `idx`-th chunk.  Returns `None` if `idx >= len()`.
1259    #[must_use]
1260    pub fn get(&self, idx: usize) -> Option<MtmdInputChunk<'_>> {
1261        if idx >= self.len() {
1262            return None;
1263        }
1264        let ptr = unsafe { sys::mtmd_input_chunks_get(self.ptr.as_ptr(), idx) };
1265        if ptr.is_null() {
1266            return None;
1267        }
1268        Some(MtmdInputChunk {
1269            ptr,
1270            _marker: std::marker::PhantomData,
1271        })
1272    }
1273
1274    /// Iterate over all chunks.
1275    pub fn iter(&self) -> impl Iterator<Item = MtmdInputChunk<'_>> {
1276        (0..self.len()).filter_map(|i| self.get(i))
1277    }
1278
1279    /// Total number of tokens across all chunks.
1280    ///
1281    /// Equivalent to `mtmd_helper_get_n_tokens`.
1282    #[must_use]
1283    pub fn n_tokens(&self) -> usize {
1284        unsafe { sys::mtmd_helper_get_n_tokens(self.ptr.as_ptr()) }
1285    }
1286
1287    /// Total number of *positions* across all chunks (used for KV-cache
1288    /// tracking with M-RoPE models where positions ≠ tokens).
1289    ///
1290    /// Equivalent to `mtmd_helper_get_n_pos`.
1291    #[must_use]
1292    pub fn n_pos(&self) -> i32 {
1293        unsafe { sys::mtmd_helper_get_n_pos(self.ptr.as_ptr()) }
1294    }
1295}
1296
1297impl Default for MtmdInputChunks {
1298    fn default() -> Self {
1299        Self::new()
1300    }
1301}
1302
1303// ─────────────────────────────────────────────────────────────────────────────
1304// MtmdInputChunkType
1305// ─────────────────────────────────────────────────────────────────────────────
1306
1307/// The type of an [`MtmdInputChunk`].
1308#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1309pub enum MtmdInputChunkType {
1310    /// Plain text tokens.
1311    Text,
1312    /// Image tokens (embeddings produced by the vision encoder).
1313    Image,
1314    /// Audio tokens (embeddings produced by the audio encoder).
1315    Audio,
1316}
1317
1318impl From<sys::mtmd_input_chunk_type> for MtmdInputChunkType {
1319    fn from(v: sys::mtmd_input_chunk_type) -> Self {
1320        // mtmd_input_chunk_type is a plain C `typedef unsigned int`.
1321        // The variants are exported as free-standing constants.
1322        if v == sys::MTMD_INPUT_CHUNK_TYPE_IMAGE {
1323            Self::Image
1324        } else if v == sys::MTMD_INPUT_CHUNK_TYPE_AUDIO {
1325            Self::Audio
1326        } else {
1327            Self::Text
1328        }
1329    }
1330}
1331
1332// ─────────────────────────────────────────────────────────────────────────────
1333// MtmdInputChunk
1334// ─────────────────────────────────────────────────────────────────────────────
1335
1336/// A single tokenized input chunk (text, image, or audio).
1337///
1338/// Instances are borrowed from an [`MtmdInputChunks`] list and live as long
1339/// as that list.
1340#[derive(Debug)]
1341pub struct MtmdInputChunk<'chunks> {
1342    ptr: *const sys::mtmd_input_chunk,
1343    _marker: std::marker::PhantomData<&'chunks MtmdInputChunks>,
1344}
1345
1346impl<'chunks> MtmdInputChunk<'chunks> {
1347    /// The type of this chunk.
1348    #[must_use]
1349    pub fn chunk_type(&self) -> MtmdInputChunkType {
1350        let t = unsafe { sys::mtmd_input_chunk_get_type(self.ptr) };
1351        MtmdInputChunkType::from(t)
1352    }
1353
1354    /// Total number of tokens in this chunk.
1355    #[must_use]
1356    pub fn n_tokens(&self) -> usize {
1357        unsafe { sys::mtmd_input_chunk_get_n_tokens(self.ptr) }
1358    }
1359
1360    /// Number of temporal positions (equals `n_tokens` for non-M-RoPE models).
1361    #[must_use]
1362    pub fn n_pos(&self) -> i32 {
1363        unsafe { sys::mtmd_input_chunk_get_n_pos(self.ptr) }
1364    }
1365
1366    /// Return the raw llama token IDs for a **text** chunk.
1367    ///
1368    /// Returns `None` if this chunk is not a text chunk.
1369    #[must_use]
1370    pub fn text_tokens(&self) -> Option<&[i32]> {
1371        if self.chunk_type() != MtmdInputChunkType::Text {
1372            return None;
1373        }
1374        let mut n: usize = 0;
1375        let ptr = unsafe { sys::mtmd_input_chunk_get_tokens_text(self.ptr, &raw mut n) };
1376        if ptr.is_null() || n == 0 {
1377            return Some(&[]);
1378        }
1379        Some(unsafe { slice::from_raw_parts(ptr, n) })
1380    }
1381
1382    /// Return the image token metadata for an **image** or **audio** chunk.
1383    ///
1384    /// Returns `None` for text chunks.
1385    #[must_use]
1386    pub fn image_tokens(&self) -> Option<MtmdImageTokens<'chunks>> {
1387        match self.chunk_type() {
1388            MtmdInputChunkType::Image | MtmdInputChunkType::Audio => {}
1389            MtmdInputChunkType::Text => return None,
1390        }
1391        let ptr = unsafe { sys::mtmd_input_chunk_get_tokens_image(self.ptr) };
1392        if ptr.is_null() {
1393            return None;
1394        }
1395        Some(MtmdImageTokens {
1396            ptr,
1397            _marker: std::marker::PhantomData,
1398        })
1399    }
1400
1401    /// Optional ID attached to this chunk (used for KV cache tracking).
1402    #[must_use]
1403    pub fn id(&self) -> Option<&str> {
1404        let ptr = unsafe { sys::mtmd_input_chunk_get_id(self.ptr) };
1405        if ptr.is_null() {
1406            return None;
1407        }
1408        unsafe { CStr::from_ptr(ptr) }.to_str().ok()
1409    }
1410
1411    /// Returns the raw `*const mtmd_input_chunk` pointer.
1412    ///
1413    /// # Safety
1414    ///
1415    /// The returned pointer is valid for the lifetime of the parent
1416    /// `MtmdInputChunks`.
1417    #[must_use]
1418    pub fn as_ptr(&self) -> *const sys::mtmd_input_chunk {
1419        self.ptr
1420    }
1421}
1422
1423// ─────────────────────────────────────────────────────────────────────────────
1424// MtmdDecoderPos
1425// ─────────────────────────────────────────────────────────────────────────────
1426
1427/// Per-token position used by M-RoPE decoder attention.
1428///
1429/// `t` is the temporal axis, `x`/`y` the spatial axes. `z` is reserved for
1430/// future use. Values are *relative* to a base `pos_0` provided when the
1431/// position is computed.
1432#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
1433#[repr(C)]
1434pub struct MtmdDecoderPos {
1435    /// Temporal index.
1436    pub t: u32,
1437    /// Spatial X.
1438    pub x: u32,
1439    /// Spatial Y.
1440    pub y: u32,
1441    /// Reserved.
1442    pub z: u32,
1443}
1444
1445// ─────────────────────────────────────────────────────────────────────────────
1446// MtmdImageTokens
1447// ─────────────────────────────────────────────────────────────────────────────
1448
1449/// Image/audio token metadata attached to a non-text [`MtmdInputChunk`].
1450#[derive(Debug)]
1451pub struct MtmdImageTokens<'chunks> {
1452    ptr: *const sys::mtmd_image_tokens,
1453    _marker: std::marker::PhantomData<&'chunks MtmdInputChunks>,
1454}
1455
1456impl MtmdImageTokens<'_> {
1457    /// Total number of embedding tokens.
1458    #[must_use]
1459    pub fn n_tokens(&self) -> usize {
1460        unsafe { sys::mtmd_image_tokens_get_n_tokens(self.ptr) }
1461    }
1462
1463    /// Width of the token grid.
1464    #[must_use]
1465    pub fn nx(&self) -> usize {
1466        unsafe { sys::mtmd_image_tokens_get_nx(self.ptr) }
1467    }
1468
1469    /// Height of the token grid.
1470    #[must_use]
1471    pub fn ny(&self) -> usize {
1472        unsafe { sys::mtmd_image_tokens_get_ny(self.ptr) }
1473    }
1474
1475    /// Number of temporal positions (M-RoPE variant; equals `n_tokens` otherwise).
1476    #[must_use]
1477    pub fn n_pos(&self) -> i32 {
1478        unsafe { sys::mtmd_image_tokens_get_n_pos(self.ptr) }
1479    }
1480
1481    /// Optional ID for KV cache tracking.
1482    #[must_use]
1483    pub fn id(&self) -> Option<&str> {
1484        let ptr = unsafe { sys::mtmd_image_tokens_get_id(self.ptr) };
1485        if ptr.is_null() {
1486            return None;
1487        }
1488        unsafe { CStr::from_ptr(ptr) }.to_str().ok()
1489    }
1490
1491    /// Compute the per-token decoder positions used by M-RoPE models.
1492    ///
1493    /// Returns a vector of length [`n_tokens`](Self::n_tokens). Each entry
1494    /// is relative to `pos_0`; for non-M-RoPE models this typically reduces
1495    /// to `(0, i, 0, 0)` for the i-th token.
1496    ///
1497    /// Wraps `mtmd_helper_image_get_decoder_pos`.
1498    #[must_use]
1499    pub fn decoder_positions(&self, pos_0: i32) -> Vec<MtmdDecoderPos> {
1500        let n = self.n_tokens();
1501        let mut out = vec![MtmdDecoderPos::default(); n];
1502        if n == 0 {
1503            return out;
1504        }
1505        unsafe {
1506            sys::mtmd_helper_image_get_decoder_pos(
1507                self.ptr,
1508                pos_0,
1509                out.as_mut_ptr().cast::<sys::mtmd_decoder_pos>(),
1510            );
1511        }
1512        out
1513    }
1514}
1515
1516// ─────────────────────────────────────────────────────────────────────────────
1517// LlamaContext extension
1518// ─────────────────────────────────────────────────────────────────────────────
1519
1520use crate::context::LlamaContext;
1521
1522impl LlamaContext<'_> {
1523    /// Expose the raw `llama_context` pointer for use with mtmd helpers.
1524    ///
1525    /// # Safety
1526    ///
1527    /// The pointer is valid for the lifetime of this `LlamaContext` and must
1528    /// not be freed by the caller.
1529    #[must_use]
1530    pub fn as_ptr(&self) -> *mut sys::llama_context {
1531        self.context.as_ptr()
1532    }
1533}
1534
1535#[cfg(test)]
1536mod tests {
1537    use super::*;
1538
1539    #[test]
1540    fn decoder_pos_layout_matches_sys() {
1541        // The Rust MtmdDecoderPos is cast to sys::mtmd_decoder_pos at the
1542        // FFI boundary in `MtmdImageTokens::decoder_positions`. Verify the
1543        // assumption.
1544        assert_eq!(
1545            std::mem::size_of::<MtmdDecoderPos>(),
1546            std::mem::size_of::<sys::mtmd_decoder_pos>(),
1547        );
1548        assert_eq!(
1549            std::mem::align_of::<MtmdDecoderPos>(),
1550            std::mem::align_of::<sys::mtmd_decoder_pos>(),
1551        );
1552        assert_eq!(std::mem::offset_of!(MtmdDecoderPos, t), 0);
1553        assert_eq!(std::mem::offset_of!(MtmdDecoderPos, x), 4);
1554        assert_eq!(std::mem::offset_of!(MtmdDecoderPos, y), 8);
1555        assert_eq!(std::mem::offset_of!(MtmdDecoderPos, z), 12);
1556    }
1557}