llama_cpp_bindings/
context.rs1use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use crate::llama_batch::LlamaBatch;
9use crate::model::{LlamaLoraAdapter, LlamaModel};
10use crate::timing::LlamaTimings;
11use crate::token::LlamaToken;
12use crate::token::data::LlamaTokenData;
13use crate::token::data_array::LlamaTokenDataArray;
14use crate::{
15 DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
16 LlamaLoraAdapterSetError,
17};
18
19pub mod kv_cache;
20pub mod params;
21pub mod session;
22
23pub struct LlamaContext<'model> {
25 pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
27 pub model: &'model LlamaModel,
29 initialized_logits: Vec<i32>,
30 embeddings_enabled: bool,
31}
32
33impl Debug for LlamaContext<'_> {
34 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("LlamaContext")
36 .field("context", &self.context)
37 .finish()
38 }
39}
40
41impl<'model> LlamaContext<'model> {
42 #[must_use]
44 pub fn new(
45 llama_model: &'model LlamaModel,
46 llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
47 embeddings_enabled: bool,
48 ) -> Self {
49 Self {
50 context: llama_context,
51 model: llama_model,
52 initialized_logits: Vec::new(),
53 embeddings_enabled,
54 }
55 }
56
57 #[must_use]
59 pub fn n_batch(&self) -> u32 {
60 unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
61 }
62
63 #[must_use]
65 pub fn n_ubatch(&self) -> u32 {
66 unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
67 }
68
69 #[must_use]
71 pub fn n_ctx(&self) -> u32 {
72 unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
73 }
74
75 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
85 let result = unsafe {
86 llama_cpp_bindings_sys::llama_decode(self.context.as_ptr(), batch.llama_batch)
87 };
88
89 match NonZeroI32::new(result) {
90 None => {
91 self.initialized_logits
92 .clone_from(&batch.initialized_logits);
93 Ok(())
94 }
95 Some(error) => Err(DecodeError::from(error)),
96 }
97 }
98
99 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
109 let result = unsafe {
110 llama_cpp_bindings_sys::llama_encode(self.context.as_ptr(), batch.llama_batch)
111 };
112
113 match NonZeroI32::new(result) {
114 None => {
115 self.initialized_logits
116 .clone_from(&batch.initialized_logits);
117 Ok(())
118 }
119 Some(error) => Err(EncodeError::from(error)),
120 }
121 }
122
123 pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
140 if !self.embeddings_enabled {
141 return Err(EmbeddingsError::NotEnabled);
142 }
143
144 let n_embd =
145 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
146
147 unsafe {
148 let embedding =
149 llama_cpp_bindings_sys::llama_get_embeddings_seq(self.context.as_ptr(), i);
150
151 if embedding.is_null() {
153 Err(EmbeddingsError::NonePoolType)
154 } else {
155 Ok(slice::from_raw_parts(embedding, n_embd))
156 }
157 }
158 }
159
160 pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
177 if !self.embeddings_enabled {
178 return Err(EmbeddingsError::NotEnabled);
179 }
180
181 let n_embd =
182 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
183
184 unsafe {
185 let embedding =
186 llama_cpp_bindings_sys::llama_get_embeddings_ith(self.context.as_ptr(), i);
187 if embedding.is_null() {
189 Err(EmbeddingsError::LogitsNotEnabled)
190 } else {
191 Ok(slice::from_raw_parts(embedding, n_embd))
192 }
193 }
194 }
195
196 pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
206 (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
207 let token = LlamaToken::new(i);
208 LlamaTokenData::new(token, *logit, 0_f32)
209 })
210 }
211
212 #[must_use]
223 pub fn token_data_array(&self) -> LlamaTokenDataArray {
224 LlamaTokenDataArray::from_iter(self.candidates(), false)
225 }
226
227 #[must_use]
243 pub fn get_logits(&self) -> &[f32] {
244 let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
245 assert!(!data.is_null(), "logits data for last token is null");
246 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
247
248 unsafe { slice::from_raw_parts(data, len) }
249 }
250
251 pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
257 (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
258 let token = LlamaToken::new(i);
259 LlamaTokenData::new(token, *logit, 0_f32)
260 })
261 }
262
263 #[must_use]
274 pub fn token_data_array_ith(&self, i: i32) -> LlamaTokenDataArray {
275 LlamaTokenDataArray::from_iter(self.candidates_ith(i), false)
276 }
277
278 #[must_use]
286 pub fn get_logits_ith(&self, i: i32) -> &[f32] {
287 assert!(
288 self.initialized_logits.contains(&i),
289 "logit {i} is not initialized. only {:?} is",
290 self.initialized_logits
291 );
292 assert!(
293 self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
294 "n_ctx ({}) must be greater than i ({})",
295 self.n_ctx(),
296 i
297 );
298
299 let data =
300 unsafe { llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), i) };
301 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
302
303 unsafe { slice::from_raw_parts(data, len) }
304 }
305
306 pub fn reset_timings(&mut self) {
308 unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
309 }
310
311 pub fn timings(&mut self) -> LlamaTimings {
313 let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
314 LlamaTimings { timings }
315 }
316
317 pub fn lora_adapter_set(
323 &self,
324 adapter: &mut LlamaLoraAdapter,
325 scale: f32,
326 ) -> Result<(), LlamaLoraAdapterSetError> {
327 let mut adapters = [adapter.lora_adapter.as_ptr()];
328 let mut scales = [scale];
329 let err_code = unsafe {
330 llama_cpp_bindings_sys::llama_set_adapters_lora(
331 self.context.as_ptr(),
332 adapters.as_mut_ptr(),
333 1,
334 scales.as_mut_ptr(),
335 )
336 };
337 if err_code != 0 {
338 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
339 }
340
341 tracing::debug!("Set lora adapter");
342 Ok(())
343 }
344
345 pub fn lora_adapter_remove(
354 &self,
355 _adapter: &mut LlamaLoraAdapter,
356 ) -> Result<(), LlamaLoraAdapterRemoveError> {
357 let err_code = unsafe {
358 llama_cpp_bindings_sys::llama_set_adapters_lora(
359 self.context.as_ptr(),
360 std::ptr::null_mut(),
361 0,
362 std::ptr::null_mut(),
363 )
364 };
365 if err_code != 0 {
366 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
367 }
368
369 tracing::debug!("Remove lora adapter");
370 Ok(())
371 }
372}
373
374impl Drop for LlamaContext<'_> {
375 fn drop(&mut self) {
376 unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
377 }
378}