pjrt 0.2.0

A safe PJRT C API bindings for Rust
Documentation
use std::borrow::Cow;
use std::rc::Rc;
use std::slice;

use bon::bon;
use pjrt_sys::{
    PJRT_Client, PJRT_Client_AddressableDevices_Args, PJRT_Client_AddressableMemories_Args,
    PJRT_Client_Compile_Args, PJRT_Client_DefaultDeviceAssignment_Args, PJRT_Client_Destroy_Args,
    PJRT_Client_Devices_Args, PJRT_Client_LookupAddressableDevice_Args,
    PJRT_Client_LookupDevice_Args, PJRT_Client_PlatformName_Args, PJRT_Client_PlatformVersion_Args,
    PJRT_Client_ProcessIndex_Args, PJRT_Client_TopologyDescription_Args,
    PJRT_Executable_DeserializeAndLoad_Args, PJRT_Program,
};

use crate::{
    utils, Api, CompileOptions, CompileToLoadedExecutable, Device, DeviceAssignment,
    GlobalDeviceId, KeyValueStore, LoadedExecutable, LocalHardwareId, Memory, NamedValue, Program,
    Result, TopologyDescription,
};

struct ClientRaw {
    api: Api,
    ptr: *mut PJRT_Client,
}

impl Drop for ClientRaw {
    fn drop(&mut self) {
        let mut args = PJRT_Client_Destroy_Args::new();
        args.client = self.ptr;
        self.api
            .PJRT_Client_Destroy(args)
            .expect("PJRT_Client_Destroy");
    }
}

#[derive(Clone)]
pub struct Client {
    raw: Rc<ClientRaw>,
}

#[bon]
impl Client {
    pub fn wrap(api: &Api, ptr: *mut PJRT_Client) -> Self {
        assert!(!ptr.is_null());
        Self {
            raw: Rc::new(ClientRaw {
                api: api.clone(),
                ptr,
            }),
        }
    }

    #[allow(clippy::borrowed_box)]
    #[builder(finish_fn = build)]
    pub fn builder(
        #[builder(start_fn)] api: &Api,
        #[builder(default = bon::vec![], into)] options: Vec<NamedValue>,
        #[builder] kv_store: Option<&Box<dyn KeyValueStore>>,
    ) -> Result<Self> {
        api.create_client(options, kv_store)
    }

    pub fn api(&self) -> &Api {
        &self.raw.api
    }

    pub(crate) fn ptr(&self) -> *mut PJRT_Client {
        self.raw.ptr
    }

    pub fn platform_name(&self) -> Cow<'_, str> {
        let mut args = PJRT_Client_PlatformName_Args::new();
        args.client = self.ptr();
        args = self
            .api()
            .PJRT_Client_PlatformName(args)
            .expect("PJRT_Client_PlatformName");
        utils::str_from_raw(args.platform_name, args.platform_name_size)
    }

    pub fn platform_version(&self) -> Cow<'_, str> {
        let mut args = PJRT_Client_PlatformVersion_Args::new();
        args.client = self.ptr();
        args = self
            .api()
            .PJRT_Client_PlatformVersion(args)
            .expect("PJRT_Client_PlatformVersion");
        utils::str_from_raw(args.platform_version, args.platform_version_size)
    }

    pub fn process_index(&self) -> i32 {
        let mut args = PJRT_Client_ProcessIndex_Args::new();
        args.client = self.ptr();
        args = self
            .api()
            .PJRT_Client_ProcessIndex(args)
            .expect("PJRT_Client_ProcessIndex");
        args.process_index
    }

    pub fn devices(&self) -> Vec<Device> {
        let mut args = PJRT_Client_Devices_Args::new();
        args.client = self.ptr();
        args = self
            .api()
            .PJRT_Client_Devices(args)
            .expect("PJRT_Client_Devices");
        let raw_devices = unsafe { slice::from_raw_parts(args.devices, args.num_devices) };
        raw_devices
            .iter()
            .cloned()
            .map(|d| Device::wrap(self, d))
            .collect()
    }

    pub fn addressable_devices(&self) -> Vec<Device> {
        let mut args = PJRT_Client_AddressableDevices_Args::new();
        args.client = self.ptr();
        args = self
            .api()
            .PJRT_Client_AddressableDevices(args)
            .expect("PJRT_Client_AddressableDevices");
        let devices = unsafe {
            slice::from_raw_parts(args.addressable_devices, args.num_addressable_devices)
        };
        devices
            .iter()
            .cloned()
            .map(|d| Device::wrap(self, d))
            .collect()
    }

    pub fn addressable_memories(&self) -> Vec<Memory> {
        let mut args = PJRT_Client_AddressableMemories_Args::new();
        args.client = self.ptr();
        args = self
            .api()
            .PJRT_Client_AddressableMemories(args)
            .expect("PJRT_Client_AddressableMemories");
        let memories = unsafe {
            slice::from_raw_parts(args.addressable_memories, args.num_addressable_memories)
        };
        memories
            .iter()
            .cloned()
            .map(|d| Memory::wrap(self, d))
            .collect()
    }

    pub fn lookup_device(&self, global_device_id: GlobalDeviceId) -> Result<Device> {
        let mut args = PJRT_Client_LookupDevice_Args::new();
        args.client = self.ptr();
        args.id = global_device_id;
        args = self.api().PJRT_Client_LookupDevice(args)?;
        Ok(Device::wrap(self, args.device))
    }

    pub fn lookup_addressable_device(&self, local_hardware_id: LocalHardwareId) -> Result<Device> {
        let mut args = PJRT_Client_LookupAddressableDevice_Args::new();
        args.client = self.ptr();
        args.local_hardware_id = local_hardware_id;
        args = self.api().PJRT_Client_LookupAddressableDevice(args)?;
        Ok(Device::wrap(self, args.addressable_device))
    }

    pub fn compile<T>(&self, program: &T, options: CompileOptions) -> Result<LoadedExecutable>
    where
        Self: CompileToLoadedExecutable<T>,
    {
        CompileToLoadedExecutable::<T>::compile(self, program, &options)
    }

    pub fn load_executable(&self, bytes: &[u8]) -> Result<LoadedExecutable> {
        let mut args = PJRT_Executable_DeserializeAndLoad_Args::new();
        args.client = self.ptr();
        args.serialized_executable = bytes.as_ptr() as *const i8;
        args.serialized_executable_size = bytes.len();
        args = self.api().PJRT_Executable_DeserializeAndLoad(args)?;
        Ok(LoadedExecutable::wrap(self, args.loaded_executable))
    }

    pub fn default_device_assignment(
        &self,
        num_replicas: usize,
        num_partitions: usize,
    ) -> Result<DeviceAssignment> {
        let mut default_assignment = vec![0; num_replicas * num_partitions];
        let mut args = PJRT_Client_DefaultDeviceAssignment_Args::new();
        args.client = self.ptr();
        args.num_replicas = num_replicas as i32;
        args.num_partitions = num_partitions as i32;
        args.default_assignment = default_assignment.as_mut_ptr();
        args.default_assignment_size = default_assignment.len();
        _ = self.api().PJRT_Client_DefaultDeviceAssignment(args)?;
        let assignment = DeviceAssignment::new(num_replicas, num_partitions, default_assignment);
        Ok(assignment)
    }

    pub fn topology(&self) -> TopologyDescription {
        let mut args = PJRT_Client_TopologyDescription_Args::new();
        args.client = self.ptr();
        args = self
            .api()
            .PJRT_Client_TopologyDescription(args)
            .expect("PJRT_Client_TopologyDescription");
        TopologyDescription::wrap(self.api(), args.topology)
    }

    // TODO:
    // PJRT_Client_CreateViewOfDeviceBuffer
}

impl CompileToLoadedExecutable<Program> for Client {
    fn compile(&self, program: &Program, options: &CompileOptions) -> Result<LoadedExecutable> {
        let options_encoded = options.encode();
        let mut args = PJRT_Client_Compile_Args::new();
        args.client = self.ptr();
        args.program = &program.prog as *const PJRT_Program;
        args.compile_options = options_encoded.as_ptr() as *const i8;
        args.compile_options_size = options_encoded.len();
        args = self.api().PJRT_Client_Compile(args)?;
        Ok(LoadedExecutable::wrap(self, args.executable))
    }
}