llama_cpp_4/context.rs
1//! Safe wrapper around `llama_context`.
2
3use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use llama_cpp_sys_4::llama_pooling_type;
9use params::LlamaPoolingType;
10use perf::PerfContextData;
11
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::token::data::LlamaTokenData;
15use crate::token::data_array::LlamaTokenDataArray;
16use crate::token::LlamaToken;
17use crate::{
18 DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
19 LlamaLoraAdapterSetError,
20};
21
22pub mod kv_cache;
23pub mod params;
24pub mod perf;
25pub mod session;
26pub mod tensor_capture;
27
28/// A safe wrapper around the `llama_context` C++ context.
29///
30/// This struct provides a safe interface to interact with the `llama_context` used by the `LlamaModel`.
31/// It encapsulates the raw C++ context pointer and provides additional fields for managing the model and
32/// context-specific settings like embeddings and logits.
33///
34/// The `LlamaContext` struct ensures that the C++ context is always valid by using the `NonNull` type for
35/// the context pointer, preventing it from being null. The struct also holds a reference to the model
36/// (`LlamaModel`) that the context is tied to, along with some internal state like whether embeddings are enabled
37/// and the initialized logits for the context.
38///
39/// # Fields
40///
41/// - `context`: A non-null pointer to the raw C++ `llama_context`. This is the main context used for interacting with the model.
42/// - `model`: A reference to the `LlamaModel` associated with this context. This model provides the data and parameters
43/// that the context interacts with.
44/// - `initialized_logits`: A vector used to store the initialized logits. These are used in the model's processing and
45/// are kept separate from the context data.
46/// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in the context. This is useful for
47/// controlling whether embedding data is generated during the interaction with the model.
48#[allow(clippy::module_name_repetitions)]
49pub struct LlamaContext<'a> {
50 pub(crate) context: NonNull<llama_cpp_sys_4::llama_context>,
51 /// a reference to the contexts model.
52 pub model: &'a LlamaModel,
53 initialized_logits: Vec<i32>,
54 embeddings_enabled: bool,
55}
56
57impl Debug for LlamaContext<'_> {
58 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("LlamaContext")
60 .field("context", &self.context)
61 .finish()
62 }
63}
64
65impl<'model> LlamaContext<'model> {
66 /// Creates a new instance of `LlamaContext` with the provided model, context, and embeddings flag.
67 ///
68 /// This function initializes a new `LlamaContext` object, which is used to interact with the
69 /// `LlamaModel`. The context is created from a pointer to a C++ context and the embeddings flag
70 /// determines whether embeddings are enabled in the context.
71 ///
72 /// # Parameters
73 ///
74 /// - `llama_model`: A reference to an existing `LlamaModel` that will be used with the new context.
75 /// - `llama_context`: A non-null pointer to an existing `llama_cpp_sys_4::llama_context` representing
76 /// the context created in previous steps. This context is necessary for interacting with the model.
77 /// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in this context.
78 ///
79 /// # Returns
80 ///
81 /// This function returns a new instance of `LlamaContext` initialized with the given parameters:
82 /// - The model reference (`llama_model`) is stored in the context.
83 /// - The raw context pointer (`llama_context`) is wrapped in a `NonNull` to ensure safety.
84 /// - The `embeddings_enabled` flag is used to determine if embeddings are enabled for the context.
85 ///
86 /// # Example
87 /// ```ignore
88 /// let llama_model = LlamaModel::load_from_file(&backend, "path/to/model", ¶ms).unwrap();
89 /// let context_ptr = NonNull::new(some_llama_context_ptr).unwrap();
90 /// let context = LlamaContext::new(&llama_model, context_ptr, true);
91 /// // Now you can use the context
92 /// ```
93 pub(crate) fn new(
94 llama_model: &'model LlamaModel,
95 llama_context: NonNull<llama_cpp_sys_4::llama_context>,
96 embeddings_enabled: bool,
97 ) -> Self {
98 Self {
99 context: llama_context,
100 model: llama_model,
101 initialized_logits: Vec::new(),
102 embeddings_enabled,
103 }
104 }
105
106 /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to `n_ubatch`.
107 #[must_use]
108 pub fn n_batch(&self) -> u32 {
109 unsafe { llama_cpp_sys_4::llama_n_batch(self.context.as_ptr()) }
110 }
111
112 /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to `n_batch`.
113 #[must_use]
114 pub fn n_ubatch(&self) -> u32 {
115 unsafe { llama_cpp_sys_4::llama_n_ubatch(self.context.as_ptr()) }
116 }
117
118 /// Gets the size of the context.
119 #[must_use]
120 pub fn n_ctx(&self) -> u32 {
121 unsafe { llama_cpp_sys_4::llama_n_ctx(self.context.as_ptr()) }
122 }
123
124 /// Decodes the batch.
125 ///
126 /// # Errors
127 ///
128 /// - `DecodeError` if the decoding failed.
129 ///
130 /// # Panics
131 ///
132 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
133 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
134 let result =
135 unsafe { llama_cpp_sys_4::llama_decode(self.context.as_ptr(), batch.llama_batch) };
136
137 match NonZeroI32::new(result) {
138 None => {
139 self.initialized_logits
140 .clone_from(&batch.initialized_logits);
141 Ok(())
142 }
143 Some(error) => Err(DecodeError::from(error)),
144 }
145 }
146
147 /// Encodes the batch.
148 ///
149 /// # Errors
150 ///
151 /// - `EncodeError` if the decoding failed.
152 ///
153 /// # Panics
154 ///
155 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
156 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
157 let result =
158 unsafe { llama_cpp_sys_4::llama_encode(self.context.as_ptr(), batch.llama_batch) };
159
160 match NonZeroI32::new(result) {
161 None => {
162 self.initialized_logits
163 .clone_from(&batch.initialized_logits);
164 Ok(())
165 }
166 Some(error) => Err(EncodeError::from(error)),
167 }
168 }
169
170 /// Return Pooling type for Llama's Context
171 #[must_use]
172 pub fn pooling_type(&self) -> LlamaPoolingType {
173 let pooling_type = unsafe { llama_pooling_type(self.context.as_ptr()) };
174
175 LlamaPoolingType::from(pooling_type)
176 }
177
178 /// Get the embeddings for the `i`th sequence in the current context.
179 ///
180 /// # Returns
181 ///
182 /// A slice containing the embeddings for the last decoded batch.
183 /// The size corresponds to the `n_embd` parameter of the context's model.
184 ///
185 /// # Errors
186 ///
187 /// - When the current context was constructed without enabling embeddings.
188 /// - If the current model had a pooling type of [`llama_cpp_sys_4::LLAMA_POOLING_TYPE_NONE`]
189 /// - If the given sequence index exceeds the max sequence id.
190 ///
191 /// # Panics
192 ///
193 /// * `n_embd` does not fit into a usize
194 pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
195 if !self.embeddings_enabled {
196 return Err(EmbeddingsError::NotEnabled);
197 }
198
199 let n_embd =
200 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
201
202 unsafe {
203 let embedding = llama_cpp_sys_4::llama_get_embeddings_seq(self.context.as_ptr(), i);
204
205 // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
206 if embedding.is_null() {
207 Err(EmbeddingsError::NonePoolType)
208 } else {
209 Ok(slice::from_raw_parts(embedding, n_embd))
210 }
211 }
212 }
213
214 /// Get the embeddings for the `i`th token in the current context.
215 ///
216 /// # Returns
217 ///
218 /// A slice containing the embeddings for the last decoded batch of the given token.
219 /// The size corresponds to the `n_embd` parameter of the context's model.
220 ///
221 /// # Errors
222 ///
223 /// - When the current context was constructed without enabling embeddings.
224 /// - When the given token didn't have logits enabled when it was passed.
225 /// - If the given token index exceeds the max token id.
226 ///
227 /// # Panics
228 ///
229 /// * `n_embd` does not fit into a usize
230 pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
231 if !self.embeddings_enabled {
232 return Err(EmbeddingsError::NotEnabled);
233 }
234
235 let n_embd =
236 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
237
238 unsafe {
239 let embedding = llama_cpp_sys_4::llama_get_embeddings_ith(self.context.as_ptr(), i);
240 // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
241 if embedding.is_null() {
242 Err(EmbeddingsError::LogitsNotEnabled)
243 } else {
244 Ok(slice::from_raw_parts(embedding, n_embd))
245 }
246 }
247 }
248
249 /// Get the logits for the last token in the context.
250 ///
251 /// # Returns
252 /// An iterator over unsorted `LlamaTokenData` containing the
253 /// logits for the last token in the context.
254 ///
255 /// # Panics
256 ///
257 /// - underlying logits data is null
258 pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
259 (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
260 let token = LlamaToken::new(i);
261 LlamaTokenData::new(token, *logit, 0_f32)
262 })
263 }
264
265 /// Get the token data array for the last token in the context.
266 ///
267 /// This is a convience method that implements:
268 /// ```ignore
269 /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
270 /// ```
271 ///
272 /// # Panics
273 ///
274 /// - underlying logits data is null
275 #[must_use]
276 pub fn token_data_array(&self) -> LlamaTokenDataArray {
277 LlamaTokenDataArray::from_iter(self.candidates(), false)
278 }
279
280 /// Token logits obtained from the last call to `decode()`.
281 /// The logits for which `batch.logits[i] != 0` are stored contiguously
282 /// in the order they have appeared in the batch.
283 /// Rows: number of tokens for which `batch.logits[i] != 0`
284 /// Cols: `n_vocab`
285 ///
286 /// # Returns
287 ///
288 /// A slice containing the logits for the last decoded token.
289 /// The size corresponds to the `n_vocab` parameter of the context's model.
290 ///
291 /// # Panics
292 ///
293 /// - `n_vocab` does not fit into a usize
294 /// - token data returned is null
295 #[must_use]
296 pub fn get_logits(&self) -> &[f32] {
297 let data = unsafe { llama_cpp_sys_4::llama_get_logits(self.context.as_ptr()) };
298 assert!(!data.is_null(), "logits data for last token is null");
299 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
300
301 unsafe { slice::from_raw_parts(data, len) }
302 }
303
304 /// Get the logits for the ith token in the context.
305 ///
306 /// # Panics
307 ///
308 /// - logit `i` is not initialized.
309 pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
310 (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
311 let token = LlamaToken::new(i);
312 LlamaTokenData::new(token, *logit, 0_f32)
313 })
314 }
315
316 /// Get the logits for the ith token in the context.
317 ///
318 /// # Panics
319 ///
320 /// - `i` is greater than `n_ctx`
321 /// - `n_vocab` does not fit into a usize
322 /// - logit `i` is not initialized.
323 #[must_use]
324 pub fn get_logits_ith(&self, i: i32) -> &[f32] {
325 assert!(
326 self.initialized_logits.contains(&i),
327 "logit {i} is not initialized. only {:?} is",
328 self.initialized_logits
329 );
330 assert!(
331 self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
332 "n_ctx ({}) must be greater than i ({})",
333 self.n_ctx(),
334 i
335 );
336
337 let data = unsafe { llama_cpp_sys_4::llama_get_logits_ith(self.context.as_ptr(), i) };
338 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
339
340 unsafe { slice::from_raw_parts(data, len) }
341 }
342
343 /// Get the number of context tokens per sequence.
344 #[must_use]
345 pub fn n_ctx_seq(&self) -> u32 {
346 unsafe { llama_cpp_sys_4::llama_n_ctx_seq(self.context.as_ptr()) }
347 }
348
349 /// Get the maximum number of sequences.
350 #[must_use]
351 pub fn n_seq_max(&self) -> u32 {
352 unsafe { llama_cpp_sys_4::llama_n_seq_max(self.context.as_ptr()) }
353 }
354
355 /// Get the number of threads used for generation.
356 #[must_use]
357 pub fn n_threads(&self) -> i32 {
358 unsafe { llama_cpp_sys_4::llama_n_threads(self.context.as_ptr()) }
359 }
360
361 /// Get the number of threads used for batch processing.
362 #[must_use]
363 pub fn n_threads_batch(&self) -> i32 {
364 unsafe { llama_cpp_sys_4::llama_n_threads_batch(self.context.as_ptr()) }
365 }
366
367 /// Set the number of threads used for generation and batch processing.
368 pub fn set_n_threads(&mut self, n_threads: i32, n_threads_batch: i32) {
369 unsafe {
370 llama_cpp_sys_4::llama_set_n_threads(self.context.as_ptr(), n_threads, n_threads_batch);
371 }
372 }
373
374 /// Set whether to use causal attention.
375 ///
376 /// If set to `false`, the model will use non-causal attention, which is
377 /// needed for embedding models.
378 pub fn set_causal_attn(&mut self, causal_attn: bool) {
379 unsafe {
380 llama_cpp_sys_4::llama_set_causal_attn(self.context.as_ptr(), causal_attn);
381 }
382 }
383
384 /// Set whether to compute embeddings.
385 ///
386 /// This allows toggling embedding mode at runtime (as opposed to only at
387 /// context creation time).
388 pub fn set_embeddings(&mut self, embeddings: bool) {
389 self.embeddings_enabled = embeddings;
390 unsafe {
391 llama_cpp_sys_4::llama_set_embeddings(self.context.as_ptr(), embeddings);
392 }
393 }
394
395 /// Mark the next computation as a warmup run.
396 ///
397 /// Warmup runs are useful for GPU backends to compile kernels before
398 /// actual inference begins.
399 pub fn set_warmup(&mut self, warmup: bool) {
400 unsafe {
401 llama_cpp_sys_4::llama_set_warmup(self.context.as_ptr(), warmup);
402 }
403 }
404
405 /// Wait for all pending async computations to finish.
406 pub fn synchronize(&mut self) {
407 unsafe {
408 llama_cpp_sys_4::llama_synchronize(self.context.as_ptr());
409 }
410 }
411
412 /// Get all embeddings for the current context.
413 ///
414 /// Returns a slice of all embeddings from the last decoded batch.
415 /// For pooled embeddings use [`embeddings_seq_ith`](Self::embeddings_seq_ith) instead.
416 ///
417 /// # Errors
418 ///
419 /// - When the current context was constructed without enabling embeddings.
420 /// - If the embeddings pointer is null.
421 ///
422 /// # Panics
423 ///
424 /// * `n_embd` does not fit into a usize
425 pub fn get_embeddings(&self) -> Result<&[f32], EmbeddingsError> {
426 if !self.embeddings_enabled {
427 return Err(EmbeddingsError::NotEnabled);
428 }
429
430 let n_embd =
431 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
432
433 unsafe {
434 let embedding = llama_cpp_sys_4::llama_get_embeddings(self.context.as_ptr());
435 if embedding.is_null() {
436 Err(EmbeddingsError::NonePoolType)
437 } else {
438 Ok(slice::from_raw_parts(embedding, n_embd))
439 }
440 }
441 }
442
443 /// Reset the timings for the context.
444 pub fn reset_timings(&mut self) {
445 unsafe { llama_cpp_sys_4::ggml_time_init() }
446 }
447
448 /// Returns the timings for the context.
449 pub fn timings(&mut self) -> PerfContextData {
450 let perf_context_data =
451 unsafe { llama_cpp_sys_4::llama_perf_context(self.context.as_ptr()) };
452 PerfContextData { perf_context_data }
453 }
454
455 /// Reset the performance counters for the context.
456 pub fn perf_context_reset(&mut self) {
457 unsafe { llama_cpp_sys_4::llama_perf_context_reset(self.context.as_ptr()) }
458 }
459
460 /// Check if the KV cache memory supports shifting.
461 #[must_use]
462 pub fn memory_can_shift(&self) -> bool {
463 unsafe {
464 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
465 llama_cpp_sys_4::llama_memory_can_shift(mem)
466 }
467 }
468
469 /// Get the minimum position in a sequence's KV cache.
470 #[must_use]
471 pub fn memory_seq_pos_min(&self, seq_id: i32) -> i32 {
472 unsafe {
473 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
474 llama_cpp_sys_4::llama_memory_seq_pos_min(mem, seq_id)
475 }
476 }
477
478 /// Print a breakdown of the memory usage.
479 pub fn memory_breakdown_print(&self) {
480 unsafe {
481 llama_cpp_sys_4::common_memory_breakdown_print(self.context.as_ptr());
482 }
483 }
484
485 /// Get the size of the full context state in bytes.
486 ///
487 /// This is the size needed for [`state_get_data`](Self::state_get_data) and
488 /// [`state_set_data`](Self::state_set_data).
489 #[must_use]
490 pub fn state_get_size(&mut self) -> usize {
491 unsafe { llama_cpp_sys_4::llama_state_get_size(self.context.as_ptr()) }
492 }
493
494 /// Copy the full context state into a byte buffer.
495 ///
496 /// The buffer must be at least [`state_get_size`](Self::state_get_size) bytes.
497 ///
498 /// Returns the number of bytes written.
499 pub fn state_get_data(&mut self, dst: &mut [u8]) -> usize {
500 unsafe {
501 llama_cpp_sys_4::llama_state_get_data(
502 self.context.as_ptr(),
503 dst.as_mut_ptr(),
504 dst.len(),
505 )
506 }
507 }
508
509 /// Restore the full context state from a byte buffer.
510 ///
511 /// Returns the number of bytes read.
512 pub fn state_set_data(&mut self, src: &[u8]) -> usize {
513 unsafe {
514 llama_cpp_sys_4::llama_state_set_data(self.context.as_ptr(), src.as_ptr(), src.len())
515 }
516 }
517
518 /// Save the context state to a file along with the given tokens.
519 ///
520 /// Returns `true` on success.
521 ///
522 /// # Panics
523 ///
524 /// Panics if the path contains null bytes.
525 pub fn state_save_file(
526 &mut self,
527 path: impl AsRef<std::path::Path>,
528 tokens: &[LlamaToken],
529 ) -> bool {
530 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
531 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
532 unsafe {
533 llama_cpp_sys_4::llama_state_save_file(
534 self.context.as_ptr(),
535 c_path.as_ptr(),
536 tokens.as_ptr().cast(),
537 tokens.len(),
538 )
539 }
540 }
541
542 /// Load a context state from a file.
543 ///
544 /// Returns `true` on success and fills `tokens_out` with the saved tokens.
545 ///
546 /// # Panics
547 ///
548 /// Panics if the path contains null bytes.
549 pub fn state_load_file(
550 &mut self,
551 path: impl AsRef<std::path::Path>,
552 tokens_out: &mut Vec<LlamaToken>,
553 n_token_capacity: usize,
554 ) -> bool {
555 tokens_out.resize(n_token_capacity, LlamaToken(0));
556 let mut n_token_count: usize = 0;
557 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
558 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
559 let ok = unsafe {
560 llama_cpp_sys_4::llama_state_load_file(
561 self.context.as_ptr(),
562 c_path.as_ptr(),
563 tokens_out.as_mut_ptr().cast(),
564 n_token_capacity,
565 std::ptr::addr_of_mut!(n_token_count),
566 )
567 };
568 if ok {
569 tokens_out.truncate(n_token_count);
570 }
571 ok
572 }
573
574 /// Get the size of a single sequence's state in bytes.
575 #[must_use]
576 pub fn state_seq_get_size(&mut self, seq_id: i32) -> usize {
577 unsafe { llama_cpp_sys_4::llama_state_seq_get_size(self.context.as_ptr(), seq_id) }
578 }
579
580 /// Copy a single sequence's state into a byte buffer.
581 ///
582 /// Returns the number of bytes written.
583 pub fn state_seq_get_data(&mut self, dst: &mut [u8], seq_id: i32) -> usize {
584 unsafe {
585 llama_cpp_sys_4::llama_state_seq_get_data(
586 self.context.as_ptr(),
587 dst.as_mut_ptr(),
588 dst.len(),
589 seq_id,
590 )
591 }
592 }
593
594 /// Restore a single sequence's state from a byte buffer.
595 ///
596 /// Returns the number of bytes read.
597 pub fn state_seq_set_data(&mut self, src: &[u8], dest_seq_id: i32) -> usize {
598 unsafe {
599 llama_cpp_sys_4::llama_state_seq_set_data(
600 self.context.as_ptr(),
601 src.as_ptr(),
602 src.len(),
603 dest_seq_id,
604 )
605 }
606 }
607
608 /// Save a single sequence's state to a file.
609 ///
610 /// Returns the number of bytes written (0 on failure).
611 ///
612 /// # Panics
613 ///
614 /// Panics if the path contains null bytes.
615 pub fn state_seq_save_file(
616 &mut self,
617 path: impl AsRef<std::path::Path>,
618 seq_id: i32,
619 tokens: &[LlamaToken],
620 ) -> usize {
621 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
622 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
623 unsafe {
624 llama_cpp_sys_4::llama_state_seq_save_file(
625 self.context.as_ptr(),
626 c_path.as_ptr(),
627 seq_id,
628 tokens.as_ptr().cast(),
629 tokens.len(),
630 )
631 }
632 }
633
634 /// Load a single sequence's state from a file.
635 ///
636 /// Returns the number of bytes read (0 on failure).
637 ///
638 /// # Panics
639 ///
640 /// Panics if the path contains null bytes.
641 pub fn state_seq_load_file(
642 &mut self,
643 path: impl AsRef<std::path::Path>,
644 dest_seq_id: i32,
645 tokens_out: &mut Vec<LlamaToken>,
646 n_token_capacity: usize,
647 ) -> usize {
648 tokens_out.resize(n_token_capacity, LlamaToken(0));
649 let mut n_token_count: usize = 0;
650 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
651 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
652 let ret = unsafe {
653 llama_cpp_sys_4::llama_state_seq_load_file(
654 self.context.as_ptr(),
655 c_path.as_ptr(),
656 dest_seq_id,
657 tokens_out.as_mut_ptr().cast(),
658 n_token_capacity,
659 std::ptr::addr_of_mut!(n_token_count),
660 )
661 };
662 if ret > 0 {
663 tokens_out.truncate(n_token_count);
664 }
665 ret
666 }
667
668 /// Set a control vector on the context.
669 ///
670 /// # Parameters
671 ///
672 /// - `data`: The control vector data (embedding values). Pass an empty slice to clear.
673 /// - `n_embd`: The embedding dimension.
674 /// - `il_start`: The starting layer index (inclusive).
675 /// - `il_end`: The ending layer index (exclusive).
676 ///
677 /// # Errors
678 ///
679 /// Returns `Err` with the error code if the operation fails.
680 pub fn set_adapter_cvec(
681 &mut self,
682 data: &[f32],
683 n_embd: i32,
684 il_start: i32,
685 il_end: i32,
686 ) -> Result<(), i32> {
687 let ret = unsafe {
688 llama_cpp_sys_4::llama_set_adapter_cvec(
689 self.context.as_ptr(),
690 data.as_ptr(),
691 data.len(),
692 n_embd,
693 il_start,
694 il_end,
695 )
696 };
697 if ret != 0 {
698 Err(ret)
699 } else {
700 Ok(())
701 }
702 }
703
704 /// Get sampled token debug info for the `i`th position.
705 ///
706 /// Returns the sampled token at position `i` from the last decode call.
707 #[must_use]
708 pub fn get_sampled_token_ith(&self, i: i32) -> LlamaToken {
709 let token =
710 unsafe { llama_cpp_sys_4::llama_get_sampled_token_ith(self.context.as_ptr(), i) };
711 LlamaToken(token)
712 }
713
714 /// Get sampled candidate tokens for the `i`th position.
715 ///
716 /// Returns a slice of candidate tokens from the last decode call.
717 #[must_use]
718 pub fn get_sampled_candidates_ith(&self, i: i32) -> &[LlamaToken] {
719 let count = unsafe {
720 llama_cpp_sys_4::llama_get_sampled_candidates_count_ith(self.context.as_ptr(), i)
721 } as usize;
722 if count == 0 {
723 return &[];
724 }
725 let ptr =
726 unsafe { llama_cpp_sys_4::llama_get_sampled_candidates_ith(self.context.as_ptr(), i) };
727 if ptr.is_null() {
728 return &[];
729 }
730 unsafe { slice::from_raw_parts(ptr.cast::<LlamaToken>(), count) }
731 }
732
733 /// Get the number of sampled logits for the `i`th position.
734 #[must_use]
735 pub fn get_sampled_logits_count_ith(&self, i: i32) -> u32 {
736 unsafe { llama_cpp_sys_4::llama_get_sampled_logits_count_ith(self.context.as_ptr(), i) }
737 }
738
739 /// Get sampled logits for the `i`th position.
740 ///
741 /// Returns a slice of logit values from the last decode call.
742 #[must_use]
743 pub fn get_sampled_logits_ith(&self, i: i32) -> &[f32] {
744 let count = self.get_sampled_logits_count_ith(i) as usize;
745 if count == 0 {
746 return &[];
747 }
748 let ptr =
749 unsafe { llama_cpp_sys_4::llama_get_sampled_logits_ith(self.context.as_ptr(), i) };
750 if ptr.is_null() {
751 return &[];
752 }
753 unsafe { slice::from_raw_parts(ptr, count) }
754 }
755
756 /// Get the number of sampled probabilities for the `i`th position.
757 #[must_use]
758 pub fn get_sampled_probs_count_ith(&self, i: i32) -> u32 {
759 unsafe { llama_cpp_sys_4::llama_get_sampled_probs_count_ith(self.context.as_ptr(), i) }
760 }
761
762 /// Get sampled probabilities for the `i`th position.
763 ///
764 /// Returns a slice of probability values from the last decode call.
765 #[must_use]
766 pub fn get_sampled_probs_ith(&self, i: i32) -> &[f32] {
767 let count = self.get_sampled_probs_count_ith(i) as usize;
768 if count == 0 {
769 return &[];
770 }
771 let ptr = unsafe { llama_cpp_sys_4::llama_get_sampled_probs_ith(self.context.as_ptr(), i) };
772 if ptr.is_null() {
773 return &[];
774 }
775 unsafe { slice::from_raw_parts(ptr, count) }
776 }
777
778 /// Get the size of a single sequence's state with flags.
779 #[must_use]
780 pub fn state_seq_get_size_ext(&mut self, seq_id: i32, flags: u32) -> usize {
781 unsafe {
782 llama_cpp_sys_4::llama_state_seq_get_size_ext(self.context.as_ptr(), seq_id, flags)
783 }
784 }
785
786 /// Copy a single sequence's state into a byte buffer with flags.
787 ///
788 /// Returns the number of bytes written.
789 pub fn state_seq_get_data_ext(&mut self, dst: &mut [u8], seq_id: i32, flags: u32) -> usize {
790 unsafe {
791 llama_cpp_sys_4::llama_state_seq_get_data_ext(
792 self.context.as_ptr(),
793 dst.as_mut_ptr(),
794 dst.len(),
795 seq_id,
796 flags,
797 )
798 }
799 }
800
801 /// Restore a single sequence's state from a byte buffer with flags.
802 ///
803 /// Returns the number of bytes read.
804 pub fn state_seq_set_data_ext(&mut self, src: &[u8], dest_seq_id: i32, flags: u32) -> usize {
805 unsafe {
806 llama_cpp_sys_4::llama_state_seq_set_data_ext(
807 self.context.as_ptr(),
808 src.as_ptr(),
809 src.len(),
810 dest_seq_id,
811 flags,
812 )
813 }
814 }
815
816 /// Set an abort callback for the context.
817 ///
818 /// The callback is called periodically during computation. If it returns `true`,
819 /// the computation is aborted.
820 ///
821 /// # Safety
822 ///
823 /// The callback data must remain valid for the lifetime of the context or until
824 /// the callback is replaced.
825 pub unsafe fn set_abort_callback(
826 &mut self,
827 callback: llama_cpp_sys_4::ggml_abort_callback,
828 data: *mut std::ffi::c_void,
829 ) {
830 llama_cpp_sys_4::llama_set_abort_callback(self.context.as_ptr(), callback, data);
831 }
832
833 /// Attach a thread pool to the context.
834 ///
835 /// # Safety
836 ///
837 /// The thread pools must remain valid for the lifetime of the context or until
838 /// they are detached.
839 pub unsafe fn attach_threadpool(
840 &mut self,
841 threadpool: llama_cpp_sys_4::ggml_threadpool_t,
842 threadpool_batch: llama_cpp_sys_4::ggml_threadpool_t,
843 ) {
844 llama_cpp_sys_4::llama_attach_threadpool(
845 self.context.as_ptr(),
846 threadpool,
847 threadpool_batch,
848 );
849 }
850
851 /// Detach the thread pool from the context.
852 pub fn detach_threadpool(&mut self) {
853 unsafe {
854 llama_cpp_sys_4::llama_detach_threadpool(self.context.as_ptr());
855 }
856 }
857
858 /// Set a sampler for a specific sequence.
859 ///
860 /// Returns `true` on success.
861 pub fn set_sampler(
862 &mut self,
863 seq_id: i32,
864 sampler: &mut crate::sampling::LlamaSampler,
865 ) -> bool {
866 unsafe {
867 llama_cpp_sys_4::llama_set_sampler(
868 self.context.as_ptr(),
869 seq_id,
870 sampler.sampler.as_ptr(),
871 )
872 }
873 }
874
875 /// Get the raw model pointer from this context.
876 ///
877 /// This is mainly useful for FFI interop. In normal usage, access
878 /// the model via the `model` field instead.
879 #[must_use]
880 pub fn get_model_ptr(&self) -> *const llama_cpp_sys_4::llama_model {
881 unsafe { llama_cpp_sys_4::llama_get_model(self.context.as_ptr()) }
882 }
883
884 /// Sets a lora adapter.
885 ///
886 /// # Errors
887 ///
888 /// See [`LlamaLoraAdapterSetError`] for more information.
889 pub fn lora_adapter_set(
890 &self,
891 adapter: &mut LlamaLoraAdapter,
892 scale: f32,
893 ) -> Result<(), LlamaLoraAdapterSetError> {
894 let err_code = unsafe {
895 // llama_set_adapter_lora / llama_rm_adapter_lora were replaced by llama_set_adapters_lora
896 // which takes a full list of adapters + scales at once (b8249+)
897 let mut adapter_ptr = adapter.lora_adapter.as_ptr();
898 let mut scale_val = scale;
899 llama_cpp_sys_4::llama_set_adapters_lora(
900 self.context.as_ptr(),
901 &raw mut adapter_ptr,
902 1,
903 &raw mut scale_val,
904 )
905 };
906 if err_code != 0 {
907 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
908 }
909
910 tracing::debug!("Set lora adapter");
911 Ok(())
912 }
913
914 /// Remove all lora adapters from the context.
915 ///
916 /// Note: as of llama.cpp b8249 the per-adapter remove API was replaced by
917 /// `llama_set_adapters_lora` which operates on the full adapter list at once.
918 /// Calling this function clears **all** adapters currently set on the context.
919 ///
920 /// # Errors
921 ///
922 /// See [`LlamaLoraAdapterRemoveError`] for more information.
923 pub fn lora_adapter_remove(
924 &self,
925 _adapter: &mut LlamaLoraAdapter,
926 ) -> Result<(), LlamaLoraAdapterRemoveError> {
927 let err_code = unsafe {
928 llama_cpp_sys_4::llama_set_adapters_lora(
929 self.context.as_ptr(),
930 std::ptr::null_mut(),
931 0,
932 std::ptr::null_mut(),
933 )
934 };
935 if err_code != 0 {
936 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
937 }
938
939 tracing::debug!("Remove lora adapter");
940 Ok(())
941 }
942}
943
944impl Drop for LlamaContext<'_> {
945 fn drop(&mut self) {
946 unsafe { llama_cpp_sys_4::llama_free(self.context.as_ptr()) }
947 }
948}