use serde::Deserialize;
use std::fs;
use std::path::Path;
use crate::DeviceProfile;
#[derive(Clone, Debug, Deserialize, PartialEq)]
pub struct DeviceSignature {
pub id: String,
pub family: String,
#[serde(default)]
pub architecture_generation: Option<u32>,
#[serde(default)]
pub device_name_contains: Vec<String>,
pub max_sm: u32,
pub warp_size: u32,
pub regs_per_thread_max: u32,
pub shared_mem_per_sm_kb: u32,
pub l1_kb: u32,
pub l2_kb: u32,
pub mem_bw_gbps: u32,
pub tensor_core_supported: bool,
#[serde(default)]
pub tensor_core_dtypes: Vec<String>,
pub ideal_unroll_depth: u32,
pub ideal_vector_pack_bits: u32,
pub ideal_workgroup_tile: [u32; 3],
pub bank_count: u32,
pub bank_width_bytes: u32,
}
#[derive(Clone, Debug, Default, PartialEq)]
pub struct DeviceSignatureTable {
signatures: Vec<DeviceSignature>,
}
impl DeviceSignature {
pub const BUILTIN_BLACKWELL_120: &'static str =
include_str!("../../devices/blackwell_120.toml");
pub fn from_toml_str(source: &str) -> Result<Self, String> {
let signature: Self = toml::from_str(source)
.map_err(|error| format!("device signature TOML parse failed. Fix: {error}"))?;
signature.validate()?;
Ok(signature)
}
pub fn validate(&self) -> Result<(), String> {
if self.id.trim().is_empty() {
return Err("device signature id is empty. Fix: set a stable id.".to_string());
}
if self.family.trim().is_empty() {
return Err(
"device signature family is empty. Fix: set the architecture family.".to_string(),
);
}
if self.warp_size == 0 || !self.warp_size.is_power_of_two() {
return Err(format!(
"device signature `{}` has invalid warp_size {}. Fix: use a non-zero power of two.",
self.id, self.warp_size
));
}
if self.ideal_vector_pack_bits == 0 || self.ideal_vector_pack_bits % 32 != 0 {
return Err(format!(
"device signature `{}` has invalid ideal_vector_pack_bits {}. Fix: use a positive multiple of 32.",
self.id, self.ideal_vector_pack_bits
));
}
if self.ideal_workgroup_tile.contains(&0) {
return Err(format!(
"device signature `{}` has a zero ideal_workgroup_tile axis. Fix: every axis must be positive.",
self.id
));
}
if self.bank_count == 0 || self.bank_width_bytes == 0 {
return Err(format!(
"device signature `{}` has invalid shared-memory bank metadata. Fix: bank_count and bank_width_bytes must be non-zero.",
self.id
));
}
Ok(())
}
#[must_use]
pub fn apply_to_profile(&self, mut profile: DeviceProfile) -> DeviceProfile {
profile.subgroup_size = self.warp_size;
profile.supports_tensor_cores = self.tensor_core_supported;
profile.has_subgroup_shuffle = self.warp_size > 0;
profile.has_shared_memory |= self.shared_mem_per_sm_kb > 0;
if profile.max_shared_memory_bytes == 0 {
profile.max_shared_memory_bytes = self.shared_mem_per_sm_kb.saturating_mul(1024);
}
profile.compute_units = self.max_sm;
profile.regs_per_thread_max = self.regs_per_thread_max;
profile.l1_cache_bytes = self.l1_kb.saturating_mul(1024);
profile.l2_cache_bytes = self.l2_kb.saturating_mul(1024);
profile.mem_bw_gbps = self.mem_bw_gbps;
profile.ideal_unroll_depth = self.ideal_unroll_depth;
profile.ideal_vector_pack_bits = self.ideal_vector_pack_bits;
profile.ideal_workgroup_tile = self.ideal_workgroup_tile;
profile.shared_memory_bank_count = self.bank_count;
profile.shared_memory_bank_width_bytes = self.bank_width_bytes;
profile
}
#[must_use]
pub fn matches_architecture_generation(&self, generation: u32) -> bool {
self.architecture_generation == Some(generation)
|| self.id.rsplit('_').next().and_then(parse_u32) == Some(generation)
}
#[must_use]
pub fn matches_device_name(&self, device_name: &str) -> bool {
let device_name = device_name.to_ascii_lowercase();
self.device_name_contains
.iter()
.any(|needle| device_name.contains(&needle.to_ascii_lowercase()))
}
}
impl DeviceSignatureTable {
pub fn builtins() -> Result<Self, String> {
let mut signatures = vec![DeviceSignature::from_toml_str(
DeviceSignature::BUILTIN_BLACKWELL_120,
)?];
signatures.sort_by(|left, right| left.id.cmp(&right.id));
Ok(Self { signatures })
}
pub fn load_dir(dir: impl AsRef<Path>) -> Result<Self, String> {
let dir = dir.as_ref();
let entries = fs::read_dir(dir).map_err(|error| {
format!(
"device signature directory `{}` cannot be read. Fix: create it or pass the correct path: {error}",
dir.display()
)
})?;
let mut signatures = Vec::new();
for entry in entries {
let entry = entry.map_err(|error| {
format!(
"device signature directory `{}` contains an unreadable entry. Fix: {error}",
dir.display()
)
})?;
let path = entry.path();
if path.extension().and_then(|ext| ext.to_str()) != Some("toml") {
continue;
}
let source = fs::read_to_string(&path).map_err(|error| {
format!(
"device signature file `{}` cannot be read. Fix: {error}",
path.display()
)
})?;
let signature = DeviceSignature::from_toml_str(&source)
.map_err(|error| format!("{} in `{}`", error, path.display()))?;
signatures.push(signature);
}
signatures.sort_by(|left, right| left.id.cmp(&right.id));
dedupe_signature_ids(&signatures)?;
Ok(Self { signatures })
}
#[must_use]
pub fn signatures(&self) -> &[DeviceSignature] {
&self.signatures
}
#[must_use]
pub fn get(&self, id: &str) -> Option<&DeviceSignature> {
self.signatures
.binary_search_by(|signature| signature.id.as_str().cmp(id))
.ok()
.and_then(|index| self.signatures.get(index))
}
#[must_use]
pub fn find_architecture_generation(&self, generation: u32) -> Option<&DeviceSignature> {
self.signatures
.iter()
.find(|signature| signature.matches_architecture_generation(generation))
}
#[must_use]
pub fn find_device_name(&self, device_name: &str) -> Option<&DeviceSignature> {
self.signatures
.iter()
.find(|signature| signature.matches_device_name(device_name))
}
#[must_use]
pub fn apply_generation_to_profile(
&self,
generation: u32,
profile: DeviceProfile,
) -> DeviceProfile {
self.find_architecture_generation(generation)
.map_or(profile, |signature| signature.apply_to_profile(profile))
}
#[must_use]
pub fn apply_device_name_to_profile(
&self,
device_name: &str,
profile: DeviceProfile,
) -> DeviceProfile {
self.find_device_name(device_name)
.map_or(profile, |signature| signature.apply_to_profile(profile))
}
}
fn dedupe_signature_ids(signatures: &[DeviceSignature]) -> Result<(), String> {
for pair in signatures.windows(2) {
if pair[0].id == pair[1].id {
return Err(format!(
"duplicate device signature id `{}`. Fix: keep exactly one TOML file per id.",
pair[0].id
));
}
}
Ok(())
}
fn parse_u32(value: &str) -> Option<u32> {
let mut out = 0u32;
for byte in value.bytes() {
if !byte.is_ascii_digit() {
return None;
}
out = out.checked_mul(10)?.checked_add(u32::from(byte - b'0'))?;
}
Some(out)
}
#[cfg(test)]
mod tests {
use super::{DeviceSignature, DeviceSignatureTable};
use crate::DeviceProfile;
const SAMPLE: &str = r#"
id = "sample_arch"
family = "sample"
max_sm = 128
warp_size = 32
regs_per_thread_max = 255
shared_mem_per_sm_kb = 128
l1_kb = 128
l2_kb = 98304
mem_bw_gbps = 1700
tensor_core_supported = true
tensor_core_dtypes = ["f16", "bf16", "tf32"]
ideal_unroll_depth = 8
ideal_vector_pack_bits = 128
ideal_workgroup_tile = [16, 16, 1]
bank_count = 32
bank_width_bytes = 4
"#;
#[test]
fn parses_and_validates_signature() {
let signature = DeviceSignature::from_toml_str(SAMPLE).unwrap();
assert_eq!(signature.id, "sample_arch");
assert_eq!(signature.architecture_generation, None);
assert_eq!(signature.warp_size, 32);
assert!(signature.tensor_core_supported);
}
#[test]
fn rejects_invalid_warp_size() {
let err =
DeviceSignature::from_toml_str(&SAMPLE.replace("warp_size = 32", "warp_size = 48"))
.unwrap_err();
assert!(err.contains("warp_size"));
}
#[test]
fn applies_architecture_facts_to_profile() {
let signature = DeviceSignature::from_toml_str(SAMPLE).unwrap();
let profile = signature.apply_to_profile(DeviceProfile::conservative("test"));
assert_eq!(profile.subgroup_size, 32);
assert_eq!(profile.max_shared_memory_bytes, 128 * 1024);
assert_eq!(profile.compute_units, 128);
assert_eq!(profile.ideal_vector_pack_bits, 128);
assert_eq!(profile.shared_memory_bank_width_bytes, 4);
assert!(profile.supports_tensor_cores);
}
#[test]
fn preserves_live_shared_memory_per_workgroup_limit() {
let signature = DeviceSignature::from_toml_str(SAMPLE).unwrap();
let mut live = DeviceProfile::conservative("native");
live.max_shared_memory_bytes = 48 * 1024;
let profile = signature.apply_to_profile(live);
assert_eq!(profile.max_shared_memory_bytes, 48 * 1024);
assert!(profile.has_shared_memory);
assert_eq!(profile.shared_memory_bank_count, 32);
}
#[test]
fn loads_directory_in_id_order() {
let dir = tempfile::tempdir().unwrap();
std::fs::write(
dir.path().join("b.toml"),
SAMPLE.replace("sample_arch", "b"),
)
.unwrap();
std::fs::write(
dir.path().join("a.toml"),
SAMPLE.replace("sample_arch", "a"),
)
.unwrap();
let table = DeviceSignatureTable::load_dir(dir.path()).unwrap();
assert_eq!(table.signatures()[0].id, "a");
assert_eq!(table.signatures()[1].id, "b");
assert!(table.get("b").is_some());
}
#[test]
fn repository_device_signatures_load() {
let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../devices");
let table = DeviceSignatureTable::load_dir(dir).unwrap();
assert!(table.get("blackwell_120").is_some());
}
#[test]
fn builtins_match_generation_and_device_name() {
let table = DeviceSignatureTable::builtins().unwrap();
let signature = table.find_architecture_generation(120).unwrap();
assert_eq!(signature.id, "blackwell_120");
assert!(table.find_device_name("RTX 5090").is_some());
}
#[test]
fn builtin_signature_materially_projects_planner_fields() {
let table = DeviceSignatureTable::builtins().unwrap();
let profile =
table.apply_generation_to_profile(120, DeviceProfile::conservative("backend"));
assert_eq!(profile.ideal_unroll_depth, 8);
assert_eq!(profile.ideal_vector_pack_bits, 128);
assert_eq!(profile.ideal_workgroup_tile, [16, 16, 1]);
assert_eq!(profile.shared_memory_bank_count, 32);
}
}