Skip to main content

llama_cpp_bindings/mtmd/
mtmd_input_chunks.rs

1use std::ptr::NonNull;
2
3use crate::context::LlamaContext;
4
5use super::mtmd_context::MtmdContext;
6use super::mtmd_eval_error::MtmdEvalError;
7use super::mtmd_input_chunk::MtmdInputChunk;
8use super::mtmd_input_chunks_error::MtmdInputChunksError;
9
10const fn check_eval_result(result: i32) -> Result<(), MtmdEvalError> {
11    if result == 0 {
12        Ok(())
13    } else {
14        Err(MtmdEvalError::EvalFailed { code: result })
15    }
16}
17
18#[derive(Debug)]
19pub struct MtmdInputChunks {
20    pub chunks: NonNull<llama_cpp_bindings_sys::mtmd_input_chunks>,
21}
22
23impl MtmdInputChunks {
24    /// # Errors
25    ///
26    /// Returns `MtmdInputChunksError::ChunksCreationFailed` if the underlying llama.cpp function
27    /// returns null.
28    ///
29    pub fn new() -> Result<Self, MtmdInputChunksError> {
30        let chunks = unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_init() };
31        let chunks = NonNull::new(chunks).ok_or(MtmdInputChunksError::ChunksCreationFailed)?;
32
33        Ok(Self { chunks })
34    }
35
36    #[must_use]
37    pub fn len(&self) -> usize {
38        unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_size(self.chunks.as_ptr()) }
39    }
40
41    #[must_use]
42    pub fn is_empty(&self) -> bool {
43        self.len() == 0
44    }
45
46    #[must_use]
47    pub fn get(&self, index: usize) -> Option<MtmdInputChunk> {
48        if index >= self.len() {
49            return None;
50        }
51
52        let chunk_ptr =
53            unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_get(self.chunks.as_ptr(), index) };
54
55        NonNull::new(chunk_ptr.cast_mut()).map(|ptr| MtmdInputChunk {
56            chunk: ptr,
57            owned: false,
58        })
59    }
60
61    #[must_use]
62    pub fn total_tokens(&self) -> usize {
63        unsafe { llama_cpp_bindings_sys::mtmd_helper_get_n_tokens(self.chunks.as_ptr()) }
64    }
65
66    #[must_use]
67    pub fn total_positions(&self) -> i32 {
68        unsafe { llama_cpp_bindings_sys::mtmd_helper_get_n_pos(self.chunks.as_ptr()) }
69    }
70
71    /// # Errors
72    ///
73    /// Returns `MtmdEvalError::EvalFailure` if any encoding or decoding operation fails.
74    pub fn eval_chunks(
75        &self,
76        mtmd_ctx: &MtmdContext,
77        llama_ctx: &LlamaContext,
78        start_position: llama_cpp_bindings_sys::llama_pos,
79        seq_id: llama_cpp_bindings_sys::llama_seq_id,
80        n_batch: i32,
81        logits_last: bool,
82    ) -> Result<llama_cpp_bindings_sys::llama_pos, MtmdEvalError> {
83        let context_max_batch = llama_ctx.n_batch();
84
85        if n_batch > 0 && n_batch.cast_unsigned() > context_max_batch {
86            return Err(MtmdEvalError::BatchSizeExceedsContextLimit {
87                requested: n_batch,
88                context_max: context_max_batch,
89            });
90        }
91
92        let mut final_position: llama_cpp_bindings_sys::llama_pos = start_position;
93
94        let result = unsafe {
95            llama_cpp_bindings_sys::mtmd_helper_eval_chunks(
96                mtmd_ctx.context.as_ptr(),
97                llama_ctx.context.as_ptr(),
98                self.chunks.as_ptr(),
99                start_position,
100                seq_id,
101                n_batch,
102                logits_last,
103                &raw mut final_position,
104            )
105        };
106
107        check_eval_result(result)?;
108
109        Ok(final_position)
110    }
111}
112
113impl Drop for MtmdInputChunks {
114    fn drop(&mut self) {
115        unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_free(self.chunks.as_ptr()) }
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::MtmdInputChunks;
122
123    #[test]
124    fn new_creates_empty_chunks() {
125        let chunks = MtmdInputChunks::new().unwrap();
126
127        assert!(chunks.is_empty());
128        assert_eq!(chunks.len(), 0);
129    }
130
131    #[test]
132    fn get_out_of_bounds_returns_none() {
133        let chunks = MtmdInputChunks::new().unwrap();
134
135        assert!(chunks.get(0).is_none());
136        assert!(chunks.get(999).is_none());
137    }
138
139    #[test]
140    fn check_eval_result_ok_for_zero() {
141        use super::check_eval_result;
142
143        assert!(check_eval_result(0).is_ok());
144    }
145
146    #[test]
147    fn check_eval_result_error_for_nonzero() {
148        use super::MtmdEvalError;
149        use super::check_eval_result;
150
151        let err = check_eval_result(7).unwrap_err();
152        let representative = MtmdEvalError::EvalFailed { code: 0 };
153
154        assert_eq!(
155            std::mem::discriminant(&err),
156            std::mem::discriminant(&representative)
157        );
158    }
159}