Skip to main content

cubecl_common/device/
base.rs

1use crate::stub::Arc;
2use core::{any::Any, cmp::Ordering};
3
4/// The device id.
5#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
6pub struct DeviceId {
7    /// The type id identifies the type of the device.
8    pub type_id: u16,
9    /// The index id identifies the device number.
10    pub index_id: u32,
11}
12
13/// Device trait for all cubecl devices.
14pub trait Device: Default + Clone + core::fmt::Debug + Send + Sync + 'static {
15    /// Create a device from its [id](DeviceId).
16    fn from_id(device_id: DeviceId) -> Self;
17    /// Retrieve the [device id](DeviceId) from the device.
18    fn to_id(&self) -> DeviceId;
19}
20
21impl core::fmt::Display for DeviceId {
22    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
23        f.write_fmt(format_args!(
24            "DeviceId(type={}, index={})",
25            self.type_id, self.index_id
26        ))
27    }
28}
29
30impl Ord for DeviceId {
31    fn cmp(&self, other: &Self) -> Ordering {
32        match self.type_id.cmp(&other.type_id) {
33            Ordering::Equal => self.index_id.cmp(&other.index_id),
34            other => other,
35        }
36    }
37}
38
39impl PartialOrd for DeviceId {
40    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
41        Some(self.cmp(other))
42    }
43}
44
45/// An pointer to a service's server utilities.
46pub type ServerUtilitiesHandle = Arc<dyn Any + Send + Sync>;
47
48/// Represent a service that runs on a device.
49pub trait DeviceService: Send + 'static {
50    /// Initializes the service. It is only called once per device.
51    fn init(device_id: DeviceId) -> Self;
52    /// Get the service utilities.
53    fn utilities(&self) -> ServerUtilitiesHandle;
54}