use std::ptr::NonNull;
use crate::context::LlamaContext;
use crate::llama_batch::LlamaBatch;
use crate::token::LlamaToken;
#[derive(Debug, thiserror::Error)]
pub enum MtpSessionError {
#[error("failed to create MTP draft session — check that ctx_dft was built with LlamaContextType::Mtp and the model has MTP heads")]
Init,
#[error("mtp_session_process failed (see llama.cpp logs)")]
Process,
#[error("sequence id {seq_id} out of range (n_seq = {n_seq})")]
BadSeqId {
seq_id: i32,
n_seq: u32,
},
}
pub struct MtpSession {
raw: NonNull<llama_cpp_sys_4::mtp_session>,
n_seq: u32,
n_draft_max: i32,
}
unsafe impl Send for MtpSession {}
impl MtpSession {
pub fn new(
target: &LlamaContext<'_>,
draft: &LlamaContext<'_>,
n_seq: u32,
n_draft_max: i32,
) -> Result<Self, MtpSessionError> {
let raw = unsafe {
llama_cpp_sys_4::mtp_session_new(
target.context.as_ptr(),
draft.context.as_ptr(),
n_seq,
n_draft_max,
)
};
let raw = NonNull::new(raw).ok_or(MtpSessionError::Init)?;
Ok(Self {
raw,
n_seq,
n_draft_max,
})
}
#[must_use]
pub fn need_embd(&self) -> bool {
unsafe { llama_cpp_sys_4::mtp_session_need_embd(self.raw.as_ptr()) }
}
#[must_use]
pub fn n_draft_max(&self) -> i32 {
self.n_draft_max
}
#[must_use]
pub fn n_seq(&self) -> u32 {
self.n_seq
}
pub fn begin(&mut self, seq_id: i32, prompt: &[LlamaToken]) -> Result<(), MtpSessionError> {
self.check_seq(seq_id)?;
unsafe {
llama_cpp_sys_4::mtp_session_begin(
self.raw.as_ptr(),
seq_id,
prompt.as_ptr().cast(),
prompt.len(),
);
}
Ok(())
}
pub fn process(&mut self, batch: &LlamaBatch) -> Result<(), MtpSessionError> {
let ok = unsafe {
llama_cpp_sys_4::mtp_session_process(self.raw.as_ptr(), &batch.llama_batch)
};
if ok {
Ok(())
} else {
Err(MtpSessionError::Process)
}
}
pub fn draft(
&mut self,
seq_id: i32,
n_past: i32,
id_last: LlamaToken,
) -> Result<Vec<LlamaToken>, MtpSessionError> {
self.check_seq(seq_id)?;
let cap = self.n_draft_max.max(0) as usize;
let mut buf: Vec<i32> = vec![0; cap];
let mut out_n: i32 = cap as i32;
unsafe {
llama_cpp_sys_4::mtp_session_draft(
self.raw.as_ptr(),
seq_id,
n_past,
id_last.0,
buf.as_mut_ptr(),
&mut out_n,
);
}
let n = out_n.max(0) as usize;
buf.truncate(n);
Ok(buf.into_iter().map(LlamaToken).collect())
}
pub fn accept(&mut self, seq_id: i32, n_accepted: u16) -> Result<(), MtpSessionError> {
self.check_seq(seq_id)?;
unsafe {
llama_cpp_sys_4::mtp_session_accept(self.raw.as_ptr(), seq_id, n_accepted);
}
Ok(())
}
fn check_seq(&self, seq_id: i32) -> Result<(), MtpSessionError> {
if seq_id < 0 || (seq_id as u32) >= self.n_seq {
return Err(MtpSessionError::BadSeqId {
seq_id,
n_seq: self.n_seq,
});
}
Ok(())
}
}
impl Drop for MtpSession {
fn drop(&mut self) {
unsafe { llama_cpp_sys_4::mtp_session_free(self.raw.as_ptr()) }
}
}
impl std::fmt::Debug for MtpSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MtpSession")
.field("n_seq", &self.n_seq)
.field("n_draft_max", &self.n_draft_max)
.finish()
}
}