1use crate::csr_array::CsrArray;
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::{Array1, ArrayView1};
10use scirs2_core::numeric::{Float, SparseElement};
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::sync::Arc;
14
15#[cfg(feature = "gpu")]
16use scirs2_core::gpu::{GpuDevice, GpuError};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum VulkanOptimizationLevel {
21 Basic,
23 ComputeShader,
25 Subgroup,
27 Maximum,
29}
30
31#[derive(Debug, Clone)]
33pub struct VulkanDeviceInfo {
34 pub device_name: String,
35 pub vendor_id: u32,
36 pub device_type: VulkanDeviceType,
37 pub max_compute_shared_memory_size: usize,
38 pub max_compute_work_group_count: [u32; 3],
39 pub max_compute_work_group_invocations: u32,
40 pub max_compute_work_group_size: [u32; 3],
41 pub subgroup_size: u32,
42 pub supports_subgroups: bool,
43 pub supports_int8: bool,
44 pub supports_int16: bool,
45 pub supports_float64: bool,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum VulkanDeviceType {
50 Other,
51 IntegratedGpu,
52 DiscreteGpu,
53 VirtualGpu,
54 Cpu,
55}
56
57impl VulkanDeviceInfo {
58 pub fn detect() -> Self {
60 Self {
63 device_name: "Default Vulkan Device".to_string(),
64 vendor_id: 0,
65 device_type: VulkanDeviceType::DiscreteGpu,
66 max_compute_shared_memory_size: 32768, max_compute_work_group_count: [65535, 65535, 65535],
68 max_compute_work_group_invocations: 1024,
69 max_compute_work_group_size: [1024, 1024, 64],
70 subgroup_size: 32,
71 supports_subgroups: true,
72 supports_int8: true,
73 supports_int16: true,
74 supports_float64: true,
75 }
76 }
77
78 pub fn is_nvidia(&self) -> bool {
80 self.vendor_id == 0x10DE
81 }
82
83 pub fn is_amd(&self) -> bool {
85 self.vendor_id == 0x1002
86 }
87
88 pub fn is_intel(&self) -> bool {
90 self.vendor_id == 0x8086
91 }
92
93 pub fn optimal_workgroup_size(&self) -> usize {
95 if self.supports_subgroups {
96 self.subgroup_size as usize
97 } else {
98 64 }
100 }
101}
102
103#[derive(Debug)]
105pub struct VulkanMemoryManager {
106 allocated_buffers: HashMap<String, usize>,
107 total_allocated: usize,
108 peak_usage: usize,
109}
110
111impl VulkanMemoryManager {
112 pub fn new() -> Self {
113 Self {
114 allocated_buffers: HashMap::new(),
115 total_allocated: 0,
116 peak_usage: 0,
117 }
118 }
119
120 pub fn allocate(&mut self, id: String, size: usize) -> SparseResult<()> {
121 self.allocated_buffers.insert(id, size);
122 self.total_allocated += size;
123 if self.total_allocated > self.peak_usage {
124 self.peak_usage = self.total_allocated;
125 }
126 Ok(())
127 }
128
129 pub fn deallocate(&mut self, id: &str) -> SparseResult<()> {
130 if let Some(size) = self.allocated_buffers.remove(id) {
131 self.total_allocated = self.total_allocated.saturating_sub(size);
132 }
133 Ok(())
134 }
135
136 pub fn current_usage(&self) -> usize {
137 self.total_allocated
138 }
139
140 pub fn peak_usage(&self) -> usize {
141 self.peak_usage
142 }
143
144 pub fn reset(&mut self) {
145 self.allocated_buffers.clear();
146 self.total_allocated = 0;
147 self.peak_usage = 0;
148 }
149}
150
151impl Default for VulkanMemoryManager {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157pub struct VulkanSpMatVec {
159 device_info: VulkanDeviceInfo,
160 memory_manager: VulkanMemoryManager,
161 shader_cache: HashMap<String, Arc<Vec<u8>>>,
162}
163
164impl VulkanSpMatVec {
165 pub fn new() -> SparseResult<Self> {
167 let device_info = VulkanDeviceInfo::detect();
168
169 Ok(Self {
170 device_info,
171 memory_manager: VulkanMemoryManager::new(),
172 shader_cache: HashMap::new(),
173 })
174 }
175
176 pub fn device_info(&self) -> &VulkanDeviceInfo {
178 &self.device_info
179 }
180
181 pub fn memory_manager(&self) -> &VulkanMemoryManager {
183 &self.memory_manager
184 }
185
186 pub fn memory_manager_mut(&mut self) -> &mut VulkanMemoryManager {
188 &mut self.memory_manager
189 }
190
191 #[cfg(feature = "gpu")]
193 pub fn execute_spmv<T>(
194 &self,
195 matrix: &CsrArray<T>,
196 vector: &ArrayView1<T>,
197 device: &GpuDevice,
198 ) -> SparseResult<Array1<T>>
199 where
200 T: Float + SparseElement + Debug + Copy + std::iter::Sum,
201 {
202 self.execute_optimized_spmv(
203 matrix,
204 vector,
205 device,
206 VulkanOptimizationLevel::ComputeShader,
207 )
208 }
209
210 #[cfg(feature = "gpu")]
212 pub fn execute_optimized_spmv<T>(
213 &self,
214 matrix: &CsrArray<T>,
215 vector: &ArrayView1<T>,
216 device: &GpuDevice,
217 optimization_level: VulkanOptimizationLevel,
218 ) -> SparseResult<Array1<T>>
219 where
220 T: Float + SparseElement + Debug + Copy + std::iter::Sum,
221 {
222 let (nrows, ncols) = matrix.shape();
224 if vector.len() != ncols {
225 return Err(SparseError::DimensionMismatch {
226 expected: ncols,
227 found: vector.len(),
228 });
229 }
230
231 matrix.dot_vector(vector)
242 }
243
244 pub fn execute_spmv_cpu<T>(
246 &self,
247 matrix: &CsrArray<T>,
248 vector: &ArrayView1<T>,
249 ) -> SparseResult<Array1<T>>
250 where
251 T: Float + SparseElement + Debug + Copy + std::iter::Sum,
252 {
253 matrix.dot_vector(vector)
254 }
255
256 fn get_spmv_shader_source(&self, optimization_level: VulkanOptimizationLevel) -> &str {
258 match optimization_level {
259 VulkanOptimizationLevel::Basic => {
260 r#"
262#version 450
263
264layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
265
266layout(set = 0, binding = 0) readonly buffer IndptrBuffer {
267 uint indptr[];
268};
269
270layout(set = 0, binding = 1) readonly buffer IndicesBuffer {
271 uint indices[];
272};
273
274layout(set = 0, binding = 2) readonly buffer DataBuffer {
275 float data[];
276};
277
278layout(set = 0, binding = 3) readonly buffer VectorBuffer {
279 float vector[];
280};
281
282layout(set = 0, binding = 4) writeonly buffer ResultBuffer {
283 float result[];
284};
285
286layout(push_constant) uniform PushConstants {
287 uint nrows;
288} pc;
289
290void main() {
291 uint row = gl_GlobalInvocationID.x;
292
293 if (row >= pc.nrows) {
294 return;
295 }
296
297 uint row_start = indptr[row];
298 uint row_end = indptr[row + 1];
299
300 float sum = 0.0;
301 for (uint i = row_start; i < row_end; i++) {
302 uint col = indices[i];
303 sum += data[i] * vector[col];
304 }
305
306 result[row] = sum;
307}
308"#
309 }
310 VulkanOptimizationLevel::ComputeShader => {
311 r#"
313#version 450
314
315layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
316
317layout(set = 0, binding = 0) readonly buffer IndptrBuffer {
318 uint indptr[];
319};
320
321layout(set = 0, binding = 1) readonly buffer IndicesBuffer {
322 uint indices[];
323};
324
325layout(set = 0, binding = 2) readonly buffer DataBuffer {
326 float data[];
327};
328
329layout(set = 0, binding = 3) readonly buffer VectorBuffer {
330 float vector[];
331};
332
333layout(set = 0, binding = 4) writeonly buffer ResultBuffer {
334 float result[];
335};
336
337layout(push_constant) uniform PushConstants {
338 uint nrows;
339} pc;
340
341shared float shared_vector[256];
342
343void main() {
344 uint row = gl_GlobalInvocationID.x;
345 uint local_id = gl_LocalInvocationID.x;
346
347 if (row >= pc.nrows) {
348 return;
349 }
350
351 uint row_start = indptr[row];
352 uint row_end = indptr[row + 1];
353
354 float sum = 0.0;
355 for (uint i = row_start; i < row_end; i++) {
356 uint col = indices[i];
357
358 // Cooperative loading to shared memory for better cache utilization
359 if (col < 256) {
360 shared_vector[col] = vector[col];
361 memoryBarrierShared();
362 barrier();
363 sum += data[i] * shared_vector[col];
364 } else {
365 sum += data[i] * vector[col];
366 }
367 }
368
369 result[row] = sum;
370}
371"#
372 }
373 VulkanOptimizationLevel::Subgroup => {
374 r#"
376#version 450
377#extension GL_KHR_shader_subgroup_arithmetic : enable
378
379layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
380
381layout(set = 0, binding = 0) readonly buffer IndptrBuffer {
382 uint indptr[];
383};
384
385layout(set = 0, binding = 1) readonly buffer IndicesBuffer {
386 uint indices[];
387};
388
389layout(set = 0, binding = 2) readonly buffer DataBuffer {
390 float data[];
391};
392
393layout(set = 0, binding = 3) readonly buffer VectorBuffer {
394 float vector[];
395};
396
397layout(set = 0, binding = 4) writeonly buffer ResultBuffer {
398 float result[];
399};
400
401layout(push_constant) uniform PushConstants {
402 uint nrows;
403} pc;
404
405void main() {
406 uint row = gl_GlobalInvocationID.x;
407
408 if (row >= pc.nrows) {
409 return;
410 }
411
412 uint row_start = indptr[row];
413 uint row_end = indptr[row + 1];
414
415 float sum = 0.0;
416 for (uint i = row_start; i < row_end; i++) {
417 uint col = indices[i];
418 sum += data[i] * vector[col];
419 }
420
421 // Use subgroup reduction for better performance
422 sum = subgroupAdd(sum);
423
424 if (subgroupElect()) {
425 result[row] = sum;
426 }
427}
428"#
429 }
430 VulkanOptimizationLevel::Maximum => {
431 self.get_spmv_shader_source(VulkanOptimizationLevel::Subgroup)
433 }
434 }
435 }
436
437 fn compile_shader(&mut self, source: &str, name: &str) -> SparseResult<Arc<Vec<u8>>> {
439 let bytecode = Arc::new(source.as_bytes().to_vec());
442 self.shader_cache.insert(name.to_string(), bytecode.clone());
443 Ok(bytecode)
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_vulkan_device_info() {
453 let info = VulkanDeviceInfo::detect();
454 assert!(!info.device_name.is_empty());
455 assert!(info.optimal_workgroup_size() > 0);
456 }
457
458 #[test]
459 fn test_vulkan_memory_manager() {
460 let mut manager = VulkanMemoryManager::new();
461
462 manager
463 .allocate("buffer1".to_string(), 1024)
464 .expect("Failed to allocate");
465 assert_eq!(manager.current_usage(), 1024);
466
467 manager
468 .allocate("buffer2".to_string(), 2048)
469 .expect("Failed to allocate");
470 assert_eq!(manager.current_usage(), 3072);
471 assert_eq!(manager.peak_usage(), 3072);
472
473 manager.deallocate("buffer1").expect("Failed to deallocate");
474 assert_eq!(manager.current_usage(), 2048);
475 assert_eq!(manager.peak_usage(), 3072);
476
477 manager.reset();
478 assert_eq!(manager.current_usage(), 0);
479 }
480
481 #[test]
482 fn test_vulkan_spmv_creation() {
483 let result = VulkanSpMatVec::new();
484 assert!(result.is_ok());
485
486 let spmv = result.expect("Failed to create");
487 assert!(spmv.device_info().optimal_workgroup_size() > 0);
488 }
489
490 #[test]
491 fn test_vulkan_cpu_fallback() {
492 let spmv = VulkanSpMatVec::new().expect("Failed to create");
493
494 let rows = vec![0, 0, 1, 2];
496 let cols = vec![0, 1, 1, 2];
497 let data = vec![1.0, 2.0, 3.0, 4.0];
498 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false)
499 .expect("Failed to create matrix");
500
501 let vector = Array1::from_vec(vec![1.0, 2.0, 3.0]);
502 let result = spmv
503 .execute_spmv_cpu(&matrix, &vector.view())
504 .expect("Failed to execute");
505
506 assert_eq!(result.len(), 3);
507 }
508
509 #[test]
510 fn test_shader_source_generation() {
511 let spmv = VulkanSpMatVec::new().expect("Failed to create");
512
513 let basic_shader = spmv.get_spmv_shader_source(VulkanOptimizationLevel::Basic);
514 assert!(basic_shader.contains("#version 450"));
515 assert!(basic_shader.contains("layout"));
516
517 let optimized_shader = spmv.get_spmv_shader_source(VulkanOptimizationLevel::ComputeShader);
518 assert!(optimized_shader.contains("shared"));
519
520 let subgroup_shader = spmv.get_spmv_shader_source(VulkanOptimizationLevel::Subgroup);
521 assert!(subgroup_shader.contains("subgroup"));
522 }
523}