llama_cpp_4/context.rs
1//! Safe wrapper around `llama_context`.
2
3use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use llama_cpp_sys_4::llama_pooling_type;
9use params::LlamaPoolingType;
10use perf::PerfContextData;
11
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::token::data::LlamaTokenData;
15use crate::token::data_array::LlamaTokenDataArray;
16use crate::token::LlamaToken;
17use crate::{
18 DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
19 LlamaLoraAdapterSetError,
20};
21
22pub mod kv_cache;
23pub mod params;
24pub mod perf;
25pub mod session;
26pub mod tensor_capture;
27
28/// A safe wrapper around the `llama_context` C++ context.
29///
30/// This struct provides a safe interface to interact with the `llama_context` used by the `LlamaModel`.
31/// It encapsulates the raw C++ context pointer and provides additional fields for managing the model and
32/// context-specific settings like embeddings and logits.
33///
34/// The `LlamaContext` struct ensures that the C++ context is always valid by using the `NonNull` type for
35/// the context pointer, preventing it from being null. The struct also holds a reference to the model
36/// (`LlamaModel`) that the context is tied to, along with some internal state like whether embeddings are enabled
37/// and the initialized logits for the context.
38///
39/// # Fields
40///
41/// - `context`: A non-null pointer to the raw C++ `llama_context`. This is the main context used for interacting with the model.
42/// - `model`: A reference to the `LlamaModel` associated with this context. This model provides the data and parameters
43/// that the context interacts with.
44/// - `initialized_logits`: A vector used to store the initialized logits. These are used in the model's processing and
45/// are kept separate from the context data.
46/// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in the context. This is useful for
47/// controlling whether embedding data is generated during the interaction with the model.
48#[allow(clippy::module_name_repetitions)]
49pub struct LlamaContext<'a> {
50 pub(crate) context: NonNull<llama_cpp_sys_4::llama_context>,
51 /// a reference to the contexts model.
52 pub model: &'a LlamaModel,
53 initialized_logits: Vec<i32>,
54 embeddings_enabled: bool,
55}
56
57impl Debug for LlamaContext<'_> {
58 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("LlamaContext")
60 .field("context", &self.context)
61 .finish()
62 }
63}
64
65impl<'model> LlamaContext<'model> {
66 /// Creates a new instance of `LlamaContext` with the provided model, context, and embeddings flag.
67 ///
68 /// This function initializes a new `LlamaContext` object, which is used to interact with the
69 /// `LlamaModel`. The context is created from a pointer to a C++ context and the embeddings flag
70 /// determines whether embeddings are enabled in the context.
71 ///
72 /// # Parameters
73 ///
74 /// - `llama_model`: A reference to an existing `LlamaModel` that will be used with the new context.
75 /// - `llama_context`: A non-null pointer to an existing `llama_cpp_sys_4::llama_context` representing
76 /// the context created in previous steps. This context is necessary for interacting with the model.
77 /// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in this context.
78 ///
79 /// # Returns
80 ///
81 /// This function returns a new instance of `LlamaContext` initialized with the given parameters:
82 /// - The model reference (`llama_model`) is stored in the context.
83 /// - The raw context pointer (`llama_context`) is wrapped in a `NonNull` to ensure safety.
84 /// - The `embeddings_enabled` flag is used to determine if embeddings are enabled for the context.
85 ///
86 /// # Example
87 /// ```
88 /// let llama_model = LlamaModel::load("path/to/model").unwrap();
89 /// let context_ptr = NonNull::new(some_llama_context_ptr).unwrap();
90 /// let context = LlamaContext::new(&llama_model, context_ptr, true);
91 /// // Now you can use the context
92 /// ```
93 pub(crate) fn new(
94 llama_model: &'model LlamaModel,
95 llama_context: NonNull<llama_cpp_sys_4::llama_context>,
96 embeddings_enabled: bool,
97 ) -> Self {
98 Self {
99 context: llama_context,
100 model: llama_model,
101 initialized_logits: Vec::new(),
102 embeddings_enabled,
103 }
104 }
105
106 /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to `n_ubatch`.
107 #[must_use]
108 pub fn n_batch(&self) -> u32 {
109 unsafe { llama_cpp_sys_4::llama_n_batch(self.context.as_ptr()) }
110 }
111
112 /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to `n_batch`.
113 #[must_use]
114 pub fn n_ubatch(&self) -> u32 {
115 unsafe { llama_cpp_sys_4::llama_n_ubatch(self.context.as_ptr()) }
116 }
117
118 /// Gets the size of the context.
119 #[must_use]
120 pub fn n_ctx(&self) -> u32 {
121 unsafe { llama_cpp_sys_4::llama_n_ctx(self.context.as_ptr()) }
122 }
123
124 /// Decodes the batch.
125 ///
126 /// # Errors
127 ///
128 /// - `DecodeError` if the decoding failed.
129 ///
130 /// # Panics
131 ///
132 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
133 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
134 let result =
135 unsafe { llama_cpp_sys_4::llama_decode(self.context.as_ptr(), batch.llama_batch) };
136
137 match NonZeroI32::new(result) {
138 None => {
139 self.initialized_logits
140 .clone_from(&batch.initialized_logits);
141 Ok(())
142 }
143 Some(error) => Err(DecodeError::from(error)),
144 }
145 }
146
147 /// Encodes the batch.
148 ///
149 /// # Errors
150 ///
151 /// - `EncodeError` if the decoding failed.
152 ///
153 /// # Panics
154 ///
155 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
156 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
157 let result =
158 unsafe { llama_cpp_sys_4::llama_encode(self.context.as_ptr(), batch.llama_batch) };
159
160 match NonZeroI32::new(result) {
161 None => {
162 self.initialized_logits
163 .clone_from(&batch.initialized_logits);
164 Ok(())
165 }
166 Some(error) => Err(EncodeError::from(error)),
167 }
168 }
169
170 /// Return Pooling type for Llama's Context
171 #[must_use]
172 pub fn pooling_type(&self) -> LlamaPoolingType {
173 let pooling_type = unsafe { llama_pooling_type(self.context.as_ptr()) };
174
175 LlamaPoolingType::from(pooling_type)
176 }
177
178 /// Get the embeddings for the `i`th sequence in the current context.
179 ///
180 /// # Returns
181 ///
182 /// A slice containing the embeddings for the last decoded batch.
183 /// The size corresponds to the `n_embd` parameter of the context's model.
184 ///
185 /// # Errors
186 ///
187 /// - When the current context was constructed without enabling embeddings.
188 /// - If the current model had a pooling type of [`llama_cpp_sys_4::LLAMA_POOLING_TYPE_NONE`]
189 /// - If the given sequence index exceeds the max sequence id.
190 ///
191 /// # Panics
192 ///
193 /// * `n_embd` does not fit into a usize
194 pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
195 if !self.embeddings_enabled {
196 return Err(EmbeddingsError::NotEnabled);
197 }
198
199 let n_embd =
200 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
201
202 unsafe {
203 let embedding = llama_cpp_sys_4::llama_get_embeddings_seq(self.context.as_ptr(), i);
204
205 // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
206 if embedding.is_null() {
207 Err(EmbeddingsError::NonePoolType)
208 } else {
209 Ok(slice::from_raw_parts(embedding, n_embd))
210 }
211 }
212 }
213
214 /// Get the embeddings for the `i`th token in the current context.
215 ///
216 /// # Returns
217 ///
218 /// A slice containing the embeddings for the last decoded batch of the given token.
219 /// The size corresponds to the `n_embd` parameter of the context's model.
220 ///
221 /// # Errors
222 ///
223 /// - When the current context was constructed without enabling embeddings.
224 /// - When the given token didn't have logits enabled when it was passed.
225 /// - If the given token index exceeds the max token id.
226 ///
227 /// # Panics
228 ///
229 /// * `n_embd` does not fit into a usize
230 pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
231 if !self.embeddings_enabled {
232 return Err(EmbeddingsError::NotEnabled);
233 }
234
235 let n_embd =
236 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
237
238 unsafe {
239 let embedding = llama_cpp_sys_4::llama_get_embeddings_ith(self.context.as_ptr(), i);
240 // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
241 if embedding.is_null() {
242 Err(EmbeddingsError::LogitsNotEnabled)
243 } else {
244 Ok(slice::from_raw_parts(embedding, n_embd))
245 }
246 }
247 }
248
249 /// Get the logits for the last token in the context.
250 ///
251 /// # Returns
252 /// An iterator over unsorted `LlamaTokenData` containing the
253 /// logits for the last token in the context.
254 ///
255 /// # Panics
256 ///
257 /// - underlying logits data is null
258 pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
259 (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
260 let token = LlamaToken::new(i);
261 LlamaTokenData::new(token, *logit, 0_f32)
262 })
263 }
264
265 /// Get the token data array for the last token in the context.
266 ///
267 /// This is a convience method that implements:
268 /// ```ignore
269 /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
270 /// ```
271 ///
272 /// # Panics
273 ///
274 /// - underlying logits data is null
275 #[must_use]
276 pub fn token_data_array(&self) -> LlamaTokenDataArray {
277 LlamaTokenDataArray::from_iter(self.candidates(), false)
278 }
279
280 /// Token logits obtained from the last call to `decode()`.
281 /// The logits for which `batch.logits[i] != 0` are stored contiguously
282 /// in the order they have appeared in the batch.
283 /// Rows: number of tokens for which `batch.logits[i] != 0`
284 /// Cols: `n_vocab`
285 ///
286 /// # Returns
287 ///
288 /// A slice containing the logits for the last decoded token.
289 /// The size corresponds to the `n_vocab` parameter of the context's model.
290 ///
291 /// # Panics
292 ///
293 /// - `n_vocab` does not fit into a usize
294 /// - token data returned is null
295 #[must_use]
296 pub fn get_logits(&self) -> &[f32] {
297 let data = unsafe { llama_cpp_sys_4::llama_get_logits(self.context.as_ptr()) };
298 assert!(!data.is_null(), "logits data for last token is null");
299 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
300
301 unsafe { slice::from_raw_parts(data, len) }
302 }
303
304 /// Get the logits for the ith token in the context.
305 ///
306 /// # Panics
307 ///
308 /// - logit `i` is not initialized.
309 pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
310 (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
311 let token = LlamaToken::new(i);
312 LlamaTokenData::new(token, *logit, 0_f32)
313 })
314 }
315
316 /// Get the logits for the ith token in the context.
317 ///
318 /// # Panics
319 ///
320 /// - `i` is greater than `n_ctx`
321 /// - `n_vocab` does not fit into a usize
322 /// - logit `i` is not initialized.
323 #[must_use]
324 pub fn get_logits_ith(&self, i: i32) -> &[f32] {
325 assert!(
326 self.initialized_logits.contains(&i),
327 "logit {i} is not initialized. only {:?} is",
328 self.initialized_logits
329 );
330 assert!(
331 self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
332 "n_ctx ({}) must be greater than i ({})",
333 self.n_ctx(),
334 i
335 );
336
337 let data = unsafe { llama_cpp_sys_4::llama_get_logits_ith(self.context.as_ptr(), i) };
338 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
339
340 unsafe { slice::from_raw_parts(data, len) }
341 }
342
343 /// Get the number of context tokens per sequence.
344 #[must_use]
345 pub fn n_ctx_seq(&self) -> u32 {
346 unsafe { llama_cpp_sys_4::llama_n_ctx_seq(self.context.as_ptr()) }
347 }
348
349 /// Get the maximum number of sequences.
350 #[must_use]
351 pub fn n_seq_max(&self) -> u32 {
352 unsafe { llama_cpp_sys_4::llama_n_seq_max(self.context.as_ptr()) }
353 }
354
355 /// Get the number of threads used for generation.
356 #[must_use]
357 pub fn n_threads(&self) -> i32 {
358 unsafe { llama_cpp_sys_4::llama_n_threads(self.context.as_ptr()) }
359 }
360
361 /// Get the number of threads used for batch processing.
362 #[must_use]
363 pub fn n_threads_batch(&self) -> i32 {
364 unsafe { llama_cpp_sys_4::llama_n_threads_batch(self.context.as_ptr()) }
365 }
366
367 /// Set the number of threads used for generation and batch processing.
368 pub fn set_n_threads(&mut self, n_threads: i32, n_threads_batch: i32) {
369 unsafe {
370 llama_cpp_sys_4::llama_set_n_threads(
371 self.context.as_ptr(),
372 n_threads,
373 n_threads_batch,
374 );
375 }
376 }
377
378 /// Set whether to use causal attention.
379 ///
380 /// If set to `false`, the model will use non-causal attention, which is
381 /// needed for embedding models.
382 pub fn set_causal_attn(&mut self, causal_attn: bool) {
383 unsafe {
384 llama_cpp_sys_4::llama_set_causal_attn(self.context.as_ptr(), causal_attn);
385 }
386 }
387
388 /// Set whether to compute embeddings.
389 ///
390 /// This allows toggling embedding mode at runtime (as opposed to only at
391 /// context creation time).
392 pub fn set_embeddings(&mut self, embeddings: bool) {
393 self.embeddings_enabled = embeddings;
394 unsafe {
395 llama_cpp_sys_4::llama_set_embeddings(self.context.as_ptr(), embeddings);
396 }
397 }
398
399 /// Mark the next computation as a warmup run.
400 ///
401 /// Warmup runs are useful for GPU backends to compile kernels before
402 /// actual inference begins.
403 pub fn set_warmup(&mut self, warmup: bool) {
404 unsafe {
405 llama_cpp_sys_4::llama_set_warmup(self.context.as_ptr(), warmup);
406 }
407 }
408
409 /// Wait for all pending async computations to finish.
410 pub fn synchronize(&mut self) {
411 unsafe {
412 llama_cpp_sys_4::llama_synchronize(self.context.as_ptr());
413 }
414 }
415
416 /// Get all embeddings for the current context.
417 ///
418 /// Returns a slice of all embeddings from the last decoded batch.
419 /// For pooled embeddings use [`embeddings_seq_ith`](Self::embeddings_seq_ith) instead.
420 ///
421 /// # Errors
422 ///
423 /// - When the current context was constructed without enabling embeddings.
424 /// - If the embeddings pointer is null.
425 ///
426 /// # Panics
427 ///
428 /// * `n_embd` does not fit into a usize
429 pub fn get_embeddings(&self) -> Result<&[f32], EmbeddingsError> {
430 if !self.embeddings_enabled {
431 return Err(EmbeddingsError::NotEnabled);
432 }
433
434 let n_embd =
435 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
436
437 unsafe {
438 let embedding = llama_cpp_sys_4::llama_get_embeddings(self.context.as_ptr());
439 if embedding.is_null() {
440 Err(EmbeddingsError::NonePoolType)
441 } else {
442 Ok(slice::from_raw_parts(embedding, n_embd))
443 }
444 }
445 }
446
447 /// Reset the timings for the context.
448 pub fn reset_timings(&mut self) {
449 unsafe { llama_cpp_sys_4::ggml_time_init() }
450 }
451
452 /// Returns the timings for the context.
453 pub fn timings(&mut self) -> PerfContextData {
454 let perf_context_data =
455 unsafe { llama_cpp_sys_4::llama_perf_context(self.context.as_ptr()) };
456 PerfContextData { perf_context_data }
457 }
458
459 /// Reset the performance counters for the context.
460 pub fn perf_context_reset(&mut self) {
461 unsafe { llama_cpp_sys_4::llama_perf_context_reset(self.context.as_ptr()) }
462 }
463
464 /// Check if the KV cache memory supports shifting.
465 #[must_use]
466 pub fn memory_can_shift(&self) -> bool {
467 unsafe {
468 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
469 llama_cpp_sys_4::llama_memory_can_shift(mem)
470 }
471 }
472
473 /// Get the minimum position in a sequence's KV cache.
474 #[must_use]
475 pub fn memory_seq_pos_min(&self, seq_id: i32) -> i32 {
476 unsafe {
477 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
478 llama_cpp_sys_4::llama_memory_seq_pos_min(mem, seq_id)
479 }
480 }
481
482 /// Print a breakdown of the memory usage.
483 pub fn memory_breakdown_print(&self) {
484 unsafe {
485 llama_cpp_sys_4::llama_memory_breakdown_print(self.context.as_ptr());
486 }
487 }
488
489 /// Get the size of the full context state in bytes.
490 ///
491 /// This is the size needed for [`state_get_data`](Self::state_get_data) and
492 /// [`state_set_data`](Self::state_set_data).
493 #[must_use]
494 pub fn state_get_size(&mut self) -> usize {
495 unsafe { llama_cpp_sys_4::llama_state_get_size(self.context.as_ptr()) }
496 }
497
498 /// Copy the full context state into a byte buffer.
499 ///
500 /// The buffer must be at least [`state_get_size`](Self::state_get_size) bytes.
501 ///
502 /// Returns the number of bytes written.
503 pub fn state_get_data(&mut self, dst: &mut [u8]) -> usize {
504 unsafe {
505 llama_cpp_sys_4::llama_state_get_data(
506 self.context.as_ptr(),
507 dst.as_mut_ptr(),
508 dst.len(),
509 )
510 }
511 }
512
513 /// Restore the full context state from a byte buffer.
514 ///
515 /// Returns the number of bytes read.
516 pub fn state_set_data(&mut self, src: &[u8]) -> usize {
517 unsafe {
518 llama_cpp_sys_4::llama_state_set_data(
519 self.context.as_ptr(),
520 src.as_ptr(),
521 src.len(),
522 )
523 }
524 }
525
526 /// Save the context state to a file along with the given tokens.
527 ///
528 /// Returns `true` on success.
529 ///
530 /// # Panics
531 ///
532 /// Panics if the path contains null bytes.
533 pub fn state_save_file(
534 &mut self,
535 path: impl AsRef<std::path::Path>,
536 tokens: &[LlamaToken],
537 ) -> bool {
538 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
539 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
540 unsafe {
541 llama_cpp_sys_4::llama_state_save_file(
542 self.context.as_ptr(),
543 c_path.as_ptr(),
544 tokens.as_ptr().cast(),
545 tokens.len(),
546 )
547 }
548 }
549
550 /// Load a context state from a file.
551 ///
552 /// Returns `true` on success and fills `tokens_out` with the saved tokens.
553 ///
554 /// # Panics
555 ///
556 /// Panics if the path contains null bytes.
557 pub fn state_load_file(
558 &mut self,
559 path: impl AsRef<std::path::Path>,
560 tokens_out: &mut Vec<LlamaToken>,
561 n_token_capacity: usize,
562 ) -> bool {
563 tokens_out.resize(n_token_capacity, LlamaToken(0));
564 let mut n_token_count: usize = 0;
565 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
566 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
567 let ok = unsafe {
568 llama_cpp_sys_4::llama_state_load_file(
569 self.context.as_ptr(),
570 c_path.as_ptr(),
571 tokens_out.as_mut_ptr().cast(),
572 n_token_capacity,
573 std::ptr::addr_of_mut!(n_token_count),
574 )
575 };
576 if ok {
577 tokens_out.truncate(n_token_count);
578 }
579 ok
580 }
581
582 /// Get the size of a single sequence's state in bytes.
583 #[must_use]
584 pub fn state_seq_get_size(&mut self, seq_id: i32) -> usize {
585 unsafe { llama_cpp_sys_4::llama_state_seq_get_size(self.context.as_ptr(), seq_id) }
586 }
587
588 /// Copy a single sequence's state into a byte buffer.
589 ///
590 /// Returns the number of bytes written.
591 pub fn state_seq_get_data(&mut self, dst: &mut [u8], seq_id: i32) -> usize {
592 unsafe {
593 llama_cpp_sys_4::llama_state_seq_get_data(
594 self.context.as_ptr(),
595 dst.as_mut_ptr(),
596 dst.len(),
597 seq_id,
598 )
599 }
600 }
601
602 /// Restore a single sequence's state from a byte buffer.
603 ///
604 /// Returns the number of bytes read.
605 pub fn state_seq_set_data(&mut self, src: &[u8], dest_seq_id: i32) -> usize {
606 unsafe {
607 llama_cpp_sys_4::llama_state_seq_set_data(
608 self.context.as_ptr(),
609 src.as_ptr(),
610 src.len(),
611 dest_seq_id,
612 )
613 }
614 }
615
616 /// Save a single sequence's state to a file.
617 ///
618 /// Returns the number of bytes written (0 on failure).
619 ///
620 /// # Panics
621 ///
622 /// Panics if the path contains null bytes.
623 pub fn state_seq_save_file(
624 &mut self,
625 path: impl AsRef<std::path::Path>,
626 seq_id: i32,
627 tokens: &[LlamaToken],
628 ) -> usize {
629 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
630 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
631 unsafe {
632 llama_cpp_sys_4::llama_state_seq_save_file(
633 self.context.as_ptr(),
634 c_path.as_ptr(),
635 seq_id,
636 tokens.as_ptr().cast(),
637 tokens.len(),
638 )
639 }
640 }
641
642 /// Load a single sequence's state from a file.
643 ///
644 /// Returns the number of bytes read (0 on failure).
645 ///
646 /// # Panics
647 ///
648 /// Panics if the path contains null bytes.
649 pub fn state_seq_load_file(
650 &mut self,
651 path: impl AsRef<std::path::Path>,
652 dest_seq_id: i32,
653 tokens_out: &mut Vec<LlamaToken>,
654 n_token_capacity: usize,
655 ) -> usize {
656 tokens_out.resize(n_token_capacity, LlamaToken(0));
657 let mut n_token_count: usize = 0;
658 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
659 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
660 let ret = unsafe {
661 llama_cpp_sys_4::llama_state_seq_load_file(
662 self.context.as_ptr(),
663 c_path.as_ptr(),
664 dest_seq_id,
665 tokens_out.as_mut_ptr().cast(),
666 n_token_capacity,
667 std::ptr::addr_of_mut!(n_token_count),
668 )
669 };
670 if ret > 0 {
671 tokens_out.truncate(n_token_count);
672 }
673 ret
674 }
675
676 /// Set a control vector on the context.
677 ///
678 /// # Parameters
679 ///
680 /// - `data`: The control vector data (embedding values). Pass an empty slice to clear.
681 /// - `n_embd`: The embedding dimension.
682 /// - `il_start`: The starting layer index (inclusive).
683 /// - `il_end`: The ending layer index (exclusive).
684 ///
685 /// # Errors
686 ///
687 /// Returns `Err` with the error code if the operation fails.
688 pub fn set_adapter_cvec(
689 &mut self,
690 data: &[f32],
691 n_embd: i32,
692 il_start: i32,
693 il_end: i32,
694 ) -> Result<(), i32> {
695 let ret = unsafe {
696 llama_cpp_sys_4::llama_set_adapter_cvec(
697 self.context.as_ptr(),
698 data.as_ptr(),
699 data.len(),
700 n_embd,
701 il_start,
702 il_end,
703 )
704 };
705 if ret != 0 {
706 Err(ret)
707 } else {
708 Ok(())
709 }
710 }
711
712 /// Get sampled token debug info for the `i`th position.
713 ///
714 /// Returns the sampled token at position `i` from the last decode call.
715 #[must_use]
716 pub fn get_sampled_token_ith(&self, i: i32) -> LlamaToken {
717 let token =
718 unsafe { llama_cpp_sys_4::llama_get_sampled_token_ith(self.context.as_ptr(), i) };
719 LlamaToken(token)
720 }
721
722 /// Get sampled candidate tokens for the `i`th position.
723 ///
724 /// Returns a slice of candidate tokens from the last decode call.
725 #[must_use]
726 pub fn get_sampled_candidates_ith(&self, i: i32) -> &[LlamaToken] {
727 let count = unsafe {
728 llama_cpp_sys_4::llama_get_sampled_candidates_count_ith(self.context.as_ptr(), i)
729 } as usize;
730 if count == 0 {
731 return &[];
732 }
733 let ptr = unsafe {
734 llama_cpp_sys_4::llama_get_sampled_candidates_ith(self.context.as_ptr(), i)
735 };
736 if ptr.is_null() {
737 return &[];
738 }
739 unsafe { slice::from_raw_parts(ptr.cast::<LlamaToken>(), count) }
740 }
741
742 /// Get the number of sampled logits for the `i`th position.
743 #[must_use]
744 pub fn get_sampled_logits_count_ith(&self, i: i32) -> u32 {
745 unsafe {
746 llama_cpp_sys_4::llama_get_sampled_logits_count_ith(self.context.as_ptr(), i)
747 }
748 }
749
750 /// Get sampled logits for the `i`th position.
751 ///
752 /// Returns a slice of logit values from the last decode call.
753 #[must_use]
754 pub fn get_sampled_logits_ith(&self, i: i32) -> &[f32] {
755 let count = self.get_sampled_logits_count_ith(i) as usize;
756 if count == 0 {
757 return &[];
758 }
759 let ptr = unsafe {
760 llama_cpp_sys_4::llama_get_sampled_logits_ith(self.context.as_ptr(), i)
761 };
762 if ptr.is_null() {
763 return &[];
764 }
765 unsafe { slice::from_raw_parts(ptr, count) }
766 }
767
768 /// Get the number of sampled probabilities for the `i`th position.
769 #[must_use]
770 pub fn get_sampled_probs_count_ith(&self, i: i32) -> u32 {
771 unsafe {
772 llama_cpp_sys_4::llama_get_sampled_probs_count_ith(self.context.as_ptr(), i)
773 }
774 }
775
776 /// Get sampled probabilities for the `i`th position.
777 ///
778 /// Returns a slice of probability values from the last decode call.
779 #[must_use]
780 pub fn get_sampled_probs_ith(&self, i: i32) -> &[f32] {
781 let count = self.get_sampled_probs_count_ith(i) as usize;
782 if count == 0 {
783 return &[];
784 }
785 let ptr = unsafe {
786 llama_cpp_sys_4::llama_get_sampled_probs_ith(self.context.as_ptr(), i)
787 };
788 if ptr.is_null() {
789 return &[];
790 }
791 unsafe { slice::from_raw_parts(ptr, count) }
792 }
793
794 /// Get the size of a single sequence's state with flags.
795 #[must_use]
796 pub fn state_seq_get_size_ext(&mut self, seq_id: i32, flags: u32) -> usize {
797 unsafe {
798 llama_cpp_sys_4::llama_state_seq_get_size_ext(self.context.as_ptr(), seq_id, flags)
799 }
800 }
801
802 /// Copy a single sequence's state into a byte buffer with flags.
803 ///
804 /// Returns the number of bytes written.
805 pub fn state_seq_get_data_ext(&mut self, dst: &mut [u8], seq_id: i32, flags: u32) -> usize {
806 unsafe {
807 llama_cpp_sys_4::llama_state_seq_get_data_ext(
808 self.context.as_ptr(),
809 dst.as_mut_ptr(),
810 dst.len(),
811 seq_id,
812 flags,
813 )
814 }
815 }
816
817 /// Restore a single sequence's state from a byte buffer with flags.
818 ///
819 /// Returns the number of bytes read.
820 pub fn state_seq_set_data_ext(
821 &mut self,
822 src: &[u8],
823 dest_seq_id: i32,
824 flags: u32,
825 ) -> usize {
826 unsafe {
827 llama_cpp_sys_4::llama_state_seq_set_data_ext(
828 self.context.as_ptr(),
829 src.as_ptr(),
830 src.len(),
831 dest_seq_id,
832 flags,
833 )
834 }
835 }
836
837 /// Set an abort callback for the context.
838 ///
839 /// The callback is called periodically during computation. If it returns `true`,
840 /// the computation is aborted.
841 ///
842 /// # Safety
843 ///
844 /// The callback data must remain valid for the lifetime of the context or until
845 /// the callback is replaced.
846 pub unsafe fn set_abort_callback(
847 &mut self,
848 callback: llama_cpp_sys_4::ggml_abort_callback,
849 data: *mut std::ffi::c_void,
850 ) {
851 llama_cpp_sys_4::llama_set_abort_callback(self.context.as_ptr(), callback, data);
852 }
853
854 /// Attach a thread pool to the context.
855 ///
856 /// # Safety
857 ///
858 /// The thread pools must remain valid for the lifetime of the context or until
859 /// they are detached.
860 pub unsafe fn attach_threadpool(
861 &mut self,
862 threadpool: llama_cpp_sys_4::ggml_threadpool_t,
863 threadpool_batch: llama_cpp_sys_4::ggml_threadpool_t,
864 ) {
865 llama_cpp_sys_4::llama_attach_threadpool(
866 self.context.as_ptr(),
867 threadpool,
868 threadpool_batch,
869 );
870 }
871
872 /// Detach the thread pool from the context.
873 pub fn detach_threadpool(&mut self) {
874 unsafe {
875 llama_cpp_sys_4::llama_detach_threadpool(self.context.as_ptr());
876 }
877 }
878
879 /// Set a sampler for a specific sequence.
880 ///
881 /// Returns `true` on success.
882 pub fn set_sampler(
883 &mut self,
884 seq_id: i32,
885 sampler: &mut crate::sampling::LlamaSampler,
886 ) -> bool {
887 unsafe {
888 llama_cpp_sys_4::llama_set_sampler(
889 self.context.as_ptr(),
890 seq_id,
891 sampler.sampler.as_ptr(),
892 )
893 }
894 }
895
896 /// Get the raw model pointer from this context.
897 ///
898 /// This is mainly useful for FFI interop. In normal usage, access
899 /// the model via the `model` field instead.
900 #[must_use]
901 pub fn get_model_ptr(&self) -> *const llama_cpp_sys_4::llama_model {
902 unsafe { llama_cpp_sys_4::llama_get_model(self.context.as_ptr()) }
903 }
904
905 /// Sets a lora adapter.
906 ///
907 /// # Errors
908 ///
909 /// See [`LlamaLoraAdapterSetError`] for more information.
910 pub fn lora_adapter_set(
911 &self,
912 adapter: &mut LlamaLoraAdapter,
913 scale: f32,
914 ) -> Result<(), LlamaLoraAdapterSetError> {
915 let err_code = unsafe {
916 // llama_set_adapter_lora / llama_rm_adapter_lora were replaced by llama_set_adapters_lora
917 // which takes a full list of adapters + scales at once (b8249+)
918 let mut adapter_ptr = adapter.lora_adapter.as_ptr();
919 let mut scale_val = scale;
920 llama_cpp_sys_4::llama_set_adapters_lora(
921 self.context.as_ptr(),
922 &raw mut adapter_ptr,
923 1,
924 &raw mut scale_val,
925 )
926 };
927 if err_code != 0 {
928 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
929 }
930
931 tracing::debug!("Set lora adapter");
932 Ok(())
933 }
934
935 /// Remove all lora adapters from the context.
936 ///
937 /// Note: as of llama.cpp b8249 the per-adapter remove API was replaced by
938 /// `llama_set_adapters_lora` which operates on the full adapter list at once.
939 /// Calling this function clears **all** adapters currently set on the context.
940 ///
941 /// # Errors
942 ///
943 /// See [`LlamaLoraAdapterRemoveError`] for more information.
944 pub fn lora_adapter_remove(
945 &self,
946 _adapter: &mut LlamaLoraAdapter,
947 ) -> Result<(), LlamaLoraAdapterRemoveError> {
948 let err_code = unsafe {
949 llama_cpp_sys_4::llama_set_adapters_lora(
950 self.context.as_ptr(),
951 std::ptr::null_mut(),
952 0,
953 std::ptr::null_mut(),
954 )
955 };
956 if err_code != 0 {
957 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
958 }
959
960 tracing::debug!("Remove lora adapter");
961 Ok(())
962 }
963}
964
965impl Drop for LlamaContext<'_> {
966 fn drop(&mut self) {
967 unsafe { llama_cpp_sys_4::llama_free(self.context.as_ptr()) }
968 }
969}