use crate::csr_array::CsrArray;
use crate::error::{SparseError, SparseResult};
use crate::sparray::SparseArray;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, SparseElement};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
#[cfg(feature = "gpu")]
use scirs2_core::gpu::{GpuDevice, GpuError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VulkanOptimizationLevel {
Basic,
ComputeShader,
Subgroup,
Maximum,
}
#[derive(Debug, Clone)]
pub struct VulkanDeviceInfo {
pub device_name: String,
pub vendor_id: u32,
pub device_type: VulkanDeviceType,
pub max_compute_shared_memory_size: usize,
pub max_compute_work_group_count: [u32; 3],
pub max_compute_work_group_invocations: u32,
pub max_compute_work_group_size: [u32; 3],
pub subgroup_size: u32,
pub supports_subgroups: bool,
pub supports_int8: bool,
pub supports_int16: bool,
pub supports_float64: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VulkanDeviceType {
Other,
IntegratedGpu,
DiscreteGpu,
VirtualGpu,
Cpu,
}
impl VulkanDeviceInfo {
pub fn detect() -> Self {
Self {
device_name: "Default Vulkan Device".to_string(),
vendor_id: 0,
device_type: VulkanDeviceType::DiscreteGpu,
max_compute_shared_memory_size: 32768, max_compute_work_group_count: [65535, 65535, 65535],
max_compute_work_group_invocations: 1024,
max_compute_work_group_size: [1024, 1024, 64],
subgroup_size: 32,
supports_subgroups: true,
supports_int8: true,
supports_int16: true,
supports_float64: true,
}
}
pub fn is_nvidia(&self) -> bool {
self.vendor_id == 0x10DE
}
pub fn is_amd(&self) -> bool {
self.vendor_id == 0x1002
}
pub fn is_intel(&self) -> bool {
self.vendor_id == 0x8086
}
pub fn optimal_workgroup_size(&self) -> usize {
if self.supports_subgroups {
self.subgroup_size as usize
} else {
64 }
}
}
#[derive(Debug)]
pub struct VulkanMemoryManager {
allocated_buffers: HashMap<String, usize>,
total_allocated: usize,
peak_usage: usize,
}
impl VulkanMemoryManager {
pub fn new() -> Self {
Self {
allocated_buffers: HashMap::new(),
total_allocated: 0,
peak_usage: 0,
}
}
pub fn allocate(&mut self, id: String, size: usize) -> SparseResult<()> {
self.allocated_buffers.insert(id, size);
self.total_allocated += size;
if self.total_allocated > self.peak_usage {
self.peak_usage = self.total_allocated;
}
Ok(())
}
pub fn deallocate(&mut self, id: &str) -> SparseResult<()> {
if let Some(size) = self.allocated_buffers.remove(id) {
self.total_allocated = self.total_allocated.saturating_sub(size);
}
Ok(())
}
pub fn current_usage(&self) -> usize {
self.total_allocated
}
pub fn peak_usage(&self) -> usize {
self.peak_usage
}
pub fn reset(&mut self) {
self.allocated_buffers.clear();
self.total_allocated = 0;
self.peak_usage = 0;
}
}
impl Default for VulkanMemoryManager {
fn default() -> Self {
Self::new()
}
}
pub struct VulkanSpMatVec {
device_info: VulkanDeviceInfo,
memory_manager: VulkanMemoryManager,
shader_cache: HashMap<String, Arc<Vec<u8>>>,
}
impl VulkanSpMatVec {
pub fn new() -> SparseResult<Self> {
let device_info = VulkanDeviceInfo::detect();
Ok(Self {
device_info,
memory_manager: VulkanMemoryManager::new(),
shader_cache: HashMap::new(),
})
}
pub fn device_info(&self) -> &VulkanDeviceInfo {
&self.device_info
}
pub fn memory_manager(&self) -> &VulkanMemoryManager {
&self.memory_manager
}
pub fn memory_manager_mut(&mut self) -> &mut VulkanMemoryManager {
&mut self.memory_manager
}
#[cfg(feature = "gpu")]
pub fn execute_spmv<T>(
&self,
matrix: &CsrArray<T>,
vector: &ArrayView1<T>,
device: &GpuDevice,
) -> SparseResult<Array1<T>>
where
T: Float + SparseElement + Debug + Copy + std::iter::Sum,
{
self.execute_optimized_spmv(
matrix,
vector,
device,
VulkanOptimizationLevel::ComputeShader,
)
}
#[cfg(feature = "gpu")]
pub fn execute_optimized_spmv<T>(
&self,
matrix: &CsrArray<T>,
vector: &ArrayView1<T>,
device: &GpuDevice,
optimization_level: VulkanOptimizationLevel,
) -> SparseResult<Array1<T>>
where
T: Float + SparseElement + Debug + Copy + std::iter::Sum,
{
let (nrows, ncols) = matrix.shape();
if vector.len() != ncols {
return Err(SparseError::DimensionMismatch {
expected: ncols,
found: vector.len(),
});
}
matrix.dot_vector(vector)
}
pub fn execute_spmv_cpu<T>(
&self,
matrix: &CsrArray<T>,
vector: &ArrayView1<T>,
) -> SparseResult<Array1<T>>
where
T: Float + SparseElement + Debug + Copy + std::iter::Sum,
{
matrix.dot_vector(vector)
}
fn get_spmv_shader_source(&self, optimization_level: VulkanOptimizationLevel) -> &str {
match optimization_level {
VulkanOptimizationLevel::Basic => {
r#"
#version 450
layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer IndptrBuffer {
uint indptr[];
};
layout(set = 0, binding = 1) readonly buffer IndicesBuffer {
uint indices[];
};
layout(set = 0, binding = 2) readonly buffer DataBuffer {
float data[];
};
layout(set = 0, binding = 3) readonly buffer VectorBuffer {
float vector[];
};
layout(set = 0, binding = 4) writeonly buffer ResultBuffer {
float result[];
};
layout(push_constant) uniform PushConstants {
uint nrows;
} pc;
void main() {
uint row = gl_GlobalInvocationID.x;
if (row >= pc.nrows) {
return;
}
uint row_start = indptr[row];
uint row_end = indptr[row + 1];
float sum = 0.0;
for (uint i = row_start; i < row_end; i++) {
uint col = indices[i];
sum += data[i] * vector[col];
}
result[row] = sum;
}
"#
}
VulkanOptimizationLevel::ComputeShader => {
r#"
#version 450
layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer IndptrBuffer {
uint indptr[];
};
layout(set = 0, binding = 1) readonly buffer IndicesBuffer {
uint indices[];
};
layout(set = 0, binding = 2) readonly buffer DataBuffer {
float data[];
};
layout(set = 0, binding = 3) readonly buffer VectorBuffer {
float vector[];
};
layout(set = 0, binding = 4) writeonly buffer ResultBuffer {
float result[];
};
layout(push_constant) uniform PushConstants {
uint nrows;
} pc;
shared float shared_vector[256];
void main() {
uint row = gl_GlobalInvocationID.x;
uint local_id = gl_LocalInvocationID.x;
if (row >= pc.nrows) {
return;
}
uint row_start = indptr[row];
uint row_end = indptr[row + 1];
float sum = 0.0;
for (uint i = row_start; i < row_end; i++) {
uint col = indices[i];
// Cooperative loading to shared memory for better cache utilization
if (col < 256) {
shared_vector[col] = vector[col];
memoryBarrierShared();
barrier();
sum += data[i] * shared_vector[col];
} else {
sum += data[i] * vector[col];
}
}
result[row] = sum;
}
"#
}
VulkanOptimizationLevel::Subgroup => {
r#"
#version 450
#extension GL_KHR_shader_subgroup_arithmetic : enable
layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer IndptrBuffer {
uint indptr[];
};
layout(set = 0, binding = 1) readonly buffer IndicesBuffer {
uint indices[];
};
layout(set = 0, binding = 2) readonly buffer DataBuffer {
float data[];
};
layout(set = 0, binding = 3) readonly buffer VectorBuffer {
float vector[];
};
layout(set = 0, binding = 4) writeonly buffer ResultBuffer {
float result[];
};
layout(push_constant) uniform PushConstants {
uint nrows;
} pc;
void main() {
uint row = gl_GlobalInvocationID.x;
if (row >= pc.nrows) {
return;
}
uint row_start = indptr[row];
uint row_end = indptr[row + 1];
float sum = 0.0;
for (uint i = row_start; i < row_end; i++) {
uint col = indices[i];
sum += data[i] * vector[col];
}
// Use subgroup reduction for better performance
sum = subgroupAdd(sum);
if (subgroupElect()) {
result[row] = sum;
}
}
"#
}
VulkanOptimizationLevel::Maximum => {
self.get_spmv_shader_source(VulkanOptimizationLevel::Subgroup)
}
}
}
fn compile_shader(&mut self, source: &str, name: &str) -> SparseResult<Arc<Vec<u8>>> {
let bytecode = Arc::new(source.as_bytes().to_vec());
self.shader_cache.insert(name.to_string(), bytecode.clone());
Ok(bytecode)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vulkan_device_info() {
let info = VulkanDeviceInfo::detect();
assert!(!info.device_name.is_empty());
assert!(info.optimal_workgroup_size() > 0);
}
#[test]
fn test_vulkan_memory_manager() {
let mut manager = VulkanMemoryManager::new();
manager
.allocate("buffer1".to_string(), 1024)
.expect("Failed to allocate");
assert_eq!(manager.current_usage(), 1024);
manager
.allocate("buffer2".to_string(), 2048)
.expect("Failed to allocate");
assert_eq!(manager.current_usage(), 3072);
assert_eq!(manager.peak_usage(), 3072);
manager.deallocate("buffer1").expect("Failed to deallocate");
assert_eq!(manager.current_usage(), 2048);
assert_eq!(manager.peak_usage(), 3072);
manager.reset();
assert_eq!(manager.current_usage(), 0);
}
#[test]
fn test_vulkan_spmv_creation() {
let result = VulkanSpMatVec::new();
assert!(result.is_ok());
let spmv = result.expect("Failed to create");
assert!(spmv.device_info().optimal_workgroup_size() > 0);
}
#[test]
fn test_vulkan_cpu_fallback() {
let spmv = VulkanSpMatVec::new().expect("Failed to create");
let rows = vec![0, 0, 1, 2];
let cols = vec![0, 1, 1, 2];
let data = vec![1.0, 2.0, 3.0, 4.0];
let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false)
.expect("Failed to create matrix");
let vector = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let result = spmv
.execute_spmv_cpu(&matrix, &vector.view())
.expect("Failed to execute");
assert_eq!(result.len(), 3);
}
#[test]
fn test_shader_source_generation() {
let spmv = VulkanSpMatVec::new().expect("Failed to create");
let basic_shader = spmv.get_spmv_shader_source(VulkanOptimizationLevel::Basic);
assert!(basic_shader.contains("#version 450"));
assert!(basic_shader.contains("layout"));
let optimized_shader = spmv.get_spmv_shader_source(VulkanOptimizationLevel::ComputeShader);
assert!(optimized_shader.contains("shared"));
let subgroup_shader = spmv.get_spmv_shader_source(VulkanOptimizationLevel::Subgroup);
assert!(subgroup_shader.contains("subgroup"));
}
}