Skip to main content

llama_cpp_bindings/mtmd/
mtmd_context.rs

1use std::ffi::CString;
2use std::ffi::c_char;
3use std::ptr::NonNull;
4
5use crate::ffi_error_reader::read_and_free_cpp_error;
6use crate::model::LlamaModel;
7
8use super::mtmd_bitmap::MtmdBitmap;
9use super::mtmd_context_params::MtmdContextParams;
10use super::mtmd_encode_error::MtmdEncodeError;
11use super::mtmd_init_error::MtmdInitError;
12use super::mtmd_input_chunk::MtmdInputChunk;
13use super::mtmd_input_chunks::MtmdInputChunks;
14use super::mtmd_input_text::MtmdInputText;
15use super::mtmd_tokenize_error::MtmdTokenizeError;
16
17fn map_tokenize_status(
18    status: llama_cpp_bindings_sys::llama_rs_mtmd_tokenize_status,
19    undocumented_return_code: i32,
20    out_error: *mut c_char,
21) -> Result<(), MtmdTokenizeError> {
22    match status {
23        llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_OK => Ok(()),
24        llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_VENDORED_REPORTED_BITMAP_COUNT_DOES_NOT_MATCH_MARKER_COUNT => {
25            Err(MtmdTokenizeError::BitmapCountDoesNotMatchMarkerCount)
26        }
27        llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_VENDORED_REPORTED_IMAGE_PREPROCESSING_ERROR => {
28            Err(MtmdTokenizeError::MediaPreprocessingFailed)
29        }
30        llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_VENDORED_RETURNED_UNDOCUMENTED_NONZERO_CODE => {
31            Err(MtmdTokenizeError::UnknownStatus {
32                code: undocumented_return_code,
33            })
34        }
35        llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_ERROR_STRING_ALLOCATION_FAILED => {
36            Err(MtmdTokenizeError::NotEnoughMemory)
37        }
38        llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_VENDORED_THREW_CXX_EXCEPTION => {
39            let message = unsafe { read_and_free_cpp_error(out_error) };
40            Err(MtmdTokenizeError::Reported { message })
41        }
42        llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_NULL_BITMAPS_ARG_WHEN_NUM_BITMAPS_NONZERO => unreachable!("llama_rs_mtmd_tokenize NULL_BITMAPS_ARG: Rust always passes a non-null bitmaps pointer when count > 0"),
43        other => unreachable!("llama_rs_mtmd_tokenize returned unrecognized status: {other}"),
44    }
45}
46
47fn map_encode_chunk_status(
48    status: llama_cpp_bindings_sys::llama_rs_mtmd_encode_chunk_status,
49    vendored_return_code: i32,
50    out_error: *mut c_char,
51) -> Result<(), MtmdEncodeError> {
52    match status {
53        llama_cpp_bindings_sys::LLAMA_RS_MTMD_ENCODE_CHUNK_OK => Ok(()),
54        llama_cpp_bindings_sys::LLAMA_RS_MTMD_ENCODE_CHUNK_VENDORED_RETURNED_NONZERO_CODE => {
55            Err(MtmdEncodeError::EncodingFailed {
56                code: vendored_return_code,
57            })
58        }
59        llama_cpp_bindings_sys::LLAMA_RS_MTMD_ENCODE_CHUNK_ERROR_STRING_ALLOCATION_FAILED => {
60            Err(MtmdEncodeError::NotEnoughMemory)
61        }
62        llama_cpp_bindings_sys::LLAMA_RS_MTMD_ENCODE_CHUNK_VENDORED_THREW_CXX_EXCEPTION => {
63            let message = unsafe { read_and_free_cpp_error(out_error) };
64            Err(MtmdEncodeError::Reported { message })
65        }
66        other => unreachable!("llama_rs_mtmd_encode_chunk returned unrecognized status: {other}"),
67    }
68}
69
70#[derive(Debug)]
71pub struct MtmdContext {
72    pub context: NonNull<llama_cpp_bindings_sys::mtmd_context>,
73}
74
75unsafe impl Send for MtmdContext {}
76unsafe impl Sync for MtmdContext {}
77
78impl MtmdContext {
79    /// # Errors
80    ///
81    /// Returns an [`MtmdInitError`] variant matching the wrapper's status code.
82    pub fn init_from_file(
83        mmproj_path: &str,
84        text_model: &LlamaModel,
85        params: &MtmdContextParams,
86    ) -> Result<Self, MtmdInitError> {
87        let path_cstr = CString::new(mmproj_path)?;
88        let ctx_params = llama_cpp_bindings_sys::mtmd_context_params::from(params);
89
90        let mut out_ctx: *mut llama_cpp_bindings_sys::mtmd_context = std::ptr::null_mut();
91        let mut out_error: *mut c_char = std::ptr::null_mut();
92
93        let status = unsafe {
94            llama_cpp_bindings_sys::llama_rs_mtmd_init_from_file(
95                path_cstr.as_ptr(),
96                text_model.model.as_ptr(),
97                ctx_params,
98                &raw mut out_ctx,
99                &raw mut out_error,
100            )
101        };
102
103        match status {
104            llama_cpp_bindings_sys::LLAMA_RS_MTMD_INIT_FROM_FILE_OK => {
105                let context = NonNull::new(out_ctx).ok_or_else(|| MtmdInitError::Unloadable {
106                    path: std::path::PathBuf::from(mmproj_path),
107                })?;
108                Ok(Self { context })
109            }
110            llama_cpp_bindings_sys::LLAMA_RS_MTMD_INIT_FROM_FILE_VENDORED_RETURNED_NULL => {
111                Err(MtmdInitError::Unloadable {
112                    path: std::path::PathBuf::from(mmproj_path),
113                })
114            }
115            llama_cpp_bindings_sys::LLAMA_RS_MTMD_INIT_FROM_FILE_ERROR_STRING_ALLOCATION_FAILED => {
116                Err(MtmdInitError::NotEnoughMemory)
117            }
118            llama_cpp_bindings_sys::LLAMA_RS_MTMD_INIT_FROM_FILE_VENDORED_THREW_CXX_EXCEPTION => {
119                let message = unsafe { read_and_free_cpp_error(out_error) };
120                Err(MtmdInitError::Reported { message })
121            }
122            other => {
123                unreachable!("llama_rs_mtmd_init_from_file returned unrecognized status: {other}")
124            }
125        }
126    }
127
128    #[must_use]
129    pub fn decode_use_non_causal(&self, chunk: &MtmdInputChunk) -> bool {
130        unsafe {
131            llama_cpp_bindings_sys::mtmd_decode_use_non_causal(
132                self.context.as_ptr(),
133                chunk.chunk.as_ptr(),
134            )
135        }
136    }
137
138    #[must_use]
139    pub fn decode_use_mrope(&self) -> bool {
140        unsafe { llama_cpp_bindings_sys::mtmd_decode_use_mrope(self.context.as_ptr()) }
141    }
142
143    #[must_use]
144    pub fn support_vision(&self) -> bool {
145        unsafe { llama_cpp_bindings_sys::mtmd_support_vision(self.context.as_ptr()) }
146    }
147
148    #[must_use]
149    pub fn support_audio(&self) -> bool {
150        unsafe { llama_cpp_bindings_sys::mtmd_support_audio(self.context.as_ptr()) }
151    }
152
153    #[must_use]
154    pub fn get_audio_sample_rate(&self) -> Option<u32> {
155        let rate =
156            unsafe { llama_cpp_bindings_sys::mtmd_get_audio_sample_rate(self.context.as_ptr()) };
157        (rate > 0).then_some(rate.unsigned_abs())
158    }
159
160    /// # Errors
161    ///
162    /// Returns an [`MtmdTokenizeError`] variant matching the wrapper's status code.
163    pub fn tokenize(
164        &self,
165        text: MtmdInputText,
166        bitmaps: &[&MtmdBitmap],
167    ) -> Result<MtmdInputChunks, MtmdTokenizeError> {
168        let chunks = MtmdInputChunks::new()?;
169        let text_cstring = CString::new(text.text)?;
170        let input_text = llama_cpp_bindings_sys::mtmd_input_text {
171            text: text_cstring.as_ptr(),
172            add_special: text.add_special,
173            parse_special: text.parse_special,
174        };
175
176        let bitmap_ptrs: Vec<*const llama_cpp_bindings_sys::mtmd_bitmap> = bitmaps
177            .iter()
178            .map(|bitmap| bitmap.bitmap.as_ptr().cast_const())
179            .collect();
180
181        let mut out_undocumented_return_code: i32 = 0;
182        let mut out_error: *mut c_char = std::ptr::null_mut();
183
184        let status = unsafe {
185            llama_cpp_bindings_sys::llama_rs_mtmd_tokenize(
186                self.context.as_ptr(),
187                chunks.chunks.as_ptr(),
188                &raw const input_text,
189                bitmap_ptrs.as_ptr().cast_mut(),
190                bitmaps.len(),
191                &raw mut out_undocumented_return_code,
192                &raw mut out_error,
193            )
194        };
195
196        map_tokenize_status(status, out_undocumented_return_code, out_error)?;
197        Ok(chunks)
198    }
199
200    /// # Errors
201    ///
202    /// Returns an [`MtmdEncodeError`] variant matching the wrapper's status code.
203    pub fn encode_chunk(&self, chunk: &MtmdInputChunk) -> Result<(), MtmdEncodeError> {
204        let mut out_vendored_return_code: i32 = 0;
205        let mut out_error: *mut c_char = std::ptr::null_mut();
206
207        let status = unsafe {
208            llama_cpp_bindings_sys::llama_rs_mtmd_encode_chunk(
209                self.context.as_ptr(),
210                chunk.chunk.as_ptr(),
211                &raw mut out_vendored_return_code,
212                &raw mut out_error,
213            )
214        };
215
216        map_encode_chunk_status(status, out_vendored_return_code, out_error)
217    }
218}
219
220impl Drop for MtmdContext {
221    fn drop(&mut self) {
222        unsafe { llama_cpp_bindings_sys::mtmd_free(self.context.as_ptr()) }
223    }
224}
225
226#[cfg(test)]
227mod unit_tests {
228    use super::map_encode_chunk_status;
229    use super::map_tokenize_status;
230    use crate::mtmd::mtmd_encode_error::MtmdEncodeError;
231    use crate::mtmd::mtmd_tokenize_error::MtmdTokenizeError;
232
233    #[test]
234    fn tokenize_status_maps_bitmap_count_mismatch() {
235        let result = map_tokenize_status(
236            llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_VENDORED_REPORTED_BITMAP_COUNT_DOES_NOT_MATCH_MARKER_COUNT,
237            0,
238            std::ptr::null_mut(),
239        );
240
241        assert!(matches!(
242            result,
243            Err(MtmdTokenizeError::BitmapCountDoesNotMatchMarkerCount)
244        ));
245    }
246
247    #[test]
248    fn tokenize_status_maps_media_preprocessing_failed() {
249        let result = map_tokenize_status(
250            llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_VENDORED_REPORTED_IMAGE_PREPROCESSING_ERROR,
251            0,
252            std::ptr::null_mut(),
253        );
254
255        assert!(matches!(
256            result,
257            Err(MtmdTokenizeError::MediaPreprocessingFailed)
258        ));
259    }
260
261    #[test]
262    fn tokenize_status_maps_unknown_status_with_value() {
263        let result = map_tokenize_status(
264            llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_VENDORED_RETURNED_UNDOCUMENTED_NONZERO_CODE,
265            42,
266            std::ptr::null_mut(),
267        );
268
269        assert!(matches!(
270            result,
271            Err(MtmdTokenizeError::UnknownStatus { code: 42 })
272        ));
273    }
274
275    #[test]
276    fn tokenize_status_maps_ok_to_unit() {
277        let result = map_tokenize_status(
278            llama_cpp_bindings_sys::LLAMA_RS_MTMD_TOKENIZE_OK,
279            0,
280            std::ptr::null_mut(),
281        );
282
283        assert!(matches!(result, Ok(())));
284    }
285
286    #[test]
287    fn encode_chunk_status_maps_ok_to_unit() {
288        let result = map_encode_chunk_status(
289            llama_cpp_bindings_sys::LLAMA_RS_MTMD_ENCODE_CHUNK_OK,
290            0,
291            std::ptr::null_mut(),
292        );
293
294        assert!(matches!(result, Ok(())));
295    }
296
297    #[test]
298    fn encode_chunk_status_maps_encoding_failed_with_code() {
299        let result = map_encode_chunk_status(
300            llama_cpp_bindings_sys::LLAMA_RS_MTMD_ENCODE_CHUNK_VENDORED_RETURNED_NONZERO_CODE,
301            5,
302            std::ptr::null_mut(),
303        );
304
305        assert!(matches!(
306            result,
307            Err(MtmdEncodeError::EncodingFailed { code: 5 })
308        ));
309    }
310}