use crate::dialect_lookup::DialectLookup;
use crate::ir::DataType;
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct BackendCapabilities {
pub supports_subgroup_ops: bool,
pub supports_indirect_dispatch: bool,
pub supports_specialization_constants: bool,
}
pub trait BackendValidationCapabilities {
fn backend_name(&self) -> &'static str;
fn supports_cast_target(&self, target: &DataType) -> bool;
#[inline]
fn supports_subgroup_ops(&self) -> bool {
false
}
#[inline]
fn supports_indirect_dispatch(&self) -> bool {
false
}
#[inline]
fn supports_specialization_constants(&self) -> bool {
false
}
#[must_use]
#[inline]
fn backend_capabilities(&self) -> BackendCapabilities {
BackendCapabilities {
supports_subgroup_ops: self.supports_subgroup_ops(),
supports_indirect_dispatch: self.supports_indirect_dispatch(),
supports_specialization_constants: self.supports_specialization_constants(),
}
}
}
#[derive(Clone, Copy, Default)]
pub struct ValidationOptions<'a> {
pub backend: Option<&'a dyn BackendValidationCapabilities>,
pub backend_capabilities: Option<BackendCapabilities>,
pub dialect_lookup: Option<&'a dyn DialectLookup>,
pub allow_shadowing: bool,
}
impl<'a> ValidationOptions<'a> {
#[must_use]
#[inline]
pub fn universal() -> Self {
Self::default()
}
#[must_use]
#[inline]
pub fn with_backend(mut self, backend: &'a dyn BackendValidationCapabilities) -> Self {
self.backend = Some(backend);
self.backend_capabilities = Some(backend.backend_capabilities());
self
}
#[must_use]
#[inline]
pub fn with_backend_capabilities(mut self, backend_capabilities: BackendCapabilities) -> Self {
self.backend_capabilities = Some(backend_capabilities);
self
}
#[must_use]
#[inline]
pub fn with_dialect_lookup(mut self, lookup: &'a dyn DialectLookup) -> Self {
self.dialect_lookup = Some(lookup);
self
}
#[must_use]
#[inline]
pub fn with_shadowing(mut self, allow_shadowing: bool) -> Self {
self.allow_shadowing = allow_shadowing;
self
}
#[must_use]
#[inline]
pub fn backend_name(&self) -> &'static str {
self.backend
.map(BackendValidationCapabilities::backend_name)
.unwrap_or("best-effort universal")
}
#[must_use]
#[inline]
pub fn supports_cast_target(&self, target: &DataType) -> bool {
self.backend
.map(|backend| backend.supports_cast_target(target))
.unwrap_or(true)
}
#[must_use]
#[inline]
pub fn requires_subgroup_ops(&self) -> bool {
self.backend_capabilities
.is_some_and(|caps| caps.supports_subgroup_ops)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct CapabilityFixtureBackend;
impl BackendValidationCapabilities for CapabilityFixtureBackend {
fn backend_name(&self) -> &'static str {
"capability-fixture-gpu"
}
fn supports_cast_target(&self, target: &DataType) -> bool {
matches!(target, DataType::U32 | DataType::F32)
}
fn supports_subgroup_ops(&self) -> bool {
true
}
}
#[test]
fn universal_defaults() {
let opts = ValidationOptions::universal();
assert!(opts.backend.is_none());
assert!(!opts.allow_shadowing);
assert_eq!(opts.backend_name(), "best-effort universal");
}
#[test]
fn with_backend_sets_name_and_caps() {
let backend = CapabilityFixtureBackend;
let opts = ValidationOptions::universal().with_backend(&backend);
assert_eq!(opts.backend_name(), "capability-fixture-gpu");
assert!(opts.requires_subgroup_ops());
}
#[test]
fn supports_cast_target_delegates_to_backend() {
let backend = CapabilityFixtureBackend;
let opts = ValidationOptions::universal().with_backend(&backend);
assert!(opts.supports_cast_target(&DataType::U32));
assert!(!opts.supports_cast_target(&DataType::Bool));
}
#[test]
fn supports_cast_target_defaults_true_without_backend() {
let opts = ValidationOptions::universal();
assert!(opts.supports_cast_target(&DataType::Bool));
}
#[test]
fn with_shadowing_toggle() {
let opts = ValidationOptions::universal().with_shadowing(true);
assert!(opts.allow_shadowing);
}
#[test]
fn backend_capabilities_default() {
let caps = BackendCapabilities::default();
assert!(!caps.supports_subgroup_ops);
assert!(!caps.supports_indirect_dispatch);
assert!(!caps.supports_specialization_constants);
}
#[test]
fn with_backend_capabilities_snapshot() {
let caps = BackendCapabilities {
supports_subgroup_ops: true,
supports_indirect_dispatch: false,
supports_specialization_constants: false,
};
let opts = ValidationOptions::universal().with_backend_capabilities(caps);
assert!(opts.requires_subgroup_ops());
}
}