1use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use crate::llama_batch::LlamaBatch;
9use crate::model::{LlamaLoraAdapter, LlamaModel};
10use crate::timing::LlamaTimings;
11use crate::token::data::LlamaTokenData;
12use crate::token::data_array::LlamaTokenDataArray;
13use crate::token::LlamaToken;
14use crate::{
15 DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
16 LlamaLoraAdapterSetError,
17};
18
19pub mod kv_cache;
20pub mod params;
21pub mod session;
22
23#[allow(clippy::module_name_repetitions)]
25pub struct LlamaContext<'a> {
26 pub(crate) context: NonNull<llama_cpp_sys_2::llama_context>,
27 pub model: &'a LlamaModel,
29 initialized_logits: Vec<i32>,
30 embeddings_enabled: bool,
31}
32
33impl Debug for LlamaContext<'_> {
34 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("LlamaContext")
36 .field("context", &self.context)
37 .finish()
38 }
39}
40
41impl<'model> LlamaContext<'model> {
42 pub(crate) fn new(
43 llama_model: &'model LlamaModel,
44 llama_context: NonNull<llama_cpp_sys_2::llama_context>,
45 embeddings_enabled: bool,
46 ) -> Self {
47 Self {
48 context: llama_context,
49 model: llama_model,
50 initialized_logits: Vec::new(),
51 embeddings_enabled,
52 }
53 }
54
55 #[must_use]
57 pub fn n_batch(&self) -> u32 {
58 unsafe { llama_cpp_sys_2::llama_n_batch(self.context.as_ptr()) }
59 }
60
61 #[must_use]
63 pub fn n_ubatch(&self) -> u32 {
64 unsafe { llama_cpp_sys_2::llama_n_ubatch(self.context.as_ptr()) }
65 }
66
67 #[must_use]
69 pub fn n_ctx(&self) -> u32 {
70 unsafe { llama_cpp_sys_2::llama_n_ctx(self.context.as_ptr()) }
71 }
72
73 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
83 let result =
84 unsafe { llama_cpp_sys_2::llama_decode(self.context.as_ptr(), batch.llama_batch) };
85
86 match NonZeroI32::new(result) {
87 None => {
88 self.initialized_logits
89 .clone_from(&batch.initialized_logits);
90 Ok(())
91 }
92 Some(error) => Err(DecodeError::from(error)),
93 }
94 }
95
96 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
106 let result =
107 unsafe { llama_cpp_sys_2::llama_encode(self.context.as_ptr(), batch.llama_batch) };
108
109 match NonZeroI32::new(result) {
110 None => {
111 self.initialized_logits
112 .clone_from(&batch.initialized_logits);
113 Ok(())
114 }
115 Some(error) => Err(EncodeError::from(error)),
116 }
117 }
118
119 pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
136 if !self.embeddings_enabled {
137 return Err(EmbeddingsError::NotEnabled);
138 }
139
140 let n_embd =
141 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
142
143 unsafe {
144 let embedding = llama_cpp_sys_2::llama_get_embeddings_seq(self.context.as_ptr(), i);
145
146 if embedding.is_null() {
148 Err(EmbeddingsError::NonePoolType)
149 } else {
150 Ok(slice::from_raw_parts(embedding, n_embd))
151 }
152 }
153 }
154
155 pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
172 if !self.embeddings_enabled {
173 return Err(EmbeddingsError::NotEnabled);
174 }
175
176 let n_embd =
177 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
178
179 unsafe {
180 let embedding = llama_cpp_sys_2::llama_get_embeddings_ith(self.context.as_ptr(), i);
181 if embedding.is_null() {
183 Err(EmbeddingsError::LogitsNotEnabled)
184 } else {
185 Ok(slice::from_raw_parts(embedding, n_embd))
186 }
187 }
188 }
189
190 pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
200 (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
201 let token = LlamaToken::new(i);
202 LlamaTokenData::new(token, *logit, 0_f32)
203 })
204 }
205
206 #[must_use]
217 pub fn token_data_array(&self) -> LlamaTokenDataArray {
218 LlamaTokenDataArray::from_iter(self.candidates(), false)
219 }
220
221 #[must_use]
237 pub fn get_logits(&self) -> &[f32] {
238 let data = unsafe { llama_cpp_sys_2::llama_get_logits(self.context.as_ptr()) };
239 assert!(!data.is_null(), "logits data for last token is null");
240 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
241
242 unsafe { slice::from_raw_parts(data, len) }
243 }
244
245 pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
251 (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
252 let token = LlamaToken::new(i);
253 LlamaTokenData::new(token, *logit, 0_f32)
254 })
255 }
256
257 #[must_use]
268 pub fn token_data_array_ith(&self, i: i32) -> LlamaTokenDataArray {
269 LlamaTokenDataArray::from_iter(self.candidates_ith(i), false)
270 }
271
272 #[must_use]
280 pub fn get_logits_ith(&self, i: i32) -> &[f32] {
281 assert!(
282 self.initialized_logits.contains(&i),
283 "logit {i} is not initialized. only {:?} is",
284 self.initialized_logits
285 );
286 assert!(
287 self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
288 "n_ctx ({}) must be greater than i ({})",
289 self.n_ctx(),
290 i
291 );
292
293 let data = unsafe { llama_cpp_sys_2::llama_get_logits_ith(self.context.as_ptr(), i) };
294 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
295
296 unsafe { slice::from_raw_parts(data, len) }
297 }
298
299 pub fn reset_timings(&mut self) {
301 unsafe { llama_cpp_sys_2::llama_perf_context_reset(self.context.as_ptr()) }
302 }
303
304 pub fn timings(&mut self) -> LlamaTimings {
306 let timings = unsafe { llama_cpp_sys_2::llama_perf_context(self.context.as_ptr()) };
307 LlamaTimings { timings }
308 }
309
310 pub fn lora_adapter_set(
316 &self,
317 adapter: &mut LlamaLoraAdapter,
318 scale: f32,
319 ) -> Result<(), LlamaLoraAdapterSetError> {
320 let err_code = unsafe {
321 llama_cpp_sys_2::llama_set_adapter_lora(
322 self.context.as_ptr(),
323 adapter.lora_adapter.as_ptr(),
324 scale,
325 )
326 };
327 if err_code != 0 {
328 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
329 }
330
331 tracing::debug!("Set lora adapter");
332 Ok(())
333 }
334
335 pub fn lora_adapter_remove(
341 &self,
342 adapter: &mut LlamaLoraAdapter,
343 ) -> Result<(), LlamaLoraAdapterRemoveError> {
344 let err_code = unsafe {
345 llama_cpp_sys_2::llama_rm_adapter_lora(
346 self.context.as_ptr(),
347 adapter.lora_adapter.as_ptr(),
348 )
349 };
350 if err_code != 0 {
351 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
352 }
353
354 tracing::debug!("Remove lora adapter");
355 Ok(())
356 }
357}
358
359impl Drop for LlamaContext<'_> {
360 fn drop(&mut self) {
361 unsafe { llama_cpp_sys_2::llama_free(self.context.as_ptr()) }
362 }
363}