#![allow(clippy::needless_range_loop)]
use crate::{Error, Result};
use candle_core::{DType, Device, Tensor};
#[derive(Debug, Clone)]
pub struct DraftTree {
tokens: Vec<u32>,
parents: Vec<usize>,
}
impl DraftTree {
pub fn from_parent_table(nodes: &[(usize, u32)]) -> Result<Self> {
if nodes.is_empty() {
return Err(Error::Sampling(
"DraftTree must have at least a root".into(),
));
}
let mut tokens = Vec::with_capacity(nodes.len());
let mut parents = Vec::with_capacity(nodes.len());
tokens.push(nodes[0].1);
parents.push(0);
for (i, &(p, tok)) in nodes.iter().enumerate().skip(1) {
if p >= i {
return Err(Error::Sampling(format!(
"node {i} has parent index {p}, which is not strictly smaller",
)));
}
tokens.push(tok);
parents.push(p);
}
Ok(Self { tokens, parents })
}
pub fn linear(root: u32, tail: &[u32]) -> Self {
let mut tokens = Vec::with_capacity(tail.len() + 1);
let mut parents = Vec::with_capacity(tail.len() + 1);
tokens.push(root);
parents.push(0);
for (i, &t) in tail.iter().enumerate() {
tokens.push(t);
parents.push(i); }
Self { tokens, parents }
}
pub fn len(&self) -> usize {
self.tokens.len()
}
pub fn is_empty(&self) -> bool {
self.tokens.len() <= 1
}
pub fn token_at(&self, i: usize) -> u32 {
self.tokens[i]
}
pub fn tokens(&self) -> &[u32] {
&self.tokens
}
pub fn parent_of(&self, i: usize) -> usize {
self.parents[i]
}
pub fn ancestors(&self, mut i: usize) -> Vec<usize> {
let mut out = vec![i];
while i != 0 {
i = self.parents[i];
out.push(i);
}
out
}
pub fn depth_of(&self, i: usize) -> usize {
let mut d = 0;
let mut cur = i;
while cur != 0 {
cur = self.parents[cur];
d += 1;
}
d
}
pub fn position_ids(&self, prefix_len: usize) -> Vec<usize> {
(0..self.len())
.map(|i| prefix_len + self.depth_of(i))
.collect()
}
pub fn attention_mask_bool(&self) -> Vec<Vec<bool>> {
let n = self.len();
let mut mask = vec![vec![false; n]; n];
for i in 0..n {
for j in self.ancestors(i) {
mask[i][j] = true;
}
}
mask
}
pub fn paths(&self) -> Vec<Vec<usize>> {
let mut is_leaf = vec![true; self.len()];
for &p in self.parents.iter().skip(1) {
is_leaf[p] = false;
}
let mut out = Vec::new();
for (i, &leaf) in is_leaf.iter().enumerate() {
if leaf {
let mut chain = self.ancestors(i);
chain.reverse(); out.push(chain);
}
}
out
}
pub fn path_to(&self, target: usize) -> Vec<usize> {
let mut chain = self.ancestors(target);
chain.reverse();
chain
}
pub fn tree_self_bias(&self, device: &Device, dtype: DType) -> Result<Tensor> {
let n = self.len();
let mut data = vec![0f32; n * n];
for i in 0..n {
for j in 0..n {
let allowed = self.is_ancestor_of(j, i);
if !allowed {
data[i * n + j] = f32::NEG_INFINITY;
}
}
}
let t = Tensor::from_slice(&data, (n, n), device).map_err(Error::Candle)?;
if dtype != DType::F32 {
t.to_dtype(dtype).map_err(Error::Candle)
} else {
Ok(t)
}
}
pub fn full_attention_bias(
&self,
prefix_len: usize,
device: &Device,
dtype: DType,
) -> Result<Tensor> {
let n = self.len();
let total = prefix_len + n;
let mut data = vec![0f32; n * total];
for i in 0..n {
for j in 0..n {
if !self.is_ancestor_of(j, i) {
data[i * total + prefix_len + j] = f32::NEG_INFINITY;
}
}
}
let t = Tensor::from_slice(&data, (n, total), device).map_err(Error::Candle)?;
if dtype != DType::F32 {
t.to_dtype(dtype).map_err(Error::Candle)
} else {
Ok(t)
}
}
pub fn full_attention_bias_4d(
&self,
prefix_len: usize,
batch: usize,
head_dim_size: usize,
device: &Device,
dtype: DType,
) -> Result<Tensor> {
let bias = self.full_attention_bias(prefix_len, device, dtype)?;
let n = self.len();
bias.reshape((1, 1, n, prefix_len + n))
.and_then(|t| t.expand((batch, head_dim_size, n, prefix_len + n)))
.map_err(Error::Candle)
}
fn is_ancestor_of(&self, ancestor_idx: usize, node_idx: usize) -> bool {
let mut cur = node_idx;
if cur == ancestor_idx {
return true;
}
while cur != 0 {
cur = self.parents[cur];
if cur == ancestor_idx {
return true;
}
}
ancestor_idx == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn linear_tree_is_a_chain() {
let t = DraftTree::linear(10, &[20, 30, 40]);
assert_eq!(t.len(), 4);
assert_eq!(t.tokens(), &[10, 20, 30, 40]);
assert_eq!(t.parent_of(0), 0);
assert_eq!(t.parent_of(1), 0);
assert_eq!(t.parent_of(2), 1);
assert_eq!(t.parent_of(3), 2);
assert_eq!(t.depth_of(3), 3);
assert_eq!(t.paths(), vec![vec![0, 1, 2, 3]]);
}
#[test]
fn branching_tree_paths() {
let t = DraftTree::from_parent_table(&[
(0, 100), (0, 11),
(0, 12),
(1, 23),
(1, 24),
(2, 35),
])
.unwrap();
assert_eq!(t.len(), 6);
assert_eq!(t.depth_of(3), 2);
assert_eq!(t.depth_of(5), 2);
let mut paths = t.paths();
paths.sort_by_key(|p| (p.len(), p.clone()));
assert_eq!(paths, vec![vec![0, 1, 3], vec![0, 1, 4], vec![0, 2, 5]]);
}
#[test]
fn linear_mask_is_lower_triangular() {
let t = DraftTree::linear(10, &[20, 30, 40]);
let m = t.attention_mask_bool();
for i in 0..4 {
for j in 0..4 {
assert_eq!(m[i][j], j <= i, "expected causal at ({i},{j})");
}
}
}
#[test]
fn branching_mask_blocks_siblings() {
let t =
DraftTree::from_parent_table(&[(0, 100), (0, 11), (0, 12), (1, 23), (1, 24), (2, 35)])
.unwrap();
let m = t.attention_mask_bool();
assert!(m[3][0] && m[3][1] && m[3][3]);
assert!(!m[3][2], "node 3 must NOT see sibling-of-parent (2)");
assert!(!m[3][4], "node 3 must NOT see sibling (4)");
assert!(!m[3][5], "node 3 must NOT see other-branch leaf (5)");
assert!(m[5][0] && m[5][2] && m[5][5]);
assert!(!m[5][1] && !m[5][3] && !m[5][4]);
}
#[test]
fn position_ids_offset_by_prefix() {
let t = DraftTree::linear(0, &[1, 2, 3]);
let pos = t.position_ids(7);
assert_eq!(pos, vec![7, 8, 9, 10]);
}
#[test]
fn rejects_forward_parent_reference() {
let bad = [(0, 0u32), (5, 1)]; assert!(DraftTree::from_parent_table(&bad).is_err());
}
#[test]
fn rejects_empty_tree() {
assert!(DraftTree::from_parent_table(&[]).is_err());
}
#[test]
fn path_to_walks_root_first() {
let t = DraftTree::from_parent_table(&[(0, 0), (0, 1), (1, 2), (2, 3)]).unwrap();
assert_eq!(t.path_to(3), vec![0, 1, 2, 3]);
}
#[test]
fn tree_self_bias_linear_is_lower_triangular() {
let t = DraftTree::linear(0, &[1, 2, 3]);
let dev = Device::Cpu;
let bias = t.tree_self_bias(&dev, DType::F32).unwrap();
let v = bias.to_vec2::<f32>().unwrap();
for i in 0..4 {
for j in 0..4 {
if j <= i {
assert_eq!(v[i][j], 0.0, "expected 0 at allowed ({i},{j})");
} else {
assert!(v[i][j].is_infinite() && v[i][j].is_sign_negative());
}
}
}
}
#[test]
fn tree_self_bias_branching_blocks_siblings() {
let t =
DraftTree::from_parent_table(&[(0, 100), (0, 11), (0, 12), (1, 23), (1, 24), (2, 35)])
.unwrap();
let bias = t.tree_self_bias(&Device::Cpu, DType::F32).unwrap();
let v = bias.to_vec2::<f32>().unwrap();
for j in 0..6 {
let allowed = matches!(j, 0 | 1 | 3);
if allowed {
assert_eq!(v[3][j], 0.0, "node 3 should attend to {j}");
} else {
assert!(
v[3][j].is_infinite() && v[3][j].is_sign_negative(),
"node 3 should NOT attend to {j}"
);
}
}
for j in 0..6 {
let allowed = matches!(j, 0 | 2 | 5);
if allowed {
assert_eq!(v[5][j], 0.0, "node 5 should attend to {j}");
} else {
assert!(v[5][j].is_infinite() && v[5][j].is_sign_negative());
}
}
}
#[test]
fn full_attention_bias_keeps_prefix_unmasked() {
let t = DraftTree::linear(0, &[1, 2]);
let bias = t.full_attention_bias(5, &Device::Cpu, DType::F32).unwrap();
assert_eq!(bias.dims(), &[3, 5 + 3]);
let v = bias.to_vec2::<f32>().unwrap();
for i in 0..3 {
for j in 0..5 {
assert_eq!(v[i][j], 0.0, "prefix col {j} for tree row {i}");
}
}
for i in 0..3 {
for j in 0..3 {
let v_ij = v[i][5 + j];
if j <= i {
assert_eq!(v_ij, 0.0);
} else {
assert!(v_ij.is_infinite() && v_ij.is_sign_negative());
}
}
}
}
#[test]
fn full_attention_bias_4d_has_expected_shape() {
let t = DraftTree::linear(0, &[1, 2, 3]);
let bias = t
.full_attention_bias_4d(7, 1, 1, &Device::Cpu, DType::F32)
.unwrap();
assert_eq!(bias.dims(), &[1, 1, 4, 7 + 4]);
}
#[test]
fn full_attention_bias_4d_broadcasts_to_heads() {
let t = DraftTree::linear(0, &[1, 2]);
let bias = t
.full_attention_bias_4d(0, 2, 4, &Device::Cpu, DType::F32)
.unwrap();
assert_eq!(bias.dims(), &[2, 4, 3, 3]);
}
#[test]
fn tree_self_bias_dtype_conversion() {
let t = DraftTree::linear(0, &[1, 2]);
let bias_f16 = t.tree_self_bias(&Device::Cpu, DType::F16).unwrap();
assert_eq!(bias_f16.dtype(), DType::F16);
}
}