cubecl_runtime/
feature_set.rs

1use crate::{
2    Features, TypeUsage,
3    memory_management::{HardwareProperties, MemoryDeviceProperties},
4};
5use cubecl_common::profile::TimingMethod;
6use cubecl_ir::{SemanticType, StorageType, Type};
7use enumset::EnumSet;
8
9/// Properties of what the device can do, like what `Feature` are
10/// supported by it and what its memory properties are.
11#[derive(Debug)]
12pub struct DeviceProperties {
13    /// The features supported by the runtime.
14    pub features: Features,
15    /// The memory properties of this client.
16    pub memory: MemoryDeviceProperties,
17    /// The topology properties of this client.
18    pub hardware: HardwareProperties,
19    /// The method used for profiling on the device.
20    pub timing_method: TimingMethod,
21}
22
23impl DeviceProperties {
24    /// Create a new feature set with the given features and memory properties.
25    pub fn new(
26        features: Features,
27        memory_props: MemoryDeviceProperties,
28        hardware: HardwareProperties,
29        timing_method: TimingMethod,
30    ) -> Self {
31        DeviceProperties {
32            features,
33            memory: memory_props,
34            hardware,
35            timing_method,
36        }
37    }
38
39    /// Get the usages for a type
40    pub fn type_usage(&self, ty: StorageType) -> EnumSet<TypeUsage> {
41        self.features.type_usage(ty)
42    }
43
44    /// Whether the type is supported in any way
45    pub fn supports_type(&self, ty: impl Into<Type>) -> bool {
46        self.features.supports_type(ty)
47    }
48
49    /// Register a storage type to the features
50    pub fn register_type_usage(
51        &mut self,
52        ty: impl Into<StorageType>,
53        uses: impl Into<EnumSet<TypeUsage>>,
54    ) {
55        *self.features.storage_types.entry(ty.into()).or_default() |= uses.into();
56    }
57
58    /// Register a semantic type to the features
59    pub fn register_semantic_type(&mut self, ty: SemanticType) {
60        self.features.semantic_types.insert(ty);
61    }
62}