use crate::errors::InstructionModelError;
use crate::instructions::Instruction;
pub struct MultiplyBufferHeadsInstruction {
data_ptr: usize,
heads_ptr: usize,
output_ptr: usize,
data_size: usize,
num_heads: usize,
head_dim: usize,
}
impl MultiplyBufferHeadsInstruction {
pub fn new(
data_ptr: usize,
heads_ptr: usize,
output_ptr: usize,
data_size: usize,
heads_size: usize,
) -> Result<Self, InstructionModelError> {
if !data_size.is_multiple_of(heads_size) {
return Err(InstructionModelError::InvalidBufferHeadsSize {
data_size,
heads_size,
});
}
let head_dim = data_size / heads_size;
Ok(Self {
data_ptr,
heads_ptr,
output_ptr,
data_size,
num_heads: heads_size,
head_dim,
})
}
}
impl Instruction for MultiplyBufferHeadsInstruction {
fn output_ptr(&self) -> usize {
self.output_ptr
}
fn data_size(&self) -> usize {
self.data_size
}
fn apply(&self, buffer: &mut [f32]) -> Result<(), InstructionModelError> {
let mut idx = 0;
for head in 0..self.num_heads {
let head_value = buffer[self.heads_ptr + head];
for _ in 0..self.head_dim {
buffer[self.output_ptr + idx] = buffer[self.data_ptr + idx] * head_value;
idx += 1;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multiply_buffer_heads_basic() {
let mut buffer = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ];
let instruction = MultiplyBufferHeadsInstruction::new(0, 8, 10, 8, 2).unwrap();
instruction.apply(&mut buffer).unwrap();
let expected = vec![2.0, 4.0, 6.0, 8.0, 15.0, 18.0, 21.0, 24.0];
assert_eq!(&buffer[10..18], &expected[..]);
}
#[test]
fn test_multiply_buffer_heads_four_heads() {
let data: Vec<f32> = (1..=20).map(|x| x as f32).collect();
let heads = vec![2.0, 3.0, 4.0, 5.0];
let output = vec![0.0; 20];
let mut buffer = Vec::new();
buffer.extend(&data);
buffer.extend(&heads);
buffer.extend(&output);
let instruction = MultiplyBufferHeadsInstruction::new(0, 20, 24, 20, 4).unwrap();
instruction.apply(&mut buffer).unwrap();
let expected: Vec<f32> = vec![
2.0, 4.0, 6.0, 8.0, 10.0, 18.0, 21.0, 24.0, 27.0, 30.0, 44.0, 48.0, 52.0, 56.0, 60.0, 80.0, 85.0, 90.0, 95.0, 100.0, ];
assert_eq!(&buffer[24..44], &expected[..]);
}
#[test]
fn test_multiply_buffer_heads_invalid_size() {
let result = MultiplyBufferHeadsInstruction::new(0, 7, 10, 7, 3);
assert!(result.is_err());
match result {
Err(InstructionModelError::InvalidBufferHeadsSize {
data_size,
heads_size,
}) => {
assert_eq!(data_size, 7);
assert_eq!(heads_size, 3);
}
_ => panic!("Expected InvalidBufferHeadsSize error"),
}
}
#[test]
fn test_multiply_buffer_heads_single_head() {
let mut buffer = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0, 0.0, 0.0, ];
let instruction = MultiplyBufferHeadsInstruction::new(0, 4, 5, 4, 1).unwrap();
instruction.apply(&mut buffer).unwrap();
let expected = vec![5.0, 10.0, 15.0, 20.0];
assert_eq!(&buffer[5..9], &expected[..]);
}
}