use derive_more::{Deref, DerefMut};
use std::cmp::min;
use std::ops::{Bound, RangeBounds};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use thiserror::Error;
use tokio::sync::mpsc::unbounded_channel;
use tracing::{error, info, trace, warn};
use llama_cpp_sys::{
llama_context, llama_copy_state_data, llama_decode, llama_free, llama_get_logits_ith,
llama_get_state_size, llama_kv_cache_seq_rm, llama_set_state_data, llama_token_data,
llama_token_data_array,
};
use crate::standard_sampler::StandardSampler;
use crate::{LlamaModel, LlamaTokenizationError, Sampler, Token};
mod completion;
mod params;
use crate::batch::Batch;
pub use completion::CompletionHandle;
pub use completion::*;
pub use params::*;
#[derive(Deref, DerefMut)]
pub(crate) struct LlamaContextInner {
pub(crate) ptr: *mut llama_context,
}
unsafe impl Send for LlamaContextInner {}
impl Drop for LlamaContextInner {
fn drop(&mut self) {
unsafe { llama_free(self.ptr) }
}
}
#[derive(Clone)]
pub struct LlamaSession {
pub(crate) inner: Arc<LlamaSessionInner>,
}
pub(crate) struct LlamaSessionInner {
pub(crate) model: LlamaModel,
pub(crate) ctx: Mutex<LlamaContextInner>,
pub(crate) tokens: RwLock<Vec<Token>>,
pub(crate) last_batch_size: AtomicUsize,
pub(crate) max_batch: u32,
pub(crate) params: SessionParams,
}
#[derive(Error, Debug)]
pub enum LlamaContextError {
#[error("tokenization failed: {0}")]
TokenizationFailed(#[from] LlamaTokenizationError),
#[error("{provided_tokens} were provided, but llama.cpp can only handle {max_tokens}")]
MaxTokensExceeded {
provided_tokens: usize,
max_tokens: usize,
},
#[error("failed to create llama context")]
SessionFailed,
#[error("advancing context failed (error code {0})")]
DecodeFailed(i32),
#[error("failed to process embeddings (reason: {0})")]
EmbeddingsFailed(String),
#[error("failed to operate over kv cache due to invalid range")]
InvalidRange,
#[error("cannot start completing without any history")]
NoContext,
}
impl LlamaSession {
pub fn advance_context_with_tokens(
&mut self,
tokens: impl AsRef<[Token]>,
) -> Result<(), LlamaContextError> {
let tokens = tokens.as_ref();
let n_tokens = tokens.len();
if n_tokens == 0 {
return Ok(());
}
if n_tokens > i32::MAX as usize {
return Err(LlamaContextError::MaxTokensExceeded {
provided_tokens: n_tokens,
max_tokens: i32::MAX as usize,
});
}
info!("Advancing context with {n_tokens} tokens");
let batch_size = min(n_tokens, self.inner.max_batch as usize);
let sequences = tokens.chunks(batch_size);
if n_tokens > batch_size {
info!("Number of tokens exceeds the maximum batch size ({}) for this session, splitting the input", self.inner.max_batch);
}
let mut batch = Batch::new(batch_size, 0, 1);
let history_size = self.context_size();
let mut local_history = 0;
let mut last_batch_size = self.inner.last_batch_size.load(Ordering::SeqCst);
for sequence in sequences {
batch.clear();
for token in sequence {
batch.add(*token, history_size + local_history, &[0], false);
local_history += 1;
}
if local_history == n_tokens {
batch.set_logits(sequence.len() - 1, true);
}
trace!("Wrote {n_tokens} tokens to the token buffer");
trace!("Starting LLaMA decode for batch");
let err = unsafe {
let session_guard = self.inner.ctx.lock().unwrap();
llama_decode(**session_guard, batch.handle())
};
if err != 0 {
return Err(LlamaContextError::DecodeFailed(err));
}
trace!("Batch decode completed successfully");
last_batch_size = sequence.len();
}
self.inner.tokens.write().unwrap().extend_from_slice(tokens);
self.inner
.last_batch_size
.store(last_batch_size, Ordering::SeqCst);
Ok(())
}
pub async fn advance_context_with_tokens_async(
&mut self,
tokens: impl AsRef<[Token]>,
) -> Result<(), LlamaContextError> {
let tokens = tokens.as_ref().to_owned();
let mut session = self.clone();
tokio::task::spawn_blocking(move || session.advance_context_with_tokens(tokens))
.await
.unwrap()
}
pub fn advance_context(&mut self, ctx: impl AsRef<[u8]>) -> Result<(), LlamaContextError> {
let tokens = self
.inner
.model
.tokenize_bytes(ctx.as_ref(), false, true)?
.into_boxed_slice();
self.advance_context_with_tokens(tokens)
}
pub async fn advance_context_async(
&mut self,
ctx: impl AsRef<[u8]>,
) -> Result<(), LlamaContextError> {
let ctx = ctx.as_ref().to_owned();
let mut session = self.clone();
tokio::task::spawn_blocking(move || session.advance_context(ctx))
.await
.unwrap()
}
pub fn start_completing(&mut self) -> Result<CompletionHandle, LlamaContextError> {
self.start_completing_with(
StandardSampler::new_greedy(),
self.params().n_ctx as usize - self.context_size(),
)
}
pub fn start_completing_with<S>(
&mut self,
mut sampler: S,
max_predictions: usize,
) -> Result<CompletionHandle, LlamaContextError>
where
S: Sampler + Send + Sync + 'static,
{
let history_size = self.context_size();
if history_size == 0 {
return Err(LlamaContextError::NoContext);
}
let (tx, rx) = unbounded_channel();
let session = self.clone();
info!("Generating completions with {history_size} tokens of history");
thread::spawn(move || {
let context = session.inner.ctx.lock().unwrap();
let vocab = session.model().vocabulary_size();
let end_of_stream = session.model().eos();
let mut token_buf = session.inner.tokens.write().unwrap();
let mut batch = Batch::new(1, 0, 1);
let mut current_pos = history_size;
if session.inner.last_batch_size.load(Ordering::SeqCst) == 0 {
unsafe {
llama_kv_cache_seq_rm(**context, -1, token_buf.len() as i32 - 1, -1);
}
batch.add(*token_buf.last().unwrap(), current_pos, &[0], true);
let res = unsafe { llama_decode(**context, batch.handle()) };
if res != 0 {
error!("Failed to decode context ({res})");
return;
}
session
.inner
.last_batch_size
.store(batch.tokens(), Ordering::SeqCst);
batch.clear();
}
loop {
let mut candidates: Vec<llama_token_data> = {
let i = session.inner.last_batch_size.load(Ordering::SeqCst);
let logits = unsafe { llama_get_logits_ith(**context, (i - 1) as i32) };
let logits = unsafe { std::slice::from_raw_parts(logits, vocab) };
logits
.iter()
.enumerate()
.map(|(id, &logit)| llama_token_data {
id: id as i32,
logit,
p: 0.0,
})
.collect()
};
let candidates_p = llama_token_data_array {
data: candidates.as_mut_ptr(),
size: vocab,
sorted: false,
};
let token = sampler.sample(**context, &token_buf, candidates_p);
if let Err(e) = tx.send(token) {
let token_str = String::from_utf8_lossy(session.inner.model.detokenize(e.0));
warn!("Cannot send token ({}): {}", token_str, e);
return;
}
if token == end_of_stream || token_buf.len() - history_size >= max_predictions {
return;
}
batch.add(token, current_pos, &[0], true);
let res = unsafe { llama_decode(**context, batch.handle()) };
if res != 0 {
error!("Failed to decode context ({res})");
return;
}
session
.inner
.last_batch_size
.store(batch.tokens(), Ordering::SeqCst);
current_pos = token_buf.len();
token_buf.push(token);
batch.clear();
}
});
Ok(CompletionHandle {
rx,
model: self.model(),
})
}
pub fn model(&self) -> LlamaModel {
self.inner.model.clone()
}
pub fn params(&self) -> &SessionParams {
&self.inner.params
}
pub fn context_size(&self) -> usize {
self.inner.tokens.read().unwrap().len()
}
pub fn context(&self) -> Vec<Token> {
self.inner.tokens.read().unwrap().clone()
}
pub fn remove_tokens_in_range(
&mut self,
range: impl RangeBounds<usize>,
) -> Result<(), LlamaContextError> {
let start_bound = match range.start_bound() {
Bound::Included(i) => *i as i32,
Bound::Excluded(i) => *i as i32 + 1,
Bound::Unbounded => -1,
};
let end_bound = match range.end_bound() {
Bound::Included(i) => *i as i32 + 1,
Bound::Excluded(i) => *i as i32,
Bound::Unbounded => -1,
};
let success = unsafe {
let context = self.inner.ctx.lock().unwrap();
llama_kv_cache_seq_rm(**context, -1, start_bound, end_bound)
};
if !success {
return Err(LlamaContextError::InvalidRange);
}
if end_bound == -1 || end_bound as usize >= self.context_size() {
self.inner.last_batch_size.store(0, Ordering::SeqCst);
}
self.inner.tokens.write().unwrap().drain(range);
Ok(())
}
pub fn truncate_context(&mut self, n_tokens: usize) -> Result<(), LlamaContextError> {
self.remove_tokens_in_range(n_tokens..)
}
pub fn set_context_to_tokens(
&mut self,
new_tokens: impl AsRef<[Token]>,
) -> Result<(), LlamaContextError> {
let new_tokens = new_tokens.as_ref();
let old_tokens = self.inner.tokens.read().unwrap();
let shared_prefix = old_tokens
.iter()
.zip(new_tokens)
.position(|(t1, t2)| t1 != t2)
.unwrap_or(new_tokens.len().min(old_tokens.len()));
std::mem::drop(old_tokens);
self.truncate_context(shared_prefix)?;
self.advance_context_with_tokens(&new_tokens[shared_prefix..])
}
pub fn set_context(&mut self, ctx: impl AsRef<[u8]>) -> Result<(), LlamaContextError> {
let tokens = self
.inner
.model
.tokenize_bytes(ctx.as_ref(), false, false)?
.into_boxed_slice();
self.set_context_to_tokens(tokens)
}
pub async fn set_context_to_tokens_async(
&mut self,
tokens: impl AsRef<[Token]>,
) -> Result<(), LlamaContextError> {
let tokens = tokens.as_ref().to_owned();
let mut session = self.clone();
tokio::task::spawn_blocking(move || session.set_context_to_tokens(tokens))
.await
.unwrap()
}
pub async fn set_context_async(
&mut self,
ctx: impl AsRef<[u8]>,
) -> Result<(), LlamaContextError> {
let ctx = ctx.as_ref().to_owned();
let mut session = self.clone();
tokio::task::spawn_blocking(move || session.set_context(ctx))
.await
.unwrap()
}
pub fn deep_copy(&self) -> Result<LlamaSession, LlamaContextError> {
let ctx = self.inner.ctx.lock().unwrap();
#[allow(unused_mut)]
let mut copy = self.model().create_session(self.inner.params.clone())?;
let size = unsafe { llama_get_state_size(**ctx) };
let mut buf = vec![0; size];
unsafe {
let copy_size = llama_copy_state_data(**ctx, buf.as_mut_ptr());
assert!(copy_size <= size);
let copy_guard = copy.inner.ctx.lock().unwrap();
let set_size = llama_set_state_data(**copy_guard, buf.as_mut_ptr());
assert_eq!(copy_size, set_size);
}
*copy.inner.tokens.write().unwrap() = self.inner.tokens.read().unwrap().clone();
copy.inner.last_batch_size.store(
self.inner.last_batch_size.load(Ordering::SeqCst),
Ordering::SeqCst,
);
Ok(copy)
}
pub fn memory_size(&self) -> usize {
let ctx = self.inner.ctx.lock().unwrap();
unsafe { llama_get_state_size(**ctx) }
}
}