llama_cpp_bindings/mtmd/
mtmd_context.rs1use std::ffi::CString;
2use std::ptr::NonNull;
3
4use crate::model::LlamaModel;
5
6use super::mtmd_bitmap::MtmdBitmap;
7use super::mtmd_context_params::MtmdContextParams;
8use super::mtmd_error::{MtmdEncodeError, MtmdInitError, MtmdTokenizeError};
9use super::mtmd_input_chunk::MtmdInputChunk;
10use super::mtmd_input_chunks::MtmdInputChunks;
11use super::mtmd_input_text::MtmdInputText;
12
13const fn tokenize_result_to_error(result: i32) -> MtmdTokenizeError {
14 match result {
15 1 => MtmdTokenizeError::BitmapCountMismatch,
16 2 => MtmdTokenizeError::ImagePreprocessingError,
17 _ => MtmdTokenizeError::UnknownError(result),
18 }
19}
20
21const fn check_encode_result(result: i32) -> Result<(), MtmdEncodeError> {
22 if result == 0 {
23 Ok(())
24 } else {
25 Err(MtmdEncodeError::EncodeFailure(result))
26 }
27}
28
29#[derive(Debug)]
34pub struct MtmdContext {
35 pub context: NonNull<llama_cpp_bindings_sys::mtmd_context>,
37}
38
39unsafe impl Send for MtmdContext {}
40unsafe impl Sync for MtmdContext {}
41
42impl MtmdContext {
43 pub fn init_from_file(
51 mmproj_path: &str,
52 text_model: &LlamaModel,
53 params: &MtmdContextParams,
54 ) -> Result<Self, MtmdInitError> {
55 let path_cstr = CString::new(mmproj_path)?;
56 let ctx_params = llama_cpp_bindings_sys::mtmd_context_params::from(params);
57
58 let context = unsafe {
59 llama_cpp_bindings_sys::mtmd_init_from_file(
60 path_cstr.as_ptr(),
61 text_model.model.as_ptr(),
62 ctx_params,
63 )
64 };
65
66 let context = NonNull::new(context).ok_or(MtmdInitError::NullResult)?;
67
68 Ok(Self { context })
69 }
70
71 #[must_use]
74 pub fn decode_use_non_causal(&self, chunk: &MtmdInputChunk) -> bool {
75 unsafe {
76 llama_cpp_bindings_sys::mtmd_decode_use_non_causal(
77 self.context.as_ptr(),
78 chunk.chunk.as_ptr(),
79 )
80 }
81 }
82
83 #[must_use]
85 pub fn decode_use_mrope(&self) -> bool {
86 unsafe { llama_cpp_bindings_sys::mtmd_decode_use_mrope(self.context.as_ptr()) }
87 }
88
89 #[must_use]
91 pub fn support_vision(&self) -> bool {
92 unsafe { llama_cpp_bindings_sys::mtmd_support_vision(self.context.as_ptr()) }
93 }
94
95 #[must_use]
97 pub fn support_audio(&self) -> bool {
98 unsafe { llama_cpp_bindings_sys::mtmd_support_audio(self.context.as_ptr()) }
99 }
100
101 #[must_use]
104 pub fn get_audio_sample_rate(&self) -> Option<u32> {
105 let rate =
106 unsafe { llama_cpp_bindings_sys::mtmd_get_audio_sample_rate(self.context.as_ptr()) };
107 (rate > 0).then_some(rate.unsigned_abs())
108 }
109
110 pub fn tokenize(
137 &self,
138 text: MtmdInputText,
139 bitmaps: &[&MtmdBitmap],
140 ) -> Result<MtmdInputChunks, MtmdTokenizeError> {
141 let chunks = MtmdInputChunks::new()?;
142 let text_cstring = CString::new(text.text)?;
143 let input_text = llama_cpp_bindings_sys::mtmd_input_text {
144 text: text_cstring.as_ptr(),
145 add_special: text.add_special,
146 parse_special: text.parse_special,
147 };
148
149 let bitmap_ptrs: Vec<*const llama_cpp_bindings_sys::mtmd_bitmap> = bitmaps
150 .iter()
151 .map(|bitmap| bitmap.bitmap.as_ptr().cast_const())
152 .collect();
153
154 let result = unsafe {
155 llama_cpp_bindings_sys::mtmd_tokenize(
156 self.context.as_ptr(),
157 chunks.chunks.as_ptr(),
158 &raw const input_text,
159 bitmap_ptrs.as_ptr().cast_mut(),
160 bitmaps.len(),
161 )
162 };
163
164 if result == 0 {
165 Ok(chunks)
166 } else {
167 Err(tokenize_result_to_error(result))
168 }
169 }
170
171 pub fn encode_chunk(&self, chunk: &MtmdInputChunk) -> Result<(), MtmdEncodeError> {
177 let result = unsafe {
178 llama_cpp_bindings_sys::mtmd_encode_chunk(self.context.as_ptr(), chunk.chunk.as_ptr())
179 };
180
181 check_encode_result(result)
182 }
183}
184
185impl Drop for MtmdContext {
186 fn drop(&mut self) {
187 unsafe { llama_cpp_bindings_sys::mtmd_free(self.context.as_ptr()) }
188 }
189}
190
191#[cfg(test)]
192mod unit_tests {
193 use super::check_encode_result;
194 use super::tokenize_result_to_error;
195
196 #[test]
197 fn tokenize_result_bitmap_count_mismatch() {
198 let error = tokenize_result_to_error(1);
199
200 assert!(error.to_string().contains("does not match"));
201 }
202
203 #[test]
204 fn tokenize_result_image_preprocessing_error() {
205 let error = tokenize_result_to_error(2);
206
207 assert!(error.to_string().contains("Image preprocessing"));
208 }
209
210 #[test]
211 fn tokenize_result_unknown_error() {
212 let error = tokenize_result_to_error(42);
213
214 assert!(error.to_string().contains("Unknown error: 42"));
215 }
216
217 #[test]
218 fn check_encode_result_ok_for_zero() {
219 assert!(check_encode_result(0).is_ok());
220 }
221
222 #[test]
223 fn check_encode_result_error_for_nonzero() {
224 let result = check_encode_result(5);
225
226 assert!(
227 result
228 .unwrap_err()
229 .to_string()
230 .contains("Encode failed with code: 5")
231 );
232 }
233}