llama_cpp_4/context.rs
1//! Safe wrapper around `llama_context`.
2
3use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use llama_cpp_sys_4::llama_pooling_type;
9use params::LlamaPoolingType;
10use perf::PerfContextData;
11
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::token::data::LlamaTokenData;
15use crate::token::data_array::LlamaTokenDataArray;
16use crate::token::LlamaToken;
17use crate::{
18 DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
19 LlamaLoraAdapterSetError,
20};
21
22pub mod kv_cache;
23pub mod params;
24pub mod perf;
25pub mod session;
26pub mod tensor_capture;
27
28/// A safe wrapper around the `llama_context` C++ context.
29///
30/// This struct provides a safe interface to interact with the `llama_context` used by the `LlamaModel`.
31/// It encapsulates the raw C++ context pointer and provides additional fields for managing the model and
32/// context-specific settings like embeddings and logits.
33///
34/// The `LlamaContext` struct ensures that the C++ context is always valid by using the `NonNull` type for
35/// the context pointer, preventing it from being null. The struct also holds a reference to the model
36/// (`LlamaModel`) that the context is tied to, along with some internal state like whether embeddings are enabled
37/// and the initialized logits for the context.
38///
39/// # Fields
40///
41/// - `context`: A non-null pointer to the raw C++ `llama_context`. This is the main context used for interacting with the model.
42/// - `model`: A reference to the `LlamaModel` associated with this context. This model provides the data and parameters
43/// that the context interacts with.
44/// - `initialized_logits`: A vector used to store the initialized logits. These are used in the model's processing and
45/// are kept separate from the context data.
46/// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in the context. This is useful for
47/// controlling whether embedding data is generated during the interaction with the model.
48#[allow(clippy::module_name_repetitions)]
49pub struct LlamaContext<'a> {
50 pub(crate) context: NonNull<llama_cpp_sys_4::llama_context>,
51 /// a reference to the contexts model.
52 pub model: &'a LlamaModel,
53 initialized_logits: Vec<i32>,
54 embeddings_enabled: bool,
55}
56
57impl Debug for LlamaContext<'_> {
58 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("LlamaContext")
60 .field("context", &self.context)
61 .finish()
62 }
63}
64
65impl<'model> LlamaContext<'model> {
66 /// Creates a new instance of `LlamaContext` with the provided model, context, and embeddings flag.
67 ///
68 /// This function initializes a new `LlamaContext` object, which is used to interact with the
69 /// `LlamaModel`. The context is created from a pointer to a C++ context and the embeddings flag
70 /// determines whether embeddings are enabled in the context.
71 ///
72 /// # Parameters
73 ///
74 /// - `llama_model`: A reference to an existing `LlamaModel` that will be used with the new context.
75 /// - `llama_context`: A non-null pointer to an existing `llama_cpp_sys_4::llama_context` representing
76 /// the context created in previous steps. This context is necessary for interacting with the model.
77 /// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in this context.
78 ///
79 /// # Returns
80 ///
81 /// This function returns a new instance of `LlamaContext` initialized with the given parameters:
82 /// - The model reference (`llama_model`) is stored in the context.
83 /// - The raw context pointer (`llama_context`) is wrapped in a `NonNull` to ensure safety.
84 /// - The `embeddings_enabled` flag is used to determine if embeddings are enabled for the context.
85 ///
86 /// # Example
87 /// ```ignore
88 /// let llama_model = LlamaModel::load_from_file(&backend, "path/to/model", ¶ms).unwrap();
89 /// let context_ptr = NonNull::new(some_llama_context_ptr).unwrap();
90 /// let context = LlamaContext::new(&llama_model, context_ptr, true);
91 /// // Now you can use the context
92 /// ```
93 pub(crate) fn new(
94 llama_model: &'model LlamaModel,
95 llama_context: NonNull<llama_cpp_sys_4::llama_context>,
96 embeddings_enabled: bool,
97 ) -> Self {
98 Self {
99 context: llama_context,
100 model: llama_model,
101 initialized_logits: Vec::new(),
102 embeddings_enabled,
103 }
104 }
105
106 /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to `n_ubatch`.
107 #[must_use]
108 pub fn n_batch(&self) -> u32 {
109 unsafe { llama_cpp_sys_4::llama_n_batch(self.context.as_ptr()) }
110 }
111
112 /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to `n_batch`.
113 #[must_use]
114 pub fn n_ubatch(&self) -> u32 {
115 unsafe { llama_cpp_sys_4::llama_n_ubatch(self.context.as_ptr()) }
116 }
117
118 /// Gets the size of the context.
119 #[must_use]
120 pub fn n_ctx(&self) -> u32 {
121 unsafe { llama_cpp_sys_4::llama_n_ctx(self.context.as_ptr()) }
122 }
123
124 /// Decodes the batch.
125 ///
126 /// # Errors
127 ///
128 /// - `DecodeError` if the decoding failed.
129 ///
130 /// # Panics
131 ///
132 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
133 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
134 let result =
135 unsafe { llama_cpp_sys_4::llama_decode(self.context.as_ptr(), batch.llama_batch) };
136
137 match NonZeroI32::new(result) {
138 None => {
139 self.initialized_logits
140 .clone_from(&batch.initialized_logits);
141 Ok(())
142 }
143 Some(error) => Err(DecodeError::from(error)),
144 }
145 }
146
147 /// Encodes the batch.
148 ///
149 /// # Errors
150 ///
151 /// - `EncodeError` if the decoding failed.
152 ///
153 /// # Panics
154 ///
155 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
156 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
157 let result =
158 unsafe { llama_cpp_sys_4::llama_encode(self.context.as_ptr(), batch.llama_batch) };
159
160 match NonZeroI32::new(result) {
161 None => {
162 self.initialized_logits
163 .clone_from(&batch.initialized_logits);
164 Ok(())
165 }
166 Some(error) => Err(EncodeError::from(error)),
167 }
168 }
169
170 /// Return Pooling type for Llama's Context
171 #[must_use]
172 pub fn pooling_type(&self) -> LlamaPoolingType {
173 let pooling_type = unsafe { llama_pooling_type(self.context.as_ptr()) };
174
175 LlamaPoolingType::from(pooling_type)
176 }
177
178 /// Get the embeddings for the `i`th sequence in the current context.
179 ///
180 /// # Returns
181 ///
182 /// A slice containing the embeddings for the last decoded batch.
183 /// The size corresponds to the `n_embd` parameter of the context's model.
184 ///
185 /// # Errors
186 ///
187 /// - When the current context was constructed without enabling embeddings.
188 /// - If the current model had a pooling type of [`llama_cpp_sys_4::LLAMA_POOLING_TYPE_NONE`]
189 /// - If the given sequence index exceeds the max sequence id.
190 ///
191 /// # Panics
192 ///
193 /// * `n_embd` does not fit into a usize
194 pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
195 if !self.embeddings_enabled {
196 return Err(EmbeddingsError::NotEnabled);
197 }
198
199 let n_embd =
200 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
201
202 unsafe {
203 let embedding = llama_cpp_sys_4::llama_get_embeddings_seq(self.context.as_ptr(), i);
204
205 // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
206 if embedding.is_null() {
207 Err(EmbeddingsError::NonePoolType)
208 } else {
209 Ok(slice::from_raw_parts(embedding, n_embd))
210 }
211 }
212 }
213
214 /// Get the embeddings for the `i`th token in the current context.
215 ///
216 /// # Returns
217 ///
218 /// A slice containing the embeddings for the last decoded batch of the given token.
219 /// The size corresponds to the `n_embd` parameter of the context's model.
220 ///
221 /// # Errors
222 ///
223 /// - When the current context was constructed without enabling embeddings.
224 /// - When the given token didn't have logits enabled when it was passed.
225 /// - If the given token index exceeds the max token id.
226 ///
227 /// # Panics
228 ///
229 /// * `n_embd` does not fit into a usize
230 pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
231 if !self.embeddings_enabled {
232 return Err(EmbeddingsError::NotEnabled);
233 }
234
235 let n_embd =
236 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
237
238 unsafe {
239 let embedding = llama_cpp_sys_4::llama_get_embeddings_ith(self.context.as_ptr(), i);
240 // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
241 if embedding.is_null() {
242 Err(EmbeddingsError::LogitsNotEnabled)
243 } else {
244 Ok(slice::from_raw_parts(embedding, n_embd))
245 }
246 }
247 }
248
249 /// Get the logits for the last token in the context.
250 ///
251 /// # Returns
252 /// An iterator over unsorted `LlamaTokenData` containing the
253 /// logits for the last token in the context.
254 ///
255 /// # Panics
256 ///
257 /// - underlying logits data is null
258 pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
259 (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
260 let token = LlamaToken::new(i);
261 LlamaTokenData::new(token, *logit, 0_f32)
262 })
263 }
264
265 /// Get the token data array for the last token in the context.
266 ///
267 /// This is a convience method that implements:
268 /// ```ignore
269 /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
270 /// ```
271 ///
272 /// # Panics
273 ///
274 /// - underlying logits data is null
275 #[must_use]
276 pub fn token_data_array(&self) -> LlamaTokenDataArray {
277 LlamaTokenDataArray::from_iter(self.candidates(), false)
278 }
279
280 /// Token logits obtained from the last call to `decode()`.
281 /// The logits for which `batch.logits[i] != 0` are stored contiguously
282 /// in the order they have appeared in the batch.
283 /// Rows: number of tokens for which `batch.logits[i] != 0`
284 /// Cols: `n_vocab`
285 ///
286 /// # Returns
287 ///
288 /// A slice containing the logits for the last decoded token.
289 /// The size corresponds to the `n_vocab` parameter of the context's model.
290 ///
291 /// # Panics
292 ///
293 /// - `n_vocab` does not fit into a usize
294 /// - token data returned is null
295 #[must_use]
296 pub fn get_logits(&self) -> &[f32] {
297 let data = unsafe { llama_cpp_sys_4::llama_get_logits(self.context.as_ptr()) };
298 assert!(!data.is_null(), "logits data for last token is null");
299 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
300
301 unsafe { slice::from_raw_parts(data, len) }
302 }
303
304 /// Get the logits for the ith token in the context.
305 ///
306 /// # Panics
307 ///
308 /// - logit `i` is not initialized.
309 pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
310 (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
311 let token = LlamaToken::new(i);
312 LlamaTokenData::new(token, *logit, 0_f32)
313 })
314 }
315
316 /// Get the logits for the ith token in the context.
317 ///
318 /// # Panics
319 ///
320 /// - `i` is greater than `n_ctx`
321 /// - `n_vocab` does not fit into a usize
322 /// - logit `i` is not initialized.
323 #[must_use]
324 pub fn get_logits_ith(&self, i: i32) -> &[f32] {
325 assert!(
326 self.initialized_logits.contains(&i),
327 "logit {i} is not initialized. only {:?} is",
328 self.initialized_logits
329 );
330 assert!(
331 self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
332 "n_ctx ({}) must be greater than i ({})",
333 self.n_ctx(),
334 i
335 );
336
337 let data = unsafe { llama_cpp_sys_4::llama_get_logits_ith(self.context.as_ptr(), i) };
338 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
339
340 unsafe { slice::from_raw_parts(data, len) }
341 }
342
343 /// Get the number of context tokens per sequence.
344 #[must_use]
345 pub fn n_ctx_seq(&self) -> u32 {
346 unsafe { llama_cpp_sys_4::llama_n_ctx_seq(self.context.as_ptr()) }
347 }
348
349 /// Get the maximum number of sequences.
350 #[must_use]
351 pub fn n_seq_max(&self) -> u32 {
352 unsafe { llama_cpp_sys_4::llama_n_seq_max(self.context.as_ptr()) }
353 }
354
355 /// Get the number of recurrent-state snapshots per sequence.
356 #[must_use]
357 pub fn n_rs_seq(&self) -> u32 {
358 unsafe { llama_cpp_sys_4::llama_n_rs_seq(self.context.as_ptr()) }
359 }
360
361 /// Get the number of threads used for generation.
362 #[must_use]
363 pub fn n_threads(&self) -> i32 {
364 unsafe { llama_cpp_sys_4::llama_n_threads(self.context.as_ptr()) }
365 }
366
367 /// Get the number of threads used for batch processing.
368 #[must_use]
369 pub fn n_threads_batch(&self) -> i32 {
370 unsafe { llama_cpp_sys_4::llama_n_threads_batch(self.context.as_ptr()) }
371 }
372
373 /// Set the number of threads used for generation and batch processing.
374 pub fn set_n_threads(&mut self, n_threads: i32, n_threads_batch: i32) {
375 unsafe {
376 llama_cpp_sys_4::llama_set_n_threads(self.context.as_ptr(), n_threads, n_threads_batch);
377 }
378 }
379
380 /// Set whether to use causal attention.
381 ///
382 /// If set to `false`, the model will use non-causal attention, which is
383 /// needed for embedding models.
384 pub fn set_causal_attn(&mut self, causal_attn: bool) {
385 unsafe {
386 llama_cpp_sys_4::llama_set_causal_attn(self.context.as_ptr(), causal_attn);
387 }
388 }
389
390 /// Set whether to compute embeddings.
391 ///
392 /// This allows toggling embedding mode at runtime (as opposed to only at
393 /// context creation time).
394 pub fn set_embeddings(&mut self, embeddings: bool) {
395 self.embeddings_enabled = embeddings;
396 unsafe {
397 llama_cpp_sys_4::llama_set_embeddings(self.context.as_ptr(), embeddings);
398 }
399 }
400
401 /// Mark the next computation as a warmup run.
402 ///
403 /// Warmup runs are useful for GPU backends to compile kernels before
404 /// actual inference begins.
405 pub fn set_warmup(&mut self, warmup: bool) {
406 unsafe {
407 llama_cpp_sys_4::llama_set_warmup(self.context.as_ptr(), warmup);
408 }
409 }
410
411 /// Wait for all pending async computations to finish.
412 pub fn synchronize(&mut self) {
413 unsafe {
414 llama_cpp_sys_4::llama_synchronize(self.context.as_ptr());
415 }
416 }
417
418 /// Get all embeddings for the current context.
419 ///
420 /// Returns a slice of all embeddings from the last decoded batch.
421 /// For pooled embeddings use [`embeddings_seq_ith`](Self::embeddings_seq_ith) instead.
422 ///
423 /// # Errors
424 ///
425 /// - When the current context was constructed without enabling embeddings.
426 /// - If the embeddings pointer is null.
427 ///
428 /// # Panics
429 ///
430 /// * `n_embd` does not fit into a usize
431 pub fn get_embeddings(&self) -> Result<&[f32], EmbeddingsError> {
432 if !self.embeddings_enabled {
433 return Err(EmbeddingsError::NotEnabled);
434 }
435
436 let n_embd =
437 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
438
439 unsafe {
440 let embedding = llama_cpp_sys_4::llama_get_embeddings(self.context.as_ptr());
441 if embedding.is_null() {
442 Err(EmbeddingsError::NonePoolType)
443 } else {
444 Ok(slice::from_raw_parts(embedding, n_embd))
445 }
446 }
447 }
448
449 /// Toggle extraction of next-n embeddings (Rust name: pre-norm) — hidden
450 /// states used by MTP draft heads. Upstream C API: `llama_set_embeddings_nextn`
451 /// (llama.cpp PR #23198 and later renames).
452 ///
453 /// If `masked` is `true`, pre-norm rows are extracted only for tokens
454 /// whose `batch.logits[i]` is non-zero. If `masked` is `false`, rows are
455 /// extracted for every token in the batch regardless of `batch.logits` —
456 /// callers can then leave `batch.logits[i] = false` on prompt-fill
457 /// positions and avoid copying the full logits row for each one.
458 ///
459 /// Upstream's MTP session init configures pre-norm extraction on the target
460 /// and draft contexts automatically. Call this manually only for custom
461 /// speculative setups.
462 pub fn set_embeddings_pre_norm(&mut self, value: bool, masked: bool) {
463 unsafe {
464 llama_cpp_sys_4::llama_set_embeddings_nextn(self.context.as_ptr(), value, masked);
465 }
466 }
467
468 /// Get the full pre-norm embeddings buffer for the last decoded batch.
469 ///
470 /// Returns `None` when pre-norm embeddings are disabled or the buffer
471 /// hasn't been populated. The length of the returned slice is
472 /// `n_embd * <number of pre-norm rows>` — interpretation of the row
473 /// count depends on whether the setter was called with `masked=true`
474 /// (one row per sampled token) or `masked=false` (one row per batch
475 /// token). Use [`get_embeddings_pre_norm_ith`](Self::get_embeddings_pre_norm_ith)
476 /// when you only need a single row.
477 #[must_use]
478 pub fn get_embeddings_pre_norm(&self) -> Option<&[f32]> {
479 let n_embd =
480 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
481 unsafe {
482 let p = llama_cpp_sys_4::llama_get_embeddings_nextn(self.context.as_ptr());
483 if p.is_null() {
484 None
485 } else {
486 Some(slice::from_raw_parts(p, n_embd))
487 }
488 }
489 }
490
491 /// Get the pre-norm embedding row for the `i`th output position of the
492 /// last decoded batch. Returns `None` if upstream rejects the index
493 /// (e.g. masked mode with `batch.logits[i] == 0`, or out of range).
494 #[must_use]
495 pub fn get_embeddings_pre_norm_ith(&self, i: i32) -> Option<&[f32]> {
496 let n_embd =
497 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
498 unsafe {
499 let p = llama_cpp_sys_4::llama_get_embeddings_nextn_ith(self.context.as_ptr(), i);
500 if p.is_null() {
501 None
502 } else {
503 Some(slice::from_raw_parts(p, n_embd))
504 }
505 }
506 }
507
508 /// Reset the timings for the context.
509 pub fn reset_timings(&mut self) {
510 unsafe { llama_cpp_sys_4::ggml_time_init() }
511 }
512
513 /// Returns the timings for the context.
514 pub fn timings(&mut self) -> PerfContextData {
515 let perf_context_data =
516 unsafe { llama_cpp_sys_4::llama_perf_context(self.context.as_ptr()) };
517 PerfContextData { perf_context_data }
518 }
519
520 /// Reset the performance counters for the context.
521 pub fn perf_context_reset(&mut self) {
522 unsafe { llama_cpp_sys_4::llama_perf_context_reset(self.context.as_ptr()) }
523 }
524
525 /// Check if the KV cache memory supports shifting.
526 #[must_use]
527 pub fn memory_can_shift(&self) -> bool {
528 unsafe {
529 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
530 llama_cpp_sys_4::llama_memory_can_shift(mem)
531 }
532 }
533
534 /// Get the minimum position in a sequence's KV cache.
535 #[must_use]
536 pub fn memory_seq_pos_min(&self, seq_id: i32) -> i32 {
537 unsafe {
538 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
539 llama_cpp_sys_4::llama_memory_seq_pos_min(mem, seq_id)
540 }
541 }
542
543 /// Print a breakdown of the memory usage.
544 pub fn memory_breakdown_print(&self) {
545 unsafe {
546 llama_cpp_sys_4::common_memory_breakdown_print(self.context.as_ptr());
547 }
548 }
549
550 /// Get the size of the full context state in bytes.
551 ///
552 /// This is the size needed for [`state_get_data`](Self::state_get_data) and
553 /// [`state_set_data`](Self::state_set_data).
554 #[must_use]
555 pub fn state_get_size(&mut self) -> usize {
556 unsafe { llama_cpp_sys_4::llama_state_get_size(self.context.as_ptr()) }
557 }
558
559 /// Copy the full context state into a byte buffer.
560 ///
561 /// The buffer must be at least [`state_get_size`](Self::state_get_size) bytes.
562 ///
563 /// Returns the number of bytes written.
564 pub fn state_get_data(&mut self, dst: &mut [u8]) -> usize {
565 unsafe {
566 llama_cpp_sys_4::llama_state_get_data(
567 self.context.as_ptr(),
568 dst.as_mut_ptr(),
569 dst.len(),
570 )
571 }
572 }
573
574 /// Restore the full context state from a byte buffer.
575 ///
576 /// Returns the number of bytes read.
577 pub fn state_set_data(&mut self, src: &[u8]) -> usize {
578 unsafe {
579 llama_cpp_sys_4::llama_state_set_data(self.context.as_ptr(), src.as_ptr(), src.len())
580 }
581 }
582
583 /// Save the context state to a file along with the given tokens.
584 ///
585 /// Returns `true` on success.
586 ///
587 /// # Panics
588 ///
589 /// Panics if the path contains null bytes.
590 pub fn state_save_file(
591 &mut self,
592 path: impl AsRef<std::path::Path>,
593 tokens: &[LlamaToken],
594 ) -> bool {
595 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
596 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
597 unsafe {
598 llama_cpp_sys_4::llama_state_save_file(
599 self.context.as_ptr(),
600 c_path.as_ptr(),
601 tokens.as_ptr().cast(),
602 tokens.len(),
603 )
604 }
605 }
606
607 /// Load a context state from a file.
608 ///
609 /// Returns `true` on success and fills `tokens_out` with the saved tokens.
610 ///
611 /// # Panics
612 ///
613 /// Panics if the path contains null bytes.
614 pub fn state_load_file(
615 &mut self,
616 path: impl AsRef<std::path::Path>,
617 tokens_out: &mut Vec<LlamaToken>,
618 n_token_capacity: usize,
619 ) -> bool {
620 tokens_out.resize(n_token_capacity, LlamaToken(0));
621 let mut n_token_count: usize = 0;
622 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
623 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
624 let ok = unsafe {
625 llama_cpp_sys_4::llama_state_load_file(
626 self.context.as_ptr(),
627 c_path.as_ptr(),
628 tokens_out.as_mut_ptr().cast(),
629 n_token_capacity,
630 std::ptr::addr_of_mut!(n_token_count),
631 )
632 };
633 if ok {
634 tokens_out.truncate(n_token_count);
635 }
636 ok
637 }
638
639 /// Get the size of a single sequence's state in bytes.
640 #[must_use]
641 pub fn state_seq_get_size(&mut self, seq_id: i32) -> usize {
642 unsafe { llama_cpp_sys_4::llama_state_seq_get_size(self.context.as_ptr(), seq_id) }
643 }
644
645 /// Copy a single sequence's state into a byte buffer.
646 ///
647 /// Returns the number of bytes written.
648 pub fn state_seq_get_data(&mut self, dst: &mut [u8], seq_id: i32) -> usize {
649 unsafe {
650 llama_cpp_sys_4::llama_state_seq_get_data(
651 self.context.as_ptr(),
652 dst.as_mut_ptr(),
653 dst.len(),
654 seq_id,
655 )
656 }
657 }
658
659 /// Restore a single sequence's state from a byte buffer.
660 ///
661 /// Returns the number of bytes read.
662 pub fn state_seq_set_data(&mut self, src: &[u8], dest_seq_id: i32) -> usize {
663 unsafe {
664 llama_cpp_sys_4::llama_state_seq_set_data(
665 self.context.as_ptr(),
666 src.as_ptr(),
667 src.len(),
668 dest_seq_id,
669 )
670 }
671 }
672
673 /// Save a single sequence's state to a file.
674 ///
675 /// Returns the number of bytes written (0 on failure).
676 ///
677 /// # Panics
678 ///
679 /// Panics if the path contains null bytes.
680 pub fn state_seq_save_file(
681 &mut self,
682 path: impl AsRef<std::path::Path>,
683 seq_id: i32,
684 tokens: &[LlamaToken],
685 ) -> usize {
686 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
687 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
688 unsafe {
689 llama_cpp_sys_4::llama_state_seq_save_file(
690 self.context.as_ptr(),
691 c_path.as_ptr(),
692 seq_id,
693 tokens.as_ptr().cast(),
694 tokens.len(),
695 )
696 }
697 }
698
699 /// Load a single sequence's state from a file.
700 ///
701 /// Returns the number of bytes read (0 on failure).
702 ///
703 /// # Panics
704 ///
705 /// Panics if the path contains null bytes.
706 pub fn state_seq_load_file(
707 &mut self,
708 path: impl AsRef<std::path::Path>,
709 dest_seq_id: i32,
710 tokens_out: &mut Vec<LlamaToken>,
711 n_token_capacity: usize,
712 ) -> usize {
713 tokens_out.resize(n_token_capacity, LlamaToken(0));
714 let mut n_token_count: usize = 0;
715 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
716 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
717 let ret = unsafe {
718 llama_cpp_sys_4::llama_state_seq_load_file(
719 self.context.as_ptr(),
720 c_path.as_ptr(),
721 dest_seq_id,
722 tokens_out.as_mut_ptr().cast(),
723 n_token_capacity,
724 std::ptr::addr_of_mut!(n_token_count),
725 )
726 };
727 if ret > 0 {
728 tokens_out.truncate(n_token_count);
729 }
730 ret
731 }
732
733 /// Set a control vector on the context.
734 ///
735 /// # Parameters
736 ///
737 /// - `data`: The control vector data (embedding values). Pass an empty slice to clear.
738 /// - `n_embd`: The embedding dimension.
739 /// - `il_start`: The starting layer index (inclusive).
740 /// - `il_end`: The ending layer index (exclusive).
741 ///
742 /// # Errors
743 ///
744 /// Returns `Err` with the error code if the operation fails.
745 pub fn set_adapter_cvec(
746 &mut self,
747 data: &[f32],
748 n_embd: i32,
749 il_start: i32,
750 il_end: i32,
751 ) -> Result<(), i32> {
752 let ret = unsafe {
753 llama_cpp_sys_4::llama_set_adapter_cvec(
754 self.context.as_ptr(),
755 data.as_ptr(),
756 data.len(),
757 n_embd,
758 il_start,
759 il_end,
760 )
761 };
762 if ret != 0 {
763 Err(ret)
764 } else {
765 Ok(())
766 }
767 }
768
769 /// Get sampled token debug info for the `i`th position.
770 ///
771 /// Returns the sampled token at position `i` from the last decode call.
772 #[must_use]
773 pub fn get_sampled_token_ith(&self, i: i32) -> LlamaToken {
774 let token =
775 unsafe { llama_cpp_sys_4::llama_get_sampled_token_ith(self.context.as_ptr(), i) };
776 LlamaToken(token)
777 }
778
779 /// Get sampled candidate tokens for the `i`th position.
780 ///
781 /// Returns a slice of candidate tokens from the last decode call.
782 #[must_use]
783 pub fn get_sampled_candidates_ith(&self, i: i32) -> &[LlamaToken] {
784 let count = unsafe {
785 llama_cpp_sys_4::llama_get_sampled_candidates_count_ith(self.context.as_ptr(), i)
786 } as usize;
787 if count == 0 {
788 return &[];
789 }
790 let ptr =
791 unsafe { llama_cpp_sys_4::llama_get_sampled_candidates_ith(self.context.as_ptr(), i) };
792 if ptr.is_null() {
793 return &[];
794 }
795 unsafe { slice::from_raw_parts(ptr.cast::<LlamaToken>(), count) }
796 }
797
798 /// Get the number of sampled logits for the `i`th position.
799 #[must_use]
800 pub fn get_sampled_logits_count_ith(&self, i: i32) -> u32 {
801 unsafe { llama_cpp_sys_4::llama_get_sampled_logits_count_ith(self.context.as_ptr(), i) }
802 }
803
804 /// Get sampled logits for the `i`th position.
805 ///
806 /// Returns a slice of logit values from the last decode call.
807 #[must_use]
808 pub fn get_sampled_logits_ith(&self, i: i32) -> &[f32] {
809 let count = self.get_sampled_logits_count_ith(i) as usize;
810 if count == 0 {
811 return &[];
812 }
813 let ptr =
814 unsafe { llama_cpp_sys_4::llama_get_sampled_logits_ith(self.context.as_ptr(), i) };
815 if ptr.is_null() {
816 return &[];
817 }
818 unsafe { slice::from_raw_parts(ptr, count) }
819 }
820
821 /// Get the number of sampled probabilities for the `i`th position.
822 #[must_use]
823 pub fn get_sampled_probs_count_ith(&self, i: i32) -> u32 {
824 unsafe { llama_cpp_sys_4::llama_get_sampled_probs_count_ith(self.context.as_ptr(), i) }
825 }
826
827 /// Get sampled probabilities for the `i`th position.
828 ///
829 /// Returns a slice of probability values from the last decode call.
830 #[must_use]
831 pub fn get_sampled_probs_ith(&self, i: i32) -> &[f32] {
832 let count = self.get_sampled_probs_count_ith(i) as usize;
833 if count == 0 {
834 return &[];
835 }
836 let ptr = unsafe { llama_cpp_sys_4::llama_get_sampled_probs_ith(self.context.as_ptr(), i) };
837 if ptr.is_null() {
838 return &[];
839 }
840 unsafe { slice::from_raw_parts(ptr, count) }
841 }
842
843 /// Get the size of a single sequence's state with flags.
844 #[must_use]
845 pub fn state_seq_get_size_ext(&mut self, seq_id: i32, flags: u32) -> usize {
846 unsafe {
847 llama_cpp_sys_4::llama_state_seq_get_size_ext(self.context.as_ptr(), seq_id, flags)
848 }
849 }
850
851 /// Copy a single sequence's state into a byte buffer with flags.
852 ///
853 /// Returns the number of bytes written.
854 pub fn state_seq_get_data_ext(&mut self, dst: &mut [u8], seq_id: i32, flags: u32) -> usize {
855 unsafe {
856 llama_cpp_sys_4::llama_state_seq_get_data_ext(
857 self.context.as_ptr(),
858 dst.as_mut_ptr(),
859 dst.len(),
860 seq_id,
861 flags,
862 )
863 }
864 }
865
866 /// Restore a single sequence's state from a byte buffer with flags.
867 ///
868 /// Returns the number of bytes read.
869 pub fn state_seq_set_data_ext(&mut self, src: &[u8], dest_seq_id: i32, flags: u32) -> usize {
870 unsafe {
871 llama_cpp_sys_4::llama_state_seq_set_data_ext(
872 self.context.as_ptr(),
873 src.as_ptr(),
874 src.len(),
875 dest_seq_id,
876 flags,
877 )
878 }
879 }
880
881 /// Set an abort callback for the context.
882 ///
883 /// The callback is called periodically during computation. If it returns `true`,
884 /// the computation is aborted.
885 ///
886 /// # Safety
887 ///
888 /// The callback data must remain valid for the lifetime of the context or until
889 /// the callback is replaced.
890 pub unsafe fn set_abort_callback(
891 &mut self,
892 callback: llama_cpp_sys_4::ggml_abort_callback,
893 data: *mut std::ffi::c_void,
894 ) {
895 llama_cpp_sys_4::llama_set_abort_callback(self.context.as_ptr(), callback, data);
896 }
897
898 /// Attach a thread pool to the context.
899 ///
900 /// # Safety
901 ///
902 /// The thread pools must remain valid for the lifetime of the context or until
903 /// they are detached.
904 pub unsafe fn attach_threadpool(
905 &mut self,
906 threadpool: llama_cpp_sys_4::ggml_threadpool_t,
907 threadpool_batch: llama_cpp_sys_4::ggml_threadpool_t,
908 ) {
909 llama_cpp_sys_4::llama_attach_threadpool(
910 self.context.as_ptr(),
911 threadpool,
912 threadpool_batch,
913 );
914 }
915
916 /// Detach the thread pool from the context.
917 pub fn detach_threadpool(&mut self) {
918 unsafe {
919 llama_cpp_sys_4::llama_detach_threadpool(self.context.as_ptr());
920 }
921 }
922
923 /// Set a sampler for a specific sequence.
924 ///
925 /// Returns `true` on success.
926 pub fn set_sampler(
927 &mut self,
928 seq_id: i32,
929 sampler: &mut crate::sampling::LlamaSampler,
930 ) -> bool {
931 unsafe {
932 llama_cpp_sys_4::llama_set_sampler(
933 self.context.as_ptr(),
934 seq_id,
935 sampler.sampler.as_ptr(),
936 )
937 }
938 }
939
940 /// Get the raw model pointer from this context.
941 ///
942 /// This is mainly useful for FFI interop. In normal usage, access
943 /// the model via the `model` field instead.
944 #[must_use]
945 pub fn get_model_ptr(&self) -> *const llama_cpp_sys_4::llama_model {
946 unsafe { llama_cpp_sys_4::llama_get_model(self.context.as_ptr()) }
947 }
948
949 /// Sets a lora adapter.
950 ///
951 /// # Errors
952 ///
953 /// See [`LlamaLoraAdapterSetError`] for more information.
954 pub fn lora_adapter_set(
955 &self,
956 adapter: &mut LlamaLoraAdapter,
957 scale: f32,
958 ) -> Result<(), LlamaLoraAdapterSetError> {
959 let err_code = unsafe {
960 // llama_set_adapter_lora / llama_rm_adapter_lora were replaced by llama_set_adapters_lora
961 // which takes a full list of adapters + scales at once (b8249+)
962 let mut adapter_ptr = adapter.lora_adapter.as_ptr();
963 let mut scale_val = scale;
964 llama_cpp_sys_4::llama_set_adapters_lora(
965 self.context.as_ptr(),
966 &raw mut adapter_ptr,
967 1,
968 &raw mut scale_val,
969 )
970 };
971 if err_code != 0 {
972 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
973 }
974
975 tracing::debug!("Set lora adapter");
976 Ok(())
977 }
978
979 /// Remove all lora adapters from the context.
980 ///
981 /// Note: as of llama.cpp b8249 the per-adapter remove API was replaced by
982 /// `llama_set_adapters_lora` which operates on the full adapter list at once.
983 /// Calling this function clears **all** adapters currently set on the context.
984 ///
985 /// # Errors
986 ///
987 /// See [`LlamaLoraAdapterRemoveError`] for more information.
988 pub fn lora_adapter_remove(
989 &self,
990 _adapter: &mut LlamaLoraAdapter,
991 ) -> Result<(), LlamaLoraAdapterRemoveError> {
992 let err_code = unsafe {
993 llama_cpp_sys_4::llama_set_adapters_lora(
994 self.context.as_ptr(),
995 std::ptr::null_mut(),
996 0,
997 std::ptr::null_mut(),
998 )
999 };
1000 if err_code != 0 {
1001 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
1002 }
1003
1004 tracing::debug!("Remove lora adapter");
1005 Ok(())
1006 }
1007}
1008
1009impl Drop for LlamaContext<'_> {
1010 fn drop(&mut self) {
1011 unsafe { llama_cpp_sys_4::llama_free(self.context.as_ptr()) }
1012 }
1013}