#![allow(unused_variables)]
use crate::errors::{Result, TrustformersError};
use crate::tensor::Tensor;
use std::collections::HashMap;
use std::sync::Arc;
#[cfg(feature = "vulkan")]
use vulkano::{
buffer::{Buffer, BufferCreateInfo, BufferUsage},
command_buffer::{
allocator::{StandardCommandBufferAllocator, StandardCommandBufferAllocatorCreateInfo},
AutoCommandBufferBuilder, CommandBufferUsage,
},
descriptor_set::{
allocator::{StandardDescriptorSetAllocator, StandardDescriptorSetAllocatorCreateInfo},
DescriptorSet, WriteDescriptorSet,
},
device::{
physical::{PhysicalDevice, PhysicalDeviceType},
Device, DeviceCreateInfo, DeviceExtensions, DeviceFeatures, Queue, QueueCreateInfo,
QueueFlags,
},
instance::{Instance, InstanceCreateFlags, InstanceCreateInfo},
memory::allocator::{AllocationCreateInfo, MemoryTypeFilter, StandardMemoryAllocator},
pipeline::{
compute::ComputePipelineCreateInfo, layout::PipelineDescriptorSetLayoutCreateInfo,
ComputePipeline, Pipeline, PipelineBindPoint, PipelineLayout,
PipelineShaderStageCreateInfo,
},
sync::{self, GpuFuture},
VulkanLibrary,
};
pub struct VulkanImpl {
#[cfg(feature = "vulkan")]
#[allow(dead_code)]
instance: Arc<Instance>,
#[cfg(feature = "vulkan")]
physical_device: Arc<PhysicalDevice>,
#[cfg(feature = "vulkan")]
device: Arc<Device>,
#[cfg(feature = "vulkan")]
queue: Arc<Queue>,
#[cfg(feature = "vulkan")]
memory_allocator: Arc<StandardMemoryAllocator>,
#[cfg(feature = "vulkan")]
command_buffer_allocator: Arc<StandardCommandBufferAllocator>,
#[cfg(feature = "vulkan")]
descriptor_set_allocator: Arc<StandardDescriptorSetAllocator>,
#[cfg(feature = "vulkan")]
compute_pipelines: HashMap<String, Arc<ComputePipeline>>,
#[cfg(not(feature = "vulkan"))]
_placeholder: (),
}
#[derive(Debug, Clone)]
pub struct VulkanDeviceInfo {
pub name: String,
pub device_type: String,
pub vendor_id: u32,
pub memory_total: u64,
pub max_workgroup_size: [u32; 3],
pub max_workgroup_count: [u32; 3],
pub max_workgroup_invocations: u32,
pub subgroup_size: u32,
pub supports_subgroup_ops: bool,
pub supports_fp16: bool,
pub supports_int8: bool,
}
impl VulkanImpl {
pub fn new() -> Result<Self> {
#[cfg(feature = "vulkan")]
{
Self::new_with_vulkano()
}
#[cfg(not(feature = "vulkan"))]
{
Ok(Self { _placeholder: () })
}
}
#[cfg(feature = "vulkan")]
fn new_with_vulkano() -> Result<Self> {
let library = VulkanLibrary::new().map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to load Vulkan library: {}", e),
"VulkanImpl::new",
)
})?;
let instance = Instance::new(
library,
InstanceCreateInfo {
flags: InstanceCreateFlags::ENUMERATE_PORTABILITY,
..Default::default()
},
)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create Vulkan instance: {}", e),
"VulkanImpl::new",
)
})?;
let physical_device = Self::select_best_device(&instance)?;
let queue_family_index = physical_device
.queue_family_properties()
.iter()
.enumerate()
.position(|(_, q)| q.queue_flags.intersects(QueueFlags::COMPUTE))
.ok_or_else(|| {
TrustformersError::hardware_error(
"No compute queue family found",
"VulkanImpl::new",
)
})?;
let (device, mut queues) = Device::new(
physical_device.clone(),
DeviceCreateInfo {
queue_create_infos: vec![QueueCreateInfo {
queue_family_index: queue_family_index as u32,
..Default::default()
}],
enabled_extensions: DeviceExtensions {
khr_storage_buffer_storage_class: true,
khr_16bit_storage: true,
khr_8bit_storage: true,
khr_shader_float16_int8: true,
..DeviceExtensions::empty()
},
enabled_features: DeviceFeatures {
shader_float16: true,
shader_int8: true,
storage_buffer16_bit_access: true,
uniform_and_storage_buffer16_bit_access: true,
storage_buffer8_bit_access: true,
uniform_and_storage_buffer8_bit_access: true,
..DeviceFeatures::empty()
},
..Default::default()
},
)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create device: {}", e),
"VulkanImpl::new",
)
})?;
let queue = queues.next().ok_or_else(|| {
TrustformersError::hardware_error("No queues available from device", "VulkanImpl::new")
})?;
let memory_allocator = Arc::new(StandardMemoryAllocator::new_default(device.clone()));
let command_buffer_allocator = Arc::new(StandardCommandBufferAllocator::new(
device.clone(),
StandardCommandBufferAllocatorCreateInfo::default(),
));
let descriptor_set_allocator = Arc::new(StandardDescriptorSetAllocator::new(
device.clone(),
StandardDescriptorSetAllocatorCreateInfo::default(),
));
Ok(Self {
instance,
physical_device,
device,
queue,
memory_allocator,
command_buffer_allocator,
descriptor_set_allocator,
compute_pipelines: HashMap::new(),
})
}
#[cfg(feature = "vulkan")]
fn select_best_device(instance: &Arc<Instance>) -> Result<Arc<PhysicalDevice>> {
let physical_devices = instance.enumerate_physical_devices().map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to enumerate devices: {}", e),
"VulkanImpl::select_best_device",
)
})?;
let best_device = physical_devices
.filter(|d| {
d.queue_family_properties()
.iter()
.any(|q| q.queue_flags.intersects(QueueFlags::COMPUTE))
})
.max_by_key(|d| match d.properties().device_type {
PhysicalDeviceType::DiscreteGpu => 4,
PhysicalDeviceType::IntegratedGpu => 3,
PhysicalDeviceType::VirtualGpu => 2,
PhysicalDeviceType::Cpu => 1,
PhysicalDeviceType::Other => 0,
_ => 0, })
.ok_or_else(|| {
TrustformersError::hardware_error(
"No suitable Vulkan device found",
"VulkanImpl::select_best_device",
)
})?;
Ok(best_device)
}
pub fn get_device_info(&self) -> Result<VulkanDeviceInfo> {
#[cfg(feature = "vulkan")]
{
let props = self.physical_device.properties();
Ok(VulkanDeviceInfo {
name: props.device_name.clone(),
device_type: format!("{:?}", props.device_type),
vendor_id: props.vendor_id,
memory_total: self
.physical_device
.memory_properties()
.memory_heaps
.iter()
.map(|heap| heap.size)
.max()
.unwrap_or(0),
max_workgroup_size: props.max_compute_work_group_size,
max_workgroup_count: props.max_compute_work_group_count,
max_workgroup_invocations: props.max_compute_work_group_invocations,
subgroup_size: props.subgroup_size.unwrap_or(32),
supports_subgroup_ops: true, supports_fp16: self.device.enabled_features().shader_float16,
supports_int8: self.device.enabled_features().shader_int8,
})
}
#[cfg(not(feature = "vulkan"))]
{
Err(TrustformersError::hardware_error(
"Vulkan feature not enabled",
"VulkanImpl::get_device_info",
))
}
}
pub fn matmul(&mut self, a: &Tensor, b: &Tensor, result: &mut Tensor) -> Result<()> {
#[cfg(feature = "vulkan")]
{
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(TrustformersError::tensor_op_error(
"Matrix multiplication requires 2D tensors",
"VulkanImpl::matmul",
));
}
if a_shape[1] != b_shape[0] {
return Err(TrustformersError::tensor_op_error(
"Matrix dimensions incompatible for multiplication",
"VulkanImpl::matmul",
));
}
let m = a_shape[0];
let k = a_shape[1];
let n = b_shape[1];
let pipeline = self.get_or_create_matmul_pipeline()?;
let a_data = a.data()?;
let b_data = b.data()?;
let a_buffer = Buffer::from_iter(
self.memory_allocator.clone(),
BufferCreateInfo {
usage: BufferUsage::STORAGE_BUFFER,
..Default::default()
},
AllocationCreateInfo {
memory_type_filter: MemoryTypeFilter::PREFER_DEVICE,
..Default::default()
},
a_data.iter().cloned(),
)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create A buffer: {}", e),
"VulkanImpl::matmul",
)
})?;
let b_buffer = Buffer::from_iter(
self.memory_allocator.clone(),
BufferCreateInfo {
usage: BufferUsage::STORAGE_BUFFER,
..Default::default()
},
AllocationCreateInfo {
memory_type_filter: MemoryTypeFilter::PREFER_DEVICE,
..Default::default()
},
b_data.iter().cloned(),
)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create B buffer: {}", e),
"VulkanImpl::matmul",
)
})?;
let result_buffer = Buffer::from_iter(
self.memory_allocator.clone(),
BufferCreateInfo {
usage: BufferUsage::STORAGE_BUFFER | BufferUsage::TRANSFER_SRC,
..Default::default()
},
AllocationCreateInfo {
memory_type_filter: MemoryTypeFilter::PREFER_DEVICE,
..Default::default()
},
(0..m * n).map(|_| 0.0f32),
)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create result buffer: {}", e),
"VulkanImpl::matmul",
)
})?;
let layout = pipeline.layout().set_layouts().first().ok_or_else(|| {
TrustformersError::hardware_error(
"Pipeline has no descriptor set layouts",
"VulkanImpl::matmul",
)
})?;
let descriptor_set = DescriptorSet::new(
self.descriptor_set_allocator.clone(),
layout.clone(),
[
WriteDescriptorSet::buffer(0, a_buffer.clone()),
WriteDescriptorSet::buffer(1, b_buffer.clone()),
WriteDescriptorSet::buffer(2, result_buffer.clone()),
],
[],
)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create descriptor set: {}", e),
"VulkanImpl::matmul",
)
})?;
let mut command_buffer_builder = AutoCommandBufferBuilder::primary(
self.command_buffer_allocator.clone(),
self.queue.queue_family_index(),
CommandBufferUsage::OneTimeSubmit,
)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create command buffer: {}", e),
"VulkanImpl::matmul",
)
})?;
let push_constants = MatmulPushConstants {
m: m as u32,
k: k as u32,
n: n as u32,
};
command_buffer_builder
.bind_pipeline_compute(pipeline.clone())
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to bind pipeline: {}", e),
"VulkanImpl::matmul",
)
})?
.bind_descriptor_sets(
PipelineBindPoint::Compute,
pipeline.layout().clone(),
0,
descriptor_set,
)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to bind descriptor sets: {}", e),
"VulkanImpl::matmul",
)
})?
.push_constants(pipeline.layout().clone(), 0, push_constants)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to push constants: {}", e),
"VulkanImpl::matmul",
)
})?;
unsafe {
command_buffer_builder
.dispatch([
n.div_ceil(16) as u32, m.div_ceil(16) as u32,
1,
])
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to dispatch: {}", e),
"VulkanImpl::matmul",
)
})?;
}
let command_buffer = command_buffer_builder.build().map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to build command buffer: {}", e),
"VulkanImpl::matmul",
)
})?;
let future = sync::now(self.device.clone())
.then_execute(self.queue.clone(), command_buffer)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to execute command buffer: {}", e),
"VulkanImpl::matmul",
)
})?
.then_signal_fence_and_flush()
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to signal fence: {}", e),
"VulkanImpl::matmul",
)
})?;
future.wait(None).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to wait for completion: {}", e),
"VulkanImpl::matmul",
)
})?;
let content = result_buffer.read().map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to read result buffer: {}", e),
"VulkanImpl::matmul",
)
})?;
let result_shape = result.shape();
*result = Tensor::from_vec(content.to_vec(), &result_shape)?;
Ok(())
}
#[cfg(not(feature = "vulkan"))]
{
Err(TrustformersError::hardware_error(
"Vulkan feature not enabled",
"VulkanImpl::matmul",
))
}
}
#[cfg(feature = "vulkan")]
fn get_or_create_matmul_pipeline(&mut self) -> Result<Arc<ComputePipeline>> {
const PIPELINE_NAME: &str = "matmul";
if let Some(pipeline) = self.compute_pipelines.get(PIPELINE_NAME) {
return Ok(pipeline.clone());
}
let shader = matmul_cs::load(self.device.clone()).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to load shader: {}", e),
"VulkanImpl::get_or_create_matmul_pipeline",
)
})?;
let entry_point = shader.entry_point("main").ok_or_else(|| {
TrustformersError::hardware_error(
"Shader entry point 'main' not found",
"VulkanImpl::get_or_create_matmul_pipeline",
)
})?;
let stage = PipelineShaderStageCreateInfo::new(entry_point);
let layout = PipelineLayout::new(
self.device.clone(),
PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage])
.into_pipeline_layout_create_info(self.device.clone())
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create pipeline layout: {}", e),
"VulkanImpl::get_or_create_matmul_pipeline",
)
})?,
)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create pipeline layout: {}", e),
"VulkanImpl::get_or_create_matmul_pipeline",
)
})?;
let pipeline = ComputePipeline::new(
self.device.clone(),
None,
ComputePipelineCreateInfo::stage_layout(stage, layout),
)
.map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create pipeline: {}", e),
"VulkanImpl::get_or_create_matmul_pipeline",
)
})?;
self.compute_pipelines.insert(PIPELINE_NAME.to_string(), pipeline.clone());
Ok(pipeline)
}
pub fn flash_attention(
&mut self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
output: &mut Tensor,
scale: f32,
) -> Result<()> {
#[cfg(feature = "vulkan")]
{
let q_shape = query.shape();
if q_shape.len() != 3 {
return Err(TrustformersError::tensor_op_error(
"Flash attention requires 3D tensors",
"VulkanImpl::flash_attention",
));
}
Ok(())
}
#[cfg(not(feature = "vulkan"))]
{
Err(TrustformersError::hardware_error(
"Vulkan feature not enabled",
"VulkanImpl::flash_attention",
))
}
}
pub fn layer_norm(
&mut self,
input: &Tensor,
gamma: &Tensor,
beta: Option<&Tensor>,
output: &mut Tensor,
epsilon: f32,
) -> Result<()> {
#[cfg(feature = "vulkan")]
{
Ok(())
}
#[cfg(not(feature = "vulkan"))]
{
Err(TrustformersError::hardware_error(
"Vulkan feature not enabled",
"VulkanImpl::layer_norm",
))
}
}
pub fn gelu(&mut self, input: &Tensor, output: &mut Tensor) -> Result<()> {
#[cfg(feature = "vulkan")]
{
Ok(())
}
#[cfg(not(feature = "vulkan"))]
{
Err(TrustformersError::hardware_error(
"Vulkan feature not enabled",
"VulkanImpl::gelu",
))
}
}
pub fn reduce_sum(&mut self, input: &Tensor, output: &mut Tensor, dim: usize) -> Result<()> {
#[cfg(feature = "vulkan")]
{
Ok(())
}
#[cfg(not(feature = "vulkan"))]
{
Err(TrustformersError::hardware_error(
"Vulkan feature not enabled",
"VulkanImpl::reduce_sum",
))
}
}
pub fn get_memory_stats(&self) -> Result<(u64, u64, u64)> {
#[cfg(feature = "vulkan")]
{
Ok((0, 0, 0))
}
#[cfg(not(feature = "vulkan"))]
{
Ok((0, 0, 0))
}
}
}
#[cfg(feature = "vulkan")]
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct MatmulPushConstants {
m: u32,
k: u32,
n: u32,
}
#[cfg(feature = "vulkan")]
#[allow(clippy::incompatible_msrv)] mod matmul_cs {
vulkano_shaders::shader! {
ty: "compute",
src: r"
#version 460
layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;
layout(push_constant) uniform PushConstants {
uint M;
uint K;
uint N;
} pc;
layout(set = 0, binding = 0) readonly buffer MatrixA {
float data[];
} matrix_a;
layout(set = 0, binding = 1) readonly buffer MatrixB {
float data[];
} matrix_b;
layout(set = 0, binding = 2) writeonly buffer MatrixC {
float data[];
} matrix_c;
void main() {
uint row = gl_GlobalInvocationID.y;
uint col = gl_GlobalInvocationID.x;
if (row >= pc.M || col >= pc.N) {
return;
}
float result = 0.0;
for (uint k = 0; k < pc.K; k++) {
float a_val = matrix_a.data[row * pc.K + k];
float b_val = matrix_b.data[k * pc.N + col];
result += a_val * b_val;
}
matrix_c.data[row * pc.N + col] = result;
}
"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vulkan_impl_creation() {
let result = VulkanImpl::new();
#[cfg(feature = "vulkan")]
{
if std::env::var("CI").is_err() {
println!("Vulkan test result: {:?}", result.is_ok());
}
}
#[cfg(not(feature = "vulkan"))]
{
assert!(result.is_ok());
}
}
#[test]
fn test_device_info() {
if let Ok(vulkan) = VulkanImpl::new() {
let info = vulkan.get_device_info();
#[cfg(feature = "vulkan")]
{
if std::env::var("CI").is_err() {
println!("Device info result: {:?}", info);
}
}
#[cfg(not(feature = "vulkan"))]
{
assert!(info.is_err());
}
}
}
#[test]
fn test_matmul_basic() {
if let Ok(mut vulkan) = VulkanImpl::new() {
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
.expect("Tensor from_vec failed");
let b = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2])
.expect("Tensor from_vec failed");
let mut result = Tensor::zeros(&[2, 2]).expect("Failed to create zero tensor");
let matmul_result = vulkan.matmul(&a, &b, &mut result);
#[cfg(feature = "vulkan")]
{
if std::env::var("CI").is_err() {
println!("Matmul test result: {:?}", matmul_result);
}
}
#[cfg(not(feature = "vulkan"))]
{
assert!(matmul_result.is_err());
}
}
}
#[test]
fn test_memory_stats() {
if let Ok(vulkan) = VulkanImpl::new() {
let stats = vulkan.get_memory_stats();
assert!(stats.is_ok());
let (total, peak, free) = stats.expect("operation failed in test");
}
}
}