use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
use ash::vk;
use rand::Rng;
use std::collections::HashMap;
use std::ffi::CStr;
use std::sync::{Arc, Mutex};
macro_rules! spv {
($name:literal) => {
include_bytes!(concat!(env!("OUT_DIR"), "/", $name, ".spv"))
};
}
fn kernel_spv(name: &str) -> Result<&'static [u8]> {
let b: &'static [u8] = match name {
"add" => spv!("add"),
"sub" => spv!("sub"),
"mul" => spv!("mul"),
"div" => spv!("div"),
"affine" => spv!("affine"),
"neg" => spv!("neg"),
"exp" => spv!("exp"),
"silu" => spv!("silu"),
"silu_mul" => spv!("silu_mul"),
"gelu" => spv!("gelu"),
"relu" => spv!("relu"),
"sqr" => spv!("sqr"),
"sqrt" => spv!("sqrt"),
"recip" => spv!("recip"),
"tanh" => spv!("tanh"),
"matmul" => spv!("matmul"),
"bmm" => spv!("bmm"),
"bmm_reg" => spv!("bmm_reg"),
"bmm_reg_nt" => spv!("bmm_reg_nt"),
"bmm_coopmat" => spv!("bmm_coopmat"),
"bmm_coopmat_rb" => spv!("bmm_coopmat_rb"),
"bmm_coopmat_rb_nt" => spv!("bmm_coopmat_rb_nt"),
"cast_f2h" => spv!("cast_f2h"),
"mul_mat_vec_q8" => spv!("mul_mat_vec_q8"),
"copy" => spv!("copy"),
"reduce_sum" => spv!("reduce_sum"),
"reduce_max" => spv!("reduce_max"),
"strided_copy" => spv!("strided_copy"),
"index_select" => spv!("index_select"),
"where_cond" => spv!("where_cond"),
"softmax_rows" => spv!("softmax_rows"),
"rms_norm" => spv!("rms_norm"),
"rope" => spv!("rope"),
"sin" => spv!("sin"),
"cos" => spv!("cos"),
"log" => spv!("log"),
"abs" => spv!("abs"),
"floor" => spv!("floor"),
"ceil" => spv!("ceil"),
"round" => spv!("round"),
"sign" => spv!("sign"),
"erf" => spv!("erf"),
"gelu_erf" => spv!("gelu_erf"),
"powf" => spv!("powf"),
"elu" => spv!("elu"),
"maximum" => spv!("maximum"),
"minimum" => spv!("minimum"),
"cmp" => spv!("cmp"),
"cast_f2u" => spv!("cast_f2u"),
"cast_u2f" => spv!("cast_u2f"),
"reduce_min" => spv!("reduce_min"),
"reduce_argmin" => spv!("reduce_argmin"),
"reduce_argmax" => spv!("reduce_argmax"),
"gather" => spv!("gather"),
"scatter_set" => spv!("scatter_set"),
"scatter_add_set" => spv!("scatter_add_set"),
"conv1d" => spv!("conv1d"),
"conv2d" => spv!("conv2d"),
"conv_transpose1d" => spv!("conv_transpose1d"),
"conv_transpose2d" => spv!("conv_transpose2d"),
"avg_pool2d" => spv!("avg_pool2d"),
"max_pool2d" => spv!("max_pool2d"),
"upsample_nearest1d" => spv!("upsample_nearest1d"),
"upsample_nearest2d" => spv!("upsample_nearest2d"),
"upsample_bilinear2d" => spv!("upsample_bilinear2d"),
_ => crate::bail!("vulkan: no SPIR-V kernel for `{name}`"),
};
Ok(b)
}
#[derive(thiserror::Error, Debug)]
pub enum VulkanError {
#[error("{0}")]
Message(String),
}
impl From<String> for VulkanError {
fn from(e: String) -> Self {
VulkanError::Message(e)
}
}
impl From<VulkanError> for Error {
fn from(e: VulkanError) -> Self {
Error::Msg(e.to_string())
}
}
#[derive(Clone)]
struct CachedPipeline {
pipeline: vk::Pipeline,
layout: vk::PipelineLayout,
set_layout: vk::DescriptorSetLayout,
n_buffers: usize,
}
struct VkInner {
_entry: ash::Entry,
instance: ash::Instance,
device: ash::Device,
queue: vk::Queue,
qfi: u32,
gpu_id: usize,
mem_props: vk::PhysicalDeviceMemoryProperties,
coopmat: bool,
cm_mnk: (u32, u32, u32),
cm_use: bool,
seed: Mutex<u64>,
pipelines: Mutex<HashMap<&'static str, CachedPipeline>>,
submitter: Mutex<Submitter>,
bufpool: Mutex<BufPool>,
}
struct Submitter {
cpool: vk::CommandPool,
cmd: vk::CommandBuffer,
fence: vk::Fence,
dpool: vk::DescriptorPool,
recording: bool,
n: u32,
}
unsafe impl Send for Submitter {}
#[derive(Clone)]
pub struct VulkanDevice {
inner: Arc<VkInner>,
}
impl std::fmt::Debug for VulkanDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "VulkanDevice({})", self.inner.gpu_id)
}
}
pub struct VulkanStorage {
buffer: vk::Buffer,
memory: vk::DeviceMemory,
count: usize,
dtype: DType,
device: VulkanDevice,
}
impl std::fmt::Debug for VulkanStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"VulkanStorage(count={}, dtype={:?})",
self.count, self.dtype
)
}
}
#[derive(Default)]
struct BufPool {
pending: Vec<(u64, vk::Buffer, vk::DeviceMemory)>,
free: HashMap<u64, Vec<(vk::Buffer, vk::DeviceMemory)>>,
}
impl Drop for VulkanStorage {
fn drop(&mut self) {
if self.buffer == vk::Buffer::null() {
return;
}
let bytes = ((self.count * self.dtype.size_in_bytes()).max(4)) as u64;
if let Ok(mut pool) = self.device.inner.bufpool.lock() {
pool.pending.push((bytes, self.buffer, self.memory));
}
}
}
impl VulkanDevice {
fn dev(&self) -> &ash::Device {
&self.inner.device
}
pub fn coopmat_info(&self) -> Option<(u32, u32, u32)> {
self.inner.coopmat.then_some(self.inner.cm_mnk)
}
pub fn quantize_q8(&self, w: &[f32], nout: usize, k: usize) -> Result<VulkanStorage> {
if k % 32 != 0 {
crate::bail!("quantize_q8: k must be a multiple of 32, got {k}");
}
if w.len() != nout * k {
crate::bail!("quantize_q8: w len {} != {nout}*{k}", w.len());
}
let nblocks = k / 32;
let mut packed = vec![0u32; nout * nblocks * 9];
for n in 0..nout {
for b in 0..nblocks {
let blk = &w[n * k + b * 32..n * k + b * 32 + 32];
let amax = blk.iter().fold(0f32, |m, &v| m.max(v.abs()));
let inv = if amax > 0.0 { 127.0 / amax } else { 0.0 };
let scale = if amax > 0.0 { amax / 127.0 } else { 1.0 };
let o = (n * nblocks + b) * 9;
packed[o] = half::f16::from_f32(scale).to_bits() as u32;
for j in 0..8 {
let mut word = 0u32;
for l in 0..4 {
let q = (blk[j * 4 + l] * inv).round().clamp(-127.0, 127.0) as i32 as i8;
word |= ((q as u8) as u32) << (l * 8);
}
packed[o + 1 + j] = word;
}
}
}
self.upload_u32(&packed)
}
pub fn matvec_q8(
&self,
wq: &VulkanStorage,
x: &[f32],
nout: usize,
k: usize,
) -> Result<Vec<f32>> {
if x.len() != k {
crate::bail!("matvec_q8: x len {} != k {k}", x.len());
}
let xs = self.upload_f32(x)?;
self.matvec_q8_gpu(wq, &xs, nout, k)?.to_vec_f32()
}
pub fn matvec_q8_gpu(
&self,
wq: &VulkanStorage,
x: &VulkanStorage,
nout: usize,
k: usize,
) -> Result<VulkanStorage> {
if x.count < k {
crate::bail!("matvec_q8_gpu: x count {} < k {k}", x.count);
}
let out = self.alloc_f32(nout)?;
let push = push_u32(&[nout as u32, k as u32]);
self.dispatch(
"mul_mat_vec_q8",
&[wq.buffer, x.buffer, out.buffer],
&push,
((nout as u32).div_ceil(WG1D), 1, 1),
)?;
Ok(out)
}
unsafe fn raw_buffer(&self, bytes: u64) -> Result<(vk::Buffer, vk::DeviceMemory)> {
let bytes = bytes.max(4); if let Some(pair) = self
.inner
.bufpool
.lock()
.unwrap()
.free
.get_mut(&bytes)
.and_then(Vec::pop)
{
return Ok(pair);
}
let dev = self.dev();
let info = vk::BufferCreateInfo::default()
.size(bytes)
.usage(vk::BufferUsageFlags::STORAGE_BUFFER)
.sharing_mode(vk::SharingMode::EXCLUSIVE);
let buf = dev.create_buffer(&info, None).map_err(vkerr)?;
let req = dev.get_buffer_memory_requirements(buf);
let mp = &self.inner.mem_props;
let base = vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT;
let cached = base | vk::MemoryPropertyFlags::HOST_CACHED;
let usable = |i: u32, flags: vk::MemoryPropertyFlags| {
(req.memory_type_bits & (1 << i)) != 0
&& mp.memory_types[i as usize].property_flags.contains(flags)
};
let idx = (0..mp.memory_type_count)
.find(|&i| usable(i, cached))
.or_else(|| (0..mp.memory_type_count).find(|&i| usable(i, base)))
.ok_or_else(|| Error::Msg("vulkan: no host-visible memory type".into()))?;
let mem = dev
.allocate_memory(
&vk::MemoryAllocateInfo::default()
.allocation_size(req.size)
.memory_type_index(idx),
None,
)
.map_err(vkerr)?;
dev.bind_buffer_memory(buf, mem, 0).map_err(vkerr)?;
Ok((buf, mem))
}
unsafe fn write_f32(&self, mem: vk::DeviceMemory, data: &[f32]) -> Result<()> {
if data.is_empty() {
return Ok(());
}
let dev = self.dev();
let ptr = dev
.map_memory(mem, 0, (data.len() * 4) as u64, vk::MemoryMapFlags::empty())
.map_err(vkerr)? as *mut f32;
std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, data.len());
dev.unmap_memory(mem);
Ok(())
}
unsafe fn read_f32(&self, mem: vk::DeviceMemory, n: usize) -> Result<Vec<f32>> {
if n == 0 {
return Ok(Vec::new());
}
self.flush()?;
let dev = self.dev();
let ptr = dev
.map_memory(mem, 0, (n * 4) as u64, vk::MemoryMapFlags::empty())
.map_err(vkerr)? as *const f32;
let v = std::slice::from_raw_parts(ptr, n).to_vec();
dev.unmap_memory(mem);
Ok(v)
}
unsafe fn write_u32(&self, mem: vk::DeviceMemory, data: &[u32]) -> Result<()> {
if data.is_empty() {
return Ok(());
}
let dev = self.dev();
let ptr = dev
.map_memory(mem, 0, (data.len() * 4) as u64, vk::MemoryMapFlags::empty())
.map_err(vkerr)? as *mut u32;
std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, data.len());
dev.unmap_memory(mem);
Ok(())
}
unsafe fn read_u32(&self, mem: vk::DeviceMemory, n: usize) -> Result<Vec<u32>> {
if n == 0 {
return Ok(Vec::new());
}
self.flush()?;
let dev = self.dev();
let ptr = dev
.map_memory(mem, 0, (n * 4) as u64, vk::MemoryMapFlags::empty())
.map_err(vkerr)? as *const u32;
let v = std::slice::from_raw_parts(ptr, n).to_vec();
dev.unmap_memory(mem);
Ok(v)
}
fn pipeline(
&self,
name: &'static str,
n_buffers: usize,
push_size: usize,
) -> Result<CachedPipeline> {
if let Some(p) = self.inner.pipelines.lock().unwrap().get(name) {
return Ok(p.clone());
}
let dev = self.dev();
let cached = unsafe {
let binds: Vec<_> = (0..n_buffers as u32)
.map(|i| {
vk::DescriptorSetLayoutBinding::default()
.binding(i)
.descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
.descriptor_count(1)
.stage_flags(vk::ShaderStageFlags::COMPUTE)
})
.collect();
let set_layout = dev
.create_descriptor_set_layout(
&vk::DescriptorSetLayoutCreateInfo::default().bindings(&binds),
None,
)
.map_err(vkerr)?;
let set_layouts = [set_layout];
let pcr = [vk::PushConstantRange::default()
.stage_flags(vk::ShaderStageFlags::COMPUTE)
.offset(0)
.size(push_size.max(4) as u32)];
let layout = dev
.create_pipeline_layout(
&vk::PipelineLayoutCreateInfo::default()
.set_layouts(&set_layouts)
.push_constant_ranges(&pcr),
None,
)
.map_err(vkerr)?;
let spv_bytes = kernel_spv(name)?;
let spv = ash::util::read_spv(&mut std::io::Cursor::new(spv_bytes))
.map_err(|e| Error::Msg(format!("vulkan: bad SPIR-V `{name}`: {e}")))?;
let module = dev
.create_shader_module(&vk::ShaderModuleCreateInfo::default().code(&spv), None)
.map_err(vkerr)?;
let stage = vk::PipelineShaderStageCreateInfo::default()
.stage(vk::ShaderStageFlags::COMPUTE)
.module(module)
.name(c"main");
let pipeline = dev
.create_compute_pipelines(
vk::PipelineCache::null(),
&[vk::ComputePipelineCreateInfo::default()
.stage(stage)
.layout(layout)],
None,
)
.map_err(|(_, e)| vkerr(e))?[0];
dev.destroy_shader_module(module, None);
CachedPipeline {
pipeline,
layout,
set_layout,
n_buffers,
}
};
self.inner
.pipelines
.lock()
.unwrap()
.insert(name, cached.clone());
Ok(cached)
}
fn dispatch(
&self,
name: &'static str,
bufs: &[vk::Buffer],
push: &[u8],
groups: (u32, u32, u32),
) -> Result<()> {
let p = self.pipeline(name, bufs.len(), push.len())?;
debug_assert_eq!(
p.n_buffers,
bufs.len(),
"vulkan: kernel `{name}` binding count drift"
);
let dev = self.dev();
let queue = self.inner.queue;
let mut s = self.inner.submitter.lock().unwrap();
unsafe {
if !s.recording {
dev.reset_descriptor_pool(s.dpool, vk::DescriptorPoolResetFlags::empty())
.map_err(vkerr)?;
dev.begin_command_buffer(s.cmd, &vk::CommandBufferBeginInfo::default())
.map_err(vkerr)?;
s.recording = true;
s.n = 0;
}
let set_layouts = [p.set_layout];
let set = dev
.allocate_descriptor_sets(
&vk::DescriptorSetAllocateInfo::default()
.descriptor_pool(s.dpool)
.set_layouts(&set_layouts),
)
.map_err(vkerr)?[0];
let infos: Vec<[vk::DescriptorBufferInfo; 1]> = bufs
.iter()
.map(|&b| {
[vk::DescriptorBufferInfo::default()
.buffer(b)
.range(vk::WHOLE_SIZE)]
})
.collect();
let writes: Vec<_> = (0..bufs.len())
.map(|i| {
vk::WriteDescriptorSet::default()
.dst_set(set)
.dst_binding(i as u32)
.descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
.buffer_info(&infos[i])
})
.collect();
dev.update_descriptor_sets(&writes, &[]);
let cmd = s.cmd;
dev.cmd_bind_pipeline(cmd, vk::PipelineBindPoint::COMPUTE, p.pipeline);
dev.cmd_bind_descriptor_sets(
cmd,
vk::PipelineBindPoint::COMPUTE,
p.layout,
0,
&[set],
&[],
);
if !push.is_empty() {
dev.cmd_push_constants(cmd, p.layout, vk::ShaderStageFlags::COMPUTE, 0, push);
}
dev.cmd_dispatch(cmd, groups.0, groups.1, groups.2);
let bar = [vk::MemoryBarrier::default()
.src_access_mask(vk::AccessFlags::SHADER_WRITE)
.dst_access_mask(vk::AccessFlags::SHADER_READ | vk::AccessFlags::SHADER_WRITE)];
dev.cmd_pipeline_barrier(
cmd,
vk::PipelineStageFlags::COMPUTE_SHADER,
vk::PipelineStageFlags::COMPUTE_SHADER,
vk::DependencyFlags::empty(),
&bar,
&[],
&[],
);
s.n += 1;
if s.n >= BATCH_CAP {
flush_locked(dev, queue, &mut s)?;
drop(s);
self.reclaim();
}
}
Ok(())
}
fn flush(&self) -> Result<()> {
let dev = self.dev();
let queue = self.inner.queue;
{
let mut s = self.inner.submitter.lock().unwrap();
flush_locked(dev, queue, &mut s)?;
}
self.reclaim();
Ok(())
}
fn reclaim(&self) {
let mut pool = self.inner.bufpool.lock().unwrap();
let pending = std::mem::take(&mut pool.pending);
for (bytes, buf, mem) in pending {
pool.free.entry(bytes).or_default().push((buf, mem));
}
}
fn alloc_f32(&self, count: usize) -> Result<VulkanStorage> {
let (buffer, memory) = unsafe { self.raw_buffer((count * 4) as u64)? };
Ok(VulkanStorage {
buffer,
memory,
count,
dtype: DType::F32,
device: self.clone(),
})
}
fn alloc_f16(&self, count: usize) -> Result<(vk::Buffer, vk::DeviceMemory, u64)> {
let bytes = ((count * 2).max(4)) as u64;
let (buffer, memory) = unsafe { self.raw_buffer(bytes)? };
Ok((buffer, memory, bytes))
}
fn free_scratch(&self, bytes: u64, buffer: vk::Buffer, memory: vk::DeviceMemory) {
if let Ok(mut pool) = self.inner.bufpool.lock() {
pool.pending.push((bytes, buffer, memory));
}
}
pub(crate) fn upload_f32(&self, data: &[f32]) -> Result<VulkanStorage> {
let s = self.alloc_f32(data.len())?;
unsafe { self.write_f32(s.memory, data)? };
Ok(s)
}
fn alloc_u32(&self, count: usize) -> Result<VulkanStorage> {
let (buffer, memory) = unsafe { self.raw_buffer((count * 4) as u64)? };
Ok(VulkanStorage {
buffer,
memory,
count,
dtype: DType::U32,
device: self.clone(),
})
}
fn upload_u32(&self, data: &[u32]) -> Result<VulkanStorage> {
let s = self.alloc_u32(data.len())?;
unsafe { self.write_u32(s.memory, data)? };
Ok(s)
}
}
fn vkerr(e: vk::Result) -> Error {
Error::Msg(format!("vulkan: {e:?}"))
}
const BATCH_CAP: u32 = 4096;
fn flush_locked(dev: &ash::Device, queue: vk::Queue, s: &mut Submitter) -> Result<()> {
if !s.recording {
return Ok(());
}
unsafe {
dev.end_command_buffer(s.cmd).map_err(vkerr)?;
dev.reset_fences(&[s.fence]).map_err(vkerr)?;
let cmds = [s.cmd];
dev.queue_submit(
queue,
&[vk::SubmitInfo::default().command_buffers(&cmds)],
s.fence,
)
.map_err(vkerr)?;
dev.wait_for_fences(&[s.fence], true, u64::MAX)
.map_err(vkerr)?;
}
s.recording = false;
s.n = 0;
Ok(())
}
fn push_u32(v: &[u32]) -> Vec<u8> {
let mut b = Vec::with_capacity(v.len() * 4);
for x in v {
b.extend_from_slice(&x.to_ne_bytes());
}
b
}
const WG1D: u32 = 64;
impl BackendDevice for VulkanDevice {
type Storage = VulkanStorage;
fn new(ordinal: usize) -> Result<Self> {
unsafe {
let entry = ash::Entry::load()
.map_err(|e| Error::Msg(format!("vulkan: loader not found: {e}")))?;
let app = vk::ApplicationInfo::default().api_version(vk::make_api_version(0, 1, 3, 0));
let instance = entry
.create_instance(
&vk::InstanceCreateInfo::default().application_info(&app),
None,
)
.map_err(vkerr)?;
let mut gpus = Vec::new();
for pd in instance.enumerate_physical_devices().map_err(vkerr)? {
let p = instance.get_physical_device_properties(pd);
let name = CStr::from_ptr(p.device_name.as_ptr())
.to_string_lossy()
.into_owned();
let is_cpu = p.device_type == vk::PhysicalDeviceType::CPU
|| name.to_lowercase().contains("llvmpipe");
if !is_cpu {
gpus.push(pd);
}
}
if gpus.is_empty() {
instance.destroy_instance(None);
return Err(Error::Msg("vulkan: no non-CPU Vulkan device".into()));
}
let pdev = *gpus
.get(ordinal)
.ok_or_else(|| Error::Msg(format!("vulkan: no device at ordinal {ordinal}")))?;
let qfi = instance
.get_physical_device_queue_family_properties(pdev)
.iter()
.position(|q| q.queue_flags.contains(vk::QueueFlags::COMPUTE))
.ok_or_else(|| Error::Msg("vulkan: no compute queue".into()))?
as u32;
let prios = [1.0f32];
let qci = [vk::DeviceQueueCreateInfo::default()
.queue_family_index(qfi)
.queue_priorities(&prios)];
let dev_exts = instance
.enumerate_device_extension_properties(pdev)
.unwrap_or_default();
let has_cm_ext = dev_exts.iter().any(|e| {
CStr::from_ptr(e.extension_name.as_ptr()) == ash::khr::cooperative_matrix::NAME
});
let cm_mnk = if has_cm_ext {
let cm = ash::khr::cooperative_matrix::Instance::new(&entry, &instance);
cm.get_physical_device_cooperative_matrix_properties(pdev)
.ok()
.and_then(|props| {
props.into_iter().find(|p| {
p.a_type == vk::ComponentTypeKHR::FLOAT16
&& p.b_type == vk::ComponentTypeKHR::FLOAT16
&& p.c_type == vk::ComponentTypeKHR::FLOAT32
&& p.result_type == vk::ComponentTypeKHR::FLOAT32
&& p.scope == vk::ScopeKHR::SUBGROUP
})
})
.map(|p| (p.m_size, p.n_size, p.k_size))
} else {
None
};
let coopmat = cm_mnk.is_some();
let cm_mnk = cm_mnk.unwrap_or((0, 0, 0));
let cm_use = coopmat
&& std::env::var("HANZO_VK_COOPMAT")
.map(|v| v != "0")
.unwrap_or(true);
let cm_ext_names = [ash::khr::cooperative_matrix::NAME.as_ptr()];
let mut cm_feat =
vk::PhysicalDeviceCooperativeMatrixFeaturesKHR::default().cooperative_matrix(true);
let mut mm_feat =
vk::PhysicalDeviceVulkanMemoryModelFeatures::default().vulkan_memory_model(true);
let mut f16_feat =
vk::PhysicalDeviceShaderFloat16Int8Features::default().shader_float16(true);
let mut s16_feat =
vk::PhysicalDevice16BitStorageFeatures::default().storage_buffer16_bit_access(true);
let mut dci = vk::DeviceCreateInfo::default().queue_create_infos(&qci);
if coopmat {
dci = dci
.enabled_extension_names(&cm_ext_names)
.push_next(&mut cm_feat)
.push_next(&mut mm_feat)
.push_next(&mut f16_feat)
.push_next(&mut s16_feat);
}
let device = instance.create_device(pdev, &dci, None).map_err(vkerr)?;
let queue = device.get_device_queue(qfi, 0);
let mem_props = instance.get_physical_device_memory_properties(pdev);
let cpool = device
.create_command_pool(
&vk::CommandPoolCreateInfo::default()
.queue_family_index(qfi)
.flags(vk::CommandPoolCreateFlags::RESET_COMMAND_BUFFER),
None,
)
.map_err(vkerr)?;
let cmd = device
.allocate_command_buffers(
&vk::CommandBufferAllocateInfo::default()
.command_pool(cpool)
.level(vk::CommandBufferLevel::PRIMARY)
.command_buffer_count(1),
)
.map_err(vkerr)?[0];
let fence = device
.create_fence(&vk::FenceCreateInfo::default(), None)
.map_err(vkerr)?;
let dpool_sizes = [vk::DescriptorPoolSize::default()
.ty(vk::DescriptorType::STORAGE_BUFFER)
.descriptor_count(BATCH_CAP * 4)];
let dpool = device
.create_descriptor_pool(
&vk::DescriptorPoolCreateInfo::default()
.pool_sizes(&dpool_sizes)
.max_sets(BATCH_CAP),
None,
)
.map_err(vkerr)?;
let submitter = Mutex::new(Submitter {
cpool,
cmd,
fence,
dpool,
recording: false,
n: 0,
});
let inner = VkInner {
_entry: entry,
instance,
device,
queue,
qfi,
gpu_id: ordinal,
mem_props,
seed: Mutex::new(299792458),
pipelines: Mutex::new(HashMap::new()),
submitter,
bufpool: Mutex::new(BufPool::default()),
coopmat,
cm_mnk,
cm_use,
};
Ok(Self {
inner: Arc::new(inner),
})
}
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Vulkan {
gpu_id: self.inner.gpu_id,
}
}
fn same_device(&self, rhs: &Self) -> bool {
Arc::ptr_eq(&self.inner, &rhs.inner)
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
let count = shape.elem_count();
match dtype {
DType::F32 => self.upload_f32(&vec![0f32; count]),
DType::U32 => self.upload_u32(&vec![0u32; count]),
_ => crate::bail!("vulkan: only f32/u32 supported, got {dtype:?}"),
}
}
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
match dtype {
DType::F32 => self.alloc_f32(shape.elem_count()),
DType::U32 => self.alloc_u32(shape.elem_count()),
_ => crate::bail!("vulkan: only f32/u32 supported, got {dtype:?}"),
}
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
self.storage_from_cpu_storage(&T::to_cpu_storage(s))
}
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
match s {
CpuStorage::F32(v) => self.upload_f32(v),
CpuStorage::U32(v) => self.upload_u32(v),
CpuStorage::F16(v) => {
self.upload_f32(&v.iter().map(|x| x.to_f32()).collect::<Vec<_>>())
}
CpuStorage::BF16(v) => {
self.upload_f32(&v.iter().map(|x| x.to_f32()).collect::<Vec<_>>())
}
_ => crate::bail!(
"vulkan: only f32/u32/f16/bf16 supported, got {:?}",
s.dtype()
),
}
}
fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
self.storage_from_cpu_storage(&s)
}
fn rand_uniform(
&self,
shape: &Shape,
dtype: DType,
min: f64,
max: f64,
) -> Result<Self::Storage> {
if dtype != DType::F32 {
crate::bail!("vulkan: rand_uniform only f32, got {dtype:?}");
}
let mut rng = rand::rng();
let n = shape.elem_count();
let uniform = rand::distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
let mut data = Vec::with_capacity(n);
for _ in 0..n {
data.push(rng.sample::<f32, _>(uniform));
}
self.upload_f32(&data)
}
fn rand_normal(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
std: f64,
) -> Result<Self::Storage> {
if dtype != DType::F32 {
crate::bail!("vulkan: rand_normal only f32, got {dtype:?}");
}
use rand_distr::Distribution;
let mut rng = rand::rng();
let n = shape.elem_count();
let normal = rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
let mut data = Vec::with_capacity(n);
for _ in 0..n {
data.push(normal.sample(&mut rng));
}
self.upload_f32(&data)
}
fn set_seed(&self, seed: u64) -> Result<()> {
*self.inner.seed.lock().unwrap() = seed;
Ok(())
}
fn get_current_seed(&self) -> Result<u64> {
Ok(*self.inner.seed.lock().unwrap())
}
fn synchronize(&self) -> Result<()> {
self.flush()?;
unsafe { self.dev().device_wait_idle().map_err(vkerr) }
}
}
impl VulkanStorage {
fn count(&self) -> usize {
self.count
}
fn to_vec_f32(&self) -> Result<Vec<f32>> {
unsafe { self.device.read_f32(self.memory, self.count) }
}
fn to_vec_u32(&self) -> Result<Vec<u32>> {
unsafe { self.device.read_u32(self.memory, self.count) }
}
fn groups_1d(n: usize) -> (u32, u32, u32) {
((n as u32).div_ceil(WG1D), 1, 1)
}
fn contiguous(&self, layout: &Layout) -> Result<VulkanStorage> {
let dims = layout.dims();
let rank = dims.len();
if rank > 6 {
crate::bail!("vulkan: contiguous supports rank <= 6, got {rank}");
}
let n = layout.shape().elem_count();
let out = self.device.alloc_f32(n)?;
let strides = layout.stride();
let mut p = vec![n as u32, rank as u32, layout.start_offset() as u32];
let mut shape6 = [0u32; 6];
let mut stride6 = [0u32; 6];
for d in 0..rank {
shape6[d] = dims[d] as u32;
stride6[d] = strides[d] as u32;
}
p.extend_from_slice(&shape6);
p.extend_from_slice(&stride6);
self.device.dispatch(
"strided_copy",
&[self.buffer, out.buffer],
&push_u32(&p),
Self::groups_1d(n),
)?;
Ok(out)
}
fn contig_buf(&self, layout: &Layout, keep: &mut Option<VulkanStorage>) -> Result<vk::Buffer> {
if layout.is_contiguous() && layout.start_offset() == 0 {
Ok(self.buffer)
} else {
let s = self.contiguous(layout)?;
let b = s.buffer;
*keep = Some(s);
Ok(b)
}
}
pub fn softmax_last_dim(&self, layout: &Layout) -> Result<VulkanStorage> {
let mut xk = None;
let xb = self.contig_buf(layout, &mut xk)?;
let dims = layout.dims();
let m = *dims.last().unwrap_or(&1);
let nrows = layout.shape().elem_count() / m.max(1);
let out = self.device.alloc_f32(nrows * m)?;
self.device.dispatch(
"softmax_rows",
&[xb, out.buffer],
&push_u32(&[nrows as u32, m as u32]),
Self::groups_1d(nrows),
)?;
Ok(out)
}
pub fn rms_norm(
&self,
layout: &Layout,
alpha: &VulkanStorage,
alpha_l: &Layout,
eps: f32,
) -> Result<VulkanStorage> {
let mut xk = None;
let mut ak = None;
let xb = self.contig_buf(layout, &mut xk)?;
let ab = alpha.contig_buf(alpha_l, &mut ak)?;
let dims = layout.dims();
let m = *dims.last().unwrap_or(&1);
let nrows = layout.shape().elem_count() / m.max(1);
let out = self.device.alloc_f32(nrows * m)?;
let mut push = push_u32(&[nrows as u32, m as u32]);
push.extend_from_slice(&eps.to_ne_bytes());
self.device.dispatch(
"rms_norm",
&[xb, ab, out.buffer],
&push,
Self::groups_1d(nrows),
)?;
Ok(out)
}
pub fn silu_mul(
&self,
layout: &Layout,
rhs: &VulkanStorage,
rhs_l: &Layout,
) -> Result<VulkanStorage> {
let mut ak = None;
let mut bk = None;
let ab = self.contig_buf(layout, &mut ak)?;
let bb = rhs.contig_buf(rhs_l, &mut bk)?;
let n = layout.shape().elem_count();
let out = self.device.alloc_f32(n)?;
self.device.dispatch(
"silu_mul",
&[ab, bb, out.buffer],
&(n as u32).to_ne_bytes(),
Self::groups_1d(n),
)?;
Ok(out)
}
pub fn rope(
&self,
layout: &Layout,
cos: &VulkanStorage,
cos_l: &Layout,
sin: &VulkanStorage,
sin_l: &Layout,
) -> Result<VulkanStorage> {
let mut srck = None;
let mut ck = None;
let mut sk = None;
let srcb = self.contig_buf(layout, &mut srck)?;
let cb = cos.contig_buf(cos_l, &mut ck)?;
let sb = sin.contig_buf(sin_l, &mut sk)?;
let (b, h, t, d) = layout.shape().dims4()?;
let unbatched = (cos_l.dims().len() == 3 && sin_l.dims().len() == 3) as u32;
let out = self.device.alloc_f32(b * h * t * d)?;
let pairs = b * h * t * (d / 2);
self.device.dispatch(
"rope",
&[srcb, cb, sb, out.buffer],
&push_u32(&[b as u32, h as u32, t as u32, d as u32, unbatched]),
Self::groups_1d(pairs),
)?;
Ok(out)
}
fn scatter_impl(
&self,
kernel: &'static str,
l: &Layout,
ids: &VulkanStorage,
ids_l: &Layout,
src: &VulkanStorage,
src_l: &Layout,
dim: usize,
) -> Result<()> {
if ids.dtype != DType::U32 {
crate::bail!("vulkan: scatter requires u32 ids, got {:?}", ids.dtype);
}
let idc = ids.contiguous(ids_l)?;
let srcc = src.contiguous(src_l)?;
let src_dims = src_l.dims();
let dst_dims = l.dims();
let right: usize = src_dims[dim + 1..].iter().product();
let dim_src = src_dims[dim];
let dim_dst = dst_dims[dim];
let n = src_l.shape().elem_count();
self.device.dispatch(
kernel,
&[self.buffer, srcc.buffer, idc.buffer],
&push_u32(&[n as u32, right as u32, dim_src as u32, dim_dst as u32]),
Self::groups_1d(n),
)?;
Ok(())
}
fn pool2d(
&self,
kernel: &'static str,
l: &Layout,
k: (usize, usize),
stride: (usize, usize),
) -> Result<VulkanStorage> {
let inp = self.contiguous(l)?;
let (b, c, ih, iw) = l.shape().dims4()?;
let (kh, kw) = k;
let (sh, sw) = stride;
let oh = (ih - kh) / sh + 1;
let ow = (iw - kw) / sw + 1;
let out = self.device.alloc_f32(b * c * oh * ow)?;
self.device.dispatch(
kernel,
&[inp.buffer, out.buffer],
&push_u32(&[
b as u32, c as u32, ih as u32, iw as u32, oh as u32, ow as u32, kh as u32,
kw as u32, sh as u32, sw as u32,
]),
Self::groups_1d(b * c * oh * ow),
)?;
Ok(out)
}
}
impl BackendStorage for VulkanStorage {
type Device = VulkanDevice;
fn try_clone(&self, _: &Layout) -> Result<Self> {
match self.dtype {
DType::F32 => self.device.upload_f32(&self.to_vec_f32()?),
DType::U32 => self.device.upload_u32(&self.to_vec_u32()?),
dt => crate::bail!("vulkan: try_clone unsupported dtype {dt:?}"),
}
}
fn dtype(&self) -> DType {
self.dtype
}
fn device(&self) -> &Self::Device {
&self.device
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
match self.dtype {
DType::U32 => Ok(CpuStorage::U32(self.to_vec_u32()?)),
_ => Ok(CpuStorage::F32(self.to_vec_f32()?)),
}
}
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
let c = self.contiguous(layout)?;
let n = layout.shape().elem_count();
let out = self.device.alloc_f32(n)?;
let mut push = (n as u32).to_ne_bytes().to_vec();
push.extend_from_slice(&(mul as f32).to_ne_bytes());
push.extend_from_slice(&(add as f32).to_ne_bytes());
self.device
.dispatch("affine", &[c.buffer, out.buffer], &push, Self::groups_1d(n))?;
Ok(out)
}
fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
let c = self.contiguous(layout)?;
let n = layout.shape().elem_count();
let out = self.device.alloc_f32(n)?;
let mut push = (n as u32).to_ne_bytes().to_vec();
push.extend_from_slice(&(e as f32).to_ne_bytes());
self.device
.dispatch("powf", &[c.buffer, out.buffer], &push, Self::groups_1d(n))?;
Ok(out)
}
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
let c = self.contiguous(layout)?;
let n = layout.shape().elem_count();
let out = self.device.alloc_f32(n)?;
let mut push = (n as u32).to_ne_bytes().to_vec();
push.extend_from_slice(&(alpha as f32).to_ne_bytes());
self.device
.dispatch("elu", &[c.buffer, out.buffer], &push, Self::groups_1d(n))?;
Ok(out)
}
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let (kernel, is_arg) = match op {
ReduceOp::Sum => ("reduce_sum", false),
ReduceOp::Max => ("reduce_max", false),
ReduceOp::Min => ("reduce_min", false),
ReduceOp::ArgMin => ("reduce_argmin", true),
ReduceOp::ArgMax => ("reduce_argmax", true),
};
let dims = layout.dims();
let rank = dims.len();
if rank == 0 {
crate::bail!("vulkan: reduce_op on scalar not supported");
}
if sum_dims != [rank - 1] {
crate::bail!(
"vulkan: reduce_op only supports the last dim (got dims={sum_dims:?}, rank={rank})"
);
}
let c = self.contiguous(layout)?;
let cols = dims[rank - 1];
let rows: usize = dims[..rank - 1].iter().product();
let out = if is_arg {
self.device.alloc_u32(rows)?
} else {
self.device.alloc_f32(rows)?
};
let push = push_u32(&[rows as u32, cols as u32]);
self.device.dispatch(
kernel,
&[c.buffer, out.buffer],
&push,
Self::groups_1d(rows),
)?;
Ok(out)
}
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
let lc = self.contiguous(lhs_l)?;
let rc = rhs.contiguous(rhs_l)?;
let n = lhs_l.shape().elem_count();
let out = self.device.alloc_u32(n)?;
let code: u32 = match op {
CmpOp::Eq => 0,
CmpOp::Ne => 1,
CmpOp::Le => 2,
CmpOp::Ge => 3,
CmpOp::Lt => 4,
CmpOp::Gt => 5,
};
self.device.dispatch(
"cmp",
&[lc.buffer, rc.buffer, out.buffer],
&push_u32(&[n as u32, code]),
Self::groups_1d(n),
)?;
Ok(out)
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
let n = layout.shape().elem_count();
match (self.dtype, dtype) {
(DType::F32, DType::F32) | (DType::U32, DType::U32) => self.contiguous(layout),
(DType::F32, DType::U32) => {
let c = self.contiguous(layout)?;
let out = self.device.alloc_u32(n)?;
self.device.dispatch(
"cast_f2u",
&[c.buffer, out.buffer],
&push_u32(&[n as u32]),
Self::groups_1d(n),
)?;
Ok(out)
}
(DType::U32, DType::F32) => {
let c = self.contiguous(layout)?;
let out = self.device.alloc_f32(n)?;
self.device.dispatch(
"cast_u2f",
&[c.buffer, out.buffer],
&push_u32(&[n as u32]),
Self::groups_1d(n),
)?;
Ok(out)
}
_ => {
let cpu = self.to_cpu_storage()?;
let converted = crate::backend::BackendStorage::to_dtype(&cpu, layout, dtype)?;
self.device.storage_from_cpu_storage(&converted)
}
}
}
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let kernel: &'static str = match B::NAME {
"silu" => "silu",
"gelu" => "gelu",
"relu" => "relu",
"exp" => "exp",
"neg" => "neg",
"sqr" => "sqr",
"sqrt" => "sqrt",
"recip" => "recip",
"tanh" => "tanh",
"sin" => "sin",
"cos" => "cos",
"log" => "log",
"abs" => "abs",
"floor" => "floor",
"ceil" => "ceil",
"round" => "round",
"sign" => "sign",
"erf" => "erf",
"gelu_erf" => "gelu_erf",
_ => {
let cpu = self.to_cpu_storage()?;
let r = crate::backend::BackendStorage::unary_impl::<B>(&cpu, layout)?;
return self.device.storage_from_cpu_storage(&r);
}
};
let mut ck = None;
let cb = self.contig_buf(layout, &mut ck)?;
let n = layout.shape().elem_count();
let out = self.device.alloc_f32(n)?;
let push = (n as u32).to_ne_bytes();
self.device
.dispatch(kernel, &[cb, out.buffer], &push, Self::groups_1d(n))?;
Ok(out)
}
fn binary_impl<B: BinaryOpT>(
&self,
rhs: &Self,
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let kernel: &'static str = match B::NAME {
"add" => "add",
"sub" => "sub",
"mul" => "mul",
"div" => "div",
"maximum" => "maximum",
"minimum" => "minimum",
_ => {
let lc = self.to_cpu_storage()?;
let rc = rhs.to_cpu_storage()?;
let r = crate::backend::BackendStorage::binary_impl::<B>(&lc, &rc, lhs_l, rhs_l)?;
return self.device.storage_from_cpu_storage(&r);
}
};
let mut lk = None;
let mut rk = None;
let lb = self.contig_buf(lhs_l, &mut lk)?;
let rb = rhs.contig_buf(rhs_l, &mut rk)?;
let n = lhs_l.shape().elem_count();
let out = self.device.alloc_f32(n)?;
let push = (n as u32).to_ne_bytes();
self.device
.dispatch(kernel, &[lb, rb, out.buffer], &push, Self::groups_1d(n))?;
Ok(out)
}
fn where_cond(
&self,
l: &Layout,
t: &Self,
t_l: &Layout,
f: &Self,
f_l: &Layout,
) -> Result<Self> {
if self.dtype != DType::U32 {
crate::bail!(
"vulkan: where_cond requires u32 condition, got {:?}",
self.dtype
);
}
if !l.is_contiguous() {
crate::bail!("vulkan: where_cond requires contiguous condition");
}
let mut tk = None;
let mut fk = None;
let tb = t.contig_buf(t_l, &mut tk)?;
let fb = f.contig_buf(f_l, &mut fk)?;
let n = l.shape().elem_count();
let out = self.device.alloc_f32(n)?;
let push = push_u32(&[n as u32]);
self.device.dispatch(
"where_cond",
&[self.buffer, tb, fb, out.buffer],
&push,
Self::groups_1d(n),
)?;
Ok(out)
}
fn conv1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
p: &crate::conv::ParamsConv1D,
) -> Result<Self> {
let inp = self.contiguous(l)?;
let w = kernel.contiguous(kernel_l)?;
let l_out = p.l_out();
let out = self.device.alloc_f32(p.b_size * p.c_out * l_out)?;
let push = push_u32(&[
p.b_size as u32,
p.c_in as u32,
p.c_out as u32,
p.l_in as u32,
l_out as u32,
p.k_size as u32,
p.padding as u32,
p.stride as u32,
p.dilation as u32,
]);
self.device.dispatch(
"conv1d",
&[inp.buffer, w.buffer, out.buffer],
&push,
Self::groups_1d(p.b_size * p.c_out * l_out),
)?;
Ok(out)
}
fn conv_transpose1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
p: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
let inp = self.contiguous(l)?;
let w = kernel.contiguous(kernel_l)?;
let l_out = p.l_out();
let out = self.device.alloc_f32(p.b_size * p.c_out * l_out)?;
let push = push_u32(&[
p.b_size as u32,
p.c_in as u32,
p.c_out as u32,
p.l_in as u32,
l_out as u32,
p.k_size as u32,
p.padding as u32,
p.stride as u32,
p.dilation as u32,
]);
self.device.dispatch(
"conv_transpose1d",
&[inp.buffer, w.buffer, out.buffer],
&push,
Self::groups_1d(p.b_size * p.c_out * l_out),
)?;
Ok(out)
}
fn conv2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
p: &crate::conv::ParamsConv2D,
) -> Result<Self> {
let inp = self.contiguous(l)?;
let w = kernel.contiguous(kernel_l)?;
let (oh, ow) = (p.out_h(), p.out_w());
let out = self.device.alloc_f32(p.b_size * p.c_out * oh * ow)?;
let push = push_u32(&[
p.b_size as u32,
p.c_in as u32,
p.c_out as u32,
p.i_h as u32,
p.i_w as u32,
oh as u32,
ow as u32,
p.k_h as u32,
p.k_w as u32,
p.padding as u32,
p.stride as u32,
p.dilation as u32,
]);
self.device.dispatch(
"conv2d",
&[inp.buffer, w.buffer, out.buffer],
&push,
Self::groups_1d(p.b_size * p.c_out * oh * ow),
)?;
Ok(out)
}
fn conv_transpose2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
p: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
let inp = self.contiguous(l)?;
let w = kernel.contiguous(kernel_l)?;
let (oh, ow) = (p.out_h(), p.out_w());
let out = self.device.alloc_f32(p.b_size * p.c_out * oh * ow)?;
let push = push_u32(&[
p.b_size as u32,
p.c_in as u32,
p.c_out as u32,
p.i_h as u32,
p.i_w as u32,
oh as u32,
ow as u32,
p.k_h as u32,
p.k_w as u32,
p.padding as u32,
p.stride as u32,
p.dilation as u32,
]);
self.device.dispatch(
"conv_transpose2d",
&[inp.buffer, w.buffer, out.buffer],
&push,
Self::groups_1d(p.b_size * p.c_out * oh * ow),
)?;
Ok(out)
}
fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
self.pool2d("avg_pool2d", l, k, stride)
}
fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
self.pool2d("max_pool2d", l, k, stride)
}
fn upsample_nearest1d(&self, l: &Layout, sz: usize) -> Result<Self> {
let inp = self.contiguous(l)?;
let (b, c, l_in) = l.shape().dims3()?;
let out = self.device.alloc_f32(b * c * sz)?;
self.device.dispatch(
"upsample_nearest1d",
&[inp.buffer, out.buffer],
&push_u32(&[b as u32, c as u32, l_in as u32, sz as u32]),
Self::groups_1d(b * c * sz),
)?;
Ok(out)
}
fn upsample_nearest2d(&self, l: &Layout, oh: usize, ow: usize) -> Result<Self> {
let inp = self.contiguous(l)?;
let (b, c, ih, iw) = l.shape().dims4()?;
let out = self.device.alloc_f32(b * c * oh * ow)?;
self.device.dispatch(
"upsample_nearest2d",
&[inp.buffer, out.buffer],
&push_u32(&[
b as u32, c as u32, ih as u32, iw as u32, oh as u32, ow as u32,
]),
Self::groups_1d(b * c * oh * ow),
)?;
Ok(out)
}
fn upsample_bilinear2d(
&self,
l: &Layout,
oh: usize,
ow: usize,
align_corners: bool,
scale_h: Option<f64>,
scale_w: Option<f64>,
) -> Result<Self> {
let inp = self.contiguous(l)?;
let (b, c, ih, iw) = l.shape().dims4()?;
let sh = if align_corners {
if oh > 1 {
(ih - 1) as f64 / (oh - 1) as f64
} else {
0.0
}
} else {
scale_h.map(|s| 1.0 / s).unwrap_or(ih as f64 / oh as f64)
};
let sw = if align_corners {
if ow > 1 {
(iw - 1) as f64 / (ow - 1) as f64
} else {
0.0
}
} else {
scale_w.map(|s| 1.0 / s).unwrap_or(iw as f64 / ow as f64)
};
let out = self.device.alloc_f32(b * c * oh * ow)?;
let mut push = push_u32(&[
b as u32,
c as u32,
ih as u32,
iw as u32,
oh as u32,
ow as u32,
align_corners as u32,
]);
push.extend_from_slice(&(sh as f32).to_ne_bytes());
push.extend_from_slice(&(sw as f32).to_ne_bytes());
self.device.dispatch(
"upsample_bilinear2d",
&[inp.buffer, out.buffer],
&push,
Self::groups_1d(b * c * oh * ow),
)?;
Ok(out)
}
fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
if ids.dtype != DType::U32 {
crate::bail!("vulkan: gather requires u32 ids, got {:?}", ids.dtype);
}
let src = self.contiguous(l)?;
let idc = ids.contiguous(ids_l)?;
let out_dims = ids_l.dims();
let src_dims = l.dims();
let right: usize = out_dims[dim + 1..].iter().product();
let dim_out = out_dims[dim];
let dim_src = src_dims[dim];
let n = ids_l.shape().elem_count();
let out = self.device.alloc_f32(n)?;
self.device.dispatch(
"gather",
&[src.buffer, idc.buffer, out.buffer],
&push_u32(&[n as u32, right as u32, dim_out as u32, dim_src as u32]),
Self::groups_1d(n),
)?;
Ok(out)
}
fn scatter_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<()> {
self.scatter_impl("scatter_set", l, ids, ids_l, src, src_l, dim)
}
fn scatter_add_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<()> {
self.scatter_impl("scatter_add_set", l, ids, ids_l, src, src_l, dim)
}
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
if ids.dtype != DType::U32 {
crate::bail!("vulkan: index_select requires u32 ids, got {:?}", ids.dtype);
}
if !ids_l.is_contiguous() {
crate::bail!("vulkan: index_select requires contiguous ids");
}
let src_c = self.contiguous(l)?;
let dims = l.dims();
let left: usize = dims[..dim].iter().product();
let dim_size = dims[dim];
let right: usize = dims[dim + 1..].iter().product();
let n_ids = ids_l.shape().elem_count();
let total = left * n_ids * right;
let out = self.device.alloc_f32(total)?;
let push = push_u32(&[left as u32, dim_size as u32, right as u32, n_ids as u32]);
self.device.dispatch(
"index_select",
&[ids.buffer, src_c.buffer, out.buffer],
&push,
Self::groups_1d(total),
)?;
Ok(out)
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
crate::bail!("vulkan: index_add not implemented")
}
fn matmul(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let mut lkeep = None;
let lc_buf = if lhs_l.is_contiguous() && lhs_l.start_offset() == 0 {
self.buffer
} else {
let s = self.contiguous(lhs_l)?;
let bf = s.buffer;
lkeep = Some(s);
bf
};
let d = rhs_l.dims();
let st = rhs_l.stride();
let nt = rhs_l.start_offset() == 0
&& ((d.len() == 2 && d[0] == k && d[1] == n && st[0] == 1 && st[1] == k)
|| (d.len() == 3
&& d[0] == b
&& d[1] == k
&& d[2] == n
&& st[0] == n * k
&& st[1] == 1
&& st[2] == k));
let mut rkeep = None;
let rc_buf = if nt || (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) {
rhs.buffer
} else {
let s = rhs.contiguous(rhs_l)?;
let bf = s.buffer;
rkeep = Some(s);
bf
};
let _ = (&lkeep, &rkeep); let out = self.device.alloc_f32(b * m * n)?;
let push = push_u32(&[b as u32, m as u32, k as u32, n as u32]);
if self.device.inner.cm_use
&& matches!(self.device.coopmat_info(), Some((16, 16, 16)))
&& m % 16 == 0
&& n % 16 == 0
&& k % 16 == 0
{
let (a16, a16_mem, a16_bytes) = self.device.alloc_f16(b * m * k)?;
let (b16, b16_mem, b16_bytes) = self.device.alloc_f16(b * k * n)?;
self.device.dispatch(
"cast_f2h",
&[lc_buf, a16],
&push_u32(&[(b * m * k) as u32]),
Self::groups_1d(b * m * k),
)?;
self.device.dispatch(
"cast_f2h",
&[rc_buf, b16],
&push_u32(&[(b * k * n) as u32]),
Self::groups_1d(b * k * n),
)?;
let mt = (m / 16) as u32;
let nt_tiles = (n / 16) as u32;
let groups = (nt_tiles.div_ceil(4), mt.div_ceil(4), b as u32);
let kernel = if nt {
"bmm_coopmat_rb_nt"
} else {
"bmm_coopmat_rb"
};
self.device
.dispatch(kernel, &[a16, b16, out.buffer], &push, groups)?;
self.device.free_scratch(a16_bytes, a16, a16_mem);
self.device.free_scratch(b16_bytes, b16, b16_mem);
return Ok(out);
}
let groups = ((n as u32).div_ceil(64), (m as u32).div_ceil(64), b as u32);
let kernel = if nt { "bmm_reg_nt" } else { "bmm_reg" };
self.device
.dispatch(kernel, &[lc_buf, rc_buf, out.buffer], &push, groups)?;
Ok(out)
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let n = src_l.shape().elem_count();
if n == 0 {
return Ok(());
}
if dst_offset != 0 {
crate::bail!("vulkan: copy_strided_src with non-zero dst offset not supported");
}
if dst.count() < n {
crate::bail!(
"vulkan: copy_strided_src dst too small ({} < {n})",
dst.count()
);
}
let dims = src_l.dims();
let rank = dims.len();
if rank > 6 {
crate::bail!("vulkan: copy_strided_src supports rank <= 6, got {rank}");
}
let strides = src_l.stride();
let mut p = vec![n as u32, rank as u32, src_l.start_offset() as u32];
let mut shape6 = [0u32; 6];
let mut stride6 = [0u32; 6];
for d in 0..rank {
shape6[d] = dims[d] as u32;
stride6[d] = strides[d] as u32;
}
p.extend_from_slice(&shape6);
p.extend_from_slice(&stride6);
self.device.dispatch(
"strided_copy",
&[self.buffer, dst.buffer],
&push_u32(&p),
Self::groups_1d(n),
)
}
fn copy2d(
&self,
dst: &mut Self,
d1: usize,
d2: usize,
src_stride1: usize,
dst_stride1: usize,
src_offset: usize,
dst_offset: usize,
) -> Result<()> {
let src = self.to_vec_f32()?;
let mut dstv = dst.to_vec_f32()?;
for i in 0..d1 {
let so = src_offset + i * src_stride1;
let d_o = dst_offset + i * dst_stride1;
dstv[d_o..d_o + d2].copy_from_slice(&src[so..so + d2]);
}
unsafe { dst.device.write_f32(dst.memory, &dstv) }
}
fn const_set(&mut self, s: crate::scalar::Scalar, layout: &Layout) -> Result<()> {
if !layout.is_contiguous() {
crate::bail!("vulkan: const_set requires contiguous layout");
}
let v = s.to_f64() as f32;
let n = layout.shape().elem_count();
let off = layout.start_offset();
let mut data = self.to_vec_f32()?;
if off + n > data.len() {
crate::bail!("vulkan: const_set out of range");
}
for x in &mut data[off..off + n] {
*x = v;
}
unsafe { self.device.write_f32(self.memory, &data) }
}
}