use vyre_foundation::optimizer::AdapterCaps;
use vyre_foundation::validate;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct DeviceProfile {
pub backend: &'static str,
pub supports_subgroup_ops: bool,
pub supports_indirect_dispatch: bool,
pub supports_specialization_constants: bool,
pub supports_f16: bool,
pub supports_bf16: bool,
pub supports_trap_propagation: bool,
pub supports_tensor_cores: bool,
pub has_mul_high: bool,
pub has_dual_issue_fp32_int32: bool,
pub has_subgroup_shuffle: bool,
pub has_shared_memory: bool,
pub max_native_int_width: u32,
pub max_workgroup_size: [u32; 3],
pub max_invocations_per_workgroup: u32,
pub max_shared_memory_bytes: u32,
pub max_storage_buffer_binding_size: u64,
pub subgroup_size: u32,
pub compute_units: u32,
pub regs_per_thread_max: u32,
pub l1_cache_bytes: u32,
pub l2_cache_bytes: u32,
pub mem_bw_gbps: u32,
pub ideal_unroll_depth: u32,
pub ideal_vector_pack_bits: u32,
pub ideal_workgroup_tile: [u32; 3],
pub shared_memory_bank_count: u32,
pub shared_memory_bank_width_bytes: u32,
}
impl Default for DeviceProfile {
fn default() -> Self {
Self::conservative("unknown")
}
}
impl DeviceProfile {
#[must_use]
pub const fn conservative(backend: &'static str) -> Self {
Self {
backend,
supports_subgroup_ops: false,
supports_indirect_dispatch: false,
supports_specialization_constants: false,
supports_f16: false,
supports_bf16: false,
supports_trap_propagation: false,
supports_tensor_cores: false,
has_mul_high: false,
has_dual_issue_fp32_int32: false,
has_subgroup_shuffle: false,
has_shared_memory: false,
max_native_int_width: 32,
max_workgroup_size: [1, 1, 1],
max_invocations_per_workgroup: 1,
max_shared_memory_bytes: 0,
max_storage_buffer_binding_size: 0,
subgroup_size: 0,
compute_units: 0,
regs_per_thread_max: 0,
l1_cache_bytes: 0,
l2_cache_bytes: 0,
mem_bw_gbps: 0,
ideal_unroll_depth: 0,
ideal_vector_pack_bits: 0,
ideal_workgroup_tile: [0, 0, 0],
shared_memory_bank_count: 0,
shared_memory_bank_width_bytes: 0,
}
}
#[must_use]
pub fn from_backend(backend: &dyn crate::backend::VyreBackend) -> Self {
let max_workgroup_size = backend.max_workgroup_size();
Self {
backend: backend.id(),
supports_subgroup_ops: backend.supports_subgroup_ops(),
supports_indirect_dispatch: backend.supports_indirect_dispatch(),
supports_specialization_constants: false,
supports_f16: backend.supports_f16(),
supports_bf16: backend.supports_bf16(),
supports_trap_propagation: false,
supports_tensor_cores: backend.supports_tensor_cores(),
has_mul_high: false,
has_dual_issue_fp32_int32: false,
has_subgroup_shuffle: backend.supports_subgroup_ops(),
has_shared_memory: false,
max_native_int_width: 32,
max_workgroup_size,
max_invocations_per_workgroup: backend.max_compute_invocations_per_workgroup(),
max_shared_memory_bytes: 0,
max_storage_buffer_binding_size: backend.max_storage_buffer_bytes(),
subgroup_size: backend.subgroup_size().unwrap_or(0),
compute_units: 0,
regs_per_thread_max: 0,
l1_cache_bytes: 0,
l2_cache_bytes: 0,
mem_bw_gbps: 0,
ideal_unroll_depth: 0,
ideal_vector_pack_bits: 0,
ideal_workgroup_tile: [0, 0, 0],
shared_memory_bank_count: 0,
shared_memory_bank_width_bytes: 0,
}
}
#[must_use]
pub const fn validation_capabilities(self) -> validate::BackendCapabilities {
validate::BackendCapabilities {
supports_subgroup_ops: self.supports_subgroup_ops,
supports_indirect_dispatch: self.supports_indirect_dispatch,
supports_specialization_constants: self.supports_specialization_constants,
}
}
#[must_use]
pub const fn adapter_caps(self) -> AdapterCaps {
AdapterCaps {
backend: self.backend,
supports_subgroup_ops: self.supports_subgroup_ops,
supports_indirect_dispatch: self.supports_indirect_dispatch,
supports_specialization_constants: self.supports_specialization_constants,
max_workgroup_size: self.max_workgroup_size,
max_invocations_per_workgroup: self.max_invocations_per_workgroup,
max_shared_memory_bytes: self.max_shared_memory_bytes,
max_storage_buffer_binding_size: self.max_storage_buffer_binding_size,
subgroup_size: self.subgroup_size,
compute_units: self.compute_units,
regs_per_thread_max: self.regs_per_thread_max,
l1_cache_bytes: self.l1_cache_bytes,
l2_cache_bytes: self.l2_cache_bytes,
mem_bw_gbps: self.mem_bw_gbps,
ideal_unroll_depth: self.ideal_unroll_depth,
ideal_vector_pack_bits: self.ideal_vector_pack_bits,
ideal_workgroup_tile: self.ideal_workgroup_tile,
shared_memory_bank_count: self.shared_memory_bank_count,
shared_memory_bank_width_bytes: self.shared_memory_bank_width_bytes,
}
}
#[must_use]
pub const fn strategy_capabilities(self) -> crate::strategy::BackendCapabilities {
crate::strategy::BackendCapabilities {
has_mul_high: self.has_mul_high,
has_dual_issue_fp32_int32: self.has_dual_issue_fp32_int32,
has_tensor_core_int: self.supports_tensor_cores,
has_native_f16: self.supports_f16,
has_warp_shuffle: self.has_subgroup_shuffle,
has_shared_memory: self.has_shared_memory,
has_transcendental_polynomial_emit: true,
max_native_int_width: self.max_native_int_width,
}
}
}
impl From<DeviceProfile> for AdapterCaps {
#[inline]
fn from(profile: DeviceProfile) -> Self {
profile.adapter_caps()
}
}
impl From<DeviceProfile> for validate::BackendCapabilities {
#[inline]
fn from(profile: DeviceProfile) -> Self {
profile.validation_capabilities()
}
}
impl From<DeviceProfile> for crate::strategy::BackendCapabilities {
#[inline]
fn from(profile: DeviceProfile) -> Self {
profile.strategy_capabilities()
}
}
#[cfg(test)]
mod tests {
use super::DeviceProfile;
#[test]
fn projections_share_the_same_feature_bits() {
let profile = DeviceProfile {
backend: "test",
supports_subgroup_ops: true,
supports_indirect_dispatch: true,
supports_specialization_constants: true,
supports_f16: true,
supports_bf16: false,
supports_trap_propagation: true,
supports_tensor_cores: true,
has_mul_high: true,
has_dual_issue_fp32_int32: true,
has_subgroup_shuffle: true,
has_shared_memory: true,
max_native_int_width: 64,
max_workgroup_size: [256, 1, 1],
max_invocations_per_workgroup: 256,
max_shared_memory_bytes: 48 * 1024,
max_storage_buffer_binding_size: 1 << 30,
subgroup_size: 32,
compute_units: 128,
regs_per_thread_max: 255,
l1_cache_bytes: 128 * 1024,
l2_cache_bytes: 64 * 1024 * 1024,
mem_bw_gbps: 1700,
ideal_unroll_depth: 8,
ideal_vector_pack_bits: 128,
ideal_workgroup_tile: [16, 16, 1],
shared_memory_bank_count: 32,
shared_memory_bank_width_bytes: 4,
};
let validation = profile.validation_capabilities();
let adapter = profile.adapter_caps();
let strategy = profile.strategy_capabilities();
assert!(validation.supports_subgroup_ops);
assert!(adapter.supports_subgroup_ops);
assert!(strategy.has_warp_shuffle);
assert_eq!(adapter.max_invocations_per_workgroup, 256);
assert_eq!(adapter.ideal_unroll_depth, 8);
assert_eq!(adapter.ideal_vector_pack_bits, 128);
assert_eq!(adapter.ideal_workgroup_tile, [16, 16, 1]);
assert_eq!(strategy.max_native_int_width, 64);
}
}