Skip to main content

llama_cpp_bindings/mtmd/
mtmd_context.rs

1use std::ffi::CString;
2use std::ptr::NonNull;
3
4use crate::model::LlamaModel;
5
6use super::mtmd_bitmap::MtmdBitmap;
7use super::mtmd_context_params::MtmdContextParams;
8use super::mtmd_error::{MtmdEncodeError, MtmdInitError, MtmdTokenizeError};
9use super::mtmd_input_chunk::MtmdInputChunk;
10use super::mtmd_input_chunks::MtmdInputChunks;
11use super::mtmd_input_text::MtmdInputText;
12
13const fn tokenize_result_to_error(result: i32) -> MtmdTokenizeError {
14    match result {
15        1 => MtmdTokenizeError::BitmapCountMismatch,
16        2 => MtmdTokenizeError::ImagePreprocessingError,
17        _ => MtmdTokenizeError::UnknownError(result),
18    }
19}
20
21const fn check_encode_result(result: i32) -> Result<(), MtmdEncodeError> {
22    if result == 0 {
23        Ok(())
24    } else {
25        Err(MtmdEncodeError::EncodeFailure(result))
26    }
27}
28
29/// Safe wrapper around `mtmd_context`.
30///
31/// This represents an initialized multimodal context that can process
32/// text, images, and audio through llama.cpp's multimodal interface.
33#[derive(Debug)]
34pub struct MtmdContext {
35    /// Raw pointer to the underlying `mtmd_context`.
36    pub context: NonNull<llama_cpp_bindings_sys::mtmd_context>,
37}
38
39unsafe impl Send for MtmdContext {}
40unsafe impl Sync for MtmdContext {}
41
42impl MtmdContext {
43    /// Initialize MTMD context from a multimodal projection file.
44    ///
45    /// # Errors
46    ///
47    /// This function will return an error if:
48    /// - The path cannot be converted to a C string
49    /// - The underlying C function returns null (indicating initialization failure)
50    pub fn init_from_file(
51        mmproj_path: &str,
52        text_model: &LlamaModel,
53        params: &MtmdContextParams,
54    ) -> Result<Self, MtmdInitError> {
55        let path_cstr = CString::new(mmproj_path)?;
56        let ctx_params = llama_cpp_bindings_sys::mtmd_context_params::from(params);
57
58        let context = unsafe {
59            llama_cpp_bindings_sys::mtmd_init_from_file(
60                path_cstr.as_ptr(),
61                text_model.model.as_ptr(),
62                ctx_params,
63            )
64        };
65
66        let context = NonNull::new(context).ok_or(MtmdInitError::NullResult)?;
67
68        Ok(Self { context })
69    }
70
71    /// Check whether non-causal attention mask is needed before `llama_decode`
72    /// for the given input chunk.
73    #[must_use]
74    pub fn decode_use_non_causal(&self, chunk: &MtmdInputChunk) -> bool {
75        unsafe {
76            llama_cpp_bindings_sys::mtmd_decode_use_non_causal(
77                self.context.as_ptr(),
78                chunk.chunk.as_ptr(),
79            )
80        }
81    }
82
83    /// Check whether the current model uses M-RoPE for `llama_decode`.
84    #[must_use]
85    pub fn decode_use_mrope(&self) -> bool {
86        unsafe { llama_cpp_bindings_sys::mtmd_decode_use_mrope(self.context.as_ptr()) }
87    }
88
89    /// Check whether the current model supports vision input.
90    #[must_use]
91    pub fn support_vision(&self) -> bool {
92        unsafe { llama_cpp_bindings_sys::mtmd_support_vision(self.context.as_ptr()) }
93    }
94
95    /// Check whether the current model supports audio input.
96    #[must_use]
97    pub fn support_audio(&self) -> bool {
98        unsafe { llama_cpp_bindings_sys::mtmd_support_audio(self.context.as_ptr()) }
99    }
100
101    /// Get audio sample rate in Hz (e.g., 16000 for Whisper).
102    /// Returns None if audio is not supported.
103    #[must_use]
104    pub fn get_audio_sample_rate(&self) -> Option<u32> {
105        let rate =
106            unsafe { llama_cpp_bindings_sys::mtmd_get_audio_sample_rate(self.context.as_ptr()) };
107        (rate > 0).then_some(rate.unsigned_abs())
108    }
109
110    /// Tokenize input text and bitmaps into chunks.
111    ///
112    /// The input text must contain media markers (default: `<__media__>`) that will be
113    /// replaced with the corresponding bitmap data from the `bitmaps` array.
114    /// The number of bitmaps must equal the number of markers in the text.
115    ///
116    /// # Errors
117    ///
118    /// * `BitmapCountMismatch` - Number of bitmaps doesn't match number of markers
119    /// * `ImagePreprocessingError` - Error occurred during image preprocessing
120    /// * `UnknownError` - Other tokenization error occurred
121    ///
122    /// # Example
123    ///
124    /// ```no_run
125    /// # use llama_cpp_bindings::mtmd::*;
126    /// # fn example(ctx: &MtmdContext, bitmap: &MtmdBitmap) -> Result<(), Box<dyn std::error::Error>> {
127    /// let text = MtmdInputText {
128    ///     text: "Here is an image: <__media__>\nDescribe it.".to_string(),
129    ///     add_special: true,
130    ///     parse_special: true,
131    /// };
132    /// let chunks = ctx.tokenize(text, &[bitmap])?;
133    /// # Ok(())
134    /// # }
135    /// ```
136    pub fn tokenize(
137        &self,
138        text: MtmdInputText,
139        bitmaps: &[&MtmdBitmap],
140    ) -> Result<MtmdInputChunks, MtmdTokenizeError> {
141        let chunks = MtmdInputChunks::new()?;
142        let text_cstring = CString::new(text.text)?;
143        let input_text = llama_cpp_bindings_sys::mtmd_input_text {
144            text: text_cstring.as_ptr(),
145            add_special: text.add_special,
146            parse_special: text.parse_special,
147        };
148
149        let bitmap_ptrs: Vec<*const llama_cpp_bindings_sys::mtmd_bitmap> = bitmaps
150            .iter()
151            .map(|bitmap| bitmap.bitmap.as_ptr().cast_const())
152            .collect();
153
154        let result = unsafe {
155            llama_cpp_bindings_sys::mtmd_tokenize(
156                self.context.as_ptr(),
157                chunks.chunks.as_ptr(),
158                &raw const input_text,
159                bitmap_ptrs.as_ptr().cast_mut(),
160                bitmaps.len(),
161            )
162        };
163
164        if result == 0 {
165            Ok(chunks)
166        } else {
167            Err(tokenize_result_to_error(result))
168        }
169    }
170
171    /// Encode a chunk for image/audio processing.
172    ///
173    /// # Errors
174    ///
175    /// Returns `MtmdEncodeError::EncodeFailure` if encoding fails.
176    pub fn encode_chunk(&self, chunk: &MtmdInputChunk) -> Result<(), MtmdEncodeError> {
177        let result = unsafe {
178            llama_cpp_bindings_sys::mtmd_encode_chunk(self.context.as_ptr(), chunk.chunk.as_ptr())
179        };
180
181        check_encode_result(result)
182    }
183}
184
185impl Drop for MtmdContext {
186    fn drop(&mut self) {
187        unsafe { llama_cpp_bindings_sys::mtmd_free(self.context.as_ptr()) }
188    }
189}
190
191#[cfg(test)]
192mod unit_tests {
193    use super::check_encode_result;
194    use super::tokenize_result_to_error;
195
196    #[test]
197    fn tokenize_result_bitmap_count_mismatch() {
198        let error = tokenize_result_to_error(1);
199
200        assert!(error.to_string().contains("does not match"));
201    }
202
203    #[test]
204    fn tokenize_result_image_preprocessing_error() {
205        let error = tokenize_result_to_error(2);
206
207        assert!(error.to_string().contains("Image preprocessing"));
208    }
209
210    #[test]
211    fn tokenize_result_unknown_error() {
212        let error = tokenize_result_to_error(42);
213
214        assert!(error.to_string().contains("Unknown error: 42"));
215    }
216
217    #[test]
218    fn check_encode_result_ok_for_zero() {
219        assert!(check_encode_result(0).is_ok());
220    }
221
222    #[test]
223    fn check_encode_result_error_for_nonzero() {
224        let result = check_encode_result(5);
225
226        assert!(
227            result
228                .unwrap_err()
229                .to_string()
230                .contains("Encode failed with code: 5")
231        );
232    }
233}