llama_cpp_4/context.rs
1//! Safe wrapper around `llama_context`.
2
3use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use llama_cpp_sys_4::llama_pooling_type;
9use params::LlamaPoolingType;
10use perf::PerfContextData;
11
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::token::data::LlamaTokenData;
15use crate::token::data_array::LlamaTokenDataArray;
16use crate::token::LlamaToken;
17use crate::{
18 DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
19 LlamaLoraAdapterSetError,
20};
21
22pub mod kv_cache;
23pub mod params;
24pub mod perf;
25pub mod session;
26
27/// A safe wrapper around the `llama_context` C++ context.
28///
29/// This struct provides a safe interface to interact with the `llama_context` used by the `LlamaModel`.
30/// It encapsulates the raw C++ context pointer and provides additional fields for managing the model and
31/// context-specific settings like embeddings and logits.
32///
33/// The `LlamaContext` struct ensures that the C++ context is always valid by using the `NonNull` type for
34/// the context pointer, preventing it from being null. The struct also holds a reference to the model
35/// (`LlamaModel`) that the context is tied to, along with some internal state like whether embeddings are enabled
36/// and the initialized logits for the context.
37///
38/// # Fields
39///
40/// - `context`: A non-null pointer to the raw C++ `llama_context`. This is the main context used for interacting with the model.
41/// - `model`: A reference to the `LlamaModel` associated with this context. This model provides the data and parameters
42/// that the context interacts with.
43/// - `initialized_logits`: A vector used to store the initialized logits. These are used in the model's processing and
44/// are kept separate from the context data.
45/// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in the context. This is useful for
46/// controlling whether embedding data is generated during the interaction with the model.
47#[allow(clippy::module_name_repetitions)]
48pub struct LlamaContext<'a> {
49 pub(crate) context: NonNull<llama_cpp_sys_4::llama_context>,
50 /// a reference to the contexts model.
51 pub model: &'a LlamaModel,
52 initialized_logits: Vec<i32>,
53 embeddings_enabled: bool,
54}
55
56impl Debug for LlamaContext<'_> {
57 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("LlamaContext")
59 .field("context", &self.context)
60 .finish()
61 }
62}
63
64impl<'model> LlamaContext<'model> {
65 /// Creates a new instance of `LlamaContext` with the provided model, context, and embeddings flag.
66 ///
67 /// This function initializes a new `LlamaContext` object, which is used to interact with the
68 /// `LlamaModel`. The context is created from a pointer to a C++ context and the embeddings flag
69 /// determines whether embeddings are enabled in the context.
70 ///
71 /// # Parameters
72 ///
73 /// - `llama_model`: A reference to an existing `LlamaModel` that will be used with the new context.
74 /// - `llama_context`: A non-null pointer to an existing `llama_cpp_sys_4::llama_context` representing
75 /// the context created in previous steps. This context is necessary for interacting with the model.
76 /// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in this context.
77 ///
78 /// # Returns
79 ///
80 /// This function returns a new instance of `LlamaContext` initialized with the given parameters:
81 /// - The model reference (`llama_model`) is stored in the context.
82 /// - The raw context pointer (`llama_context`) is wrapped in a `NonNull` to ensure safety.
83 /// - The `embeddings_enabled` flag is used to determine if embeddings are enabled for the context.
84 ///
85 /// # Example
86 /// ```
87 /// let llama_model = LlamaModel::load("path/to/model").unwrap();
88 /// let context_ptr = NonNull::new(some_llama_context_ptr).unwrap();
89 /// let context = LlamaContext::new(&llama_model, context_ptr, true);
90 /// // Now you can use the context
91 /// ```
92 pub(crate) fn new(
93 llama_model: &'model LlamaModel,
94 llama_context: NonNull<llama_cpp_sys_4::llama_context>,
95 embeddings_enabled: bool,
96 ) -> Self {
97 Self {
98 context: llama_context,
99 model: llama_model,
100 initialized_logits: Vec::new(),
101 embeddings_enabled,
102 }
103 }
104
105 /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to `n_ubatch`.
106 #[must_use]
107 pub fn n_batch(&self) -> u32 {
108 unsafe { llama_cpp_sys_4::llama_n_batch(self.context.as_ptr()) }
109 }
110
111 /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to `n_batch`.
112 #[must_use]
113 pub fn n_ubatch(&self) -> u32 {
114 unsafe { llama_cpp_sys_4::llama_n_ubatch(self.context.as_ptr()) }
115 }
116
117 /// Gets the size of the context.
118 #[must_use]
119 pub fn n_ctx(&self) -> u32 {
120 unsafe { llama_cpp_sys_4::llama_n_ctx(self.context.as_ptr()) }
121 }
122
123 /// Decodes the batch.
124 ///
125 /// # Errors
126 ///
127 /// - `DecodeError` if the decoding failed.
128 ///
129 /// # Panics
130 ///
131 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
132 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
133 let result =
134 unsafe { llama_cpp_sys_4::llama_decode(self.context.as_ptr(), batch.llama_batch) };
135
136 match NonZeroI32::new(result) {
137 None => {
138 self.initialized_logits
139 .clone_from(&batch.initialized_logits);
140 Ok(())
141 }
142 Some(error) => Err(DecodeError::from(error)),
143 }
144 }
145
146 /// Encodes the batch.
147 ///
148 /// # Errors
149 ///
150 /// - `EncodeError` if the decoding failed.
151 ///
152 /// # Panics
153 ///
154 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
155 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
156 let result =
157 unsafe { llama_cpp_sys_4::llama_encode(self.context.as_ptr(), batch.llama_batch) };
158
159 match NonZeroI32::new(result) {
160 None => {
161 self.initialized_logits
162 .clone_from(&batch.initialized_logits);
163 Ok(())
164 }
165 Some(error) => Err(EncodeError::from(error)),
166 }
167 }
168
169 /// Return Pooling type for Llama's Context
170 #[must_use]
171 pub fn pooling_type(&self) -> LlamaPoolingType {
172 let pooling_type = unsafe { llama_pooling_type(self.context.as_ptr()) };
173
174 LlamaPoolingType::from(pooling_type)
175 }
176
177 /// Get the embeddings for the `i`th sequence in the current context.
178 ///
179 /// # Returns
180 ///
181 /// A slice containing the embeddings for the last decoded batch.
182 /// The size corresponds to the `n_embd` parameter of the context's model.
183 ///
184 /// # Errors
185 ///
186 /// - When the current context was constructed without enabling embeddings.
187 /// - If the current model had a pooling type of [`llama_cpp_sys_4::LLAMA_POOLING_TYPE_NONE`]
188 /// - If the given sequence index exceeds the max sequence id.
189 ///
190 /// # Panics
191 ///
192 /// * `n_embd` does not fit into a usize
193 pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
194 if !self.embeddings_enabled {
195 return Err(EmbeddingsError::NotEnabled);
196 }
197
198 let n_embd =
199 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
200
201 unsafe {
202 let embedding = llama_cpp_sys_4::llama_get_embeddings_seq(self.context.as_ptr(), i);
203
204 // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
205 if embedding.is_null() {
206 Err(EmbeddingsError::NonePoolType)
207 } else {
208 Ok(slice::from_raw_parts(embedding, n_embd))
209 }
210 }
211 }
212
213 /// Get the embeddings for the `i`th token in the current context.
214 ///
215 /// # Returns
216 ///
217 /// A slice containing the embeddings for the last decoded batch of the given token.
218 /// The size corresponds to the `n_embd` parameter of the context's model.
219 ///
220 /// # Errors
221 ///
222 /// - When the current context was constructed without enabling embeddings.
223 /// - When the given token didn't have logits enabled when it was passed.
224 /// - If the given token index exceeds the max token id.
225 ///
226 /// # Panics
227 ///
228 /// * `n_embd` does not fit into a usize
229 pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
230 if !self.embeddings_enabled {
231 return Err(EmbeddingsError::NotEnabled);
232 }
233
234 let n_embd =
235 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
236
237 unsafe {
238 let embedding = llama_cpp_sys_4::llama_get_embeddings_ith(self.context.as_ptr(), i);
239 // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
240 if embedding.is_null() {
241 Err(EmbeddingsError::LogitsNotEnabled)
242 } else {
243 Ok(slice::from_raw_parts(embedding, n_embd))
244 }
245 }
246 }
247
248 /// Get the logits for the last token in the context.
249 ///
250 /// # Returns
251 /// An iterator over unsorted `LlamaTokenData` containing the
252 /// logits for the last token in the context.
253 ///
254 /// # Panics
255 ///
256 /// - underlying logits data is null
257 pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
258 (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
259 let token = LlamaToken::new(i);
260 LlamaTokenData::new(token, *logit, 0_f32)
261 })
262 }
263
264 /// Get the token data array for the last token in the context.
265 ///
266 /// This is a convience method that implements:
267 /// ```ignore
268 /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
269 /// ```
270 ///
271 /// # Panics
272 ///
273 /// - underlying logits data is null
274 #[must_use]
275 pub fn token_data_array(&self) -> LlamaTokenDataArray {
276 LlamaTokenDataArray::from_iter(self.candidates(), false)
277 }
278
279 /// Token logits obtained from the last call to `decode()`.
280 /// The logits for which `batch.logits[i] != 0` are stored contiguously
281 /// in the order they have appeared in the batch.
282 /// Rows: number of tokens for which `batch.logits[i] != 0`
283 /// Cols: `n_vocab`
284 ///
285 /// # Returns
286 ///
287 /// A slice containing the logits for the last decoded token.
288 /// The size corresponds to the `n_vocab` parameter of the context's model.
289 ///
290 /// # Panics
291 ///
292 /// - `n_vocab` does not fit into a usize
293 /// - token data returned is null
294 #[must_use]
295 pub fn get_logits(&self) -> &[f32] {
296 let data = unsafe { llama_cpp_sys_4::llama_get_logits(self.context.as_ptr()) };
297 assert!(!data.is_null(), "logits data for last token is null");
298 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
299
300 unsafe { slice::from_raw_parts(data, len) }
301 }
302
303 /// Get the logits for the ith token in the context.
304 ///
305 /// # Panics
306 ///
307 /// - logit `i` is not initialized.
308 pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
309 (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
310 let token = LlamaToken::new(i);
311 LlamaTokenData::new(token, *logit, 0_f32)
312 })
313 }
314
315 /// Get the logits for the ith token in the context.
316 ///
317 /// # Panics
318 ///
319 /// - `i` is greater than `n_ctx`
320 /// - `n_vocab` does not fit into a usize
321 /// - logit `i` is not initialized.
322 #[must_use]
323 pub fn get_logits_ith(&self, i: i32) -> &[f32] {
324 assert!(
325 self.initialized_logits.contains(&i),
326 "logit {i} is not initialized. only {:?} is",
327 self.initialized_logits
328 );
329 assert!(
330 self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
331 "n_ctx ({}) must be greater than i ({})",
332 self.n_ctx(),
333 i
334 );
335
336 let data = unsafe { llama_cpp_sys_4::llama_get_logits_ith(self.context.as_ptr(), i) };
337 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
338
339 unsafe { slice::from_raw_parts(data, len) }
340 }
341
342 /// Get the number of context tokens per sequence.
343 #[must_use]
344 pub fn n_ctx_seq(&self) -> u32 {
345 unsafe { llama_cpp_sys_4::llama_n_ctx_seq(self.context.as_ptr()) }
346 }
347
348 /// Get the maximum number of sequences.
349 #[must_use]
350 pub fn n_seq_max(&self) -> u32 {
351 unsafe { llama_cpp_sys_4::llama_n_seq_max(self.context.as_ptr()) }
352 }
353
354 /// Get the number of threads used for generation.
355 #[must_use]
356 pub fn n_threads(&self) -> i32 {
357 unsafe { llama_cpp_sys_4::llama_n_threads(self.context.as_ptr()) }
358 }
359
360 /// Get the number of threads used for batch processing.
361 #[must_use]
362 pub fn n_threads_batch(&self) -> i32 {
363 unsafe { llama_cpp_sys_4::llama_n_threads_batch(self.context.as_ptr()) }
364 }
365
366 /// Set the number of threads used for generation and batch processing.
367 pub fn set_n_threads(&mut self, n_threads: i32, n_threads_batch: i32) {
368 unsafe {
369 llama_cpp_sys_4::llama_set_n_threads(
370 self.context.as_ptr(),
371 n_threads,
372 n_threads_batch,
373 );
374 }
375 }
376
377 /// Set whether to use causal attention.
378 ///
379 /// If set to `false`, the model will use non-causal attention, which is
380 /// needed for embedding models.
381 pub fn set_causal_attn(&mut self, causal_attn: bool) {
382 unsafe {
383 llama_cpp_sys_4::llama_set_causal_attn(self.context.as_ptr(), causal_attn);
384 }
385 }
386
387 /// Set whether to compute embeddings.
388 ///
389 /// This allows toggling embedding mode at runtime (as opposed to only at
390 /// context creation time).
391 pub fn set_embeddings(&mut self, embeddings: bool) {
392 self.embeddings_enabled = embeddings;
393 unsafe {
394 llama_cpp_sys_4::llama_set_embeddings(self.context.as_ptr(), embeddings);
395 }
396 }
397
398 /// Mark the next computation as a warmup run.
399 ///
400 /// Warmup runs are useful for GPU backends to compile kernels before
401 /// actual inference begins.
402 pub fn set_warmup(&mut self, warmup: bool) {
403 unsafe {
404 llama_cpp_sys_4::llama_set_warmup(self.context.as_ptr(), warmup);
405 }
406 }
407
408 /// Wait for all pending async computations to finish.
409 pub fn synchronize(&mut self) {
410 unsafe {
411 llama_cpp_sys_4::llama_synchronize(self.context.as_ptr());
412 }
413 }
414
415 /// Get all embeddings for the current context.
416 ///
417 /// Returns a slice of all embeddings from the last decoded batch.
418 /// For pooled embeddings use [`embeddings_seq_ith`](Self::embeddings_seq_ith) instead.
419 ///
420 /// # Errors
421 ///
422 /// - When the current context was constructed without enabling embeddings.
423 /// - If the embeddings pointer is null.
424 ///
425 /// # Panics
426 ///
427 /// * `n_embd` does not fit into a usize
428 pub fn get_embeddings(&self) -> Result<&[f32], EmbeddingsError> {
429 if !self.embeddings_enabled {
430 return Err(EmbeddingsError::NotEnabled);
431 }
432
433 let n_embd =
434 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
435
436 unsafe {
437 let embedding = llama_cpp_sys_4::llama_get_embeddings(self.context.as_ptr());
438 if embedding.is_null() {
439 Err(EmbeddingsError::NonePoolType)
440 } else {
441 Ok(slice::from_raw_parts(embedding, n_embd))
442 }
443 }
444 }
445
446 /// Reset the timings for the context.
447 pub fn reset_timings(&mut self) {
448 unsafe { llama_cpp_sys_4::ggml_time_init() }
449 }
450
451 /// Returns the timings for the context.
452 pub fn timings(&mut self) -> PerfContextData {
453 let perf_context_data =
454 unsafe { llama_cpp_sys_4::llama_perf_context(self.context.as_ptr()) };
455 PerfContextData { perf_context_data }
456 }
457
458 /// Reset the performance counters for the context.
459 pub fn perf_context_reset(&mut self) {
460 unsafe { llama_cpp_sys_4::llama_perf_context_reset(self.context.as_ptr()) }
461 }
462
463 /// Check if the KV cache memory supports shifting.
464 #[must_use]
465 pub fn memory_can_shift(&self) -> bool {
466 unsafe {
467 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
468 llama_cpp_sys_4::llama_memory_can_shift(mem)
469 }
470 }
471
472 /// Get the minimum position in a sequence's KV cache.
473 #[must_use]
474 pub fn memory_seq_pos_min(&self, seq_id: i32) -> i32 {
475 unsafe {
476 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
477 llama_cpp_sys_4::llama_memory_seq_pos_min(mem, seq_id)
478 }
479 }
480
481 /// Print a breakdown of the memory usage.
482 pub fn memory_breakdown_print(&self) {
483 unsafe {
484 llama_cpp_sys_4::llama_memory_breakdown_print(self.context.as_ptr());
485 }
486 }
487
488 /// Get the size of the full context state in bytes.
489 ///
490 /// This is the size needed for [`state_get_data`](Self::state_get_data) and
491 /// [`state_set_data`](Self::state_set_data).
492 #[must_use]
493 pub fn state_get_size(&mut self) -> usize {
494 unsafe { llama_cpp_sys_4::llama_state_get_size(self.context.as_ptr()) }
495 }
496
497 /// Copy the full context state into a byte buffer.
498 ///
499 /// The buffer must be at least [`state_get_size`](Self::state_get_size) bytes.
500 ///
501 /// Returns the number of bytes written.
502 pub fn state_get_data(&mut self, dst: &mut [u8]) -> usize {
503 unsafe {
504 llama_cpp_sys_4::llama_state_get_data(
505 self.context.as_ptr(),
506 dst.as_mut_ptr(),
507 dst.len(),
508 )
509 }
510 }
511
512 /// Restore the full context state from a byte buffer.
513 ///
514 /// Returns the number of bytes read.
515 pub fn state_set_data(&mut self, src: &[u8]) -> usize {
516 unsafe {
517 llama_cpp_sys_4::llama_state_set_data(
518 self.context.as_ptr(),
519 src.as_ptr(),
520 src.len(),
521 )
522 }
523 }
524
525 /// Save the context state to a file along with the given tokens.
526 ///
527 /// Returns `true` on success.
528 ///
529 /// # Panics
530 ///
531 /// Panics if the path contains null bytes.
532 pub fn state_save_file(
533 &mut self,
534 path: impl AsRef<std::path::Path>,
535 tokens: &[LlamaToken],
536 ) -> bool {
537 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
538 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
539 unsafe {
540 llama_cpp_sys_4::llama_state_save_file(
541 self.context.as_ptr(),
542 c_path.as_ptr(),
543 tokens.as_ptr().cast(),
544 tokens.len(),
545 )
546 }
547 }
548
549 /// Load a context state from a file.
550 ///
551 /// Returns `true` on success and fills `tokens_out` with the saved tokens.
552 ///
553 /// # Panics
554 ///
555 /// Panics if the path contains null bytes.
556 pub fn state_load_file(
557 &mut self,
558 path: impl AsRef<std::path::Path>,
559 tokens_out: &mut Vec<LlamaToken>,
560 n_token_capacity: usize,
561 ) -> bool {
562 tokens_out.resize(n_token_capacity, LlamaToken(0));
563 let mut n_token_count: usize = 0;
564 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
565 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
566 let ok = unsafe {
567 llama_cpp_sys_4::llama_state_load_file(
568 self.context.as_ptr(),
569 c_path.as_ptr(),
570 tokens_out.as_mut_ptr().cast(),
571 n_token_capacity,
572 std::ptr::addr_of_mut!(n_token_count),
573 )
574 };
575 if ok {
576 tokens_out.truncate(n_token_count);
577 }
578 ok
579 }
580
581 /// Get the size of a single sequence's state in bytes.
582 #[must_use]
583 pub fn state_seq_get_size(&mut self, seq_id: i32) -> usize {
584 unsafe { llama_cpp_sys_4::llama_state_seq_get_size(self.context.as_ptr(), seq_id) }
585 }
586
587 /// Copy a single sequence's state into a byte buffer.
588 ///
589 /// Returns the number of bytes written.
590 pub fn state_seq_get_data(&mut self, dst: &mut [u8], seq_id: i32) -> usize {
591 unsafe {
592 llama_cpp_sys_4::llama_state_seq_get_data(
593 self.context.as_ptr(),
594 dst.as_mut_ptr(),
595 dst.len(),
596 seq_id,
597 )
598 }
599 }
600
601 /// Restore a single sequence's state from a byte buffer.
602 ///
603 /// Returns the number of bytes read.
604 pub fn state_seq_set_data(&mut self, src: &[u8], dest_seq_id: i32) -> usize {
605 unsafe {
606 llama_cpp_sys_4::llama_state_seq_set_data(
607 self.context.as_ptr(),
608 src.as_ptr(),
609 src.len(),
610 dest_seq_id,
611 )
612 }
613 }
614
615 /// Save a single sequence's state to a file.
616 ///
617 /// Returns the number of bytes written (0 on failure).
618 ///
619 /// # Panics
620 ///
621 /// Panics if the path contains null bytes.
622 pub fn state_seq_save_file(
623 &mut self,
624 path: impl AsRef<std::path::Path>,
625 seq_id: i32,
626 tokens: &[LlamaToken],
627 ) -> usize {
628 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
629 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
630 unsafe {
631 llama_cpp_sys_4::llama_state_seq_save_file(
632 self.context.as_ptr(),
633 c_path.as_ptr(),
634 seq_id,
635 tokens.as_ptr().cast(),
636 tokens.len(),
637 )
638 }
639 }
640
641 /// Load a single sequence's state from a file.
642 ///
643 /// Returns the number of bytes read (0 on failure).
644 ///
645 /// # Panics
646 ///
647 /// Panics if the path contains null bytes.
648 pub fn state_seq_load_file(
649 &mut self,
650 path: impl AsRef<std::path::Path>,
651 dest_seq_id: i32,
652 tokens_out: &mut Vec<LlamaToken>,
653 n_token_capacity: usize,
654 ) -> usize {
655 tokens_out.resize(n_token_capacity, LlamaToken(0));
656 let mut n_token_count: usize = 0;
657 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
658 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
659 let ret = unsafe {
660 llama_cpp_sys_4::llama_state_seq_load_file(
661 self.context.as_ptr(),
662 c_path.as_ptr(),
663 dest_seq_id,
664 tokens_out.as_mut_ptr().cast(),
665 n_token_capacity,
666 std::ptr::addr_of_mut!(n_token_count),
667 )
668 };
669 if ret > 0 {
670 tokens_out.truncate(n_token_count);
671 }
672 ret
673 }
674
675 /// Set a control vector on the context.
676 ///
677 /// # Parameters
678 ///
679 /// - `data`: The control vector data (embedding values). Pass an empty slice to clear.
680 /// - `n_embd`: The embedding dimension.
681 /// - `il_start`: The starting layer index (inclusive).
682 /// - `il_end`: The ending layer index (exclusive).
683 ///
684 /// # Errors
685 ///
686 /// Returns `Err` with the error code if the operation fails.
687 pub fn set_adapter_cvec(
688 &mut self,
689 data: &[f32],
690 n_embd: i32,
691 il_start: i32,
692 il_end: i32,
693 ) -> Result<(), i32> {
694 let ret = unsafe {
695 llama_cpp_sys_4::llama_set_adapter_cvec(
696 self.context.as_ptr(),
697 data.as_ptr(),
698 data.len(),
699 n_embd,
700 il_start,
701 il_end,
702 )
703 };
704 if ret != 0 {
705 Err(ret)
706 } else {
707 Ok(())
708 }
709 }
710
711 /// Get sampled token debug info for the `i`th position.
712 ///
713 /// Returns the sampled token at position `i` from the last decode call.
714 #[must_use]
715 pub fn get_sampled_token_ith(&self, i: i32) -> LlamaToken {
716 let token =
717 unsafe { llama_cpp_sys_4::llama_get_sampled_token_ith(self.context.as_ptr(), i) };
718 LlamaToken(token)
719 }
720
721 /// Get sampled candidate tokens for the `i`th position.
722 ///
723 /// Returns a slice of candidate tokens from the last decode call.
724 #[must_use]
725 pub fn get_sampled_candidates_ith(&self, i: i32) -> &[LlamaToken] {
726 let count = unsafe {
727 llama_cpp_sys_4::llama_get_sampled_candidates_count_ith(self.context.as_ptr(), i)
728 } as usize;
729 if count == 0 {
730 return &[];
731 }
732 let ptr = unsafe {
733 llama_cpp_sys_4::llama_get_sampled_candidates_ith(self.context.as_ptr(), i)
734 };
735 if ptr.is_null() {
736 return &[];
737 }
738 unsafe { slice::from_raw_parts(ptr.cast::<LlamaToken>(), count) }
739 }
740
741 /// Get the number of sampled logits for the `i`th position.
742 #[must_use]
743 pub fn get_sampled_logits_count_ith(&self, i: i32) -> u32 {
744 unsafe {
745 llama_cpp_sys_4::llama_get_sampled_logits_count_ith(self.context.as_ptr(), i)
746 }
747 }
748
749 /// Get sampled logits for the `i`th position.
750 ///
751 /// Returns a slice of logit values from the last decode call.
752 #[must_use]
753 pub fn get_sampled_logits_ith(&self, i: i32) -> &[f32] {
754 let count = self.get_sampled_logits_count_ith(i) as usize;
755 if count == 0 {
756 return &[];
757 }
758 let ptr = unsafe {
759 llama_cpp_sys_4::llama_get_sampled_logits_ith(self.context.as_ptr(), i)
760 };
761 if ptr.is_null() {
762 return &[];
763 }
764 unsafe { slice::from_raw_parts(ptr, count) }
765 }
766
767 /// Get the number of sampled probabilities for the `i`th position.
768 #[must_use]
769 pub fn get_sampled_probs_count_ith(&self, i: i32) -> u32 {
770 unsafe {
771 llama_cpp_sys_4::llama_get_sampled_probs_count_ith(self.context.as_ptr(), i)
772 }
773 }
774
775 /// Get sampled probabilities for the `i`th position.
776 ///
777 /// Returns a slice of probability values from the last decode call.
778 #[must_use]
779 pub fn get_sampled_probs_ith(&self, i: i32) -> &[f32] {
780 let count = self.get_sampled_probs_count_ith(i) as usize;
781 if count == 0 {
782 return &[];
783 }
784 let ptr = unsafe {
785 llama_cpp_sys_4::llama_get_sampled_probs_ith(self.context.as_ptr(), i)
786 };
787 if ptr.is_null() {
788 return &[];
789 }
790 unsafe { slice::from_raw_parts(ptr, count) }
791 }
792
793 /// Get the size of a single sequence's state with flags.
794 #[must_use]
795 pub fn state_seq_get_size_ext(&mut self, seq_id: i32, flags: u32) -> usize {
796 unsafe {
797 llama_cpp_sys_4::llama_state_seq_get_size_ext(self.context.as_ptr(), seq_id, flags)
798 }
799 }
800
801 /// Copy a single sequence's state into a byte buffer with flags.
802 ///
803 /// Returns the number of bytes written.
804 pub fn state_seq_get_data_ext(&mut self, dst: &mut [u8], seq_id: i32, flags: u32) -> usize {
805 unsafe {
806 llama_cpp_sys_4::llama_state_seq_get_data_ext(
807 self.context.as_ptr(),
808 dst.as_mut_ptr(),
809 dst.len(),
810 seq_id,
811 flags,
812 )
813 }
814 }
815
816 /// Restore a single sequence's state from a byte buffer with flags.
817 ///
818 /// Returns the number of bytes read.
819 pub fn state_seq_set_data_ext(
820 &mut self,
821 src: &[u8],
822 dest_seq_id: i32,
823 flags: u32,
824 ) -> usize {
825 unsafe {
826 llama_cpp_sys_4::llama_state_seq_set_data_ext(
827 self.context.as_ptr(),
828 src.as_ptr(),
829 src.len(),
830 dest_seq_id,
831 flags,
832 )
833 }
834 }
835
836 /// Set an abort callback for the context.
837 ///
838 /// The callback is called periodically during computation. If it returns `true`,
839 /// the computation is aborted.
840 ///
841 /// # Safety
842 ///
843 /// The callback data must remain valid for the lifetime of the context or until
844 /// the callback is replaced.
845 pub unsafe fn set_abort_callback(
846 &mut self,
847 callback: llama_cpp_sys_4::ggml_abort_callback,
848 data: *mut std::ffi::c_void,
849 ) {
850 llama_cpp_sys_4::llama_set_abort_callback(self.context.as_ptr(), callback, data);
851 }
852
853 /// Attach a thread pool to the context.
854 ///
855 /// # Safety
856 ///
857 /// The thread pools must remain valid for the lifetime of the context or until
858 /// they are detached.
859 pub unsafe fn attach_threadpool(
860 &mut self,
861 threadpool: llama_cpp_sys_4::ggml_threadpool_t,
862 threadpool_batch: llama_cpp_sys_4::ggml_threadpool_t,
863 ) {
864 llama_cpp_sys_4::llama_attach_threadpool(
865 self.context.as_ptr(),
866 threadpool,
867 threadpool_batch,
868 );
869 }
870
871 /// Detach the thread pool from the context.
872 pub fn detach_threadpool(&mut self) {
873 unsafe {
874 llama_cpp_sys_4::llama_detach_threadpool(self.context.as_ptr());
875 }
876 }
877
878 /// Set a sampler for a specific sequence.
879 ///
880 /// Returns `true` on success.
881 pub fn set_sampler(
882 &mut self,
883 seq_id: i32,
884 sampler: &mut crate::sampling::LlamaSampler,
885 ) -> bool {
886 unsafe {
887 llama_cpp_sys_4::llama_set_sampler(
888 self.context.as_ptr(),
889 seq_id,
890 sampler.sampler.as_ptr(),
891 )
892 }
893 }
894
895 /// Get the raw model pointer from this context.
896 ///
897 /// This is mainly useful for FFI interop. In normal usage, access
898 /// the model via the `model` field instead.
899 #[must_use]
900 pub fn get_model_ptr(&self) -> *const llama_cpp_sys_4::llama_model {
901 unsafe { llama_cpp_sys_4::llama_get_model(self.context.as_ptr()) }
902 }
903
904 /// Sets a lora adapter.
905 ///
906 /// # Errors
907 ///
908 /// See [`LlamaLoraAdapterSetError`] for more information.
909 pub fn lora_adapter_set(
910 &self,
911 adapter: &mut LlamaLoraAdapter,
912 scale: f32,
913 ) -> Result<(), LlamaLoraAdapterSetError> {
914 let err_code = unsafe {
915 // llama_set_adapter_lora / llama_rm_adapter_lora were replaced by llama_set_adapters_lora
916 // which takes a full list of adapters + scales at once (b8249+)
917 let mut adapter_ptr = adapter.lora_adapter.as_ptr();
918 let mut scale_val = scale;
919 llama_cpp_sys_4::llama_set_adapters_lora(
920 self.context.as_ptr(),
921 &raw mut adapter_ptr,
922 1,
923 &raw mut scale_val,
924 )
925 };
926 if err_code != 0 {
927 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
928 }
929
930 tracing::debug!("Set lora adapter");
931 Ok(())
932 }
933
934 /// Remove all lora adapters from the context.
935 ///
936 /// Note: as of llama.cpp b8249 the per-adapter remove API was replaced by
937 /// `llama_set_adapters_lora` which operates on the full adapter list at once.
938 /// Calling this function clears **all** adapters currently set on the context.
939 ///
940 /// # Errors
941 ///
942 /// See [`LlamaLoraAdapterRemoveError`] for more information.
943 pub fn lora_adapter_remove(
944 &self,
945 _adapter: &mut LlamaLoraAdapter,
946 ) -> Result<(), LlamaLoraAdapterRemoveError> {
947 let err_code = unsafe {
948 llama_cpp_sys_4::llama_set_adapters_lora(
949 self.context.as_ptr(),
950 std::ptr::null_mut(),
951 0,
952 std::ptr::null_mut(),
953 )
954 };
955 if err_code != 0 {
956 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
957 }
958
959 tracing::debug!("Remove lora adapter");
960 Ok(())
961 }
962}
963
964impl Drop for LlamaContext<'_> {
965 fn drop(&mut self) {
966 unsafe { llama_cpp_sys_4::llama_free(self.context.as_ptr()) }
967 }
968}