llama_cpp_bindings/
context.rs1use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use crate::llama_batch::LlamaBatch;
9use crate::model::{LlamaLoraAdapter, LlamaModel};
10use crate::timing::LlamaTimings;
11use crate::token::LlamaToken;
12use crate::token::data::LlamaTokenData;
13use crate::token::data_array::LlamaTokenDataArray;
14use crate::{
15 DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
16 LlamaLoraAdapterSetError,
17};
18
19pub mod kv_cache;
20pub mod params;
21pub mod session;
22
23pub struct LlamaContext<'model> {
25 pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
27 pub model: &'model LlamaModel,
29 initialized_logits: Vec<i32>,
30 embeddings_enabled: bool,
31}
32
33impl Debug for LlamaContext<'_> {
34 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("LlamaContext")
36 .field("context", &self.context)
37 .finish()
38 }
39}
40
41impl<'model> LlamaContext<'model> {
42 #[must_use]
44 pub fn new(
45 llama_model: &'model LlamaModel,
46 llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
47 embeddings_enabled: bool,
48 ) -> Self {
49 Self {
50 context: llama_context,
51 model: llama_model,
52 initialized_logits: Vec::new(),
53 embeddings_enabled,
54 }
55 }
56
57 #[must_use]
59 pub fn n_batch(&self) -> u32 {
60 unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
61 }
62
63 #[must_use]
65 pub fn n_ubatch(&self) -> u32 {
66 unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
67 }
68
69 #[must_use]
71 pub fn n_ctx(&self) -> u32 {
72 unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
73 }
74
75 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
85 let result = unsafe {
86 llama_cpp_bindings_sys::llama_decode(self.context.as_ptr(), batch.llama_batch)
87 };
88
89 match NonZeroI32::new(result) {
90 None => {
91 self.initialized_logits
92 .clone_from(&batch.initialized_logits);
93 Ok(())
94 }
95 Some(error) => Err(DecodeError::from(error)),
96 }
97 }
98
99 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
109 let result = unsafe {
110 llama_cpp_bindings_sys::llama_encode(self.context.as_ptr(), batch.llama_batch)
111 };
112
113 match NonZeroI32::new(result) {
114 None => {
115 self.initialized_logits
116 .clone_from(&batch.initialized_logits);
117 Ok(())
118 }
119 Some(error) => Err(EncodeError::from(error)),
120 }
121 }
122
123 pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
140 if !self.embeddings_enabled {
141 return Err(EmbeddingsError::NotEnabled);
142 }
143
144 let n_embd =
145 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
146
147 unsafe {
148 let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
149 self.context.as_ptr(),
150 sequence_index,
151 );
152
153 if embedding.is_null() {
154 Err(EmbeddingsError::NonePoolType)
155 } else {
156 Ok(slice::from_raw_parts(embedding, n_embd))
157 }
158 }
159 }
160
161 pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
178 if !self.embeddings_enabled {
179 return Err(EmbeddingsError::NotEnabled);
180 }
181
182 let n_embd =
183 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
184
185 unsafe {
186 let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
187 self.context.as_ptr(),
188 token_index,
189 );
190
191 if embedding.is_null() {
192 Err(EmbeddingsError::LogitsNotEnabled)
193 } else {
194 Ok(slice::from_raw_parts(embedding, n_embd))
195 }
196 }
197 }
198
199 pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
209 (0_i32..).zip(self.get_logits()).map(|(token_id, logit)| {
210 let token = LlamaToken::new(token_id);
211 LlamaTokenData::new(token, *logit, 0_f32)
212 })
213 }
214
215 #[must_use]
226 pub fn token_data_array(&self) -> LlamaTokenDataArray {
227 LlamaTokenDataArray::from_iter(self.candidates(), false)
228 }
229
230 #[must_use]
246 pub fn get_logits(&self) -> &[f32] {
247 let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
248 assert!(!data.is_null(), "logits data for last token is null");
249 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
250
251 unsafe { slice::from_raw_parts(data, len) }
252 }
253
254 pub fn candidates_ith(&self, token_index: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
260 (0_i32..)
261 .zip(self.get_logits_ith(token_index))
262 .map(|(token_id, logit)| {
263 let token = LlamaToken::new(token_id);
264 LlamaTokenData::new(token, *logit, 0_f32)
265 })
266 }
267
268 #[must_use]
279 pub fn token_data_array_ith(&self, token_index: i32) -> LlamaTokenDataArray {
280 LlamaTokenDataArray::from_iter(self.candidates_ith(token_index), false)
281 }
282
283 #[must_use]
291 pub fn get_logits_ith(&self, token_index: i32) -> &[f32] {
292 assert!(
293 self.initialized_logits.contains(&token_index),
294 "logit {token_index} is not initialized. only {:?} is",
295 self.initialized_logits
296 );
297 assert!(
298 self.n_ctx() > u32::try_from(token_index).expect("token_index does not fit into a u32"),
299 "n_ctx ({}) must be greater than token_index ({})",
300 self.n_ctx(),
301 token_index
302 );
303
304 let data = unsafe {
305 llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
306 };
307 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
308
309 unsafe { slice::from_raw_parts(data, len) }
310 }
311
312 pub fn reset_timings(&mut self) {
314 unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
315 }
316
317 pub fn timings(&mut self) -> LlamaTimings {
319 let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
320 LlamaTimings { timings }
321 }
322
323 pub fn lora_adapter_set(
329 &self,
330 adapter: &mut LlamaLoraAdapter,
331 scale: f32,
332 ) -> Result<(), LlamaLoraAdapterSetError> {
333 let mut adapters = [adapter.lora_adapter.as_ptr()];
334 let mut scales = [scale];
335 let err_code = unsafe {
336 llama_cpp_bindings_sys::llama_set_adapters_lora(
337 self.context.as_ptr(),
338 adapters.as_mut_ptr(),
339 1,
340 scales.as_mut_ptr(),
341 )
342 };
343 if err_code != 0 {
344 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
345 }
346
347 tracing::debug!("Set lora adapter");
348 Ok(())
349 }
350
351 pub fn lora_adapter_remove(
360 &self,
361 _adapter: &mut LlamaLoraAdapter,
362 ) -> Result<(), LlamaLoraAdapterRemoveError> {
363 let err_code = unsafe {
364 llama_cpp_bindings_sys::llama_set_adapters_lora(
365 self.context.as_ptr(),
366 std::ptr::null_mut(),
367 0,
368 std::ptr::null_mut(),
369 )
370 };
371 if err_code != 0 {
372 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
373 }
374
375 tracing::debug!("Remove lora adapter");
376 Ok(())
377 }
378}
379
380impl Drop for LlamaContext<'_> {
381 fn drop(&mut self) {
382 unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
383 }
384}