1use std::ffi::c_void;
4use std::fmt::{Debug, Formatter};
5use std::num::NonZeroI32;
6use std::ptr::NonNull;
7use std::slice;
8use std::sync::Arc;
9use std::sync::atomic::AtomicBool;
10use std::sync::atomic::Ordering;
11
12use crate::context::params::LlamaContextParams;
13use crate::llama_backend::LlamaBackend;
14use crate::llama_batch::LlamaBatch;
15use crate::model::{LlamaLoraAdapter, LlamaModel};
16use crate::timing::LlamaTimings;
17use crate::token::LlamaToken;
18use crate::token::data::LlamaTokenData;
19use crate::token::data_array::LlamaTokenDataArray;
20use crate::{
21 DecodeError, EmbeddingsError, EncodeError, LlamaContextLoadError, LlamaLoraAdapterRemoveError,
22 LlamaLoraAdapterSetError, LogitsError,
23};
24
25const fn check_lora_set_result(err_code: i32) -> Result<(), LlamaLoraAdapterSetError> {
26 if err_code != 0 {
27 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
28 }
29
30 Ok(())
31}
32
33const fn check_lora_remove_result(err_code: i32) -> Result<(), LlamaLoraAdapterRemoveError> {
34 if err_code != 0 {
35 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
36 }
37
38 Ok(())
39}
40
41pub mod kv_cache;
42pub mod llama_state_seq_flags;
43pub mod load_seq_state_error;
44pub mod load_session_error;
45pub mod params;
46pub mod save_seq_state_error;
47pub mod save_session_error;
48pub mod session;
49
50unsafe extern "C" fn abort_callback_trampoline(data: *mut c_void) -> bool {
51 let flag = unsafe { &*(data as *const AtomicBool) };
52
53 flag.load(Ordering::Relaxed)
54}
55
56pub struct LlamaContext<'model> {
58 pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
60 pub model: &'model LlamaModel,
62 abort_flag: Option<Arc<AtomicBool>>,
63 initialized_logits: Vec<i32>,
64 embeddings_enabled: bool,
65}
66
67impl Debug for LlamaContext<'_> {
68 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
69 f.debug_struct("LlamaContext")
70 .field("context", &self.context)
71 .finish()
72 }
73}
74
75impl<'model> LlamaContext<'model> {
76 #[must_use]
78 pub const fn new(
79 llama_model: &'model LlamaModel,
80 llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
81 embeddings_enabled: bool,
82 ) -> Self {
83 Self {
84 context: llama_context,
85 model: llama_model,
86 abort_flag: None,
87 initialized_logits: Vec::new(),
88 embeddings_enabled,
89 }
90 }
91
92 #[expect(
101 clippy::needless_pass_by_value,
102 reason = "LlamaContextParams may become non-trivially copyable upstream"
103 )]
104 pub fn from_model(
105 model: &'model LlamaModel,
106 _backend: &LlamaBackend,
107 params: LlamaContextParams,
108 ) -> Result<Self, LlamaContextLoadError> {
109 let context_params = params.context_params;
110 let context = unsafe {
111 llama_cpp_bindings_sys::llama_new_context_with_model(
112 model.model.as_ptr(),
113 context_params,
114 )
115 };
116 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
117
118 Ok(Self::new(model, context, params.embeddings()))
119 }
120
121 #[must_use]
123 pub fn n_batch(&self) -> u32 {
124 unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
125 }
126
127 #[must_use]
129 pub fn n_ubatch(&self) -> u32 {
130 unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
131 }
132
133 #[must_use]
135 pub fn n_ctx(&self) -> u32 {
136 unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
137 }
138
139 #[expect(unsafe_code, reason = "required for FFI abort callback registration")]
145 pub fn set_abort_flag(&mut self, flag: Arc<AtomicBool>) {
146 let raw_ptr = Arc::as_ptr(&flag) as *mut c_void;
147 self.abort_flag = Some(flag);
148
149 unsafe {
150 llama_cpp_bindings_sys::llama_set_abort_callback(
151 self.context.as_ptr(),
152 Some(abort_callback_trampoline),
153 raw_ptr,
154 );
155 }
156 }
157
158 #[expect(unsafe_code, reason = "required for FFI abort callback deregistration")]
160 pub fn clear_abort_callback(&mut self) {
161 self.abort_flag = None;
162
163 unsafe {
164 llama_cpp_bindings_sys::llama_set_abort_callback(
165 self.context.as_ptr(),
166 None,
167 std::ptr::null_mut(),
168 );
169 }
170 }
171
172 #[expect(unsafe_code, reason = "required for FFI synchronization call")]
177 pub fn synchronize(&self) {
178 unsafe { llama_cpp_bindings_sys::llama_synchronize(self.context.as_ptr()) }
179 }
180
181 #[expect(unsafe_code, reason = "required for FFI threadpool detachment")]
186 pub fn detach_threadpool(&self) {
187 unsafe { llama_cpp_bindings_sys::llama_detach_threadpool(self.context.as_ptr()) }
188 }
189
190 pub fn mark_logits_initialized(&mut self, token_index: i32) {
194 self.initialized_logits = vec![token_index];
195 }
196
197 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
207 let result = unsafe {
208 llama_cpp_bindings_sys::llama_decode(self.context.as_ptr(), batch.llama_batch)
209 };
210
211 match NonZeroI32::new(result) {
212 None => {
213 self.initialized_logits
214 .clone_from(&batch.initialized_logits);
215 Ok(())
216 }
217 Some(error) => Err(DecodeError::from(error)),
218 }
219 }
220
221 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
231 let status = unsafe {
232 llama_cpp_bindings_sys::llama_rs_encode(self.context.as_ptr(), batch.llama_batch)
233 };
234
235 self.handle_encode_result(status, batch)
236 }
237
238 fn handle_encode_result(
239 &mut self,
240 status: llama_cpp_bindings_sys::llama_rs_status,
241 batch: &mut LlamaBatch,
242 ) -> Result<(), EncodeError> {
243 if crate::status_is_ok(status) {
244 self.initialized_logits
245 .clone_from(&batch.initialized_logits);
246
247 Ok(())
248 } else {
249 Err(EncodeError::from(
250 NonZeroI32::new(crate::status_to_i32(status))
251 .unwrap_or(NonZeroI32::new(1).expect("1 is non-zero")),
252 ))
253 }
254 }
255
256 pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
270 if !self.embeddings_enabled {
271 return Err(EmbeddingsError::NotEnabled);
272 }
273
274 let n_embd = usize::try_from(self.model.n_embd())
275 .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
276
277 unsafe {
278 let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
279 self.context.as_ptr(),
280 sequence_index,
281 );
282
283 if embedding.is_null() {
284 Err(EmbeddingsError::NonePoolType)
285 } else {
286 Ok(slice::from_raw_parts(embedding, n_embd))
287 }
288 }
289 }
290
291 pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
305 if !self.embeddings_enabled {
306 return Err(EmbeddingsError::NotEnabled);
307 }
308
309 let n_embd = usize::try_from(self.model.n_embd())
310 .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
311
312 unsafe {
313 let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
314 self.context.as_ptr(),
315 token_index,
316 );
317
318 if embedding.is_null() {
319 Err(EmbeddingsError::LogitsNotEnabled)
320 } else {
321 Ok(slice::from_raw_parts(embedding, n_embd))
322 }
323 }
324 }
325
326 pub fn candidates(&self) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
335 let logits = self.get_logits()?;
336
337 Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
338 let token = LlamaToken::new(token_id);
339 LlamaTokenData::new(token, *logit, 0_f32)
340 }))
341 }
342
343 pub fn token_data_array(&self) -> Result<LlamaTokenDataArray, LogitsError> {
348 Ok(LlamaTokenDataArray::from_iter(self.candidates()?, false))
349 }
350
351 pub fn get_logits(&self) -> Result<&[f32], LogitsError> {
365 let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
366
367 if data.is_null() {
368 return Err(LogitsError::NullLogits);
369 }
370
371 let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
372
373 Ok(unsafe { slice::from_raw_parts(data, len) })
374 }
375
376 pub fn candidates_ith(
381 &self,
382 token_index: i32,
383 ) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
384 let logits = self.get_logits_ith(token_index)?;
385
386 Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
387 let token = LlamaToken::new(token_id);
388 LlamaTokenData::new(token, *logit, 0_f32)
389 }))
390 }
391
392 pub fn token_data_array_ith(
397 &self,
398 token_index: i32,
399 ) -> Result<LlamaTokenDataArray, LogitsError> {
400 Ok(LlamaTokenDataArray::from_iter(
401 self.candidates_ith(token_index)?,
402 false,
403 ))
404 }
405
406 pub fn get_logits_ith(&self, token_index: i32) -> Result<&[f32], LogitsError> {
411 if !self.initialized_logits.contains(&token_index) {
412 return Err(LogitsError::TokenNotInitialized(token_index));
413 }
414
415 if token_index >= 0 {
416 let token_index_u32 =
417 u32::try_from(token_index).map_err(LogitsError::TokenIndexOverflow)?;
418
419 if self.n_ctx() <= token_index_u32 {
420 return Err(LogitsError::TokenIndexExceedsContext {
421 token_index: token_index_u32,
422 context_size: self.n_ctx(),
423 });
424 }
425 }
426
427 let data = unsafe {
428 llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
429 };
430 let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
431
432 Ok(unsafe { slice::from_raw_parts(data, len) })
433 }
434
435 pub fn reset_timings(&mut self) {
437 unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
438 }
439
440 pub fn timings(&mut self) -> LlamaTimings {
442 let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
443 LlamaTimings { timings }
444 }
445
446 pub fn lora_adapter_set(
452 &self,
453 adapter: &mut LlamaLoraAdapter,
454 scale: f32,
455 ) -> Result<(), LlamaLoraAdapterSetError> {
456 let mut adapters = [adapter.lora_adapter.as_ptr()];
457 let mut scales = [scale];
458 let err_code = unsafe {
459 llama_cpp_bindings_sys::llama_set_adapters_lora(
460 self.context.as_ptr(),
461 adapters.as_mut_ptr(),
462 1,
463 scales.as_mut_ptr(),
464 )
465 };
466 check_lora_set_result(err_code)?;
467
468 tracing::debug!("Set lora adapter");
469 Ok(())
470 }
471
472 pub fn lora_adapter_remove(
481 &self,
482 _adapter: &mut LlamaLoraAdapter,
483 ) -> Result<(), LlamaLoraAdapterRemoveError> {
484 let err_code = unsafe {
485 llama_cpp_bindings_sys::llama_set_adapters_lora(
486 self.context.as_ptr(),
487 std::ptr::null_mut(),
488 0,
489 std::ptr::null_mut(),
490 )
491 };
492 check_lora_remove_result(err_code)?;
493
494 tracing::debug!("Remove lora adapter");
495 Ok(())
496 }
497}
498
499impl Drop for LlamaContext<'_> {
500 fn drop(&mut self) {
501 unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
502 }
503}
504
505#[cfg(test)]
506mod unit_tests {
507 use crate::LlamaLoraAdapterRemoveError;
508 use crate::LlamaLoraAdapterSetError;
509
510 use super::{check_lora_remove_result, check_lora_set_result};
511
512 #[test]
513 fn check_lora_set_result_ok_for_zero() {
514 assert!(check_lora_set_result(0).is_ok());
515 }
516
517 #[test]
518 fn check_lora_set_result_error_for_nonzero() {
519 let result = check_lora_set_result(-1);
520
521 assert_eq!(result, Err(LlamaLoraAdapterSetError::ErrorResult(-1)));
522 }
523
524 #[test]
525 fn check_lora_remove_result_ok_for_zero() {
526 assert!(check_lora_remove_result(0).is_ok());
527 }
528
529 #[test]
530 fn check_lora_remove_result_error_for_nonzero() {
531 let result = check_lora_remove_result(-1);
532
533 assert_eq!(result, Err(LlamaLoraAdapterRemoveError::ErrorResult(-1)));
534 }
535}