1use std::slice;
2
3use pjrt_sys::{
4 PJRT_Device, PJRT_Device_AddressableMemories_Args, PJRT_Device_DefaultMemory_Args,
5 PJRT_Device_GetDescription_Args, PJRT_Device_IsAddressable_Args,
6 PJRT_Device_LocalHardwareId_Args, PJRT_Device_MemoryStats_Args,
7};
8
9use crate::{Client, DeviceDescription, Memory, Result};
10
11pub type GlobalDeviceId = i32;
15
16pub type LocalDeviceId = i32;
19
20pub type LocalHardwareId = i32;
25
26pub struct Device {
27 client: Client,
28 pub(crate) ptr: *mut PJRT_Device,
29}
30
31impl Device {
32 pub fn wrap(client: &Client, ptr: *mut PJRT_Device) -> Device {
33 assert!(!ptr.is_null());
34 Self {
35 client: client.clone(),
36 ptr,
37 }
38 }
39
40 pub fn client(&self) -> &Client {
41 &self.client
42 }
43
44 pub fn description(&self) -> DeviceDescription {
45 let mut args = PJRT_Device_GetDescription_Args::new();
46 args.device = self.ptr;
47 args = self
48 .client
49 .api()
50 .PJRT_Device_GetDescription(args)
51 .expect("PJRT_Device_GetDescription");
52 DeviceDescription::wrap(self.client.api(), args.device_description)
53 }
54
55 pub fn is_addressable(&self) -> bool {
56 let mut args = PJRT_Device_IsAddressable_Args::new();
57 args.device = self.ptr;
58 args = self
59 .client
60 .api()
61 .PJRT_Device_IsAddressable(args)
62 .expect("PJRT_Device_IsAddressable");
63 args.is_addressable
64 }
65
66 pub fn local_hardware_id(&self) -> LocalHardwareId {
67 let mut args = PJRT_Device_LocalHardwareId_Args::new();
68 args.device = self.ptr;
69 args = self
70 .client
71 .api()
72 .PJRT_Device_LocalHardwareId(args)
73 .expect("PJRT_Device_LocalHardwareId");
74 args.local_hardware_id
75 }
76
77 pub fn addressable_memories(&self) -> Vec<Memory> {
78 let mut args = PJRT_Device_AddressableMemories_Args::new();
79 args.device = self.ptr;
80 args = self
81 .client
82 .api()
83 .PJRT_Device_AddressableMemories(args)
84 .expect("PJRT_Device_AddressableMemories");
85 let memories = unsafe { slice::from_raw_parts(args.memories, args.num_memories) };
86 memories
87 .iter()
88 .cloned()
89 .map(|d| Memory::wrap(&self.client, d))
90 .collect()
91 }
92
93 pub fn default_memory(&self) -> Memory {
94 let mut args = PJRT_Device_DefaultMemory_Args::new();
95 args.device = self.ptr;
96 args = self
97 .client
98 .api()
99 .PJRT_Device_DefaultMemory(args)
100 .expect("PJRT_Device_DefaultMemory");
101 Memory::wrap(&self.client, args.memory)
102 }
103
104 pub fn memory_stats(&self) -> Result<MemoryStats> {
105 let mut args = PJRT_Device_MemoryStats_Args::new();
106 args.device = self.ptr;
107 args = self.client.api().PJRT_Device_MemoryStats(args)?;
108 Ok(MemoryStats::from(args))
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
113pub struct MemoryStats {
114 pub bytes_in_use: i64,
115 pub peak_bytes_in_use: i64,
116 pub peak_bytes_in_use_is_set: bool,
117 pub num_allocs: i64,
118 pub num_allocs_is_set: bool,
119 pub largest_alloc_size: i64,
120 pub largest_alloc_size_is_set: bool,
121 pub bytes_limit: i64,
122 pub bytes_limit_is_set: bool,
123 pub bytes_reserved: i64,
124 pub bytes_reserved_is_set: bool,
125 pub peak_bytes_reserved: i64,
126 pub peak_bytes_reserved_is_set: bool,
127 pub bytes_reservable_limit: i64,
128 pub bytes_reservable_limit_is_set: bool,
129 pub largest_free_block_bytes: i64,
130 pub largest_free_block_bytes_is_set: bool,
131 pub pool_bytes: i64,
132 pub pool_bytes_is_set: bool,
133 pub peak_pool_bytes: i64,
134 pub peak_pool_bytes_is_set: bool,
135}
136
137impl From<PJRT_Device_MemoryStats_Args> for MemoryStats {
138 fn from(args: PJRT_Device_MemoryStats_Args) -> Self {
139 Self {
140 bytes_in_use: args.bytes_in_use,
141 peak_bytes_in_use: args.peak_bytes_in_use,
142 peak_bytes_in_use_is_set: args.peak_bytes_in_use_is_set,
143 num_allocs: args.num_allocs,
144 num_allocs_is_set: args.num_allocs_is_set,
145 largest_alloc_size: args.largest_alloc_size,
146 largest_alloc_size_is_set: args.largest_alloc_size_is_set,
147 bytes_limit: args.bytes_limit,
148 bytes_limit_is_set: args.bytes_limit_is_set,
149 bytes_reserved: args.bytes_reserved,
150 bytes_reserved_is_set: args.bytes_reserved_is_set,
151 peak_bytes_reserved: args.peak_bytes_reserved,
152 peak_bytes_reserved_is_set: args.peak_bytes_reserved_is_set,
153 bytes_reservable_limit: args.bytes_reservable_limit,
154 bytes_reservable_limit_is_set: args.bytes_reservable_limit_is_set,
155 largest_free_block_bytes: args.largest_free_block_bytes,
156 largest_free_block_bytes_is_set: args.largest_free_block_bytes_is_set,
157 pool_bytes: args.pool_bytes,
158 pool_bytes_is_set: args.pool_bytes_is_set,
159 peak_pool_bytes: args.peak_pool_bytes,
160 peak_pool_bytes_is_set: args.peak_pool_bytes_is_set,
161 }
162 }
163}