use pjrt_sys::{
PJRT_Buffer, PJRT_Buffer_CopyToDevice_Args, PJRT_Buffer_CopyToMemory_Args,
PJRT_Buffer_Delete_Args, PJRT_Buffer_Destroy_Args, PJRT_Buffer_Device_Args,
PJRT_Buffer_Dimensions_Args, PJRT_Buffer_DynamicDimensionIndices_Args,
PJRT_Buffer_ElementType_Args, PJRT_Buffer_GetMemoryLayout_Args, PJRT_Buffer_IsDeleted_Args,
PJRT_Buffer_IsOnCpu_Args, PJRT_Buffer_Memory_Args, PJRT_Buffer_OnDeviceSizeInBytes_Args,
PJRT_Buffer_ReadyEvent_Args, PJRT_Buffer_ToHostBuffer_Args,
PJRT_Buffer_UnpaddedDimensions_Args,
};
use crate::event::Event;
use crate::{Client, Device, HostBuffer, Memory, MemoryLayout, PrimitiveType, Result};
pub struct Buffer {
client: Client,
pub(crate) ptr: *mut PJRT_Buffer,
}
impl Drop for Buffer {
fn drop(&mut self) {
let mut args = PJRT_Buffer_Destroy_Args::new();
args.buffer = self.ptr;
self.client
.api()
.PJRT_Buffer_Destroy(args)
.expect("PJRT_Buffer_Destroy");
}
}
impl Buffer {
pub fn wrap(client: &Client, ptr: *mut PJRT_Buffer) -> Self {
assert!(!ptr.is_null());
Self {
client: client.clone(),
ptr,
}
}
pub fn client(&self) -> &Client {
&self.client
}
pub fn primitive_type(&self) -> PrimitiveType {
let mut args = PJRT_Buffer_ElementType_Args::new();
args.buffer = self.ptr;
args = self
.client
.api()
.PJRT_Buffer_ElementType(args)
.expect("PJRT_Buffer_ElementType");
PrimitiveType::try_from(args.type_).expect("PrimitiveType")
}
pub fn dims(&self) -> Vec<i64> {
let mut args = PJRT_Buffer_Dimensions_Args::new();
args.buffer = self.ptr;
args = self
.client
.api()
.PJRT_Buffer_Dimensions(args)
.expect("PJRT_Buffer_Dimensions");
if args.num_dims == 0 {
return vec![];
}
let s = unsafe { std::slice::from_raw_parts(args.dims, args.num_dims) };
s.to_owned()
}
pub fn unpadded_dims(&self) -> Vec<i64> {
let mut args = PJRT_Buffer_UnpaddedDimensions_Args::new();
args.buffer = self.ptr;
args = self
.client
.api()
.PJRT_Buffer_UnpaddedDimensions(args)
.expect("PJRT_Buffer_UnpaddedDimensions");
let s = unsafe { std::slice::from_raw_parts(args.unpadded_dims, args.num_dims) };
s.to_owned()
}
pub fn dynamic_dims_indices(&self) -> Vec<usize> {
let mut args = PJRT_Buffer_DynamicDimensionIndices_Args::new();
args.buffer = self.ptr;
args = self
.client
.api()
.PJRT_Buffer_DynamicDimensionIndices(args)
.expect("PJRT_Buffer_DynamicDimensionIndices");
let s =
unsafe { std::slice::from_raw_parts(args.dynamic_dim_indices, args.num_dynamic_dims) };
s.to_owned()
}
pub fn layout(&self) -> MemoryLayout {
let mut args = PJRT_Buffer_GetMemoryLayout_Args::new();
args.buffer = self.ptr;
args = self
.client
.api()
.PJRT_Buffer_GetMemoryLayout(args)
.expect("PJRT_Buffer_GetMemoryLayout");
MemoryLayout::try_from(&args.layout).expect("layout")
}
pub fn on_device_size(&self) -> usize {
let mut args = PJRT_Buffer_OnDeviceSizeInBytes_Args::new();
args.buffer = self.ptr;
args = self
.client
.api()
.PJRT_Buffer_OnDeviceSizeInBytes(args)
.expect("PJRT_Buffer_GetMemoryLayout");
args.on_device_size_in_bytes
}
pub fn is_on_cpu(&self) -> bool {
let mut args = PJRT_Buffer_IsOnCpu_Args::new();
args.buffer = self.ptr;
args = self
.client
.api()
.PJRT_Buffer_IsOnCpu(args)
.expect("PJRT_Buffer_IsOnCpu");
args.is_on_cpu
}
pub fn device(&self) -> Device {
let mut args = PJRT_Buffer_Device_Args::new();
args.buffer = self.ptr;
args = self
.client
.api()
.PJRT_Buffer_Device(args)
.expect("PJRT_Buffer_Device");
Device::wrap(&self.client, args.device)
}
pub fn memory(&self) -> Memory {
let mut args = PJRT_Buffer_Memory_Args::new();
args.buffer = self.ptr;
args = self
.client
.api()
.PJRT_Buffer_Memory(args)
.expect("PJRT_Buffer_Memory");
Memory::wrap(&self.client, args.memory)
}
pub fn delete(self) {
let mut args = PJRT_Buffer_Delete_Args::new();
args.buffer = self.ptr;
self.client
.api()
.PJRT_Buffer_Delete(args)
.expect("PJRT_Buffer_Delete");
}
pub fn is_deleted(&self) -> bool {
let mut args = PJRT_Buffer_IsDeleted_Args::new();
args.buffer = self.ptr;
args = self
.client
.api()
.PJRT_Buffer_IsDeleted(args)
.expect("PJRT_Buffer_IsDeleted");
args.is_deleted
}
pub(crate) fn ready_event(&self) -> Result<Event> {
let mut args = PJRT_Buffer_ReadyEvent_Args::new();
args.buffer = self.ptr;
args = self.client.api().PJRT_Buffer_ReadyEvent(args)?;
Ok(Event::wrap(self.client.api(), args.event))
}
fn call_copy_to_device(&self, device: &Device) -> Result<PJRT_Buffer_CopyToDevice_Args> {
let mut args = PJRT_Buffer_CopyToDevice_Args::new();
args.buffer = self.ptr;
args.dst_device = device.ptr;
self.client.api().PJRT_Buffer_CopyToDevice(args)
}
pub async fn copy_to_device(&self, device: &Device) -> Result<Buffer> {
let args = self.call_copy_to_device(device)?;
let buf = Buffer::wrap(device.client(), args.dst_buffer);
let event = buf.ready_event()?;
event.await?;
Ok(buf)
}
pub fn copy_to_device_sync(&self, device: &Device) -> Result<Buffer> {
let args = self.call_copy_to_device(device)?;
let buf = Buffer::wrap(device.client(), args.dst_buffer);
let event = buf.ready_event()?;
event.wait()?;
Ok(buf)
}
fn call_copy_to_memory(&self, memory: &Memory) -> Result<PJRT_Buffer_CopyToMemory_Args> {
let mut args = PJRT_Buffer_CopyToMemory_Args::new();
args.buffer = self.ptr;
args.dst_memory = memory.ptr;
self.client.api().PJRT_Buffer_CopyToMemory(args)
}
pub async fn copy_to_memory(&self, memory: &Memory) -> Result<Buffer> {
let args = self.call_copy_to_memory(memory)?;
let buf = Buffer::wrap(memory.client(), args.dst_buffer);
let event = buf.ready_event()?;
event.await?;
Ok(buf)
}
pub fn copy_to_memory_sync(&self, memory: &Memory) -> Result<Buffer> {
let args = self.call_copy_to_memory(memory)?;
let buf = Buffer::wrap(memory.client(), args.dst_buffer);
let event = buf.ready_event()?;
event.wait()?;
Ok(buf)
}
pub fn call_copy_to_host(&self) -> Result<(PJRT_Buffer_ToHostBuffer_Args, Vec<u8>)> {
let mut args = PJRT_Buffer_ToHostBuffer_Args::new();
args.src = self.ptr;
args = self.client.api().PJRT_Buffer_ToHostBuffer(args)?;
let buf_size = args.dst_size;
let mut buf: Vec<u8> = vec![0; buf_size];
args.dst = buf.as_mut_ptr() as *mut _;
args = self.client.api().PJRT_Buffer_ToHostBuffer(args)?;
Ok((args, buf))
}
pub async fn copy_to_host(&self) -> Result<HostBuffer> {
let (args, data) = self.call_copy_to_host()?;
let event = Event::wrap(self.client.api(), args.event);
event.await?;
let ty = self.primitive_type();
let dims = self.dims();
let layout = self.layout();
HostBuffer::builder()
.bytes(data, ty)
.dims(dims)
.layout(layout)
.build()
}
pub fn copy_to_host_sync(&self) -> Result<HostBuffer> {
let (args, data) = self.call_copy_to_host()?;
let event = Event::wrap(self.client.api(), args.event);
event.wait()?;
let ty = self.primitive_type();
let dims = self.dims();
let layout = self.layout();
HostBuffer::builder()
.bytes(data, ty)
.dims(dims)
.layout(layout)
.build()
}
}