#![allow(missing_docs)]
use crate::model::quantized_llama_local::ModelWeights;
use crate::model::{Decoder, TreeDecoder};
use crate::tree::DraftTree;
use crate::{Error, Result};
use candle_core::quantized::gguf_file;
use candle_core::{DType, Device, IndexOp, Tensor};
use std::path::Path;
use tokenizers::Tokenizer;
pub struct LlamaQuantDecoder {
model: ModelWeights,
tokenizer: Tokenizer,
history: Vec<u32>,
device: Device,
vocab_size: usize,
hidden_size: usize,
last_logits: Option<Vec<f32>>,
eos_token_ids: Vec<u32>,
cache_len: usize,
}
impl std::fmt::Debug for LlamaQuantDecoder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlamaQuantDecoder")
.field("vocab_size", &self.vocab_size)
.field("hidden_size", &self.hidden_size)
.field("history_len", &self.history.len())
.field("cache_len", &self.cache_len)
.field("device", &self.device)
.finish()
}
}
impl LlamaQuantDecoder {
pub fn from_gguf(
gguf_path: impl AsRef<Path>,
tokenizer_path: impl AsRef<Path>,
device: Device,
eos_token_ids: Vec<u32>,
) -> Result<Self> {
let mut file = std::fs::File::open(gguf_path.as_ref())
.map_err(|e| Error::Other(anyhow::anyhow!("open gguf: {e}")))?;
let content = gguf_file::Content::read(&mut file).map_err(Error::Candle)?;
let hidden_size = content
.metadata
.get("llama.embedding_length")
.ok_or_else(|| Error::Other(anyhow::anyhow!("missing llama.embedding_length")))?
.to_u32()
.map_err(Error::Candle)? as usize;
let vocab_size = content
.metadata
.get("tokenizer.ggml.tokens")
.and_then(|v| v.to_vec().ok())
.map(|v| v.len())
.unwrap_or(128256);
let model = ModelWeights::from_gguf(content, &mut file, &device).map_err(Error::Candle)?;
let tokenizer = Tokenizer::from_file(tokenizer_path.as_ref())
.map_err(|e| Error::Tokenizer(e.to_string()))?;
Ok(Self {
model,
tokenizer,
history: Vec::new(),
device,
vocab_size,
hidden_size,
eos_token_ids,
cache_len: 0,
last_logits: None,
})
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
pub fn apply_lm_head(&self, hidden: &Tensor) -> Result<Tensor> {
let hidden_owned;
let hidden_use: &Tensor = if hidden.dtype() != DType::F32 {
hidden_owned = hidden.to_dtype(DType::F32).map_err(Error::Candle)?;
&hidden_owned
} else {
hidden
};
self.model.apply_lm_head(hidden_use).map_err(Error::Candle)
}
pub fn encode(&self, text: &str, add_special_tokens: bool) -> Result<Vec<u32>> {
let enc = self
.tokenizer
.encode(text, add_special_tokens)
.map_err(|e| Error::Tokenizer(e.to_string()))?;
Ok(enc.get_ids().to_vec())
}
pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
self.tokenizer
.decode(ids, skip_special_tokens)
.map_err(|e| Error::Tokenizer(e.to_string()))
}
fn forward_advance_logits(&mut self, tokens: &[u32]) -> Result<Tensor> {
if tokens.is_empty() {
return Err(Error::Sampling("forward_advance with empty tokens".into()));
}
let input = Tensor::new(tokens, &self.device)
.and_then(|t| t.unsqueeze(0))
.map_err(Error::Candle)?;
let hidden = self
.model
.forward_hidden(&input, self.cache_len)
.map_err(Error::Candle)?;
let logits = self.model.apply_lm_head(&hidden).map_err(Error::Candle)?;
let logits = logits.i((0, .., ..)).map_err(Error::Candle)?;
self.cache_len += tokens.len();
let n_rows = logits.dim(0).map_err(Error::Candle)?;
let last_row = logits.i((n_rows - 1, ..)).map_err(Error::Candle)?;
self.last_logits = Some(self.row_to_vec(&last_row)?);
Ok(logits)
}
fn row_to_vec(&self, t: &Tensor) -> Result<Vec<f32>> {
let t = if t.dtype() == DType::F32 {
t.clone()
} else {
t.to_dtype(DType::F32).map_err(Error::Candle)?
};
t.to_vec1::<f32>().map_err(Error::Candle)
}
pub fn observe_returning_last_hidden(&mut self, ids: &[u32]) -> Result<Tensor> {
if ids.is_empty() {
return Err(Error::Sampling(
"observe_returning_last_hidden with empty ids".into(),
));
}
let input = Tensor::new(ids, &self.device)
.and_then(|t| t.unsqueeze(0))
.map_err(Error::Candle)?;
let hidden = self
.model
.forward_hidden(&input, self.cache_len)
.map_err(Error::Candle)?;
self.cache_len += ids.len();
self.history.extend_from_slice(ids);
let last_idx = hidden.dim(1).map_err(Error::Candle)? - 1;
hidden.i((0, last_idx, ..)).map_err(Error::Candle)
}
pub fn last_hidden_state(&mut self) -> Result<Tensor> {
if self.history.is_empty() {
return Err(Error::Sampling(
"last_hidden_state with empty history".into(),
));
}
let last = *self.history.last().unwrap();
let target_len = self.history.len() - 1;
self.model
.truncate_kv_cache_to(target_len)
.map_err(Error::Candle)?;
self.cache_len = target_len;
let input = Tensor::from_slice(&[last], (1, 1), &self.device).map_err(Error::Candle)?;
let hidden = self
.model
.forward_hidden(&input, self.cache_len)
.map_err(Error::Candle)?;
self.cache_len += 1;
hidden.i((0, 0, ..)).map_err(Error::Candle)
}
pub fn last_hidden_states_multi(
&mut self,
layers: &[usize],
) -> Result<(Tensor, Vec<Tensor>)> {
if self.history.is_empty() {
return Err(Error::Sampling(
"last_hidden_states_multi with empty history".into(),
));
}
let last = *self.history.last().unwrap();
let target_len = self.history.len() - 1;
self.model
.truncate_kv_cache_to(target_len)
.map_err(Error::Candle)?;
self.cache_len = target_len;
let input = Tensor::from_slice(&[last], (1, 1), &self.device).map_err(Error::Candle)?;
let (final_h, mids) = self
.model
.forward_hidden_with_layers(&input, self.cache_len, layers)
.map_err(Error::Candle)?;
self.cache_len += 1;
let mids_last: Vec<Tensor> = mids
.into_iter()
.map(|t| t.i((0, 0, ..)).map_err(Error::Candle))
.collect::<Result<_>>()?;
let final_last = final_h.i((0, 0, ..)).map_err(Error::Candle)?;
Ok((final_last, mids_last))
}
pub fn num_hidden_layers(&self) -> usize {
self.model.num_hidden_layers()
}
pub fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
self.model.embed_tokens(input_ids).map_err(Error::Candle)
}
pub fn tree_logits(&mut self, tree: &DraftTree) -> Result<Vec<Vec<f32>>> {
if self.history.is_empty() {
return Err(Error::Sampling("tree_logits with empty history".into()));
}
let last_committed = *self.history.last().unwrap();
if tree.token_at(0) != last_committed {
return Err(Error::Sampling(format!(
"tree root ({}) must equal last committed token ({})",
tree.token_at(0),
last_committed
)));
}
let pre_cache_len = self.cache_len;
debug_assert_eq!(pre_cache_len, self.history.len());
let prefix_len = pre_cache_len - 1;
self.model
.truncate_kv_cache_to(prefix_len)
.map_err(Error::Candle)?;
self.cache_len = prefix_len;
let positions: Vec<u32> = (0..tree.len())
.map(|i| (prefix_len + tree.depth_of(i)) as u32)
.collect();
let position_tensor =
Tensor::from_vec(positions, (tree.len(),), &self.device).map_err(Error::Candle)?;
let bias = tree.full_attention_bias_4d(prefix_len, 1, 1, &self.device, DType::F32)?;
let input_ids = Tensor::from_slice(tree.tokens(), (1, tree.len()), &self.device)
.map_err(Error::Candle)?;
let hidden = self
.model
.forward_with_positions(&input_ids, &position_tensor, &bias)
.map_err(Error::Candle)?;
let logits = self.model.apply_lm_head(&hidden).map_err(Error::Candle)?;
let logits = logits.i((0, .., ..)).map_err(Error::Candle)?;
let mut out = Vec::with_capacity(tree.len());
for i in 0..tree.len() {
let row = logits.i((i, ..)).map_err(Error::Candle)?;
out.push(self.row_to_vec(&row)?);
}
self.model
.truncate_kv_cache_to(prefix_len)
.map_err(Error::Candle)?;
self.cache_len = prefix_len;
let restore_logits = self.forward_advance_logits(&[last_committed])?;
let restore_row = restore_logits
.i((restore_logits.dim(0).map_err(Error::Candle)? - 1, ..))
.map_err(Error::Candle)?;
out[0] = self.row_to_vec(&restore_row)?;
debug_assert_eq!(self.cache_len, pre_cache_len);
Ok(out)
}
pub fn tree_logits_keep_kv(
&mut self,
tree: &DraftTree,
) -> Result<(Vec<Vec<f32>>, Vec<Tensor>)> {
self.last_logits = None;
if self.history.is_empty() {
return Err(Error::Sampling("tree_logits with empty history".into()));
}
let last_committed = *self.history.last().unwrap();
if tree.token_at(0) != last_committed {
return Err(Error::Sampling(format!(
"tree root ({}) must equal last committed token ({})",
tree.token_at(0),
last_committed
)));
}
let pre_cache_len = self.cache_len;
debug_assert_eq!(pre_cache_len, self.history.len());
let prefix_len = pre_cache_len - 1;
self.model
.truncate_kv_cache_to(prefix_len)
.map_err(Error::Candle)?;
self.cache_len = prefix_len;
let positions: Vec<u32> = (0..tree.len())
.map(|i| (prefix_len + tree.depth_of(i)) as u32)
.collect();
let position_tensor =
Tensor::from_vec(positions, (tree.len(),), &self.device).map_err(Error::Candle)?;
let bias = tree.full_attention_bias_4d(prefix_len, 1, 1, &self.device, DType::F32)?;
let input_ids = Tensor::from_slice(tree.tokens(), (1, tree.len()), &self.device)
.map_err(Error::Candle)?;
let hidden = self
.model
.forward_with_positions(&input_ids, &position_tensor, &bias)
.map_err(Error::Candle)?;
self.cache_len = prefix_len + tree.len();
let logits = self.model.apply_lm_head(&hidden).map_err(Error::Candle)?;
let logits = logits.i((0, .., ..)).map_err(Error::Candle)?;
let mut out_logits = Vec::with_capacity(tree.len());
let mut out_hidden = Vec::with_capacity(tree.len());
for i in 0..tree.len() {
let row = logits.i((i, ..)).map_err(Error::Candle)?;
out_logits.push(self.row_to_vec(&row)?);
out_hidden.push(hidden.i((0, i, ..)).map_err(Error::Candle)?);
}
Ok((out_logits, out_hidden))
}
pub fn commit_tree_path(
&mut self,
tree: &DraftTree,
accepted_indices: &[usize],
) -> Result<()> {
self.last_logits = None;
if self.history.is_empty() {
return Err(Error::Sampling(
"commit_tree_path with empty history".into(),
));
}
debug_assert!(!accepted_indices.is_empty() && accepted_indices[0] == 0);
let last_committed = *self.history.last().unwrap();
debug_assert_eq!(tree.token_at(0), last_committed);
let prefix_len = self.history.len() - 1;
let mut keep: Vec<u32> = Vec::with_capacity(prefix_len + accepted_indices.len());
for i in 0..prefix_len {
keep.push(i as u32);
}
for &ti in accepted_indices {
keep.push((prefix_len + ti) as u32);
}
self.model.keep_kv_indices(&keep).map_err(Error::Candle)?;
self.cache_len = keep.len();
for &ti in accepted_indices.iter().skip(1) {
self.history.push(tree.token_at(ti));
}
Ok(())
}
}
impl TreeDecoder for LlamaQuantDecoder {
fn last_hidden_state(&mut self) -> Result<Tensor> {
LlamaQuantDecoder::last_hidden_state(self)
}
fn tree_logits(&mut self, tree: &DraftTree) -> Result<Vec<Vec<f32>>> {
LlamaQuantDecoder::tree_logits(self, tree)
}
fn apply_lm_head(&self, hidden: &Tensor) -> Result<Tensor> {
LlamaQuantDecoder::apply_lm_head(self, hidden)
}
fn last_hidden_states_multi(
&mut self,
layers: &[usize],
) -> Result<(Tensor, Vec<Tensor>)> {
LlamaQuantDecoder::last_hidden_states_multi(self, layers)
}
fn num_hidden_layers(&self) -> usize {
LlamaQuantDecoder::num_hidden_layers(self)
}
fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
LlamaQuantDecoder::embed_tokens(self, input_ids)
}
fn tree_logits_keep_kv(
&mut self,
tree: &DraftTree,
) -> Result<(Vec<Vec<f32>>, Vec<Tensor>)> {
LlamaQuantDecoder::tree_logits_keep_kv(self, tree)
}
fn observe_returning_last_hidden(&mut self, ids: &[u32]) -> Result<Tensor> {
LlamaQuantDecoder::observe_returning_last_hidden(self, ids)
}
fn commit_tree_path(
&mut self,
tree: &DraftTree,
accepted_indices: &[usize],
) -> Result<()> {
LlamaQuantDecoder::commit_tree_path(self, tree, accepted_indices)
}
}
impl Decoder for LlamaQuantDecoder {
fn encode(&self, text: &str, add_special_tokens: bool) -> Result<Vec<u32>> {
LlamaQuantDecoder::encode(self, text, add_special_tokens)
}
fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
LlamaQuantDecoder::decode(self, ids, skip_special_tokens)
}
fn eos_token_ids(&self) -> Vec<u32> {
self.eos_token_ids.clone()
}
fn vocab_size(&self) -> usize {
self.vocab_size
}
fn history(&self) -> &[u32] {
&self.history
}
fn reset(&mut self) {
self.history.clear();
self.model.clear_kv_cache();
self.cache_len = 0;
self.last_logits = None;
}
fn observe(&mut self, ids: &[u32]) -> Result<()> {
if ids.is_empty() {
return Ok(());
}
let _ = self.forward_advance_logits(ids)?;
self.history.extend_from_slice(ids);
Ok(())
}
fn next_logits(&mut self) -> Result<Vec<f32>> {
if self.history.is_empty() {
return Err(Error::Sampling(
"next_logits called with empty history".into(),
));
}
if let Some(cached) = &self.last_logits {
return Ok(cached.clone());
}
let last = *self.history.last().unwrap();
let target_len = self.history.len() - 1;
self.model
.truncate_kv_cache_to(target_len)
.map_err(Error::Candle)?;
self.cache_len = target_len;
let logits = self.forward_advance_logits(&[last])?;
let last_row = logits
.i((logits.dim(0).map_err(Error::Candle)? - 1, ..))
.map_err(Error::Candle)?;
self.row_to_vec(&last_row)
}
fn batched_logits(&mut self, drafts: &[u32]) -> Result<Vec<Vec<f32>>> {
if drafts.is_empty() {
let logits = self.next_logits()?;
return Ok(vec![logits]);
}
if self.history.is_empty() {
return Err(Error::Sampling("batched_logits with empty history".into()));
}
let last = *self.history.last().unwrap();
let target_len = self.history.len() - 1;
self.model
.truncate_kv_cache_to(target_len)
.map_err(Error::Candle)?;
self.cache_len = target_len;
let mut combined: Vec<u32> = Vec::with_capacity(1 + drafts.len());
combined.push(last);
combined.extend_from_slice(drafts);
let logits = self.forward_advance_logits(&combined)?;
let n_rows = logits.dim(0).map_err(Error::Candle)?;
debug_assert_eq!(n_rows, drafts.len() + 1);
let mut out = Vec::with_capacity(n_rows);
for i in 0..n_rows {
let row = logits.i((i, ..)).map_err(Error::Candle)?;
out.push(self.row_to_vec(&row)?);
}
self.history.extend_from_slice(drafts);
Ok(out)
}
fn rollback_to(&mut self, len: usize) -> Result<()> {
if len > self.history.len() {
return Err(Error::CacheRollback(format!(
"rollback target {len} > history length {}",
self.history.len()
)));
}
self.history.truncate(len);
self.last_logits = None;
self.model
.truncate_kv_cache_to(len)
.map_err(Error::Candle)?;
self.cache_len = len;
Ok(())
}
}