llama_cpp_bindings/mtmd/
mtmd_input_chunks.rs1use std::ptr::NonNull;
2
3use crate::context::LlamaContext;
4
5use super::mtmd_context::MtmdContext;
6use super::mtmd_eval_error::MtmdEvalError;
7use super::mtmd_input_chunk::MtmdInputChunk;
8use super::mtmd_input_chunks_error::MtmdInputChunksError;
9
10const fn check_eval_result(result: i32) -> Result<(), MtmdEvalError> {
11 if result == 0 {
12 Ok(())
13 } else {
14 Err(MtmdEvalError::EvalFailed { code: result })
15 }
16}
17
18#[derive(Debug)]
24pub struct MtmdInputChunks {
25 pub chunks: NonNull<llama_cpp_bindings_sys::mtmd_input_chunks>,
27}
28
29impl MtmdInputChunks {
30 pub fn new() -> Result<Self, MtmdInputChunksError> {
47 let chunks = unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_init() };
48 let chunks = NonNull::new(chunks).ok_or(MtmdInputChunksError::ChunksCreationFailed)?;
49
50 Ok(Self { chunks })
51 }
52
53 #[must_use]
55 pub fn len(&self) -> usize {
56 unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_size(self.chunks.as_ptr()) }
57 }
58
59 #[must_use]
61 pub fn is_empty(&self) -> bool {
62 self.len() == 0
63 }
64
65 #[must_use]
67 pub fn get(&self, index: usize) -> Option<MtmdInputChunk> {
68 if index >= self.len() {
69 return None;
70 }
71
72 let chunk_ptr =
73 unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_get(self.chunks.as_ptr(), index) };
74
75 NonNull::new(chunk_ptr.cast_mut()).map(|ptr| MtmdInputChunk {
76 chunk: ptr,
77 owned: false,
78 })
79 }
80
81 #[must_use]
83 pub fn total_tokens(&self) -> usize {
84 unsafe { llama_cpp_bindings_sys::mtmd_helper_get_n_tokens(self.chunks.as_ptr()) }
85 }
86
87 #[must_use]
89 pub fn total_positions(&self) -> i32 {
90 unsafe { llama_cpp_bindings_sys::mtmd_helper_get_n_pos(self.chunks.as_ptr()) }
91 }
92
93 pub fn eval_chunks(
99 &self,
100 mtmd_ctx: &MtmdContext,
101 llama_ctx: &LlamaContext,
102 start_position: llama_cpp_bindings_sys::llama_pos,
103 seq_id: llama_cpp_bindings_sys::llama_seq_id,
104 n_batch: i32,
105 logits_last: bool,
106 ) -> Result<llama_cpp_bindings_sys::llama_pos, MtmdEvalError> {
107 let context_max_batch = llama_ctx.n_batch();
108
109 if n_batch > 0 && n_batch.cast_unsigned() > context_max_batch {
110 return Err(MtmdEvalError::BatchSizeExceedsContextLimit {
111 requested: n_batch,
112 context_max: context_max_batch,
113 });
114 }
115
116 let mut final_position: llama_cpp_bindings_sys::llama_pos = start_position;
122
123 let result = unsafe {
124 llama_cpp_bindings_sys::mtmd_helper_eval_chunks(
125 mtmd_ctx.context.as_ptr(),
126 llama_ctx.context.as_ptr(),
127 self.chunks.as_ptr(),
128 start_position,
129 seq_id,
130 n_batch,
131 logits_last,
132 &raw mut final_position,
133 )
134 };
135
136 check_eval_result(result)?;
137
138 Ok(final_position)
139 }
140}
141
142impl Drop for MtmdInputChunks {
143 fn drop(&mut self) {
144 unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_free(self.chunks.as_ptr()) }
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::MtmdInputChunks;
151
152 #[test]
153 fn new_creates_empty_chunks() {
154 let chunks = MtmdInputChunks::new().unwrap();
155
156 assert!(chunks.is_empty());
157 assert_eq!(chunks.len(), 0);
158 }
159
160 #[test]
161 fn get_out_of_bounds_returns_none() {
162 let chunks = MtmdInputChunks::new().unwrap();
163
164 assert!(chunks.get(0).is_none());
165 assert!(chunks.get(999).is_none());
166 }
167
168 #[test]
169 fn check_eval_result_ok_for_zero() {
170 use super::check_eval_result;
171
172 assert!(check_eval_result(0).is_ok());
173 }
174
175 #[test]
176 fn check_eval_result_error_for_nonzero() {
177 use super::MtmdEvalError;
178 use super::check_eval_result;
179
180 let result = check_eval_result(7);
181
182 assert!(matches!(result, Err(MtmdEvalError::EvalFailed { code: 7 })));
183 }
184}