use std::marker::PhantomData;
use std::pin::Pin;
use crate::error::PropertySetAttempt;
use crate::interfaces::MonitorProgress;
use crate::interfaces::ProgressMonitor;
use crate::optimization_profile::OptimizationProfile;
use crate::Builder;
use crate::Error;
use crate::Result;
use cxx::UniquePtr;
use trtx_sys::nvinfer1::{self, IBuilderConfig};
use trtx_sys::{
BuilderFlag, DeviceType, EngineCapability, HardwareCompatibilityLevel, MemoryPoolType,
PreviewFeature, ProfilingVerbosity, RuntimePlatform, TilingOptimizationLevel,
};
#[cfg(not(feature = "enterprise"))]
use trtx_sys::ComputeCapability;
pub struct BuilderConfig<'builder> {
pub(crate) inner: UniquePtr<IBuilderConfig>,
progress_monitor: Option<Pin<Box<ProgressMonitor>>>,
_builder: PhantomData<&'builder Builder<'builder>>,
}
impl std::fmt::Debug for BuilderConfig<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BuilderConfig")
.field("inner", &format!("{:x}", self.inner.as_ptr() as usize))
.finish_non_exhaustive()
}
}
impl<'builder> BuilderConfig<'builder> {
pub(crate) fn new(builder_config: *mut nvinfer1::IBuilderConfig) -> Result<Self> {
#[cfg(not(feature = "mock"))]
if builder_config.is_null() {
return Err(Error::BuilderConfigCreationFailed);
}
Ok(Self {
inner: unsafe { UniquePtr::from_raw(builder_config) },
progress_monitor: None,
_builder: Default::default(),
})
}
pub fn set_progress_monitor(
&mut self,
progress_monitor: Box<dyn MonitorProgress>,
) -> Result<()> {
let progress_monitor = ProgressMonitor::new(progress_monitor)?;
if self.progress_monitor.is_some() {
panic!("Setting a progress monitor more than once not supported at the moment");
}
self.progress_monitor = Some(progress_monitor);
#[cfg(not(feature = "mock"))]
unsafe {
self.inner.pin_mut().setProgressMonitor(
self.progress_monitor
.as_mut()
.expect("progress_monitor can't be empty. we just set it")
.as_trt_progress_monitor(),
)
};
Ok(())
}
pub fn set_memory_pool_limit(&mut self, pool: MemoryPoolType, size: usize) {
#[cfg(not(feature = "mock"))]
self.inner.pin_mut().setMemoryPoolLimit(pool.into(), size);
}
pub fn set_profiling_verbosity(&mut self, verbosity: ProfilingVerbosity) {
#[cfg(not(feature = "mock"))]
self.inner.pin_mut().setProfilingVerbosity(verbosity.into());
}
pub fn profiling_verbosity(&self) -> ProfilingVerbosity {
if cfg!(not(feature = "mock")) {
self.inner.getProfilingVerbosity().into()
} else {
ProfilingVerbosity::kNONE
}
}
#[deprecated = "use profiling_verbosity instead"]
pub fn get_profiling_verbosity(&self) -> ProfilingVerbosity {
self.profiling_verbosity()
}
pub fn set_avg_timing_iterations(&mut self, avg_timing: i32) {
#[cfg(not(feature = "mock"))]
self.inner.pin_mut().setAvgTimingIterations(avg_timing);
}
pub fn avg_timing_iterations(&self) -> i32 {
if cfg!(not(feature = "mock")) {
self.inner.getAvgTimingIterations()
} else {
0
}
}
#[deprecated = "use avg_timing_iterations instead"]
pub fn get_avg_timing_iterations(&self) -> i32 {
self.avg_timing_iterations()
}
pub fn set_engine_capability(&mut self, capability: EngineCapability) {
#[cfg(not(feature = "mock"))]
self.inner.pin_mut().setEngineCapability(capability.into());
}
pub fn engine_capability(&self) -> EngineCapability {
if cfg!(not(feature = "mock")) {
self.inner.getEngineCapability().into()
} else {
EngineCapability::kSTANDARD
}
}
#[deprecated = "use engine_capability instead"]
pub fn get_engine_capability(&self) -> EngineCapability {
self.engine_capability()
}
pub fn set_flags(&mut self, flags: u32) {
#[cfg(not(feature = "mock"))]
self.inner.pin_mut().setFlags(flags);
}
pub fn flags(&self) -> u32 {
if cfg!(not(feature = "mock")) {
self.inner.getFlags()
} else {
0
}
}
#[deprecated = "use flags instead"]
pub fn get_flags(&self) -> u32 {
self.flags()
}
pub fn set_flag(&mut self, flag: BuilderFlag) {
#[cfg(not(feature = "mock"))]
self.inner.pin_mut().setFlag(flag.into());
}
pub fn clear_flag(&mut self, flag: BuilderFlag) {
#[cfg(not(feature = "mock"))]
self.inner.pin_mut().clearFlag(flag.into());
}
pub fn flag(&self, flag: BuilderFlag) -> bool {
if cfg!(not(feature = "mock")) {
self.inner.getFlag(flag.into())
} else {
false
}
}
#[deprecated = "use flag instead"]
pub fn get_flag(&self, flag: BuilderFlag) -> bool {
self.flag(flag)
}
pub fn set_dla_core(&mut self, dla_core: i32) {
#[cfg(not(feature = "mock"))]
self.inner.pin_mut().setDLACore(dla_core);
}
pub fn dla_core(&self) -> i32 {
if cfg!(not(feature = "mock")) {
self.inner.getDLACore()
} else {
0
}
}
#[deprecated = "use dla_core instead"]
pub fn get_dla_core(&self) -> i32 {
self.dla_core()
}
pub fn set_default_device_type(&mut self, device_type: DeviceType) {
#[cfg(not(feature = "mock"))]
self.inner
.pin_mut()
.setDefaultDeviceType(device_type.into());
}
pub fn default_device_type(&self) -> DeviceType {
if cfg!(not(feature = "mock")) {
self.inner.getDefaultDeviceType().into()
} else {
DeviceType::kGPU
}
}
#[deprecated = "use default_device_type instead"]
pub fn get_default_device_type(&self) -> DeviceType {
self.default_device_type()
}
pub fn reset(&mut self) {
#[cfg(not(feature = "mock"))]
self.inner.pin_mut().reset();
}
pub fn nb_optimization_profiles(&self) -> i32 {
if cfg!(not(feature = "mock")) {
self.inner.getNbOptimizationProfiles()
} else {
0
}
}
#[deprecated = "use nb_optimization_profiles instead"]
pub fn get_nb_optimization_profiles(&self) -> i32 {
self.nb_optimization_profiles()
}
pub fn add_optimization_profile(
&mut self,
profile: &mut OptimizationProfile<'_>,
) -> Result<i32> {
#[cfg(not(feature = "mock"))]
{
let idx = unsafe {
self.inner
.pin_mut()
.addOptimizationProfile(profile.inner.as_mut().get_unchecked_mut())
};
if idx >= 0 {
Ok(idx)
} else {
Err(Error::Runtime("addOptimizationProfile failed".to_string()))
}
}
#[cfg(feature = "mock")]
Ok(0)
}
pub fn set_tactic_sources(&mut self, sources: u32) -> crate::Result<()> {
if cfg!(not(feature = "mock")) {
if self.inner.pin_mut().setTacticSources(sources) {
Ok(())
} else {
Err(crate::Error::FailedToSetProperty(
PropertySetAttempt::BuilderConfigTacticSources,
))
}
} else {
Ok(())
}
}
pub fn tactic_sources(&self) -> u32 {
if cfg!(not(feature = "mock")) {
self.inner.getTacticSources()
} else {
0
}
}
#[deprecated = "use tactic_sources instead"]
pub fn get_tactic_sources(&self) -> u32 {
self.tactic_sources()
}
pub fn memory_pool_limit(&self, pool: MemoryPoolType) -> usize {
if cfg!(not(feature = "mock")) {
self.inner.getMemoryPoolLimit(pool.into())
} else {
0
}
}
#[deprecated = "use memory_pool_limit instead"]
pub fn get_memory_pool_limit(&self, pool: MemoryPoolType) -> usize {
self.memory_pool_limit(pool)
}
pub fn set_preview_feature(&mut self, feature: PreviewFeature, enable: bool) {
#[cfg(not(feature = "mock"))]
self.inner
.pin_mut()
.setPreviewFeature(feature.into(), enable);
}
pub fn preview_feature(&self, feature: PreviewFeature) -> bool {
if cfg!(not(feature = "mock")) {
self.inner.getPreviewFeature(feature.into())
} else {
false
}
}
#[deprecated = "use preview_feature instead"]
pub fn get_preview_feature(&self, feature: PreviewFeature) -> bool {
self.preview_feature(feature)
}
pub fn set_builder_optimization_level(&mut self, level: i32) {
#[cfg(not(feature = "mock"))]
self.inner.pin_mut().setBuilderOptimizationLevel(level);
}
pub fn builder_optimization_level(&mut self) -> i32 {
if cfg!(not(feature = "mock")) {
self.inner.pin_mut().getBuilderOptimizationLevel()
} else {
0
}
}
#[deprecated = "use builder_optimization_level instead"]
pub fn get_builder_optimization_level(&mut self) -> i32 {
self.builder_optimization_level()
}
pub fn set_hardware_compatibility_level(&mut self, level: HardwareCompatibilityLevel) {
#[cfg(not(feature = "mock"))]
self.inner
.pin_mut()
.setHardwareCompatibilityLevel(level.into());
}
pub fn hardware_compatibility_level(&self) -> HardwareCompatibilityLevel {
self.inner.getHardwareCompatibilityLevel().into()
}
#[deprecated = "use hardware_compatibility_level instead"]
pub fn get_hardware_compatibility_level(&self) -> HardwareCompatibilityLevel {
self.hardware_compatibility_level()
}
pub fn set_max_aux_streams(&mut self, nb_streams: i32) {
#[cfg(not(feature = "mock"))]
self.inner.pin_mut().setMaxAuxStreams(nb_streams);
}
pub fn max_aux_streams(&self) -> i32 {
if cfg!(not(feature = "mock")) {
self.inner.getMaxAuxStreams()
} else {
0
}
}
#[deprecated = "use max_aux_streams instead"]
pub fn get_max_aux_streams(&self) -> i32 {
self.max_aux_streams()
}
pub fn set_runtime_platform(&mut self, platform: RuntimePlatform) {
#[cfg(not(feature = "mock"))]
self.inner.pin_mut().setRuntimePlatform(platform.into());
}
pub fn runtime_platform(&self) -> RuntimePlatform {
if cfg!(not(feature = "mock")) {
self.inner.getRuntimePlatform().into()
} else {
RuntimePlatform::kSAME_AS_BUILD
}
}
#[deprecated = "use runtime_platform instead"]
pub fn get_runtime_platform(&self) -> RuntimePlatform {
self.runtime_platform()
}
pub fn set_max_nb_tactics(&mut self, max_nb_tactics: i32) {
#[cfg(not(feature = "mock"))]
self.inner.pin_mut().setMaxNbTactics(max_nb_tactics);
}
pub fn max_nb_tactics(&self) -> i32 {
if cfg!(not(feature = "mock")) {
self.inner.getMaxNbTactics()
} else {
0
}
}
#[deprecated = "use max_nb_tactics instead"]
pub fn get_max_nb_tactics(&self) -> i32 {
self.max_nb_tactics()
}
pub fn set_tiling_optimization_level(
&mut self,
level: TilingOptimizationLevel,
) -> crate::Result<()> {
if cfg!(not(feature = "mock")) {
if self
.inner
.pin_mut()
.setTilingOptimizationLevel(level.into())
{
Ok(())
} else {
Err(crate::Error::FailedToSetProperty(
PropertySetAttempt::BuilderConfigTilingOptimizationLevel,
))
}
} else {
Ok(())
}
}
pub fn tiling_optimization_level(&self) -> TilingOptimizationLevel {
if cfg!(not(feature = "mock")) {
self.inner.getTilingOptimizationLevel().into()
} else {
TilingOptimizationLevel::kNONE
}
}
#[deprecated = "use tiling_optimization_level instead"]
pub fn get_tiling_optimization_level(&self) -> TilingOptimizationLevel {
self.tiling_optimization_level()
}
pub fn set_l2_limit_for_tiling(&mut self, size: i64) -> crate::Result<()> {
if cfg!(not(feature = "mock")) {
if self.inner.pin_mut().setL2LimitForTiling(size) {
Ok(())
} else {
Err(crate::Error::FailedToSetProperty(
PropertySetAttempt::BuilderConfigL2LimitForTiling,
))
}
} else {
Ok(())
}
}
pub fn l2_limit_for_tiling(&self) -> i64 {
if cfg!(not(feature = "mock")) {
self.inner.getL2LimitForTiling()
} else {
0
}
}
#[deprecated = "use l2_limit_for_tiling instead"]
pub fn get_l2_limit_for_tiling(&self) -> i64 {
self.l2_limit_for_tiling()
}
#[cfg(not(feature = "enterprise"))]
pub fn set_nb_compute_capabilities(
&mut self,
max_nb_compute_capabilities: i32,
) -> crate::Result<()> {
if cfg!(not(feature = "mock")) {
if self
.inner
.pin_mut()
.setNbComputeCapabilities(max_nb_compute_capabilities)
{
Ok(())
} else {
Err(crate::Error::FailedToSetProperty(
PropertySetAttempt::BuilderConfigNbComputeCapabilities,
))
}
} else {
Ok(())
}
}
#[cfg(not(feature = "enterprise"))]
pub fn nb_compute_capabilities(&self) -> i32 {
if cfg!(not(feature = "mock")) {
self.inner.getNbComputeCapabilities()
} else {
0
}
}
#[cfg(not(feature = "enterprise"))]
#[deprecated = "use nb_compute_capabilities instead"]
pub fn get_nb_compute_capabilities(&self) -> i32 {
self.nb_compute_capabilities()
}
#[cfg(not(feature = "enterprise"))]
pub fn set_compute_capability(
&mut self,
compute_capability: ComputeCapability,
index: i32,
) -> crate::Result<()> {
if cfg!(not(feature = "mock")) {
if self
.inner
.pin_mut()
.setComputeCapability(compute_capability.into(), index)
{
Ok(())
} else {
Err(crate::Error::FailedToSetProperty(
PropertySetAttempt::BuilderConfigComputeCapability,
))
}
} else {
Ok(())
}
}
#[cfg(not(feature = "enterprise"))]
pub fn compute_capability(&self, index: i32) -> ComputeCapability {
if cfg!(not(feature = "mock")) {
self.inner.getComputeCapability(index).into()
} else {
ComputeCapability::kNONE
}
}
#[cfg(not(feature = "enterprise"))]
#[deprecated = "use compute_capability instead"]
pub fn get_compute_capability(&self, index: i32) -> ComputeCapability {
self.compute_capability(index)
}
}
#[cfg(test)]
#[cfg(not(feature = "mock"))]
mod tests {
use crate::builder::MemoryPoolType;
use crate::interfaces::MonitorProgress;
use crate::{Builder, DataType, Logger, NetworkDefinition};
use std::ops::ControlFlow;
use std::sync::atomic::{AtomicU32, Ordering};
const NUM_LAYERS: usize = 40;
struct StdoutProgressMonitor {
step_count: AtomicU32,
cancel_after: u32,
}
impl StdoutProgressMonitor {
fn new(cancel_after: u32) -> Self {
Self {
step_count: AtomicU32::new(0),
cancel_after,
}
}
}
impl MonitorProgress for StdoutProgressMonitor {
fn phase_start(&self, phase_name: &str, parent_phase: Option<&str>, num_steps: i32) {
println!(
"[progress] phase_start phase={:?} parent={:?} num_steps={}",
phase_name, parent_phase, num_steps
);
}
fn step_complete(&self, phase_name: &str, step: i32) -> ControlFlow<()> {
let n = self.step_count.fetch_add(1, Ordering::SeqCst);
println!(
"[progress] step_complete phase={:?} step={}",
phase_name, step
);
if n + 1 >= self.cancel_after {
println!("[progress] cancel requested after {} steps", n + 1);
ControlFlow::Break(())
} else {
ControlFlow::Continue(())
}
}
fn phase_finish(&self, phase_name: &str) {
println!("[progress] phase_finish phase={:?}", phase_name);
}
}
fn build_heavy_network(logger: &Logger) -> crate::Result<(Builder<'_>, NetworkDefinition<'_>)> {
let mut builder = Builder::new(logger)?;
let mut network = builder.create_network(0)?;
let mut tensor = network.add_input("input", DataType::kFLOAT, &[1, 4])?;
for i in 0..NUM_LAYERS {
let mut layer = network.add_identity(&tensor)?;
layer.set_name(&mut network, &format!("layer_{}", i))?;
tensor = layer.output(&network, 0)?;
}
tensor.set_name(&mut network, "output")?;
network.mark_output(&tensor);
Ok((builder, network))
}
#[test]
fn set_progress_monitor_cancel_build() {
let logger = Logger::stderr().expect("logger");
let (mut builder, mut network) = build_heavy_network(&logger).expect("build network");
let mut config = builder.create_config().expect("config");
config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 24);
let monitor = Box::new(StdoutProgressMonitor::new(3));
config.set_progress_monitor(monitor).unwrap();
let result = builder.build_serialized_network(&mut network, &mut config);
assert!(
result.is_err(),
"build should fail (cancelled by progress monitor)"
);
}
#[test]
fn set_progress_monitor_progress_to_stdout() {
let logger = Logger::stderr().expect("logger");
let (mut builder, mut network) = build_heavy_network(&logger).expect("build network");
let mut config = builder.create_config().expect("config");
config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 24);
let monitor = Box::new(StdoutProgressMonitor::new(10000));
config.set_progress_monitor(monitor).unwrap();
let result = builder.build_serialized_network(&mut network, &mut config);
assert!(result.is_ok(), "build should succeed when not cancelling");
}
}