Skip to main content

llama_cpp_bindings/mtmd/
mtmd_input_chunk.rs

1use std::ffi::CStr;
2use std::ptr::NonNull;
3use std::slice;
4
5use crate::context::LlamaContext;
6use crate::token::LlamaToken;
7
8use super::image_chunk_batch_size_mismatch::ImageChunkBatchSizeMismatch;
9use super::mtmd_context::MtmdContext;
10use super::mtmd_error::MtmdEvalError;
11use super::mtmd_error::MtmdInputChunkError;
12use super::mtmd_input_chunk_type::{MtmdInputChunkType, MtmdInputChunkTypeError};
13
14/// # Safety
15///
16/// `tokens_ptr` must point to at least `n_tokens` valid `llama_token` values
17/// that remain valid for the lifetime `'chunk`.
18const unsafe fn tokens_from_raw_ptr<'chunk>(
19    tokens_ptr: *const llama_cpp_bindings_sys::llama_token,
20    n_tokens: usize,
21) -> Option<&'chunk [LlamaToken]> {
22    if tokens_ptr.is_null() || n_tokens == 0 {
23        None
24    } else {
25        unsafe {
26            Some(slice::from_raw_parts(
27                tokens_ptr.cast::<LlamaToken>(),
28                n_tokens,
29            ))
30        }
31    }
32}
33
34/// Safe wrapper around `mtmd_input_chunk`.
35///
36/// Represents a single chunk of input data, which can be either text tokens,
37/// image tokens, or audio tokens. The chunk type determines what kind of
38/// data and operations are available.
39#[derive(Debug)]
40pub struct MtmdInputChunk {
41    /// Raw pointer to the underlying `mtmd_input_chunk`.
42    pub chunk: NonNull<llama_cpp_bindings_sys::mtmd_input_chunk>,
43    pub owned: bool,
44}
45
46impl MtmdInputChunk {
47    /// Get the type of this chunk
48    ///
49    /// # Errors
50    /// Returns an error if the chunk type is unknown.
51    pub fn chunk_type(&self) -> Result<MtmdInputChunkType, MtmdInputChunkTypeError> {
52        let chunk_type =
53            unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_type(self.chunk.as_ptr()) };
54        MtmdInputChunkType::try_from(chunk_type)
55    }
56
57    /// Get text tokens from this chunk.
58    ///
59    /// Only valid for text chunks. Returns `None` for image or audio chunks.
60    #[must_use]
61    pub fn text_tokens(&self) -> Option<&[LlamaToken]> {
62        if self.chunk_type() != Ok(MtmdInputChunkType::Text) {
63            return None;
64        }
65
66        let mut n_tokens = 0usize;
67        let tokens_ptr = unsafe {
68            llama_cpp_bindings_sys::mtmd_input_chunk_get_tokens_text(
69                self.chunk.as_ptr(),
70                &raw mut n_tokens,
71            )
72        };
73
74        unsafe { tokens_from_raw_ptr(tokens_ptr, n_tokens) }
75    }
76
77    /// Get the number of tokens in this chunk
78    #[must_use]
79    pub fn n_tokens(&self) -> usize {
80        unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_n_tokens(self.chunk.as_ptr()) }
81    }
82
83    /// Get the number of positions in this chunk.
84    #[must_use]
85    pub fn n_positions(&self) -> i32 {
86        unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_n_pos(self.chunk.as_ptr()) }
87    }
88
89    /// Get chunk ID if available.
90    ///
91    /// Returns `None` for text chunks, may return an ID for image/audio chunks.
92    #[must_use]
93    pub fn id(&self) -> Option<String> {
94        let ptr = unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_id(self.chunk.as_ptr()) };
95        if ptr.is_null() {
96            None
97        } else {
98            unsafe { CStr::from_ptr(ptr) }
99                .to_string_lossy()
100                .into_owned()
101                .into()
102        }
103    }
104
105    /// Create a copy of this chunk that you own.
106    ///
107    /// # Errors
108    ///
109    /// Returns `MtmdInputChunkError::NullResult` if copying fails.
110    pub fn copy(&self) -> Result<Self, MtmdInputChunkError> {
111        let chunk = unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_copy(self.chunk.as_ptr()) };
112        let chunk = NonNull::new(chunk).ok_or(MtmdInputChunkError::NullResult)?;
113
114        Ok(Self { chunk, owned: true })
115    }
116
117    /// Evaluate this single chunk through the multimodal helper.
118    ///
119    /// Mirrors `MtmdInputChunks::eval_chunks` but for one chunk at a time, so
120    /// callers can interleave per-chunk decode with per-chunk bookkeeping
121    /// (token counting, marker state-machine replay) inside one loop instead
122    /// of running the helper-level all-chunks eval and a separate ingest pass.
123    ///
124    /// Image chunks are decoded as one `llama_decode` call inside the helper,
125    /// so their token count must fit in `n_batch`. When it would not, the
126    /// binding refuses the call up front because the C-side
127    /// `GGML_ASSERT(n_tokens_all <= cparams.n_batch)` would otherwise abort
128    /// the process.
129    ///
130    /// # Errors
131    ///
132    /// Returns [`MtmdEvalError::ImageChunkExceedsBatchSize`] when this is an
133    /// image chunk whose token count exceeds `n_batch`. Returns
134    /// [`MtmdEvalError::EvalFailure`] if the underlying encode or decode step
135    /// fails.
136    pub fn eval_single(
137        &self,
138        mtmd_ctx: &MtmdContext,
139        llama_ctx: &LlamaContext,
140        start_position: llama_cpp_bindings_sys::llama_pos,
141        seq_id: llama_cpp_bindings_sys::llama_seq_id,
142        n_batch: i32,
143        logits_last: bool,
144    ) -> Result<llama_cpp_bindings_sys::llama_pos, MtmdEvalError> {
145        let chunk_token_count = self.n_tokens();
146
147        if matches!(self.chunk_type(), Ok(MtmdInputChunkType::Image))
148            && i64::try_from(chunk_token_count).is_ok_and(|tokens| tokens > i64::from(n_batch))
149        {
150            #[expect(
151                clippy::cast_possible_truncation,
152                clippy::cast_sign_loss,
153                reason = "image token counts and n_batch are model-bounded and fit in u32"
154            )]
155            return Err(MtmdEvalError::ImageChunkExceedsBatchSize(
156                ImageChunkBatchSizeMismatch {
157                    image_tokens: chunk_token_count as u32,
158                    n_batch: n_batch as u32,
159                },
160            ));
161        }
162
163        let mut final_position: llama_cpp_bindings_sys::llama_pos = start_position;
164
165        let result = unsafe {
166            llama_cpp_bindings_sys::mtmd_helper_eval_chunk_single(
167                mtmd_ctx.context.as_ptr(),
168                llama_ctx.context.as_ptr(),
169                self.chunk.as_ptr(),
170                start_position,
171                seq_id,
172                n_batch,
173                logits_last,
174                &raw mut final_position,
175            )
176        };
177
178        if result == 0 {
179            Ok(final_position)
180        } else {
181            Err(MtmdEvalError::EvalFailure(result))
182        }
183    }
184}
185
186impl Drop for MtmdInputChunk {
187    fn drop(&mut self) {
188        if self.owned {
189            unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_free(self.chunk.as_ptr()) }
190        }
191    }
192}
193
194#[cfg(test)]
195mod unit_tests {
196    use super::tokens_from_raw_ptr;
197
198    #[test]
199    fn tokens_from_raw_ptr_returns_none_for_null() {
200        assert!(unsafe { tokens_from_raw_ptr(std::ptr::null(), 5) }.is_none());
201    }
202
203    #[test]
204    fn tokens_from_raw_ptr_returns_none_for_zero_count() {
205        let token: llama_cpp_bindings_sys::llama_token = 42;
206        assert!(unsafe { tokens_from_raw_ptr(&raw const token, 0) }.is_none());
207    }
208
209    #[test]
210    fn tokens_from_raw_ptr_returns_some_for_valid() {
211        let tokens: [llama_cpp_bindings_sys::llama_token; 2] = [1, 2];
212        let result = unsafe { tokens_from_raw_ptr(tokens.as_ptr(), 2) };
213
214        assert!(result.is_some());
215        assert_eq!(result.unwrap().len(), 2);
216    }
217}