llama_cpp_bindings/mtmd/
mtmd_input_chunk.rs1use std::ffi::CStr;
2use std::ffi::c_char;
3use std::ptr::NonNull;
4use std::slice;
5
6use crate::context::LlamaContext;
7use crate::ffi_error_reader::read_and_free_cpp_error;
8use crate::token::LlamaToken;
9
10use super::image_chunk_batch_size_mismatch::ImageChunkBatchSizeMismatch;
11use super::mtmd_context::MtmdContext;
12use super::mtmd_eval_error::MtmdEvalError;
13use super::mtmd_input_chunk_error::MtmdInputChunkError;
14use super::mtmd_input_chunk_type::MtmdInputChunkType;
15use super::mtmd_input_chunk_type_error::MtmdInputChunkTypeError;
16
17const unsafe fn tokens_from_raw_ptr<'chunk>(
22 tokens_ptr: *const llama_cpp_bindings_sys::llama_token,
23 n_tokens: usize,
24) -> Option<&'chunk [LlamaToken]> {
25 if tokens_ptr.is_null() || n_tokens == 0 {
26 None
27 } else {
28 unsafe {
29 Some(slice::from_raw_parts(
30 tokens_ptr.cast::<LlamaToken>(),
31 n_tokens,
32 ))
33 }
34 }
35}
36
37#[derive(Debug)]
43pub struct MtmdInputChunk {
44 pub chunk: NonNull<llama_cpp_bindings_sys::mtmd_input_chunk>,
46 pub owned: bool,
47}
48
49impl MtmdInputChunk {
50 pub fn chunk_type(&self) -> Result<MtmdInputChunkType, MtmdInputChunkTypeError> {
55 let chunk_type =
56 unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_type(self.chunk.as_ptr()) };
57 MtmdInputChunkType::try_from(chunk_type)
58 }
59
60 #[must_use]
64 pub fn text_tokens(&self) -> Option<&[LlamaToken]> {
65 if self.chunk_type() != Ok(MtmdInputChunkType::Text) {
66 return None;
67 }
68
69 let mut n_tokens = 0usize;
70 let tokens_ptr = unsafe {
71 llama_cpp_bindings_sys::mtmd_input_chunk_get_tokens_text(
72 self.chunk.as_ptr(),
73 &raw mut n_tokens,
74 )
75 };
76
77 unsafe { tokens_from_raw_ptr(tokens_ptr, n_tokens) }
78 }
79
80 #[must_use]
82 pub fn n_tokens(&self) -> usize {
83 unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_n_tokens(self.chunk.as_ptr()) }
84 }
85
86 #[must_use]
88 pub fn n_positions(&self) -> i32 {
89 unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_n_pos(self.chunk.as_ptr()) }
90 }
91
92 #[must_use]
96 pub fn id(&self) -> Option<String> {
97 let ptr = unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_id(self.chunk.as_ptr()) };
98 if ptr.is_null() {
99 None
100 } else {
101 unsafe { CStr::from_ptr(ptr) }
102 .to_string_lossy()
103 .into_owned()
104 .into()
105 }
106 }
107
108 pub fn copy(&self) -> Result<Self, MtmdInputChunkError> {
114 let chunk = unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_copy(self.chunk.as_ptr()) };
115 let chunk = NonNull::new(chunk).ok_or(MtmdInputChunkError::ChunkOperationFailed)?;
116
117 Ok(Self { chunk, owned: true })
118 }
119
120 pub fn eval_single(
140 &self,
141 mtmd_ctx: &MtmdContext,
142 llama_ctx: &LlamaContext,
143 start_position: llama_cpp_bindings_sys::llama_pos,
144 seq_id: llama_cpp_bindings_sys::llama_seq_id,
145 n_batch: i32,
146 logits_last: bool,
147 ) -> Result<llama_cpp_bindings_sys::llama_pos, MtmdEvalError> {
148 let chunk_token_count = self.n_tokens();
149
150 if matches!(self.chunk_type(), Ok(MtmdInputChunkType::Image))
151 && i64::try_from(chunk_token_count).is_ok_and(|tokens| tokens > i64::from(n_batch))
152 {
153 #[expect(
154 clippy::cast_possible_truncation,
155 clippy::cast_sign_loss,
156 reason = "image token counts and n_batch are model-bounded and fit in u32"
157 )]
158 return Err(MtmdEvalError::ImageChunkExceedsBatchSize(
159 ImageChunkBatchSizeMismatch {
160 image_tokens: chunk_token_count as u32,
161 n_batch: n_batch as u32,
162 },
163 ));
164 }
165
166 let mut final_position: llama_cpp_bindings_sys::llama_pos = start_position;
167 let mut out_vendored_return_code: i32 = 0;
168 let mut out_error: *mut c_char = std::ptr::null_mut();
169
170 let status = unsafe {
171 llama_cpp_bindings_sys::llama_rs_mtmd_eval_chunk_single(
172 mtmd_ctx.context.as_ptr(),
173 llama_ctx.context.as_ptr(),
174 self.chunk.as_ptr(),
175 start_position,
176 seq_id,
177 n_batch,
178 logits_last,
179 &raw mut final_position,
180 &raw mut out_vendored_return_code,
181 &raw mut out_error,
182 )
183 };
184
185 match status {
186 llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_OK => Ok(final_position),
187 llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_VENDORED_RETURNED_NONZERO_CODE => {
188 Err(MtmdEvalError::EvalFailed {
189 code: out_vendored_return_code,
190 })
191 }
192 llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_ERROR_STRING_ALLOCATION_FAILED => {
193 Err(MtmdEvalError::NotEnoughMemory)
194 }
195 llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_VENDORED_THREW_CXX_EXCEPTION => {
196 let message = unsafe { read_and_free_cpp_error(out_error) };
197 Err(MtmdEvalError::Reported { message })
198 }
199 other => unreachable!(
200 "llama_rs_mtmd_eval_chunk_single returned unrecognized status: {other}"
201 ),
202 }
203 }
204}
205
206impl Drop for MtmdInputChunk {
207 fn drop(&mut self) {
208 if self.owned {
209 unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_free(self.chunk.as_ptr()) }
210 }
211 }
212}
213
214#[cfg(test)]
215mod unit_tests {
216 use super::tokens_from_raw_ptr;
217
218 #[test]
219 fn tokens_from_raw_ptr_returns_none_for_null() {
220 assert!(unsafe { tokens_from_raw_ptr(std::ptr::null(), 5) }.is_none());
221 }
222
223 #[test]
224 fn tokens_from_raw_ptr_returns_none_for_zero_count() {
225 let token: llama_cpp_bindings_sys::llama_token = 42;
226 assert!(unsafe { tokens_from_raw_ptr(&raw const token, 0) }.is_none());
227 }
228
229 #[test]
230 fn tokens_from_raw_ptr_returns_some_for_valid() {
231 let tokens: [llama_cpp_bindings_sys::llama_token; 2] = [1, 2];
232 let result = unsafe { tokens_from_raw_ptr(tokens.as_ptr(), 2) };
233
234 assert!(result.is_some());
235 assert_eq!(result.unwrap().len(), 2);
236 }
237}