llama_cpp_4/context.rs
1//! Safe wrapper around `llama_context`.
2
3use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use llama_cpp_sys_4::llama_pooling_type;
9use params::LlamaPoolingType;
10use perf::PerfContextData;
11
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::token::data::LlamaTokenData;
15use crate::token::data_array::LlamaTokenDataArray;
16use crate::token::LlamaToken;
17use crate::{
18 DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
19 LlamaLoraAdapterSetError,
20};
21
22pub mod kv_cache;
23pub mod params;
24pub mod perf;
25pub mod session;
26pub mod tensor_capture;
27
28/// A safe wrapper around the `llama_context` C++ context.
29///
30/// This struct provides a safe interface to interact with the `llama_context` used by the `LlamaModel`.
31/// It encapsulates the raw C++ context pointer and provides additional fields for managing the model and
32/// context-specific settings like embeddings and logits.
33///
34/// The `LlamaContext` struct ensures that the C++ context is always valid by using the `NonNull` type for
35/// the context pointer, preventing it from being null. The struct also holds a reference to the model
36/// (`LlamaModel`) that the context is tied to, along with some internal state like whether embeddings are enabled
37/// and the initialized logits for the context.
38///
39/// # Fields
40///
41/// - `context`: A non-null pointer to the raw C++ `llama_context`. This is the main context used for interacting with the model.
42/// - `model`: A reference to the `LlamaModel` associated with this context. This model provides the data and parameters
43/// that the context interacts with.
44/// - `initialized_logits`: A vector used to store the initialized logits. These are used in the model's processing and
45/// are kept separate from the context data.
46/// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in the context. This is useful for
47/// controlling whether embedding data is generated during the interaction with the model.
48#[allow(clippy::module_name_repetitions)]
49pub struct LlamaContext<'a> {
50 pub(crate) context: NonNull<llama_cpp_sys_4::llama_context>,
51 /// a reference to the contexts model.
52 pub model: &'a LlamaModel,
53 initialized_logits: Vec<i32>,
54 embeddings_enabled: bool,
55}
56
57impl Debug for LlamaContext<'_> {
58 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("LlamaContext")
60 .field("context", &self.context)
61 .finish()
62 }
63}
64
65impl<'model> LlamaContext<'model> {
66 /// Creates a new instance of `LlamaContext` with the provided model, context, and embeddings flag.
67 ///
68 /// This function initializes a new `LlamaContext` object, which is used to interact with the
69 /// `LlamaModel`. The context is created from a pointer to a C++ context and the embeddings flag
70 /// determines whether embeddings are enabled in the context.
71 ///
72 /// # Parameters
73 ///
74 /// - `llama_model`: A reference to an existing `LlamaModel` that will be used with the new context.
75 /// - `llama_context`: A non-null pointer to an existing `llama_cpp_sys_4::llama_context` representing
76 /// the context created in previous steps. This context is necessary for interacting with the model.
77 /// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in this context.
78 ///
79 /// # Returns
80 ///
81 /// This function returns a new instance of `LlamaContext` initialized with the given parameters:
82 /// - The model reference (`llama_model`) is stored in the context.
83 /// - The raw context pointer (`llama_context`) is wrapped in a `NonNull` to ensure safety.
84 /// - The `embeddings_enabled` flag is used to determine if embeddings are enabled for the context.
85 ///
86 /// # Example
87 /// ```ignore
88 /// let llama_model = LlamaModel::load_from_file(&backend, "path/to/model", ¶ms).unwrap();
89 /// let context_ptr = NonNull::new(some_llama_context_ptr).unwrap();
90 /// let context = LlamaContext::new(&llama_model, context_ptr, true);
91 /// // Now you can use the context
92 /// ```
93 pub(crate) fn new(
94 llama_model: &'model LlamaModel,
95 llama_context: NonNull<llama_cpp_sys_4::llama_context>,
96 embeddings_enabled: bool,
97 ) -> Self {
98 Self {
99 context: llama_context,
100 model: llama_model,
101 initialized_logits: Vec::new(),
102 embeddings_enabled,
103 }
104 }
105
106 /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to `n_ubatch`.
107 #[must_use]
108 pub fn n_batch(&self) -> u32 {
109 unsafe { llama_cpp_sys_4::llama_n_batch(self.context.as_ptr()) }
110 }
111
112 /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to `n_batch`.
113 #[must_use]
114 pub fn n_ubatch(&self) -> u32 {
115 unsafe { llama_cpp_sys_4::llama_n_ubatch(self.context.as_ptr()) }
116 }
117
118 /// Gets the size of the context.
119 #[must_use]
120 pub fn n_ctx(&self) -> u32 {
121 unsafe { llama_cpp_sys_4::llama_n_ctx(self.context.as_ptr()) }
122 }
123
124 /// Decodes the batch.
125 ///
126 /// # Errors
127 ///
128 /// - `DecodeError` if the decoding failed.
129 ///
130 /// # Panics
131 ///
132 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
133 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
134 let result =
135 unsafe { llama_cpp_sys_4::llama_decode(self.context.as_ptr(), batch.llama_batch) };
136
137 match NonZeroI32::new(result) {
138 None => {
139 self.initialized_logits
140 .clone_from(&batch.initialized_logits);
141 Ok(())
142 }
143 Some(error) => Err(DecodeError::from(error)),
144 }
145 }
146
147 /// Encodes the batch.
148 ///
149 /// # Errors
150 ///
151 /// - `EncodeError` if the decoding failed.
152 ///
153 /// # Panics
154 ///
155 /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
156 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
157 let result =
158 unsafe { llama_cpp_sys_4::llama_encode(self.context.as_ptr(), batch.llama_batch) };
159
160 match NonZeroI32::new(result) {
161 None => {
162 self.initialized_logits
163 .clone_from(&batch.initialized_logits);
164 Ok(())
165 }
166 Some(error) => Err(EncodeError::from(error)),
167 }
168 }
169
170 /// Return Pooling type for Llama's Context
171 #[must_use]
172 pub fn pooling_type(&self) -> LlamaPoolingType {
173 let pooling_type = unsafe { llama_pooling_type(self.context.as_ptr()) };
174
175 LlamaPoolingType::from(pooling_type)
176 }
177
178 /// Get the embeddings for the `i`th sequence in the current context.
179 ///
180 /// # Returns
181 ///
182 /// A slice containing the embeddings for the last decoded batch.
183 /// The size corresponds to the `n_embd` parameter of the context's model.
184 ///
185 /// # Errors
186 ///
187 /// - When the current context was constructed without enabling embeddings.
188 /// - If the current model had a pooling type of [`llama_cpp_sys_4::LLAMA_POOLING_TYPE_NONE`]
189 /// - If the given sequence index exceeds the max sequence id.
190 ///
191 /// # Panics
192 ///
193 /// * `n_embd` does not fit into a usize
194 pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
195 if !self.embeddings_enabled {
196 return Err(EmbeddingsError::NotEnabled);
197 }
198
199 let n_embd =
200 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
201
202 unsafe {
203 let embedding = llama_cpp_sys_4::llama_get_embeddings_seq(self.context.as_ptr(), i);
204
205 // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
206 if embedding.is_null() {
207 Err(EmbeddingsError::NonePoolType)
208 } else {
209 Ok(slice::from_raw_parts(embedding, n_embd))
210 }
211 }
212 }
213
214 /// Get the embeddings for the `i`th token in the current context.
215 ///
216 /// # Returns
217 ///
218 /// A slice containing the embeddings for the last decoded batch of the given token.
219 /// The size corresponds to the `n_embd` parameter of the context's model.
220 ///
221 /// # Errors
222 ///
223 /// - When the current context was constructed without enabling embeddings.
224 /// - When the given token didn't have logits enabled when it was passed.
225 /// - If the given token index exceeds the max token id.
226 ///
227 /// # Panics
228 ///
229 /// * `n_embd` does not fit into a usize
230 pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
231 if !self.embeddings_enabled {
232 return Err(EmbeddingsError::NotEnabled);
233 }
234
235 let n_embd =
236 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
237
238 unsafe {
239 let embedding = llama_cpp_sys_4::llama_get_embeddings_ith(self.context.as_ptr(), i);
240 // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
241 if embedding.is_null() {
242 Err(EmbeddingsError::LogitsNotEnabled)
243 } else {
244 Ok(slice::from_raw_parts(embedding, n_embd))
245 }
246 }
247 }
248
249 /// Get the logits for the last token in the context.
250 ///
251 /// # Returns
252 /// An iterator over unsorted `LlamaTokenData` containing the
253 /// logits for the last token in the context.
254 ///
255 /// # Panics
256 ///
257 /// - underlying logits data is null
258 pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
259 (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
260 let token = LlamaToken::new(i);
261 LlamaTokenData::new(token, *logit, 0_f32)
262 })
263 }
264
265 /// Get the token data array for the last token in the context.
266 ///
267 /// This is a convience method that implements:
268 /// ```ignore
269 /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
270 /// ```
271 ///
272 /// # Panics
273 ///
274 /// - underlying logits data is null
275 #[must_use]
276 pub fn token_data_array(&self) -> LlamaTokenDataArray {
277 LlamaTokenDataArray::from_iter(self.candidates(), false)
278 }
279
280 /// Token logits obtained from the last call to `decode()`.
281 /// The logits for which `batch.logits[i] != 0` are stored contiguously
282 /// in the order they have appeared in the batch.
283 /// Rows: number of tokens for which `batch.logits[i] != 0`
284 /// Cols: `n_vocab`
285 ///
286 /// # Returns
287 ///
288 /// A slice containing the logits for the last decoded token.
289 /// The size corresponds to the `n_vocab` parameter of the context's model.
290 ///
291 /// # Panics
292 ///
293 /// - `n_vocab` does not fit into a usize
294 /// - token data returned is null
295 #[must_use]
296 pub fn get_logits(&self) -> &[f32] {
297 let data = unsafe { llama_cpp_sys_4::llama_get_logits(self.context.as_ptr()) };
298 assert!(!data.is_null(), "logits data for last token is null");
299 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
300
301 unsafe { slice::from_raw_parts(data, len) }
302 }
303
304 /// Get the logits for the ith token in the context.
305 ///
306 /// # Panics
307 ///
308 /// - logit `i` is not initialized.
309 pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
310 (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
311 let token = LlamaToken::new(i);
312 LlamaTokenData::new(token, *logit, 0_f32)
313 })
314 }
315
316 /// Get the logits for the ith token in the context.
317 ///
318 /// # Panics
319 ///
320 /// - `i` is greater than `n_ctx`
321 /// - `n_vocab` does not fit into a usize
322 /// - logit `i` is not initialized.
323 #[must_use]
324 pub fn get_logits_ith(&self, i: i32) -> &[f32] {
325 assert!(
326 self.initialized_logits.contains(&i),
327 "logit {i} is not initialized. only {:?} is",
328 self.initialized_logits
329 );
330 assert!(
331 self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
332 "n_ctx ({}) must be greater than i ({})",
333 self.n_ctx(),
334 i
335 );
336
337 let data = unsafe { llama_cpp_sys_4::llama_get_logits_ith(self.context.as_ptr(), i) };
338 let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
339
340 unsafe { slice::from_raw_parts(data, len) }
341 }
342
343 /// Get the number of context tokens per sequence.
344 #[must_use]
345 pub fn n_ctx_seq(&self) -> u32 {
346 unsafe { llama_cpp_sys_4::llama_n_ctx_seq(self.context.as_ptr()) }
347 }
348
349 /// Get the maximum number of sequences.
350 #[must_use]
351 pub fn n_seq_max(&self) -> u32 {
352 unsafe { llama_cpp_sys_4::llama_n_seq_max(self.context.as_ptr()) }
353 }
354
355 /// Get the number of recurrent-state snapshots per sequence.
356 ///
357 /// This is only available when built with the `mtp` feature.
358 #[cfg(feature = "mtp")]
359 #[must_use]
360 pub fn n_rs_seq(&self) -> u32 {
361 unsafe { llama_cpp_sys_4::llama_n_rs_seq(self.context.as_ptr()) }
362 }
363
364 /// Get the number of threads used for generation.
365 #[must_use]
366 pub fn n_threads(&self) -> i32 {
367 unsafe { llama_cpp_sys_4::llama_n_threads(self.context.as_ptr()) }
368 }
369
370 /// Get the number of threads used for batch processing.
371 #[must_use]
372 pub fn n_threads_batch(&self) -> i32 {
373 unsafe { llama_cpp_sys_4::llama_n_threads_batch(self.context.as_ptr()) }
374 }
375
376 /// Set the number of threads used for generation and batch processing.
377 pub fn set_n_threads(&mut self, n_threads: i32, n_threads_batch: i32) {
378 unsafe {
379 llama_cpp_sys_4::llama_set_n_threads(self.context.as_ptr(), n_threads, n_threads_batch);
380 }
381 }
382
383 /// Set whether to use causal attention.
384 ///
385 /// If set to `false`, the model will use non-causal attention, which is
386 /// needed for embedding models.
387 pub fn set_causal_attn(&mut self, causal_attn: bool) {
388 unsafe {
389 llama_cpp_sys_4::llama_set_causal_attn(self.context.as_ptr(), causal_attn);
390 }
391 }
392
393 /// Set whether to compute embeddings.
394 ///
395 /// This allows toggling embedding mode at runtime (as opposed to only at
396 /// context creation time).
397 pub fn set_embeddings(&mut self, embeddings: bool) {
398 self.embeddings_enabled = embeddings;
399 unsafe {
400 llama_cpp_sys_4::llama_set_embeddings(self.context.as_ptr(), embeddings);
401 }
402 }
403
404 /// Mark the next computation as a warmup run.
405 ///
406 /// Warmup runs are useful for GPU backends to compile kernels before
407 /// actual inference begins.
408 pub fn set_warmup(&mut self, warmup: bool) {
409 unsafe {
410 llama_cpp_sys_4::llama_set_warmup(self.context.as_ptr(), warmup);
411 }
412 }
413
414 /// Attach or detach an MTP context used for speculative drafting.
415 ///
416 /// Pass `Some(&mtp_ctx)` to register an MTP context, or `None` to detach it.
417 /// This is only available when built with the `mtp` feature.
418 #[cfg(feature = "mtp")]
419 pub fn set_mtp(&mut self, mtp_ctx: Option<&LlamaContext<'_>>) {
420 let ptr = mtp_ctx.map_or(std::ptr::null_mut(), |ctx| ctx.context.as_ptr());
421 unsafe {
422 llama_cpp_sys_4::llama_set_mtp(self.context.as_ptr(), ptr);
423 }
424 }
425
426 /// Wait for all pending async computations to finish.
427 pub fn synchronize(&mut self) {
428 unsafe {
429 llama_cpp_sys_4::llama_synchronize(self.context.as_ptr());
430 }
431 }
432
433 /// Get all embeddings for the current context.
434 ///
435 /// Returns a slice of all embeddings from the last decoded batch.
436 /// For pooled embeddings use [`embeddings_seq_ith`](Self::embeddings_seq_ith) instead.
437 ///
438 /// # Errors
439 ///
440 /// - When the current context was constructed without enabling embeddings.
441 /// - If the embeddings pointer is null.
442 ///
443 /// # Panics
444 ///
445 /// * `n_embd` does not fit into a usize
446 pub fn get_embeddings(&self) -> Result<&[f32], EmbeddingsError> {
447 if !self.embeddings_enabled {
448 return Err(EmbeddingsError::NotEnabled);
449 }
450
451 let n_embd =
452 usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
453
454 unsafe {
455 let embedding = llama_cpp_sys_4::llama_get_embeddings(self.context.as_ptr());
456 if embedding.is_null() {
457 Err(EmbeddingsError::NonePoolType)
458 } else {
459 Ok(slice::from_raw_parts(embedding, n_embd))
460 }
461 }
462 }
463
464 /// Reset the timings for the context.
465 pub fn reset_timings(&mut self) {
466 unsafe { llama_cpp_sys_4::ggml_time_init() }
467 }
468
469 /// Returns the timings for the context.
470 pub fn timings(&mut self) -> PerfContextData {
471 let perf_context_data =
472 unsafe { llama_cpp_sys_4::llama_perf_context(self.context.as_ptr()) };
473 PerfContextData { perf_context_data }
474 }
475
476 /// Reset the performance counters for the context.
477 pub fn perf_context_reset(&mut self) {
478 unsafe { llama_cpp_sys_4::llama_perf_context_reset(self.context.as_ptr()) }
479 }
480
481 /// Check if the KV cache memory supports shifting.
482 #[must_use]
483 pub fn memory_can_shift(&self) -> bool {
484 unsafe {
485 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
486 llama_cpp_sys_4::llama_memory_can_shift(mem)
487 }
488 }
489
490 /// Get the minimum position in a sequence's KV cache.
491 #[must_use]
492 pub fn memory_seq_pos_min(&self, seq_id: i32) -> i32 {
493 unsafe {
494 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
495 llama_cpp_sys_4::llama_memory_seq_pos_min(mem, seq_id)
496 }
497 }
498
499 /// Print a breakdown of the memory usage.
500 pub fn memory_breakdown_print(&self) {
501 unsafe {
502 llama_cpp_sys_4::common_memory_breakdown_print(self.context.as_ptr());
503 }
504 }
505
506 /// Get the size of the full context state in bytes.
507 ///
508 /// This is the size needed for [`state_get_data`](Self::state_get_data) and
509 /// [`state_set_data`](Self::state_set_data).
510 #[must_use]
511 pub fn state_get_size(&mut self) -> usize {
512 unsafe { llama_cpp_sys_4::llama_state_get_size(self.context.as_ptr()) }
513 }
514
515 /// Copy the full context state into a byte buffer.
516 ///
517 /// The buffer must be at least [`state_get_size`](Self::state_get_size) bytes.
518 ///
519 /// Returns the number of bytes written.
520 pub fn state_get_data(&mut self, dst: &mut [u8]) -> usize {
521 unsafe {
522 llama_cpp_sys_4::llama_state_get_data(
523 self.context.as_ptr(),
524 dst.as_mut_ptr(),
525 dst.len(),
526 )
527 }
528 }
529
530 /// Restore the full context state from a byte buffer.
531 ///
532 /// Returns the number of bytes read.
533 pub fn state_set_data(&mut self, src: &[u8]) -> usize {
534 unsafe {
535 llama_cpp_sys_4::llama_state_set_data(self.context.as_ptr(), src.as_ptr(), src.len())
536 }
537 }
538
539 /// Save the context state to a file along with the given tokens.
540 ///
541 /// Returns `true` on success.
542 ///
543 /// # Panics
544 ///
545 /// Panics if the path contains null bytes.
546 pub fn state_save_file(
547 &mut self,
548 path: impl AsRef<std::path::Path>,
549 tokens: &[LlamaToken],
550 ) -> bool {
551 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
552 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
553 unsafe {
554 llama_cpp_sys_4::llama_state_save_file(
555 self.context.as_ptr(),
556 c_path.as_ptr(),
557 tokens.as_ptr().cast(),
558 tokens.len(),
559 )
560 }
561 }
562
563 /// Load a context state from a file.
564 ///
565 /// Returns `true` on success and fills `tokens_out` with the saved tokens.
566 ///
567 /// # Panics
568 ///
569 /// Panics if the path contains null bytes.
570 pub fn state_load_file(
571 &mut self,
572 path: impl AsRef<std::path::Path>,
573 tokens_out: &mut Vec<LlamaToken>,
574 n_token_capacity: usize,
575 ) -> bool {
576 tokens_out.resize(n_token_capacity, LlamaToken(0));
577 let mut n_token_count: usize = 0;
578 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
579 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
580 let ok = unsafe {
581 llama_cpp_sys_4::llama_state_load_file(
582 self.context.as_ptr(),
583 c_path.as_ptr(),
584 tokens_out.as_mut_ptr().cast(),
585 n_token_capacity,
586 std::ptr::addr_of_mut!(n_token_count),
587 )
588 };
589 if ok {
590 tokens_out.truncate(n_token_count);
591 }
592 ok
593 }
594
595 /// Get the size of a single sequence's state in bytes.
596 #[must_use]
597 pub fn state_seq_get_size(&mut self, seq_id: i32) -> usize {
598 unsafe { llama_cpp_sys_4::llama_state_seq_get_size(self.context.as_ptr(), seq_id) }
599 }
600
601 /// Copy a single sequence's state into a byte buffer.
602 ///
603 /// Returns the number of bytes written.
604 pub fn state_seq_get_data(&mut self, dst: &mut [u8], seq_id: i32) -> usize {
605 unsafe {
606 llama_cpp_sys_4::llama_state_seq_get_data(
607 self.context.as_ptr(),
608 dst.as_mut_ptr(),
609 dst.len(),
610 seq_id,
611 )
612 }
613 }
614
615 /// Restore a single sequence's state from a byte buffer.
616 ///
617 /// Returns the number of bytes read.
618 pub fn state_seq_set_data(&mut self, src: &[u8], dest_seq_id: i32) -> usize {
619 unsafe {
620 llama_cpp_sys_4::llama_state_seq_set_data(
621 self.context.as_ptr(),
622 src.as_ptr(),
623 src.len(),
624 dest_seq_id,
625 )
626 }
627 }
628
629 /// Save a single sequence's state to a file.
630 ///
631 /// Returns the number of bytes written (0 on failure).
632 ///
633 /// # Panics
634 ///
635 /// Panics if the path contains null bytes.
636 pub fn state_seq_save_file(
637 &mut self,
638 path: impl AsRef<std::path::Path>,
639 seq_id: i32,
640 tokens: &[LlamaToken],
641 ) -> usize {
642 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
643 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
644 unsafe {
645 llama_cpp_sys_4::llama_state_seq_save_file(
646 self.context.as_ptr(),
647 c_path.as_ptr(),
648 seq_id,
649 tokens.as_ptr().cast(),
650 tokens.len(),
651 )
652 }
653 }
654
655 /// Load a single sequence's state from a file.
656 ///
657 /// Returns the number of bytes read (0 on failure).
658 ///
659 /// # Panics
660 ///
661 /// Panics if the path contains null bytes.
662 pub fn state_seq_load_file(
663 &mut self,
664 path: impl AsRef<std::path::Path>,
665 dest_seq_id: i32,
666 tokens_out: &mut Vec<LlamaToken>,
667 n_token_capacity: usize,
668 ) -> usize {
669 tokens_out.resize(n_token_capacity, LlamaToken(0));
670 let mut n_token_count: usize = 0;
671 let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
672 let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
673 let ret = unsafe {
674 llama_cpp_sys_4::llama_state_seq_load_file(
675 self.context.as_ptr(),
676 c_path.as_ptr(),
677 dest_seq_id,
678 tokens_out.as_mut_ptr().cast(),
679 n_token_capacity,
680 std::ptr::addr_of_mut!(n_token_count),
681 )
682 };
683 if ret > 0 {
684 tokens_out.truncate(n_token_count);
685 }
686 ret
687 }
688
689 /// Set a control vector on the context.
690 ///
691 /// # Parameters
692 ///
693 /// - `data`: The control vector data (embedding values). Pass an empty slice to clear.
694 /// - `n_embd`: The embedding dimension.
695 /// - `il_start`: The starting layer index (inclusive).
696 /// - `il_end`: The ending layer index (exclusive).
697 ///
698 /// # Errors
699 ///
700 /// Returns `Err` with the error code if the operation fails.
701 pub fn set_adapter_cvec(
702 &mut self,
703 data: &[f32],
704 n_embd: i32,
705 il_start: i32,
706 il_end: i32,
707 ) -> Result<(), i32> {
708 let ret = unsafe {
709 llama_cpp_sys_4::llama_set_adapter_cvec(
710 self.context.as_ptr(),
711 data.as_ptr(),
712 data.len(),
713 n_embd,
714 il_start,
715 il_end,
716 )
717 };
718 if ret != 0 {
719 Err(ret)
720 } else {
721 Ok(())
722 }
723 }
724
725 /// Get sampled token debug info for the `i`th position.
726 ///
727 /// Returns the sampled token at position `i` from the last decode call.
728 #[must_use]
729 pub fn get_sampled_token_ith(&self, i: i32) -> LlamaToken {
730 let token =
731 unsafe { llama_cpp_sys_4::llama_get_sampled_token_ith(self.context.as_ptr(), i) };
732 LlamaToken(token)
733 }
734
735 /// Get sampled candidate tokens for the `i`th position.
736 ///
737 /// Returns a slice of candidate tokens from the last decode call.
738 #[must_use]
739 pub fn get_sampled_candidates_ith(&self, i: i32) -> &[LlamaToken] {
740 let count = unsafe {
741 llama_cpp_sys_4::llama_get_sampled_candidates_count_ith(self.context.as_ptr(), i)
742 } as usize;
743 if count == 0 {
744 return &[];
745 }
746 let ptr =
747 unsafe { llama_cpp_sys_4::llama_get_sampled_candidates_ith(self.context.as_ptr(), i) };
748 if ptr.is_null() {
749 return &[];
750 }
751 unsafe { slice::from_raw_parts(ptr.cast::<LlamaToken>(), count) }
752 }
753
754 /// Get the number of sampled logits for the `i`th position.
755 #[must_use]
756 pub fn get_sampled_logits_count_ith(&self, i: i32) -> u32 {
757 unsafe { llama_cpp_sys_4::llama_get_sampled_logits_count_ith(self.context.as_ptr(), i) }
758 }
759
760 /// Get sampled logits for the `i`th position.
761 ///
762 /// Returns a slice of logit values from the last decode call.
763 #[must_use]
764 pub fn get_sampled_logits_ith(&self, i: i32) -> &[f32] {
765 let count = self.get_sampled_logits_count_ith(i) as usize;
766 if count == 0 {
767 return &[];
768 }
769 let ptr =
770 unsafe { llama_cpp_sys_4::llama_get_sampled_logits_ith(self.context.as_ptr(), i) };
771 if ptr.is_null() {
772 return &[];
773 }
774 unsafe { slice::from_raw_parts(ptr, count) }
775 }
776
777 /// Get the number of sampled probabilities for the `i`th position.
778 #[must_use]
779 pub fn get_sampled_probs_count_ith(&self, i: i32) -> u32 {
780 unsafe { llama_cpp_sys_4::llama_get_sampled_probs_count_ith(self.context.as_ptr(), i) }
781 }
782
783 /// Get sampled probabilities for the `i`th position.
784 ///
785 /// Returns a slice of probability values from the last decode call.
786 #[must_use]
787 pub fn get_sampled_probs_ith(&self, i: i32) -> &[f32] {
788 let count = self.get_sampled_probs_count_ith(i) as usize;
789 if count == 0 {
790 return &[];
791 }
792 let ptr = unsafe { llama_cpp_sys_4::llama_get_sampled_probs_ith(self.context.as_ptr(), i) };
793 if ptr.is_null() {
794 return &[];
795 }
796 unsafe { slice::from_raw_parts(ptr, count) }
797 }
798
799 /// Get the size of a single sequence's state with flags.
800 #[must_use]
801 pub fn state_seq_get_size_ext(&mut self, seq_id: i32, flags: u32) -> usize {
802 unsafe {
803 llama_cpp_sys_4::llama_state_seq_get_size_ext(self.context.as_ptr(), seq_id, flags)
804 }
805 }
806
807 /// Copy a single sequence's state into a byte buffer with flags.
808 ///
809 /// Returns the number of bytes written.
810 pub fn state_seq_get_data_ext(&mut self, dst: &mut [u8], seq_id: i32, flags: u32) -> usize {
811 unsafe {
812 llama_cpp_sys_4::llama_state_seq_get_data_ext(
813 self.context.as_ptr(),
814 dst.as_mut_ptr(),
815 dst.len(),
816 seq_id,
817 flags,
818 )
819 }
820 }
821
822 /// Restore a single sequence's state from a byte buffer with flags.
823 ///
824 /// Returns the number of bytes read.
825 pub fn state_seq_set_data_ext(&mut self, src: &[u8], dest_seq_id: i32, flags: u32) -> usize {
826 unsafe {
827 llama_cpp_sys_4::llama_state_seq_set_data_ext(
828 self.context.as_ptr(),
829 src.as_ptr(),
830 src.len(),
831 dest_seq_id,
832 flags,
833 )
834 }
835 }
836
837 /// Set an abort callback for the context.
838 ///
839 /// The callback is called periodically during computation. If it returns `true`,
840 /// the computation is aborted.
841 ///
842 /// # Safety
843 ///
844 /// The callback data must remain valid for the lifetime of the context or until
845 /// the callback is replaced.
846 pub unsafe fn set_abort_callback(
847 &mut self,
848 callback: llama_cpp_sys_4::ggml_abort_callback,
849 data: *mut std::ffi::c_void,
850 ) {
851 llama_cpp_sys_4::llama_set_abort_callback(self.context.as_ptr(), callback, data);
852 }
853
854 /// Attach a thread pool to the context.
855 ///
856 /// # Safety
857 ///
858 /// The thread pools must remain valid for the lifetime of the context or until
859 /// they are detached.
860 pub unsafe fn attach_threadpool(
861 &mut self,
862 threadpool: llama_cpp_sys_4::ggml_threadpool_t,
863 threadpool_batch: llama_cpp_sys_4::ggml_threadpool_t,
864 ) {
865 llama_cpp_sys_4::llama_attach_threadpool(
866 self.context.as_ptr(),
867 threadpool,
868 threadpool_batch,
869 );
870 }
871
872 /// Detach the thread pool from the context.
873 pub fn detach_threadpool(&mut self) {
874 unsafe {
875 llama_cpp_sys_4::llama_detach_threadpool(self.context.as_ptr());
876 }
877 }
878
879 /// Set a sampler for a specific sequence.
880 ///
881 /// Returns `true` on success.
882 pub fn set_sampler(
883 &mut self,
884 seq_id: i32,
885 sampler: &mut crate::sampling::LlamaSampler,
886 ) -> bool {
887 unsafe {
888 llama_cpp_sys_4::llama_set_sampler(
889 self.context.as_ptr(),
890 seq_id,
891 sampler.sampler.as_ptr(),
892 )
893 }
894 }
895
896 /// Get the raw model pointer from this context.
897 ///
898 /// This is mainly useful for FFI interop. In normal usage, access
899 /// the model via the `model` field instead.
900 #[must_use]
901 pub fn get_model_ptr(&self) -> *const llama_cpp_sys_4::llama_model {
902 unsafe { llama_cpp_sys_4::llama_get_model(self.context.as_ptr()) }
903 }
904
905 /// Sets a lora adapter.
906 ///
907 /// # Errors
908 ///
909 /// See [`LlamaLoraAdapterSetError`] for more information.
910 pub fn lora_adapter_set(
911 &self,
912 adapter: &mut LlamaLoraAdapter,
913 scale: f32,
914 ) -> Result<(), LlamaLoraAdapterSetError> {
915 let err_code = unsafe {
916 // llama_set_adapter_lora / llama_rm_adapter_lora were replaced by llama_set_adapters_lora
917 // which takes a full list of adapters + scales at once (b8249+)
918 let mut adapter_ptr = adapter.lora_adapter.as_ptr();
919 let mut scale_val = scale;
920 llama_cpp_sys_4::llama_set_adapters_lora(
921 self.context.as_ptr(),
922 &raw mut adapter_ptr,
923 1,
924 &raw mut scale_val,
925 )
926 };
927 if err_code != 0 {
928 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
929 }
930
931 tracing::debug!("Set lora adapter");
932 Ok(())
933 }
934
935 /// Remove all lora adapters from the context.
936 ///
937 /// Note: as of llama.cpp b8249 the per-adapter remove API was replaced by
938 /// `llama_set_adapters_lora` which operates on the full adapter list at once.
939 /// Calling this function clears **all** adapters currently set on the context.
940 ///
941 /// # Errors
942 ///
943 /// See [`LlamaLoraAdapterRemoveError`] for more information.
944 pub fn lora_adapter_remove(
945 &self,
946 _adapter: &mut LlamaLoraAdapter,
947 ) -> Result<(), LlamaLoraAdapterRemoveError> {
948 let err_code = unsafe {
949 llama_cpp_sys_4::llama_set_adapters_lora(
950 self.context.as_ptr(),
951 std::ptr::null_mut(),
952 0,
953 std::ptr::null_mut(),
954 )
955 };
956 if err_code != 0 {
957 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
958 }
959
960 tracing::debug!("Remove lora adapter");
961 Ok(())
962 }
963}
964
965impl Drop for LlamaContext<'_> {
966 fn drop(&mut self) {
967 unsafe { llama_cpp_sys_4::llama_free(self.context.as_ptr()) }
968 }
969}