llama_cpp_bindings/mtmd/
mtmd_input_chunks.rs1use std::ptr::NonNull;
2
3use crate::context::LlamaContext;
4
5use super::mtmd_context::MtmdContext;
6use super::mtmd_eval_error::MtmdEvalError;
7use super::mtmd_input_chunk::MtmdInputChunk;
8use super::mtmd_input_chunks_error::MtmdInputChunksError;
9
10const fn check_eval_result(result: i32) -> Result<(), MtmdEvalError> {
11 if result == 0 {
12 Ok(())
13 } else {
14 Err(MtmdEvalError::EvalFailed { code: result })
15 }
16}
17
18#[derive(Debug)]
19pub struct MtmdInputChunks {
20 pub chunks: NonNull<llama_cpp_bindings_sys::mtmd_input_chunks>,
21}
22
23impl MtmdInputChunks {
24 pub fn new() -> Result<Self, MtmdInputChunksError> {
30 let chunks = unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_init() };
31 let chunks = NonNull::new(chunks).ok_or(MtmdInputChunksError::ChunksCreationFailed)?;
32
33 Ok(Self { chunks })
34 }
35
36 #[must_use]
37 pub fn len(&self) -> usize {
38 unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_size(self.chunks.as_ptr()) }
39 }
40
41 #[must_use]
42 pub fn is_empty(&self) -> bool {
43 self.len() == 0
44 }
45
46 #[must_use]
47 pub fn get(&self, index: usize) -> Option<MtmdInputChunk> {
48 if index >= self.len() {
49 return None;
50 }
51
52 let chunk_ptr =
53 unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_get(self.chunks.as_ptr(), index) };
54
55 NonNull::new(chunk_ptr.cast_mut()).map(|ptr| MtmdInputChunk {
56 chunk: ptr,
57 owned: false,
58 })
59 }
60
61 #[must_use]
62 pub fn total_tokens(&self) -> usize {
63 unsafe { llama_cpp_bindings_sys::mtmd_helper_get_n_tokens(self.chunks.as_ptr()) }
64 }
65
66 #[must_use]
67 pub fn total_positions(&self) -> i32 {
68 unsafe { llama_cpp_bindings_sys::mtmd_helper_get_n_pos(self.chunks.as_ptr()) }
69 }
70
71 pub fn eval_chunks(
75 &self,
76 mtmd_ctx: &MtmdContext,
77 llama_ctx: &LlamaContext,
78 start_position: llama_cpp_bindings_sys::llama_pos,
79 seq_id: llama_cpp_bindings_sys::llama_seq_id,
80 n_batch: i32,
81 logits_last: bool,
82 ) -> Result<llama_cpp_bindings_sys::llama_pos, MtmdEvalError> {
83 let context_max_batch = llama_ctx.n_batch();
84
85 if n_batch > 0 && n_batch.cast_unsigned() > context_max_batch {
86 return Err(MtmdEvalError::BatchSizeExceedsContextLimit {
87 requested: n_batch,
88 context_max: context_max_batch,
89 });
90 }
91
92 let mut final_position: llama_cpp_bindings_sys::llama_pos = start_position;
93
94 let result = unsafe {
95 llama_cpp_bindings_sys::mtmd_helper_eval_chunks(
96 mtmd_ctx.context.as_ptr(),
97 llama_ctx.context.as_ptr(),
98 self.chunks.as_ptr(),
99 start_position,
100 seq_id,
101 n_batch,
102 logits_last,
103 &raw mut final_position,
104 )
105 };
106
107 check_eval_result(result)?;
108
109 Ok(final_position)
110 }
111}
112
113impl Drop for MtmdInputChunks {
114 fn drop(&mut self) {
115 unsafe { llama_cpp_bindings_sys::mtmd_input_chunks_free(self.chunks.as_ptr()) }
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::MtmdInputChunks;
122
123 #[test]
124 fn new_creates_empty_chunks() {
125 let chunks = MtmdInputChunks::new().unwrap();
126
127 assert!(chunks.is_empty());
128 assert_eq!(chunks.len(), 0);
129 }
130
131 #[test]
132 fn get_out_of_bounds_returns_none() {
133 let chunks = MtmdInputChunks::new().unwrap();
134
135 assert!(chunks.get(0).is_none());
136 assert!(chunks.get(999).is_none());
137 }
138
139 #[test]
140 fn check_eval_result_ok_for_zero() {
141 use super::check_eval_result;
142
143 assert!(check_eval_result(0).is_ok());
144 }
145
146 #[test]
147 fn check_eval_result_error_for_nonzero() {
148 use super::MtmdEvalError;
149 use super::check_eval_result;
150
151 let err = check_eval_result(7).unwrap_err();
152 let representative = MtmdEvalError::EvalFailed { code: 0 };
153
154 assert_eq!(
155 std::mem::discriminant(&err),
156 std::mem::discriminant(&representative)
157 );
158 }
159}