#![cfg(feature = "tensorrt-plugin")]
use std::sync::Arc;
use crate::error::TrtError;
use crate::sys;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PluginCapability {
Core,
Build,
Runtime,
}
#[derive(Debug, Clone)]
pub struct PluginField {
pub name: String,
pub data: Vec<u8>,
pub dtype: sys::DataType,
}
pub trait PluginV3: Send + Sync {
fn name(&self) -> &str;
fn version(&self) -> &str;
fn namespace(&self) -> &str {
""
}
fn clone_boxed(&self) -> Box<dyn PluginV3>;
fn get_capability(&self, _cap: PluginCapability) -> Option<&dyn PluginV3> {
None
}
fn configure(&mut self, _fields: &[PluginField]) -> Result<(), TrtError> {
Ok(())
}
fn infer_shapes(&self, _input_shapes: &[Vec<i32>]) -> Vec<Vec<i32>> {
Vec::new()
}
fn enqueue(
&self,
_inputs: &[u64],
_outputs: &[u64],
_stream: *mut std::os::raw::c_void,
) -> Result<(), TrtError> {
Ok(())
}
}
pub fn make<P: PluginV3 + 'static>(plugin: P) -> Arc<dyn PluginV3> {
Arc::new(plugin) as Arc<dyn PluginV3>
}
pub fn register_plugin(plugin: Arc<dyn PluginV3>) -> Result<(), TrtError> {
#[cfg(feature = "tensorrt-link")]
{
crate::init_logger();
let user = Box::into_raw(Box::new(plugin)) as *mut std::ffi::c_void;
let vt = sys::AtomrPluginVTable {
get_name: plugin_get_name,
get_version: plugin_get_version,
get_namespace: plugin_get_namespace,
create_plugin: plugin_create_instance,
destroy: plugin_destroy_user,
destroy_instance: plugin_destroy_instance,
};
unsafe {
let creator = sys::atomr_trt_make_plugin_creator(&vt, user);
if creator.is_null() {
drop(Box::from_raw(user as *mut Arc<dyn PluginV3>));
return Err(TrtError::Plugin(
"atomr_trt_make_plugin_creator returned null".into(),
));
}
let rc = sys::atomr_trt_register_plugin_creator(creator);
if rc != 0 {
return Err(TrtError::Plugin(format!(
"registerCreator returned {rc}; plugin name/namespace may collide \
with an existing entry"
)));
}
}
Ok(())
}
#[cfg(not(feature = "tensorrt-link"))]
{
let _ = plugin;
Err(TrtError::NotLinked(
"register_plugin requires the `tensorrt-link` feature",
))
}
}
#[cfg(feature = "tensorrt-link")]
unsafe extern "C" fn plugin_get_name(user: *const std::ffi::c_void) -> *const std::os::raw::c_char {
let arc = &*(user as *const Arc<dyn PluginV3>);
cstr_for_str(arc.name(), &PLUGIN_NAME_CACHE)
}
#[cfg(feature = "tensorrt-link")]
unsafe extern "C" fn plugin_get_version(
user: *const std::ffi::c_void,
) -> *const std::os::raw::c_char {
let arc = &*(user as *const Arc<dyn PluginV3>);
cstr_for_str(arc.version(), &PLUGIN_VERSION_CACHE)
}
#[cfg(feature = "tensorrt-link")]
unsafe extern "C" fn plugin_get_namespace(
user: *const std::ffi::c_void,
) -> *const std::os::raw::c_char {
let arc = &*(user as *const Arc<dyn PluginV3>);
cstr_for_str(arc.namespace(), &PLUGIN_NS_CACHE)
}
#[cfg(feature = "tensorrt-link")]
unsafe extern "C" fn plugin_create_instance(
user: *const std::ffi::c_void,
_name: *const std::os::raw::c_char,
) -> *mut std::ffi::c_void {
let arc = &*(user as *const Arc<dyn PluginV3>);
let cloned: Box<dyn PluginV3> = arc.clone_boxed();
Box::into_raw(Box::new(cloned)) as *mut std::ffi::c_void
}
#[cfg(feature = "tensorrt-link")]
unsafe extern "C" fn plugin_destroy_user(user: *mut std::ffi::c_void) {
if !user.is_null() {
drop(Box::from_raw(user as *mut Arc<dyn PluginV3>));
}
}
#[cfg(feature = "tensorrt-link")]
unsafe extern "C" fn plugin_destroy_instance(instance: *mut std::ffi::c_void) {
if !instance.is_null() {
drop(Box::from_raw(instance as *mut Box<dyn PluginV3>));
}
}
#[cfg(feature = "tensorrt-link")]
static PLUGIN_NAME_CACHE: std::sync::OnceLock<
parking_lot::Mutex<std::collections::HashMap<String, std::ffi::CString>>,
> = std::sync::OnceLock::new();
#[cfg(feature = "tensorrt-link")]
static PLUGIN_VERSION_CACHE: std::sync::OnceLock<
parking_lot::Mutex<std::collections::HashMap<String, std::ffi::CString>>,
> = std::sync::OnceLock::new();
#[cfg(feature = "tensorrt-link")]
static PLUGIN_NS_CACHE: std::sync::OnceLock<
parking_lot::Mutex<std::collections::HashMap<String, std::ffi::CString>>,
> = std::sync::OnceLock::new();
#[cfg(feature = "tensorrt-link")]
fn cstr_for_str(
s: &str,
cache: &std::sync::OnceLock<
parking_lot::Mutex<std::collections::HashMap<String, std::ffi::CString>>,
>,
) -> *const std::os::raw::c_char {
let map = cache.get_or_init(|| parking_lot::Mutex::new(std::collections::HashMap::new()));
let mut g = map.lock();
g.entry(s.to_string())
.or_insert_with(|| std::ffi::CString::new(s).unwrap_or_default())
.as_ptr()
}
#[cfg(test)]
mod tests {
use super::*;
struct StubPlugin {
name: String,
version: String,
}
impl PluginV3 for StubPlugin {
fn name(&self) -> &str {
&self.name
}
fn version(&self) -> &str {
&self.version
}
fn clone_boxed(&self) -> Box<dyn PluginV3> {
Box::new(StubPlugin {
name: self.name.clone(),
version: self.version.clone(),
})
}
fn get_capability(&self, _cap: PluginCapability) -> Option<&dyn PluginV3> {
Some(self)
}
}
#[test]
fn plugin_v3_trait_object_safe() {
let p: Arc<dyn PluginV3> = make(StubPlugin {
name: "Stub".into(),
version: "1".into(),
});
assert_eq!(p.name(), "Stub");
assert_eq!(p.version(), "1");
assert_eq!(p.namespace(), "");
assert!(p.get_capability(PluginCapability::Core).is_some());
let cloned = p.clone_boxed();
assert_eq!(cloned.name(), "Stub");
let r = register_plugin(p);
assert!(matches!(
r,
Err(TrtError::NotLinked(_)) | Err(TrtError::Plugin(_))
));
fn assert_obj_safe<T: ?Sized + PluginV3>() {}
assert_obj_safe::<dyn PluginV3>();
}
}