use super::types::{TernaryPlanes, TernaryTensor};
use crate::error::{Result, UnslothError};
use candle_core::Tensor;
#[derive(Debug, Clone)]
pub struct TernaryAttentionScoreConfig {
pub tile_k: u32,
pub block_size: u32,
pub outputs_per_thread: u32,
}
impl TernaryAttentionScoreConfig {
#[must_use]
pub fn rtx_5080() -> Self {
Self {
tile_k: 64, block_size: 256,
outputs_per_thread: 2,
}
}
#[must_use]
pub fn rtx_3090_ti() -> Self {
Self {
tile_k: 32, block_size: 256,
outputs_per_thread: 2,
}
}
#[must_use]
pub fn default_config() -> Self {
Self::rtx_3090_ti()
}
}
#[derive(Debug, Clone)]
pub struct AttentionScoreLaunchConfig {
pub grid_dim: (u32, u32, u32),
pub block_dim: (u32, u32, u32),
}
#[must_use]
pub fn get_attention_score_launch_config(
batch_size: usize,
num_heads: usize,
q_seq_len: usize,
config: &TernaryAttentionScoreConfig,
) -> AttentionScoreLaunchConfig {
let outputs_per_block = config.block_size * config.outputs_per_thread;
let blocks_per_query_dim =
((q_seq_len as u32 + outputs_per_block - 1) / outputs_per_block).max(1);
let grid_x = (batch_size * num_heads) as u32;
let grid_y = blocks_per_query_dim;
AttentionScoreLaunchConfig {
grid_dim: (grid_x, grid_y, 1),
block_dim: (config.block_size, 1, 1),
}
}
pub fn ternary_attention_score_kernel_cpu(
q: &Tensor,
k_ternary: &TernaryTensor,
k_seq_len: usize,
scale: f64,
_config: &TernaryAttentionScoreConfig,
) -> Result<Tensor> {
let q_dims = q.dims();
if q_dims.len() != 4 {
return Err(UnslothError::ShapeMismatch {
expected: vec![4],
actual: q_dims.to_vec(),
});
}
let (batch, num_heads, q_seq_len, head_dim) = (q_dims[0], q_dims[1], q_dims[2], q_dims[3]);
let (k_out_features, k_in_features) = k_ternary.shape();
let expected_k_shape = (num_heads * k_seq_len, head_dim);
if (k_out_features, k_in_features) != expected_k_shape {
return Err(UnslothError::ShapeMismatch {
expected: vec![expected_k_shape.0, expected_k_shape.1],
actual: vec![k_out_features, k_in_features],
});
}
let mut scores_data = vec![0.0f32; batch * num_heads * q_seq_len * k_seq_len];
let q_data = q.to_vec2::<f32>()?;
let q_flat: Vec<f32> = q_data.iter().flatten().copied().collect();
for b in 0..batch {
for h in 0..num_heads {
for qi in 0..q_seq_len {
let q_offset = ((b * num_heads + h) * q_seq_len + qi) * head_dim;
let q_vec = &q_flat[q_offset..q_offset + head_dim];
for ki in 0..k_seq_len {
let k_row = h * k_seq_len + ki;
let score = ternary_dot_product_fp_query(q_vec, k_ternary, k_row, head_dim);
let output_idx = ((b * num_heads + h) * q_seq_len + qi) * k_seq_len + ki;
scores_data[output_idx] = score * (scale as f32);
}
}
}
}
let scores = Tensor::from_vec(
scores_data,
(batch, num_heads, q_seq_len, k_seq_len),
q.device(),
)?;
Ok(scores)
}
fn ternary_dot_product_fp_query(
q_vec: &[f32],
k_ternary: &TernaryTensor,
k_row: usize,
head_dim: usize,
) -> f32 {
const THRESHOLD: f32 = 0.5;
let k_words = (head_dim + 31) / 32;
let mut q_plus = vec![0u32; k_words];
let mut q_minus = vec![0u32; k_words];
for (dim_idx, &val) in q_vec.iter().enumerate().take(head_dim) {
let word_idx = dim_idx / 32;
let bit_idx = (dim_idx % 32) as u32;
if val > THRESHOLD {
q_plus[word_idx] |= 1u32 << bit_idx;
} else if val < -THRESHOLD {
q_minus[word_idx] |= 1u32 << bit_idx;
}
}
let k_plus = k_ternary.plus_plane();
let k_minus = k_ternary.minus_plane();
let k_scale = k_ternary.scale(k_row);
let mut pos_sum = 0u32;
let mut neg_sum = 0u32;
for i in 0..k_words {
let k_plus_word = k_plus[k_row * k_words + i];
let k_minus_word = k_minus[k_row * k_words + i];
pos_sum += (q_plus[i] & k_plus_word).count_ones();
pos_sum += (q_minus[i] & k_minus_word).count_ones();
neg_sum += (q_plus[i] & k_minus_word).count_ones();
neg_sum += (q_minus[i] & k_plus_word).count_ones();
}
let dot = (pos_sum as i32 - neg_sum as i32) as f32;
dot * k_scale
}
pub fn ternary_attention_score_cuda(
q: &Tensor,
k_ternary: &TernaryTensor,
k_seq_len: usize,
scale: f64,
config: &TernaryAttentionScoreConfig,
) -> Result<Tensor> {
ternary_attention_score_kernel_cpu(q, k_ternary, k_seq_len, scale, config)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
fn create_test_query(batch: usize, heads: usize, seq_len: usize, dim: usize) -> Tensor {
let size = batch * heads * seq_len * dim;
let data: Vec<f32> = (0..size).map(|i| (i % 10) as f32 * 0.1).collect();
Tensor::from_vec(data, (batch, heads, seq_len, dim), &Device::Cpu).unwrap()
}
fn create_test_ternary_keys(heads: usize, seq_len: usize, dim: usize) -> TernaryTensor {
let out_features = heads * seq_len;
let k_words = (dim + 31) / 32;
let mut plus = vec![0u32; out_features * k_words];
let mut minus = vec![0u32; out_features * k_words];
for i in 0..out_features {
if i % 3 == 0 {
plus[i * k_words] = 0xFF00FF00; } else if i % 3 == 1 {
minus[i * k_words] = 0x00FF00FF; }
}
let scales = vec![1.0f32; out_features];
let shape = (out_features, dim);
TernaryTensor::new(plus, minus, scales, shape)
}
#[test]
fn test_attention_score_config_creation() {
let config = TernaryAttentionScoreConfig::rtx_5080();
assert_eq!(config.tile_k, 64);
assert_eq!(config.block_size, 256);
assert_eq!(config.outputs_per_thread, 2);
let config = TernaryAttentionScoreConfig::rtx_3090_ti();
assert_eq!(config.tile_k, 32);
}
#[test]
fn test_attention_score_launch_config() {
let config = TernaryAttentionScoreConfig::default_config();
let launch = get_attention_score_launch_config(2, 8, 128, &config);
assert_eq!(launch.grid_dim.0, 2 * 8);
assert_eq!(launch.block_dim.0, config.block_size);
}
#[test]
fn test_ternary_attention_score_kernel_shape() {
let batch = 2;
let heads = 4;
let q_seq = 8;
let k_seq = 8;
let dim = 64;
let q = create_test_query(batch, heads, q_seq, dim);
let k = create_test_ternary_keys(heads, k_seq, dim);
let config = TernaryAttentionScoreConfig::default_config();
let scores = ternary_attention_score_kernel_cpu(&q, &k, k_seq, 1.0, &config).unwrap();
assert_eq!(scores.dims(), &[batch, heads, q_seq, k_seq]);
}
#[test]
fn test_ternary_attention_score_kernel_numerical() {
let batch = 1;
let heads = 1;
let q_seq = 2;
let k_seq = 2;
let dim = 32;
let mut q_data = vec![0.0f32; batch * heads * q_seq * dim];
q_data[0] = 1.0;
q_data[2] = -1.0;
q_data[dim + 1] = 1.0;
q_data[dim + 3] = -1.0;
let q = Tensor::from_vec(q_data, (batch, heads, q_seq, dim), &Device::Cpu).unwrap();
let k_words = 1; let mut plus = vec![0u32; heads * k_seq * k_words];
let mut minus = vec![0u32; heads * k_seq * k_words];
plus[0] = 1u32;
plus[1] = 2u32;
let scales = vec![1.0f32; heads * k_seq];
let k = TernaryTensor::new(plus, minus, scales, (heads * k_seq, dim));
let config = TernaryAttentionScoreConfig::default_config();
let scores = ternary_attention_score_kernel_cpu(&q, &k, k_seq, 1.0, &config).unwrap();
let scores_data = scores.to_vec3::<f32>().unwrap();
let score_q0_k0 = scores_data[0][0][0][0];
let score_q0_k1 = scores_data[0][0][0][1];
assert!(score_q0_k0 > score_q0_k1);
let score_q1_k0 = scores_data[0][0][1][0];
let score_q1_k1 = scores_data[0][0][1][1];
assert!(score_q1_k1 > score_q1_k0);
}
#[test]
fn test_ternary_dot_product_simple() {
let dim = 32;
let k_words = 1;
let mut q = vec![0.0f32; dim];
q[0] = 1.0;
q[2] = -1.0;
let plus = vec![1u32]; let minus = vec![4u32]; let scales = vec![1.0f32];
let k = TernaryTensor::new(plus, minus, scales, (1, dim));
let score = ternary_dot_product_fp_query(&q, &k, 0, dim);
assert!((score - 2.0).abs() < 0.001, "Expected 2.0, got {}", score);
}
#[test]
fn test_ternary_attention_score_cuda_fallback() {
let batch = 1;
let heads = 2;
let q_seq = 4;
let k_seq = 4;
let dim = 64;
let q = create_test_query(batch, heads, q_seq, dim);
let k = create_test_ternary_keys(heads, k_seq, dim);
let config = TernaryAttentionScoreConfig::default_config();
let scores = ternary_attention_score_cuda(&q, &k, k_seq, 1.0, &config).unwrap();
assert_eq!(scores.dims(), &[batch, heads, q_seq, k_seq]);
}
}
#[derive(Debug, Clone)]
struct OnlineSoftmaxState {
max_score: f32,
exp_sum: f32,
}
impl OnlineSoftmaxState {
#[must_use]
fn new() -> Self {
Self {
max_score: f32::NEG_INFINITY,
exp_sum: 0.0,
}
}
fn update(&mut self, score: f32) -> (f32, f32, f32) {
let old_max = self.max_score;
if score > self.max_score {
self.max_score = score;
}
let correction = if self.max_score != old_max {
(old_max - self.max_score).exp()
} else {
1.0
};
let weight = (score - self.max_score).exp();
self.exp_sum = self.exp_sum * correction + weight;
(self.max_score, correction, weight)
}
#[must_use]
fn get_norm_factor(&self) -> f32 {
if self.exp_sum > 0.0 {
1.0 / self.exp_sum
} else {
0.0
}
}
}
pub fn ternary_attention_online_softmax_cpu(
q: &Tensor,
k_ternary: &TernaryTensor,
v: &Tensor,
k_seq_len: usize,
scale: f64,
_config: &TernaryAttentionScoreConfig,
) -> Result<Tensor> {
let q_dims = q.dims();
if q_dims.len() != 4 {
return Err(UnslothError::ShapeMismatch {
expected: vec![4],
actual: q_dims.to_vec(),
});
}
let (batch, num_heads, q_seq_len, head_dim) = (q_dims[0], q_dims[1], q_dims[2], q_dims[3]);
let v_dims = v.dims();
if v_dims != [batch, num_heads, k_seq_len, head_dim] {
return Err(UnslothError::ShapeMismatch {
expected: vec![batch, num_heads, k_seq_len, head_dim],
actual: v_dims.to_vec(),
});
}
let (k_out_features, k_in_features) = k_ternary.shape();
let expected_k_shape = (num_heads * k_seq_len, head_dim);
if (k_out_features, k_in_features) != expected_k_shape {
return Err(UnslothError::ShapeMismatch {
expected: vec![expected_k_shape.0, expected_k_shape.1],
actual: vec![k_out_features, k_in_features],
});
}
let mut output_data = vec![0.0f32; batch * num_heads * q_seq_len * head_dim];
let q_data = q.to_vec2::<f32>()?;
let q_flat: Vec<f32> = q_data.iter().flatten().copied().collect();
let v_data = v.to_vec2::<f32>()?;
let v_flat: Vec<f32> = v_data.iter().flatten().copied().collect();
for b in 0..batch {
for h in 0..num_heads {
for qi in 0..q_seq_len {
let mut softmax_state = OnlineSoftmaxState::new();
let mut output_accum = vec![0.0f32; head_dim];
let q_offset = ((b * num_heads + h) * q_seq_len + qi) * head_dim;
let q_vec = &q_flat[q_offset..q_offset + head_dim];
for ki in 0..k_seq_len {
let k_row = h * k_seq_len + ki;
let score = ternary_dot_product_fp_query(q_vec, k_ternary, k_row, head_dim);
let scaled_score = score * (scale as f32);
let (_max, correction, weight) = softmax_state.update(scaled_score);
if correction != 1.0 {
for val in &mut output_accum {
*val *= correction;
}
}
let v_offset = ((b * num_heads + h) * k_seq_len + ki) * head_dim;
let v_vec = &v_flat[v_offset..v_offset + head_dim];
for (j, v_val) in v_vec.iter().enumerate() {
output_accum[j] += weight * v_val;
}
}
let norm_factor = softmax_state.get_norm_factor();
for val in &mut output_accum {
*val *= norm_factor;
}
let out_offset = ((b * num_heads + h) * q_seq_len + qi) * head_dim;
output_data[out_offset..out_offset + head_dim].copy_from_slice(&output_accum);
}
}
}
let output = Tensor::from_vec(
output_data,
(batch, num_heads, q_seq_len, head_dim),
q.device(),
)?;
Ok(output)
}
pub fn ternary_attention_cuda(
q: &Tensor,
k_ternary: &TernaryTensor,
v: &Tensor,
k_seq_len: usize,
scale: f64,
config: &TernaryAttentionScoreConfig,
) -> Result<Tensor> {
ternary_attention_online_softmax_cpu(q, k_ternary, v, k_seq_len, scale, config)
}
#[cfg(test)]
mod online_softmax_tests {
use super::*;
use candle_core::Device;
fn create_test_values(batch: usize, heads: usize, seq_len: usize, dim: usize) -> Tensor {
let size = batch * heads * seq_len * dim;
let data: Vec<f32> = (0..size).map(|i| ((i % 7) as f32 + 1.0) * 0.1).collect();
Tensor::from_vec(data, (batch, heads, seq_len, dim), &Device::Cpu).unwrap()
}
#[test]
fn test_online_softmax_state_update() {
let mut state = OnlineSoftmaxState::new();
let (max1, corr1, weight1) = state.update(1.0);
assert_eq!(max1, 1.0);
assert_eq!(corr1, 1.0);
assert!((weight1 - 1.0).abs() < 0.001);
let (max2, corr2, weight2) = state.update(2.0);
assert_eq!(max2, 2.0);
assert!(corr2 < 1.0); assert!((weight2 - 1.0).abs() < 0.001);
let (max3, corr3, weight3) = state.update(0.5);
assert_eq!(max3, 2.0); assert_eq!(corr3, 1.0); assert!(weight3 < weight2); }
#[test]
fn test_online_softmax_normalization() {
let mut state = OnlineSoftmaxState::new();
state.update(1.0);
state.update(2.0);
state.update(3.0);
let norm = state.get_norm_factor();
assert!(norm > 0.0);
assert!(norm < 1.0); }
#[test]
fn test_ternary_attention_online_softmax_shape() {
let batch = 2;
let heads = 4;
let q_seq = 8;
let k_seq = 8;
let dim = 64;
let q = super::tests::create_test_query(batch, heads, q_seq, dim);
let k = super::tests::create_test_ternary_keys(heads, k_seq, dim);
let v = create_test_values(batch, heads, k_seq, dim);
let config = TernaryAttentionScoreConfig::default_config();
let output = ternary_attention_online_softmax_cpu(&q, &k, &v, k_seq, 1.0, &config).unwrap();
assert_eq!(output.dims(), &[batch, heads, q_seq, dim]);
}
#[test]
fn test_ternary_attention_online_softmax_numerical() {
let batch = 1;
let heads = 1;
let q_seq = 1;
let k_seq = 3;
let dim = 32;
let mut q_data = vec![0.0f32; batch * heads * q_seq * dim];
q_data[0] = 1.0;
let q = Tensor::from_vec(q_data, (batch, heads, q_seq, dim), &Device::Cpu).unwrap();
let k_words = 1;
let mut plus = vec![0u32; heads * k_seq * k_words];
let mut minus = vec![0u32; heads * k_seq * k_words];
plus[0] = 1u32;
plus[1] = 2u32;
plus[2] = 0u32;
let scales = vec![1.0f32; heads * k_seq];
let k = TernaryTensor::new(plus, minus, scales, (heads * k_seq, dim));
let mut v_data = vec![0.0f32; batch * heads * k_seq * dim];
v_data[0] = 1.0; v_data[dim + 1] = 1.0; v_data[2 * dim + 2] = 1.0; let v = Tensor::from_vec(v_data, (batch, heads, k_seq, dim), &Device::Cpu).unwrap();
let config = TernaryAttentionScoreConfig::default_config();
let output = ternary_attention_online_softmax_cpu(&q, &k, &v, k_seq, 1.0, &config).unwrap();
let output_vec = output.to_vec2::<f32>().unwrap();
let output_flat: Vec<f32> = output_vec.iter().flatten().copied().collect();
assert!(output_flat[0] > output_flat[1]);
assert!(output_flat[0] > output_flat[2]);
}
#[test]
fn test_ternary_attention_online_softmax_stability() {
let batch = 1;
let heads = 1;
let q_seq = 1;
let k_seq = 2;
let dim = 32;
let q = super::tests::create_test_query(batch, heads, q_seq, dim);
let k = super::tests::create_test_ternary_keys(heads, k_seq, dim);
let v = create_test_values(batch, heads, k_seq, dim);
let config = TernaryAttentionScoreConfig::default_config();
let output =
ternary_attention_online_softmax_cpu(&q, &k, &v, k_seq, 10.0, &config).unwrap();
let output_vec = output.to_vec2::<f32>().unwrap();
let output_flat: Vec<f32> = output_vec.iter().flatten().copied().collect();
for val in output_flat {
assert!(val.is_finite());
}
}
#[test]
fn test_ternary_attention_cuda_fallback() {
let batch = 1;
let heads = 2;
let q_seq = 4;
let k_seq = 4;
let dim = 64;
let q = super::tests::create_test_query(batch, heads, q_seq, dim);
let k = super::tests::create_test_ternary_keys(heads, k_seq, dim);
let v = create_test_values(batch, heads, k_seq, dim);
let config = TernaryAttentionScoreConfig::default_config();
let output = ternary_attention_cuda(&q, &k, &v, k_seq, 1.0, &config).unwrap();
assert_eq!(output.dims(), &[batch, heads, q_seq, dim]);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AttentionDispatchMode {
FullPrecision,
TernaryK,
Auto,
}
#[derive(Debug, Clone)]
pub struct HybridAttentionConfig {
pub ternary_config: TernaryAttentionScoreConfig,
pub mode: AttentionDispatchMode,
pub sparsity_threshold: f32,
pub seq_len_threshold: usize,
pub accuracy_tolerance: f32,
}
impl HybridAttentionConfig {
#[must_use]
pub fn speed_optimized() -> Self {
Self {
ternary_config: TernaryAttentionScoreConfig::rtx_5080(),
mode: AttentionDispatchMode::Auto,
sparsity_threshold: 0.70, seq_len_threshold: 512, accuracy_tolerance: 0.05, }
}
#[must_use]
pub fn accuracy_optimized() -> Self {
Self {
ternary_config: TernaryAttentionScoreConfig::rtx_3090_ti(),
mode: AttentionDispatchMode::Auto,
sparsity_threshold: 0.90, seq_len_threshold: 2048, accuracy_tolerance: 0.01, }
}
#[must_use]
pub fn balanced() -> Self {
Self {
ternary_config: TernaryAttentionScoreConfig::default_config(),
mode: AttentionDispatchMode::Auto,
sparsity_threshold: 0.80, seq_len_threshold: 1024, accuracy_tolerance: 0.02, }
}
#[must_use]
pub fn force_fp() -> Self {
let mut config = Self::balanced();
config.mode = AttentionDispatchMode::FullPrecision;
config
}
#[must_use]
pub fn force_ternary() -> Self {
let mut config = Self::balanced();
config.mode = AttentionDispatchMode::TernaryK;
config
}
}
impl Default for HybridAttentionConfig {
fn default() -> Self {
Self::balanced()
}
}
#[derive(Debug, Clone)]
pub struct DispatchDecision {
pub selected_mode: AttentionDispatchMode,
pub k_sparsity: f32,
pub seq_len: usize,
pub reason: String,
}
#[must_use]
pub fn decide_attention_mode(
k_ternary: &TernaryTensor,
seq_len: usize,
config: &HybridAttentionConfig,
) -> DispatchDecision {
let k_sparsity = k_ternary.sparsity();
match config.mode {
AttentionDispatchMode::FullPrecision => {
return DispatchDecision {
selected_mode: AttentionDispatchMode::FullPrecision,
k_sparsity,
seq_len,
reason: "Forced full precision mode".to_string(),
};
}
AttentionDispatchMode::TernaryK => {
return DispatchDecision {
selected_mode: AttentionDispatchMode::TernaryK,
k_sparsity,
seq_len,
reason: "Forced ternary mode".to_string(),
};
}
AttentionDispatchMode::Auto => {
}
}
let prefer_ternary_sparsity = k_sparsity >= config.sparsity_threshold;
let prefer_ternary_length = seq_len >= config.seq_len_threshold;
let use_ternary = if prefer_ternary_sparsity && prefer_ternary_length {
true
} else if prefer_ternary_sparsity {
true
} else if prefer_ternary_length && k_sparsity >= 0.5 {
true
} else {
false
};
let (selected_mode, reason) = if use_ternary {
(
AttentionDispatchMode::TernaryK,
format!(
"Auto: K sparsity={:.1}% (threshold={:.1}%), seq_len={} (threshold={}), using ternary",
k_sparsity * 100.0,
config.sparsity_threshold * 100.0,
seq_len,
config.seq_len_threshold
),
)
} else {
(
AttentionDispatchMode::FullPrecision,
format!(
"Auto: K sparsity={:.1}% (threshold={:.1}%), seq_len={} (threshold={}), using FP",
k_sparsity * 100.0,
config.sparsity_threshold * 100.0,
seq_len,
config.seq_len_threshold
),
)
};
DispatchDecision {
selected_mode,
k_sparsity,
seq_len,
reason,
}
}
pub fn hybrid_attention(
q: &Tensor,
k_ternary: &TernaryTensor,
k_fp: &Tensor, v: &Tensor,
k_seq_len: usize,
scale: f64,
config: &HybridAttentionConfig,
) -> Result<(Tensor, DispatchDecision)> {
let decision = decide_attention_mode(k_ternary, k_seq_len, config);
let output = match decision.selected_mode {
AttentionDispatchMode::TernaryK => {
ternary_attention_cuda(q, k_ternary, v, k_seq_len, scale, &config.ternary_config)?
}
AttentionDispatchMode::FullPrecision => {
fallback_fp_attention(q, k_fp, v, scale)?
}
AttentionDispatchMode::Auto => {
unreachable!("Auto mode should be resolved by decide_attention_mode")
}
};
Ok((output, decision))
}
fn fallback_fp_attention(q: &Tensor, k: &Tensor, v: &Tensor, scale: f64) -> Result<Tensor> {
let scores = q.matmul(&k.transpose(2, 3)?)?;
let scores = (scores * scale)?;
let probs = candle_nn::ops::softmax(&scores, candle_core::D::Minus1)?;
let output = probs.matmul(v)?;
Ok(output)
}
#[cfg(test)]
mod hybrid_dispatch_tests {
use super::*;
use candle_core::Device;
fn create_sparse_ternary_keys(
heads: usize,
seq_len: usize,
dim: usize,
sparsity: f32,
) -> TernaryTensor {
let out_features = heads * seq_len;
let k_words = (dim + 31) / 32;
let mut plus = vec![0u32; out_features * k_words];
let mut minus = vec![0u32; out_features * k_words];
let bits_per_feature = dim;
let active_bits = ((1.0 - sparsity) * bits_per_feature as f32) as usize;
for i in 0..out_features {
for bit in 0..active_bits.min(bits_per_feature) {
let word_idx = bit / 32;
let bit_idx = (bit % 32) as u32;
if i % 2 == 0 {
plus[i * k_words + word_idx] |= 1u32 << bit_idx;
} else {
minus[i * k_words + word_idx] |= 1u32 << bit_idx;
}
}
}
let scales = vec![1.0f32; out_features];
TernaryTensor::new(plus, minus, scales, (out_features, dim))
}
#[test]
fn test_hybrid_config_presets() {
let speed = HybridAttentionConfig::speed_optimized();
assert_eq!(speed.mode, AttentionDispatchMode::Auto);
assert!(speed.sparsity_threshold < 0.80);
let accuracy = HybridAttentionConfig::accuracy_optimized();
assert!(accuracy.sparsity_threshold > 0.80);
let balanced = HybridAttentionConfig::balanced();
assert_eq!(balanced.sparsity_threshold, 0.80);
}
#[test]
fn test_dispatch_decision_forced_fp() {
let config = HybridAttentionConfig::force_fp();
let k = create_sparse_ternary_keys(4, 8, 64, 0.95);
let decision = decide_attention_mode(&k, 2048, &config);
assert_eq!(decision.selected_mode, AttentionDispatchMode::FullPrecision);
assert!(decision.reason.contains("Forced"));
}
#[test]
fn test_dispatch_decision_forced_ternary() {
let config = HybridAttentionConfig::force_ternary();
let k = create_sparse_ternary_keys(4, 8, 64, 0.10);
let decision = decide_attention_mode(&k, 128, &config);
assert_eq!(decision.selected_mode, AttentionDispatchMode::TernaryK);
assert!(decision.reason.contains("Forced"));
}
#[test]
fn test_dispatch_decision_auto_high_sparsity() {
let config = HybridAttentionConfig::balanced();
let k = create_sparse_ternary_keys(4, 8, 64, 0.90);
let decision = decide_attention_mode(&k, 512, &config);
assert_eq!(decision.selected_mode, AttentionDispatchMode::TernaryK);
assert!(decision.k_sparsity >= config.sparsity_threshold);
}
#[test]
fn test_dispatch_decision_auto_low_sparsity() {
let config = HybridAttentionConfig::balanced();
let k = create_sparse_ternary_keys(4, 8, 64, 0.20);
let decision = decide_attention_mode(&k, 128, &config);
assert_eq!(decision.selected_mode, AttentionDispatchMode::FullPrecision);
assert!(decision.k_sparsity < config.sparsity_threshold);
}
#[test]
fn test_dispatch_decision_auto_long_sequence() {
let config = HybridAttentionConfig::balanced();
let k = create_sparse_ternary_keys(4, 2048, 64, 0.60);
let decision = decide_attention_mode(&k, 2048, &config);
assert_eq!(decision.selected_mode, AttentionDispatchMode::TernaryK);
}
#[test]
fn test_hybrid_attention_shape() {
let batch = 2;
let heads = 4;
let q_seq = 8;
let k_seq = 8;
let dim = 64;
let q =
super::online_softmax_tests::super::tests::create_test_query(batch, heads, q_seq, dim);
let k_ternary = create_sparse_ternary_keys(heads, k_seq, dim, 0.90);
let k_fp = Tensor::zeros(
(batch, heads, k_seq, dim),
candle_core::DType::F32,
&Device::Cpu,
)
.unwrap();
let v = super::online_softmax_tests::create_test_values(batch, heads, k_seq, dim);
let config = HybridAttentionConfig::force_ternary();
let (output, decision) =
hybrid_attention(&q, &k_ternary, &k_fp, &v, k_seq, 1.0, &config).unwrap();
assert_eq!(output.dims(), &[batch, heads, q_seq, dim]);
assert_eq!(decision.selected_mode, AttentionDispatchMode::TernaryK);
}
#[test]
fn test_fallback_fp_attention_shape() {
let batch = 1;
let heads = 2;
let seq = 4;
let dim = 64;
let q = Tensor::zeros(
(batch, heads, seq, dim),
candle_core::DType::F32,
&Device::Cpu,
)
.unwrap();
let k = Tensor::zeros(
(batch, heads, seq, dim),
candle_core::DType::F32,
&Device::Cpu,
)
.unwrap();
let v = Tensor::ones(
(batch, heads, seq, dim),
candle_core::DType::F32,
&Device::Cpu,
)
.unwrap();
let output = fallback_fp_attention(&q, &k, &v, 1.0).unwrap();
assert_eq!(output.dims(), &[batch, heads, seq, dim]);
}
}
#[derive(Debug, Clone)]
pub struct CausalMaskConfig {
pub mask_value: f32,
pub enable_plane_optimization: bool,
pub attention_config: HybridAttentionConfig,
}
impl CausalMaskConfig {
pub fn default() -> Self {
Self {
mask_value: -1e9,
enable_plane_optimization: true,
attention_config: HybridAttentionConfig::balanced(),
}
}
pub fn with_attention_config(attention_config: HybridAttentionConfig) -> Self {
Self {
mask_value: -1e9,
enable_plane_optimization: true,
attention_config,
}
}
}
pub fn apply_causal_mask_to_scores(scores: &Tensor, mask_value: f32) -> Result<Tensor> {
use candle_core::Device;
let shape = scores.dims();
if shape.len() != 4 {
return Err(UnslothError::ShapeMismatch {
expected: "4D tensor [batch, heads, q_len, k_len]".to_string(),
got: format!("{:?}", shape),
});
}
let (batch, heads, q_len, k_len) = (shape[0], shape[1], shape[2], shape[3]);
let mut mask_data = vec![0.0f32; q_len * k_len];
for i in 0..q_len {
for j in 0..k_len {
if j > i {
mask_data[i * k_len + j] = mask_value;
}
}
}
let device = scores.device();
let mask = Tensor::from_vec(mask_data, (q_len, k_len), device)?;
let mask_broadcast = mask.broadcast_as((batch, heads, q_len, k_len))?;
let masked_scores = (scores + mask_broadcast)?;
Ok(masked_scores)
}
pub fn causal_masked_ternary_attention(
q: &Tensor,
k_ternary: &TernaryTensor,
v: &Tensor,
config: &CausalMaskConfig,
) -> Result<Tensor> {
let shape = q.dims();
let (batch, heads, q_seq_len, head_dim) = (shape[0], shape[1], shape[2], shape[3]);
let k_seq_len = k_ternary.shape()[1];
let scores = ternary_attention_score_cpu(q, k_ternary, 1.0 / (head_dim as f32).sqrt())?;
let masked_scores = apply_causal_mask_to_scores(&scores, config.mask_value)?;
let output = compute_attention_from_scores(&masked_scores, v)?;
Ok(output)
}
fn compute_attention_from_scores(scores: &Tensor, v: &Tensor) -> Result<Tensor> {
use candle_core::D;
let attn_weights = candle_nn::ops::softmax_last_dim(scores)?;
let output = attn_weights.matmul(&v)?;
Ok(output)
}
#[cfg(test)]
mod causal_mask_tests {
use super::*;
use candle_core::{Device, Tensor};
#[test]
fn test_causal_mask_config() {
let config = CausalMaskConfig::default();
assert_eq!(config.mask_value, -1e9);
assert!(config.enable_plane_optimization);
let custom =
CausalMaskConfig::with_attention_config(HybridAttentionConfig::speed_optimized());
assert_eq!(custom.mask_value, -1e9);
}
#[test]
fn test_apply_causal_mask_shape() {
let batch = 2;
let heads = 4;
let seq = 8;
let scores = Tensor::zeros(
(batch, heads, seq, seq),
candle_core::DType::F32,
&Device::Cpu,
)
.unwrap();
let masked = apply_causal_mask_to_scores(&scores, -1e9).unwrap();
assert_eq!(masked.dims(), &[batch, heads, seq, seq]);
}
#[test]
fn test_apply_causal_mask_values() {
let batch = 1;
let heads = 1;
let seq = 4;
let scores = Tensor::ones(
(batch, heads, seq, seq),
candle_core::DType::F32,
&Device::Cpu,
)
.unwrap();
let masked = apply_causal_mask_to_scores(&scores, -1e9).unwrap();
let masked_data = masked.to_vec3::<f32>().unwrap();
for i in 0..seq {
for j in 0..seq {
let value = masked_data[0][i][j];
if j > i {
assert!(value < -1e8, "Position ({}, {}) should be masked", i, j);
} else {
assert!(
(value - 1.0).abs() < 1e-5,
"Position ({}, {}) should be unmasked",
i,
j
);
}
}
}
}
#[test]
fn test_causal_masked_attention_shape() {
let batch = 2;
let heads = 4;
let seq = 8;
let dim = 64;
let q = super::hybrid_dispatch_tests::super::online_softmax_tests::super::tests::create_test_query(batch, heads, seq, dim);
let k_ternary =
super::hybrid_dispatch_tests::create_sparse_ternary_keys(heads, seq, dim, 0.80);
let v = super::hybrid_dispatch_tests::super::online_softmax_tests::create_test_values(
batch, heads, seq, dim,
);
let config = CausalMaskConfig::default();
let output = causal_masked_ternary_attention(&q, &k_ternary, &v, &config).unwrap();
assert_eq!(output.dims(), &[batch, heads, seq, dim]);
}
#[test]
fn test_causal_masked_attention_numerical() {
let batch = 1;
let heads = 1;
let seq = 4;
let dim = 8;
let q = Tensor::ones(
(batch, heads, seq, dim),
candle_core::DType::F32,
&Device::Cpu,
)
.unwrap();
let k_ternary =
super::hybrid_dispatch_tests::create_sparse_ternary_keys(heads, seq, dim, 0.50);
let v = Tensor::ones(
(batch, heads, seq, dim),
candle_core::DType::F32,
&Device::Cpu,
)
.unwrap();
let config = CausalMaskConfig::default();
let output = causal_masked_ternary_attention(&q, &k_ternary, &v, &config).unwrap();
let output_data = output.to_vec4::<f32>().unwrap();
for &val in output_data[0][0].iter().flat_map(|row| row.iter()) {
assert!(val.is_finite(), "Output contains non-finite value: {}", val);
}
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
use candle_core::{Device, Tensor};
fn create_test_ternary_keys(
heads: usize,
seq: usize,
dim: usize,
sparsity: f32,
) -> TernaryTensor {
use super::super::quantize::quantize_to_ternary;
let k_words = (dim + 31) / 32;
let total_values = heads * seq * dim;
let num_zeros = (total_values as f32 * sparsity) as usize;
let num_ones = (total_values - num_zeros) / 2;
let num_neg_ones = total_values - num_zeros - num_ones;
let mut values = vec![0.0f32; num_zeros];
values.extend(vec![1.0f32; num_ones]);
values.extend(vec![-1.0f32; num_neg_ones]);
let k_fp = Tensor::from_vec(values, (heads, seq, dim), &Device::Cpu).unwrap();
quantize_to_ternary(&k_fp, 0.5).unwrap()
}
#[test]
fn test_end_to_end_multihead_attention() {
let batch = 2;
let num_heads = 8;
let seq_len = 16;
let head_dim = 64;
let q = Tensor::randn(
0.0f32,
1.0,
(batch, num_heads, seq_len, head_dim),
&Device::Cpu,
)
.unwrap();
let k_ternary = create_test_ternary_keys(num_heads, seq_len, head_dim, 0.85);
let v = Tensor::randn(
0.0f32,
1.0,
(batch, num_heads, seq_len, head_dim),
&Device::Cpu,
)
.unwrap();
let config = HybridAttentionConfig::balanced();
let k_fp = Tensor::zeros(
(batch, num_heads, seq_len, head_dim),
candle_core::DType::F32,
&Device::Cpu,
)
.unwrap();
let (output, decision) =
hybrid_attention(&q, &k_ternary, &k_fp, &v, seq_len, 1.0, &config).unwrap();
assert_eq!(output.dims(), &[batch, num_heads, seq_len, head_dim]);
assert_eq!(decision.selected_mode, AttentionDispatchMode::TernaryK);
assert!(decision.k_sparsity >= 0.80);
let output_data = output.flatten_all().unwrap().to_vec1::<f32>().unwrap();
for &val in output_data.iter() {
assert!(val.is_finite(), "Output contains invalid value");
}
}
#[test]
fn test_end_to_end_grouped_query_attention() {
let batch = 1;
let q_heads = 8;
let kv_heads = 2; let seq_len = 32;
let head_dim = 128;
let q = Tensor::randn(
0.0f32,
1.0,
(batch, q_heads, seq_len, head_dim),
&Device::Cpu,
)
.unwrap();
let k_ternary = create_test_ternary_keys(kv_heads, seq_len, head_dim, 0.90);
let v = Tensor::randn(
0.0f32,
1.0,
(batch, kv_heads, seq_len, head_dim),
&Device::Cpu,
)
.unwrap();
let q_single_head = q.narrow(1, 0, 1).unwrap(); let k_single = k_ternary.planes().narrow(0, 0, 1).unwrap(); let v_single_head = v.narrow(1, 0, 1).unwrap();
let k_single_ternary = TernaryTensor::new(
k_single,
k_ternary.scale().clone(),
k_ternary.metadata().clone(),
);
let config = CausalMaskConfig::default();
let output = causal_masked_ternary_attention(
&q_single_head,
&k_single_ternary,
&v_single_head,
&config,
)
.unwrap();
assert_eq!(output.dims(), &[batch, 1, seq_len, head_dim]);
}
#[test]
fn test_end_to_end_long_sequence() {
let batch = 1;
let heads = 4;
let seq_len = 512; let head_dim = 64;
let q =
Tensor::randn(0.0f32, 0.1, (batch, heads, seq_len, head_dim), &Device::Cpu).unwrap();
let k_ternary = create_test_ternary_keys(heads, seq_len, head_dim, 0.75);
let v =
Tensor::randn(0.0f32, 0.1, (batch, heads, seq_len, head_dim), &Device::Cpu).unwrap();
let config = HybridAttentionConfig::speed_optimized();
let k_fp = Tensor::zeros(
(batch, heads, seq_len, head_dim),
candle_core::DType::F32,
&Device::Cpu,
)
.unwrap();
let (output, decision) =
hybrid_attention(&q, &k_ternary, &k_fp, &v, seq_len, 1.0, &config).unwrap();
assert_eq!(decision.selected_mode, AttentionDispatchMode::TernaryK);
assert_eq!(output.dims(), &[batch, heads, seq_len, head_dim]);
}
#[test]
fn test_end_to_end_causal_vs_non_causal() {
let batch = 1;
let heads = 2;
let seq_len = 8;
let head_dim = 32;
let q =
Tensor::randn(0.0f32, 1.0, (batch, heads, seq_len, head_dim), &Device::Cpu).unwrap();
let k_ternary = create_test_ternary_keys(heads, seq_len, head_dim, 0.80);
let v = Tensor::ones(
(batch, heads, seq_len, head_dim),
candle_core::DType::F32,
&Device::Cpu,
)
.unwrap();
let causal_config = CausalMaskConfig::default();
let causal_output =
causal_masked_ternary_attention(&q, &k_ternary, &v, &causal_config).unwrap();
let non_causal_output =
ternary_attention_online_softmax_cpu(&q, &k_ternary, &v, 1.0).unwrap();
assert_eq!(causal_output.dims(), non_causal_output.dims());
let causal_data = causal_output
.flatten_all()
.unwrap()
.to_vec1::<f32>()
.unwrap();
let non_causal_data = non_causal_output
.flatten_all()
.unwrap()
.to_vec1::<f32>()
.unwrap();
let mut diffs = 0;
for (c, nc) in causal_data.iter().zip(non_causal_data.iter()) {
if (c - nc).abs() > 1e-5 {
diffs += 1;
}
}
assert!(diffs > 0, "Causal and non-causal outputs should differ");
}
#[test]
fn test_end_to_end_sparsity_threshold_behavior() {
let batch = 1;
let heads = 2;
let seq_len = 16;
let head_dim = 64;
let q =
Tensor::randn(0.0f32, 1.0, (batch, heads, seq_len, head_dim), &Device::Cpu).unwrap();
let v =
Tensor::randn(0.0f32, 1.0, (batch, heads, seq_len, head_dim), &Device::Cpu).unwrap();
let k_fp = Tensor::zeros(
(batch, heads, seq_len, head_dim),
candle_core::DType::F32,
&Device::Cpu,
)
.unwrap();
let sparse_k = create_test_ternary_keys(heads, seq_len, head_dim, 0.95); let dense_k = create_test_ternary_keys(heads, seq_len, head_dim, 0.20);
let speed_config = HybridAttentionConfig::speed_optimized(); let accuracy_config = HybridAttentionConfig::accuracy_optimized();
let (_, decision1) =
hybrid_attention(&q, &sparse_k, &k_fp, &v, seq_len, 1.0, &speed_config).unwrap();
assert_eq!(decision1.selected_mode, AttentionDispatchMode::TernaryK);
let (_, decision2) =
hybrid_attention(&q, &sparse_k, &k_fp, &v, seq_len, 1.0, &accuracy_config).unwrap();
assert_eq!(decision2.selected_mode, AttentionDispatchMode::TernaryK);
let (_, decision3) =
hybrid_attention(&q, &dense_k, &k_fp, &v, seq_len, 1.0, &speed_config).unwrap();
assert_eq!(
decision3.selected_mode,
AttentionDispatchMode::FullPrecision
);
let (_, decision4) =
hybrid_attention(&q, &dense_k, &k_fp, &v, seq_len, 1.0, &accuracy_config).unwrap();
assert_eq!(
decision4.selected_mode,
AttentionDispatchMode::FullPrecision
);
}
#[test]
fn test_end_to_end_numerical_stability() {
let batch = 1;
let heads = 2;
let seq_len = 8;
let head_dim = 32;
let q = Tensor::from_vec(
vec![100.0f32; batch * heads * seq_len * head_dim],
(batch, heads, seq_len, head_dim),
&Device::Cpu,
)
.unwrap();
let k_ternary = create_test_ternary_keys(heads, seq_len, head_dim, 0.80);
let v =
Tensor::randn(0.0f32, 1.0, (batch, heads, seq_len, head_dim), &Device::Cpu).unwrap();
let output = ternary_attention_online_softmax_cpu(&q, &k_ternary, &v, 1.0).unwrap();
let output_data = output.flatten_all().unwrap().to_vec1::<f32>().unwrap();
for &val in output_data.iter() {
assert!(
val.is_finite(),
"Numerical instability: output contains invalid value"
);
assert!(val.abs() < 1e6, "Numerical instability: output too large");
}
}
#[test]
fn test_end_to_end_batch_independence() {
let batch = 4;
let heads = 2;
let seq_len = 8;
let head_dim = 32;
let q =
Tensor::randn(0.0f32, 1.0, (batch, heads, seq_len, head_dim), &Device::Cpu).unwrap();
let k_ternary = create_test_ternary_keys(heads, seq_len, head_dim, 0.85);
let v =
Tensor::randn(0.0f32, 1.0, (batch, heads, seq_len, head_dim), &Device::Cpu).unwrap();
let full_output = ternary_attention_online_softmax_cpu(&q, &k_ternary, &v, 1.0).unwrap();
for b in 0..batch {
let q_single = q.narrow(0, b, 1).unwrap();
let v_single = v.narrow(0, b, 1).unwrap();
let single_output =
ternary_attention_online_softmax_cpu(&q_single, &k_ternary, &v_single, 1.0)
.unwrap();
let full_slice = full_output.narrow(0, b, 1).unwrap();
let single_data = single_output
.flatten_all()
.unwrap()
.to_vec1::<f32>()
.unwrap();
let full_data = full_slice.flatten_all().unwrap().to_vec1::<f32>().unwrap();
for (s, f) in single_data.iter().zip(full_data.iter()) {
assert!((s - f).abs() < 1e-4, "Batch independence violated");
}
}
}
}