pjrt 0.2.0

A safe PJRT C API bindings for Rust
Documentation
use pjrt_sys::{
    PJRT_Chunk, PJRT_CopyToDeviceStream, PJRT_CopyToDeviceStream_AddChunk_Args,
    PJRT_CopyToDeviceStream_CurrentBytes_Args, PJRT_CopyToDeviceStream_Destroy_Args,
    PJRT_CopyToDeviceStream_GranuleSize_Args, PJRT_CopyToDeviceStream_TotalBytes_Args,
};

use crate::{Api, Chunk, Event, Result};

pub struct CopyToDeviceStream {
    api: Api,
    pub(crate) ptr: *mut PJRT_CopyToDeviceStream,
}

impl Drop for CopyToDeviceStream {
    fn drop(&mut self) {
        let mut args = PJRT_CopyToDeviceStream_Destroy_Args::new();
        args.stream = self.ptr;
        self.api
            .PJRT_CopyToDeviceStream_Destroy(args)
            .expect("PJRT_CopyToDeviceStream_Destroy");
    }
}

impl CopyToDeviceStream {
    pub fn wrap(api: &Api, ptr: *mut PJRT_CopyToDeviceStream) -> Self {
        assert!(!ptr.is_null());
        Self {
            api: api.clone(),
            ptr,
        }
    }

    pub fn api(&self) -> &Api {
        &self.api
    }

    pub fn call_add_chunk(&self, chunk: Chunk) -> Result<PJRT_CopyToDeviceStream_AddChunk_Args> {
        let mut args = PJRT_CopyToDeviceStream_AddChunk_Args::new();
        let mut chunk: PJRT_Chunk = chunk.into();
        args.stream = self.ptr;
        args.chunk = &mut chunk as *mut _;
        self.api.PJRT_CopyToDeviceStream_AddChunk(args)
    }

    pub fn add_chunk_sync(&self, chunk: Chunk) -> Result<()> {
        let args = self.call_add_chunk(chunk)?;
        let event = Event::wrap(&self.api, args.transfer_complete);
        event.wait()?;
        Ok(())
    }

    pub async fn add_chunk(&self, chunk: Chunk) -> Result<()> {
        let args = self.call_add_chunk(chunk)?;
        let event = Event::wrap(&self.api, args.transfer_complete);
        event.await?;
        Ok(())
    }

    pub fn total_bytes(&self) -> i64 {
        let mut args = PJRT_CopyToDeviceStream_TotalBytes_Args::new();
        args.stream = self.ptr;
        args = self
            .api
            .PJRT_CopyToDeviceStream_TotalBytes(args)
            .expect("PJRT_CopyToDeviceStream_TotalBytes");
        args.total_bytes
    }

    pub fn granule_size(&self) -> i64 {
        let mut args = PJRT_CopyToDeviceStream_GranuleSize_Args::new();
        args.stream = self.ptr;
        args = self
            .api
            .PJRT_CopyToDeviceStream_GranuleSize(args)
            .expect("PJRT_CopyToDeviceStream_GranuleSize");
        args.granule_size_in_bytes
    }

    pub fn current_bytes(&self) -> i64 {
        let mut args = PJRT_CopyToDeviceStream_CurrentBytes_Args::new();
        args.stream = self.ptr;
        args = self
            .api
            .PJRT_CopyToDeviceStream_CurrentBytes(args)
            .expect("PJRT_CopyToDeviceStream_CurrentBytes");
        args.current_bytes
    }
}