1use std::ffi::c_void;
4use std::fmt::{Debug, Formatter};
5use std::num::NonZeroI32;
6use std::ptr::NonNull;
7use std::slice;
8use std::sync::Arc;
9use std::sync::atomic::AtomicBool;
10use std::sync::atomic::Ordering;
11
12use crate::context::params::LlamaContextParams;
13use crate::llama_backend::LlamaBackend;
14use crate::llama_batch::LlamaBatch;
15use crate::model::{LlamaLoraAdapter, LlamaModel};
16use crate::timing::LlamaTimings;
17use crate::token::LlamaToken;
18use crate::token::data::LlamaTokenData;
19use crate::token::data_array::LlamaTokenDataArray;
20use crate::{
21 DecodeError, EmbeddingsError, EncodeError, LlamaContextLoadError, LlamaLoraAdapterRemoveError,
22 LlamaLoraAdapterSetError, LogitsError,
23};
24
25const fn check_lora_set_result(err_code: i32) -> Result<(), LlamaLoraAdapterSetError> {
26 if err_code != 0 {
27 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
28 }
29
30 Ok(())
31}
32
33const fn check_lora_remove_result(err_code: i32) -> Result<(), LlamaLoraAdapterRemoveError> {
34 if err_code != 0 {
35 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
36 }
37
38 Ok(())
39}
40
41pub mod kv_cache;
42pub mod kv_cache_type;
43pub mod llama_attention_type;
44pub mod llama_pooling_type;
45pub mod llama_state_seq_flags;
46pub mod load_seq_state_error;
47pub mod load_session_error;
48pub mod params;
49pub mod rope_scaling_type;
50pub mod save_seq_state_error;
51pub mod save_session_error;
52pub mod session;
53
54unsafe extern "C" fn abort_callback_trampoline(data: *mut c_void) -> bool {
55 let flag = unsafe { &*(data as *const AtomicBool) };
56
57 flag.load(Ordering::Relaxed)
58}
59
60pub struct LlamaContext<'model> {
62 pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
64 pub model: &'model LlamaModel,
66 abort_flag: Option<Arc<AtomicBool>>,
67 initialized_logits: Vec<i32>,
68 embeddings_enabled: bool,
69}
70
71impl Debug for LlamaContext<'_> {
72 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct("LlamaContext")
74 .field("context", &self.context)
75 .finish()
76 }
77}
78
79impl<'model> LlamaContext<'model> {
80 #[must_use]
82 pub const fn new(
83 llama_model: &'model LlamaModel,
84 llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
85 embeddings_enabled: bool,
86 ) -> Self {
87 Self {
88 context: llama_context,
89 model: llama_model,
90 abort_flag: None,
91 initialized_logits: Vec::new(),
92 embeddings_enabled,
93 }
94 }
95
96 #[expect(
105 clippy::needless_pass_by_value,
106 reason = "LlamaContextParams may become non-trivially copyable upstream"
107 )]
108 pub fn from_model(
109 model: &'model LlamaModel,
110 _backend: &LlamaBackend,
111 params: LlamaContextParams,
112 ) -> Result<Self, LlamaContextLoadError> {
113 let context_params = params.context_params;
114 let mut out_ctx: *mut llama_cpp_bindings_sys::llama_context = std::ptr::null_mut();
115 let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
116 let status = unsafe {
117 llama_cpp_bindings_sys::llama_rs_new_context_with_model(
118 model.model.as_ptr(),
119 context_params,
120 &raw mut out_ctx,
121 &raw mut out_error,
122 )
123 };
124 match status {
125 llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_OK => {
126 let context = NonNull::new(out_ctx)
127 .ok_or(LlamaContextLoadError::Unconstructible)?;
128 Ok(Self::new(model, context, params.embeddings()))
129 }
130 llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_VENDORED_RETURNED_NULL => {
131 Err(LlamaContextLoadError::Unconstructible)
132 }
133 llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_ERROR_STRING_ALLOCATION_FAILED => {
134 Err(LlamaContextLoadError::NotEnoughMemory)
135 }
136 llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_VENDORED_THREW_CXX_EXCEPTION => {
137 let message = unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
138 Err(LlamaContextLoadError::Reported { message })
139 }
140 other => unreachable!(
141 "llama_rs_new_context_with_model returned unrecognized status {other}"
142 ),
143 }
144 }
145
146 #[must_use]
148 pub fn n_batch(&self) -> u32 {
149 unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
150 }
151
152 #[must_use]
154 pub fn n_ubatch(&self) -> u32 {
155 unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
156 }
157
158 #[must_use]
160 pub fn n_ctx(&self) -> u32 {
161 unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
162 }
163
164 #[expect(unsafe_code, reason = "required for FFI abort callback registration")]
170 pub fn set_abort_flag(&mut self, flag: Arc<AtomicBool>) {
171 let raw_ptr = Arc::as_ptr(&flag) as *mut c_void;
172 self.abort_flag = Some(flag);
173
174 unsafe {
175 llama_cpp_bindings_sys::llama_set_abort_callback(
176 self.context.as_ptr(),
177 Some(abort_callback_trampoline),
178 raw_ptr,
179 );
180 }
181 }
182
183 #[expect(unsafe_code, reason = "required for FFI abort callback deregistration")]
185 pub fn clear_abort_callback(&mut self) {
186 self.abort_flag = None;
187
188 unsafe {
189 llama_cpp_bindings_sys::llama_set_abort_callback(
190 self.context.as_ptr(),
191 None,
192 std::ptr::null_mut(),
193 );
194 }
195 }
196
197 #[expect(unsafe_code, reason = "required for FFI synchronization call")]
202 pub fn synchronize(&self) {
203 unsafe { llama_cpp_bindings_sys::llama_synchronize(self.context.as_ptr()) }
204 }
205
206 #[expect(unsafe_code, reason = "required for FFI threadpool detachment")]
211 pub fn detach_threadpool(&self) {
212 unsafe { llama_cpp_bindings_sys::llama_detach_threadpool(self.context.as_ptr()) }
213 }
214
215 pub fn mark_logits_initialized(&mut self, token_index: i32) {
219 self.initialized_logits = vec![token_index];
220 }
221
222 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
228 let mut out_vendored_return_code: i32 = 0;
229 let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
230 let status = unsafe {
231 llama_cpp_bindings_sys::llama_rs_decode(
232 self.context.as_ptr(),
233 batch.llama_batch,
234 &raw mut out_vendored_return_code,
235 &raw mut out_error,
236 )
237 };
238 match status {
239 llama_cpp_bindings_sys::LLAMA_RS_DECODE_OK => {
240 self.initialized_logits
241 .clone_from(&batch.initialized_logits);
242 Ok(())
243 }
244 llama_cpp_bindings_sys::LLAMA_RS_DECODE_VENDORED_RETURNED_NONZERO_CODE => {
245 let code = NonZeroI32::new(out_vendored_return_code).unwrap_or_else(|| {
246 unreachable!(
247 "llama_rs_decode reported a nonzero return code but the value was zero"
248 )
249 });
250 Err(DecodeError::from(code))
251 }
252 llama_cpp_bindings_sys::LLAMA_RS_DECODE_OUT_OF_MEMORY => {
253 Err(DecodeError::DecodeOutOfMemory)
254 }
255 llama_cpp_bindings_sys::LLAMA_RS_DECODE_COMPUTE_FAILED => {
256 Err(DecodeError::ComputeFailed)
257 }
258 llama_cpp_bindings_sys::LLAMA_RS_DECODE_ERROR_STRING_ALLOCATION_FAILED => {
259 Err(DecodeError::NotEnoughMemory)
260 }
261 llama_cpp_bindings_sys::LLAMA_RS_DECODE_VENDORED_THREW_CXX_EXCEPTION => {
262 let message =
263 unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
264 Err(DecodeError::Reported { message })
265 }
266 other => unreachable!("llama_rs_decode returned unrecognized status {other}"),
267 }
268 }
269
270 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
276 let mut out_vendored_return_code: i32 = 0;
277 let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
278 let status = unsafe {
279 llama_cpp_bindings_sys::llama_rs_encode(
280 self.context.as_ptr(),
281 batch.llama_batch,
282 &raw mut out_vendored_return_code,
283 &raw mut out_error,
284 )
285 };
286 match status {
287 llama_cpp_bindings_sys::LLAMA_RS_ENCODE_OK => {
288 self.initialized_logits
289 .clone_from(&batch.initialized_logits);
290 Ok(())
291 }
292 llama_cpp_bindings_sys::LLAMA_RS_ENCODE_MODEL_HAS_NO_ENCODER => {
293 Err(EncodeError::ModelHasNoEncoder)
294 }
295 llama_cpp_bindings_sys::LLAMA_RS_ENCODE_VENDORED_RETURNED_NONZERO_CODE => {
296 let code = NonZeroI32::new(out_vendored_return_code).unwrap_or_else(|| {
297 unreachable!(
298 "llama_rs_encode reported a nonzero return code but the value was zero"
299 )
300 });
301 Err(EncodeError::from(code))
302 }
303 llama_cpp_bindings_sys::LLAMA_RS_ENCODE_OUT_OF_MEMORY => {
304 Err(EncodeError::EncodeOutOfMemory)
305 }
306 llama_cpp_bindings_sys::LLAMA_RS_ENCODE_COMPUTE_FAILED => {
307 Err(EncodeError::ComputeFailed)
308 }
309 llama_cpp_bindings_sys::LLAMA_RS_ENCODE_ERROR_STRING_ALLOCATION_FAILED => {
310 Err(EncodeError::NotEnoughMemory)
311 }
312 llama_cpp_bindings_sys::LLAMA_RS_ENCODE_VENDORED_THREW_CXX_EXCEPTION => {
313 let message =
314 unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
315 Err(EncodeError::Reported { message })
316 }
317 other => unreachable!("llama_rs_encode returned unrecognized status {other}"),
318 }
319 }
320
321 pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
335 if !self.embeddings_enabled {
336 return Err(EmbeddingsError::NotEnabled);
337 }
338
339 let n_embd = usize::try_from(self.model.n_embd())
340 .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
341
342 unsafe {
343 let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
344 self.context.as_ptr(),
345 sequence_index,
346 );
347
348 if embedding.is_null() {
349 Err(EmbeddingsError::NonePoolType)
350 } else {
351 Ok(slice::from_raw_parts(embedding, n_embd))
352 }
353 }
354 }
355
356 pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
370 if !self.embeddings_enabled {
371 return Err(EmbeddingsError::NotEnabled);
372 }
373
374 let n_embd = usize::try_from(self.model.n_embd())
375 .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
376
377 unsafe {
378 let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
379 self.context.as_ptr(),
380 token_index,
381 );
382
383 if embedding.is_null() {
384 Err(EmbeddingsError::LogitsNotEnabled)
385 } else {
386 Ok(slice::from_raw_parts(embedding, n_embd))
387 }
388 }
389 }
390
391 pub fn candidates(&self) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
400 let logits = self.get_logits()?;
401
402 Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
403 let token = LlamaToken::new(token_id);
404 LlamaTokenData::new(token, *logit, 0_f32)
405 }))
406 }
407
408 pub fn token_data_array(&self) -> Result<LlamaTokenDataArray, LogitsError> {
413 Ok(LlamaTokenDataArray::from_iter(self.candidates()?, false))
414 }
415
416 pub fn get_logits(&self) -> Result<&[f32], LogitsError> {
430 let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
431
432 if data.is_null() {
433 return Err(LogitsError::NullLogits);
434 }
435
436 let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
437
438 Ok(unsafe { slice::from_raw_parts(data, len) })
439 }
440
441 pub fn candidates_ith(
446 &self,
447 token_index: i32,
448 ) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
449 let logits = self.get_logits_ith(token_index)?;
450
451 Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
452 let token = LlamaToken::new(token_id);
453 LlamaTokenData::new(token, *logit, 0_f32)
454 }))
455 }
456
457 pub fn token_data_array_ith(
462 &self,
463 token_index: i32,
464 ) -> Result<LlamaTokenDataArray, LogitsError> {
465 Ok(LlamaTokenDataArray::from_iter(
466 self.candidates_ith(token_index)?,
467 false,
468 ))
469 }
470
471 pub fn get_logits_ith(&self, token_index: i32) -> Result<&[f32], LogitsError> {
476 if !self.initialized_logits.contains(&token_index) {
477 return Err(LogitsError::TokenNotInitialized(token_index));
478 }
479
480 if token_index >= 0 {
481 let token_index_u32 =
482 u32::try_from(token_index).map_err(LogitsError::TokenIndexOverflow)?;
483
484 if self.n_ctx() <= token_index_u32 {
485 return Err(LogitsError::TokenIndexExceedsContext {
486 token_index: token_index_u32,
487 context_size: self.n_ctx(),
488 });
489 }
490 }
491
492 let data = unsafe {
493 llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
494 };
495 let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
496
497 Ok(unsafe { slice::from_raw_parts(data, len) })
498 }
499
500 pub fn reset_timings(&mut self) {
502 unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
503 }
504
505 pub fn timings(&mut self) -> LlamaTimings {
507 let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
508 LlamaTimings { timings }
509 }
510
511 pub fn lora_adapter_set(
517 &self,
518 adapter: &mut LlamaLoraAdapter,
519 scale: f32,
520 ) -> Result<(), LlamaLoraAdapterSetError> {
521 let mut adapters = [adapter.lora_adapter.as_ptr()];
522 let mut scales = [scale];
523 let err_code = unsafe {
524 llama_cpp_bindings_sys::llama_set_adapters_lora(
525 self.context.as_ptr(),
526 adapters.as_mut_ptr(),
527 1,
528 scales.as_mut_ptr(),
529 )
530 };
531 check_lora_set_result(err_code)?;
532
533 log::debug!("Set lora adapter");
534 Ok(())
535 }
536
537 pub fn lora_adapter_remove(
546 &self,
547 _adapter: &mut LlamaLoraAdapter,
548 ) -> Result<(), LlamaLoraAdapterRemoveError> {
549 let err_code = unsafe {
550 llama_cpp_bindings_sys::llama_set_adapters_lora(
551 self.context.as_ptr(),
552 std::ptr::null_mut(),
553 0,
554 std::ptr::null_mut(),
555 )
556 };
557 check_lora_remove_result(err_code)?;
558
559 log::debug!("Remove lora adapter");
560 Ok(())
561 }
562}
563
564impl Drop for LlamaContext<'_> {
565 fn drop(&mut self) {
566 unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
567 }
568}
569
570#[cfg(test)]
571mod unit_tests {
572 use crate::LlamaLoraAdapterRemoveError;
573 use crate::LlamaLoraAdapterSetError;
574
575 use super::{check_lora_remove_result, check_lora_set_result};
576
577 #[test]
578 fn check_lora_set_result_ok_for_zero() {
579 assert!(check_lora_set_result(0).is_ok());
580 }
581
582 #[test]
583 fn check_lora_set_result_error_for_nonzero() {
584 let result = check_lora_set_result(-1);
585
586 assert_eq!(result, Err(LlamaLoraAdapterSetError::ErrorResult(-1)));
587 }
588
589 #[test]
590 fn check_lora_remove_result_ok_for_zero() {
591 assert!(check_lora_remove_result(0).is_ok());
592 }
593
594 #[test]
595 fn check_lora_remove_result_error_for_nonzero() {
596 let result = check_lora_remove_result(-1);
597
598 assert_eq!(result, Err(LlamaLoraAdapterRemoveError::ErrorResult(-1)));
599 }
600}