Skip to main content

llama_cpp_bindings/mtmd/
mtmd_input_chunks.rs

1use std::ptr::NonNull;
2
3use crate::context::LlamaContext;
4
5use super::mtmd_context::MtmdContext;
6use super::mtmd_error::MtmdEvalError;
7use super::mtmd_error::MtmdInputChunksError;
8use super::mtmd_input_chunk::MtmdInputChunk;
9
10const fn check_eval_result(result: i32) -> Result<(), MtmdEvalError> {
11    if result == 0 {
12        Ok(())
13    } else {
14        Err(MtmdEvalError::EvalFailure(result))
15    }
16}
17
18/// Safe wrapper around `mtmd_input_chunks`.
19///
20/// This is a collection of input chunks created from tokenizing text and media.
21/// The chunks represent the tokenized input that can be processed by the model,
22/// with text chunks containing tokens and media chunks containing embeddings.
23#[derive(Debug)]
24pub struct MtmdInputChunks {
25    /// Raw pointer to the underlying `mtmd_input_chunks`.
26    pub chunks: NonNull<llama_cpp_bindings_sys::mtmd_input_chunks>,
27}
28
29impl MtmdInputChunks {
30    /// Create a new empty input chunks collection.
31    ///
32    /// # Errors
33    ///
34    /// Returns `MtmdInputChunksError::NullResult` if the underlying llama.cpp function
35    /// returns null.
36    ///
37    /// # Examples
38    ///
39    /// ```
40    /// use llama_cpp_bindings::mtmd::MtmdInputChunks;
41    ///
42    /// let chunks = MtmdInputChunks::new().unwrap();
43    /// assert_eq!(chunks.len(), 0);
44    /// assert!(chunks.is_empty());
45    /// ```
46    pub fn new() -> Result<Self, MtmdInputChunksError> {
47        let chunks = unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_init() };
48        let chunks = NonNull::new(chunks).ok_or(MtmdInputChunksError::NullResult)?;
49
50        Ok(Self { chunks })
51    }
52
53    /// Get the number of chunks
54    #[must_use]
55    pub fn len(&self) -> usize {
56        unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_size(self.chunks.as_ptr()) }
57    }
58
59    /// Check if chunks collection is empty
60    #[must_use]
61    pub fn is_empty(&self) -> bool {
62        self.len() == 0
63    }
64
65    /// Get a chunk by index
66    #[must_use]
67    pub fn get(&self, index: usize) -> Option<MtmdInputChunk> {
68        if index >= self.len() {
69            return None;
70        }
71
72        let chunk_ptr =
73            unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_get(self.chunks.as_ptr(), index) };
74
75        NonNull::new(chunk_ptr.cast_mut()).map(|ptr| MtmdInputChunk {
76            chunk: ptr,
77            owned: false,
78        })
79    }
80
81    /// Get total number of tokens across all chunks.
82    #[must_use]
83    pub fn total_tokens(&self) -> usize {
84        unsafe { llama_cpp_bindings_sys::mtmd_helper_get_n_tokens(self.chunks.as_ptr()) }
85    }
86
87    /// Get total position count across all chunks.
88    #[must_use]
89    pub fn total_positions(&self) -> i32 {
90        unsafe { llama_cpp_bindings_sys::mtmd_helper_get_n_pos(self.chunks.as_ptr()) }
91    }
92
93    /// Evaluate chunks using the multimodal context and LLAMA context.
94    ///
95    /// # Errors
96    ///
97    /// Returns `MtmdEvalError::EvalFailure` if any encoding or decoding operation fails.
98    pub fn eval_chunks(
99        &self,
100        mtmd_ctx: &MtmdContext,
101        llama_ctx: &LlamaContext,
102        start_position: llama_cpp_bindings_sys::llama_pos,
103        seq_id: llama_cpp_bindings_sys::llama_seq_id,
104        n_batch: i32,
105        logits_last: bool,
106    ) -> Result<llama_cpp_bindings_sys::llama_pos, MtmdEvalError> {
107        let context_max_batch = llama_ctx.n_batch();
108
109        if n_batch > 0 && n_batch.cast_unsigned() > context_max_batch {
110            return Err(MtmdEvalError::BatchSizeExceedsContextLimit {
111                requested: n_batch,
112                context_max: context_max_batch,
113            });
114        }
115
116        // mtmd_helper_eval_chunks overwrites `*new_n_past` at the end of its
117        // chunk loop (mtmd-helper.cpp:413), so any seed would be fine — but
118        // we mirror the per-chunk wrapper's `start_position` / `final_position`
119        // shape here for parity, keeping the read-only input and write-only
120        // output strictly separated.
121        let mut final_position: llama_cpp_bindings_sys::llama_pos = start_position;
122
123        let result = unsafe {
124            llama_cpp_bindings_sys::mtmd_helper_eval_chunks(
125                mtmd_ctx.context.as_ptr(),
126                llama_ctx.context.as_ptr(),
127                self.chunks.as_ptr(),
128                start_position,
129                seq_id,
130                n_batch,
131                logits_last,
132                &raw mut final_position,
133            )
134        };
135
136        check_eval_result(result)?;
137
138        Ok(final_position)
139    }
140}
141
142impl Drop for MtmdInputChunks {
143    fn drop(&mut self) {
144        unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_free(self.chunks.as_ptr()) }
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::MtmdInputChunks;
151
152    #[test]
153    fn new_creates_empty_chunks() {
154        let chunks = MtmdInputChunks::new().unwrap();
155
156        assert!(chunks.is_empty());
157        assert_eq!(chunks.len(), 0);
158    }
159
160    #[test]
161    fn get_out_of_bounds_returns_none() {
162        let chunks = MtmdInputChunks::new().unwrap();
163
164        assert!(chunks.get(0).is_none());
165        assert!(chunks.get(999).is_none());
166    }
167
168    #[test]
169    fn check_eval_result_ok_for_zero() {
170        use super::check_eval_result;
171
172        assert!(check_eval_result(0).is_ok());
173    }
174
175    #[test]
176    fn check_eval_result_error_for_nonzero() {
177        use super::check_eval_result;
178
179        let result = check_eval_result(7);
180
181        assert!(
182            result
183                .unwrap_err()
184                .to_string()
185                .contains("Eval failed with code: 7")
186        );
187    }
188}