pjrt 0.2.0

A safe PJRT C API bindings for Rust
Documentation
use std::fs;
use std::path::Path;

use pjrt_sys::PJRT_Program;

use crate::{Error, Result};

#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ProgramFormat {
    MLIR,
    HLO,
}

impl ProgramFormat {
    pub fn as_str(&self) -> &'static str {
        match self {
            ProgramFormat::MLIR => "mlir",
            ProgramFormat::HLO => "hlo",
        }
    }

    pub fn as_bytes(&self) -> &'static [u8] {
        match self {
            ProgramFormat::MLIR => b"mlir",
            ProgramFormat::HLO => b"hlo",
        }
    }
}

impl TryFrom<&str> for ProgramFormat {
    type Error = Error;

    fn try_from(value: &str) -> Result<Self> {
        match value {
            "mlir" => Ok(ProgramFormat::MLIR),
            "hlo" => Ok(ProgramFormat::HLO),
            _ => Err(Error::InvalidProgramFormat(value.to_string())),
        }
    }
}

pub struct Program {
    format: ProgramFormat,
    code: Vec<u8>,
    pub(crate) prog: PJRT_Program,
}

impl Program {
    pub fn new(format: ProgramFormat, code: impl Into<Vec<u8>>) -> Self {
        let mut program = Program {
            format,
            code: code.into(),
            prog: PJRT_Program::new(),
        };
        program.prog.code = program.code.as_ptr() as *mut i8;
        program.prog.code_size = program.code.len();
        let format = program.format.as_bytes();
        program.prog.format = format.as_ptr() as *const i8;
        program.prog.format_size = format.len();
        program
    }

    pub fn format(&self) -> ProgramFormat {
        self.format
    }

    pub fn code(&self) -> &[u8] {
        &self.code
    }

    pub fn from_mlir<P: AsRef<Path>>(path: P) -> Result<Self> {
        let code = fs::read_to_string(path)?;
        Ok(Program::new(ProgramFormat::MLIR, code))
    }

    pub fn from_hlo<P: AsRef<Path>>(path: P) -> Result<Self> {
        let code = fs::read(path)?;
        Ok(Program::new(ProgramFormat::HLO, code))
    }
}