llama_cpp_bindings/mtmd/
mtmd_input_chunk.rs1use std::ffi::CStr;
2use std::ptr::NonNull;
3use std::slice;
4
5use crate::context::LlamaContext;
6use crate::token::LlamaToken;
7
8use super::image_chunk_batch_size_mismatch::ImageChunkBatchSizeMismatch;
9use super::mtmd_context::MtmdContext;
10use super::mtmd_error::MtmdEvalError;
11use super::mtmd_error::MtmdInputChunkError;
12use super::mtmd_input_chunk_type::{MtmdInputChunkType, MtmdInputChunkTypeError};
13
14const unsafe fn tokens_from_raw_ptr<'chunk>(
19 tokens_ptr: *const llama_cpp_bindings_sys::llama_token,
20 n_tokens: usize,
21) -> Option<&'chunk [LlamaToken]> {
22 if tokens_ptr.is_null() || n_tokens == 0 {
23 None
24 } else {
25 unsafe {
26 Some(slice::from_raw_parts(
27 tokens_ptr.cast::<LlamaToken>(),
28 n_tokens,
29 ))
30 }
31 }
32}
33
34#[derive(Debug)]
40pub struct MtmdInputChunk {
41 pub chunk: NonNull<llama_cpp_bindings_sys::mtmd_input_chunk>,
43 pub owned: bool,
44}
45
46impl MtmdInputChunk {
47 pub fn chunk_type(&self) -> Result<MtmdInputChunkType, MtmdInputChunkTypeError> {
52 let chunk_type =
53 unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_type(self.chunk.as_ptr()) };
54 MtmdInputChunkType::try_from(chunk_type)
55 }
56
57 #[must_use]
61 pub fn text_tokens(&self) -> Option<&[LlamaToken]> {
62 if self.chunk_type() != Ok(MtmdInputChunkType::Text) {
63 return None;
64 }
65
66 let mut n_tokens = 0usize;
67 let tokens_ptr = unsafe {
68 llama_cpp_bindings_sys::mtmd_input_chunk_get_tokens_text(
69 self.chunk.as_ptr(),
70 &raw mut n_tokens,
71 )
72 };
73
74 unsafe { tokens_from_raw_ptr(tokens_ptr, n_tokens) }
75 }
76
77 #[must_use]
79 pub fn n_tokens(&self) -> usize {
80 unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_n_tokens(self.chunk.as_ptr()) }
81 }
82
83 #[must_use]
85 pub fn n_positions(&self) -> i32 {
86 unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_n_pos(self.chunk.as_ptr()) }
87 }
88
89 #[must_use]
93 pub fn id(&self) -> Option<String> {
94 let ptr = unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_id(self.chunk.as_ptr()) };
95 if ptr.is_null() {
96 None
97 } else {
98 unsafe { CStr::from_ptr(ptr) }
99 .to_string_lossy()
100 .into_owned()
101 .into()
102 }
103 }
104
105 pub fn copy(&self) -> Result<Self, MtmdInputChunkError> {
111 let chunk = unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_copy(self.chunk.as_ptr()) };
112 let chunk = NonNull::new(chunk).ok_or(MtmdInputChunkError::NullResult)?;
113
114 Ok(Self { chunk, owned: true })
115 }
116
117 pub fn eval_single(
137 &self,
138 mtmd_ctx: &MtmdContext,
139 llama_ctx: &LlamaContext,
140 start_position: llama_cpp_bindings_sys::llama_pos,
141 seq_id: llama_cpp_bindings_sys::llama_seq_id,
142 n_batch: i32,
143 logits_last: bool,
144 ) -> Result<llama_cpp_bindings_sys::llama_pos, MtmdEvalError> {
145 let chunk_token_count = self.n_tokens();
146
147 if matches!(self.chunk_type(), Ok(MtmdInputChunkType::Image))
148 && i64::try_from(chunk_token_count).is_ok_and(|tokens| tokens > i64::from(n_batch))
149 {
150 #[expect(
151 clippy::cast_possible_truncation,
152 clippy::cast_sign_loss,
153 reason = "image token counts and n_batch are model-bounded and fit in u32"
154 )]
155 return Err(MtmdEvalError::ImageChunkExceedsBatchSize(
156 ImageChunkBatchSizeMismatch {
157 image_tokens: chunk_token_count as u32,
158 n_batch: n_batch as u32,
159 },
160 ));
161 }
162
163 let mut final_position: llama_cpp_bindings_sys::llama_pos = start_position;
164
165 let result = unsafe {
166 llama_cpp_bindings_sys::mtmd_helper_eval_chunk_single(
167 mtmd_ctx.context.as_ptr(),
168 llama_ctx.context.as_ptr(),
169 self.chunk.as_ptr(),
170 start_position,
171 seq_id,
172 n_batch,
173 logits_last,
174 &raw mut final_position,
175 )
176 };
177
178 if result == 0 {
179 Ok(final_position)
180 } else {
181 Err(MtmdEvalError::EvalFailure(result))
182 }
183 }
184}
185
186impl Drop for MtmdInputChunk {
187 fn drop(&mut self) {
188 if self.owned {
189 unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_free(self.chunk.as_ptr()) }
190 }
191 }
192}
193
194#[cfg(test)]
195mod unit_tests {
196 use super::tokens_from_raw_ptr;
197
198 #[test]
199 fn tokens_from_raw_ptr_returns_none_for_null() {
200 assert!(unsafe { tokens_from_raw_ptr(std::ptr::null(), 5) }.is_none());
201 }
202
203 #[test]
204 fn tokens_from_raw_ptr_returns_none_for_zero_count() {
205 let token: llama_cpp_bindings_sys::llama_token = 42;
206 assert!(unsafe { tokens_from_raw_ptr(&raw const token, 0) }.is_none());
207 }
208
209 #[test]
210 fn tokens_from_raw_ptr_returns_some_for_valid() {
211 let tokens: [llama_cpp_bindings_sys::llama_token; 2] = [1, 2];
212 let result = unsafe { tokens_from_raw_ptr(tokens.as_ptr(), 2) };
213
214 assert!(result.is_some());
215 assert_eq!(result.unwrap().len(), 2);
216 }
217}