use std::sync::Arc;
use tokio::sync::oneshot;
use crate::builder::IBuilderConfig;
use crate::engine::{EnginePlan, TrtEngine};
use crate::error::TrtError;
use crate::runtime::{ExecutionBindings, ExecutionContext};
#[derive(Debug, Clone)]
pub enum NetworkSource {
Onnx(Vec<u8>),
SerializedPlan(Vec<u8>),
}
pub struct RefitWeights {
pub name: String,
pub bytes: Vec<u8>,
pub dtype: crate::sys::DataType,
}
pub type BuildReply = oneshot::Sender<Result<EnginePlan, TrtError>>;
pub type DeserializeReply = oneshot::Sender<Result<Arc<TrtEngine>, TrtError>>;
pub type CreateContextReply = oneshot::Sender<Result<ExecutionContext, TrtError>>;
pub type EnqueueReply = oneshot::Sender<Result<(), TrtError>>;
pub type RefitReply = oneshot::Sender<Result<(), TrtError>>;
pub enum TrtMsg {
Build {
source: NetworkSource,
config: Box<IBuilderConfig>,
reply: BuildReply,
},
Deserialize {
plan: EnginePlan,
reply: DeserializeReply,
},
CreateContext {
engine: Arc<TrtEngine>,
reply: CreateContextReply,
},
EnqueueOnStream {
stream: Arc<cudarc::driver::CudaStream>,
context: ExecutionContext,
bindings: ExecutionBindings,
reply: EnqueueReply,
},
Refit {
engine: Arc<TrtEngine>,
weights: Vec<RefitWeights>,
reply: RefitReply,
},
}
pub struct TrtActor {
runtime: parking_lot::Mutex<Option<crate::runtime::TrtRuntime>>,
}
impl TrtActor {
pub fn new() -> Self {
Self {
runtime: parking_lot::Mutex::new(None),
}
}
pub fn ensure_runtime(&self) -> Result<(), TrtError> {
let mut guard = self.runtime.lock();
if guard.is_none() {
*guard = Some(crate::runtime::TrtRuntime::new()?);
}
Ok(())
}
}
impl Default for TrtActor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::Precision;
#[test]
fn trt_msg_constructs() {
let (b_tx, _b_rx) = oneshot::channel();
let _build = TrtMsg::Build {
source: NetworkSource::SerializedPlan(vec![1, 2, 3]),
config: Box::new(IBuilderConfig::new().with_precision(Precision::Fp16)),
reply: b_tx,
};
let (d_tx, _d_rx) = oneshot::channel();
let _deser = TrtMsg::Deserialize {
plan: EnginePlan::new(vec![0xAA; 8]),
reply: d_tx,
};
let engine = Arc::new(TrtEngine::for_test());
let (c_tx, _c_rx) = oneshot::channel();
let _ctx = TrtMsg::CreateContext {
engine: engine.clone(),
reply: c_tx,
};
let (r_tx, _r_rx) = oneshot::channel();
let _refit = TrtMsg::Refit {
engine: engine.clone(),
weights: vec![RefitWeights {
name: "fc.weight".into(),
bytes: vec![0; 16],
dtype: crate::sys::DataType::kHALF,
}],
reply: r_tx,
};
fn assert_send<T: Send>() {}
assert_send::<TrtActor>();
}
#[test]
fn actor_runtime_lazy_init() {
let actor = TrtActor::new();
#[cfg(not(feature = "tensorrt-link"))]
{
let r = actor.ensure_runtime();
assert!(matches!(r, Err(TrtError::NotLinked(_))));
}
#[cfg(feature = "tensorrt-link")]
{
let _ = actor;
}
}
}