use crate::errors::InstructionModelError;
use crate::instructions::Instruction;
use crate::utils::dot::{DotKernel, dot};
pub struct AttentionInstruction {
query_ptr: usize,
key_ptr: usize,
output_ptr: usize,
data_size: usize,
key_size: usize,
weights: Vec<f32>,
bias: Vec<f32>,
kernel: DotKernel,
}
impl AttentionInstruction {
pub fn new(
query_ptr: usize,
key_ptr: usize,
output_ptr: usize,
data_size: usize,
weights: &[Vec<f32>],
bias: &[f32],
) -> Self {
let key_size = weights.first().map_or(0, |row| row.len());
let mut flattened_weights = Vec::with_capacity(weights.len() * key_size);
for row in weights {
flattened_weights.extend_from_slice(row);
}
Self {
query_ptr,
key_ptr,
output_ptr,
data_size,
key_size,
weights: flattened_weights,
bias: bias.to_vec(),
kernel: DotKernel::detect(),
}
}
#[inline(always)]
fn apply_linear_transform(&self, buffer: &mut [f32]) {
let key_ptr = unsafe { buffer.as_ptr().add(self.key_ptr) };
let output_ptr = unsafe { buffer.as_mut_ptr().add(self.output_ptr) };
let weights_ptr = self.weights.as_ptr();
let bias = &self.bias;
let row_stride = self.key_size;
let kernel = self.kernel;
for row in 0..self.data_size {
let row_weights_ptr = unsafe { weights_ptr.add(row * row_stride) };
let acc = dot(kernel, row_weights_ptr, key_ptr, row_stride)
+ unsafe { *bias.get_unchecked(row) };
unsafe { *output_ptr.add(row) = acc };
}
}
#[inline(always)]
fn apply_softmax(&self, buffer: &mut [f32]) {
let output_start = self.output_ptr;
let output_slice = &mut buffer[output_start..output_start + self.data_size];
let max_val = output_slice
.iter()
.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut sum = 0.0f32;
for val in output_slice.iter_mut() {
*val = (*val - max_val).exp();
sum += *val;
}
for val in output_slice.iter_mut() {
*val /= sum;
}
}
#[inline(always)]
fn apply_elementwise_multiply(&self, buffer: &mut [f32]) {
let query_start = self.query_ptr;
let output_start = self.output_ptr;
for i in 0..self.data_size {
buffer[output_start + i] *= buffer[query_start + i];
}
}
}
impl Instruction for AttentionInstruction {
fn output_ptr(&self) -> usize {
self.output_ptr
}
fn data_size(&self) -> usize {
self.data_size
}
fn apply(&self, unified_computation_buffer: &mut [f32]) -> Result<(), InstructionModelError> {
self.apply_linear_transform(unified_computation_buffer);
self.apply_softmax(unified_computation_buffer);
self.apply_elementwise_multiply(unified_computation_buffer);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
const DELTA: f32 = 1e-5;
#[test]
fn test_attention_basic() {
let mut buffer = vec![
1.0, 2.0, 0.5, 0.5, 0.0, 0.0, ];
let weights = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let bias = vec![0.0, 0.0];
let instruction = AttentionInstruction::new(0, 2, 4, 2, &weights, &bias);
instruction.apply(&mut buffer).unwrap();
assert!((buffer[4] - 0.5).abs() < DELTA);
assert!((buffer[5] - 1.0).abs() < DELTA);
}
#[test]
fn test_attention_with_bias() {
let mut buffer = vec![
1.0, 1.0, 1.0, 1.0, 0.0, 0.0, ];
let weights = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let bias = vec![1.0, -1.0];
let instruction = AttentionInstruction::new(0, 2, 4, 2, &weights, &bias);
instruction.apply(&mut buffer).unwrap();
let exp_0 = 1.0f32;
let exp_neg2 = (-2.0f32).exp();
let sum = exp_0 + exp_neg2;
let softmax_0 = exp_0 / sum;
let softmax_1 = exp_neg2 / sum;
assert!((buffer[4] - softmax_0).abs() < DELTA);
assert!((buffer[5] - softmax_1).abs() < DELTA);
}
#[test]
fn test_attention_softmax_numerical_stability() {
let mut buffer = vec![
1.0, 1.0, 100.0, 100.0, 0.0, 0.0, ];
let weights = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let bias = vec![0.0, 0.0];
let instruction = AttentionInstruction::new(0, 2, 4, 2, &weights, &bias);
instruction.apply(&mut buffer).unwrap();
assert!((buffer[4] - 0.5).abs() < DELTA);
assert!((buffer[5] - 0.5).abs() < DELTA);
}
}