use wgpu::{Buffer, Queue};
use super::pipeline::{LayoutKey, PipelineCache};
use crate::dtype::DType;
use crate::error::{Error, Result};
const VALIDATE_EIGENVALUES_SHADER: &str = include_str!("validate_eigenvalues_f32.wgsl");
const DIAGONAL_EXP_SHADER: &str = include_str!("diagonal_exp_f32.wgsl");
const DIAGONAL_LOG_SHADER: &str = include_str!("diagonal_log_f32.wgsl");
const DIAGONAL_SQRT_SHADER: &str = include_str!("diagonal_sqrt_f32.wgsl");
const PARLETT_COLUMN_SHADER: &str = include_str!("parlett_column_f32.wgsl");
fn check_dtype_f32(dtype: DType, op: &'static str) -> Result<()> {
match dtype {
DType::F32 => Ok(()),
_ => Err(Error::UnsupportedDType { dtype, op }),
}
}
pub fn launch_validate_eigenvalues(
cache: &PipelineCache,
queue: &Queue,
matrix_t: &Buffer,
result: &Buffer,
n: usize,
eps: f32,
dtype: DType,
) -> Result<()> {
check_dtype_f32(dtype, "validate_eigenvalues")?;
let module =
cache.get_or_create_module("validate_eigenvalues_f32", VALIDATE_EIGENVALUES_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"validate_eigenvalues_f32",
"validate_eigenvalues_f32",
&module,
&layout,
);
let params: [u32; 4] = [n as u32, eps.to_bits(), 0, 0];
let params_buffer = cache.device().create_buffer(&wgpu::BufferDescriptor {
label: Some("validate_eigenvalues_params"),
size: 16,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(¶ms_buffer, 0, bytemuck::cast_slice(¶ms));
let bind_group = cache.create_bind_group(&layout, &[matrix_t, result, ¶ms_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("validate_eigenvalues"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("validate_eigenvalues"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(1, 1, 1); }
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_diagonal_func(
cache: &PipelineCache,
queue: &Queue,
input_t: &Buffer,
output_f: &Buffer,
n: usize,
eps: f32,
func_type: &str,
dtype: DType,
) -> Result<()> {
check_dtype_f32(dtype, "diagonal_func")?;
let (shader_src, module_name, entry_point): (&str, &'static str, &'static str) = match func_type
{
"exp" => (DIAGONAL_EXP_SHADER, "diagonal_exp_f32", "diagonal_exp_f32"),
"log" => (DIAGONAL_LOG_SHADER, "diagonal_log_f32", "diagonal_log_f32"),
"sqrt" => (
DIAGONAL_SQRT_SHADER,
"diagonal_sqrt_f32",
"diagonal_sqrt_f32",
),
_ => {
return Err(Error::Internal(format!(
"Unknown diagonal func type: {}",
func_type
)));
}
};
let module = cache.get_or_create_module(module_name, shader_src);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_name, entry_point, &module, &layout);
let params: [u32; 4] = [n as u32, eps.to_bits(), 0, 0];
let params_buffer = cache.device().create_buffer(&wgpu::BufferDescriptor {
label: Some("diagonal_func_params"),
size: 16,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(¶ms_buffer, 0, bytemuck::cast_slice(¶ms));
let bind_group = cache.create_bind_group(&layout, &[input_t, output_f, ¶ms_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("diagonal_func"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("diagonal_func"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(1, 1, 1); }
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_parlett_column(
cache: &PipelineCache,
queue: &Queue,
input_t: &Buffer,
output_f: &Buffer,
n: usize,
col: usize,
eps: f32,
dtype: DType,
) -> Result<()> {
check_dtype_f32(dtype, "parlett_column")?;
let module = cache.get_or_create_module("parlett_column_f32", PARLETT_COLUMN_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("parlett_column_f32", "parlett_column_f32", &module, &layout);
let params: [u32; 4] = [n as u32, col as u32, eps.to_bits(), 0];
let params_buffer = cache.device().create_buffer(&wgpu::BufferDescriptor {
label: Some("parlett_column_params"),
size: 16,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(¶ms_buffer, 0, bytemuck::cast_slice(¶ms));
let bind_group = cache.create_bind_group(&layout, &[input_t, output_f, ¶ms_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("parlett_column"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("parlett_column"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
let workgroups = (col as u32 + 255) / 256;
pass.dispatch_workgroups(workgroups.max(1), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn compute_schur_func_gpu(
cache: &PipelineCache,
queue: &Queue,
input_t: &Buffer,
output_f: &Buffer,
n: usize,
func_type: &str,
dtype: DType,
) -> Result<()> {
let eps = f32::EPSILON;
launch_diagonal_func(cache, queue, input_t, output_f, n, eps, func_type, dtype)?;
for col in 1..n {
launch_parlett_column(cache, queue, input_t, output_f, n, col, eps, dtype)?;
}
Ok(())
}