1use std::collections::HashSet;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum CapabilityFlag {
10 Float64,
12 Int64,
14 Atomic64,
16 CooperativeGroups,
18 Subgroups,
20 SubgroupShuffle,
22 SubgroupVote,
24 SubgroupReduce,
26 SharedMemory,
28 DynamicSharedMemory,
30 IndirectCommands,
32 PersistentKernels,
34 Float16,
36 TensorCores,
38 RayTracing,
40 BindlessTextures,
42 UnifiedMemory,
44 MultiGpu,
46}
47
48impl std::fmt::Display for CapabilityFlag {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 match self {
51 CapabilityFlag::Float64 => write!(f, "float64"),
52 CapabilityFlag::Int64 => write!(f, "int64"),
53 CapabilityFlag::Atomic64 => write!(f, "atomic64"),
54 CapabilityFlag::CooperativeGroups => write!(f, "cooperative_groups"),
55 CapabilityFlag::Subgroups => write!(f, "subgroups"),
56 CapabilityFlag::SubgroupShuffle => write!(f, "subgroup_shuffle"),
57 CapabilityFlag::SubgroupVote => write!(f, "subgroup_vote"),
58 CapabilityFlag::SubgroupReduce => write!(f, "subgroup_reduce"),
59 CapabilityFlag::SharedMemory => write!(f, "shared_memory"),
60 CapabilityFlag::DynamicSharedMemory => write!(f, "dynamic_shared_memory"),
61 CapabilityFlag::IndirectCommands => write!(f, "indirect_commands"),
62 CapabilityFlag::PersistentKernels => write!(f, "persistent_kernels"),
63 CapabilityFlag::Float16 => write!(f, "float16"),
64 CapabilityFlag::TensorCores => write!(f, "tensor_cores"),
65 CapabilityFlag::RayTracing => write!(f, "ray_tracing"),
66 CapabilityFlag::BindlessTextures => write!(f, "bindless_textures"),
67 CapabilityFlag::UnifiedMemory => write!(f, "unified_memory"),
68 CapabilityFlag::MultiGpu => write!(f, "multi_gpu"),
69 }
70 }
71}
72
73#[derive(Debug, Clone, Default)]
75pub struct Capabilities {
76 flags: HashSet<CapabilityFlag>,
77}
78
79impl Capabilities {
80 pub fn new() -> Self {
82 Self::default()
83 }
84
85 pub fn with_flags(flags: impl IntoIterator<Item = CapabilityFlag>) -> Self {
87 Self {
88 flags: flags.into_iter().collect(),
89 }
90 }
91
92 pub fn add(&mut self, flag: CapabilityFlag) {
94 self.flags.insert(flag);
95 }
96
97 pub fn remove(&mut self, flag: CapabilityFlag) {
99 self.flags.remove(&flag);
100 }
101
102 pub fn has(&self, flag: CapabilityFlag) -> bool {
104 self.flags.contains(&flag)
105 }
106
107 pub fn satisfies(&self, required: &Capabilities) -> bool {
109 required.flags.iter().all(|f| self.flags.contains(f))
110 }
111
112 pub fn missing(&self, required: &Capabilities) -> Vec<CapabilityFlag> {
114 required
115 .flags
116 .iter()
117 .filter(|f| !self.flags.contains(f))
118 .copied()
119 .collect()
120 }
121
122 pub fn merge(&mut self, other: &Capabilities) {
124 self.flags.extend(&other.flags);
125 }
126
127 pub fn flags(&self) -> &HashSet<CapabilityFlag> {
129 &self.flags
130 }
131
132 pub fn is_empty(&self) -> bool {
134 self.flags.is_empty()
135 }
136}
137
138#[derive(Debug, Clone)]
140pub struct BackendCapabilities {
141 pub name: String,
143 pub capabilities: Capabilities,
145 pub max_threads_per_block: u32,
147 pub max_shared_memory: u32,
149 pub warp_size: u32,
151 pub max_registers: u32,
153}
154
155impl BackendCapabilities {
156 pub fn cuda_sm80() -> Self {
158 Self {
159 name: "CUDA SM 8.0".to_string(),
160 capabilities: Capabilities::with_flags([
161 CapabilityFlag::Float64,
162 CapabilityFlag::Int64,
163 CapabilityFlag::Atomic64,
164 CapabilityFlag::CooperativeGroups,
165 CapabilityFlag::Subgroups,
166 CapabilityFlag::SubgroupShuffle,
167 CapabilityFlag::SubgroupVote,
168 CapabilityFlag::SubgroupReduce,
169 CapabilityFlag::SharedMemory,
170 CapabilityFlag::DynamicSharedMemory,
171 CapabilityFlag::PersistentKernels,
172 CapabilityFlag::Float16,
173 CapabilityFlag::TensorCores,
174 CapabilityFlag::UnifiedMemory,
175 ]),
176 max_threads_per_block: 1024,
177 max_shared_memory: 163840, warp_size: 32,
179 max_registers: 255,
180 }
181 }
182
183 pub fn wgpu_baseline() -> Self {
185 Self {
186 name: "WebGPU Baseline".to_string(),
187 capabilities: Capabilities::with_flags([
188 CapabilityFlag::SharedMemory,
189 CapabilityFlag::Float16,
190 ]),
191 max_threads_per_block: 256,
192 max_shared_memory: 16384, warp_size: 32, max_registers: 128,
195 }
196 }
197
198 pub fn wgpu_with_subgroups() -> Self {
200 let mut caps = Self::wgpu_baseline();
201 caps.name = "WebGPU with Subgroups".to_string();
202 caps.capabilities.add(CapabilityFlag::Subgroups);
203 caps.capabilities.add(CapabilityFlag::SubgroupVote);
204 caps
205 }
206
207 pub fn metal_apple_silicon() -> Self {
209 Self {
210 name: "Metal Apple Silicon".to_string(),
211 capabilities: Capabilities::with_flags([
212 CapabilityFlag::Int64,
213 CapabilityFlag::Subgroups,
214 CapabilityFlag::SubgroupShuffle,
215 CapabilityFlag::SubgroupVote,
216 CapabilityFlag::SubgroupReduce,
217 CapabilityFlag::SharedMemory,
218 CapabilityFlag::DynamicSharedMemory,
219 CapabilityFlag::IndirectCommands,
220 CapabilityFlag::Float16,
221 CapabilityFlag::UnifiedMemory,
222 ]),
223 max_threads_per_block: 1024,
224 max_shared_memory: 32768, warp_size: 32, max_registers: 256,
227 }
228 }
229
230 pub fn supports(&self, required: &Capabilities) -> bool {
232 self.capabilities.satisfies(required)
233 }
234
235 pub fn unsupported(&self, required: &Capabilities) -> Vec<CapabilityFlag> {
237 self.capabilities.missing(required)
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_capabilities_add_has() {
247 let mut caps = Capabilities::new();
248 assert!(!caps.has(CapabilityFlag::Float64));
249
250 caps.add(CapabilityFlag::Float64);
251 assert!(caps.has(CapabilityFlag::Float64));
252 }
253
254 #[test]
255 fn test_capabilities_satisfies() {
256 let available = Capabilities::with_flags([
257 CapabilityFlag::Float64,
258 CapabilityFlag::Int64,
259 CapabilityFlag::SharedMemory,
260 ]);
261
262 let required1 = Capabilities::with_flags([CapabilityFlag::Float64]);
263 assert!(available.satisfies(&required1));
264
265 let required2 = Capabilities::with_flags([CapabilityFlag::CooperativeGroups]);
266 assert!(!available.satisfies(&required2));
267 }
268
269 #[test]
270 fn test_capabilities_missing() {
271 let available = Capabilities::with_flags([CapabilityFlag::Float64]);
272 let required = Capabilities::with_flags([CapabilityFlag::Float64, CapabilityFlag::Int64]);
273
274 let missing = available.missing(&required);
275 assert_eq!(missing.len(), 1);
276 assert!(missing.contains(&CapabilityFlag::Int64));
277 }
278
279 #[test]
280 fn test_cuda_capabilities() {
281 let cuda = BackendCapabilities::cuda_sm80();
282 assert!(cuda.capabilities.has(CapabilityFlag::Float64));
283 assert!(cuda.capabilities.has(CapabilityFlag::CooperativeGroups));
284 assert!(cuda.capabilities.has(CapabilityFlag::PersistentKernels));
285 }
286
287 #[test]
288 fn test_wgpu_capabilities() {
289 let wgpu = BackendCapabilities::wgpu_baseline();
290 assert!(!wgpu.capabilities.has(CapabilityFlag::Float64));
291 assert!(wgpu.capabilities.has(CapabilityFlag::SharedMemory));
292 }
293
294 #[test]
295 fn test_metal_capabilities() {
296 let metal = BackendCapabilities::metal_apple_silicon();
297 assert!(metal.capabilities.has(CapabilityFlag::UnifiedMemory));
298 assert!(!metal.capabilities.has(CapabilityFlag::Float64)); }
300
301 #[test]
302 fn test_backend_supports() {
303 let cuda = BackendCapabilities::cuda_sm80();
304 let wgpu = BackendCapabilities::wgpu_baseline();
305
306 let requires_f64 = Capabilities::with_flags([CapabilityFlag::Float64]);
307 assert!(cuda.supports(&requires_f64));
308 assert!(!wgpu.supports(&requires_f64));
309 }
310}