llama_cpp_4/context.rs
1//! Safe wrapper around `llama_context`.
2
3use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use llama_cpp_sys_4::llama_pooling_type;
9use params::LlamaPoolingType;
10use perf::PerfContextData;
11
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::token::data::LlamaTokenData;
15use crate::token::data_array::LlamaTokenDataArray;
16use crate::token::LlamaToken;
17use crate::{
18 DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
19 LlamaLoraAdapterSetError,
20};
21
22pub mod kv_cache;
23pub mod params;
24pub mod perf;
25pub mod session;
26pub mod tensor_capture;
27
28/// A safe wrapper around the `llama_context` C++ context.
29///
30/// This struct provides a safe interface to interact with the `llama_context` used by the `LlamaModel`.
31/// It encapsulates the raw C++ context pointer and provides additional fields for managing the model and
32/// context-specific settings like embeddings and logits.
33///
34/// The `LlamaContext` struct ensures that the C++ context is always valid by using the `NonNull` type for
35/// the context pointer, preventing it from being null. The struct also holds a reference to the model
36/// (`LlamaModel`) that the context is tied to, along with some internal state like whether embeddings are enabled
37/// and the initialized logits for the context.
38///
39/// # Fields
40///
41/// - `context`: A non-null pointer to the raw C++ `llama_context`. This is the main context used for interacting with the model.
42/// - `model`: A reference to the `LlamaModel` associated with this context. This model provides the data and parameters
43/// that the context interacts with.
44/// - `initialized_logits`: A vector used to store the initialized logits. These are used in the model's processing and
45/// are kept separate from the context data.
46/// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in the context. This is useful for
47/// controlling whether embedding data is generated during the interaction with the model.
48#[allow(clippy::module_name_repetitions)]
49pub struct LlamaContext<'a> {
50 pub(crate) context: NonNull<llama_cpp_sys_4::llama_context>,
51 /// a reference to the contexts model.
52 pub model: &'a LlamaModel,
53 initialized_logits: Vec<i32>,
54 embeddings_enabled: bool,
55}
56
57impl Debug for LlamaContext<'_> {
58 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("LlamaContext")
60 .field("context", &self.context)
61 .finish()
62 }
63}
64
65impl<'model> LlamaContext<'model> {
66 /// Creates a new instance of `LlamaContext` with the provided model, context, and embeddings flag.
67 ///
68 /// This function initializes a new `LlamaContext` object, which is used to interact with the
69 /// `LlamaModel`. The context is created from a pointer to a C++ context and the embeddings flag
70 /// determines whether embeddings are enabled in the context.
71 ///
72 /// # Parameters
73 ///
74 /// - `llama_model`: A reference to an existing `LlamaModel` that will be used with the new context.
75 /// - `llama_context`: A non-null pointer to an existing `llama_cpp_sys_4::llama_context` representing
76 /// the context created in previous steps. This context is necessary for interacting with the model.
77 /// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in this context.
78 ///
79 /// # Returns
80 ///
81 /// This function returns a new instance of `LlamaContext` initialized with the given parameters:
82 /// - The model reference (`llama_model`) is stored in the context.
83 /// - The raw context pointer (`llama_context`) is wrapped in a `NonNull` to ensure safety.
84 /// - The `embeddings_enabled` flag is used to determine if embeddings are enabled for the context.
85 ///
86 /// # Example
87 /// ```ignore
88 /// let llama_model = LlamaModel::load_from_file(&backend, "path/to/model", ¶ms).unwrap();
89 /// let context_ptr = NonNull::new(some_llama_context_ptr).unwrap();
90 /// let context = LlamaContext::new(&llama_model, context_ptr, true);
91 /// // Now you can use the context
92 /// ```
93 pub(crate) fn new(
94 llama_model: &'model LlamaModel,
95 llama_context: NonNull<llama_cpp_sys_4::llama_context>,
96 embeddings_enabled: bool,
97 ) -> Self {
98 Self {
99 context: llama_context,
100 model: llama_model,
101 initialized_logits: Vec::new(),
102 embeddings_enabled,
103 }
104 }
105
106 /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to `n_ubatch`.
107 #[must_use]
108 pub fn n_batch(&self) -> u32 {
109 unsafe { llama_cpp_sys_4::llama_n_batch(self.context.as_ptr()) }
110 }
111
112 /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to `n_batch`.
113 #[must_use]
114 pub fn n_ubatch(&self) -> u32 {
115 unsafe { llama_cpp_sys_4::llama_n_ubatch(self.context.as_ptr()) }
116 }
117
118 /// Gets the size of the context.
119 #[must_use]
120 pub fn n_ctx(&self) -> u32 {
121 unsafe { llama_cpp_sys_4::llama_n_ctx(self.context.as_ptr()) }
122 }
123
124 /// Decodes the batch.
125 ///
126 /// # Errors
127 ///
128 /// - `DecodeError` if the decoding failed.
129 ///
130 /// # Panics
131 ///
132 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
133 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
134 let result =
135 unsafe { llama_cpp_sys_4::llama_decode(self.context.as_ptr(), batch.llama_batch) };
136
137 match NonZeroI32::new(result) {
138 None => {
139 self.initialized_logits
140 .clone_from(&batch.initialized_logits);
141 Ok(())
142 }
143 Some(error) => Err(DecodeError::from(error)),
144 }
145 }
146
147 /// Encodes the batch.
148 ///
149 /// # Errors
150 ///
151 /// - `EncodeError` if the decoding failed.
152 ///
153 /// # Panics
154 ///
155 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
156 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
157 let result =
158 unsafe { llama_cpp_sys_4::llama_encode(self.context.as_ptr(), batch.llama_batch) };
159
160 match NonZeroI32::new(result) {
161 None => {
162 self.initialized_logits
163 .clone_from(&batch.initialized_logits);
164 Ok(())
165 }
166 Some(error) => Err(EncodeError::from(error)),
167 }
168 }
169
170 /// Return Pooling type for Llama's Context
171 #[must_use]
172 pub fn pooling_type(&self) -> LlamaPoolingType {
173 let pooling_type = unsafe { llama_pooling_type(self.context.as_ptr()) };
174
175 LlamaPoolingType::from(pooling_type)
176 }
177
178 /// Get the embeddings for the `i`th sequence in the current context.
179 ///
180 /// # Returns
181 ///
182 /// A slice containing the embeddings for the last decoded batch.
183 /// The size corresponds to the `n_embd` parameter of the context's model.
184 ///
185 /// # Errors
186 ///
187 /// - When the current context was constructed without enabling embeddings.
188 /// - If the current model had a pooling type of [`llama_cpp_sys_4::LLAMA_POOLING_TYPE_NONE`]
189 /// - If the given sequence index exceeds the max sequence id.
190 ///
191 /// # Panics
192 ///
193 /// * `n_embd` does not fit into a usize
194 pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
195 if !self.embeddings_enabled {
196 return Err(EmbeddingsError::NotEnabled);
197 }
198
199 let n_embd =
200 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
201
202 unsafe {
203 let embedding = llama_cpp_sys_4::llama_get_embeddings_seq(self.context.as_ptr(), i);
204
205 // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
206 if embedding.is_null() {
207 Err(EmbeddingsError::NonePoolType)
208 } else {
209 Ok(slice::from_raw_parts(embedding, n_embd))
210 }
211 }
212 }
213
214 /// Get the embeddings for the `i`th token in the current context.
215 ///
216 /// # Returns
217 ///
218 /// A slice containing the embeddings for the last decoded batch of the given token.
219 /// The size corresponds to the `n_embd` parameter of the context's model.
220 ///
221 /// # Errors
222 ///
223 /// - When the current context was constructed without enabling embeddings.
224 /// - When the given token didn't have logits enabled when it was passed.
225 /// - If the given token index exceeds the max token id.
226 ///
227 /// # Panics
228 ///
229 /// * `n_embd` does not fit into a usize
230 pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
231 if !self.embeddings_enabled {
232 return Err(EmbeddingsError::NotEnabled);
233 }
234
235 let n_embd =
236 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
237
238 unsafe {
239 let embedding = llama_cpp_sys_4::llama_get_embeddings_ith(self.context.as_ptr(), i);
240 // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
241 if embedding.is_null() {
242 Err(EmbeddingsError::LogitsNotEnabled)
243 } else {
244 Ok(slice::from_raw_parts(embedding, n_embd))
245 }
246 }
247 }
248
249 /// Get the logits for the last token in the context.
250 ///
251 /// # Returns
252 /// An iterator over unsorted `LlamaTokenData` containing the
253 /// logits for the last token in the context.
254 ///
255 /// # Panics
256 ///
257 /// - underlying logits data is null
258 pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
259 (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
260 let token = LlamaToken::new(i);
261 LlamaTokenData::new(token, *logit, 0_f32)
262 })
263 }
264
265 /// Get the token data array for the last token in the context.
266 ///
267 /// This is a convience method that implements:
268 /// ```ignore
269 /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
270 /// ```
271 ///
272 /// # Panics
273 ///
274 /// - underlying logits data is null
275 #[must_use]
276 pub fn token_data_array(&self) -> LlamaTokenDataArray {
277 LlamaTokenDataArray::from_iter(self.candidates(), false)
278 }
279
280 /// Token logits obtained from the last call to `decode()`.
281 /// The logits for which `batch.logits[i] != 0` are stored contiguously
282 /// in the order they have appeared in the batch.
283 /// Rows: number of tokens for which `batch.logits[i] != 0`
284 /// Cols: `n_vocab`
285 ///
286 /// # Returns
287 ///
288 /// A slice containing the logits for the last decoded token.
289 /// The size corresponds to the `n_vocab` parameter of the context's model.
290 ///
291 /// # Panics
292 ///
293 /// - `n_vocab` does not fit into a usize
294 /// - token data returned is null
295 #[must_use]
296 pub fn get_logits(&self) -> &[f32] {
297 let data = unsafe { llama_cpp_sys_4::llama_get_logits(self.context.as_ptr()) };
298 assert!(!data.is_null(), "logits data for last token is null");
299 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
300
301 unsafe { slice::from_raw_parts(data, len) }
302 }
303
304 /// Get the logits for the ith token in the context.
305 ///
306 /// # Panics
307 ///
308 /// - logit `i` is not initialized.
309 pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
310 (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
311 let token = LlamaToken::new(i);
312 LlamaTokenData::new(token, *logit, 0_f32)
313 })
314 }
315
316 /// Get the logits for the ith token in the context.
317 ///
318 /// # Panics
319 ///
320 /// - `i` is greater than `n_ctx`
321 /// - `n_vocab` does not fit into a usize
322 /// - logit `i` is not initialized.
323 #[must_use]
324 pub fn get_logits_ith(&self, i: i32) -> &[f32] {
325 assert!(
326 self.initialized_logits.contains(&i),
327 "logit {i} is not initialized. only {:?} is",
328 self.initialized_logits
329 );
330 assert!(
331 self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
332 "n_ctx ({}) must be greater than i ({})",
333 self.n_ctx(),
334 i
335 );
336
337 let data = unsafe { llama_cpp_sys_4::llama_get_logits_ith(self.context.as_ptr(), i) };
338 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
339
340 unsafe { slice::from_raw_parts(data, len) }
341 }
342
343 /// Get the number of context tokens per sequence.
344 #[must_use]
345 pub fn n_ctx_seq(&self) -> u32 {
346 unsafe { llama_cpp_sys_4::llama_n_ctx_seq(self.context.as_ptr()) }
347 }
348
349 /// Get the maximum number of sequences.
350 #[must_use]
351 pub fn n_seq_max(&self) -> u32 {
352 unsafe { llama_cpp_sys_4::llama_n_seq_max(self.context.as_ptr()) }
353 }
354
355 /// Get the number of recurrent-state snapshots per sequence.
356 #[must_use]
357 pub fn n_rs_seq(&self) -> u32 {
358 unsafe { llama_cpp_sys_4::llama_n_rs_seq(self.context.as_ptr()) }
359 }
360
361 /// Get the number of threads used for generation.
362 #[must_use]
363 pub fn n_threads(&self) -> i32 {
364 unsafe { llama_cpp_sys_4::llama_n_threads(self.context.as_ptr()) }
365 }
366
367 /// Get the number of threads used for batch processing.
368 #[must_use]
369 pub fn n_threads_batch(&self) -> i32 {
370 unsafe { llama_cpp_sys_4::llama_n_threads_batch(self.context.as_ptr()) }
371 }
372
373 /// Set the number of threads used for generation and batch processing.
374 pub fn set_n_threads(&mut self, n_threads: i32, n_threads_batch: i32) {
375 unsafe {
376 llama_cpp_sys_4::llama_set_n_threads(self.context.as_ptr(), n_threads, n_threads_batch);
377 }
378 }
379
380 /// Set whether to use causal attention.
381 ///
382 /// If set to `false`, the model will use non-causal attention, which is
383 /// needed for embedding models.
384 pub fn set_causal_attn(&mut self, causal_attn: bool) {
385 unsafe {
386 llama_cpp_sys_4::llama_set_causal_attn(self.context.as_ptr(), causal_attn);
387 }
388 }
389
390 /// Set whether to compute embeddings.
391 ///
392 /// This allows toggling embedding mode at runtime (as opposed to only at
393 /// context creation time).
394 pub fn set_embeddings(&mut self, embeddings: bool) {
395 self.embeddings_enabled = embeddings;
396 unsafe {
397 llama_cpp_sys_4::llama_set_embeddings(self.context.as_ptr(), embeddings);
398 }
399 }
400
401 /// Mark the next computation as a warmup run.
402 ///
403 /// Warmup runs are useful for GPU backends to compile kernels before
404 /// actual inference begins.
405 pub fn set_warmup(&mut self, warmup: bool) {
406 unsafe {
407 llama_cpp_sys_4::llama_set_warmup(self.context.as_ptr(), warmup);
408 }
409 }
410
411 /// Wait for all pending async computations to finish.
412 pub fn synchronize(&mut self) {
413 unsafe {
414 llama_cpp_sys_4::llama_synchronize(self.context.as_ptr());
415 }
416 }
417
418 /// Get all embeddings for the current context.
419 ///
420 /// Returns a slice of all embeddings from the last decoded batch.
421 /// For pooled embeddings use [`embeddings_seq_ith`](Self::embeddings_seq_ith) instead.
422 ///
423 /// # Errors
424 ///
425 /// - When the current context was constructed without enabling embeddings.
426 /// - If the embeddings pointer is null.
427 ///
428 /// # Panics
429 ///
430 /// * `n_embd` does not fit into a usize
431 pub fn get_embeddings(&self) -> Result<&[f32], EmbeddingsError> {
432 if !self.embeddings_enabled {
433 return Err(EmbeddingsError::NotEnabled);
434 }
435
436 let n_embd =
437 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
438
439 unsafe {
440 let embedding = llama_cpp_sys_4::llama_get_embeddings(self.context.as_ptr());
441 if embedding.is_null() {
442 Err(EmbeddingsError::NonePoolType)
443 } else {
444 Ok(slice::from_raw_parts(embedding, n_embd))
445 }
446 }
447 }
448
449 /// Reset the timings for the context.
450 pub fn reset_timings(&mut self) {
451 unsafe { llama_cpp_sys_4::ggml_time_init() }
452 }
453
454 /// Returns the timings for the context.
455 pub fn timings(&mut self) -> PerfContextData {
456 let perf_context_data =
457 unsafe { llama_cpp_sys_4::llama_perf_context(self.context.as_ptr()) };
458 PerfContextData { perf_context_data }
459 }
460
461 /// Reset the performance counters for the context.
462 pub fn perf_context_reset(&mut self) {
463 unsafe { llama_cpp_sys_4::llama_perf_context_reset(self.context.as_ptr()) }
464 }
465
466 /// Check if the KV cache memory supports shifting.
467 #[must_use]
468 pub fn memory_can_shift(&self) -> bool {
469 unsafe {
470 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
471 llama_cpp_sys_4::llama_memory_can_shift(mem)
472 }
473 }
474
475 /// Get the minimum position in a sequence's KV cache.
476 #[must_use]
477 pub fn memory_seq_pos_min(&self, seq_id: i32) -> i32 {
478 unsafe {
479 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
480 llama_cpp_sys_4::llama_memory_seq_pos_min(mem, seq_id)
481 }
482 }
483
484 /// Print a breakdown of the memory usage.
485 pub fn memory_breakdown_print(&self) {
486 unsafe {
487 llama_cpp_sys_4::common_memory_breakdown_print(self.context.as_ptr());
488 }
489 }
490
491 /// Get the size of the full context state in bytes.
492 ///
493 /// This is the size needed for [`state_get_data`](Self::state_get_data) and
494 /// [`state_set_data`](Self::state_set_data).
495 #[must_use]
496 pub fn state_get_size(&mut self) -> usize {
497 unsafe { llama_cpp_sys_4::llama_state_get_size(self.context.as_ptr()) }
498 }
499
500 /// Copy the full context state into a byte buffer.
501 ///
502 /// The buffer must be at least [`state_get_size`](Self::state_get_size) bytes.
503 ///
504 /// Returns the number of bytes written.
505 pub fn state_get_data(&mut self, dst: &mut [u8]) -> usize {
506 unsafe {
507 llama_cpp_sys_4::llama_state_get_data(
508 self.context.as_ptr(),
509 dst.as_mut_ptr(),
510 dst.len(),
511 )
512 }
513 }
514
515 /// Restore the full context state from a byte buffer.
516 ///
517 /// Returns the number of bytes read.
518 pub fn state_set_data(&mut self, src: &[u8]) -> usize {
519 unsafe {
520 llama_cpp_sys_4::llama_state_set_data(self.context.as_ptr(), src.as_ptr(), src.len())
521 }
522 }
523
524 /// Save the context state to a file along with the given tokens.
525 ///
526 /// Returns `true` on success.
527 ///
528 /// # Panics
529 ///
530 /// Panics if the path contains null bytes.
531 pub fn state_save_file(
532 &mut self,
533 path: impl AsRef<std::path::Path>,
534 tokens: &[LlamaToken],
535 ) -> bool {
536 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
537 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
538 unsafe {
539 llama_cpp_sys_4::llama_state_save_file(
540 self.context.as_ptr(),
541 c_path.as_ptr(),
542 tokens.as_ptr().cast(),
543 tokens.len(),
544 )
545 }
546 }
547
548 /// Load a context state from a file.
549 ///
550 /// Returns `true` on success and fills `tokens_out` with the saved tokens.
551 ///
552 /// # Panics
553 ///
554 /// Panics if the path contains null bytes.
555 pub fn state_load_file(
556 &mut self,
557 path: impl AsRef<std::path::Path>,
558 tokens_out: &mut Vec<LlamaToken>,
559 n_token_capacity: usize,
560 ) -> bool {
561 tokens_out.resize(n_token_capacity, LlamaToken(0));
562 let mut n_token_count: usize = 0;
563 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
564 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
565 let ok = unsafe {
566 llama_cpp_sys_4::llama_state_load_file(
567 self.context.as_ptr(),
568 c_path.as_ptr(),
569 tokens_out.as_mut_ptr().cast(),
570 n_token_capacity,
571 std::ptr::addr_of_mut!(n_token_count),
572 )
573 };
574 if ok {
575 tokens_out.truncate(n_token_count);
576 }
577 ok
578 }
579
580 /// Get the size of a single sequence's state in bytes.
581 #[must_use]
582 pub fn state_seq_get_size(&mut self, seq_id: i32) -> usize {
583 unsafe { llama_cpp_sys_4::llama_state_seq_get_size(self.context.as_ptr(), seq_id) }
584 }
585
586 /// Copy a single sequence's state into a byte buffer.
587 ///
588 /// Returns the number of bytes written.
589 pub fn state_seq_get_data(&mut self, dst: &mut [u8], seq_id: i32) -> usize {
590 unsafe {
591 llama_cpp_sys_4::llama_state_seq_get_data(
592 self.context.as_ptr(),
593 dst.as_mut_ptr(),
594 dst.len(),
595 seq_id,
596 )
597 }
598 }
599
600 /// Restore a single sequence's state from a byte buffer.
601 ///
602 /// Returns the number of bytes read.
603 pub fn state_seq_set_data(&mut self, src: &[u8], dest_seq_id: i32) -> usize {
604 unsafe {
605 llama_cpp_sys_4::llama_state_seq_set_data(
606 self.context.as_ptr(),
607 src.as_ptr(),
608 src.len(),
609 dest_seq_id,
610 )
611 }
612 }
613
614 /// Save a single sequence's state to a file.
615 ///
616 /// Returns the number of bytes written (0 on failure).
617 ///
618 /// # Panics
619 ///
620 /// Panics if the path contains null bytes.
621 pub fn state_seq_save_file(
622 &mut self,
623 path: impl AsRef<std::path::Path>,
624 seq_id: i32,
625 tokens: &[LlamaToken],
626 ) -> usize {
627 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
628 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
629 unsafe {
630 llama_cpp_sys_4::llama_state_seq_save_file(
631 self.context.as_ptr(),
632 c_path.as_ptr(),
633 seq_id,
634 tokens.as_ptr().cast(),
635 tokens.len(),
636 )
637 }
638 }
639
640 /// Load a single sequence's state from a file.
641 ///
642 /// Returns the number of bytes read (0 on failure).
643 ///
644 /// # Panics
645 ///
646 /// Panics if the path contains null bytes.
647 pub fn state_seq_load_file(
648 &mut self,
649 path: impl AsRef<std::path::Path>,
650 dest_seq_id: i32,
651 tokens_out: &mut Vec<LlamaToken>,
652 n_token_capacity: usize,
653 ) -> usize {
654 tokens_out.resize(n_token_capacity, LlamaToken(0));
655 let mut n_token_count: usize = 0;
656 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
657 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
658 let ret = unsafe {
659 llama_cpp_sys_4::llama_state_seq_load_file(
660 self.context.as_ptr(),
661 c_path.as_ptr(),
662 dest_seq_id,
663 tokens_out.as_mut_ptr().cast(),
664 n_token_capacity,
665 std::ptr::addr_of_mut!(n_token_count),
666 )
667 };
668 if ret > 0 {
669 tokens_out.truncate(n_token_count);
670 }
671 ret
672 }
673
674 /// Set a control vector on the context.
675 ///
676 /// # Parameters
677 ///
678 /// - `data`: The control vector data (embedding values). Pass an empty slice to clear.
679 /// - `n_embd`: The embedding dimension.
680 /// - `il_start`: The starting layer index (inclusive).
681 /// - `il_end`: The ending layer index (exclusive).
682 ///
683 /// # Errors
684 ///
685 /// Returns `Err` with the error code if the operation fails.
686 pub fn set_adapter_cvec(
687 &mut self,
688 data: &[f32],
689 n_embd: i32,
690 il_start: i32,
691 il_end: i32,
692 ) -> Result<(), i32> {
693 let ret = unsafe {
694 llama_cpp_sys_4::llama_set_adapter_cvec(
695 self.context.as_ptr(),
696 data.as_ptr(),
697 data.len(),
698 n_embd,
699 il_start,
700 il_end,
701 )
702 };
703 if ret != 0 {
704 Err(ret)
705 } else {
706 Ok(())
707 }
708 }
709
710 /// Get sampled token debug info for the `i`th position.
711 ///
712 /// Returns the sampled token at position `i` from the last decode call.
713 #[must_use]
714 pub fn get_sampled_token_ith(&self, i: i32) -> LlamaToken {
715 let token =
716 unsafe { llama_cpp_sys_4::llama_get_sampled_token_ith(self.context.as_ptr(), i) };
717 LlamaToken(token)
718 }
719
720 /// Get sampled candidate tokens for the `i`th position.
721 ///
722 /// Returns a slice of candidate tokens from the last decode call.
723 #[must_use]
724 pub fn get_sampled_candidates_ith(&self, i: i32) -> &[LlamaToken] {
725 let count = unsafe {
726 llama_cpp_sys_4::llama_get_sampled_candidates_count_ith(self.context.as_ptr(), i)
727 } as usize;
728 if count == 0 {
729 return &[];
730 }
731 let ptr =
732 unsafe { llama_cpp_sys_4::llama_get_sampled_candidates_ith(self.context.as_ptr(), i) };
733 if ptr.is_null() {
734 return &[];
735 }
736 unsafe { slice::from_raw_parts(ptr.cast::<LlamaToken>(), count) }
737 }
738
739 /// Get the number of sampled logits for the `i`th position.
740 #[must_use]
741 pub fn get_sampled_logits_count_ith(&self, i: i32) -> u32 {
742 unsafe { llama_cpp_sys_4::llama_get_sampled_logits_count_ith(self.context.as_ptr(), i) }
743 }
744
745 /// Get sampled logits for the `i`th position.
746 ///
747 /// Returns a slice of logit values from the last decode call.
748 #[must_use]
749 pub fn get_sampled_logits_ith(&self, i: i32) -> &[f32] {
750 let count = self.get_sampled_logits_count_ith(i) as usize;
751 if count == 0 {
752 return &[];
753 }
754 let ptr =
755 unsafe { llama_cpp_sys_4::llama_get_sampled_logits_ith(self.context.as_ptr(), i) };
756 if ptr.is_null() {
757 return &[];
758 }
759 unsafe { slice::from_raw_parts(ptr, count) }
760 }
761
762 /// Get the number of sampled probabilities for the `i`th position.
763 #[must_use]
764 pub fn get_sampled_probs_count_ith(&self, i: i32) -> u32 {
765 unsafe { llama_cpp_sys_4::llama_get_sampled_probs_count_ith(self.context.as_ptr(), i) }
766 }
767
768 /// Get sampled probabilities for the `i`th position.
769 ///
770 /// Returns a slice of probability values from the last decode call.
771 #[must_use]
772 pub fn get_sampled_probs_ith(&self, i: i32) -> &[f32] {
773 let count = self.get_sampled_probs_count_ith(i) as usize;
774 if count == 0 {
775 return &[];
776 }
777 let ptr = unsafe { llama_cpp_sys_4::llama_get_sampled_probs_ith(self.context.as_ptr(), i) };
778 if ptr.is_null() {
779 return &[];
780 }
781 unsafe { slice::from_raw_parts(ptr, count) }
782 }
783
784 /// Get the size of a single sequence's state with flags.
785 #[must_use]
786 pub fn state_seq_get_size_ext(&mut self, seq_id: i32, flags: u32) -> usize {
787 unsafe {
788 llama_cpp_sys_4::llama_state_seq_get_size_ext(self.context.as_ptr(), seq_id, flags)
789 }
790 }
791
792 /// Copy a single sequence's state into a byte buffer with flags.
793 ///
794 /// Returns the number of bytes written.
795 pub fn state_seq_get_data_ext(&mut self, dst: &mut [u8], seq_id: i32, flags: u32) -> usize {
796 unsafe {
797 llama_cpp_sys_4::llama_state_seq_get_data_ext(
798 self.context.as_ptr(),
799 dst.as_mut_ptr(),
800 dst.len(),
801 seq_id,
802 flags,
803 )
804 }
805 }
806
807 /// Restore a single sequence's state from a byte buffer with flags.
808 ///
809 /// Returns the number of bytes read.
810 pub fn state_seq_set_data_ext(&mut self, src: &[u8], dest_seq_id: i32, flags: u32) -> usize {
811 unsafe {
812 llama_cpp_sys_4::llama_state_seq_set_data_ext(
813 self.context.as_ptr(),
814 src.as_ptr(),
815 src.len(),
816 dest_seq_id,
817 flags,
818 )
819 }
820 }
821
822 /// Set an abort callback for the context.
823 ///
824 /// The callback is called periodically during computation. If it returns `true`,
825 /// the computation is aborted.
826 ///
827 /// # Safety
828 ///
829 /// The callback data must remain valid for the lifetime of the context or until
830 /// the callback is replaced.
831 pub unsafe fn set_abort_callback(
832 &mut self,
833 callback: llama_cpp_sys_4::ggml_abort_callback,
834 data: *mut std::ffi::c_void,
835 ) {
836 llama_cpp_sys_4::llama_set_abort_callback(self.context.as_ptr(), callback, data);
837 }
838
839 /// Attach a thread pool to the context.
840 ///
841 /// # Safety
842 ///
843 /// The thread pools must remain valid for the lifetime of the context or until
844 /// they are detached.
845 pub unsafe fn attach_threadpool(
846 &mut self,
847 threadpool: llama_cpp_sys_4::ggml_threadpool_t,
848 threadpool_batch: llama_cpp_sys_4::ggml_threadpool_t,
849 ) {
850 llama_cpp_sys_4::llama_attach_threadpool(
851 self.context.as_ptr(),
852 threadpool,
853 threadpool_batch,
854 );
855 }
856
857 /// Detach the thread pool from the context.
858 pub fn detach_threadpool(&mut self) {
859 unsafe {
860 llama_cpp_sys_4::llama_detach_threadpool(self.context.as_ptr());
861 }
862 }
863
864 /// Set a sampler for a specific sequence.
865 ///
866 /// Returns `true` on success.
867 pub fn set_sampler(
868 &mut self,
869 seq_id: i32,
870 sampler: &mut crate::sampling::LlamaSampler,
871 ) -> bool {
872 unsafe {
873 llama_cpp_sys_4::llama_set_sampler(
874 self.context.as_ptr(),
875 seq_id,
876 sampler.sampler.as_ptr(),
877 )
878 }
879 }
880
881 /// Get the raw model pointer from this context.
882 ///
883 /// This is mainly useful for FFI interop. In normal usage, access
884 /// the model via the `model` field instead.
885 #[must_use]
886 pub fn get_model_ptr(&self) -> *const llama_cpp_sys_4::llama_model {
887 unsafe { llama_cpp_sys_4::llama_get_model(self.context.as_ptr()) }
888 }
889
890 /// Sets a lora adapter.
891 ///
892 /// # Errors
893 ///
894 /// See [`LlamaLoraAdapterSetError`] for more information.
895 pub fn lora_adapter_set(
896 &self,
897 adapter: &mut LlamaLoraAdapter,
898 scale: f32,
899 ) -> Result<(), LlamaLoraAdapterSetError> {
900 let err_code = unsafe {
901 // llama_set_adapter_lora / llama_rm_adapter_lora were replaced by llama_set_adapters_lora
902 // which takes a full list of adapters + scales at once (b8249+)
903 let mut adapter_ptr = adapter.lora_adapter.as_ptr();
904 let mut scale_val = scale;
905 llama_cpp_sys_4::llama_set_adapters_lora(
906 self.context.as_ptr(),
907 &raw mut adapter_ptr,
908 1,
909 &raw mut scale_val,
910 )
911 };
912 if err_code != 0 {
913 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
914 }
915
916 tracing::debug!("Set lora adapter");
917 Ok(())
918 }
919
920 /// Remove all lora adapters from the context.
921 ///
922 /// Note: as of llama.cpp b8249 the per-adapter remove API was replaced by
923 /// `llama_set_adapters_lora` which operates on the full adapter list at once.
924 /// Calling this function clears **all** adapters currently set on the context.
925 ///
926 /// # Errors
927 ///
928 /// See [`LlamaLoraAdapterRemoveError`] for more information.
929 pub fn lora_adapter_remove(
930 &self,
931 _adapter: &mut LlamaLoraAdapter,
932 ) -> Result<(), LlamaLoraAdapterRemoveError> {
933 let err_code = unsafe {
934 llama_cpp_sys_4::llama_set_adapters_lora(
935 self.context.as_ptr(),
936 std::ptr::null_mut(),
937 0,
938 std::ptr::null_mut(),
939 )
940 };
941 if err_code != 0 {
942 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
943 }
944
945 tracing::debug!("Remove lora adapter");
946 Ok(())
947 }
948}
949
950impl Drop for LlamaContext<'_> {
951 fn drop(&mut self) {
952 unsafe { llama_cpp_sys_4::llama_free(self.context.as_ptr()) }
953 }
954}