llama_cpp_4/context.rs
1//! Safe wrapper around `llama_context`.
2//!
3//! Submodules:
4//!
5//! - [`tensor_capture`] — hook `cb_eval` during [`LlamaContext::decode`] to copy
6//! intermediate tensors (per-layer hidden states, norms, …).
7//! - [`memory_breakdown`] — per-buffer memory usage after load/decode.
8//! - [`kv_cache`] — sequence copy, shift, and clear helpers.
9
10use std::fmt::{Debug, Formatter};
11use std::num::NonZeroI32;
12use std::ptr::NonNull;
13use std::slice;
14
15use llama_cpp_sys_4::llama_pooling_type;
16use params::LlamaPoolingType;
17use perf::PerfContextData;
18
19use crate::llama_batch::LlamaBatch;
20use crate::model::{LlamaLoraAdapter, LlamaModel};
21use crate::token::data::LlamaTokenData;
22use crate::token::data_array::LlamaTokenDataArray;
23use crate::token::LlamaToken;
24use crate::{
25 DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
26 LlamaLoraAdapterSetError,
27};
28
29pub mod kv_cache;
30pub mod memory_breakdown;
31pub mod params;
32pub mod perf;
33pub mod session;
34pub mod tensor_capture;
35
36pub use memory_breakdown::MemoryBreakdownEntry;
37pub use tensor_capture::{CapturedTensor, TensorCapture};
38
39/// A safe wrapper around the `llama_context` C++ context.
40///
41/// This struct provides a safe interface to interact with the `llama_context` used by the `LlamaModel`.
42/// It encapsulates the raw C++ context pointer and provides additional fields for managing the model and
43/// context-specific settings like embeddings and logits.
44///
45/// The `LlamaContext` struct ensures that the C++ context is always valid by using the `NonNull` type for
46/// the context pointer, preventing it from being null. The struct also holds a reference to the model
47/// (`LlamaModel`) that the context is tied to, along with some internal state like whether embeddings are enabled
48/// and the initialized logits for the context.
49///
50/// # Fields
51///
52/// - `context`: A non-null pointer to the raw C++ `llama_context`. This is the main context used for interacting with the model.
53/// - `model`: A reference to the `LlamaModel` associated with this context. This model provides the data and parameters
54/// that the context interacts with.
55/// - `initialized_logits`: A vector used to store the initialized logits. These are used in the model's processing and
56/// are kept separate from the context data.
57/// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in the context. This is useful for
58/// controlling whether embedding data is generated during the interaction with the model.
59#[allow(clippy::module_name_repetitions)]
60pub struct LlamaContext<'a> {
61 pub(crate) context: NonNull<llama_cpp_sys_4::llama_context>,
62 /// a reference to the contexts model.
63 pub model: &'a LlamaModel,
64 initialized_logits: Vec<i32>,
65 embeddings_enabled: bool,
66}
67
68impl Debug for LlamaContext<'_> {
69 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
70 f.debug_struct("LlamaContext")
71 .field("context", &self.context)
72 .finish()
73 }
74}
75
76impl<'model> LlamaContext<'model> {
77 /// Creates a new instance of `LlamaContext` with the provided model, context, and embeddings flag.
78 ///
79 /// This function initializes a new `LlamaContext` object, which is used to interact with the
80 /// `LlamaModel`. The context is created from a pointer to a C++ context and the embeddings flag
81 /// determines whether embeddings are enabled in the context.
82 ///
83 /// # Parameters
84 ///
85 /// - `llama_model`: A reference to an existing `LlamaModel` that will be used with the new context.
86 /// - `llama_context`: A non-null pointer to an existing `llama_cpp_sys_4::llama_context` representing
87 /// the context created in previous steps. This context is necessary for interacting with the model.
88 /// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in this context.
89 ///
90 /// # Returns
91 ///
92 /// This function returns a new instance of `LlamaContext` initialized with the given parameters:
93 /// - The model reference (`llama_model`) is stored in the context.
94 /// - The raw context pointer (`llama_context`) is wrapped in a `NonNull` to ensure safety.
95 /// - The `embeddings_enabled` flag is used to determine if embeddings are enabled for the context.
96 ///
97 /// # Example
98 /// ```ignore
99 /// let llama_model = LlamaModel::load_from_file(&backend, "path/to/model", ¶ms).unwrap();
100 /// let context_ptr = NonNull::new(some_llama_context_ptr).unwrap();
101 /// let context = LlamaContext::new(&llama_model, context_ptr, true);
102 /// // Now you can use the context
103 /// ```
104 pub(crate) fn new(
105 llama_model: &'model LlamaModel,
106 llama_context: NonNull<llama_cpp_sys_4::llama_context>,
107 embeddings_enabled: bool,
108 ) -> Self {
109 Self {
110 context: llama_context,
111 model: llama_model,
112 initialized_logits: Vec::new(),
113 embeddings_enabled,
114 }
115 }
116
117 /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to `n_ubatch`.
118 #[must_use]
119 pub fn n_batch(&self) -> u32 {
120 unsafe { llama_cpp_sys_4::llama_n_batch(self.context.as_ptr()) }
121 }
122
123 /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to `n_batch`.
124 #[must_use]
125 pub fn n_ubatch(&self) -> u32 {
126 unsafe { llama_cpp_sys_4::llama_n_ubatch(self.context.as_ptr()) }
127 }
128
129 /// Gets the size of the context.
130 #[must_use]
131 pub fn n_ctx(&self) -> u32 {
132 unsafe { llama_cpp_sys_4::llama_n_ctx(self.context.as_ptr()) }
133 }
134
135 /// Decodes the batch.
136 ///
137 /// # Errors
138 ///
139 /// - `DecodeError` if the decoding failed.
140 ///
141 /// # Panics
142 ///
143 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
144 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
145 let result =
146 unsafe { llama_cpp_sys_4::llama_decode(self.context.as_ptr(), batch.llama_batch) };
147
148 match NonZeroI32::new(result) {
149 None => {
150 self.initialized_logits
151 .clone_from(&batch.initialized_logits);
152 Ok(())
153 }
154 Some(error) => Err(DecodeError::from(error)),
155 }
156 }
157
158 /// Encodes the batch.
159 ///
160 /// # Errors
161 ///
162 /// - `EncodeError` if the decoding failed.
163 ///
164 /// # Panics
165 ///
166 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
167 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
168 let result =
169 unsafe { llama_cpp_sys_4::llama_encode(self.context.as_ptr(), batch.llama_batch) };
170
171 match NonZeroI32::new(result) {
172 None => {
173 self.initialized_logits
174 .clone_from(&batch.initialized_logits);
175 Ok(())
176 }
177 Some(error) => Err(EncodeError::from(error)),
178 }
179 }
180
181 /// Return Pooling type for Llama's Context
182 #[must_use]
183 pub fn pooling_type(&self) -> LlamaPoolingType {
184 let pooling_type = unsafe { llama_pooling_type(self.context.as_ptr()) };
185
186 LlamaPoolingType::from(pooling_type)
187 }
188
189 /// Get the embeddings for the `i`th sequence in the current context.
190 ///
191 /// # Returns
192 ///
193 /// A slice containing the embeddings for the last decoded batch.
194 /// The size corresponds to the `n_embd` parameter of the context's model.
195 ///
196 /// # Errors
197 ///
198 /// - When the current context was constructed without enabling embeddings.
199 /// - If the current model had a pooling type of [`llama_cpp_sys_4::LLAMA_POOLING_TYPE_NONE`]
200 /// - If the given sequence index exceeds the max sequence id.
201 ///
202 /// # Panics
203 ///
204 /// * `n_embd` does not fit into a usize
205 pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
206 if !self.embeddings_enabled {
207 return Err(EmbeddingsError::NotEnabled);
208 }
209
210 let n_embd =
211 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
212
213 unsafe {
214 let embedding = llama_cpp_sys_4::llama_get_embeddings_seq(self.context.as_ptr(), i);
215
216 // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
217 if embedding.is_null() {
218 Err(EmbeddingsError::NonePoolType)
219 } else {
220 Ok(slice::from_raw_parts(embedding, n_embd))
221 }
222 }
223 }
224
225 /// Get the embeddings for the `i`th token in the current context.
226 ///
227 /// # Returns
228 ///
229 /// A slice containing the embeddings for the last decoded batch of the given token.
230 /// The size corresponds to the `n_embd` parameter of the context's model.
231 ///
232 /// # Errors
233 ///
234 /// - When the current context was constructed without enabling embeddings.
235 /// - When the given token didn't have logits enabled when it was passed.
236 /// - If the given token index exceeds the max token id.
237 ///
238 /// # Panics
239 ///
240 /// * `n_embd` does not fit into a usize
241 pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
242 if !self.embeddings_enabled {
243 return Err(EmbeddingsError::NotEnabled);
244 }
245
246 let n_embd =
247 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
248
249 unsafe {
250 let embedding = llama_cpp_sys_4::llama_get_embeddings_ith(self.context.as_ptr(), i);
251 // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
252 if embedding.is_null() {
253 Err(EmbeddingsError::LogitsNotEnabled)
254 } else {
255 Ok(slice::from_raw_parts(embedding, n_embd))
256 }
257 }
258 }
259
260 /// Get the logits for the last token in the context.
261 ///
262 /// # Returns
263 /// An iterator over unsorted `LlamaTokenData` containing the
264 /// logits for the last token in the context.
265 ///
266 /// # Panics
267 ///
268 /// - underlying logits data is null
269 pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
270 (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
271 let token = LlamaToken::new(i);
272 LlamaTokenData::new(token, *logit, 0_f32)
273 })
274 }
275
276 /// Get the token data array for the last token in the context.
277 ///
278 /// This is a convience method that implements:
279 /// ```ignore
280 /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
281 /// ```
282 ///
283 /// # Panics
284 ///
285 /// - underlying logits data is null
286 #[must_use]
287 pub fn token_data_array(&self) -> LlamaTokenDataArray {
288 LlamaTokenDataArray::from_iter(self.candidates(), false)
289 }
290
291 /// Token logits obtained from the last call to `decode()`.
292 /// The logits for which `batch.logits[i] != 0` are stored contiguously
293 /// in the order they have appeared in the batch.
294 /// Rows: number of tokens for which `batch.logits[i] != 0`
295 /// Cols: `n_vocab`
296 ///
297 /// # Returns
298 ///
299 /// A slice containing the logits for the last decoded token.
300 /// The size corresponds to the `n_vocab` parameter of the context's model.
301 ///
302 /// # Panics
303 ///
304 /// - `n_vocab` does not fit into a usize
305 /// - token data returned is null
306 #[must_use]
307 pub fn get_logits(&self) -> &[f32] {
308 let data = unsafe { llama_cpp_sys_4::llama_get_logits(self.context.as_ptr()) };
309 assert!(!data.is_null(), "logits data for last token is null");
310 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
311
312 unsafe { slice::from_raw_parts(data, len) }
313 }
314
315 /// Get the logits for the ith token in the context.
316 ///
317 /// # Panics
318 ///
319 /// - logit `i` is not initialized.
320 pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
321 (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
322 let token = LlamaToken::new(i);
323 LlamaTokenData::new(token, *logit, 0_f32)
324 })
325 }
326
327 /// Get the logits for the ith token in the context.
328 ///
329 /// # Panics
330 ///
331 /// - `i` is greater than `n_ctx`
332 /// - `n_vocab` does not fit into a usize
333 /// - logit `i` is not initialized.
334 #[must_use]
335 pub fn get_logits_ith(&self, i: i32) -> &[f32] {
336 assert!(
337 self.initialized_logits.contains(&i),
338 "logit {i} is not initialized. only {:?} is",
339 self.initialized_logits
340 );
341 assert!(
342 self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
343 "n_ctx ({}) must be greater than i ({})",
344 self.n_ctx(),
345 i
346 );
347
348 let data = unsafe { llama_cpp_sys_4::llama_get_logits_ith(self.context.as_ptr(), i) };
349 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
350
351 unsafe { slice::from_raw_parts(data, len) }
352 }
353
354 /// Get the number of context tokens per sequence.
355 #[must_use]
356 pub fn n_ctx_seq(&self) -> u32 {
357 unsafe { llama_cpp_sys_4::llama_n_ctx_seq(self.context.as_ptr()) }
358 }
359
360 /// Get the maximum number of sequences.
361 #[must_use]
362 pub fn n_seq_max(&self) -> u32 {
363 unsafe { llama_cpp_sys_4::llama_n_seq_max(self.context.as_ptr()) }
364 }
365
366 /// Get the number of recurrent-state snapshots per sequence.
367 #[must_use]
368 pub fn n_rs_seq(&self) -> u32 {
369 unsafe { llama_cpp_sys_4::llama_n_rs_seq(self.context.as_ptr()) }
370 }
371
372 /// Get the number of threads used for generation.
373 #[must_use]
374 pub fn n_threads(&self) -> i32 {
375 unsafe { llama_cpp_sys_4::llama_n_threads(self.context.as_ptr()) }
376 }
377
378 /// Get the number of threads used for batch processing.
379 #[must_use]
380 pub fn n_threads_batch(&self) -> i32 {
381 unsafe { llama_cpp_sys_4::llama_n_threads_batch(self.context.as_ptr()) }
382 }
383
384 /// Set the number of threads used for generation and batch processing.
385 pub fn set_n_threads(&mut self, n_threads: i32, n_threads_batch: i32) {
386 unsafe {
387 llama_cpp_sys_4::llama_set_n_threads(self.context.as_ptr(), n_threads, n_threads_batch);
388 }
389 }
390
391 /// Set whether to use causal attention.
392 ///
393 /// If set to `false`, the model will use non-causal attention, which is
394 /// needed for embedding models.
395 pub fn set_causal_attn(&mut self, causal_attn: bool) {
396 unsafe {
397 llama_cpp_sys_4::llama_set_causal_attn(self.context.as_ptr(), causal_attn);
398 }
399 }
400
401 /// Set whether to compute embeddings.
402 ///
403 /// This allows toggling embedding mode at runtime (as opposed to only at
404 /// context creation time).
405 pub fn set_embeddings(&mut self, embeddings: bool) {
406 self.embeddings_enabled = embeddings;
407 unsafe {
408 llama_cpp_sys_4::llama_set_embeddings(self.context.as_ptr(), embeddings);
409 }
410 }
411
412 /// Mark the next computation as a warmup run.
413 ///
414 /// Warmup runs are useful for GPU backends to compile kernels before
415 /// actual inference begins.
416 pub fn set_warmup(&mut self, warmup: bool) {
417 unsafe {
418 llama_cpp_sys_4::llama_set_warmup(self.context.as_ptr(), warmup);
419 }
420 }
421
422 /// Wait for all pending async computations to finish.
423 pub fn synchronize(&mut self) {
424 unsafe {
425 llama_cpp_sys_4::llama_synchronize(self.context.as_ptr());
426 }
427 }
428
429 /// Get all embeddings for the current context.
430 ///
431 /// Returns a slice of all embeddings from the last decoded batch.
432 /// For pooled embeddings use [`embeddings_seq_ith`](Self::embeddings_seq_ith) instead.
433 ///
434 /// # Errors
435 ///
436 /// - When the current context was constructed without enabling embeddings.
437 /// - If the embeddings pointer is null.
438 ///
439 /// # Panics
440 ///
441 /// * `n_embd` does not fit into a usize
442 pub fn get_embeddings(&self) -> Result<&[f32], EmbeddingsError> {
443 if !self.embeddings_enabled {
444 return Err(EmbeddingsError::NotEnabled);
445 }
446
447 let n_embd =
448 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
449
450 unsafe {
451 let embedding = llama_cpp_sys_4::llama_get_embeddings(self.context.as_ptr());
452 if embedding.is_null() {
453 Err(EmbeddingsError::NonePoolType)
454 } else {
455 Ok(slice::from_raw_parts(embedding, n_embd))
456 }
457 }
458 }
459
460 /// Toggle extraction of next-n embeddings (Rust name: pre-norm) — hidden
461 /// states used by MTP draft heads. Upstream C API: `llama_set_embeddings_nextn`
462 /// (llama.cpp PR #23198 and later renames).
463 ///
464 /// If `masked` is `true`, pre-norm rows are extracted only for tokens
465 /// whose `batch.logits[i]` is non-zero. If `masked` is `false`, rows are
466 /// extracted for every token in the batch regardless of `batch.logits` —
467 /// callers can then leave `batch.logits[i] = false` on prompt-fill
468 /// positions and avoid copying the full logits row for each one.
469 ///
470 /// Upstream's MTP session init configures pre-norm extraction on the target
471 /// and draft contexts automatically. Call this manually only for custom
472 /// speculative setups.
473 pub fn set_embeddings_pre_norm(&mut self, value: bool, masked: bool) {
474 unsafe {
475 llama_cpp_sys_4::llama_set_embeddings_nextn(self.context.as_ptr(), value, masked);
476 }
477 }
478
479 /// Get the full pre-norm embeddings buffer for the last decoded batch.
480 ///
481 /// Returns `None` when pre-norm embeddings are disabled or the buffer
482 /// hasn't been populated. The length of the returned slice is
483 /// `n_embd * <number of pre-norm rows>` — interpretation of the row
484 /// count depends on whether the setter was called with `masked=true`
485 /// (one row per sampled token) or `masked=false` (one row per batch
486 /// token). Use [`get_embeddings_pre_norm_ith`](Self::get_embeddings_pre_norm_ith)
487 /// when you only need a single row.
488 ///
489 /// # Panics
490 ///
491 /// Panics if `n_embd` does not fit in `usize`.
492 #[must_use]
493 pub fn get_embeddings_pre_norm(&self) -> Option<&[f32]> {
494 let n_embd =
495 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
496 unsafe {
497 let p = llama_cpp_sys_4::llama_get_embeddings_nextn(self.context.as_ptr());
498 if p.is_null() {
499 None
500 } else {
501 Some(slice::from_raw_parts(p, n_embd))
502 }
503 }
504 }
505
506 /// Get the pre-norm embedding row for the `i`th output position of the
507 /// last decoded batch. Returns `None` if upstream rejects the index
508 /// (e.g. masked mode with `batch.logits[i] == 0`, or out of range).
509 ///
510 /// # Panics
511 ///
512 /// Panics if `n_embd` does not fit in `usize`.
513 #[must_use]
514 pub fn get_embeddings_pre_norm_ith(&self, i: i32) -> Option<&[f32]> {
515 let n_embd =
516 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
517 unsafe {
518 let p = llama_cpp_sys_4::llama_get_embeddings_nextn_ith(self.context.as_ptr(), i);
519 if p.is_null() {
520 None
521 } else {
522 Some(slice::from_raw_parts(p, n_embd))
523 }
524 }
525 }
526
527 /// Select which `NextN` block the MTP draft graph runs.
528 ///
529 /// `offset` indexes past the trunk transformer layers (`0` = first `NextN`
530 /// head). Required for multi-head MTP models such as Step3.5; restore to
531 /// `0` after drafting. See [`crate::mtp`] for the full speculative loop.
532 ///
533 /// # Examples
534 ///
535 /// ```ignore
536 /// for head in 0..model.n_layer_nextn() {
537 /// draft.set_nextn_layer_offset(head);
538 /// let drafts = session.draft(0, n_past, last_token)?;
539 /// }
540 /// draft.set_nextn_layer_offset(0);
541 /// ```
542 pub fn set_nextn_layer_offset(&mut self, offset: i32) {
543 unsafe {
544 llama_cpp_sys_4::llama_set_nextn_layer_offset(self.context.as_ptr(), offset);
545 }
546 }
547
548 /// Return the paired context set via
549 /// [`crate::context::params::LlamaContextParams::with_ctx_other`].
550 ///
551 /// The pointer refers to the other live context created during
552 /// [`crate::model::LlamaModel::new_context`]; it is `None` when no pairing
553 /// was configured.
554 #[must_use]
555 pub fn ctx_other(&self) -> Option<NonNull<llama_cpp_sys_4::llama_context>> {
556 NonNull::new(unsafe { llama_cpp_sys_4::llama_get_ctx_other(self.context.as_ptr()) })
557 }
558
559 /// Reset the timings for the context.
560 pub fn reset_timings(&mut self) {
561 unsafe { llama_cpp_sys_4::ggml_time_init() }
562 }
563
564 /// Returns the timings for the context.
565 pub fn timings(&mut self) -> PerfContextData {
566 let perf_context_data =
567 unsafe { llama_cpp_sys_4::llama_perf_context(self.context.as_ptr()) };
568 PerfContextData { perf_context_data }
569 }
570
571 /// Reset the performance counters for the context.
572 pub fn perf_context_reset(&mut self) {
573 unsafe { llama_cpp_sys_4::llama_perf_context_reset(self.context.as_ptr()) }
574 }
575
576 /// Check if the KV cache memory supports shifting.
577 #[must_use]
578 pub fn memory_can_shift(&self) -> bool {
579 unsafe {
580 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
581 llama_cpp_sys_4::llama_memory_can_shift(mem)
582 }
583 }
584
585 /// Get the minimum position in a sequence's KV cache.
586 #[must_use]
587 pub fn memory_seq_pos_min(&self, seq_id: i32) -> i32 {
588 unsafe {
589 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
590 llama_cpp_sys_4::llama_memory_seq_pos_min(mem, seq_id)
591 }
592 }
593
594 /// Print a human-readable memory breakdown to stderr via llama.cpp.
595 ///
596 /// For structured access use [`Self::memory_breakdown`].
597 pub fn memory_breakdown_print(&self) {
598 unsafe {
599 llama_cpp_sys_4::common_memory_breakdown_print(self.context.as_ptr());
600 }
601 }
602
603 /// Return structured per-buffer memory usage for this context.
604 ///
605 /// Each [`memory_breakdown::MemoryBreakdownEntry`] reports model weights,
606 /// KV / recurrent cache, and compute scratch bytes for one ggml buffer
607 /// type. Returns an empty vector when no buffers are registered.
608 ///
609 /// # Examples
610 ///
611 /// ```no_run
612 /// use llama_cpp_4::prelude::*;
613 ///
614 /// fn main() {
615 /// let backend = LlamaBackend::init().unwrap();
616 /// let model = LlamaModel::load_from_file(&backend, "model.gguf", &LlamaModelParams::default()).unwrap();
617 /// let ctx = model.new_context(&backend, LlamaContextParams::default()).unwrap();
618 /// let total: usize = ctx.memory_breakdown().iter().map(|e| e.total()).sum();
619 /// println!("context uses {total} bytes across all buffer types");
620 /// }
621 /// ```
622 #[must_use]
623 pub fn memory_breakdown(&self) -> Vec<memory_breakdown::MemoryBreakdownEntry> {
624 memory_breakdown::collect_memory_breakdown(self.context.as_ptr())
625 }
626
627 /// Enable or disable extraction of input embeddings for a transformer layer.
628 ///
629 /// Maps to `llama_set_embeddings_layer_inp`. After a successful
630 /// [`Self::decode`], read the vector with [`Self::get_embeddings_layer_inp`].
631 pub fn set_embeddings_layer_inp(&mut self, layer_id: u32, value: bool) {
632 unsafe {
633 llama_cpp_sys_4::llama_set_embeddings_layer_inp(self.context.as_ptr(), layer_id, value);
634 }
635 }
636
637 /// Get input embeddings for `layer_id` from the last decoded batch.
638 ///
639 /// Returns `None` when the layer was not enabled via
640 /// [`Self::set_embeddings_layer_inp`] or when upstream has no data for
641 /// `layer_id`. The slice length is [`LlamaModel::n_embd`].
642 ///
643 /// # Panics
644 ///
645 /// Panics if `n_embd` does not fit in `usize`.
646 #[must_use]
647 pub fn get_embeddings_layer_inp(&self, layer_id: u32) -> Option<&[f32]> {
648 let n_embd =
649 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
650 unsafe {
651 let p =
652 llama_cpp_sys_4::llama_get_embeddings_layer_inp(self.context.as_ptr(), layer_id);
653 if p.is_null() {
654 None
655 } else {
656 Some(slice::from_raw_parts(p, n_embd))
657 }
658 }
659 }
660
661 /// Get the size of the full context state in bytes.
662 ///
663 /// This is the size needed for [`state_get_data`](Self::state_get_data) and
664 /// [`state_set_data`](Self::state_set_data).
665 #[must_use]
666 pub fn state_get_size(&mut self) -> usize {
667 unsafe { llama_cpp_sys_4::llama_state_get_size(self.context.as_ptr()) }
668 }
669
670 /// Copy the full context state into a byte buffer.
671 ///
672 /// The buffer must be at least [`state_get_size`](Self::state_get_size) bytes.
673 ///
674 /// Returns the number of bytes written.
675 pub fn state_get_data(&mut self, dst: &mut [u8]) -> usize {
676 unsafe {
677 llama_cpp_sys_4::llama_state_get_data(
678 self.context.as_ptr(),
679 dst.as_mut_ptr(),
680 dst.len(),
681 )
682 }
683 }
684
685 /// Restore the full context state from a byte buffer.
686 ///
687 /// Returns the number of bytes read.
688 pub fn state_set_data(&mut self, src: &[u8]) -> usize {
689 unsafe {
690 llama_cpp_sys_4::llama_state_set_data(self.context.as_ptr(), src.as_ptr(), src.len())
691 }
692 }
693
694 /// Save the context state to a file along with the given tokens.
695 ///
696 /// Returns `true` on success.
697 ///
698 /// # Panics
699 ///
700 /// Panics if the path contains null bytes.
701 pub fn state_save_file(
702 &mut self,
703 path: impl AsRef<std::path::Path>,
704 tokens: &[LlamaToken],
705 ) -> bool {
706 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
707 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
708 unsafe {
709 llama_cpp_sys_4::llama_state_save_file(
710 self.context.as_ptr(),
711 c_path.as_ptr(),
712 tokens.as_ptr().cast(),
713 tokens.len(),
714 )
715 }
716 }
717
718 /// Load a context state from a file.
719 ///
720 /// Returns `true` on success and fills `tokens_out` with the saved tokens.
721 ///
722 /// # Panics
723 ///
724 /// Panics if the path contains null bytes.
725 pub fn state_load_file(
726 &mut self,
727 path: impl AsRef<std::path::Path>,
728 tokens_out: &mut Vec<LlamaToken>,
729 n_token_capacity: usize,
730 ) -> bool {
731 tokens_out.resize(n_token_capacity, LlamaToken(0));
732 let mut n_token_count: usize = 0;
733 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
734 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
735 let ok = unsafe {
736 llama_cpp_sys_4::llama_state_load_file(
737 self.context.as_ptr(),
738 c_path.as_ptr(),
739 tokens_out.as_mut_ptr().cast(),
740 n_token_capacity,
741 std::ptr::addr_of_mut!(n_token_count),
742 )
743 };
744 if ok {
745 tokens_out.truncate(n_token_count);
746 }
747 ok
748 }
749
750 /// Get the size of a single sequence's state in bytes.
751 #[must_use]
752 pub fn state_seq_get_size(&mut self, seq_id: i32) -> usize {
753 unsafe { llama_cpp_sys_4::llama_state_seq_get_size(self.context.as_ptr(), seq_id) }
754 }
755
756 /// Copy a single sequence's state into a byte buffer.
757 ///
758 /// Returns the number of bytes written.
759 pub fn state_seq_get_data(&mut self, dst: &mut [u8], seq_id: i32) -> usize {
760 unsafe {
761 llama_cpp_sys_4::llama_state_seq_get_data(
762 self.context.as_ptr(),
763 dst.as_mut_ptr(),
764 dst.len(),
765 seq_id,
766 )
767 }
768 }
769
770 /// Restore a single sequence's state from a byte buffer.
771 ///
772 /// Returns the number of bytes read.
773 pub fn state_seq_set_data(&mut self, src: &[u8], dest_seq_id: i32) -> usize {
774 unsafe {
775 llama_cpp_sys_4::llama_state_seq_set_data(
776 self.context.as_ptr(),
777 src.as_ptr(),
778 src.len(),
779 dest_seq_id,
780 )
781 }
782 }
783
784 /// Save a single sequence's state to a file.
785 ///
786 /// Returns the number of bytes written (0 on failure).
787 ///
788 /// # Panics
789 ///
790 /// Panics if the path contains null bytes.
791 pub fn state_seq_save_file(
792 &mut self,
793 path: impl AsRef<std::path::Path>,
794 seq_id: i32,
795 tokens: &[LlamaToken],
796 ) -> usize {
797 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
798 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
799 unsafe {
800 llama_cpp_sys_4::llama_state_seq_save_file(
801 self.context.as_ptr(),
802 c_path.as_ptr(),
803 seq_id,
804 tokens.as_ptr().cast(),
805 tokens.len(),
806 )
807 }
808 }
809
810 /// Load a single sequence's state from a file.
811 ///
812 /// Returns the number of bytes read (0 on failure).
813 ///
814 /// # Panics
815 ///
816 /// Panics if the path contains null bytes.
817 pub fn state_seq_load_file(
818 &mut self,
819 path: impl AsRef<std::path::Path>,
820 dest_seq_id: i32,
821 tokens_out: &mut Vec<LlamaToken>,
822 n_token_capacity: usize,
823 ) -> usize {
824 tokens_out.resize(n_token_capacity, LlamaToken(0));
825 let mut n_token_count: usize = 0;
826 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
827 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
828 let ret = unsafe {
829 llama_cpp_sys_4::llama_state_seq_load_file(
830 self.context.as_ptr(),
831 c_path.as_ptr(),
832 dest_seq_id,
833 tokens_out.as_mut_ptr().cast(),
834 n_token_capacity,
835 std::ptr::addr_of_mut!(n_token_count),
836 )
837 };
838 if ret > 0 {
839 tokens_out.truncate(n_token_count);
840 }
841 ret
842 }
843
844 /// Set a control vector on the context.
845 ///
846 /// # Parameters
847 ///
848 /// - `data`: The control vector data (embedding values). Pass an empty slice to clear.
849 /// - `n_embd`: The embedding dimension.
850 /// - `il_start`: The starting layer index (inclusive).
851 /// - `il_end`: The ending layer index (exclusive).
852 ///
853 /// # Errors
854 ///
855 /// Returns `Err` with the error code if the operation fails.
856 pub fn set_adapter_cvec(
857 &mut self,
858 data: &[f32],
859 n_embd: i32,
860 il_start: i32,
861 il_end: i32,
862 ) -> Result<(), i32> {
863 let ret = unsafe {
864 llama_cpp_sys_4::llama_set_adapter_cvec(
865 self.context.as_ptr(),
866 data.as_ptr(),
867 data.len(),
868 n_embd,
869 il_start,
870 il_end,
871 )
872 };
873 if ret != 0 {
874 Err(ret)
875 } else {
876 Ok(())
877 }
878 }
879
880 /// Get sampled token debug info for the `i`th position.
881 ///
882 /// Returns the sampled token at position `i` from the last decode call.
883 #[must_use]
884 pub fn get_sampled_token_ith(&self, i: i32) -> LlamaToken {
885 let token =
886 unsafe { llama_cpp_sys_4::llama_get_sampled_token_ith(self.context.as_ptr(), i) };
887 LlamaToken(token)
888 }
889
890 /// Get sampled candidate tokens for the `i`th position.
891 ///
892 /// Returns a slice of candidate tokens from the last decode call.
893 #[must_use]
894 pub fn get_sampled_candidates_ith(&self, i: i32) -> &[LlamaToken] {
895 let count = unsafe {
896 llama_cpp_sys_4::llama_get_sampled_candidates_count_ith(self.context.as_ptr(), i)
897 } as usize;
898 if count == 0 {
899 return &[];
900 }
901 let ptr =
902 unsafe { llama_cpp_sys_4::llama_get_sampled_candidates_ith(self.context.as_ptr(), i) };
903 if ptr.is_null() {
904 return &[];
905 }
906 unsafe { slice::from_raw_parts(ptr.cast::<LlamaToken>(), count) }
907 }
908
909 /// Get the number of sampled logits for the `i`th position.
910 #[must_use]
911 pub fn get_sampled_logits_count_ith(&self, i: i32) -> u32 {
912 unsafe { llama_cpp_sys_4::llama_get_sampled_logits_count_ith(self.context.as_ptr(), i) }
913 }
914
915 /// Get sampled logits for the `i`th position.
916 ///
917 /// Returns a slice of logit values from the last decode call.
918 #[must_use]
919 pub fn get_sampled_logits_ith(&self, i: i32) -> &[f32] {
920 let count = self.get_sampled_logits_count_ith(i) as usize;
921 if count == 0 {
922 return &[];
923 }
924 let ptr =
925 unsafe { llama_cpp_sys_4::llama_get_sampled_logits_ith(self.context.as_ptr(), i) };
926 if ptr.is_null() {
927 return &[];
928 }
929 unsafe { slice::from_raw_parts(ptr, count) }
930 }
931
932 /// Get the number of sampled probabilities for the `i`th position.
933 #[must_use]
934 pub fn get_sampled_probs_count_ith(&self, i: i32) -> u32 {
935 unsafe { llama_cpp_sys_4::llama_get_sampled_probs_count_ith(self.context.as_ptr(), i) }
936 }
937
938 /// Get sampled probabilities for the `i`th position.
939 ///
940 /// Returns a slice of probability values from the last decode call.
941 #[must_use]
942 pub fn get_sampled_probs_ith(&self, i: i32) -> &[f32] {
943 let count = self.get_sampled_probs_count_ith(i) as usize;
944 if count == 0 {
945 return &[];
946 }
947 let ptr = unsafe { llama_cpp_sys_4::llama_get_sampled_probs_ith(self.context.as_ptr(), i) };
948 if ptr.is_null() {
949 return &[];
950 }
951 unsafe { slice::from_raw_parts(ptr, count) }
952 }
953
954 /// Get the size of a single sequence's state with flags.
955 #[must_use]
956 pub fn state_seq_get_size_ext(&mut self, seq_id: i32, flags: u32) -> usize {
957 unsafe {
958 llama_cpp_sys_4::llama_state_seq_get_size_ext(self.context.as_ptr(), seq_id, flags)
959 }
960 }
961
962 /// Copy a single sequence's state into a byte buffer with flags.
963 ///
964 /// Returns the number of bytes written.
965 pub fn state_seq_get_data_ext(&mut self, dst: &mut [u8], seq_id: i32, flags: u32) -> usize {
966 unsafe {
967 llama_cpp_sys_4::llama_state_seq_get_data_ext(
968 self.context.as_ptr(),
969 dst.as_mut_ptr(),
970 dst.len(),
971 seq_id,
972 flags,
973 )
974 }
975 }
976
977 /// Restore a single sequence's state from a byte buffer with flags.
978 ///
979 /// Returns the number of bytes read.
980 pub fn state_seq_set_data_ext(&mut self, src: &[u8], dest_seq_id: i32, flags: u32) -> usize {
981 unsafe {
982 llama_cpp_sys_4::llama_state_seq_set_data_ext(
983 self.context.as_ptr(),
984 src.as_ptr(),
985 src.len(),
986 dest_seq_id,
987 flags,
988 )
989 }
990 }
991
992 /// Set an abort callback for the context.
993 ///
994 /// The callback is called periodically during computation. If it returns `true`,
995 /// the computation is aborted.
996 ///
997 /// # Safety
998 ///
999 /// The callback data must remain valid for the lifetime of the context or until
1000 /// the callback is replaced.
1001 pub unsafe fn set_abort_callback(
1002 &mut self,
1003 callback: llama_cpp_sys_4::ggml_abort_callback,
1004 data: *mut std::ffi::c_void,
1005 ) {
1006 llama_cpp_sys_4::llama_set_abort_callback(self.context.as_ptr(), callback, data);
1007 }
1008
1009 /// Attach a thread pool to the context.
1010 ///
1011 /// # Safety
1012 ///
1013 /// The thread pools must remain valid for the lifetime of the context or until
1014 /// they are detached.
1015 pub unsafe fn attach_threadpool(
1016 &mut self,
1017 threadpool: llama_cpp_sys_4::ggml_threadpool_t,
1018 threadpool_batch: llama_cpp_sys_4::ggml_threadpool_t,
1019 ) {
1020 llama_cpp_sys_4::llama_attach_threadpool(
1021 self.context.as_ptr(),
1022 threadpool,
1023 threadpool_batch,
1024 );
1025 }
1026
1027 /// Detach the thread pool from the context.
1028 pub fn detach_threadpool(&mut self) {
1029 unsafe {
1030 llama_cpp_sys_4::llama_detach_threadpool(self.context.as_ptr());
1031 }
1032 }
1033
1034 /// Set a sampler for a specific sequence.
1035 ///
1036 /// Returns `true` on success.
1037 pub fn set_sampler(
1038 &mut self,
1039 seq_id: i32,
1040 sampler: &mut crate::sampling::LlamaSampler,
1041 ) -> bool {
1042 unsafe {
1043 llama_cpp_sys_4::llama_set_sampler(
1044 self.context.as_ptr(),
1045 seq_id,
1046 sampler.sampler.as_ptr(),
1047 )
1048 }
1049 }
1050
1051 /// Get the raw model pointer from this context.
1052 ///
1053 /// This is mainly useful for FFI interop. In normal usage, access
1054 /// the model via the `model` field instead.
1055 #[must_use]
1056 pub fn get_model_ptr(&self) -> *const llama_cpp_sys_4::llama_model {
1057 unsafe { llama_cpp_sys_4::llama_get_model(self.context.as_ptr()) }
1058 }
1059
1060 /// Sets a lora adapter.
1061 ///
1062 /// # Errors
1063 ///
1064 /// See [`LlamaLoraAdapterSetError`] for more information.
1065 pub fn lora_adapter_set(
1066 &self,
1067 adapter: &mut LlamaLoraAdapter,
1068 scale: f32,
1069 ) -> Result<(), LlamaLoraAdapterSetError> {
1070 let err_code = unsafe {
1071 // llama_set_adapter_lora / llama_rm_adapter_lora were replaced by llama_set_adapters_lora
1072 // which takes a full list of adapters + scales at once (b8249+)
1073 let mut adapter_ptr = adapter.lora_adapter.as_ptr();
1074 let mut scale_val = scale;
1075 llama_cpp_sys_4::llama_set_adapters_lora(
1076 self.context.as_ptr(),
1077 &raw mut adapter_ptr,
1078 1,
1079 &raw mut scale_val,
1080 )
1081 };
1082 if err_code != 0 {
1083 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
1084 }
1085
1086 tracing::debug!("Set lora adapter");
1087 Ok(())
1088 }
1089
1090 /// Remove all lora adapters from the context.
1091 ///
1092 /// Note: as of llama.cpp b8249 the per-adapter remove API was replaced by
1093 /// `llama_set_adapters_lora` which operates on the full adapter list at once.
1094 /// Calling this function clears **all** adapters currently set on the context.
1095 ///
1096 /// # Errors
1097 ///
1098 /// See [`LlamaLoraAdapterRemoveError`] for more information.
1099 pub fn lora_adapter_remove(
1100 &self,
1101 _adapter: &mut LlamaLoraAdapter,
1102 ) -> Result<(), LlamaLoraAdapterRemoveError> {
1103 let err_code = unsafe {
1104 llama_cpp_sys_4::llama_set_adapters_lora(
1105 self.context.as_ptr(),
1106 std::ptr::null_mut(),
1107 0,
1108 std::ptr::null_mut(),
1109 )
1110 };
1111 if err_code != 0 {
1112 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
1113 }
1114
1115 tracing::debug!("Remove lora adapter");
1116 Ok(())
1117 }
1118}
1119
1120impl Drop for LlamaContext<'_> {
1121 fn drop(&mut self) {
1122 unsafe { llama_cpp_sys_4::llama_free(self.context.as_ptr()) }
1123 }
1124}