1use std::ffi::c_void;
2use std::fmt::{Debug, Formatter};
3use std::num::NonZeroI32;
4use std::ptr::NonNull;
5use std::slice;
6use std::sync::Arc;
7use std::sync::atomic::AtomicBool;
8use std::sync::atomic::Ordering;
9
10use crate::context::params::LlamaContextParams;
11use crate::llama_backend::LlamaBackend;
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::timing::LlamaTimings;
15use crate::token::LlamaToken;
16use crate::token::data::LlamaTokenData;
17use crate::token::data_array::LlamaTokenDataArray;
18use crate::{
19 DecodeError, EmbeddingsError, EncodeError, LlamaContextLoadError, LlamaLoraAdapterRemoveError,
20 LlamaLoraAdapterSetError, LogitsError,
21};
22
23const fn check_lora_set_result(err_code: i32) -> Result<(), LlamaLoraAdapterSetError> {
24 if err_code != 0 {
25 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
26 }
27
28 Ok(())
29}
30
31const fn check_lora_remove_result(err_code: i32) -> Result<(), LlamaLoraAdapterRemoveError> {
32 if err_code != 0 {
33 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
34 }
35
36 Ok(())
37}
38
39pub mod kv_cache;
40pub mod kv_cache_type;
41pub mod llama_attention_type;
42pub mod llama_pooling_type;
43pub mod llama_state_seq_flags;
44pub mod load_seq_state_error;
45pub mod load_session_error;
46pub mod params;
47pub mod rope_scaling_type;
48pub mod save_seq_state_error;
49pub mod save_session_error;
50pub mod session;
51
52unsafe extern "C" fn abort_callback_trampoline(data: *mut c_void) -> bool {
53 let flag = unsafe { &*(data as *const AtomicBool) };
54
55 flag.load(Ordering::Relaxed)
56}
57
58pub struct LlamaContext<'model> {
59 pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
60 pub model: &'model LlamaModel,
61 abort_flag: Option<Arc<AtomicBool>>,
62 initialized_logits: Vec<i32>,
63 embeddings_enabled: bool,
64}
65
66impl Debug for LlamaContext<'_> {
67 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("LlamaContext")
69 .field("context", &self.context)
70 .finish()
71 }
72}
73
74impl<'model> LlamaContext<'model> {
75 #[must_use]
76 pub const fn new(
77 llama_model: &'model LlamaModel,
78 llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
79 embeddings_enabled: bool,
80 ) -> Self {
81 Self {
82 context: llama_context,
83 model: llama_model,
84 abort_flag: None,
85 initialized_logits: Vec::new(),
86 embeddings_enabled,
87 }
88 }
89
90 #[expect(
94 clippy::needless_pass_by_value,
95 reason = "LlamaContextParams may become non-trivially copyable upstream"
96 )]
97 pub fn from_model(
98 model: &'model LlamaModel,
99 _backend: &LlamaBackend,
100 params: LlamaContextParams,
101 ) -> Result<Self, LlamaContextLoadError> {
102 let context_params = params.context_params;
103 let mut out_ctx: *mut llama_cpp_bindings_sys::llama_context = std::ptr::null_mut();
104 let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
105 let status = unsafe {
106 llama_cpp_bindings_sys::llama_rs_new_context_with_model(
107 model.model.as_ptr(),
108 context_params,
109 &raw mut out_ctx,
110 &raw mut out_error,
111 )
112 };
113 match status {
114 llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_OK => {
115 let context = NonNull::new(out_ctx)
116 .ok_or(LlamaContextLoadError::Unconstructible)?;
117 Ok(Self::new(model, context, params.embeddings()))
118 }
119 llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_VENDORED_RETURNED_NULL => {
120 Err(LlamaContextLoadError::Unconstructible)
121 }
122 llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_ERROR_STRING_ALLOCATION_FAILED => {
123 Err(LlamaContextLoadError::NotEnoughMemory)
124 }
125 llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_VENDORED_THREW_CXX_EXCEPTION => {
126 let message = unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
127 Err(LlamaContextLoadError::Reported { message })
128 }
129 other => unreachable!(
130 "llama_rs_new_context_with_model returned unrecognized status {other}"
131 ),
132 }
133 }
134
135 #[must_use]
136 pub fn n_batch(&self) -> u32 {
137 unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
138 }
139
140 #[must_use]
141 pub fn n_ubatch(&self) -> u32 {
142 unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
143 }
144
145 #[must_use]
146 pub fn n_ctx(&self) -> u32 {
147 unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
148 }
149
150 #[expect(unsafe_code, reason = "required for FFI abort callback registration")]
151 pub fn set_abort_flag(&mut self, flag: Arc<AtomicBool>) {
152 let raw_ptr = Arc::as_ptr(&flag) as *mut c_void;
153 self.abort_flag = Some(flag);
154
155 unsafe {
156 llama_cpp_bindings_sys::llama_set_abort_callback(
157 self.context.as_ptr(),
158 Some(abort_callback_trampoline),
159 raw_ptr,
160 );
161 }
162 }
163
164 #[expect(unsafe_code, reason = "required for FFI abort callback deregistration")]
165 pub fn clear_abort_callback(&mut self) {
166 self.abort_flag = None;
167
168 unsafe {
169 llama_cpp_bindings_sys::llama_set_abort_callback(
170 self.context.as_ptr(),
171 None,
172 std::ptr::null_mut(),
173 );
174 }
175 }
176
177 #[expect(unsafe_code, reason = "required for FFI synchronization call")]
178 pub fn synchronize(&self) {
179 unsafe { llama_cpp_bindings_sys::llama_synchronize(self.context.as_ptr()) }
180 }
181
182 #[expect(unsafe_code, reason = "required for FFI threadpool detachment")]
183 pub fn detach_threadpool(&self) {
184 unsafe { llama_cpp_bindings_sys::llama_detach_threadpool(self.context.as_ptr()) }
185 }
186
187 pub fn mark_logits_initialized(&mut self, token_index: i32) {
188 self.initialized_logits = vec![token_index];
189 }
190
191 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
195 let mut out_vendored_return_code: i32 = 0;
196 let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
197 let status = unsafe {
198 llama_cpp_bindings_sys::llama_rs_decode(
199 self.context.as_ptr(),
200 batch.llama_batch,
201 &raw mut out_vendored_return_code,
202 &raw mut out_error,
203 )
204 };
205 match status {
206 llama_cpp_bindings_sys::LLAMA_RS_DECODE_OK => {
207 self.initialized_logits
208 .clone_from(&batch.initialized_logits);
209 Ok(())
210 }
211 llama_cpp_bindings_sys::LLAMA_RS_DECODE_VENDORED_RETURNED_NONZERO_CODE => {
212 let code = NonZeroI32::new(out_vendored_return_code).unwrap_or_else(|| {
213 unreachable!(
214 "llama_rs_decode reported a nonzero return code but the value was zero"
215 )
216 });
217 Err(DecodeError::from(code))
218 }
219 llama_cpp_bindings_sys::LLAMA_RS_DECODE_OUT_OF_MEMORY => {
220 Err(DecodeError::DecodeOutOfMemory)
221 }
222 llama_cpp_bindings_sys::LLAMA_RS_DECODE_COMPUTE_FAILED => {
223 Err(DecodeError::ComputeFailed)
224 }
225 llama_cpp_bindings_sys::LLAMA_RS_DECODE_ERROR_STRING_ALLOCATION_FAILED => {
226 Err(DecodeError::NotEnoughMemory)
227 }
228 llama_cpp_bindings_sys::LLAMA_RS_DECODE_VENDORED_THREW_CXX_EXCEPTION => {
229 let message =
230 unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
231 Err(DecodeError::Reported { message })
232 }
233 other => unreachable!("llama_rs_decode returned unrecognized status {other}"),
234 }
235 }
236
237 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
241 let mut out_vendored_return_code: i32 = 0;
242 let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
243 let status = unsafe {
244 llama_cpp_bindings_sys::llama_rs_encode(
245 self.context.as_ptr(),
246 batch.llama_batch,
247 &raw mut out_vendored_return_code,
248 &raw mut out_error,
249 )
250 };
251 match status {
252 llama_cpp_bindings_sys::LLAMA_RS_ENCODE_OK => {
253 self.initialized_logits
254 .clone_from(&batch.initialized_logits);
255 Ok(())
256 }
257 llama_cpp_bindings_sys::LLAMA_RS_ENCODE_MODEL_HAS_NO_ENCODER => {
258 Err(EncodeError::ModelHasNoEncoder)
259 }
260 llama_cpp_bindings_sys::LLAMA_RS_ENCODE_VENDORED_RETURNED_NONZERO_CODE => {
261 let code = NonZeroI32::new(out_vendored_return_code).unwrap_or_else(|| {
262 unreachable!(
263 "llama_rs_encode reported a nonzero return code but the value was zero"
264 )
265 });
266 Err(EncodeError::from(code))
267 }
268 llama_cpp_bindings_sys::LLAMA_RS_ENCODE_OUT_OF_MEMORY => {
269 Err(EncodeError::EncodeOutOfMemory)
270 }
271 llama_cpp_bindings_sys::LLAMA_RS_ENCODE_COMPUTE_FAILED => {
272 Err(EncodeError::ComputeFailed)
273 }
274 llama_cpp_bindings_sys::LLAMA_RS_ENCODE_ERROR_STRING_ALLOCATION_FAILED => {
275 Err(EncodeError::NotEnoughMemory)
276 }
277 llama_cpp_bindings_sys::LLAMA_RS_ENCODE_VENDORED_THREW_CXX_EXCEPTION => {
278 let message =
279 unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
280 Err(EncodeError::Reported { message })
281 }
282 other => unreachable!("llama_rs_encode returned unrecognized status {other}"),
283 }
284 }
285
286 pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
293 if !self.embeddings_enabled {
294 return Err(EmbeddingsError::NotEnabled);
295 }
296
297 let n_embd = usize::try_from(self.model.n_embd())
298 .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
299
300 unsafe {
301 let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
302 self.context.as_ptr(),
303 sequence_index,
304 );
305
306 if embedding.is_null() {
307 Err(EmbeddingsError::NonePoolType)
308 } else {
309 Ok(slice::from_raw_parts(embedding, n_embd))
310 }
311 }
312 }
313
314 pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
321 if !self.embeddings_enabled {
322 return Err(EmbeddingsError::NotEnabled);
323 }
324
325 let n_embd = usize::try_from(self.model.n_embd())
326 .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
327
328 unsafe {
329 let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
330 self.context.as_ptr(),
331 token_index,
332 );
333
334 if embedding.is_null() {
335 Err(EmbeddingsError::LogitsNotEnabled)
336 } else {
337 Ok(slice::from_raw_parts(embedding, n_embd))
338 }
339 }
340 }
341
342 pub fn candidates(&self) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
345 let logits = self.get_logits()?;
346
347 Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
348 let token = LlamaToken::new(token_id);
349 LlamaTokenData::new(token, *logit, 0_f32)
350 }))
351 }
352
353 pub fn token_data_array(&self) -> Result<LlamaTokenDataArray, LogitsError> {
356 Ok(LlamaTokenDataArray::from_iter(self.candidates()?, false))
357 }
358
359 pub fn get_logits(&self) -> Result<&[f32], LogitsError> {
362 let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
363
364 if data.is_null() {
365 return Err(LogitsError::NullLogits);
366 }
367
368 let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
369
370 Ok(unsafe { slice::from_raw_parts(data, len) })
371 }
372
373 pub fn candidates_ith(
376 &self,
377 token_index: i32,
378 ) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
379 let logits = self.get_logits_ith(token_index)?;
380
381 Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
382 let token = LlamaToken::new(token_id);
383 LlamaTokenData::new(token, *logit, 0_f32)
384 }))
385 }
386
387 pub fn token_data_array_ith(
390 &self,
391 token_index: i32,
392 ) -> Result<LlamaTokenDataArray, LogitsError> {
393 Ok(LlamaTokenDataArray::from_iter(
394 self.candidates_ith(token_index)?,
395 false,
396 ))
397 }
398
399 pub fn get_logits_ith(&self, token_index: i32) -> Result<&[f32], LogitsError> {
402 if !self.initialized_logits.contains(&token_index) {
403 return Err(LogitsError::TokenNotInitialized(token_index));
404 }
405
406 if token_index >= 0 {
407 let token_index_u32 =
408 u32::try_from(token_index).map_err(LogitsError::TokenIndexOverflow)?;
409
410 if self.n_ctx() <= token_index_u32 {
411 return Err(LogitsError::TokenIndexExceedsContext {
412 token_index: token_index_u32,
413 context_size: self.n_ctx(),
414 });
415 }
416 }
417
418 let data = unsafe {
419 llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
420 };
421 let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
422
423 Ok(unsafe { slice::from_raw_parts(data, len) })
424 }
425
426 pub fn reset_timings(&mut self) {
427 unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
428 }
429
430 pub fn timings(&mut self) -> LlamaTimings {
431 let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
432 LlamaTimings { timings }
433 }
434
435 pub fn lora_adapter_set(
439 &self,
440 adapter: &mut LlamaLoraAdapter,
441 scale: f32,
442 ) -> Result<(), LlamaLoraAdapterSetError> {
443 let mut adapters = [adapter.lora_adapter.as_ptr()];
444 let mut scales = [scale];
445 let err_code = unsafe {
446 llama_cpp_bindings_sys::llama_set_adapters_lora(
447 self.context.as_ptr(),
448 adapters.as_mut_ptr(),
449 1,
450 scales.as_mut_ptr(),
451 )
452 };
453 check_lora_set_result(err_code)?;
454
455 log::debug!("Set lora adapter");
456 Ok(())
457 }
458
459 pub fn lora_adapter_remove(
463 &self,
464 _adapter: &mut LlamaLoraAdapter,
465 ) -> Result<(), LlamaLoraAdapterRemoveError> {
466 let err_code = unsafe {
467 llama_cpp_bindings_sys::llama_set_adapters_lora(
468 self.context.as_ptr(),
469 std::ptr::null_mut(),
470 0,
471 std::ptr::null_mut(),
472 )
473 };
474 check_lora_remove_result(err_code)?;
475
476 log::debug!("Remove lora adapter");
477 Ok(())
478 }
479}
480
481impl Drop for LlamaContext<'_> {
482 fn drop(&mut self) {
483 unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
484 }
485}
486
487#[cfg(test)]
488mod unit_tests {
489 use crate::LlamaLoraAdapterRemoveError;
490 use crate::LlamaLoraAdapterSetError;
491
492 use super::{check_lora_remove_result, check_lora_set_result};
493
494 #[test]
495 fn check_lora_set_result_ok_for_zero() {
496 assert!(check_lora_set_result(0).is_ok());
497 }
498
499 #[test]
500 fn check_lora_set_result_error_for_nonzero() {
501 let result = check_lora_set_result(-1);
502
503 assert_eq!(result, Err(LlamaLoraAdapterSetError::ErrorResult(-1)));
504 }
505
506 #[test]
507 fn check_lora_remove_result_ok_for_zero() {
508 assert!(check_lora_remove_result(0).is_ok());
509 }
510
511 #[test]
512 fn check_lora_remove_result_error_for_nonzero() {
513 let result = check_lora_remove_result(-1);
514
515 assert_eq!(result, Err(LlamaLoraAdapterRemoveError::ErrorResult(-1)));
516 }
517}