use wgpu::{Buffer, Queue};
use super::check_dtype_f32;
use crate::dtype::DType;
use crate::error::Result;
use crate::runtime::wgpu::shaders::linalg_shaders::basic_ops::BASIC_OPS_SHADER;
use crate::runtime::wgpu::shaders::pipeline::{LayoutKey, PipelineCache, workgroup_count};
pub fn launch_trace(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
n: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "trace");
let module = cache.get_or_create_module("linalg_basic_ops", BASIC_OPS_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline("linalg_basic_ops", "trace_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("trace"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("trace"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_diag(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
min_dim: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "diag");
let module = cache.get_or_create_module("linalg_basic_ops", BASIC_OPS_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline("linalg_basic_ops", "diag_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("diag"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("diag"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(min_dim), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_diagflat(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
n: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "diagflat");
let module = cache.get_or_create_module("linalg_basic_ops", BASIC_OPS_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("linalg_basic_ops", "diagflat_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("diagflat"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("diagflat"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n * n), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_create_identity(
cache: &PipelineCache,
queue: &Queue,
output: &Buffer,
params_buffer: &Buffer,
n: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "create_identity");
let module = cache.get_or_create_module("linalg_basic_ops", BASIC_OPS_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 1,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("linalg_basic_ops", "create_identity_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("create_identity"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("create_identity"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n * n), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_kron(
cache: &PipelineCache,
queue: &Queue,
a: &Buffer,
b: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
total_elements: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "kron");
let module = cache.get_or_create_module("linalg_basic_ops", BASIC_OPS_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline("linalg_basic_ops", "kron_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[a, b, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("kron"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("kron"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_elements), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_khatri_rao(
cache: &PipelineCache,
queue: &Queue,
a: &Buffer,
b: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
total_elements: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "khatri_rao");
let module = cache.get_or_create_module("linalg_basic_ops", BASIC_OPS_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("linalg_basic_ops", "khatri_rao_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[a, b, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("khatri_rao"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("khatri_rao"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_elements), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}