use std::sync::Arc;
use tokio::sync::oneshot;
use crate::builder::IBuilderConfig;
use crate::engine::{EnginePlan, TrtEngine};
use crate::error::TrtError;
use crate::runtime::{ExecutionBindings, ExecutionContext};
#[derive(Debug, Clone)]
pub enum NetworkSource {
Onnx(Vec<u8>),
SerializedPlan(Vec<u8>),
}
pub struct RefitWeights {
pub name: String,
pub bytes: Vec<u8>,
pub dtype: crate::sys::DataType,
}
pub type BuildReply = oneshot::Sender<Result<EnginePlan, TrtError>>;
pub type DeserializeReply = oneshot::Sender<Result<Arc<TrtEngine>, TrtError>>;
pub type CreateContextReply = oneshot::Sender<Result<ExecutionContext, TrtError>>;
pub type EnqueueReply = oneshot::Sender<Result<(), TrtError>>;
pub type RefitReply = oneshot::Sender<Result<(), TrtError>>;
pub type ExecuteReply = oneshot::Sender<Result<(), TrtError>>;
pub type BuildFromOnnxReply = oneshot::Sender<Result<EnginePlan, TrtError>>;
pub enum TrtMsg {
Build {
source: NetworkSource,
config: Box<IBuilderConfig>,
reply: BuildReply,
},
Deserialize {
plan: EnginePlan,
reply: DeserializeReply,
},
CreateContext {
engine: Arc<TrtEngine>,
reply: CreateContextReply,
},
EnqueueOnStream {
stream: Arc<cudarc::driver::CudaStream>,
context: ExecutionContext,
bindings: ExecutionBindings,
reply: EnqueueReply,
},
Refit {
engine: Arc<TrtEngine>,
weights: Vec<RefitWeights>,
reply: RefitReply,
},
Execute {
engine: Arc<TrtEngine>,
bindings: Vec<(String, u64)>,
input_shapes: Vec<(String, Vec<i32>)>,
stream: Arc<cudarc::driver::CudaStream>,
reply: ExecuteReply,
},
BuildFromOnnx {
onnx_bytes: Vec<u8>,
config: Box<IBuilderConfig>,
reply: BuildFromOnnxReply,
},
}
pub struct TrtActor {
runtime: parking_lot::Mutex<Option<crate::runtime::TrtRuntime>>,
}
impl TrtActor {
pub fn new() -> Self {
Self {
runtime: parking_lot::Mutex::new(None),
}
}
pub fn ensure_runtime(&self) -> Result<(), TrtError> {
let mut guard = self.runtime.lock();
if guard.is_none() {
*guard = Some(crate::runtime::TrtRuntime::new()?);
}
Ok(())
}
pub fn execute(
&self,
engine: &Arc<TrtEngine>,
bindings: &[(String, u64)],
input_shapes: &[(String, Vec<i32>)],
_stream: &Arc<cudarc::driver::CudaStream>,
) -> Result<(), TrtError> {
#[cfg(feature = "tensorrt-link")]
{
use std::ffi::CString;
unsafe {
let ctx_ptr = crate::sys::atomr_trt_engine_create_execution_context(engine.raw());
if ctx_ptr.is_null() {
return Err(TrtError::Execution(
"createExecutionContext returned null".into(),
));
}
for (name, dims) in input_shapes {
if dims.len() > 8 {
crate::sys::atomr_trt_context_destroy(ctx_ptr);
return Err(TrtError::InvalidArg(format!(
"tensor {name:?}: TensorRT supports at most 8 dims (got {})",
dims.len()
)));
}
let cname = match CString::new(name.clone()) {
Ok(c) => c,
Err(e) => {
crate::sys::atomr_trt_context_destroy(ctx_ptr);
return Err(TrtError::InvalidArg(format!(
"tensor name contains NUL: {e}"
)));
}
};
let mut d = [0i32; 8];
for (i, v) in dims.iter().enumerate() {
d[i] = *v;
}
let dims_struct = crate::sys::Dims {
nb_dims: dims.len() as std::os::raw::c_int,
d,
};
let rc = crate::sys::atomr_trt_context_set_input_shape(
ctx_ptr,
cname.as_ptr(),
&dims_struct as *const crate::sys::Dims,
);
if rc != 0 {
crate::sys::atomr_trt_context_destroy(ctx_ptr);
return Err(TrtError::Execution(format!(
"set_input_shape({name}) returned {rc}"
)));
}
}
for (name, addr) in bindings {
let cname = match CString::new(name.clone()) {
Ok(c) => c,
Err(e) => {
crate::sys::atomr_trt_context_destroy(ctx_ptr);
return Err(TrtError::InvalidArg(format!(
"tensor name contains NUL: {e}"
)));
}
};
let rc = crate::sys::atomr_trt_context_set_tensor_address(
ctx_ptr,
cname.as_ptr(),
*addr as *mut std::os::raw::c_void,
);
if rc != 0 {
crate::sys::atomr_trt_context_destroy(ctx_ptr);
return Err(TrtError::Execution(format!(
"set_tensor_address({name}) returned {rc}"
)));
}
}
let stream_raw = _stream.cu_stream() as *mut std::os::raw::c_void;
let rc = crate::sys::atomr_trt_context_enqueue_v3(ctx_ptr, stream_raw);
let result = if rc != 0 {
Err(TrtError::Execution(format!("enqueueV3 returned {rc}")))
} else {
Ok(())
};
crate::sys::atomr_trt_context_destroy(ctx_ptr);
result
}
}
#[cfg(not(feature = "tensorrt-link"))]
{
let _ = (engine, bindings, input_shapes, _stream);
Err(TrtError::NotLinked(
"TrtActor::execute requires the `tensorrt-link` feature",
))
}
}
pub fn build_from_onnx(
&self,
_onnx_bytes: &[u8],
_config: &IBuilderConfig,
) -> Result<EnginePlan, TrtError> {
#[cfg(all(feature = "tensorrt-link", feature = "tensorrt-onnx"))]
{
use crate::builder::BuilderFlags;
unsafe {
let builder = crate::sys::atomr_trt_builder_create(0);
if builder.is_null() {
return Err(TrtError::Build("builder_create returned null".into()));
}
let network = crate::sys::atomr_trt_builder_create_network(builder, 1u32 << 0);
if network.is_null() {
crate::sys::atomr_trt_builder_destroy(builder);
return Err(TrtError::Build("create_network returned null".into()));
}
let parser = crate::sys::atomr_trt_onnx_parser_create(network, 0);
if parser.is_null() {
crate::sys::atomr_trt_builder_destroy(builder);
return Err(TrtError::Onnx("onnx_parser_create returned null".into()));
}
let parse_rc = crate::sys::atomr_trt_onnx_parser_parse(
parser,
_onnx_bytes.as_ptr(),
_onnx_bytes.len(),
std::ptr::null(),
);
if parse_rc == 0 {
let nerr = crate::sys::atomr_trt_onnx_parser_num_errors(parser);
crate::sys::atomr_trt_onnx_parser_destroy(parser);
crate::sys::atomr_trt_builder_destroy(builder);
return Err(TrtError::Onnx(format!(
"onnx parse failed (rc={parse_rc}, errors={nerr})"
)));
}
let cfg_ptr = crate::sys::atomr_trt_builder_create_config(builder);
if cfg_ptr.is_null() {
crate::sys::atomr_trt_onnx_parser_destroy(parser);
crate::sys::atomr_trt_builder_destroy(builder);
return Err(TrtError::Build(
"builder_create_config returned null".into(),
));
}
let flags = _config.effective_flags();
for flag in [
(BuilderFlags::FP16, crate::sys::BuilderFlag::kFP16 as u32),
(BuilderFlags::INT8, crate::sys::BuilderFlag::kINT8 as u32),
(BuilderFlags::TF32, crate::sys::BuilderFlag::kTF32 as u32),
(BuilderFlags::BF16, crate::sys::BuilderFlag::kBF16 as u32),
(BuilderFlags::FP8, crate::sys::BuilderFlag::kFP8 as u32),
(BuilderFlags::REFIT, crate::sys::BuilderFlag::kREFIT as u32),
(
BuilderFlags::SPARSE_WEIGHTS,
crate::sys::BuilderFlag::kSPARSE_WEIGHTS as u32,
),
(
BuilderFlags::STRIP_PLAN,
crate::sys::BuilderFlag::kSTRIP_PLAN as u32,
),
] {
if flags.contains(flag.0) {
crate::sys::atomr_trt_config_set_flag(cfg_ptr, flag.1, 1);
}
}
if _config.workspace_bytes > 0 {
crate::sys::atomr_trt_config_set_memory_pool_limit(
cfg_ptr,
0, _config.workspace_bytes,
);
}
let host_mem =
crate::sys::atomr_trt_builder_build_serialized(builder, network, cfg_ptr);
let cleanup = || {
crate::sys::atomr_trt_config_destroy(cfg_ptr);
crate::sys::atomr_trt_onnx_parser_destroy(parser);
crate::sys::atomr_trt_builder_destroy(builder);
};
if host_mem.is_null() {
cleanup();
return Err(TrtError::Build(
"buildSerializedNetwork returned null".into(),
));
}
let data_ptr = crate::sys::atomr_trt_host_memory_data(host_mem);
let data_len = crate::sys::atomr_trt_host_memory_size(host_mem);
let bytes = if data_ptr.is_null() || data_len == 0 {
Vec::new()
} else {
std::slice::from_raw_parts(data_ptr, data_len).to_vec()
};
crate::sys::atomr_trt_host_memory_destroy(host_mem);
cleanup();
if bytes.is_empty() {
return Err(TrtError::Build("serialised plan was empty".into()));
}
Ok(EnginePlan::new(bytes))
}
}
#[cfg(not(all(feature = "tensorrt-link", feature = "tensorrt-onnx")))]
{
Err(TrtError::NotLinked(
"TrtActor::build_from_onnx requires the `tensorrt-link` + `tensorrt-onnx` features",
))
}
}
}
impl Default for TrtActor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::Precision;
#[test]
fn trt_msg_constructs() {
let (b_tx, _b_rx) = oneshot::channel();
let _build = TrtMsg::Build {
source: NetworkSource::SerializedPlan(vec![1, 2, 3]),
config: Box::new(IBuilderConfig::new().with_precision(Precision::Fp16)),
reply: b_tx,
};
let (d_tx, _d_rx) = oneshot::channel();
let _deser = TrtMsg::Deserialize {
plan: EnginePlan::new(vec![0xAA; 8]),
reply: d_tx,
};
let engine = Arc::new(TrtEngine::for_test());
let (c_tx, _c_rx) = oneshot::channel();
let _ctx = TrtMsg::CreateContext {
engine: engine.clone(),
reply: c_tx,
};
let (r_tx, _r_rx) = oneshot::channel();
let _refit = TrtMsg::Refit {
engine: engine.clone(),
weights: vec![RefitWeights {
name: "fc.weight".into(),
bytes: vec![0; 16],
dtype: crate::sys::DataType::kHALF,
}],
reply: r_tx,
};
fn assert_send<T: Send>() {}
assert_send::<TrtActor>();
}
#[test]
fn actor_runtime_lazy_init() {
let actor = TrtActor::new();
#[cfg(not(feature = "tensorrt-link"))]
{
let r = actor.ensure_runtime();
assert!(matches!(r, Err(TrtError::NotLinked(_))));
}
#[cfg(feature = "tensorrt-link")]
{
let _ = actor;
}
}
}