Skip to main content

llama_cpp_bindings/mtmd/
mtmd_context.rs

1use std::ffi::CString;
2use std::ffi::c_char;
3use std::ptr::NonNull;
4
5use crate::ffi_error_reader::read_and_free_cpp_error;
6use crate::model::LlamaModel;
7
8use super::mtmd_bitmap::MtmdBitmap;
9use super::mtmd_context_params::MtmdContextParams;
10use super::mtmd_encode_error::MtmdEncodeError;
11use super::mtmd_init_error::MtmdInitError;
12use super::mtmd_input_chunk::MtmdInputChunk;
13use super::mtmd_input_chunks::MtmdInputChunks;
14use super::mtmd_input_text::MtmdInputText;
15use super::mtmd_tokenize_error::MtmdTokenizeError;
16
17fn map_tokenize_status(
18    status: llama_cpp_bindings_sys::llama_rs_mtmd_tokenize_status,
19    undocumented_return_code: i32,
20    out_error: *mut c_char,
21) -> Result<(), MtmdTokenizeError> {
22    match status {
23        llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_OK => Ok(()),
24        llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_VENDORED_REPORTED_BITMAP_COUNT_DOES_NOT_MATCH_MARKER_COUNT => {
25            Err(MtmdTokenizeError::BitmapCountDoesNotMatchMarkerCount)
26        }
27        llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_VENDORED_REPORTED_IMAGE_PREPROCESSING_ERROR => {
28            Err(MtmdTokenizeError::MediaPreprocessingFailed)
29        }
30        llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_VENDORED_RETURNED_UNDOCUMENTED_NONZERO_CODE => {
31            Err(MtmdTokenizeError::UnknownStatus {
32                code: undocumented_return_code,
33            })
34        }
35        llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_ERROR_STRING_ALLOCATION_FAILED => {
36            Err(MtmdTokenizeError::NotEnoughMemory)
37        }
38        llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_VENDORED_THREW_CXX_EXCEPTION => {
39            let message = unsafe { read_and_free_cpp_error(out_error) };
40            Err(MtmdTokenizeError::Reported { message })
41        }
42        llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_NULL_BITMAPS_ARG_WHEN_NUM_BITMAPS_NONZERO => unreachable!("llama_rs_mtmd_tokenize NULL_BITMAPS_ARG: Rust always passes a non-null bitmaps pointer when count > 0"),
43        other => unreachable!("llama_rs_mtmd_tokenize returned unrecognized status: {other}"),
44    }
45}
46
47fn map_encode_chunk_status(
48    status: llama_cpp_bindings_sys::llama_rs_mtmd_encode_chunk_status,
49    vendored_return_code: i32,
50    out_error: *mut c_char,
51) -> Result<(), MtmdEncodeError> {
52    match status {
53        llama_cpp_bindings_sys::LLAMA_RS_MTMD_ENCODE_CHUNK_OK => Ok(()),
54        llama_cpp_bindings_sys::LLAMA_RS_MTMD_ENCODE_CHUNK_VENDORED_RETURNED_NONZERO_CODE => {
55            Err(MtmdEncodeError::EncodingFailed {
56                code: vendored_return_code,
57            })
58        }
59        llama_cpp_bindings_sys::LLAMA_RS_MTMD_ENCODE_CHUNK_ERROR_STRING_ALLOCATION_FAILED => {
60            Err(MtmdEncodeError::NotEnoughMemory)
61        }
62        llama_cpp_bindings_sys::LLAMA_RS_MTMD_ENCODE_CHUNK_VENDORED_THREW_CXX_EXCEPTION => {
63            let message = unsafe { read_and_free_cpp_error(out_error) };
64            Err(MtmdEncodeError::Reported { message })
65        }
66        other => unreachable!("llama_rs_mtmd_encode_chunk returned unrecognized status: {other}"),
67    }
68}
69
70/// Safe wrapper around `mtmd_context`.
71///
72/// This represents an initialized multimodal context that can process
73/// text, images, and audio through llama.cpp's multimodal interface.
74#[derive(Debug)]
75pub struct MtmdContext {
76    /// Raw pointer to the underlying `mtmd_context`.
77    pub context: NonNull<llama_cpp_bindings_sys::mtmd_context>,
78}
79
80unsafe impl Send for MtmdContext {}
81unsafe impl Sync for MtmdContext {}
82
83impl MtmdContext {
84    /// Initialize MTMD context from a multimodal projection file.
85    ///
86    /// # Errors
87    ///
88    /// Returns an [`MtmdInitError`] variant matching the wrapper's status code.
89    pub fn init_from_file(
90        mmproj_path: &str,
91        text_model: &LlamaModel,
92        params: &MtmdContextParams,
93    ) -> Result<Self, MtmdInitError> {
94        let path_cstr = CString::new(mmproj_path)?;
95        let ctx_params = llama_cpp_bindings_sys::mtmd_context_params::from(params);
96
97        let mut out_ctx: *mut llama_cpp_bindings_sys::mtmd_context = std::ptr::null_mut();
98        let mut out_error: *mut c_char = std::ptr::null_mut();
99
100        let status = unsafe {
101            llama_cpp_bindings_sys::llama_rs_mtmd_init_from_file(
102                path_cstr.as_ptr(),
103                text_model.model.as_ptr(),
104                ctx_params,
105                &raw mut out_ctx,
106                &raw mut out_error,
107            )
108        };
109
110        match status {
111            llama_cpp_bindings_sys::LLAMA_RS_MTMD_INIT_FROM_FILE_OK => {
112                let context = NonNull::new(out_ctx).ok_or_else(|| MtmdInitError::Unloadable {
113                    path: std::path::PathBuf::from(mmproj_path),
114                })?;
115                Ok(Self { context })
116            }
117            llama_cpp_bindings_sys::LLAMA_RS_MTMD_INIT_FROM_FILE_VENDORED_RETURNED_NULL => {
118                Err(MtmdInitError::Unloadable {
119                    path: std::path::PathBuf::from(mmproj_path),
120                })
121            }
122            llama_cpp_bindings_sys::LLAMA_RS_MTMD_INIT_FROM_FILE_ERROR_STRING_ALLOCATION_FAILED => {
123                Err(MtmdInitError::NotEnoughMemory)
124            }
125            llama_cpp_bindings_sys::LLAMA_RS_MTMD_INIT_FROM_FILE_VENDORED_THREW_CXX_EXCEPTION => {
126                let message = unsafe { read_and_free_cpp_error(out_error) };
127                Err(MtmdInitError::Reported { message })
128            }
129            other => {
130                unreachable!("llama_rs_mtmd_init_from_file returned unrecognized status: {other}")
131            }
132        }
133    }
134
135    /// Check whether non-causal attention mask is needed before `llama_decode`
136    /// for the given input chunk.
137    #[must_use]
138    pub fn decode_use_non_causal(&self, chunk: &MtmdInputChunk) -> bool {
139        unsafe {
140            llama_cpp_bindings_sys::mtmd_decode_use_non_causal(
141                self.context.as_ptr(),
142                chunk.chunk.as_ptr(),
143            )
144        }
145    }
146
147    /// Check whether the current model uses M-RoPE for `llama_decode`.
148    #[must_use]
149    pub fn decode_use_mrope(&self) -> bool {
150        unsafe { llama_cpp_bindings_sys::mtmd_decode_use_mrope(self.context.as_ptr()) }
151    }
152
153    /// Check whether the current model supports vision input.
154    #[must_use]
155    pub fn support_vision(&self) -> bool {
156        unsafe { llama_cpp_bindings_sys::mtmd_support_vision(self.context.as_ptr()) }
157    }
158
159    /// Check whether the current model supports audio input.
160    #[must_use]
161    pub fn support_audio(&self) -> bool {
162        unsafe { llama_cpp_bindings_sys::mtmd_support_audio(self.context.as_ptr()) }
163    }
164
165    /// Get audio sample rate in Hz (e.g., 16000 for Whisper).
166    /// Returns None if audio is not supported.
167    #[must_use]
168    pub fn get_audio_sample_rate(&self) -> Option<u32> {
169        let rate =
170            unsafe { llama_cpp_bindings_sys::mtmd_get_audio_sample_rate(self.context.as_ptr()) };
171        (rate > 0).then_some(rate.unsigned_abs())
172    }
173
174    /// Tokenize input text and bitmaps into chunks.
175    ///
176    /// The input text must contain media markers (default: `<__media__>`) that will be
177    /// replaced with the corresponding bitmap data from the `bitmaps` array.
178    /// The number of bitmaps must equal the number of markers in the text.
179    ///
180    /// # Errors
181    ///
182    /// Returns an [`MtmdTokenizeError`] variant matching the wrapper's status code.
183    pub fn tokenize(
184        &self,
185        text: MtmdInputText,
186        bitmaps: &[&MtmdBitmap],
187    ) -> Result<MtmdInputChunks, MtmdTokenizeError> {
188        let chunks = MtmdInputChunks::new()?;
189        let text_cstring = CString::new(text.text)?;
190        let input_text = llama_cpp_bindings_sys::mtmd_input_text {
191            text: text_cstring.as_ptr(),
192            add_special: text.add_special,
193            parse_special: text.parse_special,
194        };
195
196        let bitmap_ptrs: Vec<*const llama_cpp_bindings_sys::mtmd_bitmap> = bitmaps
197            .iter()
198            .map(|bitmap| bitmap.bitmap.as_ptr().cast_const())
199            .collect();
200
201        let mut out_undocumented_return_code: i32 = 0;
202        let mut out_error: *mut c_char = std::ptr::null_mut();
203
204        let status = unsafe {
205            llama_cpp_bindings_sys::llama_rs_mtmd_tokenize(
206                self.context.as_ptr(),
207                chunks.chunks.as_ptr(),
208                &raw const input_text,
209                bitmap_ptrs.as_ptr().cast_mut(),
210                bitmaps.len(),
211                &raw mut out_undocumented_return_code,
212                &raw mut out_error,
213            )
214        };
215
216        map_tokenize_status(status, out_undocumented_return_code, out_error)?;
217        Ok(chunks)
218    }
219
220    /// Encode a chunk for image/audio processing.
221    ///
222    /// # Errors
223    ///
224    /// Returns an [`MtmdEncodeError`] variant matching the wrapper's status code.
225    pub fn encode_chunk(&self, chunk: &MtmdInputChunk) -> Result<(), MtmdEncodeError> {
226        let mut out_vendored_return_code: i32 = 0;
227        let mut out_error: *mut c_char = std::ptr::null_mut();
228
229        let status = unsafe {
230            llama_cpp_bindings_sys::llama_rs_mtmd_encode_chunk(
231                self.context.as_ptr(),
232                chunk.chunk.as_ptr(),
233                &raw mut out_vendored_return_code,
234                &raw mut out_error,
235            )
236        };
237
238        map_encode_chunk_status(status, out_vendored_return_code, out_error)
239    }
240}
241
242impl Drop for MtmdContext {
243    fn drop(&mut self) {
244        unsafe { llama_cpp_bindings_sys::mtmd_free(self.context.as_ptr()) }
245    }
246}
247
248#[cfg(test)]
249mod unit_tests {
250    use super::map_encode_chunk_status;
251    use super::map_tokenize_status;
252    use crate::mtmd::mtmd_encode_error::MtmdEncodeError;
253    use crate::mtmd::mtmd_tokenize_error::MtmdTokenizeError;
254
255    #[test]
256    fn tokenize_status_maps_bitmap_count_mismatch() {
257        let result = map_tokenize_status(
258            llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_VENDORED_REPORTED_BITMAP_COUNT_DOES_NOT_MATCH_MARKER_COUNT,
259            0,
260            std::ptr::null_mut(),
261        );
262
263        assert!(matches!(
264            result,
265            Err(MtmdTokenizeError::BitmapCountDoesNotMatchMarkerCount)
266        ));
267    }
268
269    #[test]
270    fn tokenize_status_maps_media_preprocessing_failed() {
271        let result = map_tokenize_status(
272            llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_VENDORED_REPORTED_IMAGE_PREPROCESSING_ERROR,
273            0,
274            std::ptr::null_mut(),
275        );
276
277        assert!(matches!(
278            result,
279            Err(MtmdTokenizeError::MediaPreprocessingFailed)
280        ));
281    }
282
283    #[test]
284    fn tokenize_status_maps_unknown_status_with_value() {
285        let result = map_tokenize_status(
286            llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_VENDORED_RETURNED_UNDOCUMENTED_NONZERO_CODE,
287            42,
288            std::ptr::null_mut(),
289        );
290
291        assert!(matches!(
292            result,
293            Err(MtmdTokenizeError::UnknownStatus { code: 42 })
294        ));
295    }
296
297    #[test]
298    fn tokenize_status_maps_ok_to_unit() {
299        let result = map_tokenize_status(
300            llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_OK,
301            0,
302            std::ptr::null_mut(),
303        );
304
305        assert!(matches!(result, Ok(())));
306    }
307
308    #[test]
309    fn encode_chunk_status_maps_ok_to_unit() {
310        let result = map_encode_chunk_status(
311            llama_cpp_bindings_sys::LLAMA_RS_MTMD_ENCODE_CHUNK_OK,
312            0,
313            std::ptr::null_mut(),
314        );
315
316        assert!(matches!(result, Ok(())));
317    }
318
319    #[test]
320    fn encode_chunk_status_maps_encoding_failed_with_code() {
321        let result = map_encode_chunk_status(
322            llama_cpp_bindings_sys::LLAMA_RS_MTMD_ENCODE_CHUNK_VENDORED_RETURNED_NONZERO_CODE,
323            5,
324            std::ptr::null_mut(),
325        );
326
327        assert!(matches!(
328            result,
329            Err(MtmdEncodeError::EncodingFailed { code: 5 })
330        ));
331    }
332}