use crate::device::context::GpuContextInfo;
use crate::{Result, Tensor, TensorError};
use scirs2_core::numeric::Float;
use std::sync::Arc;
use wgpu::util::DeviceExt;
pub struct GpuAttentionOps {
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
scaled_dot_product_attention_pipeline: wgpu::ComputePipeline,
flash_attention_pipeline: wgpu::ComputePipeline,
multi_head_attention_pipeline: wgpu::ComputePipeline,
}
impl GpuAttentionOps {
pub fn new(gpu_context_info: &GpuContextInfo) -> Result<Self> {
let device = gpu_context_info.device.clone();
let queue = gpu_context_info.queue.clone();
let shader_source = include_str!("shaders/attention_ops.wgsl");
let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Attention Operations Shader"),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
let attention_bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Attention Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 4,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 5,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Attention Pipeline Layout"),
bind_group_layouts: &[Some(&attention_bind_group_layout)],
immediate_size: 0,
});
let scaled_dot_product_attention_pipeline =
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Scaled Dot-Product Attention Pipeline"),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: Some("scaled_dot_product_attention_kernel"),
cache: None,
compilation_options: Default::default(),
});
let flash_attention_pipeline =
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Flash Attention Pipeline"),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: Some("flash_attention_kernel"),
cache: None,
compilation_options: Default::default(),
});
let multi_head_attention_pipeline =
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Multi-Head Attention Pipeline"),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: Some("multi_head_attention_kernel"),
cache: None,
compilation_options: Default::default(),
});
Ok(Self {
device,
queue,
scaled_dot_product_attention_pipeline,
flash_attention_pipeline,
multi_head_attention_pipeline,
})
}
pub fn scaled_dot_product_attention<T>(
&self,
query: &Tensor<T>,
key: &Tensor<T>,
value: &Tensor<T>,
mask: Option<&Tensor<T>>,
scale: Option<T>,
) -> Result<Tensor<T>>
where
T: Float
+ Default
+ bytemuck::Pod
+ bytemuck::Zeroable
+ Send
+ Sync
+ 'static
+ Clone
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One,
{
let query_shape = query.shape();
let seq_len = query_shape.dims()[0];
let head_dim = query_shape.dims()[1];
let scale_factor = scale.unwrap_or_else(|| {
T::from(head_dim as f64)
.expect("fallback value computation failed")
.sqrt()
.recip()
});
let query_buffer = self.create_buffer_from_tensor(query, "Query Buffer")?;
let key_buffer = self.create_buffer_from_tensor(key, "Key Buffer")?;
let value_buffer = self.create_buffer_from_tensor(value, "Value Buffer")?;
let output_size = (seq_len * head_dim * std::mem::size_of::<T>()) as wgpu::BufferAddress;
let output_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Attention Output Buffer"),
size: output_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let mask_buffer = if let Some(mask) = mask {
self.create_buffer_from_tensor(mask, "Mask Buffer")?
} else {
self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Empty Mask Buffer"),
size: 4, usage: wgpu::BufferUsages::STORAGE,
mapped_at_creation: false,
})
};
let params = [
seq_len as u32,
head_dim as u32,
scale_factor
.to_f32()
.expect("numeric conversion should succeed")
.to_bits(),
if mask.is_some() { 1u32 } else { 0u32 },
];
let params_buffer = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Attention Parameters Buffer"),
contents: bytemuck::cast_slice(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Attention Bind Group"),
layout: &self
.scaled_dot_product_attention_pipeline
.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: query_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: key_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: value_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: mask_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: output_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: params_buffer.as_entire_binding(),
},
],
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Attention Command Encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Attention Compute Pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&self.scaled_dot_product_attention_pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
let workgroup_size = 256;
let num_workgroups = ((seq_len as u32) + workgroup_size - 1) / workgroup_size;
compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
}
self.queue.submit(Some(encoder.finish()));
let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Attention Staging Buffer"),
size: output_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Attention Copy Encoder"),
});
encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, output_size);
self.queue.submit(Some(encoder.finish()));
let buffer_slice = staging_buffer.slice(..);
buffer_slice.map_async(wgpu::MapMode::Read, |v| {
if let Err(e) = v {
eprintln!("Buffer mapping failed: {:?}", e);
}
});
self.device.poll(wgpu::PollType::wait_indefinitely()).ok();
let data = buffer_slice.get_mapped_range();
let result: Vec<T> = bytemuck::cast_slice(&data).to_vec();
drop(data);
staging_buffer.unmap();
Tensor::from_vec(result, query_shape.dims())
}
pub fn multi_head_attention<T>(
&self,
query: &Tensor<T>,
key: &Tensor<T>,
value: &Tensor<T>,
num_heads: usize,
mask: Option<&Tensor<T>>,
) -> Result<Tensor<T>>
where
T: Float + Default + bytemuck::Pod + bytemuck::Zeroable + Send + Sync + 'static,
{
let query_shape = query.shape();
let seq_len = query_shape.dims()[0];
let embed_dim = query_shape.dims()[1];
let head_dim = embed_dim / num_heads;
if embed_dim % num_heads != 0 {
return Err(TensorError::invalid_argument(format!(
"Embedding dimension {} is not divisible by number of heads {}",
embed_dim, num_heads
)));
}
self.scaled_dot_product_attention(query, key, value, mask, None)
}
pub fn flash_attention<T>(
&self,
query: &Tensor<T>,
key: &Tensor<T>,
value: &Tensor<T>,
mask: Option<&Tensor<T>>,
block_size: Option<usize>,
) -> Result<Tensor<T>>
where
T: Float + Default + bytemuck::Pod + bytemuck::Zeroable + Send + Sync + 'static,
{
let _block_size = block_size.unwrap_or(256);
self.scaled_dot_product_attention(query, key, value, mask, None)
}
fn create_buffer_from_tensor<
T: bytemuck::Pod
+ bytemuck::Zeroable
+ Clone
+ Default
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::Zero
+ scirs2_core::num_traits::One,
>(
&self,
tensor: &Tensor<T>,
label: &str,
) -> Result<wgpu::Buffer> {
let data = tensor.to_vec()?;
let buffer = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(label),
contents: bytemuck::cast_slice(&data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
});
Ok(buffer)
}
}