pjrt/
device.rs

1use std::slice;
2
3use pjrt_sys::{
4    PJRT_Device, PJRT_Device_AddressableMemories_Args, PJRT_Device_DefaultMemory_Args,
5    PJRT_Device_GetDescription_Args, PJRT_Device_IsAddressable_Args,
6    PJRT_Device_LocalHardwareId_Args, PJRT_Device_MemoryStats_Args,
7};
8
9use crate::{Client, DeviceDescription, Memory, Result};
10
11/// The logical global device ID. This is unique among
12/// devices of this type (e.g. CPUs, GPUs). On multi-host platforms, this will
13/// be unique across all hosts' devices.
14pub type GlobalDeviceId = i32;
15
16/// The logical local device ID. This will be used to look
17/// up an addressable device local to a given client. It is -1 if undefined.
18pub type LocalDeviceId = i32;
19
20/// The physical local device ID, e.g., the CUDA device
21/// number. Multiple PJRT devices can have the same LocalHardwareId if
22/// these PJRT devices share the same physical device.In
23/// general, not guaranteed to be dense, and -1 if undefined.
24pub type LocalHardwareId = i32;
25
26pub struct Device {
27    client: Client,
28    pub(crate) ptr: *mut PJRT_Device,
29}
30
31impl Device {
32    pub fn wrap(client: &Client, ptr: *mut PJRT_Device) -> Device {
33        assert!(!ptr.is_null());
34        Self {
35            client: client.clone(),
36            ptr,
37        }
38    }
39
40    pub fn client(&self) -> &Client {
41        &self.client
42    }
43
44    pub fn description(&self) -> DeviceDescription {
45        let mut args = PJRT_Device_GetDescription_Args::new();
46        args.device = self.ptr;
47        args = self
48            .client
49            .api()
50            .PJRT_Device_GetDescription(args)
51            .expect("PJRT_Device_GetDescription");
52        DeviceDescription::wrap(self.client.api(), args.device_description)
53    }
54
55    pub fn is_addressable(&self) -> bool {
56        let mut args = PJRT_Device_IsAddressable_Args::new();
57        args.device = self.ptr;
58        args = self
59            .client
60            .api()
61            .PJRT_Device_IsAddressable(args)
62            .expect("PJRT_Device_IsAddressable");
63        args.is_addressable
64    }
65
66    pub fn local_hardware_id(&self) -> LocalHardwareId {
67        let mut args = PJRT_Device_LocalHardwareId_Args::new();
68        args.device = self.ptr;
69        args = self
70            .client
71            .api()
72            .PJRT_Device_LocalHardwareId(args)
73            .expect("PJRT_Device_LocalHardwareId");
74        args.local_hardware_id
75    }
76
77    pub fn addressable_memories(&self) -> Vec<Memory> {
78        let mut args = PJRT_Device_AddressableMemories_Args::new();
79        args.device = self.ptr;
80        args = self
81            .client
82            .api()
83            .PJRT_Device_AddressableMemories(args)
84            .expect("PJRT_Device_AddressableMemories");
85        let memories = unsafe { slice::from_raw_parts(args.memories, args.num_memories) };
86        memories
87            .iter()
88            .cloned()
89            .map(|d| Memory::wrap(&self.client, d))
90            .collect()
91    }
92
93    pub fn default_memory(&self) -> Memory {
94        let mut args = PJRT_Device_DefaultMemory_Args::new();
95        args.device = self.ptr;
96        args = self
97            .client
98            .api()
99            .PJRT_Device_DefaultMemory(args)
100            .expect("PJRT_Device_DefaultMemory");
101        Memory::wrap(&self.client, args.memory)
102    }
103
104    pub fn memory_stats(&self) -> Result<MemoryStats> {
105        let mut args = PJRT_Device_MemoryStats_Args::new();
106        args.device = self.ptr;
107        args = self.client.api().PJRT_Device_MemoryStats(args)?;
108        Ok(MemoryStats::from(args))
109    }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
113pub struct MemoryStats {
114    pub bytes_in_use: i64,
115    pub peak_bytes_in_use: i64,
116    pub peak_bytes_in_use_is_set: bool,
117    pub num_allocs: i64,
118    pub num_allocs_is_set: bool,
119    pub largest_alloc_size: i64,
120    pub largest_alloc_size_is_set: bool,
121    pub bytes_limit: i64,
122    pub bytes_limit_is_set: bool,
123    pub bytes_reserved: i64,
124    pub bytes_reserved_is_set: bool,
125    pub peak_bytes_reserved: i64,
126    pub peak_bytes_reserved_is_set: bool,
127    pub bytes_reservable_limit: i64,
128    pub bytes_reservable_limit_is_set: bool,
129    pub largest_free_block_bytes: i64,
130    pub largest_free_block_bytes_is_set: bool,
131    pub pool_bytes: i64,
132    pub pool_bytes_is_set: bool,
133    pub peak_pool_bytes: i64,
134    pub peak_pool_bytes_is_set: bool,
135}
136
137impl From<PJRT_Device_MemoryStats_Args> for MemoryStats {
138    fn from(args: PJRT_Device_MemoryStats_Args) -> Self {
139        Self {
140            bytes_in_use: args.bytes_in_use,
141            peak_bytes_in_use: args.peak_bytes_in_use,
142            peak_bytes_in_use_is_set: args.peak_bytes_in_use_is_set,
143            num_allocs: args.num_allocs,
144            num_allocs_is_set: args.num_allocs_is_set,
145            largest_alloc_size: args.largest_alloc_size,
146            largest_alloc_size_is_set: args.largest_alloc_size_is_set,
147            bytes_limit: args.bytes_limit,
148            bytes_limit_is_set: args.bytes_limit_is_set,
149            bytes_reserved: args.bytes_reserved,
150            bytes_reserved_is_set: args.bytes_reserved_is_set,
151            peak_bytes_reserved: args.peak_bytes_reserved,
152            peak_bytes_reserved_is_set: args.peak_bytes_reserved_is_set,
153            bytes_reservable_limit: args.bytes_reservable_limit,
154            bytes_reservable_limit_is_set: args.bytes_reservable_limit_is_set,
155            largest_free_block_bytes: args.largest_free_block_bytes,
156            largest_free_block_bytes_is_set: args.largest_free_block_bytes_is_set,
157            pool_bytes: args.pool_bytes,
158            pool_bytes_is_set: args.pool_bytes_is_set,
159            peak_pool_bytes: args.peak_pool_bytes,
160            peak_pool_bytes_is_set: args.peak_pool_bytes_is_set,
161        }
162    }
163}