use std::borrow::Cow;
use std::slice;
use bon::bon;
use pjrt_sys::{
PJRT_SerializedTopology, PJRT_TopologyDescription, PJRT_TopologyDescription_Attributes_Args,
PJRT_TopologyDescription_Destroy_Args, PJRT_TopologyDescription_GetDeviceDescriptions_Args,
PJRT_TopologyDescription_PlatformName_Args, PJRT_TopologyDescription_PlatformVersion_Args,
PJRT_TopologyDescription_Serialize_Args,
};
use crate::{utils, Api, DeviceDescription, NamedValue, NamedValueMap, Result};
pub struct TopologyDescription {
pub(crate) api: Api,
pub(crate) ptr: *mut PJRT_TopologyDescription,
}
impl Drop for TopologyDescription {
fn drop(&mut self) {
let mut args = PJRT_TopologyDescription_Destroy_Args::new();
args.topology = self.ptr;
self.api
.PJRT_TopologyDescription_Destroy(args)
.expect("PJRT_TopologyDescription_Destroy");
}
}
#[bon]
impl TopologyDescription {
pub fn wrap(api: &Api, ptr: *mut PJRT_TopologyDescription) -> TopologyDescription {
assert!(!ptr.is_null());
Self {
api: api.clone(),
ptr,
}
}
#[builder(finish_fn = build)]
pub fn builder(
#[builder(start_fn)] api: &Api,
#[builder(start_fn)] name: impl AsRef<str>,
#[builder(default = bon::vec![], into)] options: Vec<NamedValue>,
) -> Result<Self> {
api.create_topology(name, options)
}
pub fn platform_name(&self) -> Cow<'_, str> {
let mut args = PJRT_TopologyDescription_PlatformName_Args::new();
args.topology = self.ptr;
args = self
.api
.PJRT_TopologyDescription_PlatformName(args)
.expect("PJRT_TopologyDescription_PlatformName");
utils::str_from_raw(args.platform_name, args.platform_name_size)
}
pub fn platform_version(&self) -> Cow<'_, str> {
let mut args = PJRT_TopologyDescription_PlatformVersion_Args::new();
args.topology = self.ptr;
args = self
.api
.PJRT_TopologyDescription_PlatformVersion(args)
.expect("PJRT_TopologyDescription_PlatformVersion");
utils::str_from_raw(args.platform_version, args.platform_version_size)
}
pub fn device_descriptions(&self) -> Vec<DeviceDescription> {
let mut args = PJRT_TopologyDescription_GetDeviceDescriptions_Args::new();
args.topology = self.ptr;
args = self
.api
.PJRT_TopologyDescription_GetDeviceDescriptions(args)
.expect("PJRT_TopologyDescription_GetDeviceDescriptions");
let descriptions =
unsafe { slice::from_raw_parts(args.descriptions, args.num_descriptions) };
descriptions
.iter()
.map(|ptr| DeviceDescription::wrap(&self.api, *ptr))
.collect()
}
pub fn attributes(&self) -> NamedValueMap {
let mut args = PJRT_TopologyDescription_Attributes_Args::new();
args.topology = self.ptr;
args = self
.api
.PJRT_TopologyDescription_Attributes(args)
.expect("PJRT_TopologyDescription_Attributes");
utils::to_named_value_map(args.attributes, args.num_attributes)
}
pub fn serialize(&self) -> SerializedTopology {
let mut args = PJRT_TopologyDescription_Serialize_Args::new();
args.topology = self.ptr;
args = self
.api
.PJRT_TopologyDescription_Serialize(args)
.expect("PJRT_TopologyDescription_Serialize");
SerializedTopology {
ptr: args.serialized_topology,
deleter: args.serialized_topology_deleter.expect("topology_deleter"),
data_ptr: args.serialized_bytes as *const u8,
data_len: args.serialized_bytes_size,
}
}
}
pub struct SerializedTopology {
ptr: *mut PJRT_SerializedTopology,
deleter: unsafe extern "C" fn(topology: *mut PJRT_SerializedTopology),
data_ptr: *const u8,
data_len: usize,
}
impl Drop for SerializedTopology {
fn drop(&mut self) {
unsafe { (self.deleter)(self.ptr) };
}
}
impl SerializedTopology {
pub fn bytes(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.data_ptr, self.data_len) }
}
}