Skip to main content

ringkernel_ir/
capabilities.rs

1//! Backend capabilities for IR code generation.
2//!
3//! Tracks what features are available on different GPU backends.
4
5use std::collections::HashSet;
6
7/// Capability flags for GPU features.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum CapabilityFlag {
10    /// 64-bit floating point (f64).
11    Float64,
12    /// 64-bit integers.
13    Int64,
14    /// 64-bit atomics.
15    Atomic64,
16    /// Cooperative groups / grid sync.
17    CooperativeGroups,
18    /// Subgroup/warp operations.
19    Subgroups,
20    /// Subgroup shuffle.
21    SubgroupShuffle,
22    /// Subgroup vote.
23    SubgroupVote,
24    /// Subgroup reduce.
25    SubgroupReduce,
26    /// Shared memory.
27    SharedMemory,
28    /// Dynamic shared memory.
29    DynamicSharedMemory,
30    /// Indirect command buffers.
31    IndirectCommands,
32    /// Persistent kernels.
33    PersistentKernels,
34    /// Half precision (f16).
35    Float16,
36    /// Tensor cores / matrix ops.
37    TensorCores,
38    /// Ray tracing.
39    RayTracing,
40    /// Bindless textures.
41    BindlessTextures,
42    /// Unified memory.
43    UnifiedMemory,
44    /// Multi-GPU support.
45    MultiGpu,
46}
47
48impl std::fmt::Display for CapabilityFlag {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        match self {
51            CapabilityFlag::Float64 => write!(f, "float64"),
52            CapabilityFlag::Int64 => write!(f, "int64"),
53            CapabilityFlag::Atomic64 => write!(f, "atomic64"),
54            CapabilityFlag::CooperativeGroups => write!(f, "cooperative_groups"),
55            CapabilityFlag::Subgroups => write!(f, "subgroups"),
56            CapabilityFlag::SubgroupShuffle => write!(f, "subgroup_shuffle"),
57            CapabilityFlag::SubgroupVote => write!(f, "subgroup_vote"),
58            CapabilityFlag::SubgroupReduce => write!(f, "subgroup_reduce"),
59            CapabilityFlag::SharedMemory => write!(f, "shared_memory"),
60            CapabilityFlag::DynamicSharedMemory => write!(f, "dynamic_shared_memory"),
61            CapabilityFlag::IndirectCommands => write!(f, "indirect_commands"),
62            CapabilityFlag::PersistentKernels => write!(f, "persistent_kernels"),
63            CapabilityFlag::Float16 => write!(f, "float16"),
64            CapabilityFlag::TensorCores => write!(f, "tensor_cores"),
65            CapabilityFlag::RayTracing => write!(f, "ray_tracing"),
66            CapabilityFlag::BindlessTextures => write!(f, "bindless_textures"),
67            CapabilityFlag::UnifiedMemory => write!(f, "unified_memory"),
68            CapabilityFlag::MultiGpu => write!(f, "multi_gpu"),
69        }
70    }
71}
72
73/// Set of capabilities required or available.
74#[derive(Debug, Clone, Default)]
75pub struct Capabilities {
76    flags: HashSet<CapabilityFlag>,
77}
78
79impl Capabilities {
80    /// Create empty capabilities.
81    pub fn new() -> Self {
82        Self::default()
83    }
84
85    /// Create with specific flags.
86    pub fn with_flags(flags: impl IntoIterator<Item = CapabilityFlag>) -> Self {
87        Self {
88            flags: flags.into_iter().collect(),
89        }
90    }
91
92    /// Add a capability.
93    pub fn add(&mut self, flag: CapabilityFlag) {
94        self.flags.insert(flag);
95    }
96
97    /// Remove a capability.
98    pub fn remove(&mut self, flag: CapabilityFlag) {
99        self.flags.remove(&flag);
100    }
101
102    /// Check if capability is present.
103    pub fn has(&self, flag: CapabilityFlag) -> bool {
104        self.flags.contains(&flag)
105    }
106
107    /// Check if all required capabilities are satisfied.
108    pub fn satisfies(&self, required: &Capabilities) -> bool {
109        required.flags.iter().all(|f| self.flags.contains(f))
110    }
111
112    /// Get missing capabilities.
113    pub fn missing(&self, required: &Capabilities) -> Vec<CapabilityFlag> {
114        required
115            .flags
116            .iter()
117            .filter(|f| !self.flags.contains(f))
118            .copied()
119            .collect()
120    }
121
122    /// Merge with another set.
123    pub fn merge(&mut self, other: &Capabilities) {
124        self.flags.extend(&other.flags);
125    }
126
127    /// Get all flags.
128    pub fn flags(&self) -> &HashSet<CapabilityFlag> {
129        &self.flags
130    }
131
132    /// Check if empty.
133    pub fn is_empty(&self) -> bool {
134        self.flags.is_empty()
135    }
136}
137
138/// Backend-specific capabilities.
139#[derive(Debug, Clone)]
140pub struct BackendCapabilities {
141    /// Backend name.
142    pub name: String,
143    /// Available capabilities.
144    pub capabilities: Capabilities,
145    /// Maximum threads per block.
146    pub max_threads_per_block: u32,
147    /// Maximum shared memory per block (bytes).
148    pub max_shared_memory: u32,
149    /// Warp/wavefront size.
150    pub warp_size: u32,
151    /// Maximum registers per thread.
152    pub max_registers: u32,
153}
154
155impl BackendCapabilities {
156    /// Create CUDA capabilities (SM 8.0+).
157    pub fn cuda_sm80() -> Self {
158        Self {
159            name: "CUDA SM 8.0".to_string(),
160            capabilities: Capabilities::with_flags([
161                CapabilityFlag::Float64,
162                CapabilityFlag::Int64,
163                CapabilityFlag::Atomic64,
164                CapabilityFlag::CooperativeGroups,
165                CapabilityFlag::Subgroups,
166                CapabilityFlag::SubgroupShuffle,
167                CapabilityFlag::SubgroupVote,
168                CapabilityFlag::SubgroupReduce,
169                CapabilityFlag::SharedMemory,
170                CapabilityFlag::DynamicSharedMemory,
171                CapabilityFlag::PersistentKernels,
172                CapabilityFlag::Float16,
173                CapabilityFlag::TensorCores,
174                CapabilityFlag::UnifiedMemory,
175            ]),
176            max_threads_per_block: 1024,
177            max_shared_memory: 163840, // 160 KB
178            warp_size: 32,
179            max_registers: 255,
180        }
181    }
182
183    /// Create WebGPU capabilities (baseline).
184    pub fn wgpu_baseline() -> Self {
185        Self {
186            name: "WebGPU Baseline".to_string(),
187            capabilities: Capabilities::with_flags([
188                CapabilityFlag::SharedMemory,
189                CapabilityFlag::Float16,
190            ]),
191            max_threads_per_block: 256,
192            max_shared_memory: 16384, // 16 KB
193            warp_size: 32,            // Varies by hardware
194            max_registers: 128,
195        }
196    }
197
198    /// Create WebGPU capabilities with subgroups.
199    pub fn wgpu_with_subgroups() -> Self {
200        let mut caps = Self::wgpu_baseline();
201        caps.name = "WebGPU with Subgroups".to_string();
202        caps.capabilities.add(CapabilityFlag::Subgroups);
203        caps.capabilities.add(CapabilityFlag::SubgroupVote);
204        caps
205    }
206
207    /// Create Metal capabilities (Apple Silicon).
208    pub fn metal_apple_silicon() -> Self {
209        Self {
210            name: "Metal Apple Silicon".to_string(),
211            capabilities: Capabilities::with_flags([
212                CapabilityFlag::Int64,
213                CapabilityFlag::Subgroups,
214                CapabilityFlag::SubgroupShuffle,
215                CapabilityFlag::SubgroupVote,
216                CapabilityFlag::SubgroupReduce,
217                CapabilityFlag::SharedMemory,
218                CapabilityFlag::DynamicSharedMemory,
219                CapabilityFlag::IndirectCommands,
220                CapabilityFlag::Float16,
221                CapabilityFlag::UnifiedMemory,
222            ]),
223            max_threads_per_block: 1024,
224            max_shared_memory: 32768, // 32 KB
225            warp_size: 32,            // SIMD width
226            max_registers: 256,
227        }
228    }
229
230    /// Check if backend supports required capabilities.
231    pub fn supports(&self, required: &Capabilities) -> bool {
232        self.capabilities.satisfies(required)
233    }
234
235    /// Get unsupported capabilities.
236    pub fn unsupported(&self, required: &Capabilities) -> Vec<CapabilityFlag> {
237        self.capabilities.missing(required)
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_capabilities_add_has() {
247        let mut caps = Capabilities::new();
248        assert!(!caps.has(CapabilityFlag::Float64));
249
250        caps.add(CapabilityFlag::Float64);
251        assert!(caps.has(CapabilityFlag::Float64));
252    }
253
254    #[test]
255    fn test_capabilities_satisfies() {
256        let available = Capabilities::with_flags([
257            CapabilityFlag::Float64,
258            CapabilityFlag::Int64,
259            CapabilityFlag::SharedMemory,
260        ]);
261
262        let required1 = Capabilities::with_flags([CapabilityFlag::Float64]);
263        assert!(available.satisfies(&required1));
264
265        let required2 = Capabilities::with_flags([CapabilityFlag::CooperativeGroups]);
266        assert!(!available.satisfies(&required2));
267    }
268
269    #[test]
270    fn test_capabilities_missing() {
271        let available = Capabilities::with_flags([CapabilityFlag::Float64]);
272        let required = Capabilities::with_flags([CapabilityFlag::Float64, CapabilityFlag::Int64]);
273
274        let missing = available.missing(&required);
275        assert_eq!(missing.len(), 1);
276        assert!(missing.contains(&CapabilityFlag::Int64));
277    }
278
279    #[test]
280    fn test_cuda_capabilities() {
281        let cuda = BackendCapabilities::cuda_sm80();
282        assert!(cuda.capabilities.has(CapabilityFlag::Float64));
283        assert!(cuda.capabilities.has(CapabilityFlag::CooperativeGroups));
284        assert!(cuda.capabilities.has(CapabilityFlag::PersistentKernels));
285    }
286
287    #[test]
288    fn test_wgpu_capabilities() {
289        let wgpu = BackendCapabilities::wgpu_baseline();
290        assert!(!wgpu.capabilities.has(CapabilityFlag::Float64));
291        assert!(wgpu.capabilities.has(CapabilityFlag::SharedMemory));
292    }
293
294    #[test]
295    fn test_metal_capabilities() {
296        let metal = BackendCapabilities::metal_apple_silicon();
297        assert!(metal.capabilities.has(CapabilityFlag::UnifiedMemory));
298        assert!(!metal.capabilities.has(CapabilityFlag::Float64)); // Metal doesn't support f64
299    }
300
301    #[test]
302    fn test_backend_supports() {
303        let cuda = BackendCapabilities::cuda_sm80();
304        let wgpu = BackendCapabilities::wgpu_baseline();
305
306        let requires_f64 = Capabilities::with_flags([CapabilityFlag::Float64]);
307        assert!(cuda.supports(&requires_f64));
308        assert!(!wgpu.supports(&requires_f64));
309    }
310}