Skip to main content

llama_cpp_bindings/mtmd/
mtmd_input_chunk.rs

1use std::ffi::CStr;
2use std::ffi::c_char;
3use std::ptr::NonNull;
4use std::slice;
5
6use crate::context::LlamaContext;
7use crate::ffi_error_reader::read_and_free_cpp_error;
8use crate::token::LlamaToken;
9
10use super::image_chunk_batch_size_mismatch::ImageChunkBatchSizeMismatch;
11use super::mtmd_context::MtmdContext;
12use super::mtmd_eval_error::MtmdEvalError;
13use super::mtmd_input_chunk_error::MtmdInputChunkError;
14use super::mtmd_input_chunk_type::MtmdInputChunkType;
15use super::mtmd_input_chunk_type_error::MtmdInputChunkTypeError;
16
17/// # Safety
18///
19/// `tokens_ptr` must point to at least `n_tokens` valid `llama_token` values
20/// that remain valid for the lifetime `'chunk`.
21const unsafe fn tokens_from_raw_ptr<'chunk>(
22    tokens_ptr: *const llama_cpp_bindings_sys::llama_token,
23    n_tokens: usize,
24) -> Option<&'chunk [LlamaToken]> {
25    if tokens_ptr.is_null() || n_tokens == 0 {
26        None
27    } else {
28        unsafe {
29            Some(slice::from_raw_parts(
30                tokens_ptr.cast::<LlamaToken>(),
31                n_tokens,
32            ))
33        }
34    }
35}
36
37#[derive(Debug)]
38pub struct MtmdInputChunk {
39    pub chunk: NonNull<llama_cpp_bindings_sys::mtmd_input_chunk>,
40    pub owned: bool,
41}
42
43impl MtmdInputChunk {
44    /// # Errors
45    /// Returns an error if the chunk type is unknown.
46    pub fn chunk_type(&self) -> Result<MtmdInputChunkType, MtmdInputChunkTypeError> {
47        let chunk_type =
48            unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_type(self.chunk.as_ptr()) };
49        MtmdInputChunkType::try_from(chunk_type)
50    }
51
52    #[must_use]
53    pub fn text_tokens(&self) -> Option<&[LlamaToken]> {
54        if self.chunk_type() != Ok(MtmdInputChunkType::Text) {
55            return None;
56        }
57
58        let mut n_tokens = 0usize;
59        let tokens_ptr = unsafe {
60            llama_cpp_bindings_sys::mtmd_input_chunk_get_tokens_text(
61                self.chunk.as_ptr(),
62                &raw mut n_tokens,
63            )
64        };
65
66        unsafe { tokens_from_raw_ptr(tokens_ptr, n_tokens) }
67    }
68
69    #[must_use]
70    pub fn n_tokens(&self) -> usize {
71        unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_n_tokens(self.chunk.as_ptr()) }
72    }
73
74    #[must_use]
75    pub fn n_positions(&self) -> i32 {
76        unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_n_pos(self.chunk.as_ptr()) }
77    }
78
79    #[must_use]
80    pub fn id(&self) -> Option<String> {
81        let ptr = unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_get_id(self.chunk.as_ptr()) };
82        if ptr.is_null() {
83            None
84        } else {
85            unsafe { CStr::from_ptr(ptr) }
86                .to_string_lossy()
87                .into_owned()
88                .into()
89        }
90    }
91
92    /// # Errors
93    ///
94    /// Returns `MtmdInputChunkError::ChunkOperationFailed` if copying fails.
95    pub fn copy(&self) -> Result<Self, MtmdInputChunkError> {
96        let chunk = unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_copy(self.chunk.as_ptr()) };
97        let chunk = NonNull::new(chunk).ok_or(MtmdInputChunkError::ChunkOperationFailed)?;
98
99        Ok(Self { chunk, owned: true })
100    }
101
102    /// # Errors
103    ///
104    /// Returns [`MtmdEvalError::ImageChunkExceedsBatchSize`] when this is an
105    /// image chunk whose token count exceeds `n_batch`. Returns
106    /// [`MtmdEvalError::EvalFailure`] if the underlying encode or decode step
107    /// fails.
108    pub fn eval_single(
109        &self,
110        mtmd_ctx: &MtmdContext,
111        llama_ctx: &LlamaContext,
112        start_position: llama_cpp_bindings_sys::llama_pos,
113        seq_id: llama_cpp_bindings_sys::llama_seq_id,
114        n_batch: i32,
115        logits_last: bool,
116    ) -> Result<llama_cpp_bindings_sys::llama_pos, MtmdEvalError> {
117        let chunk_token_count = self.n_tokens();
118
119        if matches!(self.chunk_type(), Ok(MtmdInputChunkType::Image))
120            && i64::try_from(chunk_token_count).is_ok_and(|tokens| tokens > i64::from(n_batch))
121        {
122            #[expect(
123                clippy::cast_possible_truncation,
124                clippy::cast_sign_loss,
125                reason = "image token counts and n_batch are model-bounded and fit in u32"
126            )]
127            return Err(MtmdEvalError::ImageChunkExceedsBatchSize(
128                ImageChunkBatchSizeMismatch {
129                    image_tokens: chunk_token_count as u32,
130                    n_batch: n_batch as u32,
131                },
132            ));
133        }
134
135        let mut final_position: llama_cpp_bindings_sys::llama_pos = start_position;
136        let mut out_vendored_return_code: i32 = 0;
137        let mut out_error: *mut c_char = std::ptr::null_mut();
138
139        let status = unsafe {
140            llama_cpp_bindings_sys::llama_rs_mtmd_eval_chunk_single(
141                mtmd_ctx.context.as_ptr(),
142                llama_ctx.context.as_ptr(),
143                self.chunk.as_ptr(),
144                start_position,
145                seq_id,
146                n_batch,
147                logits_last,
148                &raw mut final_position,
149                &raw mut out_vendored_return_code,
150                &raw mut out_error,
151            )
152        };
153
154        match status {
155            llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_OK => Ok(final_position),
156            llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_VENDORED_RETURNED_NONZERO_CODE => {
157                Err(MtmdEvalError::EvalFailed {
158                    code: out_vendored_return_code,
159                })
160            }
161            llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_ERROR_STRING_ALLOCATION_FAILED => {
162                Err(MtmdEvalError::NotEnoughMemory)
163            }
164            llama_cpp_bindings_sys::LLAMA_RS_MTMD_EVAL_CHUNK_SINGLE_VENDORED_THREW_CXX_EXCEPTION => {
165                let message = unsafe { read_and_free_cpp_error(out_error) };
166                Err(MtmdEvalError::Reported { message })
167            }
168            other => unreachable!(
169                "llama_rs_mtmd_eval_chunk_single returned unrecognized status: {other}"
170            ),
171        }
172    }
173}
174
175impl Drop for MtmdInputChunk {
176    fn drop(&mut self) {
177        if self.owned {
178            unsafe { llama_cpp_bindings_sys::mtmd_input_chunk_free(self.chunk.as_ptr()) }
179        }
180    }
181}
182
183#[cfg(test)]
184mod unit_tests {
185    use super::tokens_from_raw_ptr;
186
187    #[test]
188    fn tokens_from_raw_ptr_returns_none_for_null() {
189        assert!(unsafe { tokens_from_raw_ptr(std::ptr::null(), 5) }.is_none());
190    }
191
192    #[test]
193    fn tokens_from_raw_ptr_returns_none_for_zero_count() {
194        let token: llama_cpp_bindings_sys::llama_token = 42;
195        assert!(unsafe { tokens_from_raw_ptr(&raw const token, 0) }.is_none());
196    }
197
198    #[test]
199    fn tokens_from_raw_ptr_returns_some_for_valid() {
200        let tokens: [llama_cpp_bindings_sys::llama_token; 2] = [1, 2];
201        let result = unsafe { tokens_from_raw_ptr(tokens.as_ptr(), 2) };
202
203        assert!(result.is_some());
204        assert_eq!(result.unwrap().len(), 2);
205    }
206}