cubecl_runtime/
feature_set.rs

1use crate::memory_management::{HardwareProperties, MemoryDeviceProperties};
2use alloc::collections::BTreeSet;
3use cubecl_common::profile::TimingMethod;
4
5/// Properties of what the device can do, like what `Feature` are
6/// supported by it and what its memory properties are.
7#[derive(Debug)]
8pub struct DeviceProperties<Feature: Ord + Copy> {
9    set: alloc::collections::BTreeSet<Feature>,
10    /// The memory properties of this client.
11    pub memory: MemoryDeviceProperties,
12    /// The topology properties of this client.
13    pub hardware: HardwareProperties,
14    /// The method used for profiling on the device.
15    pub timing_method: TimingMethod,
16}
17
18impl<Feature: Ord + Copy> DeviceProperties<Feature> {
19    /// Create a new feature set with the given features and memory properties.
20    pub fn new(
21        features: &[Feature],
22        memory_props: MemoryDeviceProperties,
23        hardware: HardwareProperties,
24        timing_method: TimingMethod,
25    ) -> Self {
26        let mut set = BTreeSet::new();
27        for feature in features {
28            set.insert(*feature);
29        }
30
31        DeviceProperties {
32            set,
33            memory: memory_props,
34            hardware,
35            timing_method,
36        }
37    }
38
39    /// Check if the provided `Feature` is supported by the runtime.
40    pub fn feature_enabled(&self, feature: Feature) -> bool {
41        self.set.contains(&feature)
42    }
43
44    /// Register a `Feature` supported by the compute server.
45    ///
46    /// This should only be used by a [runtime](cubecl_core::Runtime) when initializing a device.
47    pub fn register_feature(&mut self, feature: Feature) -> bool {
48        self.set.insert(feature)
49    }
50
51    /// Removes a `Feature` from the compute server.
52    ///
53    /// This should only be used by a [runtime](cubecl_core::Runtime) when initializing a device.
54    pub fn remove_feature(&mut self, feature: Feature) {
55        self.set.remove(&feature);
56    }
57}