use crate::interfaces::RecordError;
pub use crate::tensor::Tensor;
use cxx::UniquePtr;
use std::ffi::{CStr, CString};
use std::marker::PhantomData;
use std::pin::Pin;
use trtx_sys::nvinfer1::{IConcatenationLayer, IEinsumLayer, INetworkDefinition, ITensor};
use trtx_sys::{nvinfer1, LayerType, SampleMode, Weights};
use trtx_sys::{AsLayer, AsLayerTyped};
#[cfg(feature = "v_1_5")]
use trtx_sys::{AttentionIOForm, CausalMaskKind};
#[cfg(feature = "v_1_4")]
use trtx_sys::{CollectiveOperation, MoEActType, ReduceOperation};
use trtx_sys::{
DataType, Dims64, FillOperation, GatherMode, MatrixOperation, ScaleMode, TopKOperation,
};
use trtx_sys::{InterpolationMode, KVCacheMode};
macro_rules! check_network {
($network:ident, $this:ident) => {
if $network.inner.as_ptr() != $this.network {
panic!("Layer or tensor was created from different network")
}
};
($network:ident, $tensor:expr) => {
if $network.inner.as_ptr() != $tensor.network {
panic!("Layer or tensor was created from different network")
}
};
}
pub(crate) use check_network;
use crate::error::{Error, OkOrElseError, OkOrFailedSettingProperty, PropertySetAttempt, Result};
use crate::interfaces::ErrorRecorder;
use log::{debug, trace};
#[derive(Debug)]
pub struct ConvWeights<'weights> {
pub kernel_weights: &'weights [u8],
pub kernel_dtype: crate::DataType,
pub kernel_name: Option<&'weights str>,
pub bias_weights: Option<&'weights [u8]>,
pub bias_dtype: Option<crate::DataType>,
pub bias_name: Option<&'weights str>,
}
#[derive(Debug)]
pub struct OwnedWeights {
pub shape: Vec<i64>,
pub data_type: DataType,
pub values: Vec<u8>,
pub name: Option<String>,
}
#[derive(Debug)]
pub struct OwnedConvWeights {
pub kernel: OwnedWeights,
pub bias: Option<OwnedWeights>,
}
impl OwnedConvWeights {
pub fn as_weights(&self) -> ConvWeights<'_> {
ConvWeights {
kernel_weights: &self.kernel.values,
kernel_dtype: self.kernel.data_type,
kernel_name: self.kernel.name.as_deref(),
bias_weights: self.bias.as_ref().map(|b| b.values.as_slice()),
bias_dtype: self.bias.as_ref().map(|b| b.data_type),
bias_name: self.bias.as_ref().and_then(|b| b.name.as_deref()),
}
}
}
pub struct Layer<'network, Inner: AsLayer> {
pub(crate) inner: Pin<&'network mut Inner>,
pub(crate) network: *const nvinfer1::INetworkDefinition,
}
impl<Inner: AsLayer> std::fmt::Debug for Layer<'_, Inner> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Layer")
.field("inner", &format!("{:x}", &*self.inner as *const _ as usize))
.field(
"layer_type",
&Into::<LayerType>::into(self.inner.as_layer().getType()),
)
.finish_non_exhaustive()
}
}
impl<'network, Inner: AsLayerTyped> Layer<'network, Inner> {
pub const fn layer_type(&self) -> LayerType {
Inner::TYPE
}
pub(crate) fn new(
network: *const nvinfer1::INetworkDefinition,
ptr: *mut Inner,
) -> Result<Self> {
unsafe {
let ptr = ptr
.as_mut()
.ok_or(Error::LayerCreationFailed(Inner::TYPE))?;
Ok(Self {
inner: Pin::new_unchecked(ptr),
network,
})
}
}
}
impl<'network> Layer<'network, nvinfer1::ILayer> {
pub fn layer_type_dynamic(&self) -> LayerType {
self.inner.as_layer().getType().into()
}
pub(crate) fn new_dyn(
network: *const nvinfer1::INetworkDefinition,
ptr: *mut nvinfer1::ILayer,
) -> Result<Self> {
unsafe {
let ptr = ptr.as_mut().ok_or(Error::GetLayerFailed)?;
Ok(Self {
inner: Pin::new_unchecked(ptr),
network,
})
}
}
}
impl<'network, Inner: AsLayer> Layer<'network, Inner> {
pub fn set_input(
&mut self,
network: &'_ mut NetworkDefinition,
index: i32,
tensor: &'_ Tensor,
) -> Result<()> {
check_network!(network, self);
debug!(
"set_input layer={} index={index} tensor={}",
layer_dbg(network, self),
tensor_dbg(network, tensor)
);
unsafe { self.inner.as_mut().get_unchecked_mut() }
.as_layer_pin_mut()
.setInput(index, tensor.pin_mut());
Ok(())
}
pub fn input(&self, network: &'_ NetworkDefinition, index: i32) -> Result<Tensor<'network>> {
check_network!(network, self);
let tensor = self.inner.as_layer().getInput(index);
unsafe { Tensor::new(self.network, tensor) }
}
#[deprecated = "use input instead"]
pub fn get_input(
&self,
network: &'_ NetworkDefinition,
index: i32,
) -> Result<Tensor<'network>> {
self.input(network, index)
}
pub fn output(&self, network: &'_ NetworkDefinition, index: i32) -> Result<Tensor<'network>> {
check_network!(network, self);
let tensor = self.inner.as_layer().getOutput(index);
unsafe { Tensor::new(self.network, tensor) }
}
#[deprecated = "use output instead"]
pub fn get_output(
&self,
network: &'_ NetworkDefinition,
index: i32,
) -> Result<Tensor<'network>> {
self.output(network, index)
}
pub fn num_inputs(&self, network: &'_ NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.as_layer().getNbInputs()
}
#[deprecated = "use num_inputs instead"]
pub fn get_num_inputs(&self, network: &'_ NetworkDefinition) -> i32 {
self.num_inputs(network)
}
pub fn num_outputs(&self, network: &'_ NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.as_layer().getNbOutputs()
}
#[deprecated = "use num_outputs instead"]
pub fn get_num_outputs(&self, network: &'_ NetworkDefinition) -> i32 {
self.num_outputs(network)
}
pub fn set_name(&mut self, network: &'_ mut NetworkDefinition, name: &str) -> Result<()> {
check_network!(network, self);
let name = CString::new(name)?;
unsafe {
self.inner
.as_mut()
.get_unchecked_mut()
.as_layer_pin_mut()
.setName(name.as_ptr())
};
Ok(())
}
pub fn name(&self, network: &NetworkDefinition) -> String {
check_network!(network, self);
let name = self.inner.as_layer().getName();
if name.is_null() {
"(unamed)".to_string()
} else {
unsafe { CStr::from_ptr(name).to_string_lossy().to_string() }
}
}
}
pub type ActivationLayer<'layer> = Layer<'layer, nvinfer1::IActivationLayer>;
pub type AssertionLayer<'layer> = Layer<'layer, nvinfer1::IAssertionLayer>;
pub type CastLayer<'layer> = Layer<'layer, nvinfer1::ICastLayer>;
pub type ConcatenationLayer<'layer> = Layer<'layer, nvinfer1::IConcatenationLayer>;
pub type ConstantLayer<'layer> = Layer<'layer, nvinfer1::IConstantLayer>;
pub type ConvolutionLayer<'layer> = Layer<'layer, nvinfer1::IConvolutionLayer>;
pub type CumulativeLayer<'layer> = Layer<'layer, nvinfer1::ICumulativeLayer>;
pub type DeconvolutionLayer<'layer> = Layer<'layer, nvinfer1::IDeconvolutionLayer>;
pub type DequantizeLayer<'layer> = Layer<'layer, nvinfer1::IDequantizeLayer>;
pub type DynamicQuantizeLayer<'layer> = Layer<'layer, nvinfer1::IDynamicQuantizeLayer>;
pub type ElementWiseLayer<'layer> = Layer<'layer, nvinfer1::IElementWiseLayer>;
pub type EinsumLayer<'layer> = Layer<'layer, nvinfer1::IEinsumLayer>;
pub type FillLayer<'layer> = Layer<'layer, nvinfer1::IFillLayer>;
pub type GatherLayer<'layer> = Layer<'layer, nvinfer1::IGatherLayer>;
pub type GridSampleLayer<'layer> = Layer<'layer, nvinfer1::IGridSampleLayer>;
pub type IdentityLayer<'layer> = Layer<'layer, nvinfer1::IIdentityLayer>;
pub type MatrixMultiplyLayer<'layer> = Layer<'layer, nvinfer1::IMatrixMultiplyLayer>;
pub type NMSLayer<'layer> = Layer<'layer, nvinfer1::INMSLayer>;
pub type NonZeroLayer<'layer> = Layer<'layer, nvinfer1::INonZeroLayer>;
pub type NormalizationLayer<'layer> = Layer<'layer, nvinfer1::INormalizationLayer>;
pub type PaddingLayer<'layer> = Layer<'layer, nvinfer1::IPaddingLayer>;
pub type ParametricReLULayer<'layer> = Layer<'layer, nvinfer1::IParametricReLULayer>;
pub type PoolingLayer<'layer> = Layer<'layer, nvinfer1::IPoolingLayer>;
pub type QuantizeLayer<'layer> = Layer<'layer, nvinfer1::IQuantizeLayer>;
pub type RaggedSoftMaxLayer<'layer> = Layer<'layer, nvinfer1::IRaggedSoftMaxLayer>;
pub type ReduceLayer<'layer> = Layer<'layer, nvinfer1::IReduceLayer>;
pub type ResizeLayer<'layer> = Layer<'layer, nvinfer1::IResizeLayer>;
pub type RotaryEmbeddingLayer<'layer> = Layer<'layer, nvinfer1::IRotaryEmbeddingLayer>;
pub type ScaleLayer<'layer> = Layer<'layer, nvinfer1::IScaleLayer>;
pub type ScatterLayer<'layer> = Layer<'layer, nvinfer1::IScatterLayer>;
pub type SelectLayer<'layer> = Layer<'layer, nvinfer1::ISelectLayer>;
pub type ShapeLayer<'layer> = Layer<'layer, nvinfer1::IShapeLayer>;
pub type ShuffleLayer<'layer> = Layer<'layer, nvinfer1::IShuffleLayer>;
pub type SliceLayer<'layer> = Layer<'layer, nvinfer1::ISliceLayer>;
pub type SoftMaxLayer<'layer> = Layer<'layer, nvinfer1::ISoftMaxLayer>;
pub type SqueezeLayer<'layer> = Layer<'layer, nvinfer1::ISqueezeLayer>;
pub type TopKLayer<'layer> = Layer<'layer, nvinfer1::ITopKLayer>;
pub type UnaryLayer<'layer> = Layer<'layer, nvinfer1::IUnaryLayer>;
pub type UnsqueezeLayer<'layer> = Layer<'layer, nvinfer1::IUnsqueezeLayer>;
pub type ReverseSequenceLayer<'layer> = Layer<'layer, nvinfer1::IReverseSequenceLayer>;
pub type KVCacheUpdateLayer<'layer> = Layer<'layer, nvinfer1::IKVCacheUpdateLayer>;
pub type LrnLayer<'layer> = Layer<'layer, nvinfer1::ILRNLayer>;
pub type OneHotLayer<'layer> = Layer<'layer, nvinfer1::IOneHotLayer>;
#[cfg(feature = "v_1_4")]
#[cfg(feature = "v_1_4")]
pub type MoELayer<'layer> = Layer<'layer, nvinfer1::IMoELayer>;
#[cfg(feature = "v_1_4")]
#[cfg(feature = "v_1_4")]
pub type DistCollectiveLayer<'layer> = Layer<'layer, nvinfer1::IDistCollectiveLayer>;
pub type AttentionInputLayer<'layer> = Layer<'layer, nvinfer1::IAttentionInputLayer>;
pub type AttentionOutputLayer<'layer> = Layer<'layer, nvinfer1::IAttentionOutputLayer>;
pub type AttentionBoundaryLayer<'layer> = Layer<'layer, nvinfer1::IAttentionBoundaryLayer>;
pub type LoopBoundaryLayer<'layer> = Layer<'layer, nvinfer1::ILoopBoundaryLayer>;
pub type RecurrenceLayer<'layer> = Layer<'layer, nvinfer1::IRecurrenceLayer>;
pub type LoopOutputLayer<'layer> = Layer<'layer, nvinfer1::ILoopOutputLayer>;
pub type TripLimitLayer<'layer> = Layer<'layer, nvinfer1::ITripLimitLayer>;
pub type IteratorLayer<'layer> = Layer<'layer, nvinfer1::IIteratorLayer>;
pub type ConditionLayer<'layer> = Layer<'layer, nvinfer1::IConditionLayer>;
pub type IfConditionalOutputLayer<'layer> = Layer<'layer, nvinfer1::IIfConditionalOutputLayer>;
pub type IfConditionalInputLayer<'layer> = Layer<'layer, nvinfer1::IIfConditionalInputLayer>;
pub type DynLayer<'layer> = Layer<'layer, nvinfer1::ILayer>;
pub struct Attention<'network> {
pub(crate) inner: Pin<&'network mut nvinfer1::IAttention>,
pub(crate) network: *const nvinfer1::INetworkDefinition,
}
impl std::fmt::Debug for Attention<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Attention")
.field("inner", &format!("{:x}", &*self.inner as *const _ as usize))
.finish_non_exhaustive()
}
}
pub struct Loop<'network> {
pub(crate) inner: Pin<&'network mut nvinfer1::ILoop>,
pub(crate) network: *const nvinfer1::INetworkDefinition,
}
impl std::fmt::Debug for Loop<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Loop")
.field("inner", &format!("{:x}", &*self.inner as *const _ as usize))
.finish_non_exhaustive()
}
}
pub struct IfConditional<'network> {
pub(crate) inner: Pin<&'network mut nvinfer1::IIfConditional>,
pub(crate) network: *const nvinfer1::INetworkDefinition,
}
impl std::fmt::Debug for IfConditional<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IfConditional")
.field("inner", &format!("{:x}", &*self.inner as *const _ as usize))
.finish_non_exhaustive()
}
}
impl ShuffleLayer<'_> {
pub fn set_reshape_dimensions(
&mut self,
network: &mut NetworkDefinition,
dims: &[i64],
) -> Result<()> {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(dims);
self.inner.as_mut().setReshapeDimensions(&dims_obj);
Ok(())
}
pub fn set_first_transpose(
&mut self,
network: &mut NetworkDefinition,
order: &[i32],
) -> Result<()> {
check_network!(network, self);
let mut order_arr = [0i32; 8];
let n = order.len().min(8);
order_arr[..n].copy_from_slice(&order[..n]);
let perm = trtx_sys::nvinfer1::Permutation { order: order_arr };
self.inner.as_mut().setFirstTranspose(perm);
Ok(())
}
pub fn set_second_transpose(
&mut self,
network: &mut NetworkDefinition,
order: &[i32],
) -> Result<()> {
check_network!(network, self);
let mut order_arr = [0i32; 8];
let n = order.len().min(8);
order_arr[..n].copy_from_slice(&order[..n]);
let perm = trtx_sys::nvinfer1::Permutation { order: order_arr };
self.inner.as_mut().setSecondTranspose(perm);
Ok(())
}
}
impl ResizeLayer<'_> {
pub fn set_output_dimensions(&mut self, network: &mut NetworkDefinition, dims: &[i64]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(dims);
self.inner.as_mut().setOutputDimensions(&dims_obj);
}
pub fn output_dimensions(&self, network: &NetworkDefinition) -> Vec<i64> {
check_network!(network, self);
let d = self.inner.as_ref().getOutputDimensions();
d.d[..d.nbDims as usize].to_vec()
}
pub fn set_scales(&mut self, network: &mut NetworkDefinition, scales: &[f32]) {
check_network!(network, self);
unsafe {
self.inner
.as_mut()
.setScales(scales.as_ptr(), scales.len() as i32);
}
}
pub fn scales(&self, network: &NetworkDefinition) -> Option<Vec<f32>> {
check_network!(network, self);
let n = unsafe { self.inner.as_ref().getScales(0, std::ptr::null_mut()) };
if n <= 0 {
return None;
}
let mut buf = vec![0.0_f32; n as usize];
let n2 = unsafe { self.inner.as_ref().getScales(n, buf.as_mut_ptr()) };
if n2 != n {
return None;
}
Some(buf)
}
pub fn set_resize_mode(&mut self, network: &mut NetworkDefinition, mode: trtx_sys::ResizeMode) {
check_network!(network, self);
self.inner.as_mut().setResizeMode(mode.into());
}
pub fn resize_mode(&self, network: &NetworkDefinition) -> trtx_sys::ResizeMode {
check_network!(network, self);
self.inner.as_ref().getResizeMode().into()
}
pub fn set_coordinate_transformation(
&mut self,
network: &mut NetworkDefinition,
transform: trtx_sys::ResizeCoordinateTransformation,
) {
check_network!(network, self);
self.inner
.as_mut()
.setCoordinateTransformation(transform.into());
}
pub fn coordinate_transformation(
&self,
network: &NetworkDefinition,
) -> trtx_sys::ResizeCoordinateTransformation {
check_network!(network, self);
self.inner.as_ref().getCoordinateTransformation().into()
}
pub fn set_selector_for_single_pixel(
&mut self,
network: &mut NetworkDefinition,
selector: trtx_sys::ResizeSelector,
) {
check_network!(network, self);
self.inner
.as_mut()
.setSelectorForSinglePixel(selector.into());
}
pub fn selector_for_single_pixel(
&self,
network: &NetworkDefinition,
) -> trtx_sys::ResizeSelector {
check_network!(network, self);
self.inner.as_ref().getSelectorForSinglePixel().into()
}
pub fn set_nearest_rounding(
&mut self,
network: &mut NetworkDefinition,
mode: trtx_sys::ResizeRoundMode,
) {
check_network!(network, self);
self.inner.as_mut().setNearestRounding(mode.into());
}
pub fn nearest_rounding(&self, network: &NetworkDefinition) -> trtx_sys::ResizeRoundMode {
check_network!(network, self);
self.inner.as_ref().getNearestRounding().into()
}
pub fn set_cubic_coeff(&mut self, network: &mut NetworkDefinition, a: f32) {
check_network!(network, self);
self.inner.as_mut().setCubicCoeff(a);
}
pub fn cubic_coeff(&self, network: &NetworkDefinition) -> f32 {
check_network!(network, self);
self.inner.as_ref().getCubicCoeff()
}
pub fn set_exclude_outside(&mut self, network: &mut NetworkDefinition, exclude: bool) {
check_network!(network, self);
self.inner.as_mut().setExcludeOutside(exclude);
}
pub fn exclude_outside(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.inner.as_ref().getExcludeOutside()
}
}
impl GatherLayer<'_> {
pub fn set_gather_mode(&mut self, network: &mut NetworkDefinition, mode: trtx_sys::GatherMode) {
check_network!(network, self);
self.inner.as_mut().setMode(mode.into());
}
}
impl<'network> ScatterLayer<'network> {
pub fn set_scatter_mode(
&mut self,
network: &mut NetworkDefinition,
mode: trtx_sys::ScatterMode,
) {
check_network!(network, self);
self.inner.as_mut().setMode(mode.into());
}
pub fn set_axis(&mut self, network: &'_ mut NetworkDefinition, axis: i32) {
check_network!(network, self);
self.inner.as_mut().setAxis(axis);
}
}
impl<'network> ConvolutionLayer<'network> {
pub fn set_stride(&mut self, network: &mut NetworkDefinition, stride: &[i64; 2]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(stride);
self.inner.as_mut().setStrideNd(&dims_obj);
}
pub fn set_padding(&mut self, network: &mut NetworkDefinition, padding: &[i64; 2]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(padding);
self.inner.as_mut().setPaddingNd(&dims_obj);
}
pub fn set_dilation(&mut self, network: &mut NetworkDefinition, dilation: &[i64; 2]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(dilation);
self.inner.as_mut().setDilationNd(&dims_obj);
}
pub fn set_num_groups(&mut self, network: &mut NetworkDefinition, num_groups: i64) {
check_network!(network, self);
self.inner.as_mut().setNbGroups(num_groups);
}
}
impl<'network> DeconvolutionLayer<'network> {
pub fn set_stride(&mut self, network: &mut NetworkDefinition, stride: &[i64; 2]) -> Result<()> {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(stride);
self.inner.as_mut().setStrideNd(&dims_obj);
Ok(())
}
pub fn set_pre_padding(
&mut self,
network: &mut NetworkDefinition,
padding: &[i64; 2],
) -> Result<()> {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(padding);
self.inner.as_mut().setPrePadding(&dims_obj);
Ok(())
}
pub fn set_post_padding(
&mut self,
network: &mut NetworkDefinition,
padding: &[i64; 2],
) -> Result<()> {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(padding);
self.inner.as_mut().setPostPadding(&dims_obj);
Ok(())
}
pub fn set_dilation(
&mut self,
network: &mut NetworkDefinition,
dilation: &[i64; 2],
) -> Result<()> {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(dilation);
self.inner.as_mut().setDilationNd(&dims_obj);
Ok(())
}
pub fn set_num_groups(
&mut self,
network: &mut NetworkDefinition,
num_groups: i64,
) -> Result<()> {
check_network!(network, self);
self.inner.as_mut().setNbGroups(num_groups);
Ok(())
}
}
impl<'network> PoolingLayer<'network> {
pub fn set_pooling_type(
&mut self,
network: &mut NetworkDefinition,
pooling_type: trtx_sys::PoolingType,
) {
check_network!(network, self);
self.inner.as_mut().setPoolingType(pooling_type.into());
}
pub fn pooling_type(&self, network: &NetworkDefinition) -> trtx_sys::PoolingType {
check_network!(network, self);
self.inner.as_ref().getPoolingType().into()
}
pub fn set_blend_factor(&mut self, network: &mut NetworkDefinition, blend_factor: f32) {
check_network!(network, self);
self.inner.as_mut().setBlendFactor(blend_factor);
}
pub fn blend_factor(&self, network: &NetworkDefinition) -> f32 {
check_network!(network, self);
self.inner.as_ref().getBlendFactor()
}
pub fn set_average_count_excludes_padding(
&mut self,
network: &mut NetworkDefinition,
exclusive: bool,
) {
check_network!(network, self);
self.inner
.as_mut()
.setAverageCountExcludesPadding(exclusive);
}
pub fn average_count_excludes_padding(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.inner.as_ref().getAverageCountExcludesPadding()
}
pub fn set_pre_padding(&mut self, network: &mut NetworkDefinition, padding: &[i64]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(padding);
self.inner.as_mut().setPrePadding(&dims_obj);
}
pub fn pre_padding(&self, network: &NetworkDefinition) -> Vec<i64> {
check_network!(network, self);
let d = self.inner.as_ref().getPrePadding();
d.d[..d.nbDims as usize].to_vec()
}
pub fn set_post_padding(&mut self, network: &mut NetworkDefinition, padding: &[i64]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(padding);
self.inner.as_mut().setPostPadding(&dims_obj);
}
pub fn post_padding(&self, network: &NetworkDefinition) -> Vec<i64> {
check_network!(network, self);
let d = self.inner.as_ref().getPostPadding();
d.d[..d.nbDims as usize].to_vec()
}
pub fn set_padding_mode(
&mut self,
network: &mut NetworkDefinition,
padding_mode: trtx_sys::PaddingMode,
) {
check_network!(network, self);
self.inner.as_mut().setPaddingMode(padding_mode.into());
}
pub fn padding_mode(&self, network: &NetworkDefinition) -> trtx_sys::PaddingMode {
check_network!(network, self);
self.inner.as_ref().getPaddingMode().into()
}
pub fn set_window_size_nd(&mut self, network: &mut NetworkDefinition, window_size: &[i64]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(window_size);
self.inner.as_mut().setWindowSizeNd(&dims_obj);
}
pub fn window_size_nd(&self, network: &NetworkDefinition) -> Vec<i64> {
check_network!(network, self);
let d = self.inner.as_ref().getWindowSizeNd();
d.d[..d.nbDims as usize].to_vec()
}
pub fn set_stride_nd(&mut self, network: &mut NetworkDefinition, stride: &[i64]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(stride);
self.inner.as_mut().setStrideNd(&dims_obj);
}
pub fn stride_nd(&self, network: &NetworkDefinition) -> Vec<i64> {
check_network!(network, self);
let d = self.inner.as_ref().getStrideNd();
d.d[..d.nbDims as usize].to_vec()
}
pub fn set_padding_nd(&mut self, network: &mut NetworkDefinition, padding: &[i64]) {
check_network!(network, self);
let dims_obj = trtx_sys::Dims::from_slice(padding);
self.inner.as_mut().setPaddingNd(&dims_obj);
}
pub fn padding_nd(&self, network: &NetworkDefinition) -> Vec<i64> {
check_network!(network, self);
let d = self.inner.as_ref().getPaddingNd();
d.d[..d.nbDims as usize].to_vec()
}
}
impl<'network> DynamicQuantizeLayer<'network> {
pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) {
check_network!(network, self);
self.inner.as_mut().setAxis(axis);
}
pub fn axis(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.getAxis()
}
pub fn set_block_shape(&mut self, network: &mut NetworkDefinition, block_shape: &[i64]) {
check_network!(network, self);
let dims = Dims64::from_slice(block_shape);
self.inner.as_mut().setBlockShape(&dims);
}
pub fn block_shape(&self, network: &NetworkDefinition) -> Dims64 {
check_network!(network, self);
self.inner.getBlockShape()
}
pub fn set_block_size(&mut self, network: &mut NetworkDefinition, size: i32) {
check_network!(network, self);
self.inner.as_mut().setBlockSize(size);
}
pub fn block_size(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.getBlockSize()
}
pub fn set_to_type(&mut self, network: &mut NetworkDefinition, to_type: DataType) {
check_network!(network, self);
self.inner.as_mut().setToType(to_type.into());
}
pub fn to_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.inner.getToType().into()
}
pub fn set_scale_type(&mut self, network: &mut NetworkDefinition, scale_type: DataType) {
check_network!(network, self);
self.inner.as_mut().setScaleType(scale_type.into());
}
pub fn scale_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.inner.getScaleType().into()
}
}
impl<'network> QuantizeLayer<'network> {
pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) {
check_network!(network, self);
self.inner.as_mut().setAxis(axis);
}
pub fn axis(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.getAxis()
}
pub fn set_block_shape(
&mut self,
network: &mut NetworkDefinition,
block_shape: &[i64],
) -> Result<()> {
check_network!(network, self);
let dims = Dims64::from_slice(block_shape);
self.inner
.as_mut()
.setBlockShape(&dims)
.ok_or_err(PropertySetAttempt::QuantizeLayerBlockShape)
}
pub fn block_shape(&self, network: &NetworkDefinition) -> Dims64 {
check_network!(network, self);
self.inner.getBlockShape()
}
pub fn set_to_type(&mut self, network: &mut NetworkDefinition, to_type: DataType) {
check_network!(network, self);
self.inner.as_mut().setToType(to_type.into());
}
pub fn to_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.inner.getToType().into()
}
}
impl<'network> DequantizeLayer<'network> {
pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) {
check_network!(network, self);
self.inner.as_mut().setAxis(axis);
}
pub fn axis(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.getAxis()
}
pub fn set_block_shape(
&mut self,
network: &mut NetworkDefinition,
block_shape: &[i64],
) -> Result<()> {
check_network!(network, self);
let dims = Dims64::from_slice(block_shape);
self.inner
.as_mut()
.setBlockShape(&dims)
.ok_or_err(PropertySetAttempt::DequantizeLayerBlockShape)
}
pub fn block_shape(&self, network: &NetworkDefinition) -> Dims64 {
check_network!(network, self);
self.inner.getBlockShape()
}
pub fn set_to_type(&mut self, network: &mut NetworkDefinition, to_type: DataType) {
check_network!(network, self);
self.inner.as_mut().setToType(to_type.into());
}
pub fn to_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.inner.getToType().into()
}
}
impl ConcatenationLayer<'_> {
pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) {
check_network!(network, self);
self.inner.as_mut().setAxis(axis);
}
}
impl RotaryEmbeddingLayer<'_> {
pub fn set_interleaved(&mut self, network: &mut NetworkDefinition, interleaved: bool) {
check_network!(network, self);
self.inner.as_mut().setInterleaved(interleaved);
}
pub fn interleaved(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.inner.as_ref().getInterleaved()
}
pub fn set_rotary_embedding_dim(
&mut self,
network: &mut NetworkDefinition,
rotary_embedding_dim: i32,
) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setRotaryEmbeddingDim(rotary_embedding_dim)
.ok_or_err(PropertySetAttempt::RotaryEmbeddingLayerRotaryEmbeddingDim)
}
pub fn rotary_embedding_dim(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.as_ref().getRotaryEmbeddingDim()
}
}
impl NormalizationLayer<'_> {
pub fn set_epsilon(&mut self, network: &mut NetworkDefinition, eps: f32) {
check_network!(network, self);
self.inner.as_mut().setEpsilon(eps);
}
pub fn epsilon(&self, network: &NetworkDefinition) -> f32 {
check_network!(network, self);
self.inner.as_ref().getEpsilon()
}
#[deprecated = "use epsilon instead"]
pub fn get_epsilon(&self, network: &NetworkDefinition) -> f32 {
self.epsilon(network)
}
pub fn set_axes(&mut self, network: &mut NetworkDefinition, axes: crate::Axes) {
check_network!(network, self);
self.inner.as_mut().setAxes(axes.to_bits());
}
pub fn axes(&self, network: &NetworkDefinition) -> crate::Axes {
check_network!(network, self);
crate::Axes::from_bits(self.inner.as_ref().getAxes())
}
#[deprecated = "use axes instead"]
pub fn get_axes(&self, network: &NetworkDefinition) -> crate::Axes {
self.axes(network)
}
pub fn set_num_groups(&mut self, network: &mut NetworkDefinition, groups: i64) {
check_network!(network, self);
self.inner.as_mut().setNbGroups(groups);
}
pub fn num_groups(&self, network: &NetworkDefinition) -> i64 {
check_network!(network, self);
self.inner.as_ref().getNbGroups()
}
#[deprecated = "use num_groups instead"]
pub fn get_num_groups(&self, network: &NetworkDefinition) -> i64 {
self.num_groups(network)
}
pub fn is_v2(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.inner.as_ref().isV2()
}
}
#[cfg(feature = "v_1_4")]
impl MoELayer<'_> {
pub fn set_gated_weights(
&mut self,
network: &mut NetworkDefinition,
fc_gate_weights: &Tensor,
fc_up_weights: &Tensor,
fc_down_weights: &Tensor,
activation_type: MoEActType,
) -> Result<()> {
check_network!(network, self);
check_network!(network, fc_gate_weights);
check_network!(network, fc_up_weights);
check_network!(network, fc_down_weights);
self.inner.as_mut().setGatedWeights(
fc_gate_weights.pin_mut(),
fc_up_weights.pin_mut(),
fc_down_weights.pin_mut(),
activation_type.into(),
);
Ok(())
}
pub fn set_gated_biases(
&mut self,
network: &mut NetworkDefinition,
fc_gate_biases: &Tensor,
fc_up_biases: &Tensor,
fc_down_biases: &Tensor,
) -> Result<()> {
check_network!(network, self);
check_network!(network, fc_gate_biases);
check_network!(network, fc_up_biases);
check_network!(network, fc_down_biases);
self.inner.as_mut().setGatedBiases(
fc_gate_biases.pin_mut(),
fc_up_biases.pin_mut(),
fc_down_biases.pin_mut(),
);
Ok(())
}
pub fn set_activation_type(
&mut self,
network: &mut NetworkDefinition,
activation_type: MoEActType,
) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setActivationType(activation_type.into());
Ok(())
}
pub fn activation_type(&self, network: &NetworkDefinition) -> MoEActType {
check_network!(network, self);
self.inner.as_ref().getActivationType().into()
}
pub fn set_quantization_static(
&mut self,
network: &mut NetworkDefinition,
fc_down_activation_scale: &Tensor,
data_type: DataType,
) -> Result<()> {
check_network!(network, self);
check_network!(network, fc_down_activation_scale);
self.inner
.as_mut()
.setQuantizationStatic(fc_down_activation_scale.pin_mut(), data_type.into());
Ok(())
}
pub fn set_quantization_dynamic_dbl_q(
&mut self,
network: &mut NetworkDefinition,
fc_down_activation_dbl_q_scale: &Tensor,
data_type: DataType,
block_shape: &[i64],
dyn_q_output_scale_type: DataType,
) -> Result<()> {
check_network!(network, self);
check_network!(network, fc_down_activation_dbl_q_scale);
let block = trtx_sys::Dims::from_slice(block_shape);
self.inner.as_mut().setQuantizationDynamicDblQ(
fc_down_activation_dbl_q_scale.pin_mut(),
data_type.into(),
&block,
dyn_q_output_scale_type.into(),
);
Ok(())
}
pub fn set_quantization_to_type(
&mut self,
network: &mut NetworkDefinition,
type_: DataType,
) -> Result<()> {
check_network!(network, self);
self.inner.as_mut().setQuantizationToType(type_.into());
Ok(())
}
pub fn quantization_to_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.inner.as_ref().getQuantizationToType().into()
}
pub fn set_quantization_block_shape(
&mut self,
network: &mut NetworkDefinition,
block_shape: &[i64],
) -> Result<()> {
check_network!(network, self);
let block = trtx_sys::Dims::from_slice(block_shape);
self.inner.as_mut().setQuantizationBlockShape(&block);
Ok(())
}
pub fn quantization_block_shape(&self, network: &NetworkDefinition) -> Vec<i64> {
check_network!(network, self);
let d = self.inner.as_ref().getQuantizationBlockShape();
d.d[..d.nbDims as usize].to_vec()
}
pub fn set_dyn_q_output_scale_type(
&mut self,
network: &mut NetworkDefinition,
type_: DataType,
) -> Result<()> {
check_network!(network, self);
self.inner.as_mut().setDynQOutputScaleType(type_.into());
Ok(())
}
pub fn dyn_q_output_scale_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.inner.as_ref().getDynQOutputScaleType().into()
}
pub fn set_swiglu_params(
&mut self,
network: &mut NetworkDefinition,
limit: f32,
alpha: f32,
beta: f32,
) -> Result<()> {
check_network!(network, self);
self.inner.as_mut().setSwigluParams(limit, alpha, beta);
Ok(())
}
pub fn set_swiglu_param_limit(
&mut self,
network: &mut NetworkDefinition,
limit: f32,
) -> Result<()> {
check_network!(network, self);
self.inner.as_mut().setSwigluParamLimit(limit);
Ok(())
}
pub fn swiglu_param_limit(&self, network: &NetworkDefinition) -> f32 {
check_network!(network, self);
self.inner.as_ref().getSwigluParamLimit()
}
pub fn set_swiglu_param_alpha(
&mut self,
network: &mut NetworkDefinition,
alpha: f32,
) -> Result<()> {
check_network!(network, self);
self.inner.as_mut().setSwigluParamAlpha(alpha);
Ok(())
}
pub fn swiglu_param_alpha(&self, network: &NetworkDefinition) -> f32 {
check_network!(network, self);
self.inner.as_ref().getSwigluParamAlpha()
}
pub fn set_swiglu_param_beta(
&mut self,
network: &mut NetworkDefinition,
beta: f32,
) -> Result<()> {
check_network!(network, self);
self.inner.as_mut().setSwigluParamBeta(beta);
Ok(())
}
pub fn swiglu_param_beta(&self, network: &NetworkDefinition) -> f32 {
check_network!(network, self);
self.inner.as_ref().getSwigluParamBeta()
}
}
#[cfg(feature = "v_1_4")]
impl DistCollectiveLayer<'_> {}
pub struct NetworkDefinition<'builder> {
pub(crate) inner: UniquePtr<INetworkDefinition>,
_builder: PhantomData<&'builder trtx_sys::nvinfer1::IBuilder>,
small_copied_weights: Vec<Vec<u8>>, error_recorder: Option<Pin<Box<ErrorRecorder>>>,
}
impl std::fmt::Debug for NetworkDefinition<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NetworkDefinition")
.field("inner", &format!("{:x}", self.inner.as_ptr() as usize))
.finish_non_exhaustive()
}
}
fn tensor_dbg(network: &NetworkDefinition<'_>, tensor: &Tensor<'_>) -> String {
tensor
.name(network)
.unwrap_or_else(|_| "(unnamed)".to_string())
}
fn layer_dbg<Inner: AsLayer>(network: &NetworkDefinition<'_>, layer: &Layer<'_, Inner>) -> String {
layer.name(network)
}
impl<'network> NetworkDefinition<'network> {
pub(crate) fn from_ptr(ptr: *mut INetworkDefinition) -> Self {
Self {
inner: unsafe { UniquePtr::from_raw(ptr) },
error_recorder: None,
_builder: Default::default(),
small_copied_weights: Default::default(),
}
}
pub fn add_input(
&mut self,
name: &str,
data_type: trtx_sys::DataType,
dims: &[i64],
) -> Result<Tensor<'network>> {
debug!("add_input name={name:?} data_type={data_type:?} dims={dims:?}");
let name_cstr = std::ffi::CString::new(name)?;
let dims_struct = trtx_sys::Dims::from_slice(dims);
let tensor_ptr = unsafe {
self.inner
.pin_mut()
.addInput(name_cstr.as_ptr(), data_type.into(), &dims_struct)
};
unsafe { Tensor::new(self.inner.as_ptr(), tensor_ptr) }
}
pub fn mark_output(&mut self, tensor: &'_ Tensor) {
check_network!(self, tensor);
debug!("mark_input tensor={}", tensor_dbg(self, tensor));
self.inner.pin_mut().markOutput(tensor.pin_mut());
}
pub fn mark_tensor_debug(&mut self, tensor: &'_ Tensor) -> Result<()> {
check_network!(self, tensor);
let success = self.inner.pin_mut().markDebug(tensor.pin_mut());
if success {
Ok(())
} else {
Err(Error::Runtime("markDebug failed".to_string()))
}
}
pub fn mark_weights_refittable(&mut self, name: &str) -> Result<()> {
debug!("mark_weights_refittable: {name:?}");
let name_cstr = std::ffi::CString::new(name)?;
unsafe {
self.inner
.pin_mut()
.markWeightsRefittable(name_cstr.as_ptr())
.ok_or_else_err(|| Error::FailedToMarkWeightsRefittable {
weight_name: name.to_string(),
})
}
}
pub fn unmark_weights_refittable(&mut self, name: &str) -> Result<()> {
debug!("unmark_weights_refittable: {name:?}");
let name_cstr = std::ffi::CString::new(name)?;
unsafe {
self.inner
.pin_mut()
.unmarkWeightsRefittable(name_cstr.as_ptr())
.ok_or_else_err(|| Error::FailedToUnmarkWeightsRefittable {
weight_name: name.to_string(),
})
}
}
pub fn are_weights_marked_refittable(&self, name: &str) -> Result<bool> {
let name_cstr = std::ffi::CString::new(name)?;
unsafe { Ok(self.inner.areWeightsMarkedRefittable(name_cstr.as_ptr())) }
}
unsafe fn set_weights_name(&mut self, weights: nvinfer1::Weights, name: &str) -> Result<()> {
let cname = CString::new(name)?;
self.inner
.pin_mut()
.setWeightsName(weights, cname.as_ptr())
.ok_or_else_err(|| Error::FailedToSetWeightsName {
weight_name: name.to_string(),
})
}
pub fn is_debug_tensor(&self, tensor: &'_ Tensor) -> bool {
check_network!(self, tensor);
self.inner.isDebugTensor(tensor.as_ref())
}
pub fn nb_inputs(&self) -> i32 {
self.inner.getNbInputs()
}
#[deprecated = "use nb_inputs instead"]
pub fn get_nb_inputs(&self) -> i32 {
self.nb_inputs()
}
pub fn nb_outputs(&self) -> i32 {
self.inner.getNbOutputs()
}
#[deprecated = "use nb_outputs instead"]
pub fn get_nb_outputs(&self) -> i32 {
self.nb_outputs()
}
pub fn input(&self, index: i32) -> Result<Tensor<'network>> {
let tensor_ptr = self.inner.getInput(index);
if tensor_ptr.is_null() {
return Err(Error::Runtime(format!(
"Failed to get input at index {}",
index
)));
}
unsafe { Tensor::new(self.inner.as_ptr(), tensor_ptr) }
}
#[deprecated = "use input instead"]
pub fn get_input(&self, index: i32) -> Result<Tensor<'network>> {
self.input(index)
}
pub fn output(&self, index: i32) -> Result<Tensor<'network>> {
let tensor_ptr = self.inner.getOutput(index);
if tensor_ptr.is_null() {
return Err(Error::Runtime(format!(
"Failed to get output at index {}",
index
)));
}
unsafe { Tensor::new(self.inner.as_ptr(), tensor_ptr) }
}
#[deprecated = "use output instead"]
pub fn get_output(&self, index: i32) -> Result<Tensor<'network>> {
self.output(index)
}
pub fn inputs(&self) -> NetworkInputIter<'_, 'network> {
NetworkInputIter {
network: self,
index: 0,
count: self.nb_inputs(),
}
}
pub fn outputs(&self) -> NetworkOutputIter<'_, 'network> {
NetworkOutputIter {
network: self,
index: 0,
count: self.nb_outputs(),
}
}
pub fn nb_layers(&self) -> i32 {
self.inner.getNbLayers()
}
#[deprecated = "use nb_layers instead"]
pub fn get_nb_layers(&self) -> i32 {
self.nb_layers()
}
pub fn layer(&self, layer_index: i32) -> Result<DynLayer<'network>> {
let layer_ptr = self.inner.getLayer(layer_index);
DynLayer::new_dyn(self.inner.as_ptr(), layer_ptr)
}
pub fn layers(&self) -> NetworkLayerIter<'_, 'network> {
NetworkLayerIter {
network: self,
index: 0,
count: self.nb_layers(),
}
}
#[deprecated = "use layer instead"]
pub fn get_layer(&self, layer_index: i32) -> Result<DynLayer<'network>> {
self.layer(layer_index)
}
pub fn layer_name(&self, layer_index: i32) -> Result<String> {
Ok(self.layer(layer_index)?.name(self))
}
#[deprecated = "use layer_name instead"]
pub fn get_layer_name(&self, layer_index: i32) -> Result<String> {
self.layer_name(layer_index)
}
pub fn layer_type(&self, layer_index: i32) -> Result<LayerType> {
Ok(self.layer(layer_index)?.layer_type_dynamic())
}
#[deprecated = "use layer_type instead"]
pub fn get_layer_type(&self, layer_index: i32) -> Result<LayerType> {
self.layer_type(layer_index)
}
pub fn add_activation(
&mut self,
input: &'_ Tensor,
activation_type: trtx_sys::ActivationType,
) -> Result<ActivationLayer<'network>> {
check_network!(self, input);
debug!(
"add_activation input={} activation_type={activation_type:?}",
tensor_dbg(self, input)
);
let layer_ptr = self
.inner
.pin_mut()
.addActivation(input.pin_mut(), activation_type.into());
ActivationLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_parametric_relu(
&mut self,
input: &'_ Tensor,
slope: &'_ Tensor,
) -> Result<ParametricReLULayer<'network>> {
check_network!(self, input);
check_network!(self, slope);
debug!(
"add_parametric_relu input={} slope={}",
tensor_dbg(self, input),
tensor_dbg(self, slope)
);
let layer_ptr = self
.inner
.pin_mut()
.addParametricReLU(input.pin_mut(), slope.pin_mut());
ParametricReLULayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_lrn(
&mut self,
input: &'_ Tensor,
window: i64,
alpha: f32,
beta: f32,
k: f32,
) -> Result<LrnLayer<'network>> {
check_network!(self, input);
debug!(
"add_lrn input={} window={window} alpha={alpha} beta={beta} k={k}",
tensor_dbg(self, input)
);
let layer_ptr = self
.inner
.pin_mut()
.addLRN(input.pin_mut(), window, alpha, beta, k);
LrnLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_unary(
&mut self,
input: &'_ Tensor,
op: trtx_sys::UnaryOperation,
) -> Result<UnaryLayer<'network>> {
check_network!(self, input);
debug!("add_unary input={} op={op:?}", tensor_dbg(self, input));
let layer_ptr = self.inner.pin_mut().addUnary(input.pin_mut(), op.into());
UnaryLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_identity(&mut self, input: &'_ Tensor) -> Result<IdentityLayer<'network>> {
check_network!(self, input);
debug!("add_identity input={}", tensor_dbg(self, input));
let layer_ptr = self.inner.pin_mut().addIdentity(input.pin_mut());
IdentityLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_shape(&mut self, input: &'_ Tensor) -> Result<ShapeLayer<'network>> {
check_network!(self, input);
debug!("add_shape input={}", tensor_dbg(self, input));
let layer_ptr = self.inner.pin_mut().addShape(input.pin_mut());
ShapeLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_cast(
&mut self,
input: &'_ Tensor,
to_type: trtx_sys::DataType,
) -> Result<CastLayer<'network>> {
check_network!(self, input);
debug!(
"add_cast input={} to_type={to_type:?}",
tensor_dbg(self, input)
);
let layer_ptr = self
.inner
.pin_mut()
.addCast(input.pin_mut(), to_type.into());
CastLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_elementwise(
&mut self,
input1: &'_ Tensor,
input2: &'_ Tensor,
op: trtx_sys::ElementWiseOperation,
) -> Result<ElementWiseLayer<'network>> {
check_network!(self, input1);
check_network!(self, input2);
debug!(
"add_elementwise input1={} input2={} op={op:?}",
tensor_dbg(self, input1),
tensor_dbg(self, input2)
);
let layer_ptr =
self.inner
.pin_mut()
.addElementWise(input1.pin_mut(), input2.pin_mut(), op.into());
ElementWiseLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_pooling(
&'_ mut self,
input: &'_ Tensor,
pooling_type: trtx_sys::PoolingType,
window_size: &[i64; 2],
) -> Result<PoolingLayer<'network>> {
check_network!(self, input);
debug!(
"add_pooling input={} pooling_type={pooling_type:?} window_size={window_size:?}",
tensor_dbg(self, input)
);
let window_dims = trtx_sys::Dims::new_2d(window_size[0], window_size[1]);
let layer_ptr =
self.inner
.pin_mut()
.addPoolingNd(input.pin_mut(), pooling_type.into(), &window_dims);
PoolingLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_shuffle(&'_ mut self, input: &'_ Tensor) -> Result<ShuffleLayer<'network>> {
check_network!(self, input);
debug!("add_shuffle input={}", tensor_dbg(self, input));
let layer_ptr = self.inner.pin_mut().addShuffle(input.pin_mut());
ShuffleLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_matrix_multiply(
&'_ mut self,
input0: &'_ Tensor,
op0: MatrixOperation,
input1: &'_ Tensor,
op1: MatrixOperation,
) -> Result<MatrixMultiplyLayer<'network>> {
check_network!(self, input0);
check_network!(self, input1);
debug!(
"add_matrix_multiply input0={} op0={op0:?} input1={} op1={op1:?}",
tensor_dbg(self, input0),
tensor_dbg(self, input1)
);
let layer_ptr = self.inner.pin_mut().addMatrixMultiply(
input0.pin_mut(),
op0.into(),
input1.pin_mut(),
op1.into(),
);
MatrixMultiplyLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_convolution_deferrred_weights(
&'_ mut self,
input: &'_ Tensor,
nb_output_maps: i32,
kernel_size: &[i32; 2],
) -> Result<ConvolutionLayer<'network>> {
debug!(
"add_convolution_deferrred_weights input={} nb_output_maps={nb_output_maps} kernel_size={kernel_size:?}",
tensor_dbg(self, input)
);
let kernel_dims = trtx_sys::Dims::new_2d(kernel_size[0] as i64, kernel_size[1] as i64);
let layer_ptr = self.inner.pin_mut().addConvolutionNd(
input.pin_mut(),
nb_output_maps as i64,
&kernel_dims,
Weights {
type_: nvinfer1::DataType::kFLOAT,
values: std::ptr::null(),
count: 0,
},
Weights {
type_: nvinfer1::DataType::kFLOAT,
values: std::ptr::null(),
count: 0,
},
);
ConvolutionLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_convolution_owned_weights(
&'_ mut self,
input: &'_ Tensor,
nb_output_maps: i32,
kernel_size: &[i32; 2],
weights: OwnedConvWeights,
) -> Result<ConvolutionLayer<'network>> {
debug!(
"add_convolution_owned_weights input={} nb_output_maps={nb_output_maps} kernel_size={kernel_size:?}",
tensor_dbg(self, input)
);
let mut layer =
self.add_convolution_deferrred_weights(input, nb_output_maps, kernel_size)?;
let kernel = self
.add_constant_owned(
&weights.kernel.shape,
weights.kernel.values,
weights.kernel.data_type,
weights.kernel.name.as_deref(),
)?
.output(self, 0)?;
layer.set_input(self, 1, &kernel)?;
if let Some(bias) = weights.bias {
let bias = self
.add_constant_owned(
&bias.shape,
bias.values,
bias.data_type,
bias.name.as_deref(),
)?
.output(self, 0)?;
layer.set_input(self, 2, &bias)?;
}
Ok(layer)
}
pub fn add_convolution(
&'_ mut self,
input: &'_ Tensor,
nb_output_maps: i32,
kernel_size: &[i32; 2],
weights: &ConvWeights<'network>,
) -> Result<ConvolutionLayer<'network>> {
check_network!(self, input);
debug!(
"add_convolution input={} nb_output_maps={nb_output_maps} kernel_size={kernel_size:?}",
tensor_dbg(self, input)
);
let kernel_dtype = weights.kernel_dtype;
let kernel_weights = weights.kernel_weights;
let bias_weights = weights.bias_weights;
let bias_dtype = weights.bias_dtype;
let kernel_bpe = kernel_dtype.size_bits() / 8;
let weight_count = (kernel_weights.len() / kernel_bpe) as i64;
let bias_dtype_val = bias_dtype.unwrap_or(kernel_dtype);
let bias_bpe = bias_dtype_val.size_bits() / 8;
let bias_count = bias_weights
.map(|b| (b.len() / bias_bpe) as i64)
.unwrap_or(0);
let kernel_ptr = if weight_count > 0 {
kernel_weights.as_ptr() as *const std::ffi::c_void
} else {
std::ptr::null()
};
let bias_ptr = if bias_count > 0 {
bias_weights
.map(|b| b.as_ptr() as *const std::ffi::c_void)
.unwrap_or(std::ptr::null())
} else {
std::ptr::null()
};
let kernel_dims = trtx_sys::Dims::new_2d(kernel_size[0] as i64, kernel_size[1] as i64);
let kernel_w = trtx_sys::nvinfer1::Weights::new_with_type(
kernel_dtype.into(),
kernel_ptr,
weight_count,
);
let bias_w =
trtx_sys::nvinfer1::Weights::new_with_type(bias_dtype_val.into(), bias_ptr, bias_count);
let layer_ptr = self.inner.pin_mut().addConvolutionNd(
input.pin_mut(),
nb_output_maps as i64,
&kernel_dims,
kernel_w,
bias_w,
);
ConvolutionLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_deconvolution(
&mut self,
input: &'_ Tensor,
nb_output_maps: i64,
kernel_size: &[i64; 2],
weights: &ConvWeights<'network>,
) -> Result<DeconvolutionLayer<'network>> {
check_network!(self, input);
debug!(
"add_deconvolution input={} nb_output_maps={nb_output_maps} kernel_size={kernel_size:?}",
tensor_dbg(self, input)
);
let kernel_dtype = weights.kernel_dtype;
let kernel_weights = weights.kernel_weights;
let bias_weights = weights.bias_weights;
let bias_dtype = weights.bias_dtype;
let kernel_bpe = kernel_dtype.size_bits() / 8;
let weight_count = (kernel_weights.len() / kernel_bpe) as i64;
let bias_dtype_val = bias_dtype.unwrap_or(kernel_dtype);
let bias_bpe = bias_dtype_val.size_bits() / 8;
let bias_count = bias_weights
.map(|b| (b.len() / bias_bpe) as i64)
.unwrap_or(0);
let kernel_ptr = if weight_count > 0 {
kernel_weights.as_ptr() as *const std::ffi::c_void
} else {
std::ptr::null()
};
let bias_ptr = if bias_count > 0 {
bias_weights
.map(|b| b.as_ptr() as *const std::ffi::c_void)
.unwrap_or(std::ptr::null())
} else {
std::ptr::null()
};
let kernel_dims = trtx_sys::Dims::new_2d(kernel_size[0], kernel_size[1]);
let kernel_w = trtx_sys::nvinfer1::Weights::new_with_type(
kernel_dtype.into(),
kernel_ptr,
weight_count,
);
let bias_w =
trtx_sys::nvinfer1::Weights::new_with_type(bias_dtype_val.into(), bias_ptr, bias_count);
let layer_ptr = self.inner.pin_mut().addDeconvolutionNd(
input.pin_mut(),
nb_output_maps,
kernel_dims,
kernel_w,
bias_w,
);
DeconvolutionLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_deconvolution_deferred_weights(
&mut self,
input: &'_ Tensor,
nb_output_maps: i64,
kernel_size: &[i64; 2],
) -> Result<DeconvolutionLayer<'network>> {
check_network!(self, input);
debug!(
"add_deconvolution_deferred_weights input={} nb_output_maps={nb_output_maps} kernel_size={kernel_size:?}",
tensor_dbg(self, input)
);
let kernel_dims = trtx_sys::Dims::new_2d(kernel_size[0], kernel_size[1]);
let layer_ptr = self.inner.pin_mut().addDeconvolutionNd(
input.pin_mut(),
nb_output_maps,
kernel_dims,
Weights {
type_: nvinfer1::DataType::kFLOAT,
values: std::ptr::null(),
count: 0,
},
Weights {
type_: nvinfer1::DataType::kFLOAT,
values: std::ptr::null(),
count: 0,
},
);
DeconvolutionLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_deconvolution_owned_weights(
&'_ mut self,
input: &'_ Tensor,
nb_output_maps: i64,
kernel_size: &[i64; 2],
weights: OwnedConvWeights,
) -> Result<DeconvolutionLayer<'network>> {
debug!(
"add_deconvolution_owned_weights input={} nb_output_maps={nb_output_maps} kernel_size={kernel_size:?}",
tensor_dbg(self, input)
);
let mut layer =
self.add_deconvolution_deferred_weights(input, nb_output_maps, kernel_size)?;
let kernel = self
.add_constant_owned(
&weights.kernel.shape,
weights.kernel.values,
weights.kernel.data_type,
weights.kernel.name.as_deref(),
)?
.output(self, 0)?;
layer.set_input(self, 1, &kernel)?;
if let Some(bias) = weights.bias {
let bias = self
.add_constant_owned(
&bias.shape,
bias.values,
bias.data_type,
bias.name.as_deref(),
)?
.output(self, 0)?;
layer.set_input(self, 2, &bias)?;
}
Ok(layer)
}
pub fn add_concatenation(
&'_ mut self,
inputs: &[&'_ Tensor],
) -> Result<ConcatenationLayer<'network>> {
for t in inputs.iter() {
check_network!(self, t);
}
let input_names: Vec<String> = inputs.iter().map(|t| tensor_dbg(self, t)).collect();
debug!("add_concatenation inputs={input_names:?}");
let mut input_ptrs: Vec<*mut std::ffi::c_void> = inputs
.iter()
.map(|t| t.as_mut() as *mut ITensor as *mut _)
.collect();
let layer_ptr = unsafe {
trtx_sys::network_add_concatenation(
self.inner.as_mut_ptr() as *mut std::ffi::c_void,
input_ptrs.as_mut_ptr(),
inputs.len() as i32,
)
} as *mut IConcatenationLayer;
ConcatenationLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_einsum(
&'_ mut self,
inputs: &[&'_ Tensor],
equation: &str,
) -> Result<EinsumLayer<'network>> {
for t in inputs.iter() {
check_network!(self, t);
}
let input_names: Vec<String> = inputs.iter().map(|t| tensor_dbg(self, t)).collect();
debug!("add_einsum inputs={input_names:?} equation={equation:?}");
let equation_cstr = CString::new(equation)?;
let mut input_ptrs: Vec<*mut std::ffi::c_void> = inputs
.iter()
.map(|t| t.as_mut() as *mut ITensor as *mut _)
.collect();
let layer_ptr = unsafe {
trtx_sys::network_add_einsum(
self.inner.as_mut_ptr() as *mut std::ffi::c_void,
input_ptrs.as_mut_ptr(),
inputs.len() as i32,
equation_cstr.as_ptr(),
)
} as *mut IEinsumLayer;
EinsumLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_small_constant_copied(
&mut self,
dims: &[i64],
weights: &[u8],
data_type: trtx_sys::DataType,
name: Option<&str>,
) -> Result<ConstantLayer<'network>> {
trace!(
"add_small_constant_copied dims={dims:?} data_type={data_type:?} weights_len={} name={name:?}",
weights.len()
);
unsafe { self.add_constant_unsafe(dims, weights, data_type, true, name) }
}
pub fn add_constant_owned(
&mut self,
dims: &[i64],
weights: Vec<u8>,
data_type: trtx_sys::DataType,
name: Option<&str>,
) -> Result<ConstantLayer<'network>> {
trace!(
"add_constant_owned dims={dims:?} data_type={data_type:?} weights_len={} name={name:?}",
weights.len()
);
let element_count: i64 = dims.iter().product();
let expected_bytes = element_count * data_type.size_bits() as i64 / 8;
if weights.len() as i64 != expected_bytes {
panic!(
"Weight size mismatch: expected {expected_bytes} bytes, got {} bytes",
weights.len()
);
}
let dims_struct = trtx_sys::Dims::from_slice(dims);
let weights_struct = trtx_sys::nvinfer1::Weights::new_with_type(
data_type.into(),
{
self.small_copied_weights.push(weights);
self.small_copied_weights
.last()
.expect("can't be empty. we just pushed")
.as_ptr()
} as *const std::ffi::c_void,
element_count,
);
let ptr = weights_struct.values;
let layer_ptr = self
.inner
.pin_mut()
.addConstant(&dims_struct, weights_struct);
let layer = ConstantLayer::new(self.inner.as_ptr(), layer_ptr)?;
if let Some(name) = name {
unsafe {
self.set_weights_name(
Weights {
type_: data_type.into(),
values: ptr,
count: element_count,
},
name,
)?
};
}
Ok(layer)
}
unsafe fn add_constant_unsafe(
&mut self,
dims: &[i64],
weights: &[u8],
data_type: trtx_sys::DataType,
copy: bool,
name: Option<&str>,
) -> Result<ConstantLayer<'network>> {
let element_count: i64 = dims.iter().product();
let expected_bytes = element_count * data_type.size_bits() as i64 / 8;
if weights.len() as i64 != expected_bytes {
panic!(
"Weight size mismatch: expected {expected_bytes} bytes, got {} bytes",
weights.len()
);
}
let dims_struct = trtx_sys::Dims::from_slice(dims);
let weights_struct = trtx_sys::nvinfer1::Weights::new_with_type(
data_type.into(),
if copy {
self.small_copied_weights.push(weights.to_vec());
self.small_copied_weights
.last()
.expect("can't be empty. we just pushed")
.as_ptr()
} else {
weights.as_ptr()
} as *const std::ffi::c_void,
element_count,
);
let ptr = weights_struct.values;
let layer_ptr = self
.inner
.pin_mut()
.addConstant(&dims_struct, weights_struct);
let layer = ConstantLayer::new(self.inner.as_ptr(), layer_ptr)?;
if let Some(name) = name {
self.set_weights_name(
Weights {
type_: data_type.into(),
values: ptr,
count: element_count,
},
name,
)?;
}
Ok(layer)
}
pub fn add_constant(
&mut self,
dims: &[i64],
weights: &'network [u8],
data_type: trtx_sys::DataType,
name: Option<&str>,
) -> Result<ConstantLayer<'network>> {
trace!(
"add_constant dims={dims:?} data_type={data_type:?} weights_len={} name={name:?}",
weights.len()
);
unsafe { self.add_constant_unsafe(dims, weights, data_type, false, name) }
}
pub fn add_softmax(
&mut self,
input: &'_ Tensor,
axes: crate::Axes,
) -> Result<SoftMaxLayer<'network>> {
check_network!(self, input);
debug!(
"add_softmax input={} axes={axes:?}",
tensor_dbg(self, input)
);
let layer_ptr = self.inner.pin_mut().addSoftMax(input.pin_mut());
let mut rtn = SoftMaxLayer::new(self.inner.as_ptr(), layer_ptr)?;
rtn.inner.as_mut().setAxes(axes.to_bits());
Ok(rtn)
}
pub fn add_ragged_softmax(
&mut self,
input: &'_ Tensor,
bounds: &'_ Tensor,
) -> Result<RaggedSoftMaxLayer<'network>> {
check_network!(self, input);
check_network!(self, bounds);
debug!(
"add_ragged_softmax input={} bounds={}",
tensor_dbg(self, input),
tensor_dbg(self, bounds)
);
let layer_ptr = self
.inner
.pin_mut()
.addRaggedSoftMax(input.pin_mut(), bounds.pin_mut());
RaggedSoftMaxLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_scale(
&mut self,
input: &'_ Tensor,
mode: ScaleMode,
shift: &[u8],
scale: &[u8],
power: &[u8],
) -> Result<ScaleLayer<'network>> {
check_network!(self, input);
debug!(
"add_scale input={} mode={mode:?} shift_len={} scale_len={} power_len={}",
tensor_dbg(self, input),
shift.len(),
scale.len(),
power.len()
);
let weight_count = match mode {
ScaleMode::kUNIFORM => 1i64,
ScaleMode::kCHANNEL => {
let input_dims = input.dimensions(self)?;
if input_dims.len() >= 4 {
input_dims[1]
} else if !input_dims.is_empty() {
input_dims[0]
} else {
1i64
}
}
ScaleMode::kELEMENTWISE => {
let input_dims = input.dimensions(self)?;
input_dims.iter().product::<i64>()
}
};
let shift_w = trtx_sys::nvinfer1::Weights::new_float(
shift.as_ptr() as *const std::ffi::c_void,
weight_count,
);
let scale_w = trtx_sys::nvinfer1::Weights::new_float(
scale.as_ptr() as *const std::ffi::c_void,
weight_count,
);
let power_w = trtx_sys::nvinfer1::Weights::new_float(
power.as_ptr() as *const std::ffi::c_void,
weight_count,
);
let layer_ptr =
self.inner
.pin_mut()
.addScale(input.pin_mut(), mode.into(), shift_w, scale_w, power_w);
ScaleLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_scale_nd(
&mut self,
input: &'_ Tensor,
mode: ScaleMode,
shift: &[u8],
scale: &[u8],
power: &[u8],
channel_axis: i32,
) -> Result<ScaleLayer<'network>> {
check_network!(self, input);
debug!(
"add_scale_nd input={} mode={mode:?} channel_axis={channel_axis} shift_len={} scale_len={} power_len={}",
tensor_dbg(self, input),
shift.len(),
scale.len(),
power.len()
);
let weight_count = match mode {
ScaleMode::kUNIFORM => 1i64,
ScaleMode::kCHANNEL => {
let input_dims = input.dimensions(self)?;
if input_dims.len() >= 4 {
input_dims[1]
} else if !input_dims.is_empty() {
input_dims[0]
} else {
1i64
}
}
ScaleMode::kELEMENTWISE => {
let input_dims = input.dimensions(self)?;
input_dims.iter().product::<i64>()
}
};
let shift_w = trtx_sys::nvinfer1::Weights::new_float(
shift.as_ptr() as *const std::ffi::c_void,
weight_count,
);
let scale_w = trtx_sys::nvinfer1::Weights::new_float(
scale.as_ptr() as *const std::ffi::c_void,
weight_count,
);
let power_w = trtx_sys::nvinfer1::Weights::new_float(
power.as_ptr() as *const std::ffi::c_void,
weight_count,
);
let layer_ptr = self.inner.pin_mut().addScaleNd(
input.pin_mut(),
mode.into(),
shift_w,
scale_w,
power_w,
channel_axis,
);
ScaleLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_reduce(
&mut self,
input: &'_ Tensor,
op: trtx_sys::ReduceOperation,
axes: crate::Axes,
keep_dims: bool,
) -> Result<ReduceLayer<'network>> {
check_network!(self, input);
debug!(
"add_reduce input={} op={op:?} axes={axes:?} keep_dims={keep_dims}",
tensor_dbg(self, input)
);
let axes_bits = axes.to_bits();
let layer_ptr =
self.inner
.pin_mut()
.addReduce(input.pin_mut(), op.into(), axes_bits, keep_dims);
ReduceLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_cumulative(
&mut self,
input: &'_ Tensor,
axis: i32,
op: trtx_sys::CumulativeOperation,
exclusive: bool,
reverse: bool,
) -> Result<CumulativeLayer<'network>> {
check_network!(self, input);
debug!(
"add_cumulative input={} axis={axis} op={op:?} exclusive={exclusive} reverse={reverse}",
tensor_dbg(self, input)
);
let axis_bytes = axis.to_le_bytes();
let axis_constant =
self.add_small_constant_copied(&[], &axis_bytes, trtx_sys::DataType::kINT32, None)?;
let axis_tensor = axis_constant.output(self, 0)?;
self.add_cumulative_with_axis_tensor(input, &axis_tensor, op, exclusive, reverse)
}
pub fn add_cumulative_with_axis_tensor(
&mut self,
input: &'_ Tensor,
axis_tensor: &'_ Tensor,
op: trtx_sys::CumulativeOperation,
exclusive: bool,
reverse: bool,
) -> Result<CumulativeLayer<'network>> {
check_network!(self, input);
check_network!(self, axis_tensor);
debug!(
"add_cumulative_with_axis_tensor input={} axis_tensor={} op={op:?} exclusive={exclusive} reverse={reverse}",
tensor_dbg(self, input),
tensor_dbg(self, axis_tensor)
);
let layer_ptr = self.inner.pin_mut().addCumulative(
input.pin_mut(),
axis_tensor.pin_mut(),
op.into(),
exclusive,
reverse,
);
CumulativeLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_slice(
&mut self,
input: &'_ Tensor,
start: &[i64],
size: &[i64],
stride: &[i64],
) -> Result<SliceLayer<'network>> {
check_network!(self, input);
debug!(
"add_slice input={} start={start:?} size={size:?} stride={stride:?}",
tensor_dbg(self, input)
);
if start.len() != size.len() || start.len() != stride.len() {
return Err(Error::Runtime(
"start, size, and stride must have the same length".to_string(),
));
}
let start_dims = trtx_sys::Dims::from_slice(start);
let size_dims = trtx_sys::Dims::from_slice(size);
let stride_dims = trtx_sys::Dims::from_slice(stride);
let layer_ptr =
self.inner
.pin_mut()
.addSlice(input.pin_mut(), &start_dims, &size_dims, &stride_dims);
SliceLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_topk_with_indices(
&mut self,
input: &'_ Tensor,
op: TopKOperation,
k: i32,
axes: crate::Axes,
indices_type: DataType,
) -> Result<TopKLayer<'network>> {
check_network!(self, input);
debug!(
"add_topk input={} op={op:?} k={k} axes={axes:?} indices_type={indices_type:?}",
tensor_dbg(self, input)
);
let axes_bits = axes.to_bits();
let layer_ptr = self.inner.pin_mut().addTopK1(
input.pin_mut(),
op.into(),
k,
axes_bits,
indices_type.into(),
);
TopKLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_topk(
&mut self,
input: &'_ Tensor,
op: TopKOperation,
k: i32,
axes: crate::Axes,
) -> Result<TopKLayer<'network>> {
check_network!(self, input);
debug!(
"add_topk input={} op={op:?} k={k} axes={axes:?}",
tensor_dbg(self, input)
);
let axes_bits = axes.to_bits();
let layer_ptr = self
.inner
.pin_mut()
.addTopK(input.pin_mut(), op.into(), k, axes_bits);
TopKLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_resize(&mut self, input: &'_ Tensor) -> Result<ResizeLayer<'network>> {
check_network!(self, input);
debug!("add_resize input={}", tensor_dbg(self, input));
let layer_ptr = self.inner.pin_mut().addResize(input.pin_mut());
ResizeLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_rotary_embedding(
&mut self,
input: &'_ Tensor,
cos_cache: &'_ Tensor,
sin_cache: &'_ Tensor,
interleaved: bool,
rotary_embedding_dim: i32,
) -> Result<RotaryEmbeddingLayer<'network>> {
check_network!(self, input);
check_network!(self, cos_cache);
check_network!(self, sin_cache);
debug!(
"add_rotary_embedding input={} cos_cache={} sin_cache={} interleaved={interleaved} rotary_embedding_dim={rotary_embedding_dim}",
tensor_dbg(self, input),
tensor_dbg(self, cos_cache),
tensor_dbg(self, sin_cache)
);
let layer_ptr = self.inner.pin_mut().addRotaryEmbedding(
input.pin_mut(),
cos_cache.pin_mut(),
sin_cache.pin_mut(),
interleaved,
rotary_embedding_dim,
);
RotaryEmbeddingLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_gather(
&'_ mut self,
data: &'_ Tensor,
indices: &'_ Tensor,
axis: i32,
) -> Result<GatherLayer<'network>> {
check_network!(self, data);
check_network!(self, indices);
debug!(
"add_gather data={} indices={} axis={axis}",
tensor_dbg(self, data),
tensor_dbg(self, indices)
);
let layer_ptr = self
.inner
.pin_mut()
.addGather(data.pin_mut(), indices.pin_mut(), axis);
GatherLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_gather_v2(
&mut self,
data: &'_ Tensor,
indices: &'_ Tensor,
mode: GatherMode,
) -> Result<GatherLayer<'network>> {
check_network!(self, data);
check_network!(self, indices);
debug!(
"add_gather_v2 data={} indices={} mode={mode:?}",
tensor_dbg(self, data),
tensor_dbg(self, indices)
);
let layer_ptr =
self.inner
.pin_mut()
.addGatherV2(data.pin_mut(), indices.pin_mut(), mode.into());
GatherLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_scatter(
&mut self,
data: &'_ Tensor,
indices: &'_ Tensor,
updates: &'_ Tensor,
mode: trtx_sys::ScatterMode,
) -> Result<ScatterLayer<'network>> {
check_network!(self, data);
check_network!(self, indices);
check_network!(self, updates);
debug!(
"add_scatter data={} indices={} updates={} mode={mode:?}",
tensor_dbg(self, data),
tensor_dbg(self, indices),
tensor_dbg(self, updates)
);
let layer_ptr = self.inner.pin_mut().addScatter(
data.pin_mut(),
indices.pin_mut(),
updates.pin_mut(),
mode.into(),
);
ScatterLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_quantize(
&'_ mut self,
input: &'_ Tensor,
scale: &'_ Tensor,
output_type: trtx_sys::DataType,
) -> Result<QuantizeLayer<'network>> {
check_network!(self, input);
check_network!(self, scale);
debug!(
"add_quantize input={} scale={} output_type={output_type:?}",
tensor_dbg(self, input),
tensor_dbg(self, scale)
);
#[cfg(not(all(feature = "enterprise", not(feature = "v_1_5"))))]
let layer_ptr =
self.inner
.pin_mut()
.addQuantize(input.pin_mut(), scale.pin_mut(), output_type.into());
#[cfg(all(feature = "enterprise", not(feature = "v_1_5")))]
let layer_ptr =
self.inner
.pin_mut()
.addQuantize1(input.pin_mut(), scale.pin_mut(), output_type.into());
QuantizeLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_dequantize(
&mut self,
input: &'_ Tensor,
scale: &'_ Tensor,
output_type: trtx_sys::DataType,
) -> Result<DequantizeLayer<'network>> {
check_network!(self, input);
check_network!(self, scale);
debug!(
"add_dequantize input={} scale={} output_type={output_type:?}",
tensor_dbg(self, input),
tensor_dbg(self, scale)
);
#[cfg(not(all(feature = "enterprise", not(feature = "v_1_5"))))]
let layer_ptr = self.inner.pin_mut().addDequantize(
input.pin_mut(),
scale.pin_mut(),
output_type.into(),
);
#[cfg(all(feature = "enterprise", not(feature = "v_1_5")))]
let layer_ptr = self.inner.pin_mut().addDequantize1(
input.pin_mut(),
scale.pin_mut(),
output_type.into(),
);
DequantizeLayer::new(self.inner.as_ptr(), layer_ptr)
}
#[deprecated(note = "use add_dynamic_quantize_v2 instead")]
pub fn add_dynamic_quantize(
&mut self,
input: &'_ Tensor,
axis: i32,
block_size: i32,
output_type: DataType,
scale_type: DataType,
) -> Result<DynamicQuantizeLayer<'network>> {
check_network!(self, input);
debug!(
"add_dynamic_quantize input={} axis={axis} block_size={block_size} output_type={output_type:?} scale_type={scale_type:?}",
tensor_dbg(self, input)
);
#[allow(deprecated)]
let layer_ptr = self.inner.pin_mut().addDynamicQuantize(
input.pin_mut(),
axis,
block_size,
output_type.into(),
scale_type.into(),
);
DynamicQuantizeLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_dynamic_quantize_v2(
&mut self,
input: &'_ Tensor,
block_shape: &[i64],
output_type: DataType,
scale_type: DataType,
) -> Result<DynamicQuantizeLayer<'network>> {
check_network!(self, input);
debug!(
"add_dynamic_quantize_v2 input={} block_shape={block_shape:?} output_type={output_type:?} scale_type={scale_type:?}",
tensor_dbg(self, input)
);
let block_shape_dims = Dims64::from_slice(block_shape);
let layer_ptr = self.inner.pin_mut().addDynamicQuantizeV2(
input.pin_mut(),
&block_shape_dims,
output_type.into(),
scale_type.into(),
);
DynamicQuantizeLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_select(
&mut self,
condition: &'_ Tensor,
then_input: &'_ Tensor,
else_input: &'_ Tensor,
) -> Result<SelectLayer<'network>> {
check_network!(self, condition);
check_network!(self, then_input);
check_network!(self, else_input);
debug!(
"add_select condition={} then_input={} else_input={}",
tensor_dbg(self, condition),
tensor_dbg(self, then_input),
tensor_dbg(self, else_input)
);
let layer_ptr = self.inner.pin_mut().addSelect(
condition.pin_mut(),
then_input.pin_mut(),
else_input.pin_mut(),
);
SelectLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_padding(
&mut self,
input: &'_ Tensor,
pre_padding: &[i64],
post_padding: &[i64],
) -> Result<PaddingLayer<'network>> {
check_network!(self, input);
debug!(
"add_padding input={} pre_padding={pre_padding:?} post_padding={post_padding:?}",
tensor_dbg(self, input)
);
if pre_padding.len() != post_padding.len() {
return Err(Error::Runtime(
"pre_padding and post_padding must have the same length".to_string(),
));
}
let pre_dims = trtx_sys::Dims::from_slice(pre_padding);
let post_dims = trtx_sys::Dims::from_slice(post_padding);
let layer_ptr = self
.inner
.pin_mut()
.addPaddingNd(input.pin_mut(), &pre_dims, &post_dims);
PaddingLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_fill(
&mut self,
dimensions: &[i64],
op: FillOperation,
output_type: DataType,
) -> Result<FillLayer<'network>> {
debug!("add_fill dimensions={dimensions:?} op={op:?} output_type={output_type:?}");
let dims = Dims64::from_slice(dimensions);
let layer_ptr = self
.inner
.pin_mut()
.addFill(&dims, op.into(), output_type.into());
FillLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_one_hot(
&mut self,
indices: &'_ Tensor,
values: &'_ Tensor,
depth: &'_ Tensor,
axis: i32,
) -> Result<OneHotLayer<'network>> {
check_network!(self, indices);
check_network!(self, values);
check_network!(self, depth);
debug!(
"add_one_hot indices={} values={} depth={} axis={axis}",
tensor_dbg(self, indices),
tensor_dbg(self, values),
tensor_dbg(self, depth)
);
let layer_ptr = self.inner.pin_mut().addOneHot(
indices.pin_mut(),
values.pin_mut(),
depth.pin_mut(),
axis,
);
OneHotLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_non_zero(
&mut self,
input: &'_ Tensor,
indices_type: DataType,
) -> Result<NonZeroLayer<'network>> {
check_network!(self, input);
debug!(
"add_non_zero input={} indices_type={indices_type:?}",
tensor_dbg(self, input)
);
let layer_ptr = self
.inner
.pin_mut()
.addNonZero1(input.pin_mut(), indices_type.into());
NonZeroLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_nms(
&mut self,
boxes: &'_ Tensor,
scores: &'_ Tensor,
max_output_boxes_per_class: &'_ Tensor,
indices_type: DataType,
) -> Result<NMSLayer<'network>> {
check_network!(self, boxes);
check_network!(self, scores);
check_network!(self, max_output_boxes_per_class);
debug!(
"add_nms boxes={} scores={} max_output_boxes_per_class={} indices_type={indices_type:?}",
tensor_dbg(self, boxes),
tensor_dbg(self, scores),
tensor_dbg(self, max_output_boxes_per_class)
);
let layer_ptr = self.inner.pin_mut().addNMS1(
boxes.pin_mut(),
scores.pin_mut(),
max_output_boxes_per_class.pin_mut(),
indices_type.into(),
);
NMSLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_reverse_sequence(
&mut self,
input: &'_ Tensor,
sequence_lens: &'_ Tensor,
) -> Result<ReverseSequenceLayer<'network>> {
check_network!(self, input);
check_network!(self, sequence_lens);
debug!(
"add_reverse_sequence input={} sequence_lens={}",
tensor_dbg(self, input),
tensor_dbg(self, sequence_lens)
);
let layer_ptr = self
.inner
.pin_mut()
.addReverseSequence(input.pin_mut(), sequence_lens.pin_mut());
ReverseSequenceLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_squeeze(
&mut self,
input: &'_ Tensor,
axes: &'_ Tensor,
) -> Result<SqueezeLayer<'network>> {
check_network!(self, input);
check_network!(self, axes);
debug!(
"add_squeeze input={} axes={}",
tensor_dbg(self, input),
tensor_dbg(self, axes)
);
let layer_ptr = self
.inner
.pin_mut()
.addSqueeze(input.pin_mut(), axes.pin_mut());
SqueezeLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_unsqueeze(
&mut self,
input: &'_ Tensor,
axes: &'_ Tensor,
) -> Result<UnsqueezeLayer<'network>> {
check_network!(self, input);
check_network!(self, axes);
debug!(
"add_unsqueeze input={} axes={}",
tensor_dbg(self, input),
tensor_dbg(self, axes)
);
let layer_ptr = self
.inner
.pin_mut()
.addUnsqueeze(input.pin_mut(), axes.pin_mut());
UnsqueezeLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_kv_cache_update(
&mut self,
cache: &'_ Tensor,
update: &'_ Tensor,
write_indices: &'_ Tensor,
cache_mode: KVCacheMode,
) -> Result<KVCacheUpdateLayer<'network>> {
check_network!(self, cache);
check_network!(self, update);
check_network!(self, write_indices);
debug!(
"add_kv_cache_update cache={} update={} write_indices={} cache_mode={cache_mode:?}",
tensor_dbg(self, cache),
tensor_dbg(self, update),
tensor_dbg(self, write_indices)
);
let layer_ptr = self.inner.pin_mut().addKVCacheUpdate(
cache.pin_mut(),
update.pin_mut(),
write_indices.pin_mut(),
cache_mode.into(),
);
KVCacheUpdateLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_assertion(
&mut self,
condition: &'_ Tensor,
message: &str,
) -> Result<AssertionLayer<'network>> {
check_network!(self, condition);
debug!(
"add_assertion condition={} message={message:?}",
tensor_dbg(self, condition)
);
let message_cstr = std::ffi::CString::new(message)?;
let layer_ptr = unsafe {
self.inner
.pin_mut()
.addAssertion(condition.pin_mut(), message_cstr.as_ptr())
};
AssertionLayer::new(self.inner.as_ptr(), layer_ptr)
}
pub fn add_loop(&mut self) -> Result<Loop<'network>> {
debug!("add_loop");
let loop_ptr = self.inner.pin_mut().addLoop();
let loop_ptr = unsafe { loop_ptr.as_mut() }
.ok_or_else(|| Error::Runtime("Failed to add loop".to_string()))?;
Ok(Loop {
inner: unsafe { Pin::new_unchecked(loop_ptr) },
network: self.inner.as_ptr(),
})
}
pub fn add_if_conditional(&mut self) -> Result<IfConditional<'network>> {
debug!("add_if_conditional");
let if_ptr = self.inner.pin_mut().addIfConditional();
let if_ptr = unsafe { if_ptr.as_mut() }
.ok_or_else(|| Error::Runtime("Failed to add if conditional".to_string()))?;
Ok(IfConditional {
inner: unsafe { Pin::new_unchecked(if_ptr) },
network: self.inner.as_ptr(),
})
}
#[cfg(feature = "v_1_5")]
pub fn add_attention_v2(
&mut self,
query: &'_ Tensor,
key: &'_ Tensor,
value: &'_ Tensor,
norm_op: trtx_sys::AttentionNormalizationOp,
causal_kind: CausalMaskKind,
) -> Result<Attention<'network>> {
check_network!(self, query);
check_network!(self, key);
check_network!(self, value);
debug!(
"add_attention_v2 query={} key={} value={} norm_op={norm_op:?} causal_kind={causal_kind:?}",
tensor_dbg(self, query),
tensor_dbg(self, key),
tensor_dbg(self, value)
);
let attn_ptr = self.inner.pin_mut().addAttentionV2(
query.pin_mut(),
key.pin_mut(),
value.pin_mut(),
norm_op.into(),
causal_kind.into(),
);
let attn = unsafe { attn_ptr.as_mut() }
.ok_or_else(|| Error::Runtime("Failed to add attention".to_string()))?;
Ok(Attention {
inner: unsafe { Pin::new_unchecked(attn) },
network: self.inner.as_ptr(),
})
}
pub fn add_attention(
&mut self,
query: &'_ Tensor,
key: &'_ Tensor,
value: &'_ Tensor,
norm_op: trtx_sys::AttentionNormalizationOp,
causal: bool,
) -> Result<Attention<'network>> {
check_network!(self, query);
check_network!(self, key);
check_network!(self, value);
debug!(
"add_attention query={} key={} value={} norm_op={norm_op:?} causal={causal}",
tensor_dbg(self, query),
tensor_dbg(self, key),
tensor_dbg(self, value)
);
let attn_ptr = self.inner.pin_mut().addAttention(
query.pin_mut(),
key.pin_mut(),
value.pin_mut(),
norm_op.into(),
causal,
);
let attn = unsafe { attn_ptr.as_mut() }
.ok_or_else(|| Error::Runtime("Failed to add attention".to_string()))?;
Ok(Attention {
inner: unsafe { Pin::new_unchecked(attn) },
network: self.inner.as_ptr(),
})
}
#[cfg(feature = "v_1_4")]
pub fn add_moe(
&mut self,
hidden_states: &Tensor,
selected_experts_for_tokens: &Tensor,
scores_for_selected_experts: &Tensor,
) -> Result<MoELayer<'network>> {
check_network!(self, hidden_states);
check_network!(self, selected_experts_for_tokens);
check_network!(self, scores_for_selected_experts);
debug!(
"add_moe hidden_states={} selected_experts_for_tokens={} scores_for_selected_experts={}",
tensor_dbg(self, hidden_states),
tensor_dbg(self, selected_experts_for_tokens),
tensor_dbg(self, scores_for_selected_experts)
);
let layer_ptr = self.inner.pin_mut().addMoE(
hidden_states.pin_mut(),
selected_experts_for_tokens.pin_mut(),
scores_for_selected_experts.pin_mut(),
);
MoELayer::new(self.inner.as_ptr(), layer_ptr)
}
#[cfg(feature = "v_1_4")]
pub fn add_dist_collective(
&mut self,
input: &Tensor,
dist_collective_op: CollectiveOperation,
reduce_op: ReduceOperation,
root: i64,
groups: &[i64],
) -> Result<DistCollectiveLayer<'network>> {
check_network!(self, input);
debug!(
"add_dist_collective input={} dist_collective_op={dist_collective_op:?} reduce_op={reduce_op:?} root={root} groups={groups:?}",
tensor_dbg(self, input)
);
let (groups_ptr, group_size) = if groups.is_empty() {
(std::ptr::null_mut(), 0i64)
} else {
(groups.as_ptr() as *mut i64, groups.len() as i64)
};
let layer_ptr = unsafe {
self.inner.pin_mut().addDistCollective(
input.pin_mut(),
dist_collective_op.into(),
reduce_op.into(),
root,
groups_ptr,
group_size,
)
};
DistCollectiveLayer::new(self.inner.as_ptr(), layer_ptr)
}
}
#[derive(Debug)]
pub struct NetworkInputIter<'a, 'network> {
network: &'a NetworkDefinition<'network>,
index: i32,
count: i32,
}
impl<'network> Iterator for NetworkInputIter<'_, 'network> {
type Item = Tensor<'network>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.count {
return None;
}
let tensor = self.network.input(self.index).expect("valid input index");
self.index += 1;
Some(tensor)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = (self.count - self.index).max(0) as usize;
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for NetworkInputIter<'_, '_> {}
#[derive(Debug)]
pub struct NetworkOutputIter<'a, 'network> {
network: &'a NetworkDefinition<'network>,
index: i32,
count: i32,
}
impl<'network> Iterator for NetworkOutputIter<'_, 'network> {
type Item = Tensor<'network>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.count {
return None;
}
let tensor = self.network.output(self.index).expect("valid output index");
self.index += 1;
Some(tensor)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = (self.count - self.index).max(0) as usize;
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for NetworkOutputIter<'_, '_> {}
#[derive(Debug)]
pub struct NetworkLayerIter<'a, 'network> {
network: &'a NetworkDefinition<'network>,
index: i32,
count: i32,
}
impl<'network> Iterator for NetworkLayerIter<'_, 'network> {
type Item = DynLayer<'network>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.count {
return None;
}
let layer = self.network.layer(self.index).expect("valid layer index");
self.index += 1;
Some(layer)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = (self.count - self.index).max(0) as usize;
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for NetworkLayerIter<'_, '_> {}
impl<'network> Attention<'network> {
pub fn set_normalization_operation(
&mut self,
network: &mut NetworkDefinition,
op: trtx_sys::AttentionNormalizationOp,
) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setNormalizationOperation(op.into())
.ok_or_err(PropertySetAttempt::AttentionLayerNormalizationOp)
}
pub fn normalization_operation(
&self,
network: &NetworkDefinition,
) -> trtx_sys::AttentionNormalizationOp {
check_network!(network, self);
self.inner.getNormalizationOperation().into()
}
pub fn set_mask(&mut self, network: &mut NetworkDefinition, mask: &Tensor) -> Result<()> {
check_network!(network, self);
check_network!(network, mask);
self.inner
.as_mut()
.setMask(mask.pin_mut())
.ok_or_err(PropertySetAttempt::AttentionLayerMask)
}
pub fn mask(&mut self, network: &mut NetworkDefinition) -> Option<Tensor<'network>> {
check_network!(network, self);
let p = self.inner.as_mut().getMask();
if p.is_null() {
None
} else {
unsafe { Tensor::new(self.network, p).ok() }
}
}
pub fn set_causal(&mut self, network: &mut NetworkDefinition, is_causal: bool) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setCausal(is_causal)
.ok_or_err(PropertySetAttempt::AttentionLayerCausal)
}
pub fn causal(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.inner.getCausal()
}
#[cfg(feature = "v_1_5")]
pub fn set_causal_kind(
&mut self,
network: &mut NetworkDefinition,
kind: CausalMaskKind,
) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setCausalKind(kind.into())
.ok_or_err(PropertySetAttempt::AttentionLayerCausalKind)
}
#[cfg(feature = "v_1_5")]
pub fn causal_kind(&self, network: &NetworkDefinition) -> CausalMaskKind {
check_network!(network, self);
self.inner.getCausalKind().into()
}
pub fn set_decomposable(
&mut self,
network: &mut NetworkDefinition,
decomposable: bool,
) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setDecomposable(decomposable)
.ok_or_err(PropertySetAttempt::AttentionLayerDecomposable)
}
pub fn decomposable(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.inner.getDecomposable()
}
pub fn set_input(
&mut self,
network: &mut NetworkDefinition,
index: i32,
input: &Tensor,
) -> Result<()> {
check_network!(network, self);
check_network!(network, input);
self.inner
.as_mut()
.setInput(index, input.pin_mut())
.ok_or_err(PropertySetAttempt::AttentionLayerInput)
}
pub fn num_inputs(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.getNbInputs()
}
pub fn input(&self, network: &NetworkDefinition, index: i32) -> Result<Tensor<'network>> {
check_network!(network, self);
let tensor_ptr = self.inner.getInput(index);
unsafe { Tensor::new(self.network, tensor_ptr) }
}
#[deprecated = "use input instead"]
pub fn get_input(&self, network: &NetworkDefinition, index: i32) -> Result<Tensor<'network>> {
self.input(network, index)
}
pub fn num_outputs(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.getNbOutputs()
}
pub fn output(&self, network: &NetworkDefinition, index: i32) -> Result<Tensor<'network>> {
check_network!(network, self);
let tensor_ptr = self.inner.getOutput(index);
unsafe { Tensor::new(self.network, tensor_ptr) }
}
#[deprecated = "use output instead"]
pub fn get_output(&self, network: &NetworkDefinition, index: i32) -> Result<Tensor<'network>> {
self.output(network, index)
}
pub fn set_name(&mut self, network: &mut NetworkDefinition, name: &str) -> Result<()> {
check_network!(network, self);
let name = CString::new(name)?;
unsafe { self.inner.as_mut().setName(name.as_ptr()) }
.ok_or_err(PropertySetAttempt::AttentionLayerName)
}
pub fn name(&self, network: &NetworkDefinition) -> String {
check_network!(network, self);
let name = self.inner.getName();
if name.is_null() {
"(unamed)".to_string()
} else {
unsafe { CStr::from_ptr(name).to_string_lossy().to_string() }
}
}
pub fn set_normalization_quantize_scale(
&mut self,
network: &mut NetworkDefinition,
tensor: &Tensor,
) -> Result<()> {
check_network!(network, self);
check_network!(network, tensor);
self.inner
.as_mut()
.setNormalizationQuantizeScale(tensor.pin_mut())
.ok_or_err(PropertySetAttempt::AttentionLayerQuantizeScale)
}
pub fn normalization_quantize_scale(
&self,
network: &NetworkDefinition,
) -> Option<Tensor<'network>> {
check_network!(network, self);
let p = self.inner.getNormalizationQuantizeScale();
if p.is_null() {
None
} else {
unsafe { Tensor::new(self.network, p).ok() }
}
}
pub fn set_normalization_quantize_to_type(
&mut self,
network: &mut NetworkDefinition,
type_: DataType,
) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setNormalizationQuantizeToType(type_.into())
.ok_or_err(PropertySetAttempt::AttentionLayerQuantizeToType)
}
pub fn normalization_quantize_to_type(&self, network: &NetworkDefinition) -> DataType {
check_network!(network, self);
self.inner.getNormalizationQuantizeToType().into()
}
#[deprecated = "use normalization_quantize_to_type instead"]
pub fn get_normalization_quantize_to_type(&self, network: &NetworkDefinition) -> DataType {
self.normalization_quantize_to_type(network)
}
pub fn set_metadata(&mut self, network: &mut NetworkDefinition, metadata: &str) -> Result<()> {
check_network!(network, self);
let metadata_cstr = CString::new(metadata)?;
unsafe { self.inner.as_mut().setMetadata(metadata_cstr.as_ptr()) }
.ok_or_err(PropertySetAttempt::AttentionLayerMetadata)
}
pub fn metadata(&self, network: &NetworkDefinition) -> String {
check_network!(network, self);
let p = self.inner.getMetadata();
if p.is_null() {
String::new()
} else {
unsafe { CStr::from_ptr(p).to_string_lossy().to_string() }
}
}
#[cfg(feature = "v_1_4")]
pub fn set_nb_ranks(&mut self, network: &mut NetworkDefinition, nb_ranks: i32) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setNbRanks(nb_ranks)
.ok_or_err(PropertySetAttempt::AttentionLayerNumRanks)
}
#[cfg(feature = "v_1_4")]
pub fn nb_ranks(&self, network: &NetworkDefinition) -> i32 {
check_network!(network, self);
self.inner.getNbRanks()
}
#[deprecated = "use nb_ranks instead"]
pub fn get_nb_ranks(&self, network: &NetworkDefinition) -> i32 {
self.nb_ranks(network)
}
#[cfg(feature = "v_1_5")]
pub fn set_query_form(
&mut self,
network: &mut NetworkDefinition,
form: AttentionIOForm,
) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setQueryForm(form.into())
.ok_or_err(PropertySetAttempt::AttentionLayerQueryForm)
}
#[cfg(feature = "v_1_5")]
pub fn set_key_value_form(
&mut self,
network: &mut NetworkDefinition,
form: AttentionIOForm,
) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setKeyValueForm(form.into())
.ok_or_err(PropertySetAttempt::AttentionLayerKeyValueForm)
}
#[cfg(feature = "v_1_5")]
pub fn query_form(&self, network: &NetworkDefinition) -> AttentionIOForm {
check_network!(network, self);
self.inner.getQueryForm().into()
}
#[cfg(feature = "v_1_5")]
pub fn key_value_form(&self, network: &NetworkDefinition) -> AttentionIOForm {
check_network!(network, self);
self.inner.getQueryForm().into()
}
#[cfg(feature = "v_1_5")]
pub fn query_lengths(&self, network: &NetworkDefinition) -> Result<Tensor<'network>> {
check_network!(network, self);
unsafe { Tensor::new(self.network, self.inner.getQueryLengths()) }
}
#[cfg(feature = "v_1_5")]
pub fn key_value_lengths(&self, network: &NetworkDefinition) -> Result<Tensor<'network>> {
check_network!(network, self);
unsafe { Tensor::new(self.network, self.inner.getKeyValueLengths()) }
}
#[cfg(feature = "v_1_5")]
pub fn set_query_lengths(
&mut self,
network: &mut NetworkDefinition,
lengths: &'_ Tensor,
) -> Result<()> {
check_network!(network, self);
check_network!(network, lengths);
unsafe {
self.inner
.as_mut()
.setQueryLengths(lengths.inner)
.ok_or_err(PropertySetAttempt::AttentionLayerQueryLengths)
}
}
#[cfg(feature = "v_1_5")]
pub fn set_key_value_lengths(
&mut self,
network: &mut NetworkDefinition,
lengths: &'_ Tensor,
) -> Result<()> {
check_network!(network, self);
check_network!(network, lengths);
unsafe {
self.inner
.as_mut()
.setKeyValueLengths(lengths.inner)
.ok_or_err(PropertySetAttempt::AttentionLayerKeyValueLengths)
}
}
}
impl<'network> KVCacheUpdateLayer<'network> {
pub fn set_cache_mode(
&mut self,
network: &mut NetworkDefinition,
mode: KVCacheMode,
) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setCacheMode(mode.into())
.ok_or_err(PropertySetAttempt::KVCacheUpdateMode)
}
pub fn cache_mode(&self, network: &NetworkDefinition) -> KVCacheMode {
check_network!(network, self);
self.inner.getCacheMode().into()
}
#[cfg(feature = "v_1_5")]
pub fn set_update_form(
&mut self,
network: &mut NetworkDefinition,
form: AttentionIOForm,
) -> Result<()> {
check_network!(network, self);
self.inner
.as_mut()
.setUpdateForm(form.into())
.ok_or_err(PropertySetAttempt::KVCacheUpdateUpdateForm)
}
#[cfg(feature = "v_1_5")]
pub fn update_form(&self, network: &NetworkDefinition) -> AttentionIOForm {
check_network!(network, self);
self.inner.getUpdateForm().into()
}
#[cfg(feature = "v_1_5")]
pub fn set_update_lengths(
&mut self,
network: &mut NetworkDefinition,
lengths: &Tensor,
) -> Result<()> {
check_network!(network, self);
check_network!(network, lengths);
unsafe {
self.inner
.as_mut()
.setUpdateLengths(lengths.inner)
.ok_or_err(PropertySetAttempt::KVCacheUpdateLayerUpdateLengths)
}
}
#[cfg(feature = "v_1_5")]
pub fn update_lengths(&self, network: &NetworkDefinition) -> Result<Tensor<'network>> {
check_network!(network, self);
unsafe { Tensor::new(self.network, self.inner.getUpdateLengths()) }
}
}
impl<'network> Loop<'network> {
pub fn add_recurrence(
&mut self,
network: &mut NetworkDefinition,
initial_value: &'_ Tensor,
) -> Result<RecurrenceLayer<'network>> {
check_network!(network, self);
check_network!(network, initial_value);
debug!(
"Loop::add_recurrence initial_value={}",
tensor_dbg(network, initial_value)
);
let layer_ptr = { self.inner.as_mut().addRecurrence(initial_value.pin_mut()) };
RecurrenceLayer::new(network.inner.as_ptr(), layer_ptr)
}
pub fn add_trip_limit(
&mut self,
network: &mut NetworkDefinition,
tensor: &'_ Tensor,
limit: trtx_sys::TripLimit,
) -> Result<TripLimitLayer<'network>> {
check_network!(network, self);
check_network!(network, tensor);
debug!(
"Loop::add_trip_limit tensor={} limit={limit:?}",
tensor_dbg(network, tensor)
);
let layer_ptr = {
self.inner
.as_mut()
.addTripLimit(tensor.pin_mut(), limit.into())
};
TripLimitLayer::new(network.inner.as_ptr(), layer_ptr)
}
pub fn add_iterator(
&mut self,
network: &mut NetworkDefinition,
tensor: &'_ Tensor,
axis: i32,
reverse: bool,
) -> Result<IteratorLayer<'network>> {
check_network!(network, self);
check_network!(network, tensor);
debug!(
"Loop::add_iterator tensor={} axis={axis} reverse={reverse}",
tensor_dbg(network, tensor)
);
let layer_ptr = {
self.inner
.as_mut()
.addIterator(tensor.pin_mut(), axis, reverse)
};
IteratorLayer::new(network.inner.as_ptr(), layer_ptr)
}
pub fn add_loop_output(
&mut self,
network: &mut NetworkDefinition,
tensor: &'_ Tensor,
output_kind: trtx_sys::LoopOutput,
axis: i32,
) -> Result<LoopOutputLayer<'network>> {
check_network!(network, self);
check_network!(network, tensor);
debug!(
"Loop::add_loop_output tensor={} output_kind={output_kind:?} axis={axis}",
tensor_dbg(network, tensor)
);
let layer_ptr =
self.inner
.as_mut()
.addLoopOutput(tensor.pin_mut(), output_kind.into(), axis);
LoopOutputLayer::new(network.inner.as_ptr(), layer_ptr)
}
}
impl<'network> IfConditional<'network> {
pub fn set_condition(
&mut self,
network: &mut NetworkDefinition,
condition: &'_ Tensor,
) -> Result<ConditionLayer<'network>> {
check_network!(network, self);
check_network!(network, condition);
let layer_ptr = self.inner.as_mut().setCondition(condition.pin_mut());
ConditionLayer::new(network.inner.as_ptr(), layer_ptr)
}
pub fn add_input(
&mut self,
network: &mut NetworkDefinition,
input: &'_ Tensor,
) -> Result<IfConditionalInputLayer<'network>> {
check_network!(network, self);
check_network!(network, input);
debug!(
"IfConditional::add_input input={}",
tensor_dbg(network, input)
);
let layer_ptr = self.inner.as_mut().addInput(input.pin_mut());
IfConditionalInputLayer::new(network.inner.as_ptr(), layer_ptr)
}
pub fn add_output(
&mut self,
network: &mut NetworkDefinition,
true_output: &'_ Tensor,
false_output: &'_ Tensor,
) -> Result<IfConditionalOutputLayer<'network>> {
check_network!(network, self);
check_network!(network, true_output);
check_network!(network, false_output);
debug!(
"IfConditional::add_output true_output={} false_output={}",
tensor_dbg(network, true_output),
tensor_dbg(network, false_output)
);
let layer_ptr = self
.inner
.as_mut()
.addOutput(true_output.pin_mut(), false_output.pin_mut());
IfConditionalOutputLayer::new(network.inner.as_ptr(), layer_ptr)
}
}
impl<'network> RecurrenceLayer<'network> {}
impl IteratorLayer<'_> {
pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) {
check_network!(network, self);
self.inner.as_mut().setAxis(axis);
}
pub fn set_reverse(&mut self, network: &mut NetworkDefinition, reverse: bool) {
check_network!(network, self);
self.inner.as_mut().setReverse(reverse);
}
}
impl LoopOutputLayer<'_> {
pub fn loop_output(&self, network: &NetworkDefinition) -> trtx_sys::nvinfer1::LoopOutput {
check_network!(network, self);
self.inner.as_ref().getLoopOutput()
}
#[deprecated = "use loop_output instead"]
pub fn get_loop_output(&self, network: &NetworkDefinition) -> trtx_sys::nvinfer1::LoopOutput {
self.loop_output(network)
}
pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) {
check_network!(network, self);
self.inner.as_mut().setAxis(axis);
}
}
impl TripLimitLayer<'_> {
pub fn trip_limit(&self, network: &NetworkDefinition) -> trtx_sys::nvinfer1::TripLimit {
check_network!(network, self);
self.inner.as_ref().getTripLimit()
}
#[deprecated = "use trip_limit instead"]
pub fn get_trip_limit(&self, network: &NetworkDefinition) -> trtx_sys::nvinfer1::TripLimit {
self.trip_limit(network)
}
}
impl<'builder> NetworkDefinition<'builder> {
pub fn add_normalization(
&mut self,
input: &'_ Tensor,
scale: &'_ Tensor,
bias: &'_ Tensor,
axes_mask: crate::Axes,
) -> Result<NormalizationLayer<'builder>> {
check_network!(self, input);
check_network!(self, scale);
check_network!(self, bias);
debug!(
"add_normalization input={} scale={} bias={} axes_mask={axes_mask:?}",
tensor_dbg(self, input),
tensor_dbg(self, scale),
tensor_dbg(self, bias)
);
let axes_bits = axes_mask.to_bits();
let ptr = self.inner.pin_mut().addNormalization(
input.pin_mut(),
scale.pin_mut(),
bias.pin_mut(),
axes_bits,
);
NormalizationLayer::new(self.inner.as_ptr(), ptr)
}
pub fn add_normalization_v2(
&mut self,
input: &'_ Tensor,
scale: &'_ Tensor,
bias: &'_ Tensor,
axes_mask: crate::Axes,
) -> Result<NormalizationLayer<'builder>> {
check_network!(self, input);
check_network!(self, scale);
check_network!(self, bias);
debug!(
"add_normalization_v2 input={} scale={} bias={} axes_mask={axes_mask:?}",
tensor_dbg(self, input),
tensor_dbg(self, scale),
tensor_dbg(self, bias)
);
let axes_bits = axes_mask.to_bits();
let ptr = self.inner.pin_mut().addNormalizationV2(
input.pin_mut(),
scale.pin_mut(),
bias.pin_mut(),
axes_bits,
);
NormalizationLayer::new(self.inner.as_ptr(), ptr)
}
pub fn set_error_recorder(&mut self, error_recorder: Box<dyn RecordError>) -> Result<()> {
let error_recorder = ErrorRecorder::new(error_recorder)?;
if self.error_recorder.is_some() {
panic!("Setting a progress monitor more than once not supported at the moment");
}
self.error_recorder = Some(error_recorder);
let rec = self
.error_recorder
.as_mut()
.unwrap()
.as_trt_error_recorder();
#[cfg(not(feature = "mock"))]
unsafe {
self.inner.pin_mut().setErrorRecorder(rec)
};
Ok(())
}
pub fn add_grid_sample(
&mut self,
input: &'_ Tensor,
grid: &'_ Tensor,
) -> Result<GridSampleLayer<'builder>> {
check_network!(self, input);
check_network!(self, grid);
let ptr = self
.inner
.pin_mut()
.addGridSample(input.pin_mut(), grid.pin_mut());
GridSampleLayer::new(self.inner.as_ptr(), ptr)
}
}
impl<'network> GridSampleLayer<'network> {
pub fn set_interpolation_mode(
&mut self,
network: &mut NetworkDefinition,
mode: InterpolationMode,
) {
check_network!(network, self);
self.inner.as_mut().setInterpolationMode(mode.into());
}
pub fn interpolation_mode(&self, network: &NetworkDefinition) -> InterpolationMode {
check_network!(network, self);
self.inner.getInterpolationMode().into()
}
pub fn set_sample_mode(&mut self, network: &mut NetworkDefinition, mode: SampleMode) {
check_network!(network, self);
self.inner.as_mut().setSampleMode(mode.into());
}
pub fn sample_mode(&self, network: &NetworkDefinition) -> SampleMode {
check_network!(network, self);
self.inner.getSampleMode().into()
}
pub fn set_align_corners(&mut self, network: &mut NetworkDefinition, align_corners: bool) {
check_network!(network, self);
self.inner.as_mut().setAlignCorners(align_corners);
}
pub fn align_corners(&self, network: &NetworkDefinition) -> bool {
check_network!(network, self);
self.inner.getAlignCorners()
}
}
impl<'network> CastLayer<'network> {
pub fn set_to_type(&mut self, network: &mut NetworkDefinition<'_>, data_type: DataType) {
check_network!(network, self);
self.inner.as_mut().setToType(data_type.into())
}
pub fn to_type(&self, network: &NetworkDefinition<'_>) -> DataType {
check_network!(network, self);
self.inner.getToType().into()
}
}
#[cfg(test)]
mod test {
use trtx_sys::LayerType;
use crate::{Builder, Logger};
#[test]
#[cfg(not(feature = "mock"))]
fn test_get_layer() {
let logger = Logger::stderr().unwrap();
let mut builder = Builder::new(&logger).unwrap();
let mut network = builder.create_network(0).unwrap();
let input = network
.add_input("a", trtx_sys::DataType::kFLOAT, &[1])
.unwrap();
let a = network
.add_activation(&input, trtx_sys::ActivationType::kRELU)
.unwrap()
.output(&network, 0)
.unwrap();
let b = network
.add_activation(&a, trtx_sys::ActivationType::kRELU)
.unwrap()
.output(&network, 0)
.unwrap();
let c = network
.add_activation(&b, trtx_sys::ActivationType::kRELU)
.unwrap()
.output(&network, 0)
.unwrap();
a.set_name(&mut network, "Fritz").unwrap();
b.set_name(&mut network, "Adam").unwrap();
c.set_name(&mut network, "James").unwrap();
assert_eq!(
&network
.layer(0)
.unwrap()
.output(&network, 0)
.unwrap()
.name(&network)
.unwrap(),
"Fritz"
);
assert_eq!(
&network
.layer(1)
.unwrap()
.output(&network, 0)
.unwrap()
.name(&network)
.unwrap(),
"Adam"
);
assert_eq!(
&network
.layer(2)
.unwrap()
.output(&network, 0)
.unwrap()
.name(&network)
.unwrap(),
"James"
);
assert_eq!(
network.layer(2).unwrap().layer_type_dynamic(),
LayerType::kACTIVATION
);
network
.layer(1)
.unwrap()
.set_name(&mut network, "Eva")
.unwrap();
assert_eq!(
&network
.layer(1)
.unwrap()
.output(&network, 0)
.unwrap()
.name(&network)
.unwrap(),
&network
.layer(2)
.unwrap()
.input(&network, 0)
.unwrap()
.name(&network)
.unwrap(),
);
assert_eq!(
"Adam",
&network
.layer(2)
.unwrap()
.input(&network, 0)
.unwrap()
.name(&network)
.unwrap()
);
assert_eq!(&network.layer(1).unwrap().name(&network), "Eva");
}
#[test]
#[cfg(not(feature = "mock"))]
fn test_inputs_outputs_iter() {
let logger = Logger::stderr().unwrap();
let mut builder = Builder::new(&logger).unwrap();
let mut network = builder.create_network(0).unwrap();
let a = network
.add_input("input_a", trtx_sys::DataType::kFLOAT, &[1])
.unwrap();
let b = network
.add_input("input_b", trtx_sys::DataType::kFLOAT, &[1])
.unwrap();
let out = network
.add_elementwise(&a, &b, trtx_sys::ElementWiseOperation::kSUM)
.unwrap()
.output(&network, 0)
.unwrap();
out.set_name(&mut network, "output_c").unwrap();
network.mark_output(&out);
let input_names: Vec<_> = network
.inputs()
.map(|t| t.name(&network).unwrap())
.collect();
assert_eq!(input_names, ["input_a", "input_b"]);
assert_eq!(network.inputs().len(), 2);
let output_names: Vec<_> = network
.outputs()
.map(|t| t.name(&network).unwrap())
.collect();
assert_eq!(output_names, ["output_c"]);
assert_eq!(network.outputs().len(), 1);
let mut old_style = Vec::new();
for i in 0..network.nb_inputs() {
old_style.push(network.input(i).unwrap().name(&network).unwrap());
}
assert_eq!(input_names, old_style);
}
#[test]
#[cfg(not(feature = "mock"))]
fn test_layers_iter() {
use trtx_sys::LayerType;
let logger = Logger::stderr().unwrap();
let mut builder = Builder::new(&logger).unwrap();
let mut network = builder.create_network(0).unwrap();
let input = network
.add_input("a", trtx_sys::DataType::kFLOAT, &[1])
.unwrap();
network
.add_activation(&input, trtx_sys::ActivationType::kRELU)
.unwrap();
network
.add_activation(&input, trtx_sys::ActivationType::kSIGMOID)
.unwrap();
assert_eq!(network.layers().len(), 2);
let types: Vec<_> = network.layers().map(|l| l.layer_type_dynamic()).collect();
assert_eq!(types, [LayerType::kACTIVATION, LayerType::kACTIVATION]);
}
}