use std::sync::Arc;
use crate::error::TrtError;
use crate::sys;
pub struct TrtEngine {
raw: *mut sys::ICudaEngine,
num_io: usize,
}
unsafe impl Send for TrtEngine {}
unsafe impl Sync for TrtEngine {}
impl TrtEngine {
pub unsafe fn from_raw(raw: *mut sys::ICudaEngine, num_io: usize) -> Result<Self, TrtError> {
if raw.is_null() {
Err(TrtError::NullEngine)
} else {
Ok(Self { raw, num_io })
}
}
#[allow(dead_code)]
pub(crate) fn for_test() -> Self {
Self {
raw: std::ptr::null_mut(),
num_io: 0,
}
}
pub fn raw(&self) -> *mut sys::ICudaEngine {
self.raw
}
pub fn num_io_tensors(&self) -> usize {
self.num_io
}
pub fn io_tensor_name(&self, _idx: usize) -> Option<String> {
#[cfg(feature = "tensorrt-link")]
{
if _idx >= self.num_io {
return None;
}
unsafe {
let p = sys::atomr_trt_engine_io_tensor_name(self.raw, _idx as i32);
if p.is_null() {
return None;
}
let cstr = std::ffi::CStr::from_ptr(p);
cstr.to_str().ok().map(|s| s.to_string())
}
}
#[cfg(not(feature = "tensorrt-link"))]
{
None
}
}
pub fn into_shared(self) -> Arc<TrtEngine> {
Arc::new(self)
}
}
impl Drop for TrtEngine {
fn drop(&mut self) {
#[cfg(feature = "tensorrt-link")]
unsafe {
if !self.raw.is_null() {
sys::atomr_trt_engine_destroy(self.raw);
}
}
}
}
#[derive(Debug, Clone)]
pub struct EnginePlan(pub Vec<u8>);
impl EnginePlan {
pub fn new(bytes: Vec<u8>) -> Self {
Self(bytes)
}
pub fn as_slice(&self) -> &[u8] {
&self.0
}
}
pub struct TrtRefitter {
raw: *mut sys::IRefitter,
}
unsafe impl Send for TrtRefitter {}
unsafe impl Sync for TrtRefitter {}
impl TrtRefitter {
pub unsafe fn from_raw(raw: *mut sys::IRefitter) -> Result<Self, TrtError> {
if raw.is_null() {
Err(TrtError::Refit("null refitter".into()))
} else {
Ok(Self { raw })
}
}
#[allow(dead_code)]
pub(crate) fn for_test() -> Self {
Self {
raw: std::ptr::null_mut(),
}
}
pub fn raw(&self) -> *mut sys::IRefitter {
self.raw
}
}
impl Drop for TrtRefitter {
fn drop(&mut self) {
#[cfg(feature = "tensorrt-link")]
unsafe {
if !self.raw.is_null() {
sys::atomr_trt_refitter_destroy(self.raw);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_send_sync<T: Send + Sync>() {}
#[test]
fn engine_handle_send_sync() {
assert_send_sync::<TrtEngine>();
assert_send_sync::<Arc<TrtEngine>>();
assert_send_sync::<TrtRefitter>();
let e = TrtEngine::for_test();
assert_eq!(e.num_io_tensors(), 0);
let shared: Arc<TrtEngine> = e.into_shared();
assert!(Arc::strong_count(&shared) >= 1);
}
#[test]
fn engine_plan_round_trip() {
let plan = EnginePlan::new(vec![0xDE, 0xAD, 0xBE, 0xEF]);
assert_eq!(plan.as_slice(), &[0xDE, 0xAD, 0xBE, 0xEF]);
}
}