llama_cpp_4/context.rs
1//! Safe wrapper around `llama_context`.
2
3use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use llama_cpp_sys_4::llama_pooling_type;
9use params::LlamaPoolingType;
10use perf::PerfContextData;
11
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::token::data::LlamaTokenData;
15use crate::token::data_array::LlamaTokenDataArray;
16use crate::token::LlamaToken;
17use crate::{
18 DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
19 LlamaLoraAdapterSetError,
20};
21
22pub mod kv_cache;
23pub mod params;
24pub mod perf;
25pub mod session;
26
27/// A safe wrapper around the `llama_context` C++ context.
28///
29/// This struct provides a safe interface to interact with the `llama_context` used by the `LlamaModel`.
30/// It encapsulates the raw C++ context pointer and provides additional fields for managing the model and
31/// context-specific settings like embeddings and logits.
32///
33/// The `LlamaContext` struct ensures that the C++ context is always valid by using the `NonNull` type for
34/// the context pointer, preventing it from being null. The struct also holds a reference to the model
35/// (`LlamaModel`) that the context is tied to, along with some internal state like whether embeddings are enabled
36/// and the initialized logits for the context.
37///
38/// # Fields
39///
40/// - `context`: A non-null pointer to the raw C++ `llama_context`. This is the main context used for interacting with the model.
41/// - `model`: A reference to the `LlamaModel` associated with this context. This model provides the data and parameters
42/// that the context interacts with.
43/// - `initialized_logits`: A vector used to store the initialized logits. These are used in the model's processing and
44/// are kept separate from the context data.
45/// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in the context. This is useful for
46/// controlling whether embedding data is generated during the interaction with the model.
47#[allow(clippy::module_name_repetitions)]
48pub struct LlamaContext<'a> {
49 pub(crate) context: NonNull<llama_cpp_sys_4::llama_context>,
50 /// a reference to the contexts model.
51 pub model: &'a LlamaModel,
52 initialized_logits: Vec<i32>,
53 embeddings_enabled: bool,
54}
55
56impl Debug for LlamaContext<'_> {
57 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("LlamaContext")
59 .field("context", &self.context)
60 .finish()
61 }
62}
63
64impl<'model> LlamaContext<'model> {
65 /// Creates a new instance of `LlamaContext` with the provided model, context, and embeddings flag.
66 ///
67 /// This function initializes a new `LlamaContext` object, which is used to interact with the
68 /// `LlamaModel`. The context is created from a pointer to a C++ context and the embeddings flag
69 /// determines whether embeddings are enabled in the context.
70 ///
71 /// # Parameters
72 ///
73 /// - `llama_model`: A reference to an existing `LlamaModel` that will be used with the new context.
74 /// - `llama_context`: A non-null pointer to an existing `llama_cpp_sys_4::llama_context` representing
75 /// the context created in previous steps. This context is necessary for interacting with the model.
76 /// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in this context.
77 ///
78 /// # Returns
79 ///
80 /// This function returns a new instance of `LlamaContext` initialized with the given parameters:
81 /// - The model reference (`llama_model`) is stored in the context.
82 /// - The raw context pointer (`llama_context`) is wrapped in a `NonNull` to ensure safety.
83 /// - The `embeddings_enabled` flag is used to determine if embeddings are enabled for the context.
84 ///
85 /// # Example
86 /// ```
87 /// let llama_model = LlamaModel::load("path/to/model").unwrap();
88 /// let context_ptr = NonNull::new(some_llama_context_ptr).unwrap();
89 /// let context = LlamaContext::new(&llama_model, context_ptr, true);
90 /// // Now you can use the context
91 /// ```
92 pub(crate) fn new(
93 llama_model: &'model LlamaModel,
94 llama_context: NonNull<llama_cpp_sys_4::llama_context>,
95 embeddings_enabled: bool,
96 ) -> Self {
97 Self {
98 context: llama_context,
99 model: llama_model,
100 initialized_logits: Vec::new(),
101 embeddings_enabled,
102 }
103 }
104
105 /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to `n_ubatch`.
106 #[must_use]
107 pub fn n_batch(&self) -> u32 {
108 unsafe { llama_cpp_sys_4::llama_n_batch(self.context.as_ptr()) }
109 }
110
111 /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to `n_batch`.
112 #[must_use]
113 pub fn n_ubatch(&self) -> u32 {
114 unsafe { llama_cpp_sys_4::llama_n_ubatch(self.context.as_ptr()) }
115 }
116
117 /// Gets the size of the context.
118 #[must_use]
119 pub fn n_ctx(&self) -> u32 {
120 unsafe { llama_cpp_sys_4::llama_n_ctx(self.context.as_ptr()) }
121 }
122
123 /// Decodes the batch.
124 ///
125 /// # Errors
126 ///
127 /// - `DecodeError` if the decoding failed.
128 ///
129 /// # Panics
130 ///
131 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
132 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
133 let result =
134 unsafe { llama_cpp_sys_4::llama_decode(self.context.as_ptr(), batch.llama_batch) };
135
136 match NonZeroI32::new(result) {
137 None => {
138 self.initialized_logits
139 .clone_from(&batch.initialized_logits);
140 Ok(())
141 }
142 Some(error) => Err(DecodeError::from(error)),
143 }
144 }
145
146 /// Encodes the batch.
147 ///
148 /// # Errors
149 ///
150 /// - `EncodeError` if the decoding failed.
151 ///
152 /// # Panics
153 ///
154 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
155 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
156 let result =
157 unsafe { llama_cpp_sys_4::llama_encode(self.context.as_ptr(), batch.llama_batch) };
158
159 match NonZeroI32::new(result) {
160 None => {
161 self.initialized_logits
162 .clone_from(&batch.initialized_logits);
163 Ok(())
164 }
165 Some(error) => Err(EncodeError::from(error)),
166 }
167 }
168
169 /// Return Pooling type for Llama's Context
170 #[must_use]
171 pub fn pooling_type(&self) -> LlamaPoolingType {
172 let pooling_type = unsafe { llama_pooling_type(self.context.as_ptr()) };
173
174 LlamaPoolingType::from(pooling_type)
175 }
176
177 /// Get the embeddings for the `i`th sequence in the current context.
178 ///
179 /// # Returns
180 ///
181 /// A slice containing the embeddings for the last decoded batch.
182 /// The size corresponds to the `n_embd` parameter of the context's model.
183 ///
184 /// # Errors
185 ///
186 /// - When the current context was constructed without enabling embeddings.
187 /// - If the current model had a pooling type of [`llama_cpp_sys_4::LLAMA_POOLING_TYPE_NONE`]
188 /// - If the given sequence index exceeds the max sequence id.
189 ///
190 /// # Panics
191 ///
192 /// * `n_embd` does not fit into a usize
193 pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
194 if !self.embeddings_enabled {
195 return Err(EmbeddingsError::NotEnabled);
196 }
197
198 let n_embd =
199 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
200
201 unsafe {
202 let embedding = llama_cpp_sys_4::llama_get_embeddings_seq(self.context.as_ptr(), i);
203
204 // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
205 if embedding.is_null() {
206 Err(EmbeddingsError::NonePoolType)
207 } else {
208 Ok(slice::from_raw_parts(embedding, n_embd))
209 }
210 }
211 }
212
213 /// Get the embeddings for the `i`th token in the current context.
214 ///
215 /// # Returns
216 ///
217 /// A slice containing the embeddings for the last decoded batch of the given token.
218 /// The size corresponds to the `n_embd` parameter of the context's model.
219 ///
220 /// # Errors
221 ///
222 /// - When the current context was constructed without enabling embeddings.
223 /// - When the given token didn't have logits enabled when it was passed.
224 /// - If the given token index exceeds the max token id.
225 ///
226 /// # Panics
227 ///
228 /// * `n_embd` does not fit into a usize
229 pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
230 if !self.embeddings_enabled {
231 return Err(EmbeddingsError::NotEnabled);
232 }
233
234 let n_embd =
235 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
236
237 unsafe {
238 let embedding = llama_cpp_sys_4::llama_get_embeddings_ith(self.context.as_ptr(), i);
239 // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
240 if embedding.is_null() {
241 Err(EmbeddingsError::LogitsNotEnabled)
242 } else {
243 Ok(slice::from_raw_parts(embedding, n_embd))
244 }
245 }
246 }
247
248 /// Get the logits for the last token in the context.
249 ///
250 /// # Returns
251 /// An iterator over unsorted `LlamaTokenData` containing the
252 /// logits for the last token in the context.
253 ///
254 /// # Panics
255 ///
256 /// - underlying logits data is null
257 pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
258 (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
259 let token = LlamaToken::new(i);
260 LlamaTokenData::new(token, *logit, 0_f32)
261 })
262 }
263
264 /// Get the token data array for the last token in the context.
265 ///
266 /// This is a convience method that implements:
267 /// ```ignore
268 /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
269 /// ```
270 ///
271 /// # Panics
272 ///
273 /// - underlying logits data is null
274 #[must_use]
275 pub fn token_data_array(&self) -> LlamaTokenDataArray {
276 LlamaTokenDataArray::from_iter(self.candidates(), false)
277 }
278
279 /// Token logits obtained from the last call to `decode()`.
280 /// The logits for which `batch.logits[i] != 0` are stored contiguously
281 /// in the order they have appeared in the batch.
282 /// Rows: number of tokens for which `batch.logits[i] != 0`
283 /// Cols: `n_vocab`
284 ///
285 /// # Returns
286 ///
287 /// A slice containing the logits for the last decoded token.
288 /// The size corresponds to the `n_vocab` parameter of the context's model.
289 ///
290 /// # Panics
291 ///
292 /// - `n_vocab` does not fit into a usize
293 /// - token data returned is null
294 #[must_use]
295 pub fn get_logits(&self) -> &[f32] {
296 let data = unsafe { llama_cpp_sys_4::llama_get_logits(self.context.as_ptr()) };
297 assert!(!data.is_null(), "logits data for last token is null");
298 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
299
300 unsafe { slice::from_raw_parts(data, len) }
301 }
302
303 /// Get the logits for the ith token in the context.
304 ///
305 /// # Panics
306 ///
307 /// - logit `i` is not initialized.
308 pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
309 (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
310 let token = LlamaToken::new(i);
311 LlamaTokenData::new(token, *logit, 0_f32)
312 })
313 }
314
315 /// Get the logits for the ith token in the context.
316 ///
317 /// # Panics
318 ///
319 /// - `i` is greater than `n_ctx`
320 /// - `n_vocab` does not fit into a usize
321 /// - logit `i` is not initialized.
322 #[must_use]
323 pub fn get_logits_ith(&self, i: i32) -> &[f32] {
324 assert!(
325 self.initialized_logits.contains(&i),
326 "logit {i} is not initialized. only {:?} is",
327 self.initialized_logits
328 );
329 assert!(
330 self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
331 "n_ctx ({}) must be greater than i ({})",
332 self.n_ctx(),
333 i
334 );
335
336 let data = unsafe { llama_cpp_sys_4::llama_get_logits_ith(self.context.as_ptr(), i) };
337 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
338
339 unsafe { slice::from_raw_parts(data, len) }
340 }
341
342 /// Reset the timings for the context.
343 pub fn reset_timings(&mut self) {
344 unsafe { llama_cpp_sys_4::ggml_time_init() }
345 }
346
347 /// Returns the timings for the context.
348 pub fn timings(&mut self) -> PerfContextData {
349 let perf_context_data =
350 unsafe { llama_cpp_sys_4::llama_perf_context(self.context.as_ptr()) };
351 PerfContextData { perf_context_data }
352 }
353
354 /// Sets a lora adapter.
355 ///
356 /// # Errors
357 ///
358 /// See [`LlamaLoraAdapterSetError`] for more information.
359 pub fn lora_adapter_set(
360 &self,
361 adapter: &mut LlamaLoraAdapter,
362 scale: f32,
363 ) -> Result<(), LlamaLoraAdapterSetError> {
364 let err_code = unsafe {
365 // llama_set_adapter_lora / llama_rm_adapter_lora were replaced by llama_set_adapters_lora
366 // which takes a full list of adapters + scales at once (b8249+)
367 let mut adapter_ptr = adapter.lora_adapter.as_ptr();
368 let mut scale_val = scale;
369 llama_cpp_sys_4::llama_set_adapters_lora(
370 self.context.as_ptr(),
371 &raw mut adapter_ptr,
372 1,
373 &raw mut scale_val,
374 )
375 };
376 if err_code != 0 {
377 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
378 }
379
380 tracing::debug!("Set lora adapter");
381 Ok(())
382 }
383
384 /// Remove all lora adapters from the context.
385 ///
386 /// Note: as of llama.cpp b8249 the per-adapter remove API was replaced by
387 /// `llama_set_adapters_lora` which operates on the full adapter list at once.
388 /// Calling this function clears **all** adapters currently set on the context.
389 ///
390 /// # Errors
391 ///
392 /// See [`LlamaLoraAdapterRemoveError`] for more information.
393 pub fn lora_adapter_remove(
394 &self,
395 _adapter: &mut LlamaLoraAdapter,
396 ) -> Result<(), LlamaLoraAdapterRemoveError> {
397 let err_code = unsafe {
398 llama_cpp_sys_4::llama_set_adapters_lora(
399 self.context.as_ptr(),
400 std::ptr::null_mut(),
401 0,
402 std::ptr::null_mut(),
403 )
404 };
405 if err_code != 0 {
406 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
407 }
408
409 tracing::debug!("Remove lora adapter");
410 Ok(())
411 }
412}
413
414impl Drop for LlamaContext<'_> {
415 fn drop(&mut self) {
416 unsafe { llama_cpp_sys_4::llama_free(self.context.as_ptr()) }
417 }
418}