Skip to main content

llama_cpp_bindings/mtmd/
mtmd_input_chunk.rs

1use std::ffi::CStr;
2use std::ffi::c_char;
3use std::ptr::NonNull;
4use std::slice;
5
6use crate::context::LlamaContext;
7use crate::ffi_error_reader::read_and_free_cpp_error;
8use crate::token::LlamaToken;
9
10use super::image_chunk_batch_size_mismatch::ImageChunkBatchSizeMismatch;
11use super::mtmd_context::MtmdContext;
12use super::mtmd_eval_error::MtmdEvalError;
13use super::mtmd_input_chunk_error::MtmdInputChunkError;
14use super::mtmd_input_chunk_type::MtmdInputChunkType;
15use super::mtmd_input_chunk_type_error::MtmdInputChunkTypeError;
16
17/// # Safety
18///
19/// `tokens_ptr` must point to at least `n_tokens` valid `llama_token` values
20/// that remain valid for the lifetime `'chunk`.
21const unsafe fn tokens_from_raw_ptr<'chunk>(
22    tokens_ptr: *const llama_cpp_bindings_sys::llama_token,
23    n_tokens: usize,
24) -> Option<&'chunk [LlamaToken]> {
25    if tokens_ptr.is_null() || n_tokens == 0 {
26        None
27    } else {
28        unsafe {
29            Some(slice::from_raw_parts(
30                tokens_ptr.cast::<LlamaToken>(),
31                n_tokens,
32            ))
33        }
34    }
35}
36
37/// Safe wrapper around `mtmd_input_chunk`.
38///
39/// Represents a single chunk of input data, which can be either text tokens,
40/// image tokens, or audio tokens. The chunk type determines what kind of
41/// data and operations are available.
42#[derive(Debug)]
43pub struct MtmdInputChunk {
44    /// Raw pointer to the underlying `mtmd_input_chunk`.
45    pub chunk: NonNull<llama_cpp_bindings_sys::mtmd_input_chunk>,
46    pub owned: bool,
47}
48
49impl MtmdInputChunk {
50    /// Get the type of this chunk
51    ///
52    /// # Errors
53    /// Returns an error if the chunk type is unknown.
54    pub fn chunk_type(&self) -> Result<MtmdInputChunkType, MtmdInputChunkTypeError> {
55        let chunk_type =
56            unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_type(self.chunk.as_ptr()) };
57        MtmdInputChunkType::try_from(chunk_type)
58    }
59
60    /// Get text tokens from this chunk.
61    ///
62    /// Only valid for text chunks. Returns `None` for image or audio chunks.
63    #[must_use]
64    pub fn text_tokens(&self) -> Option<&[LlamaToken]> {
65        if self.chunk_type() != Ok(MtmdInputChunkType::Text) {
66            return None;
67        }
68
69        let mut n_tokens = 0usize;
70        let tokens_ptr = unsafe {
71            llama_cpp_bindings_sys::mtmd_input_chunk_get_tokens_text(
72                self.chunk.as_ptr(),
73                &raw mut n_tokens,
74            )
75        };
76
77        unsafe { tokens_from_raw_ptr(tokens_ptr, n_tokens) }
78    }
79
80    /// Get the number of tokens in this chunk
81    #[must_use]
82    pub fn n_tokens(&self) -> usize {
83        unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_n_tokens(self.chunk.as_ptr()) }
84    }
85
86    /// Get the number of positions in this chunk.
87    #[must_use]
88    pub fn n_positions(&self) -> i32 {
89        unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_n_pos(self.chunk.as_ptr()) }
90    }
91
92    /// Get chunk ID if available.
93    ///
94    /// Returns `None` for text chunks, may return an ID for image/audio chunks.
95    #[must_use]
96    pub fn id(&self) -> Option<String> {
97        let ptr = unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_id(self.chunk.as_ptr()) };
98        if ptr.is_null() {
99            None
100        } else {
101            unsafe { CStr::from_ptr(ptr) }
102                .to_string_lossy()
103                .into_owned()
104                .into()
105        }
106    }
107
108    /// Create a copy of this chunk that you own.
109    ///
110    /// # Errors
111    ///
112    /// Returns `MtmdInputChunkError::ChunkOperationFailed` if copying fails.
113    pub fn copy(&self) -> Result<Self, MtmdInputChunkError> {
114        let chunk = unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_copy(self.chunk.as_ptr()) };
115        let chunk = NonNull::new(chunk).ok_or(MtmdInputChunkError::ChunkOperationFailed)?;
116
117        Ok(Self { chunk, owned: true })
118    }
119
120    /// Evaluate this single chunk through the multimodal helper.
121    ///
122    /// Mirrors `MtmdInputChunks::eval_chunks` but for one chunk at a time, so
123    /// callers can interleave per-chunk decode with per-chunk bookkeeping
124    /// (token counting, marker state-machine replay) inside one loop instead
125    /// of running the helper-level all-chunks eval and a separate ingest pass.
126    ///
127    /// Image chunks are decoded as one `llama_decode` call inside the helper,
128    /// so their token count must fit in `n_batch`. When it would not, the
129    /// binding refuses the call up front because the C-side
130    /// `GGML_ASSERT(n_tokens_all <= cparams.n_batch)` would otherwise abort
131    /// the process.
132    ///
133    /// # Errors
134    ///
135    /// Returns [`MtmdEvalError::ImageChunkExceedsBatchSize`] when this is an
136    /// image chunk whose token count exceeds `n_batch`. Returns
137    /// [`MtmdEvalError::EvalFailure`] if the underlying encode or decode step
138    /// fails.
139    pub fn eval_single(
140        &self,
141        mtmd_ctx: &MtmdContext,
142        llama_ctx: &LlamaContext,
143        start_position: llama_cpp_bindings_sys::llama_pos,
144        seq_id: llama_cpp_bindings_sys::llama_seq_id,
145        n_batch: i32,
146        logits_last: bool,
147    ) -> Result<llama_cpp_bindings_sys::llama_pos, MtmdEvalError> {
148        let chunk_token_count = self.n_tokens();
149
150        if matches!(self.chunk_type(), Ok(MtmdInputChunkType::Image))
151            && i64::try_from(chunk_token_count).is_ok_and(|tokens| tokens > i64::from(n_batch))
152        {
153            #[expect(
154                clippy::cast_possible_truncation,
155                clippy::cast_sign_loss,
156                reason = "image token counts and n_batch are model-bounded and fit in u32"
157            )]
158            return Err(MtmdEvalError::ImageChunkExceedsBatchSize(
159                ImageChunkBatchSizeMismatch {
160                    image_tokens: chunk_token_count as u32,
161                    n_batch: n_batch as u32,
162                },
163            ));
164        }
165
166        let mut final_position: llama_cpp_bindings_sys::llama_pos = start_position;
167        let mut out_vendored_return_code: i32 = 0;
168        let mut out_error: *mut c_char = std::ptr::null_mut();
169
170        let status = unsafe {
171            llama_cpp_bindings_sys::llama_rs_mtmd_eval_chunk_single(
172                mtmd_ctx.context.as_ptr(),
173                llama_ctx.context.as_ptr(),
174                self.chunk.as_ptr(),
175                start_position,
176                seq_id,
177                n_batch,
178                logits_last,
179                &raw mut final_position,
180                &raw mut out_vendored_return_code,
181                &raw mut out_error,
182            )
183        };
184
185        match status {
186            llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_OK => Ok(final_position),
187            llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_VENDORED_RETURNED_NONZERO_CODE => {
188                Err(MtmdEvalError::EvalFailed {
189                    code: out_vendored_return_code,
190                })
191            }
192            llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_ERROR_STRING_ALLOCATION_FAILED => {
193                Err(MtmdEvalError::NotEnoughMemory)
194            }
195            llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_VENDORED_THREW_CXX_EXCEPTION => {
196                let message = unsafe { read_and_free_cpp_error(out_error) };
197                Err(MtmdEvalError::Reported { message })
198            }
199            other => unreachable!(
200                "llama_rs_mtmd_eval_chunk_single returned unrecognized status: {other}"
201            ),
202        }
203    }
204}
205
206impl Drop for MtmdInputChunk {
207    fn drop(&mut self) {
208        if self.owned {
209            unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_free(self.chunk.as_ptr()) }
210        }
211    }
212}
213
214#[cfg(test)]
215mod unit_tests {
216    use super::tokens_from_raw_ptr;
217
218    #[test]
219    fn tokens_from_raw_ptr_returns_none_for_null() {
220        assert!(unsafe { tokens_from_raw_ptr(std::ptr::null(), 5) }.is_none());
221    }
222
223    #[test]
224    fn tokens_from_raw_ptr_returns_none_for_zero_count() {
225        let token: llama_cpp_bindings_sys::llama_token = 42;
226        assert!(unsafe { tokens_from_raw_ptr(&raw const token, 0) }.is_none());
227    }
228
229    #[test]
230    fn tokens_from_raw_ptr_returns_some_for_valid() {
231        let tokens: [llama_cpp_bindings_sys::llama_token; 2] = [1, 2];
232        let result = unsafe { tokens_from_raw_ptr(tokens.as_ptr(), 2) };
233
234        assert!(result.is_some());
235        assert_eq!(result.unwrap().len(), 2);
236    }
237}