llama_cpp_bindings/mtmd/
mtmd_input_chunk.rs1use std::ffi::CStr;
2use std::ffi::c_char;
3use std::ptr::NonNull;
4use std::slice;
5
6use crate::context::LlamaContext;
7use crate::ffi_error_reader::read_and_free_cpp_error;
8use crate::token::LlamaToken;
9
10use super::image_chunk_batch_size_mismatch::ImageChunkBatchSizeMismatch;
11use super::mtmd_context::MtmdContext;
12use super::mtmd_eval_error::MtmdEvalError;
13use super::mtmd_input_chunk_error::MtmdInputChunkError;
14use super::mtmd_input_chunk_type::MtmdInputChunkType;
15use super::mtmd_input_chunk_type_error::MtmdInputChunkTypeError;
16
17const unsafe fn tokens_from_raw_ptr<'chunk>(
22 tokens_ptr: *const llama_cpp_bindings_sys::llama_token,
23 n_tokens: usize,
24) -> Option<&'chunk [LlamaToken]> {
25 if tokens_ptr.is_null() || n_tokens == 0 {
26 None
27 } else {
28 unsafe {
29 Some(slice::from_raw_parts(
30 tokens_ptr.cast::<LlamaToken>(),
31 n_tokens,
32 ))
33 }
34 }
35}
36
37#[derive(Debug)]
38pub struct MtmdInputChunk {
39 pub chunk: NonNull<llama_cpp_bindings_sys::mtmd_input_chunk>,
40 pub owned: bool,
41}
42
43impl MtmdInputChunk {
44 pub fn chunk_type(&self) -> Result<MtmdInputChunkType, MtmdInputChunkTypeError> {
47 let chunk_type =
48 unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_type(self.chunk.as_ptr()) };
49 MtmdInputChunkType::try_from(chunk_type)
50 }
51
52 #[must_use]
53 pub fn text_tokens(&self) -> Option<&[LlamaToken]> {
54 if self.chunk_type() != Ok(MtmdInputChunkType::Text) {
55 return None;
56 }
57
58 let mut n_tokens = 0usize;
59 let tokens_ptr = unsafe {
60 llama_cpp_bindings_sys::mtmd_input_chunk_get_tokens_text(
61 self.chunk.as_ptr(),
62 &raw mut n_tokens,
63 )
64 };
65
66 unsafe { tokens_from_raw_ptr(tokens_ptr, n_tokens) }
67 }
68
69 #[must_use]
70 pub fn n_tokens(&self) -> usize {
71 unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_n_tokens(self.chunk.as_ptr()) }
72 }
73
74 #[must_use]
75 pub fn n_positions(&self) -> i32 {
76 unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_n_pos(self.chunk.as_ptr()) }
77 }
78
79 #[must_use]
80 pub fn id(&self) -> Option<String> {
81 let ptr = unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_id(self.chunk.as_ptr()) };
82 if ptr.is_null() {
83 None
84 } else {
85 unsafe { CStr::from_ptr(ptr) }
86 .to_string_lossy()
87 .into_owned()
88 .into()
89 }
90 }
91
92 pub fn copy(&self) -> Result<Self, MtmdInputChunkError> {
96 let chunk = unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_copy(self.chunk.as_ptr()) };
97 let chunk = NonNull::new(chunk).ok_or(MtmdInputChunkError::ChunkOperationFailed)?;
98
99 Ok(Self { chunk, owned: true })
100 }
101
102 pub fn eval_single(
109 &self,
110 mtmd_ctx: &MtmdContext,
111 llama_ctx: &LlamaContext,
112 start_position: llama_cpp_bindings_sys::llama_pos,
113 seq_id: llama_cpp_bindings_sys::llama_seq_id,
114 n_batch: i32,
115 logits_last: bool,
116 ) -> Result<llama_cpp_bindings_sys::llama_pos, MtmdEvalError> {
117 let chunk_token_count = self.n_tokens();
118
119 if matches!(self.chunk_type(), Ok(MtmdInputChunkType::Image))
120 && i64::try_from(chunk_token_count).is_ok_and(|tokens| tokens > i64::from(n_batch))
121 {
122 #[expect(
123 clippy::cast_possible_truncation,
124 clippy::cast_sign_loss,
125 reason = "image token counts and n_batch are model-bounded and fit in u32"
126 )]
127 return Err(MtmdEvalError::ImageChunkExceedsBatchSize(
128 ImageChunkBatchSizeMismatch {
129 image_tokens: chunk_token_count as u32,
130 n_batch: n_batch as u32,
131 },
132 ));
133 }
134
135 let mut final_position: llama_cpp_bindings_sys::llama_pos = start_position;
136 let mut out_vendored_return_code: i32 = 0;
137 let mut out_error: *mut c_char = std::ptr::null_mut();
138
139 let status = unsafe {
140 llama_cpp_bindings_sys::llama_rs_mtmd_eval_chunk_single(
141 mtmd_ctx.context.as_ptr(),
142 llama_ctx.context.as_ptr(),
143 self.chunk.as_ptr(),
144 start_position,
145 seq_id,
146 n_batch,
147 logits_last,
148 &raw mut final_position,
149 &raw mut out_vendored_return_code,
150 &raw mut out_error,
151 )
152 };
153
154 match status {
155 llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_OK => Ok(final_position),
156 llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_VENDORED_RETURNED_NONZERO_CODE => {
157 Err(MtmdEvalError::EvalFailed {
158 code: out_vendored_return_code,
159 })
160 }
161 llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_ERROR_STRING_ALLOCATION_FAILED => {
162 Err(MtmdEvalError::NotEnoughMemory)
163 }
164 llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_VENDORED_THREW_CXX_EXCEPTION => {
165 let message = unsafe { read_and_free_cpp_error(out_error) };
166 Err(MtmdEvalError::Reported { message })
167 }
168 other => unreachable!(
169 "llama_rs_mtmd_eval_chunk_single returned unrecognized status: {other}"
170 ),
171 }
172 }
173}
174
175impl Drop for MtmdInputChunk {
176 fn drop(&mut self) {
177 if self.owned {
178 unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_free(self.chunk.as_ptr()) }
179 }
180 }
181}
182
183#[cfg(test)]
184mod unit_tests {
185 use super::tokens_from_raw_ptr;
186
187 #[test]
188 fn tokens_from_raw_ptr_returns_none_for_null() {
189 assert!(unsafe { tokens_from_raw_ptr(std::ptr::null(), 5) }.is_none());
190 }
191
192 #[test]
193 fn tokens_from_raw_ptr_returns_none_for_zero_count() {
194 let token: llama_cpp_bindings_sys::llama_token = 42;
195 assert!(unsafe { tokens_from_raw_ptr(&raw const token, 0) }.is_none());
196 }
197
198 #[test]
199 fn tokens_from_raw_ptr_returns_some_for_valid() {
200 let tokens: [llama_cpp_bindings_sys::llama_token; 2] = [1, 2];
201 let result = unsafe { tokens_from_raw_ptr(tokens.as_ptr(), 2) };
202
203 assert!(result.is_some());
204 assert_eq!(result.unwrap().len(), 2);
205 }
206}