use std::ptr::NonNull;
use crate::context::LlamaContext;
use crate::llama_batch::LlamaBatch;
use crate::token::LlamaToken;
#[derive(Debug, thiserror::Error)]
pub enum MtpSessionError {
#[error("failed to create MTP draft session — check that ctx_dft was built with LlamaContextType::Mtp and the model has MTP heads")]
Init,
#[error("mtp_session_process failed (see llama.cpp logs)")]
Process,
#[error("sequence id {seq_id} out of range (n_seq = {n_seq})")]
BadSeqId {
seq_id: i32,
n_seq: u32,
},
#[error("invalid MTP session config: {0}")]
InvalidConfig(&'static str),
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct MtpSessionConfig {
pub n_seq: u32,
pub n_draft_max: i32,
pub n_min: i32,
pub p_min: f32,
}
impl MtpSessionConfig {
#[must_use]
pub fn new(n_seq: u32, n_draft_max: i32) -> Self {
Self {
n_seq,
n_draft_max,
n_min: 0,
p_min: 0.0,
}
}
#[must_use]
pub fn with_n_min(mut self, n_min: i32) -> Self {
self.n_min = n_min;
self
}
#[must_use]
pub fn with_p_min(mut self, p_min: f32) -> Self {
self.p_min = p_min;
self
}
}
pub struct MtpSession {
raw: NonNull<llama_cpp_sys_4::mtp_session>,
config: MtpSessionConfig,
}
unsafe impl Send for MtpSession {}
impl MtpSession {
pub fn new(
target: &LlamaContext<'_>,
draft: &LlamaContext<'_>,
n_seq: u32,
n_draft_max: i32,
) -> Result<Self, MtpSessionError> {
Self::new_with_config(target, draft, MtpSessionConfig::new(n_seq, n_draft_max))
}
pub fn new_with_config(
target: &LlamaContext<'_>,
draft: &LlamaContext<'_>,
config: MtpSessionConfig,
) -> Result<Self, MtpSessionError> {
if config.n_seq == 0 {
return Err(MtpSessionError::InvalidConfig("n_seq must be > 0"));
}
if config.n_draft_max <= 0 {
return Err(MtpSessionError::InvalidConfig("n_draft_max must be > 0"));
}
let c_config = llama_cpp_sys_4::mtp_session_config {
n_seq: config.n_seq,
n_draft_max: config.n_draft_max,
n_min: config.n_min,
p_min: config.p_min,
};
let raw = unsafe {
llama_cpp_sys_4::mtp_session_new(
target.context.as_ptr(),
draft.context.as_ptr(),
&raw const c_config,
)
};
let raw = NonNull::new(raw).ok_or(MtpSessionError::Init)?;
Ok(Self { raw, config })
}
#[must_use]
pub fn config(&self) -> MtpSessionConfig {
self.config
}
#[must_use]
pub fn need_embd(&self) -> bool {
unsafe { llama_cpp_sys_4::mtp_session_need_embd(self.raw.as_ptr()) }
}
#[must_use]
pub fn need_embd_pre_norm(&self) -> bool {
unsafe { llama_cpp_sys_4::mtp_session_need_embd_pre_norm(self.raw.as_ptr()) }
}
#[must_use]
pub fn n_draft_max(&self) -> i32 {
self.config.n_draft_max
}
#[must_use]
pub fn n_min(&self) -> i32 {
self.config.n_min
}
#[must_use]
pub fn p_min(&self) -> f32 {
self.config.p_min
}
#[must_use]
pub fn n_seq(&self) -> u32 {
self.config.n_seq
}
pub fn print_stats(&self) {
unsafe { llama_cpp_sys_4::mtp_session_print_stats(self.raw.as_ptr()) }
}
pub fn begin(&mut self, seq_id: i32, prompt: &[LlamaToken]) -> Result<(), MtpSessionError> {
self.check_seq(seq_id)?;
unsafe {
llama_cpp_sys_4::mtp_session_begin(
self.raw.as_ptr(),
seq_id,
prompt.as_ptr().cast(),
prompt.len(),
);
}
Ok(())
}
pub fn process(&mut self, batch: &LlamaBatch) -> Result<(), MtpSessionError> {
let ok =
unsafe { llama_cpp_sys_4::mtp_session_process(self.raw.as_ptr(), &batch.llama_batch) };
if ok {
Ok(())
} else {
Err(MtpSessionError::Process)
}
}
pub fn draft(
&mut self,
seq_id: i32,
n_past: i32,
id_last: LlamaToken,
) -> Result<Vec<LlamaToken>, MtpSessionError> {
self.check_seq(seq_id)?;
let cap = self.config.n_draft_max.max(0) as usize;
let mut buf: Vec<i32> = vec![0; cap];
let mut out_n: i32 = cap as i32;
unsafe {
llama_cpp_sys_4::mtp_session_draft(
self.raw.as_ptr(),
seq_id,
n_past,
id_last.0,
buf.as_mut_ptr(),
&mut out_n,
);
}
let n = out_n.max(0) as usize;
buf.truncate(n);
Ok(buf.into_iter().map(LlamaToken).collect())
}
pub fn accept(&mut self, seq_id: i32, n_accepted: u16) -> Result<(), MtpSessionError> {
self.check_seq(seq_id)?;
unsafe {
llama_cpp_sys_4::mtp_session_accept(self.raw.as_ptr(), seq_id, n_accepted);
}
Ok(())
}
fn check_seq(&self, seq_id: i32) -> Result<(), MtpSessionError> {
if seq_id < 0 || (seq_id as u32) >= self.config.n_seq {
return Err(MtpSessionError::BadSeqId {
seq_id,
n_seq: self.config.n_seq,
});
}
Ok(())
}
}
impl Drop for MtpSession {
fn drop(&mut self) {
unsafe { llama_cpp_sys_4::mtp_session_free(self.raw.as_ptr()) }
}
}
impl std::fmt::Debug for MtpSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MtpSession")
.field("config", &self.config)
.field("need_embd_pre_norm", &self.need_embd_pre_norm())
.finish()
}
}