impl FusedQKVAttention {
pub fn new(head_dim: usize, hidden_dim: usize) -> Result<Self> {
if head_dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "head_dim must be > 0".to_string(),
});
}
if hidden_dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "hidden_dim must be > 0".to_string(),
});
}
if !hidden_dim.is_multiple_of(head_dim) {
return Err(RealizarError::InvalidShape {
reason: format!(
"hidden_dim ({}) must be divisible by head_dim ({})",
hidden_dim, head_dim
),
});
}
let num_heads = hidden_dim / head_dim;
let scale = 1.0 / (head_dim as f32).sqrt();
let proj_size = hidden_dim * hidden_dim;
let init_weight = |size: usize| -> Vec<f32> {
(0..size).map(|i| (i as f32 * 0.001).sin() * 0.02).collect()
};
Ok(Self {
head_dim,
hidden_dim,
num_heads,
scale,
w_q: init_weight(proj_size),
w_k: init_weight(proj_size),
w_v: init_weight(proj_size),
w_o: init_weight(proj_size),
})
}
pub fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let shape = input.shape();
if shape.len() < 2 {
return Err(RealizarError::InvalidShape {
reason: "Input must have at least 2 dimensions [seq_len, hidden_dim]".to_string(),
});
}
let seq_len = shape[0];
let input_dim = shape[shape.len() - 1];
if input_dim != self.hidden_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Input hidden_dim ({}) doesn't match layer hidden_dim ({})",
input_dim, self.hidden_dim
),
});
}
let data = input.data();
let mut q = vec![0.0f32; seq_len * self.hidden_dim];
let mut k = vec![0.0f32; seq_len * self.hidden_dim];
let mut v = vec![0.0f32; seq_len * self.hidden_dim];
for i in 0..seq_len {
for j in 0..self.hidden_dim {
let mut sum_q = 0.0f32;
let mut sum_k = 0.0f32;
let mut sum_v = 0.0f32;
for l in 0..self.hidden_dim {
let inp = data[i * self.hidden_dim + l];
sum_q += inp * self.w_q[l * self.hidden_dim + j];
sum_k += inp * self.w_k[l * self.hidden_dim + j];
sum_v += inp * self.w_v[l * self.hidden_dim + j];
}
q[i * self.hidden_dim + j] = sum_q;
k[i * self.hidden_dim + j] = sum_k;
v[i * self.hidden_dim + j] = sum_v;
}
}
let mut output = vec![0.0f32; seq_len * self.hidden_dim];
for head in 0..self.num_heads {
let head_offset = head * self.head_dim;
for i in 0..seq_len {
let mut max_score = f32::NEG_INFINITY;
for j in 0..=i {
let mut dot = 0.0f32;
for d in 0..self.head_dim {
let q_idx = i * self.hidden_dim + head_offset + d;
let k_idx = j * self.hidden_dim + head_offset + d;
dot += q[q_idx] * k[k_idx];
}
let score = dot * self.scale;
if score > max_score {
max_score = score;
}
}
let mut sum_exp = 0.0f32;
let mut scores = vec![0.0f32; i + 1];
for (j, score) in scores.iter_mut().enumerate() {
let mut dot = 0.0f32;
for d in 0..self.head_dim {
let q_idx = i * self.hidden_dim + head_offset + d;
let k_idx = j * self.hidden_dim + head_offset + d;
dot += q[q_idx] * k[k_idx];
}
*score = (dot * self.scale - max_score).exp();
sum_exp += *score;
}
if sum_exp > 0.0 {
for d in 0..self.head_dim {
let mut weighted_sum = 0.0f32;
for (j, &score) in scores.iter().enumerate() {
let v_idx = j * self.hidden_dim + head_offset + d;
weighted_sum += (score / sum_exp) * v[v_idx];
}
output[i * self.hidden_dim + head_offset + d] = weighted_sum;
}
}
}
}
let mut final_output = vec![0.0f32; seq_len * self.hidden_dim];
for i in 0..seq_len {
for j in 0..self.hidden_dim {
let mut sum = 0.0f32;
for l in 0..self.hidden_dim {
sum += output[i * self.hidden_dim + l] * self.w_o[l * self.hidden_dim + j];
}
final_output[i * self.hidden_dim + j] = sum;
}
}
Tensor::from_vec(vec![seq_len, self.hidden_dim], final_output)
}
#[must_use]
pub fn head_dim(&self) -> usize {
self.head_dim
}
#[must_use]
pub fn hidden_dim(&self) -> usize {
self.hidden_dim
}
#[must_use]
pub fn num_heads(&self) -> usize {
self.num_heads
}
pub fn w_q_mut(&mut self) -> &mut [f32] {
&mut self.w_q
}
pub fn w_k_mut(&mut self) -> &mut [f32] {
&mut self.w_k
}
pub fn w_v_mut(&mut self) -> &mut [f32] {
&mut self.w_v
}
pub fn w_o_mut(&mut self) -> &mut [f32] {
&mut self.w_o
}
}
#[derive(Debug, Clone)]
pub struct MultiHeadAttention {
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
hidden_dim: usize,
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
attention: Attention,
}