1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
//! Sparse matrix operations CUDA kernel launchers
//!
//! This module provides Rust wrappers for CUDA sparse matrix kernels.
use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use cudarc::types::CudaTypeName;
use std::sync::Arc;
use super::loader::{get_kernel_function, get_or_load_module, kernel_names, launch_config};
use crate::error::{Error, Result};
// ============================================================================
// SpMV Launchers (Row-per-thread)
// ============================================================================
/// Launch CSR SpMV kernel (row-per-thread variant)
///
/// y = A * x where A is sparse CSR matrix
///
/// # Safety
///
/// Caller must ensure:
/// - All pointers are valid device pointers
/// - row_ptrs has length nrows + 1
/// - col_indices and values have length nnz
/// - x has length ncols
/// - y has length nrows
pub unsafe fn launch_csr_spmv<T: CudaTypeName>(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
row_ptrs: u64,
col_indices: u64,
values: u64,
x: u64,
y: u64,
nrows: usize,
) -> Result<()> {
let kernel_name = match T::NAME {
"f32" => "csr_spmv_f32",
"f64" => "csr_spmv_f64",
"__half" => "csr_spmv_f16",
"__nv_bfloat16" => "csr_spmv_bf16",
_ => {
return Err(Error::Internal(format!(
"Unsupported dtype for sparse SpMV: {}",
T::NAME
)));
}
};
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::SPARSE_SPMV_MODULE)?;
let func = get_kernel_function(&module, kernel_name)?;
let block_size = 256;
let grid_size = (nrows + block_size - 1) / block_size;
let cfg = launch_config((grid_size as u32, 1, 1), (block_size as u32, 1, 1), 0);
let nrows_i32 = nrows as i32;
let mut builder = stream.launch_builder(&func);
builder.arg(&row_ptrs);
builder.arg(&col_indices);
builder.arg(&values);
builder.arg(&x);
builder.arg(&y);
builder.arg(&nrows_i32);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA sparse SpMV kernel launch failed: {:?}", e))
})?;
Ok(())
}
}
// ============================================================================
// SpMV Launchers (Warp-level reduction)
// ============================================================================
/// Launch CSR SpMV kernel (warp-level reduction variant)
///
/// Better for very sparse matrices where each row has few non-zeros.
///
/// # Safety
///
/// Same safety requirements as `launch_csr_spmv`
pub unsafe fn launch_csr_spmv_warp<T: CudaTypeName>(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
row_ptrs: u64,
col_indices: u64,
values: u64,
x: u64,
y: u64,
nrows: usize,
) -> Result<()> {
let kernel_name = match T::NAME {
"f32" => "csr_spmv_warp_f32",
"f64" => "csr_spmv_warp_f64",
"__half" => "csr_spmv_warp_f16",
"__nv_bfloat16" => "csr_spmv_warp_bf16",
_ => {
return Err(Error::Internal(format!(
"Unsupported dtype for sparse SpMV warp: {}",
T::NAME
)));
}
};
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::SPARSE_SPMV_MODULE)?;
let func = get_kernel_function(&module, kernel_name)?;
// One block per row, 32 threads (one warp) per block
let cfg = launch_config((nrows as u32, 1, 1), (32, 1, 1), 0);
let nrows_i32 = nrows as i32;
let mut builder = stream.launch_builder(&func);
builder.arg(&row_ptrs);
builder.arg(&col_indices);
builder.arg(&values);
builder.arg(&x);
builder.arg(&y);
builder.arg(&nrows_i32);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA sparse SpMV warp kernel launch failed: {:?}",
e
))
})?;
Ok(())
}
}
// ============================================================================
// SpMM Launchers
// ============================================================================
/// Launch CSR SpMM kernel
///
/// C = A * B where A is sparse CSR matrix, B is dense matrix
///
/// # Safety
///
/// Caller must ensure:
/// - All pointers are valid device pointers
/// - row_ptrs has length nrows + 1
/// - col_indices and values have length nnz
/// - B has shape [ncols, ncols_B] stored row-major
/// - C has shape [nrows, ncols_B] stored row-major
pub unsafe fn launch_csr_spmm<T: CudaTypeName>(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
row_ptrs: u64,
col_indices: u64,
values: u64,
b: u64,
c: u64,
nrows: usize,
ncols_b: usize,
) -> Result<()> {
let kernel_name = match T::NAME {
"f32" => "csr_spmm_f32",
"f64" => "csr_spmm_f64",
"__half" => "csr_spmm_f16",
"__nv_bfloat16" => "csr_spmm_bf16",
_ => {
return Err(Error::Internal(format!(
"Unsupported dtype for sparse SpMM: {}",
T::NAME
)));
}
};
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::SPARSE_SPMV_MODULE)?;
let func = get_kernel_function(&module, kernel_name)?;
// One block per row, ncols_b threads per block (up to 1024)
let block_size = ncols_b.min(1024);
let cfg = launch_config((nrows as u32, 1, 1), (block_size as u32, 1, 1), 0);
let nrows_i32 = nrows as i32;
let ncols_b_i32 = ncols_b as i32;
let mut builder = stream.launch_builder(&func);
builder.arg(&row_ptrs);
builder.arg(&col_indices);
builder.arg(&values);
builder.arg(&b);
builder.arg(&c);
builder.arg(&nrows_i32);
builder.arg(&ncols_b_i32);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA sparse SpMM kernel launch failed: {:?}", e))
})?;
Ok(())
}
}
// ============================================================================
// Helper Functions
// ============================================================================
/// Choose optimal SpMV kernel based on matrix sparsity
///
/// Returns true if warp-level kernel should be used, false for row-per-thread
pub fn should_use_warp_kernel(avg_nnz_per_row: f32) -> bool {
// Warp kernel is better when rows are very sparse (< 32 nnz per row)
avg_nnz_per_row < 32.0
}
// ============================================================================
// DSMM Launcher (Dense × Sparse Matrix Multiplication)
// ============================================================================
/// Launch DSMM (Dense × Sparse) kernel using CSC format
///
/// Computes C = A @ B where:
/// - A is dense [M, K] row-major
/// - B is sparse CSC [K, N]
/// - C is dense [M, N] row-major
///
/// # Safety
///
/// Caller must ensure:
/// - A, C are valid dense matrices with correct dimensions
/// - col_ptrs, row_indices, values describe valid CSC matrix
/// - All pointers are device pointers
pub unsafe fn launch_dsmm_csc<T: CudaTypeName>(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
a: u64, // Dense [M, K]
col_ptrs: u64, // CSC [N+1]
row_indices: u64, // CSC [nnz]
values: u64, // CSC [nnz]
c: u64, // Dense [M, N]
m: usize,
k: usize,
n: usize,
) -> Result<()> {
let kernel_name = match T::NAME {
"float" => "dsmm_csc_f32",
"double" => "dsmm_csc_f64",
"__half" => "dsmm_csc_f16",
"__nv_bfloat16" => "dsmm_csc_bf16",
_ => {
return Err(Error::Internal(format!(
"Unsupported dtype for DSMM: {}",
T::NAME
)));
}
};
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::DSMM_MODULE)?;
let func = get_kernel_function(&module, kernel_name)?;
// One block per column, BLOCK_SIZE threads per block
let block_size = 256;
let cfg = launch_config((n as u32, 1, 1), (block_size, 1, 1), 0);
let m_u32 = m as u32;
let k_u32 = k as u32;
let n_u32 = n as u32;
let mut builder = stream.launch_builder(&func);
builder.arg(&a);
builder.arg(&col_ptrs);
builder.arg(&row_indices);
builder.arg(&values);
builder.arg(&c);
builder.arg(&m_u32);
builder.arg(&k_u32);
builder.arg(&n_u32);
builder
.launch(cfg)
.map_err(|e| Error::Internal(format!("CUDA DSMM kernel launch failed: {:?}", e)))?;
Ok(())
}
}