use crate::interfaces::RecordError;
pub use crate::tensor::Tensor;
use cxx::UniquePtr;
use std::ffi::{CStr, CString};
use std::marker::PhantomData;
use std::pin::Pin;
use trtx_sys::nvinfer1::{IConcatenationLayer, INetworkDefinition, ITensor};
use trtx_sys::InterpolationMode;
use trtx_sys::{nvinfer1, LayerType, SampleMode, Weights};
use trtx_sys::{AsLayer, AsLayerTyped};
#[cfg(feature = "v_1_4")]
use trtx_sys::{CollectiveOperation, MoEActType, ReduceOperation};
use trtx_sys::{DataType, Dims64, MatrixOperation, ScaleMode, TopKOperation};
macro_rules! check_network {
($network:ident, $this:ident) => {
if $network.inner.as_ptr() != $this.network {
panic!("Layer or tensor was created from different network")
}
};
($network:ident, $tensor:expr) => {
if $network.inner.as_ptr() != $tensor.network {
panic!("Layer or tensor was created from different network")
}
};
}
pub(crate) use check_network;
use crate::error::{Error, OkOrFailedSettingProperty, PropertySetAttempt, Result};
use crate::interfaces::ErrorRecorder;
use log::{debug, trace};
pub struct ConvWeights<'a> {
pub kernel_weights: &'a [u8],
pub kernel_dtype: crate::DataType,
pub bias_weights: Option<&'a [u8]>,
pub bias_dtype: Option<crate::DataType>,
}
pub struct OwnedWeights {
pub shape: Vec<i64>,
pub data_type: DataType,
pub values: Vec<u8>,
}
pub struct OwnedConvWeights {
pub kernel: OwnedWeights,
pub bias: Option<OwnedWeights>,
}
impl OwnedConvWeights {
pub fn as_weights(&self) -> ConvWeights<'_> {
ConvWeights {
kernel_weights: &self.kernel.values,
kernel_dtype: self.kernel.data_type,
bias_weights: self.bias.as_ref().map(|b| b.values.as_slice()),
bias_dtype: self.bias.as_ref().map(|b| b.data_type),
}
}
}
pub struct Layer<'network, Inner: AsLayer> {
pub(crate) inner: Pin<&'network mut Inner>,
pub(crate) network: *const nvinfer1::INetworkDefinition,
}
impl<'network, Inner: AsLayerTyped> Layer<'network, Inner> {
pub const fn layer_type(&self) -> LayerType {
Inner::TYPE
}
pub(crate) fn new(
network: *const nvinfer1::INetworkDefinition,
ptr: *mut Inner,
) -> Result<Self> {
unsafe {
let ptr = ptr
.as_mut()
.ok_or(Error::LayerCreationFailed(Inner::TYPE))?;
Ok(Self {
inner: Pin::new_unchecked(ptr),
network,
})
}
}
}
impl<'network> Layer<'network, nvinfer1::ILayer> {
pub fn layer_type_dynamic(&self) -> LayerType {
self.inner.as_layer().getType().into()
}
pub(crate) fn new_dyn(
network: *const nvinfer1::INetworkDefinition,
ptr: *mut nvinfer1::ILayer,
) -> Result<Self> {
unsafe {
let ptr = ptr.as_mut().ok_or(Error::GetLayerFailed)?;
Ok(Self {
inner: Pin::new_unchecked(ptr),
network,
})
}
}
}
impl<'network, Inner: AsLayer> Layer<'network, Inner> {
pub fn set_input(
&mut self,
network: &'_ mut NetworkDefinition,
index: i32,
tensor: &'_ Tensor,
) -> Result<()> {
check_network!(network, self);
debug!(
"set_input layer={} index={index} tensor={}",
layer_dbg(network, self),
tensor_dbg(network, tensor)
);
unsafe { self.inner.as_mut().get_unchecked_mut() }
.as_layer_pin_mut()
.setInput(index, tensor.pin_mut());
Ok(())
}
pub fn input(&self, network: &'_ NetworkDefinition, index: i32) -> Result<Tensor<'network>> {
check_network!(network, self);
let tensor = self.inner.as_layer().getInput(index);
unsafe { Tensor::new(self.network, tensor) }
}
#[deprecated = "use input instead"]
pub fn get_input(
&self,
network: &'_ NetworkDefinition,
index: i32,
) -> Result<Tensor<'network>> {
self.input(network, index)
}
pub fn output(&self, network: &'_ NetworkDefinition, index: i32) -> Result<Tensor<'network>> {
check_network!(network, self);
let tensor = self.inner.as_layer().getOutput(index);
unsafe { Tensor::new(self.network, tensor) }
}
#[deprecated = "use output instead"]
pub fn get_output(
&self,
network: &'_ NetworkDefinition,
index: i32,
) -> Result<Tensor<'network>> {
self.output(network, index)
}
pub fn num_inputs(&self, network: &'_ NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.as_layer().getNbInputs()
}
#[deprecated = "use num_inputs instead"]
pub fn get_num_inputs(&self, network: &'_ NetworkDefinition) -> i32 {
self.num_inputs(network)
}
pub fn num_outputs(&self, network: &'_ NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.as_layer().getNbOutputs()
}
#[deprecated = "use num_outputs instead"]
pub fn get_num_outputs(&self, network: &'_ NetworkDefinition) -> i32 {
self.num_outputs(network)
}
pub fn set_name(&mut self, network: &'_ mut NetworkDefinition, name: &str) -> Result<()> {
check_network!(network, self);
let name = CString::new(name)?;
unsafe {
self.inner
.as_mut()
.get_unchecked_mut()
.as_layer_pin_mut()
.setName(name.as_ptr())
};
Ok(())
}
pub fn name(&self, network: &NetworkDefinition) -> String {
check_network!(network, self);
let name = self.inner.as_layer().getName();
if name.is_null() {
"(unamed)".to_string()
} else {
unsafe { CStr::from_ptr(name).to_string_lossy().to_string() }
}
}
}
pub type ActivationLayer<'layer> = Layer<'layer, nvinfer1::IActivationLayer>;
pub type AssertionLayer<'layer> = Layer<'layer, nvinfer1::IAssertionLayer>;
pub type CastLayer<'layer> = Layer<'layer, nvinfer1::ICastLayer>;
pub type ConcatenationLayer<'layer> = Layer<'layer, nvinfer1::IConcatenationLayer>;
pub type ConstantLayer<'layer> = Layer<'layer, nvinfer1::IConstantLayer>;
pub type ConvolutionLayer<'layer> = Layer<'layer, nvinfer1::IConvolutionLayer>;
pub type CumulativeLayer<'layer> = Layer<'layer, nvinfer1::ICumulativeLayer>;
pub type DeconvolutionLayer<'layer> = Layer<'layer, nvinfer1::IDeconvolutionLayer>;
pub type DequantizeLayer<'layer> = Layer<'layer, nvinfer1::IDequantizeLayer>;
pub type DynamicQuantizeLayer<'layer> = Layer<'layer, nvinfer1::IDynamicQuantizeLayer>;
pub type ElementWiseLayer<'layer> = Layer<'layer, nvinfer1::IElementWiseLayer>;
pub type EinsumLayer<'layer> = Layer<'layer, nvinfer1::IEinsumLayer>;
pub type FillLayer<'layer> = Layer<'layer, nvinfer1::IFillLayer>;
pub type GatherLayer<'layer> = Layer<'layer, nvinfer1::IGatherLayer>;
pub type GridSampleLayer<'layer> = Layer<'layer, nvinfer1::IGridSampleLayer>;
pub type IdentityLayer<'layer> = Layer<'layer, nvinfer1::IIdentityLayer>;
pub type MatrixMultiplyLayer<'layer> = Layer<'layer, nvinfer1::IMatrixMultiplyLayer>;
pub type NMSLayer<'layer> = Layer<'layer, nvinfer1::INMSLayer>;
pub type NonZeroLayer<'layer> = Layer<'layer, nvinfer1::INonZeroLayer>;
pub type NormalizationLayer<'layer> = Layer<'layer, nvinfer1::INormalizationLayer>;
pub type PaddingLayer<'layer> = Layer<'layer, nvinfer1::IPaddingLayer>;
pub type ParametricReLULayer<'layer> = Layer<'layer, nvinfer1::IParametricReLULayer>;
pub type PoolingLayer<'layer> = Layer<'layer, nvinfer1::IPoolingLayer>;
pub type QuantizeLayer<'layer> = Layer<'layer, nvinfer1::IQuantizeLayer>;
pub type RaggedSoftMaxLayer<'layer> = Layer<'layer, nvinfer1::IRaggedSoftMaxLayer>;
pub type ReduceLayer<'layer> = Layer<'layer, nvinfer1::IReduceLayer>;
pub type ResizeLayer<'layer> = Layer<'layer, nvinfer1::IResizeLayer>;
pub type RotaryEmbeddingLayer<'layer> = Layer<'layer, nvinfer1::IRotaryEmbeddingLayer>;
pub type ScaleLayer<'layer> = Layer<'layer, nvinfer1::IScaleLayer>;
pub type ScatterLayer<'layer> = Layer<'layer, nvinfer1::IScatterLayer>;
pub type SelectLayer<'layer> = Layer<'layer, nvinfer1::ISelectLayer>;
pub type ShapeLayer<'layer> = Layer<'layer, nvinfer1::IShapeLayer>;
pub type ShuffleLayer<'layer> = Layer<'layer, nvinfer1::IShuffleLayer>;
pub type SliceLayer<'layer> = Layer<'layer, nvinfer1::ISliceLayer>;
pub type SoftMaxLayer<'layer> = Layer<'layer, nvinfer1::ISoftMaxLayer>;
pub type SqueezeLayer<'layer> = Layer<'layer, nvinfer1::ISqueezeLayer>;
pub type TopKLayer<'layer> = Layer<'layer, nvinfer1::ITopKLayer>;
pub type UnaryLayer<'layer> = Layer<'layer, nvinfer1::IUnaryLayer>;
pub type UnsqueezeLayer<'layer> = Layer<'layer, nvinfer1::IUnsqueezeLayer>;
pub type ReverseSequenceLayer<'layer> = Layer<'layer, nvinfer1::IReverseSequenceLayer>;
pub type KVCacheUpdateLayer<'layer> = Layer<'layer, nvinfer1::IKVCacheUpdateLayer>;
pub type LrnLayer<'layer> = Layer<'layer, nvinfer1::ILRNLayer>;
pub type OneHotLayer<'layer> = Layer<'layer, nvinfer1::IOneHotLayer>;
#[cfg(feature = "v_1_4")]
#[cfg(feature = "v_1_4")]
pub type MoELayer<'layer> = Layer<'layer, nvinfer1::IMoELayer>;
#[cfg(feature = "v_1_4")]
#[cfg(feature = "v_1_4")]
pub type DistCollectiveLayer<'layer> = Layer<'layer, nvinfer1::IDistCollectiveLayer>;
pub type AttentionInputLayer<'layer> = Layer<'layer, nvinfer1::IAttentionInputLayer>;
pub type AttentionOutputLayer<'layer> = Layer<'layer, nvinfer1::IAttentionOutputLayer>;
pub type AttentionBoundaryLayer<'layer> = Layer<'layer, nvinfer1::IAttentionBoundaryLayer>;
pub type LoopBoundaryLayer<'layer> = Layer<'layer, nvinfer1::ILoopBoundaryLayer>;
pub type RecurrenceLayer<'layer> = Layer<'layer, nvinfer1::IRecurrenceLayer>;
pub type LoopOutputLayer<'layer> = Layer<'layer, nvinfer1::ILoopOutputLayer>;
pub type TripLimitLayer<'layer> = Layer<'layer, nvinfer1::ITripLimitLayer>;
pub type IteratorLayer<'layer> = Layer<'layer, nvinfer1::IIteratorLayer>;
pub type ConditionLayer<'layer> = Layer<'layer, nvinfer1::IConditionLayer>;
pub type IfConditionalOutputLayer<'layer> = Layer<'layer, nvinfer1::IIfConditionalOutputLayer>;
pub type IfConditionalInputLayer<'layer> = Layer<'layer, nvinfer1::IIfConditionalInputLayer>;
pub type DynLayer<'layer> = Layer<'layer, nvinfer1::ILayer>;
pub struct Attention<'network> {
pub(crate) inner: Pin<&'network mut nvinfer1::IAttention>,
pub(crate) network: *const nvinfer1::INetworkDefinition,
}
pub struct Loop<'network> {
pub(crate) inner: Pin<&'network mut nvinfer1::ILoop>,
pub(crate) network: *const nvinfer1::INetworkDefinition,
}
pub struct IfConditional<'network> {
pub(crate) inner: Pin<&'network mut nvinfer1::IIfConditional>,
pub(crate) network: *const nvinfer1::INetworkDefinition,
}
impl ShuffleLayer<'_> {
pub fn set_reshape_dimensions(
&mut self,
network: &mut NetworkDefinition,
dims: &[i64],
) -> Result<()> {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(dims);
self.inner.as_mut().setReshapeDimensions(&dims_obj);
Ok(())
}
pub fn set_first_transpose(
&mut self,
network: &mut NetworkDefinition,
order: &[i32],
) -> Result<()> {
check_network!(network, self);
let mut order_arr = [0i32; 8];
let n = order.len().min(8);
order_arr[..n].copy_from_slice(&order[..n]);
let perm = trtx_sys::nvinfer1::Permutation { order: order_arr };
self.inner.as_mut().setFirstTranspose(perm);
Ok(())
}
pub fn set_second_transpose(
&mut self,
network: &mut NetworkDefinition,
order: &[i32],
) -> Result<()> {
check_network!(network, self);
let mut order_arr = [0i32; 8];
let n = order.len().min(8);
order_arr[..n].copy_from_slice(&order[..n]);
let perm = trtx_sys::nvinfer1::Permutation { order: order_arr };
self.inner.as_mut().setSecondTranspose(perm);
Ok(())
}
}
impl ResizeLayer<'_> {
pub fn set_output_dimensions(&mut self, network: &mut NetworkDefinition, dims: &[i64]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(dims);
self.inner.as_mut().setOutputDimensions(&dims_obj);
}
pub fn output_dimensions(&self, network: &NetworkDefinition) -> Vec<i64> {
check_network!(network, self);
let d = self.inner.as_ref().getOutputDimensions();
d.d[..d.nbDims as usize].to_vec()
}
pub fn set_scales(&mut self, network: &mut NetworkDefinition, scales: &[f32]) {
check_network!(network, self);
unsafe {
self.inner
.as_mut()
.setScales(scales.as_ptr(), scales.len() as i32);
}
}
pub fn scales(&self, network: &NetworkDefinition) -> Option<Vec<f32>> {
check_network!(network, self);
let n = unsafe { self.inner.as_ref().getScales(0, std::ptr::null_mut()) };
if n <= 0 {
return None;
}
let mut buf = vec![0.0_f32; n as usize];
let n2 = unsafe { self.inner.as_ref().getScales(n, buf.as_mut_ptr()) };
if n2 != n {
return None;
}
Some(buf)
}
pub fn set_resize_mode(&mut self, network: &mut NetworkDefinition, mode: trtx_sys::ResizeMode) {
check_network!(network, self);
self.inner.as_mut().setResizeMode(mode.into());
}
pub fn resize_mode(&self, network: &NetworkDefinition) -> trtx_sys::ResizeMode {
check_network!(network, self);
self.inner.as_ref().getResizeMode().into()
}
pub fn set_coordinate_transformation(
&mut self,
network: &mut NetworkDefinition,
transform: trtx_sys::ResizeCoordinateTransformation,
) {
check_network!(network, self);
self.inner
.as_mut()
.setCoordinateTransformation(transform.into());
}
pub fn coordinate_transformation(
&self,
network: &NetworkDefinition,
) -> trtx_sys::ResizeCoordinateTransformation {
check_network!(network, self);
self.inner.as_ref().getCoordinateTransformation().into()
}
pub fn set_selector_for_single_pixel(
&mut self,
network: &mut NetworkDefinition,
selector: trtx_sys::ResizeSelector,
) {
check_network!(network, self);
self.inner
.as_mut()
.setSelectorForSinglePixel(selector.into());
}
pub fn selector_for_single_pixel(
&self,
network: &NetworkDefinition,
) -> trtx_sys::ResizeSelector {
check_network!(network, self);
self.inner.as_ref().getSelectorForSinglePixel().into()
}
pub fn set_nearest_rounding(
&mut self,
network: &mut NetworkDefinition,
mode: trtx_sys::ResizeRoundMode,
) {
check_network!(network, self);
self.inner.as_mut().setNearestRounding(mode.into());
}
pub fn nearest_rounding(&self, network: &NetworkDefinition) -> trtx_sys::ResizeRoundMode {
check_network!(network, self);
self.inner.as_ref().getNearestRounding().into()
}
pub fn set_cubic_coeff(&mut self, network: &mut NetworkDefinition, a: f32) {
check_network!(network, self);
self.inner.as_mut().setCubicCoeff(a);
}
pub fn cubic_coeff(&self, network: &NetworkDefinition) -> f32 {
check_network!(network, self);
self.inner.as_ref().getCubicCoeff()
}
pub fn set_exclude_outside(&mut self, network: &mut NetworkDefinition, exclude: bool) {
check_network!(network, self);
self.inner.as_mut().setExcludeOutside(exclude);
}
pub fn exclude_outside(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.inner.as_ref().getExcludeOutside()
}
}
impl GatherLayer<'_> {
pub fn set_gather_mode(&mut self, network: &mut NetworkDefinition, mode: trtx_sys::GatherMode) {
check_network!(network, self);
self.inner.as_mut().setMode(mode.into());
}
}
impl<'network> ScatterLayer<'network> {
pub fn set_scatter_mode(
&mut self,
network: &mut NetworkDefinition,
mode: trtx_sys::ScatterMode,
) {
check_network!(network, self);
self.inner.as_mut().setMode(mode.into());
}
pub fn set_axis(&mut self, network: &'_ mut NetworkDefinition, axis: i32) {
check_network!(network, self);
self.inner.as_mut().setAxis(axis);
}
}
impl<'network> ConvolutionLayer<'network> {
pub fn set_stride(&mut self, network: &mut NetworkDefinition, stride: &[i64; 2]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(stride);
self.inner.as_mut().setStrideNd(&dims_obj);
}
pub fn set_padding(&mut self, network: &mut NetworkDefinition, padding: &[i64; 2]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(padding);
self.inner.as_mut().setPaddingNd(&dims_obj);
}
pub fn set_dilation(&mut self, network: &mut NetworkDefinition, dilation: &[i64; 2]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(dilation);
self.inner.as_mut().setDilationNd(&dims_obj);
}
pub fn set_num_groups(&mut self, network: &mut NetworkDefinition, num_groups: i64) {
check_network!(network, self);
self.inner.as_mut().setNbGroups(num_groups);
}
}
impl<'network> DeconvolutionLayer<'network> {
pub fn set_stride(&mut self, network: &mut NetworkDefinition, stride: &[i64; 2]) -> Result<()> {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(stride);
self.inner.as_mut().setStrideNd(&dims_obj);
Ok(())
}
pub fn set_pre_padding(
&mut self,
network: &mut NetworkDefinition,
padding: &[i64; 2],
) -> Result<()> {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(padding);
self.inner.as_mut().setPrePadding(&dims_obj);
Ok(())
}
pub fn set_post_padding(
&mut self,
network: &mut NetworkDefinition,
padding: &[i64; 2],
) -> Result<()> {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(padding);
self.inner.as_mut().setPostPadding(&dims_obj);
Ok(())
}
pub fn set_dilation(
&mut self,
network: &mut NetworkDefinition,
dilation: &[i64; 2],
) -> Result<()> {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(dilation);
self.inner.as_mut().setDilationNd(&dims_obj);
Ok(())
}
pub fn set_num_groups(
&mut self,
network: &mut NetworkDefinition,
num_groups: i64,
) -> Result<()> {
check_network!(network, self);
self.inner.as_mut().setNbGroups(num_groups);
Ok(())
}
}
impl<'network> PoolingLayer<'network> {
pub fn set_pooling_type(
&mut self,
network: &mut NetworkDefinition,
pooling_type: trtx_sys::PoolingType,
) {
check_network!(network, self);
self.inner.as_mut().setPoolingType(pooling_type.into());
}
pub fn pooling_type(&self, network: &NetworkDefinition) -> trtx_sys::PoolingType {
check_network!(network, self);
self.inner.as_ref().getPoolingType().into()
}
pub fn set_blend_factor(&mut self, network: &mut NetworkDefinition, blend_factor: f32) {
check_network!(network, self);
self.inner.as_mut().setBlendFactor(blend_factor);
}
pub fn blend_factor(&self, network: &NetworkDefinition) -> f32 {
check_network!(network, self);
self.inner.as_ref().getBlendFactor()
}
pub fn set_average_count_excludes_padding(
&mut self,
network: &mut NetworkDefinition,
exclusive: bool,
) {
check_network!(network, self);
self.inner
.as_mut()
.setAverageCountExcludesPadding(exclusive);
}
pub fn average_count_excludes_padding(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.inner.as_ref().getAverageCountExcludesPadding()
}
pub fn set_pre_padding(&mut self, network: &mut NetworkDefinition, padding: &[i64]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(padding);
self.inner.as_mut().setPrePadding(&dims_obj);
}
pub fn pre_padding(&self, network: &NetworkDefinition) -> Vec<i64> {
check_network!(network, self);
let d = self.inner.as_ref().getPrePadding();
d.d[..d.nbDims as usize].to_vec()
}
pub fn set_post_padding(&mut self, network: &mut NetworkDefinition, padding: &[i64]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(padding);
self.inner.as_mut().setPostPadding(&dims_obj);
}
pub fn post_padding(&self, network: &NetworkDefinition) -> Vec<i64> {
check_network!(network, self);
let d = self.inner.as_ref().getPostPadding();
d.d[..d.nbDims as usize].to_vec()
}
pub fn set_padding_mode(
&mut self,
network: &mut NetworkDefinition,
padding_mode: trtx_sys::PaddingMode,
) {
check_network!(network, self);
self.inner.as_mut().setPaddingMode(padding_mode.into());
}
pub fn padding_mode(&self, network: &NetworkDefinition) -> trtx_sys::PaddingMode {
check_network!(network, self);
self.inner.as_ref().getPaddingMode().into()
}
pub fn set_window_size_nd(&mut self, network: &mut NetworkDefinition, window_size: &[i64]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(window_size);
self.inner.as_mut().setWindowSizeNd(&dims_obj);
}
pub fn window_size_nd(&self, network: &NetworkDefinition) -> Vec<i64> {
check_network!(network, self);
let d = self.inner.as_ref().getWindowSizeNd();
d.d[..d.nbDims as usize].to_vec()
}
pub fn set_stride_nd(&mut self, network: &mut NetworkDefinition, stride: &[i64]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(stride);
self.inner.as_mut().setStrideNd(&dims_obj);
}
pub fn stride_nd(&self, network: &NetworkDefinition) -> Vec<i64> {
check_network!(network, self);
let d = self.inner.as_ref().getStrideNd();
d.d[..d.nbDims as usize].to_vec()
}
pub fn set_padding_nd(&mut self, network: &mut NetworkDefinition, padding: &[i64]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(padding);
self.inner.as_mut().setPaddingNd(&dims_obj);
}
pub fn padding_nd(&self, network: &NetworkDefinition) -> Vec<i64> {
check_network!(network, self);
let d = self.inner.as_ref().getPaddingNd();
d.d[..d.nbDims as usize].to_vec()
}
}
impl<'network> DynamicQuantizeLayer<'network> {
pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) {
check_network!(network, self);
self.inner.as_mut().setAxis(axis);
}
pub fn axis(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.getAxis()
}
pub fn set_block_shape(&mut self, network: &mut NetworkDefinition, block_shape: &[i64]) {
check_network!(network, self);
let dims = Dims64::from_slice(block_shape);
self.inner.as_mut().setBlockShape(&dims);
}
pub fn block_shape(&self, network: &NetworkDefinition) -> Dims64 {
check_network!(network, self);
self.inner.getBlockShape()
}
pub fn set_block_size(&mut self, network: &mut NetworkDefinition, size: i32) {
check_network!(network, self);
self.inner.as_mut().setBlockSize(size);
}
pub fn block_size(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.getBlockSize()
}
pub fn set_to_type(&mut self, network: &mut NetworkDefinition, to_type: DataType) {
check_network!(network, self);
self.inner.as_mut().setToType(to_type.into());
}
pub fn to_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.inner.getToType().into()
}
pub fn set_scale_type(&mut self, network: &mut NetworkDefinition, scale_type: DataType) {
check_network!(network, self);
self.inner.as_mut().setScaleType(scale_type.into());
}
pub fn scale_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.inner.getScaleType().into()
}
}
impl<'network> QuantizeLayer<'network> {
pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) {
check_network!(network, self);
self.inner.as_mut().setAxis(axis);
}
pub fn axis(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.getAxis()
}
pub fn set_block_shape(
&mut self,
network: &mut NetworkDefinition,
block_shape: &[i64],
) -> Result<()> {
check_network!(network, self);
let dims = Dims64::from_slice(block_shape);
self.inner
.as_mut()
.setBlockShape(&dims)
.ok_or_err(PropertySetAttempt::QuantizeLayerBlockShape)
}
pub fn block_shape(&self, network: &NetworkDefinition) -> Dims64 {
check_network!(network, self);
self.inner.getBlockShape()
}
pub fn set_to_type(&mut self, network: &mut NetworkDefinition, to_type: DataType) {
check_network!(network, self);
self.inner.as_mut().setToType(to_type.into());
}
pub fn to_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.inner.getToType().into()
}
}
impl<'network> DequantizeLayer<'network> {
pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) {
check_network!(network, self);
self.inner.as_mut().setAxis(axis);
}
pub fn axis(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.getAxis()
}
pub fn set_block_shape(
&mut self,
network: &mut NetworkDefinition,
block_shape: &[i64],
) -> Result<()> {
check_network!(network, self);
let dims = Dims64::from_slice(block_shape);
self.inner
.as_mut()
.setBlockShape(&dims)
.ok_or_err(PropertySetAttempt::DequantizeLayerBlockShape)
}
pub fn block_shape(&self, network: &NetworkDefinition) -> Dims64 {
check_network!(network, self);
self.inner.getBlockShape()
}
pub fn set_to_type(&mut self, network: &mut NetworkDefinition, to_type: DataType) {
check_network!(network, self);
self.inner.as_mut().setToType(to_type.into());
}
pub fn to_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.inner.getToType().into()
}
}
impl ConcatenationLayer<'_> {
pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) {
check_network!(network, self);
self.inner.as_mut().setAxis(axis);
}
}
impl NormalizationLayer<'_> {
pub fn set_epsilon(&mut self, network: &mut NetworkDefinition, eps: f32) {
check_network!(network, self);
self.inner.as_mut().setEpsilon(eps);
}
pub fn epsilon(&self, network: &NetworkDefinition) -> f32 {
check_network!(network, self);
self.inner.as_ref().getEpsilon()
}
#[deprecated = "use epsilon instead"]
pub fn get_epsilon(&self, network: &NetworkDefinition) -> f32 {
self.epsilon(network)
}
pub fn set_axes(&mut self, network: &mut NetworkDefinition, axes: crate::Axes) {
check_network!(network, self);
self.inner.as_mut().setAxes(axes.to_bits());
}
pub fn axes(&self, network: &NetworkDefinition) -> crate::Axes {
check_network!(network, self);
crate::Axes::from_bits(self.inner.as_ref().getAxes())
}
#[deprecated = "use axes instead"]
pub fn get_axes(&self, network: &NetworkDefinition) -> crate::Axes {
self.axes(network)
}
pub fn set_num_groups(&mut self, network: &mut NetworkDefinition, groups: i64) {
check_network!(network, self);
self.inner.as_mut().setNbGroups(groups);
}
pub fn num_groups(&self, network: &NetworkDefinition) -> i64 {
check_network!(network, self);
self.inner.as_ref().getNbGroups()
}
#[deprecated = "use num_groups instead"]
pub fn get_num_groups(&self, network: &NetworkDefinition) -> i64 {
self.num_groups(network)
}
pub fn is_v2(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.inner.as_ref().isV2()
}
}
#[cfg(feature = "v_1_4")]
impl MoELayer<'_> {
pub fn set_gated_weights(
&mut self,
network: &mut NetworkDefinition,
fc_gate_weights: &Tensor,
fc_up_weights: &Tensor,
fc_down_weights: &Tensor,
activation_type: MoEActType,
) -> Result<()> {
check_network!(network, self);
check_network!(network, fc_gate_weights);
check_network!(network, fc_up_weights);
check_network!(network, fc_down_weights);
self.inner.as_mut().setGatedWeights(
fc_gate_weights.pin_mut(),
fc_up_weights.pin_mut(),
fc_down_weights.pin_mut(),
activation_type.into(),
);
Ok(())
}
pub fn set_gated_biases(
&mut self,
network: &mut NetworkDefinition,
fc_gate_biases: &Tensor,
fc_up_biases: &Tensor,
fc_down_biases: &Tensor,
) -> Result<()> {
check_network!(network, self);
check_network!(network, fc_gate_biases);
check_network!(network, fc_up_biases);
check_network!(network, fc_down_biases);
self.inner.as_mut().setGatedBiases(
fc_gate_biases.pin_mut(),
fc_up_biases.pin_mut(),
fc_down_biases.pin_mut(),
);
Ok(())
}
pub fn set_activation_type(
&mut self,
network: &mut NetworkDefinition,
activation_type: MoEActType,
) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setActivationType(activation_type.into());
Ok(())
}
pub fn activation_type(&self, network: &NetworkDefinition) -> MoEActType {
check_network!(network, self);
self.inner.as_ref().getActivationType().into()
}
pub fn set_quantization_static(
&mut self,
network: &mut NetworkDefinition,
fc_down_activation_scale: &Tensor,
data_type: DataType,
) -> Result<()> {
check_network!(network, self);
check_network!(network, fc_down_activation_scale);
self.inner
.as_mut()
.setQuantizationStatic(fc_down_activation_scale.pin_mut(), data_type.into());
Ok(())
}
pub fn set_quantization_dynamic_dbl_q(
&mut self,
network: &mut NetworkDefinition,
fc_down_activation_dbl_q_scale: &Tensor,
data_type: DataType,
block_shape: &[i64],
dyn_q_output_scale_type: DataType,
) -> Result<()> {
check_network!(network, self);
check_network!(network, fc_down_activation_dbl_q_scale);
let block = trtx_sys::Dims::from_slice(block_shape);
self.inner.as_mut().setQuantizationDynamicDblQ(
fc_down_activation_dbl_q_scale.pin_mut(),
data_type.into(),
&block,
dyn_q_output_scale_type.into(),
);
Ok(())
}
pub fn set_quantization_to_type(
&mut self,
network: &mut NetworkDefinition,
type_: DataType,
) -> Result<()> {
check_network!(network, self);
self.inner.as_mut().setQuantizationToType(type_.into());
Ok(())
}
pub fn quantization_to_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.inner.as_ref().getQuantizationToType().into()
}
pub fn set_quantization_block_shape(
&mut self,
network: &mut NetworkDefinition,
block_shape: &[i64],
) -> Result<()> {
check_network!(network, self);
let block = trtx_sys::Dims::from_slice(block_shape);
self.inner.as_mut().setQuantizationBlockShape(&block);
Ok(())
}
pub fn quantization_block_shape(&self, network: &NetworkDefinition) -> Vec<i64> {
check_network!(network, self);
let d = self.inner.as_ref().getQuantizationBlockShape();
d.d[..d.nbDims as usize].to_vec()
}
pub fn set_dyn_q_output_scale_type(
&mut self,
network: &mut NetworkDefinition,
type_: DataType,
) -> Result<()> {
check_network!(network, self);
self.inner.as_mut().setDynQOutputScaleType(type_.into());
Ok(())
}
pub fn dyn_q_output_scale_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.inner.as_ref().getDynQOutputScaleType().into()
}
pub fn set_swiglu_params(
&mut self,
network: &mut NetworkDefinition,
limit: f32,
alpha: f32,
beta: f32,
) -> Result<()> {
check_network!(network, self);
self.inner.as_mut().setSwigluParams(limit, alpha, beta);
Ok(())
}
pub fn set_swiglu_param_limit(
&mut self,
network: &mut NetworkDefinition,
limit: f32,
) -> Result<()> {
check_network!(network, self);
self.inner.as_mut().setSwigluParamLimit(limit);
Ok(())
}
pub fn swiglu_param_limit(&self, network: &NetworkDefinition) -> f32 {
check_network!(network, self);
self.inner.as_ref().getSwigluParamLimit()
}
pub fn set_swiglu_param_alpha(
&mut self,
network: &mut NetworkDefinition,
alpha: f32,
) -> Result<()> {
check_network!(network, self);
self.inner.as_mut().setSwigluParamAlpha(alpha);
Ok(())
}
pub fn swiglu_param_alpha(&self, network: &NetworkDefinition) -> f32 {
check_network!(network, self);
self.inner.as_ref().getSwigluParamAlpha()
}
pub fn set_swiglu_param_beta(
&mut self,
network: &mut NetworkDefinition,
beta: f32,
) -> Result<()> {
check_network!(network, self);
self.inner.as_mut().setSwigluParamBeta(beta);
Ok(())
}
pub fn swiglu_param_beta(&self, network: &NetworkDefinition) -> f32 {
check_network!(network, self);
self.inner.as_ref().getSwigluParamBeta()
}
}
#[cfg(feature = "v_1_4")]
impl DistCollectiveLayer<'_> {}
pub struct NetworkDefinition<'builder> {
pub(crate) inner: UniquePtr<INetworkDefinition>,
_builder: PhantomData<&'builder trtx_sys::nvinfer1::IBuilder>,
small_copied_weights: Vec<Vec<u8>>, error_recorder: Option<Pin<Box<ErrorRecorder>>>,
}
fn tensor_dbg(network: &NetworkDefinition<'_>, tensor: &Tensor<'_>) -> String {
tensor
.name(network)
.unwrap_or_else(|_| "(unnamed)".to_string())
}
fn layer_dbg<Inner: AsLayer>(network: &NetworkDefinition<'_>, layer: &Layer<'_, Inner>) -> String {
layer.name(network)
}
impl<'network> NetworkDefinition<'network> {
pub(crate) fn from_ptr(ptr: *mut INetworkDefinition) -> Self {
Self {
inner: unsafe { UniquePtr::from_raw(ptr) },
error_recorder: None,
_builder: Default::default(),
small_copied_weights: Default::default(),
}
}
pub fn add_input(
&mut self,
name: &str,
data_type: trtx_sys::DataType,
dims: &[i64],
) -> Result<Tensor<'network>> {
debug!("add_input name={name:?} data_type={data_type:?} dims={dims:?}");
let name_cstr = std::ffi::CString::new(name)?;
let dims_struct = trtx_sys::Dims::from_slice(dims);
let tensor_ptr = unsafe {
self.inner
.pin_mut()
.addInput(name_cstr.as_ptr(), data_type.into(), &dims_struct)
};
unsafe { Tensor::new(self.inner.as_ptr(), tensor_ptr) }
}
pub fn mark_output(&mut self, tensor: &'_ Tensor) {
check_network!(self, tensor);
debug!("mark_input tensor={}", tensor_dbg(self, tensor));
self.inner.pin_mut().markOutput(tensor.pin_mut());
}
pub fn mark_tensor_debug(&mut self, tensor: &'_ Tensor) -> Result<()> {
check_network!(self, tensor);
let success = self.inner.pin_mut().markDebug(tensor.pin_mut());
if success {
Ok(())
} else {
Err(Error::Runtime("markDebug failed".to_string()))
}
}
pub fn is_debug_tensor(&self, tensor: &'_ Tensor) -> bool {
check_network!(self, tensor);
self.inner.isDebugTensor(tensor.as_ref())
}
pub fn nb_inputs(&self) -> i32 {
self.inner.getNbInputs()
}
#[deprecated = "use nb_inputs instead"]
pub fn get_nb_inputs(&self) -> i32 {
self.nb_inputs()
}
pub fn nb_outputs(&self) -> i32 {
self.inner.getNbOutputs()
}
#[deprecated = "use nb_outputs instead"]
pub fn get_nb_outputs(&self) -> i32 {
self.nb_outputs()
}
pub fn input(&self, index: i32) -> Result<Tensor<'network>> {
let tensor_ptr = self.inner.getInput(index);
if tensor_ptr.is_null() {
return Err(Error::Runtime(format!(
"Failed to get input at index {}",
index
)));
}
unsafe { Tensor::new(self.inner.as_ptr(), tensor_ptr) }
}
#[deprecated = "use input instead"]
pub fn get_input(&self, index: i32) -> Result<Tensor<'network>> {
self.input(index)
}
pub fn output(&self, index: i32) -> Result<Tensor<'network>> {
let tensor_ptr = self.inner.getOutput(index);
if tensor_ptr.is_null() {
return Err(Error::Runtime(format!(
"Failed to get output at index {}",
index
)));
}
unsafe { Tensor::new(self.inner.as_ptr(), tensor_ptr) }
}
#[deprecated = "use output instead"]
pub fn get_output(&self, index: i32) -> Result<Tensor<'network>> {
self.output(index)
}
pub fn nb_layers(&self) -> i32 {
self.inner.getNbLayers()
}
#[deprecated = "use nb_layers instead"]
pub fn get_nb_layers(&self) -> i32 {
self.nb_layers()
}
pub fn layer(&self, layer_index: i32) -> Result<DynLayer<'network>> {
let layer_ptr = self.inner.getLayer(layer_index);
DynLayer::new_dyn(self.inner.as_ptr(), layer_ptr)
}
#[deprecated = "use layer instead"]
pub fn get_layer(&self, layer_index: i32) -> Result<DynLayer<'network>> {
self.layer(layer_index)
}
pub fn layer_name(&self, layer_index: i32) -> Result<String> {
Ok(self.layer(layer_index)?.name(self))
}
#[deprecated = "use layer_name instead"]
pub fn get_layer_name(&self, layer_index: i32) -> Result<String> {
self.layer_name(layer_index)
}
pub fn layer_type(&self, layer_index: i32) -> Result<LayerType> {
Ok(self.layer(layer_index)?.layer_type_dynamic())
}
#[deprecated = "use layer_type instead"]
pub fn get_layer_type(&self, layer_index: i32) -> Result<LayerType> {
self.layer_type(layer_index)
}
pub fn add_activation(
&mut self,
input: &'_ Tensor,
activation_type: trtx_sys::ActivationType,
) -> Result<ActivationLayer<'network>> {
check_network!(self, input);
debug!(
"add_activation input={} activation_type={activation_type:?}",
tensor_dbg(self, input)
);
let layer_ptr = self
.inner
.pin_mut()
.addActivation(input.pin_mut(), activation_type.into());
ActivationLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_unary(
&mut self,
input: &'_ Tensor,
op: trtx_sys::UnaryOperation,
) -> Result<UnaryLayer<'network>> {
check_network!(self, input);
debug!("add_unary input={} op={op:?}", tensor_dbg(self, input));
let layer_ptr = self.inner.pin_mut().addUnary(input.pin_mut(), op.into());
UnaryLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_identity(&mut self, input: &'_ Tensor) -> Result<IdentityLayer<'network>> {
check_network!(self, input);
debug!("add_identity input={}", tensor_dbg(self, input));
let layer_ptr = self.inner.pin_mut().addIdentity(input.pin_mut());
IdentityLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_cast(
&mut self,
input: &'_ Tensor,
to_type: trtx_sys::DataType,
) -> Result<CastLayer<'network>> {
check_network!(self, input);
debug!(
"add_cast input={} to_type={to_type:?}",
tensor_dbg(self, input)
);
let layer_ptr = self
.inner
.pin_mut()
.addCast(input.pin_mut(), to_type.into());
CastLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_elementwise(
&mut self,
input1: &'_ Tensor,
input2: &'_ Tensor,
op: trtx_sys::ElementWiseOperation,
) -> Result<ElementWiseLayer<'network>> {
check_network!(self, input1);
check_network!(self, input2);
debug!(
"add_elementwise input1={} input2={} op={op:?}",
tensor_dbg(self, input1),
tensor_dbg(self, input2)
);
let layer_ptr =
self.inner
.pin_mut()
.addElementWise(input1.pin_mut(), input2.pin_mut(), op.into());
ElementWiseLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_pooling(
&'_ mut self,
input: &'_ Tensor,
pooling_type: trtx_sys::PoolingType,
window_size: &[i64; 2],
) -> Result<PoolingLayer<'network>> {
check_network!(self, input);
debug!(
"add_pooling input={} pooling_type={pooling_type:?} window_size={window_size:?}",
tensor_dbg(self, input)
);
let window_dims = trtx_sys::Dims::new_2d(window_size[0], window_size[1]);
let layer_ptr =
self.inner
.pin_mut()
.addPoolingNd(input.pin_mut(), pooling_type.into(), &window_dims);
PoolingLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_shuffle(&'_ mut self, input: &'_ Tensor) -> Result<ShuffleLayer<'network>> {
check_network!(self, input);
debug!("add_shuffle input={}", tensor_dbg(self, input));
let layer_ptr = self.inner.pin_mut().addShuffle(input.pin_mut());
ShuffleLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_matrix_multiply(
&'_ mut self,
input0: &'_ Tensor,
op0: MatrixOperation,
input1: &'_ Tensor,
op1: MatrixOperation,
) -> Result<MatrixMultiplyLayer<'network>> {
check_network!(self, input0);
check_network!(self, input1);
debug!(
"add_matrix_multiply input0={} op0={op0:?} input1={} op1={op1:?}",
tensor_dbg(self, input0),
tensor_dbg(self, input1)
);
let layer_ptr = self.inner.pin_mut().addMatrixMultiply(
input0.pin_mut(),
op0.into(),
input1.pin_mut(),
op1.into(),
);
MatrixMultiplyLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_convolution_deferrred_weights(
&'_ mut self,
input: &'_ Tensor,
nb_output_maps: i32,
kernel_size: &[i32; 2],
) -> Result<ConvolutionLayer<'network>> {
debug!(
"add_convolution_deferrred_weights input={} nb_output_maps={nb_output_maps} kernel_size={kernel_size:?}",
tensor_dbg(self, input)
);
let kernel_dims = trtx_sys::Dims::new_2d(kernel_size[0] as i64, kernel_size[1] as i64);
let layer_ptr = self.inner.pin_mut().addConvolutionNd(
input.pin_mut(),
nb_output_maps as i64,
&kernel_dims,
Weights {
type_: nvinfer1::DataType::kFLOAT,
values: std::ptr::null(),
count: 0,
},
Weights {
type_: nvinfer1::DataType::kFLOAT,
values: std::ptr::null(),
count: 0,
},
);
ConvolutionLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_convolution_owned_weights(
&'_ mut self,
input: &'_ Tensor,
nb_output_maps: i32,
kernel_size: &[i32; 2],
weights: OwnedConvWeights,
) -> Result<ConvolutionLayer<'network>> {
debug!(
"add_convolution_owned_weights input={} nb_output_maps={nb_output_maps} kernel_size={kernel_size:?}",
tensor_dbg(self, input)
);
let mut layer =
self.add_convolution_deferrred_weights(input, nb_output_maps, kernel_size)?;
let kernel = self
.add_constant_owned(
&weights.kernel.shape,
weights.kernel.values,
weights.kernel.data_type,
)?
.output(self, 0)?;
layer.set_input(self, 1, &kernel)?;
if let Some(bias) = weights.bias {
let bias = self
.add_constant_owned(&bias.shape, bias.values, bias.data_type)?
.output(self, 0)?;
layer.set_input(self, 2, &bias)?;
}
Ok(layer)
}
pub fn add_convolution(
&'_ mut self,
input: &'_ Tensor,
nb_output_maps: i32,
kernel_size: &[i32; 2],
weights: &ConvWeights<'network>,
) -> Result<ConvolutionLayer<'network>> {
check_network!(self, input);
debug!(
"add_convolution input={} nb_output_maps={nb_output_maps} kernel_size={kernel_size:?}",
tensor_dbg(self, input)
);
let kernel_dtype = weights.kernel_dtype;
let kernel_weights = weights.kernel_weights;
let bias_weights = weights.bias_weights;
let bias_dtype = weights.bias_dtype;
let kernel_bpe = kernel_dtype.size_bits() / 8;
let weight_count = (kernel_weights.len() / kernel_bpe) as i64;
let bias_dtype_val = bias_dtype.unwrap_or(kernel_dtype);
let bias_bpe = bias_dtype_val.size_bits() / 8;
let bias_count = bias_weights
.map(|b| (b.len() / bias_bpe) as i64)
.unwrap_or(0);
let kernel_ptr = if weight_count > 0 {
kernel_weights.as_ptr() as *const std::ffi::c_void
} else {
std::ptr::null()
};
let bias_ptr = if bias_count > 0 {
bias_weights
.map(|b| b.as_ptr() as *const std::ffi::c_void)
.unwrap_or(std::ptr::null())
} else {
std::ptr::null()
};
let kernel_dims = trtx_sys::Dims::new_2d(kernel_size[0] as i64, kernel_size[1] as i64);
let kernel_w = trtx_sys::nvinfer1::Weights::new_with_type(
kernel_dtype.into(),
kernel_ptr,
weight_count,
);
let bias_w =
trtx_sys::nvinfer1::Weights::new_with_type(bias_dtype_val.into(), bias_ptr, bias_count);
let layer_ptr = self.inner.pin_mut().addConvolutionNd(
input.pin_mut(),
nb_output_maps as i64,
&kernel_dims,
kernel_w,
bias_w,
);
ConvolutionLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_deconvolution(
&mut self,
input: &'_ Tensor,
nb_output_maps: i64,
kernel_size: &[i64; 2],
weights: &ConvWeights<'network>,
) -> Result<DeconvolutionLayer<'network>> {
check_network!(self, input);
debug!(
"add_deconvolution input={} nb_output_maps={nb_output_maps} kernel_size={kernel_size:?}",
tensor_dbg(self, input)
);
let kernel_dtype = weights.kernel_dtype;
let kernel_weights = weights.kernel_weights;
let bias_weights = weights.bias_weights;
let bias_dtype = weights.bias_dtype;
let kernel_bpe = kernel_dtype.size_bits() / 8;
let weight_count = (kernel_weights.len() / kernel_bpe) as i64;
let bias_dtype_val = bias_dtype.unwrap_or(kernel_dtype);
let bias_bpe = bias_dtype_val.size_bits() / 8;
let bias_count = bias_weights
.map(|b| (b.len() / bias_bpe) as i64)
.unwrap_or(0);
let kernel_ptr = if weight_count > 0 {
kernel_weights.as_ptr() as *const std::ffi::c_void
} else {
std::ptr::null()
};
let bias_ptr = if bias_count > 0 {
bias_weights
.map(|b| b.as_ptr() as *const std::ffi::c_void)
.unwrap_or(std::ptr::null())
} else {
std::ptr::null()
};
let kernel_dims = trtx_sys::Dims::new_2d(kernel_size[0], kernel_size[1]);
let kernel_w = trtx_sys::nvinfer1::Weights::new_with_type(
kernel_dtype.into(),
kernel_ptr,
weight_count,
);
let bias_w =
trtx_sys::nvinfer1::Weights::new_with_type(bias_dtype_val.into(), bias_ptr, bias_count);
let layer_ptr = self.inner.pin_mut().addDeconvolutionNd(
input.pin_mut(),
nb_output_maps,
kernel_dims,
kernel_w,
bias_w,
);
DeconvolutionLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_deconvolution_deferred_weights(
&mut self,
input: &'_ Tensor,
nb_output_maps: i64,
kernel_size: &[i64; 2],
) -> Result<DeconvolutionLayer<'network>> {
check_network!(self, input);
debug!(
"add_deconvolution_deferred_weights input={} nb_output_maps={nb_output_maps} kernel_size={kernel_size:?}",
tensor_dbg(self, input)
);
let kernel_dims = trtx_sys::Dims::new_2d(kernel_size[0], kernel_size[1]);
let layer_ptr = self.inner.pin_mut().addDeconvolutionNd(
input.pin_mut(),
nb_output_maps,
kernel_dims,
Weights {
type_: nvinfer1::DataType::kFLOAT,
values: std::ptr::null(),
count: 0,
},
Weights {
type_: nvinfer1::DataType::kFLOAT,
values: std::ptr::null(),
count: 0,
},
);
DeconvolutionLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_deconvolution_owned_weights(
&'_ mut self,
input: &'_ Tensor,
nb_output_maps: i64,
kernel_size: &[i64; 2],
weights: OwnedConvWeights,
) -> Result<DeconvolutionLayer<'network>> {
debug!(
"add_deconvolution_owned_weights input={} nb_output_maps={nb_output_maps} kernel_size={kernel_size:?}",
tensor_dbg(self, input)
);
let mut layer =
self.add_deconvolution_deferred_weights(input, nb_output_maps, kernel_size)?;
let kernel = self
.add_constant_owned(
&weights.kernel.shape,
weights.kernel.values,
weights.kernel.data_type,
)?
.output(self, 0)?;
layer.set_input(self, 1, &kernel)?;
if let Some(bias) = weights.bias {
let bias = self
.add_constant_owned(&bias.shape, bias.values, bias.data_type)?
.output(self, 0)?;
layer.set_input(self, 2, &bias)?;
}
Ok(layer)
}
pub fn add_concatenation(&self, inputs: &[&'_ Tensor]) -> Result<ConcatenationLayer<'network>> {
for t in inputs.iter() {
check_network!(self, t);
}
let input_names: Vec<String> = inputs.iter().map(|t| tensor_dbg(self, t)).collect();
debug!("add_concatenation inputs={input_names:?}");
let mut input_ptrs: Vec<*mut std::ffi::c_void> = inputs
.iter()
.map(|t| t.as_mut() as *mut ITensor as *mut _)
.collect();
let layer_ptr = unsafe {
trtx_sys::network_add_concatenation(
self.inner.as_mut_ptr() as *mut std::ffi::c_void,
input_ptrs.as_mut_ptr(),
inputs.len() as i32,
)
} as *mut IConcatenationLayer;
ConcatenationLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_small_constant_copied(
&mut self,
dims: &[i64],
weights: &[u8],
data_type: trtx_sys::DataType,
) -> Result<ConstantLayer<'network>> {
trace!(
"add_small_constant_copied dims={dims:?} data_type={data_type:?} weights_len={}",
weights.len()
);
unsafe { self.add_constant_unsafe(dims, weights, data_type, true) }
}
pub fn add_constant_owned(
&mut self,
dims: &[i64],
weights: Vec<u8>,
data_type: trtx_sys::DataType,
) -> Result<ConstantLayer<'network>> {
trace!(
"add_constant_owned dims={dims:?} data_type={data_type:?} weights_len={}",
weights.len()
);
let element_count: i64 = dims.iter().product();
let expected_bytes = element_count * data_type.size_bits() as i64 / 8;
if weights.len() as i64 != expected_bytes {
panic!(
"Weight size mismatch: expected {expected_bytes} bytes, got {} bytes",
weights.len()
);
}
let dims_struct = trtx_sys::Dims::from_slice(dims);
let weights_struct = trtx_sys::nvinfer1::Weights::new_with_type(
data_type.into(),
{
self.small_copied_weights.push(weights);
self.small_copied_weights
.last()
.expect("can't be empty. we just pushed")
.as_ptr()
} as *const std::ffi::c_void,
element_count,
);
let layer_ptr = self
.inner
.pin_mut()
.addConstant(&dims_struct, weights_struct);
ConstantLayer::new(self.inner.as_ptr(), layer_ptr)
}
unsafe fn add_constant_unsafe(
&mut self,
dims: &[i64],
weights: &[u8],
data_type: trtx_sys::DataType,
copy: bool,
) -> Result<ConstantLayer<'network>> {
let element_count: i64 = dims.iter().product();
let expected_bytes = element_count * data_type.size_bits() as i64 / 8;
if weights.len() as i64 != expected_bytes {
panic!(
"Weight size mismatch: expected {expected_bytes} bytes, got {} bytes",
weights.len()
);
}
let dims_struct = trtx_sys::Dims::from_slice(dims);
let weights_struct = trtx_sys::nvinfer1::Weights::new_with_type(
data_type.into(),
if copy {
self.small_copied_weights.push(weights.to_vec());
self.small_copied_weights
.last()
.expect("can't be empty. we just pushed")
.as_ptr()
} else {
weights.as_ptr()
} as *const std::ffi::c_void,
element_count,
);
let layer_ptr = self
.inner
.pin_mut()
.addConstant(&dims_struct, weights_struct);
ConstantLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_constant(
&mut self,
dims: &[i64],
weights: &'network [u8],
data_type: trtx_sys::DataType,
) -> Result<ConstantLayer<'network>> {
trace!(
"add_constant dims={dims:?} data_type={data_type:?} weights_len={}",
weights.len()
);
unsafe { self.add_constant_unsafe(dims, weights, data_type, false) }
}
pub fn add_softmax(
&mut self,
input: &'_ Tensor,
axes: crate::Axes,
) -> Result<SoftMaxLayer<'network>> {
check_network!(self, input);
debug!(
"add_softmax input={} axes={axes:?}",
tensor_dbg(self, input)
);
let layer_ptr = self.inner.pin_mut().addSoftMax(input.pin_mut());
let mut rtn = SoftMaxLayer::new(self.inner.as_ptr(), layer_ptr)?;
rtn.inner.as_mut().setAxes(axes.to_bits());
Ok(rtn)
}
pub fn add_scale(
&mut self,
input: &'_ Tensor,
mode: ScaleMode,
shift: &[u8],
scale: &[u8],
power: &[u8],
) -> Result<ScaleLayer<'network>> {
check_network!(self, input);
debug!(
"add_scale input={} mode={mode:?} shift_len={} scale_len={} power_len={}",
tensor_dbg(self, input),
shift.len(),
scale.len(),
power.len()
);
let weight_count = match mode {
ScaleMode::kUNIFORM => 1i64,
ScaleMode::kCHANNEL => {
let input_dims = input.dimensions(self)?;
if input_dims.len() >= 4 {
input_dims[1]
} else if !input_dims.is_empty() {
input_dims[0]
} else {
1i64
}
}
ScaleMode::kELEMENTWISE => {
let input_dims = input.dimensions(self)?;
input_dims.iter().product::<i64>()
}
};
let shift_w = trtx_sys::nvinfer1::Weights::new_float(
shift.as_ptr() as *const std::ffi::c_void,
weight_count,
);
let scale_w = trtx_sys::nvinfer1::Weights::new_float(
scale.as_ptr() as *const std::ffi::c_void,
weight_count,
);
let power_w = trtx_sys::nvinfer1::Weights::new_float(
power.as_ptr() as *const std::ffi::c_void,
weight_count,
);
let layer_ptr =
self.inner
.pin_mut()
.addScale(input.pin_mut(), mode.into(), shift_w, scale_w, power_w);
ScaleLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_reduce(
&mut self,
input: &'_ Tensor,
op: trtx_sys::ReduceOperation,
axes: crate::Axes,
keep_dims: bool,
) -> Result<ReduceLayer<'network>> {
check_network!(self, input);
debug!(
"add_reduce input={} op={op:?} axes={axes:?} keep_dims={keep_dims}",
tensor_dbg(self, input)
);
let axes_bits = axes.to_bits();
let layer_ptr =
self.inner
.pin_mut()
.addReduce(input.pin_mut(), op.into(), axes_bits, keep_dims);
ReduceLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_cumulative(
&mut self,
input: &'_ Tensor,
axis: i32,
op: trtx_sys::CumulativeOperation,
exclusive: bool,
reverse: bool,
) -> Result<CumulativeLayer<'network>> {
check_network!(self, input);
debug!(
"add_cumulative input={} axis={axis} op={op:?} exclusive={exclusive} reverse={reverse}",
tensor_dbg(self, input)
);
let axis_bytes = axis.to_le_bytes();
let axis_constant =
self.add_small_constant_copied(&[], &axis_bytes, trtx_sys::DataType::kINT32)?;
let axis_tensor = axis_constant.output(self, 0)?;
self.add_cumulative_with_axis_tensor(input, &axis_tensor, op, exclusive, reverse)
}
pub fn add_cumulative_with_axis_tensor(
&mut self,
input: &'_ Tensor,
axis_tensor: &'_ Tensor,
op: trtx_sys::CumulativeOperation,
exclusive: bool,
reverse: bool,
) -> Result<CumulativeLayer<'network>> {
check_network!(self, input);
check_network!(self, axis_tensor);
debug!(
"add_cumulative_with_axis_tensor input={} axis_tensor={} op={op:?} exclusive={exclusive} reverse={reverse}",
tensor_dbg(self, input),
tensor_dbg(self, axis_tensor)
);
let layer_ptr = self.inner.pin_mut().addCumulative(
input.pin_mut(),
axis_tensor.pin_mut(),
op.into(),
exclusive,
reverse,
);
CumulativeLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_slice(
&mut self,
input: &'_ Tensor,
start: &[i64],
size: &[i64],
stride: &[i64],
) -> Result<SliceLayer<'network>> {
check_network!(self, input);
debug!(
"add_slice input={} start={start:?} size={size:?} stride={stride:?}",
tensor_dbg(self, input)
);
if start.len() != size.len() || start.len() != stride.len() {
return Err(Error::Runtime(
"start, size, and stride must have the same length".to_string(),
));
}
let start_dims = trtx_sys::Dims::from_slice(start);
let size_dims = trtx_sys::Dims::from_slice(size);
let stride_dims = trtx_sys::Dims::from_slice(stride);
let layer_ptr =
self.inner
.pin_mut()
.addSlice(input.pin_mut(), &start_dims, &size_dims, &stride_dims);
SliceLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_topk(
&mut self,
input: &'_ Tensor,
op: TopKOperation,
k: i32,
axes: crate::Axes,
) -> Result<TopKLayer<'network>> {
check_network!(self, input);
debug!(
"add_topk input={} op={op:?} k={k} axes={axes:?}",
tensor_dbg(self, input)
);
let axes_bits = axes.to_bits();
let layer_ptr = self
.inner
.pin_mut()
.addTopK(input.pin_mut(), op.into(), k, axes_bits);
TopKLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_resize(&mut self, input: &'_ Tensor) -> Result<ResizeLayer<'network>> {
check_network!(self, input);
debug!("add_resize input={}", tensor_dbg(self, input));
let layer_ptr = self.inner.pin_mut().addResize(input.pin_mut());
ResizeLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_gather(
&'_ mut self,
data: &'_ Tensor,
indices: &'_ Tensor,
axis: i32,
) -> Result<GatherLayer<'network>> {
check_network!(self, data);
check_network!(self, indices);
debug!(
"add_gather data={} indices={} axis={axis}",
tensor_dbg(self, data),
tensor_dbg(self, indices)
);
let layer_ptr = self
.inner
.pin_mut()
.addGather(data.pin_mut(), indices.pin_mut(), axis);
GatherLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_scatter(
&mut self,
data: &'_ Tensor,
indices: &'_ Tensor,
updates: &'_ Tensor,
mode: trtx_sys::ScatterMode,
) -> Result<ScatterLayer<'network>> {
check_network!(self, data);
check_network!(self, indices);
check_network!(self, updates);
debug!(
"add_scatter data={} indices={} updates={} mode={mode:?}",
tensor_dbg(self, data),
tensor_dbg(self, indices),
tensor_dbg(self, updates)
);
let layer_ptr = self.inner.pin_mut().addScatter(
data.pin_mut(),
indices.pin_mut(),
updates.pin_mut(),
mode.into(),
);
ScatterLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_quantize(
&'_ mut self,
input: &'_ Tensor,
scale: &'_ Tensor,
output_type: trtx_sys::DataType,
) -> Result<QuantizeLayer<'network>> {
check_network!(self, input);
check_network!(self, scale);
debug!(
"add_quantize input={} scale={} output_type={output_type:?}",
tensor_dbg(self, input),
tensor_dbg(self, scale)
);
#[cfg(not(feature = "enterprise"))]
let layer_ptr =
self.inner
.pin_mut()
.addQuantize(input.pin_mut(), scale.pin_mut(), output_type.into());
#[cfg(feature = "enterprise")]
let layer_ptr =
self.inner
.pin_mut()
.addQuantize1(input.pin_mut(), scale.pin_mut(), output_type.into());
QuantizeLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_dequantize(
&mut self,
input: &'_ Tensor,
scale: &'_ Tensor,
output_type: trtx_sys::DataType,
) -> Result<DequantizeLayer<'network>> {
check_network!(self, input);
check_network!(self, scale);
debug!(
"add_dequantize input={} scale={} output_type={output_type:?}",
tensor_dbg(self, input),
tensor_dbg(self, scale)
);
#[cfg(not(feature = "enterprise"))]
let layer_ptr = self.inner.pin_mut().addDequantize(
input.pin_mut(),
scale.pin_mut(),
output_type.into(),
);
#[cfg(feature = "enterprise")]
let layer_ptr = self.inner.pin_mut().addDequantize1(
input.pin_mut(),
scale.pin_mut(),
output_type.into(),
);
DequantizeLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_select(
&mut self,
condition: &'_ Tensor,
then_input: &'_ Tensor,
else_input: &'_ Tensor,
) -> Result<SelectLayer<'network>> {
check_network!(self, condition);
check_network!(self, then_input);
check_network!(self, else_input);
debug!(
"add_select condition={} then_input={} else_input={}",
tensor_dbg(self, condition),
tensor_dbg(self, then_input),
tensor_dbg(self, else_input)
);
let layer_ptr = self.inner.pin_mut().addSelect(
condition.pin_mut(),
then_input.pin_mut(),
else_input.pin_mut(),
);
SelectLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_padding(
&mut self,
input: &'_ Tensor,
pre_padding: &[i64],
post_padding: &[i64],
) -> Result<PaddingLayer<'network>> {
check_network!(self, input);
debug!(
"add_padding input={} pre_padding={pre_padding:?} post_padding={post_padding:?}",
tensor_dbg(self, input)
);
if pre_padding.len() != post_padding.len() {
return Err(Error::Runtime(
"pre_padding and post_padding must have the same length".to_string(),
));
}
let pre_dims = trtx_sys::Dims::from_slice(pre_padding);
let post_dims = trtx_sys::Dims::from_slice(post_padding);
let layer_ptr = self
.inner
.pin_mut()
.addPaddingNd(input.pin_mut(), &pre_dims, &post_dims);
PaddingLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_assertion(&mut self, condition: &'_ Tensor, message: &str) -> Result<()> {
check_network!(self, condition);
debug!(
"add_assertion condition={} message={message:?}",
tensor_dbg(self, condition)
);
let message_cstr = std::ffi::CString::new(message)?;
let layer_ptr = unsafe {
self.inner
.pin_mut()
.addAssertion(condition.pin_mut(), message_cstr.as_ptr())
};
let _ = AssertionLayer::new(self.inner.as_ptr(), layer_ptr)?;
Ok(())
}
pub fn add_loop(&mut self) -> Result<Loop<'network>> {
debug!("add_loop");
let loop_ptr = self.inner.pin_mut().addLoop();
let loop_ptr = unsafe { loop_ptr.as_mut() }
.ok_or_else(|| Error::Runtime("Failed to add loop".to_string()))?;
Ok(Loop {
inner: unsafe { Pin::new_unchecked(loop_ptr) },
network: self.inner.as_ptr(),
})
}
pub fn add_if_conditional(&mut self) -> Result<IfConditional<'network>> {
debug!("add_if_conditional");
let if_ptr = self.inner.pin_mut().addIfConditional();
let if_ptr = unsafe { if_ptr.as_mut() }
.ok_or_else(|| Error::Runtime("Failed to add if conditional".to_string()))?;
Ok(IfConditional {
inner: unsafe { Pin::new_unchecked(if_ptr) },
network: self.inner.as_ptr(),
})
}
pub fn add_attention(
&mut self,
query: &'_ Tensor,
key: &'_ Tensor,
value: &'_ Tensor,
norm_op: trtx_sys::AttentionNormalizationOp,
causal: bool,
) -> Result<Attention<'network>> {
check_network!(self, query);
check_network!(self, key);
check_network!(self, value);
debug!(
"add_attention query={} key={} value={} norm_op={norm_op:?} causal={causal}",
tensor_dbg(self, query),
tensor_dbg(self, key),
tensor_dbg(self, value)
);
let attn_ptr = self.inner.pin_mut().addAttention(
query.pin_mut(),
key.pin_mut(),
value.pin_mut(),
norm_op.into(),
causal,
);
let attn = unsafe { attn_ptr.as_mut() }
.ok_or_else(|| Error::Runtime("Failed to add attention".to_string()))?;
Ok(Attention {
inner: unsafe { Pin::new_unchecked(attn) },
network: self.inner.as_ptr(),
})
}
#[cfg(feature = "v_1_4")]
pub fn add_moe(
&mut self,
hidden_states: &Tensor,
selected_experts_for_tokens: &Tensor,
scores_for_selected_experts: &Tensor,
) -> Result<MoELayer<'network>> {
check_network!(self, hidden_states);
check_network!(self, selected_experts_for_tokens);
check_network!(self, scores_for_selected_experts);
debug!(
"add_moe hidden_states={} selected_experts_for_tokens={} scores_for_selected_experts={}",
tensor_dbg(self, hidden_states),
tensor_dbg(self, selected_experts_for_tokens),
tensor_dbg(self, scores_for_selected_experts)
);
let layer_ptr = self.inner.pin_mut().addMoE(
hidden_states.pin_mut(),
selected_experts_for_tokens.pin_mut(),
scores_for_selected_experts.pin_mut(),
);
MoELayer::new(self.inner.as_ptr(), layer_ptr)
}
#[cfg(feature = "v_1_4")]
pub fn add_dist_collective(
&mut self,
input: &Tensor,
dist_collective_op: CollectiveOperation,
reduce_op: ReduceOperation,
root: i64,
groups: &[i64],
) -> Result<DistCollectiveLayer<'network>> {
check_network!(self, input);
debug!(
"add_dist_collective input={} dist_collective_op={dist_collective_op:?} reduce_op={reduce_op:?} root={root} groups={groups:?}",
tensor_dbg(self, input)
);
let (groups_ptr, group_size) = if groups.is_empty() {
(std::ptr::null_mut(), 0i64)
} else {
(groups.as_ptr() as *mut i64, groups.len() as i64)
};
let layer_ptr = unsafe {
self.inner.pin_mut().addDistCollective(
input.pin_mut(),
dist_collective_op.into(),
reduce_op.into(),
root,
groups_ptr,
group_size,
)
};
DistCollectiveLayer::new(self.inner.as_ptr(), layer_ptr)
}
}
impl<'network> Attention<'network> {
pub fn set_normalization_operation(
&mut self,
network: &mut NetworkDefinition,
op: trtx_sys::AttentionNormalizationOp,
) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setNormalizationOperation(op.into())
.ok_or_err(PropertySetAttempt::AttentionLayerNormalizationOp)
}
pub fn normalization_operation(
&self,
network: &NetworkDefinition,
) -> trtx_sys::AttentionNormalizationOp {
check_network!(network, self);
self.inner.getNormalizationOperation().into()
}
pub fn set_mask(&mut self, network: &mut NetworkDefinition, mask: &Tensor) -> Result<()> {
check_network!(network, self);
check_network!(network, mask);
self.inner
.as_mut()
.setMask(mask.pin_mut())
.ok_or_err(PropertySetAttempt::AttentionLayerMask)
}
pub fn mask(&mut self, network: &mut NetworkDefinition) -> Option<Tensor<'network>> {
check_network!(network, self);
let p = self.inner.as_mut().getMask();
if p.is_null() {
None
} else {
unsafe { Tensor::new(self.network, p).ok() }
}
}
pub fn set_causal(&mut self, network: &mut NetworkDefinition, is_causal: bool) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setCausal(is_causal)
.ok_or_err(PropertySetAttempt::AttentionLayerCausal)
}
pub fn causal(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.inner.getCausal()
}
pub fn set_decomposable(
&mut self,
network: &mut NetworkDefinition,
decomposable: bool,
) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setDecomposable(decomposable)
.ok_or_err(PropertySetAttempt::AttentionLayerDecomposable)
}
pub fn decomposable(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.inner.getDecomposable()
}
pub fn set_input(
&mut self,
network: &mut NetworkDefinition,
index: i32,
input: &Tensor,
) -> Result<()> {
check_network!(network, self);
check_network!(network, input);
self.inner
.as_mut()
.setInput(index, input.pin_mut())
.ok_or_err(PropertySetAttempt::AttentionLayerInput)
}
pub fn num_inputs(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.getNbInputs()
}
pub fn input(&self, network: &NetworkDefinition, index: i32) -> Result<Tensor<'network>> {
check_network!(network, self);
let tensor_ptr = self.inner.getInput(index);
unsafe { Tensor::new(self.network, tensor_ptr) }
}
#[deprecated = "use input instead"]
pub fn get_input(&self, network: &NetworkDefinition, index: i32) -> Result<Tensor<'network>> {
self.input(network, index)
}
pub fn num_outputs(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.getNbOutputs()
}
pub fn output(&self, network: &NetworkDefinition, index: i32) -> Result<Tensor<'network>> {
check_network!(network, self);
let tensor_ptr = self.inner.getOutput(index);
unsafe { Tensor::new(self.network, tensor_ptr) }
}
#[deprecated = "use output instead"]
pub fn get_output(&self, network: &NetworkDefinition, index: i32) -> Result<Tensor<'network>> {
self.output(network, index)
}
pub fn set_name(&mut self, network: &mut NetworkDefinition, name: &str) -> Result<()> {
check_network!(network, self);
let name = CString::new(name)?;
unsafe { self.inner.as_mut().setName(name.as_ptr()) }
.ok_or_err(PropertySetAttempt::AttentionLayerName)
}
pub fn name(&self, network: &NetworkDefinition) -> String {
check_network!(network, self);
let name = self.inner.getName();
if name.is_null() {
"(unamed)".to_string()
} else {
unsafe { CStr::from_ptr(name).to_string_lossy().to_string() }
}
}
pub fn set_normalization_quantize_scale(
&mut self,
network: &mut NetworkDefinition,
tensor: &Tensor,
) -> Result<()> {
check_network!(network, self);
check_network!(network, tensor);
self.inner
.as_mut()
.setNormalizationQuantizeScale(tensor.pin_mut())
.ok_or_err(PropertySetAttempt::AttentionLayerQuantizeScale)
}
pub fn normalization_quantize_scale(
&self,
network: &NetworkDefinition,
) -> Option<Tensor<'network>> {
check_network!(network, self);
let p = self.inner.getNormalizationQuantizeScale();
if p.is_null() {
None
} else {
unsafe { Tensor::new(self.network, p).ok() }
}
}
pub fn set_normalization_quantize_to_type(
&mut self,
network: &mut NetworkDefinition,
type_: DataType,
) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setNormalizationQuantizeToType(type_.into())
.ok_or_err(PropertySetAttempt::AttentionLayerQuantizeToType)
}
pub fn normalization_quantize_to_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.inner.getNormalizationQuantizeToType().into()
}
#[deprecated = "use normalization_quantize_to_type instead"]
pub fn get_normalization_quantize_to_type(&self, network: &NetworkDefinition) -> DataType {
self.normalization_quantize_to_type(network)
}
pub fn set_metadata(&mut self, network: &mut NetworkDefinition, metadata: &str) -> Result<()> {
check_network!(network, self);
let metadata_cstr = CString::new(metadata)?;
unsafe { self.inner.as_mut().setMetadata(metadata_cstr.as_ptr()) }
.ok_or_err(PropertySetAttempt::AttentionLayerMetadata)
}
pub fn metadata(&self, network: &NetworkDefinition) -> String {
check_network!(network, self);
let p = self.inner.getMetadata();
if p.is_null() {
String::new()
} else {
unsafe { CStr::from_ptr(p).to_string_lossy().to_string() }
}
}
#[cfg(feature = "v_1_4")]
pub fn set_nb_ranks(&mut self, network: &mut NetworkDefinition, nb_ranks: i32) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setNbRanks(nb_ranks)
.ok_or_err(PropertySetAttempt::AttentionLayerNumRanks)
}
#[cfg(feature = "v_1_4")]
pub fn nb_ranks(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.getNbRanks()
}
#[deprecated = "use nb_ranks instead"]
pub fn get_nb_ranks(&self, network: &NetworkDefinition) -> i32 {
self.nb_ranks(network)
}
}
impl<'network> Loop<'network> {
pub fn add_recurrence(
&mut self,
network: &mut NetworkDefinition,
initial_value: &'_ Tensor,
) -> Result<RecurrenceLayer<'network>> {
check_network!(network, self);
check_network!(network, initial_value);
debug!(
"Loop::add_recurrence initial_value={}",
tensor_dbg(network, initial_value)
);
let layer_ptr = { self.inner.as_mut().addRecurrence(initial_value.pin_mut()) };
RecurrenceLayer::new(network.inner.as_ptr(), layer_ptr)
}
pub fn add_trip_limit(
&mut self,
network: &mut NetworkDefinition,
tensor: &'_ Tensor,
limit: trtx_sys::TripLimit,
) -> Result<TripLimitLayer<'network>> {
check_network!(network, self);
check_network!(network, tensor);
debug!(
"Loop::add_trip_limit tensor={} limit={limit:?}",
tensor_dbg(network, tensor)
);
let layer_ptr = {
self.inner
.as_mut()
.addTripLimit(tensor.pin_mut(), limit.into())
};
TripLimitLayer::new(network.inner.as_ptr(), layer_ptr)
}
pub fn add_iterator(
&mut self,
network: &mut NetworkDefinition,
tensor: &'_ Tensor,
axis: i32,
reverse: bool,
) -> Result<IteratorLayer<'network>> {
check_network!(network, self);
check_network!(network, tensor);
debug!(
"Loop::add_iterator tensor={} axis={axis} reverse={reverse}",
tensor_dbg(network, tensor)
);
let layer_ptr = {
self.inner
.as_mut()
.addIterator(tensor.pin_mut(), axis, reverse)
};
IteratorLayer::new(network.inner.as_ptr(), layer_ptr)
}
pub fn add_loop_output(
&mut self,
network: &mut NetworkDefinition,
tensor: &'_ Tensor,
output_kind: trtx_sys::LoopOutput,
axis: i32,
) -> Result<LoopOutputLayer<'network>> {
check_network!(network, self);
check_network!(network, tensor);
debug!(
"Loop::add_loop_output tensor={} output_kind={output_kind:?} axis={axis}",
tensor_dbg(network, tensor)
);
let layer_ptr =
self.inner
.as_mut()
.addLoopOutput(tensor.pin_mut(), output_kind.into(), axis);
LoopOutputLayer::new(network.inner.as_ptr(), layer_ptr)
}
}
impl<'network> IfConditional<'network> {
pub fn set_condition(
&mut self,
network: &mut NetworkDefinition,
condition: &'_ Tensor,
) -> Result<ConditionLayer<'network>> {
check_network!(network, self);
check_network!(network, condition);
let layer_ptr = self.inner.as_mut().setCondition(condition.pin_mut());
ConditionLayer::new(network.inner.as_ptr(), layer_ptr)
}
pub fn add_input(
&mut self,
network: &mut NetworkDefinition,
input: &'_ Tensor,
) -> Result<IfConditionalInputLayer<'network>> {
check_network!(network, self);
check_network!(network, input);
debug!(
"IfConditional::add_input input={}",
tensor_dbg(network, input)
);
let layer_ptr = self.inner.as_mut().addInput(input.pin_mut());
IfConditionalInputLayer::new(network.inner.as_ptr(), layer_ptr)
}
pub fn add_output(
&mut self,
network: &mut NetworkDefinition,
true_output: &'_ Tensor,
false_output: &'_ Tensor,
) -> Result<IfConditionalOutputLayer<'network>> {
check_network!(network, self);
check_network!(network, true_output);
check_network!(network, false_output);
debug!(
"IfConditional::add_output true_output={} false_output={}",
tensor_dbg(network, true_output),
tensor_dbg(network, false_output)
);
let layer_ptr = self
.inner
.as_mut()
.addOutput(true_output.pin_mut(), false_output.pin_mut());
IfConditionalOutputLayer::new(network.inner.as_ptr(), layer_ptr)
}
}
impl<'network> RecurrenceLayer<'network> {}
impl IteratorLayer<'_> {
pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) {
check_network!(network, self);
self.inner.as_mut().setAxis(axis);
}
pub fn set_reverse(&mut self, network: &mut NetworkDefinition, reverse: bool) {
check_network!(network, self);
self.inner.as_mut().setReverse(reverse);
}
}
impl LoopOutputLayer<'_> {
pub fn loop_output(&self, network: &NetworkDefinition) -> trtx_sys::nvinfer1::LoopOutput {
check_network!(network, self);
self.inner.as_ref().getLoopOutput()
}
#[deprecated = "use loop_output instead"]
pub fn get_loop_output(&self, network: &NetworkDefinition) -> trtx_sys::nvinfer1::LoopOutput {
self.loop_output(network)
}
pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) {
check_network!(network, self);
self.inner.as_mut().setAxis(axis);
}
}
impl TripLimitLayer<'_> {
pub fn trip_limit(&self, network: &NetworkDefinition) -> trtx_sys::nvinfer1::TripLimit {
check_network!(network, self);
self.inner.as_ref().getTripLimit()
}
#[deprecated = "use trip_limit instead"]
pub fn get_trip_limit(&self, network: &NetworkDefinition) -> trtx_sys::nvinfer1::TripLimit {
self.trip_limit(network)
}
}
impl<'builder> NetworkDefinition<'builder> {
pub fn add_normalization(
&mut self,
input: &'_ Tensor,
scale: &'_ Tensor,
bias: &'_ Tensor,
axes_mask: crate::Axes,
) -> Result<NormalizationLayer<'builder>> {
check_network!(self, input);
check_network!(self, scale);
check_network!(self, bias);
debug!(
"add_normalization input={} scale={} bias={} axes_mask={axes_mask:?}",
tensor_dbg(self, input),
tensor_dbg(self, scale),
tensor_dbg(self, bias)
);
let axes_bits = axes_mask.to_bits();
let ptr = self.inner.pin_mut().addNormalization(
input.pin_mut(),
scale.pin_mut(),
bias.pin_mut(),
axes_bits,
);
NormalizationLayer::new(self.inner.as_ptr(), ptr)
}
pub fn add_normalization_v2(
&mut self,
input: &'_ Tensor,
scale: &'_ Tensor,
bias: &'_ Tensor,
axes_mask: crate::Axes,
) -> Result<NormalizationLayer<'builder>> {
check_network!(self, input);
check_network!(self, scale);
check_network!(self, bias);
debug!(
"add_normalization_v2 input={} scale={} bias={} axes_mask={axes_mask:?}",
tensor_dbg(self, input),
tensor_dbg(self, scale),
tensor_dbg(self, bias)
);
let axes_bits = axes_mask.to_bits();
let ptr = self.inner.pin_mut().addNormalizationV2(
input.pin_mut(),
scale.pin_mut(),
bias.pin_mut(),
axes_bits,
);
NormalizationLayer::new(self.inner.as_ptr(), ptr)
}
pub fn set_error_recorder(&mut self, error_recorder: Box<dyn RecordError>) -> Result<()> {
let error_recorder = ErrorRecorder::new(error_recorder)?;
if self.error_recorder.is_some() {
panic!("Setting a progress monitor more than once not supported at the moment");
}
self.error_recorder = Some(error_recorder);
let rec = self
.error_recorder
.as_mut()
.unwrap()
.as_trt_error_recorder();
#[cfg(not(feature = "mock"))]
unsafe {
self.inner.pin_mut().setErrorRecorder(rec)
};
Ok(())
}
pub fn add_grid_sample(
&mut self,
input: &'_ Tensor,
grid: &'_ Tensor,
) -> Result<GridSampleLayer<'builder>> {
check_network!(self, input);
check_network!(self, grid);
let ptr = self
.inner
.pin_mut()
.addGridSample(input.pin_mut(), grid.pin_mut());
GridSampleLayer::new(self.inner.as_ptr(), ptr)
}
}
impl<'network> GridSampleLayer<'network> {
pub fn set_interpolation_mode(
&mut self,
network: &mut NetworkDefinition,
mode: InterpolationMode,
) {
check_network!(network, self);
self.inner.as_mut().setInterpolationMode(mode.into());
}
pub fn interpolation_mode(&self, network: &NetworkDefinition) -> InterpolationMode {
check_network!(network, self);
self.inner.getInterpolationMode().into()
}
pub fn set_sample_mode(&mut self, network: &mut NetworkDefinition, mode: SampleMode) {
check_network!(network, self);
self.inner.as_mut().setSampleMode(mode.into());
}
pub fn sample_mode(&self, network: &NetworkDefinition) -> SampleMode {
check_network!(network, self);
self.inner.getSampleMode().into()
}
pub fn set_align_corners(&mut self, network: &mut NetworkDefinition, align_corners: bool) {
check_network!(network, self);
self.inner.as_mut().setAlignCorners(align_corners);
}
pub fn align_corners(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.inner.getAlignCorners()
}
}
impl<'network> CastLayer<'network> {
pub fn set_to_type(&mut self, network: &mut NetworkDefinition<'_>, data_type: DataType) {
check_network!(network, self);
self.inner.as_mut().setToType(data_type.into())
}
pub fn to_type(&self, network: &NetworkDefinition<'_>) -> DataType {
check_network!(network, self);
self.inner.getToType().into()
}
}
#[cfg(test)]
mod test {
use trtx_sys::LayerType;
use crate::{Builder, Logger};
#[test]
#[cfg(not(feature = "mock"))]
fn test_get_layer() {
let logger = Logger::stderr().unwrap();
let mut builder = Builder::new(&logger).unwrap();
let mut network = builder.create_network(0).unwrap();
let input = network
.add_input("a", trtx_sys::DataType::kFLOAT, &[1])
.unwrap();
let a = network
.add_activation(&input, trtx_sys::ActivationType::kRELU)
.unwrap()
.output(&network, 0)
.unwrap();
let b = network
.add_activation(&a, trtx_sys::ActivationType::kRELU)
.unwrap()
.output(&network, 0)
.unwrap();
let c = network
.add_activation(&b, trtx_sys::ActivationType::kRELU)
.unwrap()
.output(&network, 0)
.unwrap();
a.set_name(&mut network, "Fritz").unwrap();
b.set_name(&mut network, "Adam").unwrap();
c.set_name(&mut network, "James").unwrap();
assert_eq!(
&network
.layer(0)
.unwrap()
.output(&network, 0)
.unwrap()
.name(&network)
.unwrap(),
"Fritz"
);
assert_eq!(
&network
.layer(1)
.unwrap()
.output(&network, 0)
.unwrap()
.name(&network)
.unwrap(),
"Adam"
);
assert_eq!(
&network
.layer(2)
.unwrap()
.output(&network, 0)
.unwrap()
.name(&network)
.unwrap(),
"James"
);
assert_eq!(
network.layer(2).unwrap().layer_type_dynamic(),
LayerType::kACTIVATION
);
network
.layer(1)
.unwrap()
.set_name(&mut network, "Eva")
.unwrap();
assert_eq!(
&network
.layer(1)
.unwrap()
.output(&network, 0)
.unwrap()
.name(&network)
.unwrap(),
&network
.layer(2)
.unwrap()
.input(&network, 0)
.unwrap()
.name(&network)
.unwrap(),
);
assert_eq!(
"Adam",
&network
.layer(2)
.unwrap()
.input(&network, 0)
.unwrap()
.name(&network)
.unwrap()
);
assert_eq!(&network.layer(1).unwrap().name(&network), "Eva");
}
}