use std::ffi::{CStr, CString};
use std::marker::PhantomData;
use std::pin::Pin;
use crate::interfaces::{ErrorRecorder, RecordError};
use crate::{
error::{Error, Result},
CudaEngine, Logger,
};
use autocxx::cxx::UniquePtr;
use trtx_sys::{nvinfer1, DataType};
#[derive(Debug)]
pub struct Weights<'data, T> {
data: &'data [T],
data_type: DataType,
}
impl<T> Weights<'_, T> {
fn as_raw(&self) -> nvinfer1::Weights {
nvinfer1::Weights {
type_: self.data_type.into(),
values: self.data.as_ptr() as *const std::ffi::c_void,
count: (size_of_val(self.data) * 8 / self.data_type.size_bits()) as i64,
}
}
}
pub struct Refitter<'logger, 'engine> {
inner: UniquePtr<nvinfer1::IRefitter>,
error_recorder: Option<Pin<Box<ErrorRecorder>>>,
_logger: PhantomData<&'logger Logger>,
_engine: PhantomData<&'engine CudaEngine<'engine>>,
}
impl std::fmt::Debug for Refitter<'_, '_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Refitter")
.field("inner", &format!("{:x}", self.inner.as_ptr() as usize))
.finish_non_exhaustive()
}
}
impl<'logger, 'engine> Refitter<'logger, 'engine> {
#[cfg(not(feature = "link_tensorrt_rtx"))]
#[cfg(not(feature = "dlopen_tensorrt_rtx"))]
pub fn new(_cuda_engine: &'engine CudaEngine, _logger: &'logger Logger) -> Result<Self> {
Err(Error::TrtRtxLibraryNotLoaded)
}
#[cfg(any(feature = "link_tensorrt_rtx", feature = "dlopen_tensorrt_rtx"))]
pub fn new(cuda_engine: &'engine CudaEngine, logger: &'logger Logger) -> Result<Self> {
#[cfg(not(feature = "mock_runtime"))]
{
let logger_ptr = logger.as_logger_ptr();
let engine_ptr = cuda_engine.inner.as_mut_ptr() as *mut std::ffi::c_void;
let refitter = {
#[cfg(feature = "link_tensorrt_rtx")]
unsafe {
trtx_sys::create_infer_refitter(engine_ptr, logger_ptr)
}
#[cfg(not(feature = "link_tensorrt_rtx"))]
#[cfg(feature = "dlopen_tensorrt_rtx")]
unsafe {
use libloading::Symbol;
use std::ffi::c_void;
use crate::TRTLIB;
if !TRTLIB.read()?.is_some() {
crate::dynamically_load_tensorrt(None::<String>)?;
}
let lock = TRTLIB.read()?;
let create_infer_refitter: Symbol<
fn(*mut c_void, *mut c_void, u32) -> *mut nvinfer1::IRefitter,
> = lock
.as_ref()
.ok_or(Error::TrtRtxLibraryNotLoaded)?
.get(b"createInferRefitter_INTERNAL")?;
create_infer_refitter(engine_ptr, logger_ptr, trtx_sys::get_tensorrt_version())
}
};
if refitter.is_null() {
return Err(Error::Runtime("Failed to create refitter".to_string()));
}
Ok(Self {
inner: unsafe { UniquePtr::from_raw(refitter) },
error_recorder: None,
_engine: Default::default(),
_logger: Default::default(),
})
}
#[cfg(feature = "mock_runtime")]
Ok(Refitter {
inner: UniquePtr::null(),
error_recorder: None,
_engine: Default::default(),
_logger: Default::default(),
})
}
pub fn set_weights<T>(
&mut self,
layer_name: &str,
role: nvinfer1::WeightsRole,
weights: Weights<'engine, T>,
) -> Result<()> {
let name_cstr = CString::new(layer_name)?;
if unsafe {
self.inner
.pin_mut()
.setWeights(name_cstr.as_ptr(), role, weights.as_raw())
} {
Ok(())
} else {
Err(Error::Runtime(
"setWeights rejected (invalid layer/role/count/type)".to_string(),
))
}
}
pub fn refit_cuda_engine(&mut self) -> Result<()> {
if self.inner.pin_mut().refitCudaEngine() {
Ok(())
} else {
Err(Error::Runtime(
"refitCudaEngine failed (validation or getMissingWeights != 0)".to_string(),
))
}
}
pub fn missing(&self, max_count: i32) -> Result<Vec<(String, nvinfer1::WeightsRole)>> {
let n = max_count.max(0) as usize;
let mut layer_names: Vec<*const std::ffi::c_char> = vec![std::ptr::null(); n];
let mut roles: Vec<i32> = vec![0; n];
let refitter_ptr = self.refitter_ptr();
let count = unsafe {
trtx_sys::trtx_refitter_get_missing(
refitter_ptr,
n as i32,
layer_names.as_mut_ptr(),
roles.as_mut_ptr(),
)
};
let count = count.max(0) as usize;
let mut out = Vec::with_capacity(count);
for i in 0..count.min(layer_names.len()) {
let ptr = layer_names[i];
if ptr.is_null() {
break;
}
let s = unsafe { CStr::from_ptr(ptr) }.to_str()?.to_string();
let role = unsafe { std::mem::transmute::<i32, nvinfer1::WeightsRole>(roles[i]) };
out.push((s, role));
}
Ok(out)
}
#[deprecated = "use missing instead"]
pub fn get_missing(&self, max_count: i32) -> Result<Vec<(String, nvinfer1::WeightsRole)>> {
self.missing(max_count)
}
pub fn all(&self, max_count: i32) -> Result<Vec<(String, nvinfer1::WeightsRole)>> {
let n = max_count.max(0) as usize;
let mut layer_names: Vec<*const std::ffi::c_char> = vec![std::ptr::null(); n];
let mut roles: Vec<i32> = vec![0; n];
let refitter_ptr = self.refitter_ptr();
let count = unsafe {
trtx_sys::trtx_refitter_get_all(
refitter_ptr,
n as i32,
layer_names.as_mut_ptr(),
roles.as_mut_ptr(),
)
};
let count = count.max(0) as usize;
let mut out = Vec::with_capacity(count);
for i in 0..count.min(layer_names.len()) {
let ptr = layer_names[i];
if ptr.is_null() {
break;
}
let s = unsafe { CStr::from_ptr(ptr) }.to_str()?.to_string();
let role = unsafe { std::mem::transmute::<i32, nvinfer1::WeightsRole>(roles[i]) };
out.push((s, role));
}
Ok(out)
}
#[deprecated = "use all instead"]
pub fn get_all(&self, max_count: i32) -> Result<Vec<(String, nvinfer1::WeightsRole)>> {
self.all(max_count)
}
pub fn set_error_recorder(&mut self, error_recorder: Box<dyn RecordError>) -> Result<()> {
let error_recorder = ErrorRecorder::new(error_recorder)?;
if self.error_recorder.is_some() {
panic!("Setting a progress monitor more than once not supported at the moment");
}
self.error_recorder = Some(error_recorder);
let rec = self
.error_recorder
.as_mut()
.unwrap()
.as_trt_error_recorder();
#[cfg(not(feature = "mock_runtime"))]
unsafe {
self.inner.pin_mut().setErrorRecorder(rec)
};
Ok(())
}
pub fn error_recorder(&self) -> *mut nvinfer1::IErrorRecorder {
self.inner.getErrorRecorder()
}
#[deprecated = "use error_recorder instead"]
pub fn get_error_recorder(&self) -> *mut nvinfer1::IErrorRecorder {
self.error_recorder()
}
pub fn set_named_weights<T>(
&mut self,
name: &str,
weights: &Weights<'engine, T>,
) -> Result<()> {
let name_cstr = CString::new(name)?;
if unsafe {
self.inner
.pin_mut()
.setNamedWeights(name_cstr.as_ptr(), weights.as_raw())
} {
Ok(())
} else {
Err(Error::Runtime(
"setNamedWeights rejected (invalid name/count/type)".to_string(),
))
}
}
pub fn missing_weights(&self, max_count: i32) -> Result<Vec<String>> {
let n = max_count.max(0) as usize;
let mut names: Vec<*const std::ffi::c_char> = vec![std::ptr::null(); n];
let count = unsafe {
trtx_sys::trtx_refitter_get_missing_weights(
self.refitter_ptr(),
n as i32,
names.as_mut_ptr(),
)
};
let count = count.max(0) as usize;
let out = names
.iter()
.take(count.min(names.len()))
.take_while(|n| !n.is_null())
.map(|n| unsafe { CStr::from_ptr(*n).to_string_lossy().to_string() })
.collect();
Ok(out)
}
#[deprecated = "use missing_weights instead"]
pub fn get_missing_weights(&self, max_count: i32) -> Result<Vec<String>> {
self.missing_weights(max_count)
}
pub fn all_weights(&self, max_count: i32) -> Result<Vec<String>> {
let n = max_count.max(0) as usize;
let mut names: Vec<*const std::ffi::c_char> = vec![std::ptr::null(); n];
let count = unsafe {
trtx_sys::trtx_refitter_get_all_weights(
self.refitter_ptr(),
n as i32,
names.as_mut_ptr(),
)
};
let count = count.max(0) as usize;
let out = names
.iter()
.take(count.min(names.len()))
.take_while(|n| !n.is_null())
.map(|n| unsafe { CStr::from_ptr(*n).to_string_lossy().to_string() })
.collect();
Ok(out)
}
#[deprecated = "use all_weights instead"]
pub fn get_all_weights(&self, max_count: i32) -> Result<Vec<String>> {
self.all_weights(max_count)
}
fn refitter_ptr(&self) -> *mut std::ffi::c_void {
self.inner.as_ptr() as *mut std::ffi::c_void
}
pub fn logger(&self) -> *mut nvinfer1::ILogger {
self.inner.getLogger()
}
#[deprecated = "use logger instead"]
pub fn get_logger(&self) -> *mut nvinfer1::ILogger {
self.logger()
}
pub fn set_max_threads(&mut self, max_threads: i32) -> Result<()> {
if self.inner.pin_mut().setMaxThreads(max_threads) {
Ok(())
} else {
Err(Error::InvalidArgument("setMaxThreads failed".to_string()))
}
}
pub fn max_threads(&self) -> i32 {
self.inner.getMaxThreads()
}
#[deprecated = "use max_threads instead"]
pub fn get_max_threads(&self) -> i32 {
self.max_threads()
}
pub unsafe fn set_named_weights_with_location(
&mut self,
name: &str,
weights: nvinfer1::Weights,
location: nvinfer1::TensorLocation,
) -> Result<()> {
let name_cstr = CString::new(name)?;
if unsafe {
self.inner
.pin_mut()
.setNamedWeights1(name_cstr.as_ptr(), weights, location)
} {
Ok(())
} else {
Err(Error::Runtime(
"setNamedWeights (with location) rejected".to_string(),
))
}
}
pub fn named_weights(&self, weights_name: &str) -> nvinfer1::Weights {
let name_cstr = CString::new(weights_name).expect("name contains null");
unsafe { self.inner.getNamedWeights(name_cstr.as_ptr()) }
}
#[deprecated = "use named_weights instead"]
pub fn get_named_weights(&self, weights_name: &str) -> nvinfer1::Weights {
self.named_weights(weights_name)
}
pub fn weights_location(&self, weights_name: &str) -> nvinfer1::TensorLocation {
let name_cstr = CString::new(weights_name).expect("name contains null");
unsafe { self.inner.getWeightsLocation(name_cstr.as_ptr()) }
}
#[deprecated = "use weights_location instead"]
pub fn get_weights_location(&self, weights_name: &str) -> nvinfer1::TensorLocation {
self.weights_location(weights_name)
}
pub fn unset_named_weights(&mut self, weights_name: &str) -> bool {
let name_cstr = CString::new(weights_name).expect("name contains null");
unsafe { self.inner.pin_mut().unsetNamedWeights(name_cstr.as_ptr()) }
}
pub fn set_weights_validation(&mut self, weights_validation: bool) {
self.inner
.pin_mut()
.setWeightsValidation(weights_validation);
}
pub fn weights_validation(&self) -> bool {
self.inner.getWeightsValidation()
}
#[deprecated = "use weights_validation instead"]
pub fn get_weights_validation(&self) -> bool {
self.weights_validation()
}
pub unsafe fn refit_cuda_engine_async(
&mut self,
cuda_stream: *mut std::ffi::c_void,
) -> Result<()> {
if self
.inner
.pin_mut()
.refitCudaEngineAsync(cuda_stream as *mut _)
{
Ok(())
} else {
Err(Error::Runtime(
"refitCudaEngineAsync failed (validation or getMissingWeights != 0)".to_string(),
))
}
}
pub fn weights_prototype(&self, weights_name: &str) -> nvinfer1::Weights {
let name_cstr = CString::new(weights_name).expect("name contains null");
unsafe { self.inner.getWeightsPrototype(name_cstr.as_ptr()) }
}
#[deprecated = "use weights_prototype instead"]
pub fn get_weights_prototype(&self, weights_name: &str) -> nvinfer1::Weights {
self.weights_prototype(weights_name)
}
}
#[cfg(test)]
#[cfg(not(feature = "mock_runtime"))]
mod tests {
use crate::interfaces::RecordError;
use std::sync::atomic::{AtomicI32, Ordering};
use std::sync::{Arc, Mutex};
use trtx_sys::BuilderFlag;
use trtx_sys::ErrorCode;
use super::*;
use crate::builder::MemoryPoolType;
use crate::refitter::Weights;
use crate::{Builder, DataType, Logger, Runtime};
struct VecErrorRecorder {
messages: Arc<Mutex<Vec<(ErrorCode, String)>>>,
ref_count: AtomicI32,
}
impl VecErrorRecorder {
fn new(messages: Arc<Mutex<Vec<(ErrorCode, String)>>>) -> Self {
Self {
messages,
ref_count: AtomicI32::new(0),
}
}
}
impl RecordError for VecErrorRecorder {
fn nb_errors(&self) -> i32 {
self.messages.lock().unwrap().len() as i32
}
fn error_code(&self, error_idx: i32) -> ErrorCode {
self.messages.lock().unwrap()[error_idx as usize].0
}
fn error_desc(&self, _error_idx: i32) -> &CStr {
static EMPTY: &[u8] = b"\0";
unsafe { CStr::from_bytes_with_nul_unchecked(EMPTY) }
}
fn has_overflowed(&self) -> bool {
false
}
fn clear(&self) {
self.messages.lock().unwrap().clear();
}
fn report_error(&self, val: ErrorCode, desc: &str) -> bool {
self.messages.lock().unwrap().push((val, desc.to_string()));
true
}
fn inc_ref_count(&self) -> i32 {
self.ref_count.fetch_add(1, Ordering::SeqCst) + 1
}
fn dec_ref_count(&self) -> i32 {
self.ref_count.fetch_sub(1, Ordering::SeqCst) - 1
}
}
fn build_constant_network(logger: &Logger) -> Result<Vec<u8>> {
let mut builder = Builder::new(logger)?;
let mut network = builder.create_network(0)?;
let dims = [1, 4];
let initial: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
let weights_bytes: &[u8] = unsafe {
std::slice::from_raw_parts(
initial.as_ptr() as *const u8,
initial.len() * std::mem::size_of::<f32>(),
)
};
let mut const_layer = network.add_constant(&dims, weights_bytes, DataType::kFLOAT, None)?;
const_layer.set_name(&mut network, "refit_const")?;
let output = const_layer.output(&network, 0)?;
output.set_name(&mut network, "output")?;
network.mark_output(&output);
let mut config = builder.create_config()?;
config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 24);
config.set_flag(BuilderFlag::kREFIT);
let unstripped_engine_data = builder.build_serialized_network(&mut network, &mut config)?;
config.set_flag(BuilderFlag::kSTRIP_PLAN);
let engine_data = builder.build_serialized_network(&mut network, &mut config)?;
assert!(engine_data.len() < unstripped_engine_data.len());
Ok(engine_data.to_vec())
}
#[test]
fn refitter_from_constant_network() {
let logger = Logger::stderr().expect("logger");
let engine_data = build_constant_network(&logger).expect("build network");
assert!(!engine_data.is_empty());
let mut runtime = Runtime::new(&logger).expect("runtime");
let engine = runtime
.deserialize_cuda_engine(&engine_data)
.expect("deserialize engine");
let mut refitter = Refitter::new(&engine, &logger).expect("refitter");
let all = refitter.all_weights(64).expect("get_all_weights");
assert!(
!all.is_empty(),
"engine should have at least one refittable weight (constant layer)"
);
let weight_name = &all[0];
let proto = refitter.weights_prototype(weight_name);
assert!(proto.count >= 0 || proto.count == -1);
let new_vals: [f32; 4] = [10.0, 20.0, 30.0, 40.0];
let new_weights = Weights {
data_type: DataType::kFLOAT,
data: &new_vals,
};
refitter
.set_named_weights(weight_name, &new_weights)
.expect("set_named_weights");
refitter.refit_cuda_engine().expect("refit_cuda_engine");
}
#[test]
fn refitter_error_recorder_collects_invalid_weight_error() {
let logger = Logger::stderr().expect("logger");
let engine_data = build_constant_network(&logger).expect("build network");
assert!(!engine_data.is_empty());
let mut runtime = Runtime::new(&logger).expect("runtime");
let engine = runtime
.deserialize_cuda_engine(&engine_data)
.expect("deserialize engine");
let mut refitter = Refitter::new(&engine, &logger).expect("refitter");
let weight_name = refitter.all_weights(64).expect("get_all_weights")[0].clone();
let errors: Arc<Mutex<Vec<(ErrorCode, String)>>> = Arc::new(Mutex::new(Vec::new()));
let recorder = Box::new(VecErrorRecorder::new(Arc::clone(&errors)));
refitter.set_error_recorder(recorder).unwrap();
let wrong_weights = Weights {
data_type: DataType::kFLOAT,
data: &[1.0f32],
};
let _ = refitter.set_named_weights(&weight_name, &wrong_weights);
refitter.named_weights("nonexistent_weight_name");
let collected = errors.lock().unwrap();
assert!(
!collected.is_empty(),
"error recorder should have collected at least one error (invalid weight or nonexistent name)"
);
}
}