#![cfg(all(feature = "intel", any(target_os = "linux", target_os = "windows")))]
use crate::error::{HiveGpuError, Result};
use crate::traits::{GpuBackend, GpuContext};
use crate::types::{GpuCapabilities, GpuDeviceInfo, GpuMemoryStats};
use ash::vk;
use std::ffi::{CStr, CString};
use std::sync::Arc;
use tracing::debug;
pub const INTEL_VENDOR_ID: u32 = 0x8086;
pub const UNIVERSAL_ENV: &str = "HIVE_GPU_VULKAN_UNIVERSAL";
const SGEMV_DOT_SPIRV: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/sgemv_dot.spv"));
const SGEMM_DOT_SPIRV: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/sgemm_dot.spv"));
pub struct IntelContext {
sgemm_dot_pipeline: vk::Pipeline,
sgemv_dot_pipeline: vk::Pipeline,
sgemm_dot_layout: vk::PipelineLayout,
sgemv_dot_layout: vk::PipelineLayout,
sgemm_dot_module: vk::ShaderModule,
sgemv_dot_module: vk::ShaderModule,
sgemm_dot_set_layout: vk::DescriptorSetLayout,
sgemv_dot_set_layout: vk::DescriptorSetLayout,
command_pool: vk::CommandPool,
descriptor_pool: vk::DescriptorPool,
queue: vk::Queue,
queue_family_index: u32,
device: ash::Device,
#[allow(dead_code)] physical_device: vk::PhysicalDevice,
instance: ash::Instance,
#[allow(dead_code)] entry: ash::Entry,
device_name: String,
vendor_id: u32,
device_id: u32,
api_version: u32,
driver_version: u32,
#[allow(dead_code)] device_type: vk::PhysicalDeviceType,
memory_properties: vk::PhysicalDeviceMemoryProperties,
limits: vk::PhysicalDeviceLimits,
}
unsafe impl Send for IntelContext {}
unsafe impl Sync for IntelContext {}
impl std::fmt::Debug for IntelContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IntelContext")
.field("device_name", &self.device_name)
.field("vendor_id", &format_args!("0x{:04x}", self.vendor_id))
.field("device_id", &format_args!("0x{:04x}", self.device_id))
.field("api_version", &format_args!("0x{:x}", self.api_version))
.finish()
}
}
impl IntelContext {
pub fn new() -> Result<Arc<Self>> {
let universal = std::env::var(UNIVERSAL_ENV)
.map(|v| !v.is_empty())
.unwrap_or(false);
Self::new_with_preference(universal)
}
pub fn new_with_preference(universal: bool) -> Result<Arc<Self>> {
let entry = unsafe { ash::Entry::load() }.map_err(|e| {
HiveGpuError::IntelError(format!("failed to load Vulkan loader: {e:?}"))
})?;
let app_name = CString::new("hive-gpu").unwrap();
let app_info = vk::ApplicationInfo::default()
.application_name(&app_name)
.application_version(vk::make_api_version(0, 0, 2, 0))
.engine_name(&app_name)
.engine_version(vk::make_api_version(0, 0, 2, 0))
.api_version(vk::API_VERSION_1_2);
let create_info = vk::InstanceCreateInfo::default().application_info(&app_info);
let instance = unsafe { entry.create_instance(&create_info, None) }
.map_err(|e| HiveGpuError::VulkanError(format!("create_instance: {e:?}")))?;
let (physical_device, queue_family_index, props) =
select_physical_device(&instance, universal)?;
let queue_priorities = [1.0f32];
let queue_infos = [vk::DeviceQueueCreateInfo::default()
.queue_family_index(queue_family_index)
.queue_priorities(&queue_priorities)];
let device_create_info = vk::DeviceCreateInfo::default().queue_create_infos(&queue_infos);
let device = unsafe { instance.create_device(physical_device, &device_create_info, None) }
.map_err(|e| HiveGpuError::VulkanError(format!("create_device: {e:?}")))?;
let queue = unsafe { device.get_device_queue(queue_family_index, 0) };
let cp_info = vk::CommandPoolCreateInfo::default()
.queue_family_index(queue_family_index)
.flags(vk::CommandPoolCreateFlags::RESET_COMMAND_BUFFER);
let command_pool = unsafe { device.create_command_pool(&cp_info, None) }
.map_err(|e| HiveGpuError::VulkanError(format!("create_command_pool: {e:?}")))?;
let dp_sizes = [vk::DescriptorPoolSize {
ty: vk::DescriptorType::STORAGE_BUFFER,
descriptor_count: 4096, }];
let dp_info = vk::DescriptorPoolCreateInfo::default()
.max_sets(1024)
.pool_sizes(&dp_sizes)
.flags(vk::DescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET);
let descriptor_pool = unsafe { device.create_descriptor_pool(&dp_info, None) }
.map_err(|e| HiveGpuError::VulkanError(format!("create_descriptor_pool: {e:?}")))?;
let (sgemv_dot_module, sgemv_dot_set_layout, sgemv_dot_layout, sgemv_dot_pipeline) =
build_pipeline(
&device,
SGEMV_DOT_SPIRV,
3, std::mem::size_of::<SgemvPushConstants>() as u32,
)?;
let (sgemm_dot_module, sgemm_dot_set_layout, sgemm_dot_layout, sgemm_dot_pipeline) =
build_pipeline(
&device,
SGEMM_DOT_SPIRV,
3, std::mem::size_of::<SgemmPushConstants>() as u32,
)?;
let device_name = unsafe { CStr::from_ptr(props.device_name.as_ptr()) }
.to_string_lossy()
.into_owned();
let memory_properties =
unsafe { instance.get_physical_device_memory_properties(physical_device) };
debug!(
"intel context ready: device={:?} vendor=0x{:04x} id=0x{:04x}",
device_name, props.vendor_id, props.device_id
);
Ok(Arc::new(Self {
sgemm_dot_pipeline,
sgemv_dot_pipeline,
sgemm_dot_layout,
sgemv_dot_layout,
sgemm_dot_module,
sgemv_dot_module,
sgemm_dot_set_layout,
sgemv_dot_set_layout,
command_pool,
descriptor_pool,
queue,
queue_family_index,
device,
physical_device,
instance,
entry,
device_name,
vendor_id: props.vendor_id,
device_id: props.device_id,
api_version: props.api_version,
driver_version: props.driver_version,
device_type: props.device_type,
memory_properties,
limits: props.limits,
}))
}
pub fn is_available() -> bool {
let Ok(entry) = (unsafe { ash::Entry::load() }) else {
return false;
};
let app_name = CString::new("hive-gpu-probe").unwrap();
let app_info = vk::ApplicationInfo::default()
.application_name(&app_name)
.api_version(vk::API_VERSION_1_2);
let create_info = vk::InstanceCreateInfo::default().application_info(&app_info);
let instance = match unsafe { entry.create_instance(&create_info, None) } {
Ok(i) => i,
Err(_) => return false,
};
let universal = std::env::var(UNIVERSAL_ENV)
.map(|v| !v.is_empty())
.unwrap_or(false);
let found = enumerate_matching_devices(&instance, universal).is_some();
unsafe { instance.destroy_instance(None) };
found
}
pub fn device_name(&self) -> &str {
&self.device_name
}
pub fn vendor_id(&self) -> u32 {
self.vendor_id
}
pub(crate) fn device(&self) -> &ash::Device {
&self.device
}
pub(crate) fn queue(&self) -> vk::Queue {
self.queue
}
#[allow(dead_code)] pub(crate) fn queue_family_index(&self) -> u32 {
self.queue_family_index
}
pub(crate) fn command_pool(&self) -> vk::CommandPool {
self.command_pool
}
pub(crate) fn descriptor_pool(&self) -> vk::DescriptorPool {
self.descriptor_pool
}
pub(crate) fn memory_properties(&self) -> &vk::PhysicalDeviceMemoryProperties {
&self.memory_properties
}
pub(crate) fn sgemv_dot(&self) -> ComputePipeline {
ComputePipeline {
pipeline: self.sgemv_dot_pipeline,
layout: self.sgemv_dot_layout,
set_layout: self.sgemv_dot_set_layout,
}
}
pub(crate) fn sgemm_dot(&self) -> ComputePipeline {
ComputePipeline {
pipeline: self.sgemm_dot_pipeline,
layout: self.sgemm_dot_layout,
set_layout: self.sgemm_dot_set_layout,
}
}
}
#[derive(Clone, Copy)]
pub(crate) struct ComputePipeline {
pub pipeline: vk::Pipeline,
pub layout: vk::PipelineLayout,
pub set_layout: vk::DescriptorSetLayout,
}
#[repr(C)]
#[derive(Clone, Copy)]
pub(crate) struct SgemvPushConstants {
pub dimension: u32,
pub n_vectors: u32,
}
#[repr(C)]
#[derive(Clone, Copy)]
pub(crate) struct SgemmPushConstants {
pub dimension: u32,
pub n_list: u32,
pub n_samples: u32,
}
impl Drop for IntelContext {
fn drop(&mut self) {
unsafe {
let _ = self.device.device_wait_idle();
self.device.destroy_pipeline(self.sgemm_dot_pipeline, None);
self.device.destroy_pipeline(self.sgemv_dot_pipeline, None);
self.device
.destroy_pipeline_layout(self.sgemm_dot_layout, None);
self.device
.destroy_pipeline_layout(self.sgemv_dot_layout, None);
self.device
.destroy_shader_module(self.sgemm_dot_module, None);
self.device
.destroy_shader_module(self.sgemv_dot_module, None);
self.device
.destroy_descriptor_set_layout(self.sgemm_dot_set_layout, None);
self.device
.destroy_descriptor_set_layout(self.sgemv_dot_set_layout, None);
self.device.destroy_command_pool(self.command_pool, None);
self.device
.destroy_descriptor_pool(self.descriptor_pool, None);
self.device.destroy_device(None);
self.instance.destroy_instance(None);
}
}
}
fn select_physical_device(
instance: &ash::Instance,
universal: bool,
) -> Result<(vk::PhysicalDevice, u32, vk::PhysicalDeviceProperties)> {
enumerate_matching_devices(instance, universal).ok_or_else(|| {
if universal {
HiveGpuError::NoDeviceAvailable
} else {
HiveGpuError::IntelError(
"no Intel GPU found; set HIVE_GPU_VULKAN_UNIVERSAL=1 to accept any vendor"
.to_string(),
)
}
})
}
fn enumerate_matching_devices(
instance: &ash::Instance,
universal: bool,
) -> Option<(vk::PhysicalDevice, u32, vk::PhysicalDeviceProperties)> {
let devices = unsafe { instance.enumerate_physical_devices() }.ok()?;
let mut fallback: Option<(vk::PhysicalDevice, u32, vk::PhysicalDeviceProperties)> = None;
for pd in devices {
let props = unsafe { instance.get_physical_device_properties(pd) };
if !universal && props.vendor_id != INTEL_VENDOR_ID {
continue;
}
let qfp = unsafe { instance.get_physical_device_queue_family_properties(pd) };
let qfi = qfp.iter().enumerate().find_map(|(idx, qfp)| {
if qfp.queue_flags.contains(vk::QueueFlags::COMPUTE) {
Some(idx as u32)
} else {
None
}
});
let Some(qfi) = qfi else {
continue;
};
if props.device_type == vk::PhysicalDeviceType::DISCRETE_GPU {
return Some((pd, qfi, props));
}
if fallback.is_none() {
fallback = Some((pd, qfi, props));
}
}
fallback
}
fn build_pipeline(
device: &ash::Device,
spirv: &[u8],
n_bindings: u32,
push_constant_size: u32,
) -> Result<(
vk::ShaderModule,
vk::DescriptorSetLayout,
vk::PipelineLayout,
vk::Pipeline,
)> {
if spirv.len() % 4 != 0 {
return Err(HiveGpuError::SpirvCompileError(
"shader SPIR-V size not a multiple of 4 bytes".to_string(),
));
}
let code: Vec<u32> = spirv
.chunks_exact(4)
.map(|b| u32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
let module_info = vk::ShaderModuleCreateInfo::default().code(&code);
let module = unsafe { device.create_shader_module(&module_info, None) }
.map_err(|e| HiveGpuError::SpirvCompileError(format!("create_shader_module: {e:?}")))?;
let bindings: Vec<vk::DescriptorSetLayoutBinding> = (0..n_bindings)
.map(|i| {
vk::DescriptorSetLayoutBinding::default()
.binding(i)
.descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
.descriptor_count(1)
.stage_flags(vk::ShaderStageFlags::COMPUTE)
})
.collect();
let set_layout_info = vk::DescriptorSetLayoutCreateInfo::default().bindings(&bindings);
let set_layout = unsafe { device.create_descriptor_set_layout(&set_layout_info, None) }
.map_err(|e| HiveGpuError::VulkanError(format!("create_descriptor_set_layout: {e:?}")))?;
let push_constant_range = [vk::PushConstantRange::default()
.stage_flags(vk::ShaderStageFlags::COMPUTE)
.offset(0)
.size(push_constant_size)];
let set_layouts = [set_layout];
let pl_info = vk::PipelineLayoutCreateInfo::default()
.set_layouts(&set_layouts)
.push_constant_ranges(&push_constant_range);
let layout = unsafe { device.create_pipeline_layout(&pl_info, None) }
.map_err(|e| HiveGpuError::VulkanError(format!("create_pipeline_layout: {e:?}")))?;
let entry_name = CString::new("main").unwrap();
let stage = vk::PipelineShaderStageCreateInfo::default()
.stage(vk::ShaderStageFlags::COMPUTE)
.module(module)
.name(&entry_name);
let pipeline_info = [vk::ComputePipelineCreateInfo::default()
.stage(stage)
.layout(layout)];
let pipelines =
unsafe { device.create_compute_pipelines(vk::PipelineCache::null(), &pipeline_info, None) }
.map_err(|(_, e)| {
HiveGpuError::VulkanError(format!("create_compute_pipelines: {e:?}"))
})?;
Ok((module, set_layout, layout, pipelines[0]))
}
impl GpuBackend for IntelContext {
fn device_info(&self) -> GpuDeviceInfo {
let total_vram_bytes: u64 = self.memory_properties.memory_heaps
[..self.memory_properties.memory_heap_count as usize]
.iter()
.filter(|h| h.flags.contains(vk::MemoryHeapFlags::DEVICE_LOCAL))
.map(|h| h.size)
.sum();
GpuDeviceInfo {
name: self.device_name.clone(),
backend: "Intel".to_string(),
total_vram_bytes,
available_vram_bytes: total_vram_bytes,
used_vram_bytes: 0,
driver_version: format!(
"Vulkan {}.{}.{} (driver 0x{:x})",
vk::api_version_major(self.api_version),
vk::api_version_minor(self.api_version),
vk::api_version_patch(self.api_version),
self.driver_version
),
compute_capability: Some(format!(
"vk{}.{}-0x{:04x}",
vk::api_version_major(self.api_version),
vk::api_version_minor(self.api_version),
self.device_id
)),
max_threads_per_block: self.limits.max_compute_work_group_invocations,
max_shared_memory_per_block: self.limits.max_compute_shared_memory_size as u64,
device_id: self.device_id as i32,
pci_bus_id: None,
}
}
fn supports_operations(&self) -> GpuCapabilities {
GpuCapabilities {
supports_hnsw: false,
supports_batch: true,
max_dimension: 4096,
max_batch_size: 100_000,
}
}
fn memory_stats(&self) -> GpuMemoryStats {
let info = GpuBackend::device_info(self);
let used = info.used_vram_bytes as usize;
let available = info.available_vram_bytes as usize;
let total = info.total_vram_bytes as usize;
GpuMemoryStats {
total_allocated: used,
available,
utilization: if total == 0 {
0.0
} else {
used as f32 / total as f32
},
buffer_count: 0,
}
}
}
impl GpuContext for IntelContext {
fn create_storage(
&self,
dimension: usize,
metric: crate::types::GpuDistanceMetric,
) -> Result<Box<dyn crate::traits::GpuVectorStorage>> {
use super::vector_storage::IntelVectorStorage;
let ctx: Arc<Self> = unsafe {
let raw = self as *const Self;
Arc::increment_strong_count(raw);
Arc::from_raw(raw)
};
let storage = IntelVectorStorage::new(ctx, dimension, metric)?;
Ok(Box::new(storage))
}
fn create_storage_with_config(
&self,
dimension: usize,
metric: crate::types::GpuDistanceMetric,
_config: crate::types::HnswConfig,
) -> Result<Box<dyn crate::traits::GpuVectorStorage>> {
self.create_storage(dimension, metric)
}
fn memory_stats(&self) -> GpuMemoryStats {
GpuBackend::memory_stats(self)
}
fn device_info(&self) -> Result<GpuDeviceInfo> {
Ok(GpuBackend::device_info(self))
}
}