use std::borrow::Cow;
use std::fmt::{self, Debug, Display};
use std::slice;
use pjrt_sys::{
PJRT_Memory, PJRT_Memory_AddressableByDevices_Args, PJRT_Memory_DebugString_Args,
PJRT_Memory_Id_Args, PJRT_Memory_Kind_Args, PJRT_Memory_Kind_Id_Args,
PJRT_Memory_ToString_Args,
};
use crate::{utils, Client, Device};
pub struct Memory {
client: Client,
pub(crate) ptr: *mut PJRT_Memory,
}
impl Memory {
pub fn wrap(client: &Client, ptr: *mut PJRT_Memory) -> Memory {
assert!(!ptr.is_null());
Self {
client: client.clone(),
ptr,
}
}
pub fn client(&self) -> &Client {
&self.client
}
pub fn id(&self) -> i32 {
let mut args = PJRT_Memory_Id_Args::new();
args.memory = self.ptr;
args = self
.client
.api()
.PJRT_Memory_Id(args)
.expect("PJRT_Memory_Id");
args.id
}
pub fn kind(&self) -> Cow<'_, str> {
let mut args = PJRT_Memory_Kind_Args::new();
args.memory = self.ptr;
args = self
.client
.api()
.PJRT_Memory_Kind(args)
.expect("PJRT_Memory_Kind");
utils::str_from_raw(args.kind, args.kind_size)
}
pub fn kind_id(&self) -> i32 {
let mut args = PJRT_Memory_Kind_Id_Args::new();
args.memory = self.ptr;
args = self
.client
.api()
.PJRT_Memory_Kind_Id(args)
.expect("PJRT_Memory_Kind_Id");
args.kind_id
}
pub fn debug_string(&self) -> Cow<'_, str> {
let mut args = PJRT_Memory_DebugString_Args::new();
args.memory = self.ptr;
args = self
.client
.api()
.PJRT_Memory_DebugString(args)
.expect("PJRT_Memory_DebugString");
utils::str_from_raw(args.debug_string, args.debug_string_size)
}
pub fn to_string(&self) -> Cow<'_, str> {
let mut args = PJRT_Memory_ToString_Args::new();
args.memory = self.ptr;
args = self
.client
.api()
.PJRT_Memory_ToString(args)
.expect("PJRT_Memory_ToString");
utils::str_from_raw(args.to_string, args.to_string_size)
}
pub fn addressable_by_devices(&self) -> Vec<Device> {
let mut args = PJRT_Memory_AddressableByDevices_Args::new();
args.memory = self.ptr;
args = self
.client
.api()
.PJRT_Memory_AddressableByDevices(args)
.expect("PJRT_Memory_AddressableByDevices");
let devices = unsafe { slice::from_raw_parts(args.devices, args.num_devices) };
devices
.iter()
.map(|device| Device::wrap(&self.client, *device))
.collect()
}
}
impl Display for Memory {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Memory({})", self.to_string())
}
}
impl Debug for Memory {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Memory({})", self.debug_string())
}
}