#![cfg(feature = "tensorrt-plugin")]
use std::sync::Arc;
use crate::error::TrtError;
use crate::sys;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PluginCapability {
Core,
Build,
Runtime,
}
#[derive(Debug, Clone)]
pub struct PluginField {
pub name: String,
pub data: Vec<u8>,
pub dtype: sys::DataType,
}
pub trait PluginV3: Send + Sync {
fn name(&self) -> &str;
fn version(&self) -> &str;
fn namespace(&self) -> &str {
""
}
fn clone_boxed(&self) -> Box<dyn PluginV3>;
fn get_capability(&self, _cap: PluginCapability) -> Option<&dyn PluginV3> {
None
}
fn configure(&mut self, _fields: &[PluginField]) -> Result<(), TrtError> {
Ok(())
}
fn infer_shapes(&self, _input_shapes: &[Vec<i32>]) -> Vec<Vec<i32>> {
Vec::new()
}
fn enqueue(
&self,
_inputs: &[u64],
_outputs: &[u64],
_stream: *mut std::os::raw::c_void,
) -> Result<(), TrtError> {
Ok(())
}
}
pub fn make<P: PluginV3 + 'static>(plugin: P) -> Arc<dyn PluginV3> {
Arc::new(plugin) as Arc<dyn PluginV3>
}
pub fn register_plugin(_plugin: Arc<dyn PluginV3>) -> Result<(), TrtError> {
#[cfg(feature = "tensorrt-link")]
{
let _ = _plugin;
Err(TrtError::Plugin(
"C++ IPluginCreator proxy not yet implemented in this Phase 8 skeleton; \
stub returns Plugin error. Link-time registration arrives in a follow-up commit."
.into(),
))
}
#[cfg(not(feature = "tensorrt-link"))]
{
Err(TrtError::NotLinked(
"register_plugin requires the `tensorrt-link` feature",
))
}
}
#[cfg(test)]
mod tests {
use super::*;
struct StubPlugin {
name: String,
version: String,
}
impl PluginV3 for StubPlugin {
fn name(&self) -> &str {
&self.name
}
fn version(&self) -> &str {
&self.version
}
fn clone_boxed(&self) -> Box<dyn PluginV3> {
Box::new(StubPlugin {
name: self.name.clone(),
version: self.version.clone(),
})
}
fn get_capability(&self, _cap: PluginCapability) -> Option<&dyn PluginV3> {
Some(self)
}
}
#[test]
fn plugin_v3_trait_object_safe() {
let p: Arc<dyn PluginV3> = make(StubPlugin {
name: "Stub".into(),
version: "1".into(),
});
assert_eq!(p.name(), "Stub");
assert_eq!(p.version(), "1");
assert_eq!(p.namespace(), "");
assert!(p.get_capability(PluginCapability::Core).is_some());
let cloned = p.clone_boxed();
assert_eq!(cloned.name(), "Stub");
let r = register_plugin(p);
assert!(matches!(
r,
Err(TrtError::NotLinked(_)) | Err(TrtError::Plugin(_))
));
fn assert_obj_safe<T: ?Sized + PluginV3>() {}
assert_obj_safe::<dyn PluginV3>();
}
}