use super::types::{TernaryPlanes, TernaryTensor};
use crate::error::{Result, UnslothError};
use candle_core::Tensor;
#[derive(Debug, Clone)]
pub struct TernaryAttentionConfig {
pub num_heads: usize,
pub head_dim: usize,
pub causal: bool,
pub sparsity_threshold: f32,
}
impl Default for TernaryAttentionConfig {
fn default() -> Self {
Self {
num_heads: 12,
head_dim: 64,
causal: true,
sparsity_threshold: 0.8,
}
}
}
#[derive(Debug, Clone)]
pub struct TernaryAttentionWeights {
pub q_proj: TernaryTensor,
pub k_proj: TernaryTensor,
pub v_proj: TernaryTensor,
pub o_proj: TernaryTensor,
pub q_scales: Vec<f32>,
pub k_scales: Vec<f32>,
}
#[must_use]
pub fn ternary_attention_score(
q_planes: &TernaryPlanes,
k_planes: &TernaryPlanes,
scale_q: f32,
scale_k: f32,
) -> f32 {
let dot = q_planes.dot(k_planes);
dot as f32 * scale_q * scale_k
}
#[derive(Debug, Clone)]
pub struct OnlineSoftmaxState {
pub max: f32,
pub sum: f32,
pub output: Vec<f32>,
}
impl OnlineSoftmaxState {
#[must_use]
pub fn new(dim: usize) -> Self {
Self {
max: f32::NEG_INFINITY,
sum: 0.0,
output: vec![0.0; dim],
}
}
pub fn update(&mut self, score: f32, value: &[f32]) {
assert_eq!(value.len(), self.output.len(), "value dimension mismatch");
if score > self.max {
let correction = (self.max - score).exp();
self.sum *= correction;
for o in &mut self.output {
*o *= correction;
}
self.max = score;
}
let exp_score = (score - self.max).exp();
self.sum += exp_score;
for (o, &v) in self.output.iter_mut().zip(value.iter()) {
*o += exp_score * v;
}
}
#[must_use]
pub fn finalize(self) -> Vec<f32> {
if self.sum == 0.0 {
return self.output;
}
self.output.into_iter().map(|o| o / self.sum).collect()
}
}
pub fn apply_causal_mask_to_planes(planes: &mut TernaryPlanes, query_pos: usize, seq_len: usize) {
for pos in (query_pos + 1)..seq_len {
if pos < planes.num_dims {
let word_idx = pos / 32;
let bit_idx = pos % 32;
let mask = !(1u32 << bit_idx);
if word_idx < planes.plus.len() {
planes.plus[word_idx] &= mask;
planes.minus[word_idx] &= mask;
}
}
}
}
pub fn ternary_attention_cpu(
hidden_states: &Tensor,
weights: &TernaryAttentionWeights,
config: &TernaryAttentionConfig,
) -> Result<Tensor> {
use super::matmul::ternary_matmul_cpu;
let dims = hidden_states.dims();
if dims.len() != 3 {
return Err(UnslothError::ShapeMismatch {
expected: vec![3],
actual: dims.to_vec(),
});
}
let (batch, seq_len, _hidden) = (dims[0], dims[1], dims[2]);
let num_heads = config.num_heads;
let head_dim = config.head_dim;
let q = ternary_matmul_cpu(hidden_states, &weights.q_proj)?;
let k = ternary_matmul_cpu(hidden_states, &weights.k_proj)?;
let v = ternary_matmul_cpu(hidden_states, &weights.v_proj)?;
let q = q
.reshape((batch, seq_len, num_heads, head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((batch, seq_len, num_heads, head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((batch, seq_len, num_heads, head_dim))?
.transpose(1, 2)?;
let scale = (head_dim as f64).sqrt();
let scores = q.matmul(&k.transpose(2, 3)?)?;
let scores = (scores / scale)?;
let scores = if config.causal {
let mask = create_causal_mask(seq_len, hidden_states.device())?;
let mask = mask.reshape((1, 1, seq_len, seq_len))?;
scores.broadcast_add(&mask)?
} else {
scores
};
let attn_weights = candle_nn::ops::softmax(&scores, 3)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output =
attn_output
.transpose(1, 2)?
.reshape((batch, seq_len, num_heads * head_dim))?;
let output = ternary_matmul_cpu(&attn_output, &weights.o_proj)?;
Ok(output)
}
fn create_causal_mask(seq_len: usize, device: &candle_core::Device) -> Result<Tensor> {
let mut mask_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in 0..seq_len {
if j > i {
mask_data[i * seq_len + j] = f32::NEG_INFINITY;
}
}
}
let mask = Tensor::from_vec(mask_data, (seq_len, seq_len), device)?;
Ok(mask)
}
#[must_use]
pub fn should_use_ternary_attention(
weights: &TernaryAttentionWeights,
config: &TernaryAttentionConfig,
) -> bool {
let avg_sparsity = (weights.q_proj.sparsity()
+ weights.k_proj.sparsity()
+ weights.v_proj.sparsity()
+ weights.o_proj.sparsity())
/ 4.0;
avg_sparsity >= config.sparsity_threshold
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ternary_attention_score() {
let mut q = TernaryPlanes::new(64);
let mut k = TernaryPlanes::new(64);
q.set(0, 1);
q.set(1, -1);
k.set(0, 1);
k.set(1, -1);
let score = ternary_attention_score(&q, &k, 1.0, 1.0);
assert!((score - 2.0).abs() < 0.001);
}
#[test]
fn test_online_softmax() {
let mut state = OnlineSoftmaxState::new(4);
state.update(1.0, &[1.0, 0.0, 0.0, 0.0]);
state.update(2.0, &[0.0, 1.0, 0.0, 0.0]);
state.update(1.0, &[0.0, 0.0, 1.0, 0.0]);
let output = state.finalize();
let sum: f32 = output.iter().sum();
assert!((sum - 1.0).abs() < 0.001);
}
#[test]
fn test_causal_mask_planes() {
let mut planes = TernaryPlanes::new(8);
for i in 0..8 {
planes.set(i, 1);
}
apply_causal_mask_to_planes(&mut planes, 3, 8);
assert_eq!(planes.get(0), 1);
assert_eq!(planes.get(3), 1);
assert_eq!(planes.get(4), 0);
assert_eq!(planes.get(7), 0);
}
#[test]
fn test_attention_config_default() {
let config = TernaryAttentionConfig::default();
assert_eq!(config.num_heads, 12);
assert_eq!(config.head_dim, 64);
assert!(config.causal);
}
#[test]
fn test_should_use_ternary_attention_high_sparsity() {
let shape = (64, 64);
let k_words = 2;
let plus = vec![0u32; 64 * k_words];
let minus = vec![0u32; 64 * k_words];
let scales = vec![1.0f32; 64];
let weights = TernaryAttentionWeights {
q_proj: TernaryTensor::new(plus.clone(), minus.clone(), scales.clone(), shape),
k_proj: TernaryTensor::new(plus.clone(), minus.clone(), scales.clone(), shape),
v_proj: TernaryTensor::new(plus.clone(), minus.clone(), scales.clone(), shape),
o_proj: TernaryTensor::new(plus, minus, scales, shape),
q_scales: vec![1.0; 12],
k_scales: vec![1.0; 12],
};
let config = TernaryAttentionConfig {
sparsity_threshold: 0.8,
..Default::default()
};
assert!(should_use_ternary_attention(&weights, &config));
}
#[test]
fn test_should_use_ternary_attention_low_sparsity() {
let shape = (64, 64);
let k_words = 2;
let plus = vec![u32::MAX; 64 * k_words];
let minus = vec![0u32; 64 * k_words];
let scales = vec![1.0f32; 64];
let weights = TernaryAttentionWeights {
q_proj: TernaryTensor::new(plus.clone(), minus.clone(), scales.clone(), shape),
k_proj: TernaryTensor::new(plus.clone(), minus.clone(), scales.clone(), shape),
v_proj: TernaryTensor::new(plus.clone(), minus.clone(), scales.clone(), shape),
o_proj: TernaryTensor::new(plus, minus, scales, shape),
q_scales: vec![1.0; 12],
k_scales: vec![1.0; 12],
};
let config = TernaryAttentionConfig {
sparsity_threshold: 0.8,
..Default::default()
};
assert!(!should_use_ternary_attention(&weights, &config));
}
#[test]
fn test_should_use_ternary_attention_at_threshold() {
let shape = (64, 64);
let k_words = 2;
let plus = vec![0u32; 64 * k_words];
let minus = vec![0u32; 64 * k_words];
let scales = vec![1.0f32; 64];
let weights = TernaryAttentionWeights {
q_proj: TernaryTensor::new(plus.clone(), minus.clone(), scales.clone(), shape),
k_proj: TernaryTensor::new(plus.clone(), minus.clone(), scales.clone(), shape),
v_proj: TernaryTensor::new(plus.clone(), minus.clone(), scales.clone(), shape),
o_proj: TernaryTensor::new(plus, minus, scales, shape),
q_scales: vec![1.0; 12],
k_scales: vec![1.0; 12],
};
let config = TernaryAttentionConfig {
sparsity_threshold: 1.0, ..Default::default()
};
assert!(should_use_ternary_attention(&weights, &config));
}
#[test]
fn test_online_softmax_all_masked() {
let state = OnlineSoftmaxState::new(4);
let output = state.finalize();
assert_eq!(output, vec![0.0, 0.0, 0.0, 0.0]);
}
}