#version 430
// Cache-optimized matrix-vector multiplication for GPU
// Optimized for sparse and dense matrices used in implicit ODE solvers
layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;
// Dense matrix storage (row-major)
layout(std430, binding = 0) buffer MatrixBuffer {
float matrix[];
};
// Input vector
layout(std430, binding = 1) buffer VectorBuffer {
float vector[];
};
// Output vector
layout(std430, binding = 2) buffer ResultBuffer {
float result[];
};
// Sparse matrix data (for CSR format)
layout(std430, binding = 3) buffer RowPtrBuffer {
int row_ptr[];
};
layout(std430, binding = 4) buffer ColIndBuffer {
int col_ind[];
};
layout(std430, binding = 5) buffer ValuesBuffer {
float values[];
};
uniform int matrix_type; // 0=dense, 1=sparse_csr
uniform int num_rows;
uniform int num_cols;
uniform int block_size;
// Shared memory for cache blocking
shared float shared_vector[256];
shared float shared_matrix[256];
void main() {
uint row = gl_GlobalInvocationID.y;
uint col_block = gl_GlobalInvocationID.x;
if (row >= num_rows) return;
float sum = 0.0;
if (matrix_type == 0) {
// Dense matrix multiplication with cache blocking
for (uint block = 0; block < (num_cols + block_size - 1) / block_size; block++) {
uint col_start = block * block_size;
uint col_end = min(col_start + block_size, num_cols);
// Load vector block into shared memory
for (uint i = gl_LocalInvocationID.x; i < col_end - col_start; i += gl_WorkGroupSize.x) {
if (col_start + i < num_cols) {
shared_vector[i] = vector[col_start + i];
}
}
barrier();
// Compute partial dot product
for (uint col = col_start; col < col_end; col++) {
uint matrix_idx = row * num_cols + col;
sum += matrix[matrix_idx] * shared_vector[col - col_start];
}
barrier();
}
result[row] = sum;
} else if (matrix_type == 1) {
// Sparse CSR matrix multiplication
if (row < num_rows) {
int start = row_ptr[row];
int end = row_ptr[row + 1];
for (int idx = start; idx < end; idx++) {
int col = col_ind[idx];
float val = values[idx];
sum += val * vector[col];
}
result[row] = sum;
}
}
}