use std::ptr::NonNull;
use crate::context::LlamaContext;
use crate::llama_batch::LlamaBatch;
use crate::token::LlamaToken;
#[derive(Debug, thiserror::Error)]
pub enum Eagle3SessionError {
#[error("failed to create EAGLE-3 draft session — check that `draft` is a context over a valid EAGLE-3 draft model (3 extract layers) built from the same target")]
Init,
#[error("EAGLE-3 process failed (see llama.cpp logs)")]
Process,
#[error("sequence id {seq_id} out of range (n_seq = {n_seq})")]
BadSeqId {
seq_id: i32,
n_seq: u32,
},
#[error("invalid EAGLE-3 session config: {0}")]
InvalidConfig(&'static str),
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Eagle3SessionConfig {
pub n_seq: u32,
pub n_draft_max: i32,
pub n_min: i32,
pub p_min: f32,
}
impl Eagle3SessionConfig {
#[must_use]
pub fn new(n_seq: u32, n_draft_max: i32) -> Self {
Self {
n_seq,
n_draft_max,
n_min: 0,
p_min: 0.0,
}
}
#[must_use]
pub fn with_n_min(mut self, n_min: i32) -> Self {
self.n_min = n_min;
self
}
#[must_use]
pub fn with_p_min(mut self, p_min: f32) -> Self {
self.p_min = p_min;
self
}
}
pub struct Eagle3Session {
raw: NonNull<llama_cpp_sys_4::mtp_session>,
config: Eagle3SessionConfig,
}
unsafe impl Send for Eagle3Session {}
impl Eagle3Session {
pub fn new(
target: &LlamaContext<'_>,
draft: &LlamaContext<'_>,
n_seq: u32,
n_draft_max: i32,
) -> Result<Self, Eagle3SessionError> {
Self::new_with_config(target, draft, Eagle3SessionConfig::new(n_seq, n_draft_max))
}
pub fn new_with_config(
target: &LlamaContext<'_>,
draft: &LlamaContext<'_>,
config: Eagle3SessionConfig,
) -> Result<Self, Eagle3SessionError> {
if config.n_seq == 0 {
return Err(Eagle3SessionError::InvalidConfig("n_seq must be > 0"));
}
if config.n_draft_max <= 0 {
return Err(Eagle3SessionError::InvalidConfig("n_draft_max must be > 0"));
}
let c_config = llama_cpp_sys_4::mtp_session_config {
n_seq: config.n_seq,
n_draft_max: config.n_draft_max,
n_min: config.n_min,
p_min: config.p_min,
spec_type: llama_cpp_sys_4::MTP_SPEC_TYPE_EAGLE3 as i32,
};
let raw = unsafe {
llama_cpp_sys_4::mtp_session_new(
target.context.as_ptr(),
draft.context.as_ptr(),
&raw const c_config,
)
};
let raw = NonNull::new(raw).ok_or(Eagle3SessionError::Init)?;
Ok(Self { raw, config })
}
#[must_use]
pub fn config(&self) -> Eagle3SessionConfig {
self.config
}
#[must_use]
pub fn need_embd(&self) -> bool {
unsafe { llama_cpp_sys_4::mtp_session_need_embd(self.raw.as_ptr()) }
}
#[must_use]
pub fn need_embd_pre_norm(&self) -> bool {
unsafe { llama_cpp_sys_4::mtp_session_need_embd_pre_norm(self.raw.as_ptr()) }
}
#[must_use]
pub fn n_draft_max(&self) -> i32 {
self.config.n_draft_max
}
#[must_use]
pub fn n_min(&self) -> i32 {
self.config.n_min
}
#[must_use]
pub fn p_min(&self) -> f32 {
self.config.p_min
}
#[must_use]
pub fn n_seq(&self) -> u32 {
self.config.n_seq
}
pub fn print_stats(&self) {
unsafe { llama_cpp_sys_4::mtp_session_print_stats(self.raw.as_ptr()) }
}
pub fn begin(&mut self, seq_id: i32, prompt: &[LlamaToken]) -> Result<(), Eagle3SessionError> {
self.check_seq(seq_id)?;
unsafe {
llama_cpp_sys_4::mtp_session_begin(
self.raw.as_ptr(),
seq_id,
prompt.as_ptr().cast(),
prompt.len(),
);
}
Ok(())
}
pub fn process(&mut self, batch: &LlamaBatch) -> Result<(), Eagle3SessionError> {
let ok =
unsafe { llama_cpp_sys_4::mtp_session_process(self.raw.as_ptr(), &batch.llama_batch) };
if ok {
Ok(())
} else {
Err(Eagle3SessionError::Process)
}
}
pub fn draft(
&mut self,
seq_id: i32,
n_past: i32,
id_last: LlamaToken,
) -> Result<Vec<LlamaToken>, Eagle3SessionError> {
self.check_seq(seq_id)?;
let cap = self.config.n_draft_max.max(0) as usize;
let mut buf: Vec<i32> = vec![0; cap];
let mut out_n: i32 = cap as i32;
unsafe {
llama_cpp_sys_4::mtp_session_draft(
self.raw.as_ptr(),
seq_id,
n_past,
id_last.0,
buf.as_mut_ptr(),
&mut out_n,
);
}
let n = out_n.max(0) as usize;
buf.truncate(n);
Ok(buf.into_iter().map(LlamaToken).collect())
}
pub fn accept(&mut self, seq_id: i32, n_accepted: u16) -> Result<(), Eagle3SessionError> {
self.check_seq(seq_id)?;
unsafe {
llama_cpp_sys_4::mtp_session_accept(self.raw.as_ptr(), seq_id, n_accepted);
}
Ok(())
}
fn check_seq(&self, seq_id: i32) -> Result<(), Eagle3SessionError> {
if seq_id < 0 || (seq_id as u32) >= self.config.n_seq {
return Err(Eagle3SessionError::BadSeqId {
seq_id,
n_seq: self.config.n_seq,
});
}
Ok(())
}
}
impl Drop for Eagle3Session {
fn drop(&mut self) {
unsafe { llama_cpp_sys_4::mtp_session_free(self.raw.as_ptr()) }
}
}
impl std::fmt::Debug for Eagle3Session {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Eagle3Session")
.field("config", &self.config)
.finish()
}
}