use cudarc::driver::sys::{CUdeviceptr, CUstream};
use cxx::{Exception, UniquePtr};
#[cxx::bridge]
mod ffi {
#[derive(Debug, Clone)]
enum TensorDataType {
UINT8,
FP32,
INT64,
BOOL,
}
#[derive(Debug, Clone)]
struct TensorInfo {
name: String,
dims: Vec<u32>,
dtype: TensorDataType,
}
#[derive(Debug, Clone)]
struct Options {
path: String,
}
unsafe extern "C++" {
include!("libinfer/src/engine.h");
type Engine;
fn load_engine(options: &Options) -> Result<UniquePtr<Engine>>;
fn get_input_dims(self: &Engine) -> Vec<TensorInfo>;
fn get_output_dims(self: &Engine) -> Vec<TensorInfo>;
fn _get_batch_dims(self: &Engine) -> Vec<u32>;
fn get_output_len(self: &Engine) -> u32;
fn get_num_inputs(self: &Engine) -> usize;
fn get_num_outputs(self: &Engine) -> usize;
unsafe fn infer(
self: Pin<&mut Engine>,
input_ptrs: *const u64,
num_inputs: usize,
output_ptrs: *const u64,
num_outputs: usize,
stream: u64,
batch_size: u32,
) -> Result<()>;
unsafe fn infer_async(
self: Pin<&mut Engine>,
input_ptrs: *const u64,
num_inputs: usize,
output_ptrs: *const u64,
num_outputs: usize,
stream: u64,
batch_size: u32,
) -> Result<()>;
}
}
pub use ffi::{Options, TensorDataType, TensorInfo};
impl TensorDataType {
pub fn byte_size(&self) -> usize {
if *self == TensorDataType::FP32 {
4
} else if *self == TensorDataType::INT64 {
8
} else {
1
}
}
}
impl TensorInfo {
pub fn elem_count(&self) -> usize {
self.dims.iter().map(|&d| d as usize).product()
}
pub fn byte_size(&self) -> usize {
self.elem_count() * self.dtype.byte_size()
}
}
unsafe impl Send for ffi::Engine {}
#[derive(Debug, Clone)]
pub struct BatchDims {
pub min: u32,
pub opt: u32,
pub max: u32,
}
pub struct Engine {
inner: UniquePtr<ffi::Engine>,
}
impl Engine {
pub fn new(options: &Options) -> Result<Self, Exception> {
let inner = ffi::load_engine(options)?;
Ok(Self { inner })
}
pub fn infer(
&mut self,
inputs: &[CUdeviceptr],
outputs: &[CUdeviceptr],
stream: CUstream,
batch_size: u32,
) -> Result<(), Exception> {
unsafe {
self.inner.pin_mut().infer(
inputs.as_ptr(),
inputs.len(),
outputs.as_ptr(),
outputs.len(),
stream as u64,
batch_size,
)
}
}
pub fn infer_async(
&mut self,
inputs: &[CUdeviceptr],
outputs: &[CUdeviceptr],
stream: CUstream,
batch_size: u32,
) -> Result<(), Exception> {
unsafe {
self.inner.pin_mut().infer_async(
inputs.as_ptr(),
inputs.len(),
outputs.as_ptr(),
outputs.len(),
stream as u64,
batch_size,
)
}
}
pub fn get_input_dims(&self) -> Vec<TensorInfo> {
self.inner.get_input_dims().into_iter().collect()
}
pub fn get_output_dims(&self) -> Vec<TensorInfo> {
self.inner.get_output_dims().into_iter().collect()
}
pub fn get_batch_dims(&self) -> BatchDims {
let vs = self.inner._get_batch_dims();
BatchDims {
min: vs[0],
opt: vs[1],
max: vs[2],
}
}
pub fn get_num_inputs(&self) -> usize {
self.inner.get_num_inputs()
}
pub fn get_num_outputs(&self) -> usize {
self.inner.get_num_outputs()
}
pub fn get_output_len(&self) -> u32 {
self.inner.get_output_len()
}
}
unsafe impl Send for Engine {}