Skip to main content

atomr_accel_tensorrt/
builder.rs

1//! Safe wrappers for `nvinfer1::IBuilder` and
2//! `nvinfer1::IBuilderConfig`.
3//!
4//! Construction is GPU-free and panics-free even when libnvinfer is
5//! not installed: the `IBuilderConfig` struct is a pure-Rust value
6//! that records the requested knobs and is later replayed against the
7//! C++ builder via the FFI shim under `tensorrt-link`.
8
9use bitflags::bitflags;
10
11bitflags! {
12    /// Mirror of `nvinfer1::BuilderFlag` as a bitfield. Each
13    /// flag toggles a single TensorRT optimisation knob; combine with
14    /// `|`.
15    #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
16    pub struct BuilderFlags: u32 {
17        const FP16                       = 1 <<  0;
18        const INT8                       = 1 <<  1;
19        const DEBUG_KERNELS              = 1 <<  2;
20        const GPU_FALLBACK               = 1 <<  3;
21        const REFIT                      = 1 <<  4;
22        const DISABLE_TIMING_CACHE       = 1 <<  5;
23        const TF32                       = 1 <<  6;
24        const SPARSE_WEIGHTS             = 1 <<  7;
25        const SAFETY_SCOPE               = 1 <<  8;
26        const OBEY_PRECISION_CONSTRAINTS = 1 <<  9;
27        const PREFER_PRECISION_CONSTRAINTS = 1 << 10;
28        const DIRECT_IO                  = 1 << 11;
29        const REJECT_EMPTY_ALGORITHMS    = 1 << 12;
30        const BF16                       = 1 << 13;
31        const FP8                        = 1 << 14;
32        const STRIP_PLAN                 = 1 << 15;
33        const VERSION_COMPATIBLE         = 1 << 16;
34        const EXCLUDE_LEAN_RUNTIME       = 1 << 17;
35    }
36}
37
38bitflags! {
39    /// Tactic sources to enable (mirrors `nvinfer1::TacticSource`).
40    #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
41    pub struct TacticSources: u32 {
42        const CUBLAS                 = 1 << 0;
43        const CUBLAS_LT              = 1 << 1;
44        const CUDNN                  = 1 << 2;
45        const EDGE_MASK_CONVOLUTIONS = 1 << 3;
46        const JIT_CONVOLUTIONS       = 1 << 4;
47    }
48}
49
50/// High-level inference precision policy. Maps to a combination of
51/// `BuilderFlags` (e.g. `BEST` ⇒ FP16 | INT8 | TF32 | BF16 | FP8).
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
53pub enum Precision {
54    #[default]
55    Fp32,
56    Fp16,
57    Bf16,
58    Int8,
59    Fp8,
60    /// Enable everything; let the builder pick the fastest tactic.
61    Best,
62}
63
64impl Precision {
65    pub fn flags(self) -> BuilderFlags {
66        match self {
67            Precision::Fp32 => BuilderFlags::TF32,
68            Precision::Fp16 => BuilderFlags::FP16 | BuilderFlags::TF32,
69            Precision::Bf16 => BuilderFlags::BF16 | BuilderFlags::TF32,
70            Precision::Int8 => BuilderFlags::INT8 | BuilderFlags::TF32,
71            Precision::Fp8 => BuilderFlags::FP8 | BuilderFlags::FP16 | BuilderFlags::TF32,
72            Precision::Best => {
73                BuilderFlags::FP16
74                    | BuilderFlags::BF16
75                    | BuilderFlags::INT8
76                    | BuilderFlags::FP8
77                    | BuilderFlags::TF32
78            }
79        }
80    }
81}
82
83/// Default GPU/DLA target. DLA is the Jetson AI accelerator;
84/// `Dla(core)` selects a specific DLA core.
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
86pub enum DeviceType {
87    #[default]
88    Gpu,
89    Dla(i32),
90}
91
92/// Engine refit policy. `OnDemand` opts into `BuilderFlags::REFIT`,
93/// `WeightsStreaming` further enables `STRIP_PLAN` so weights live
94/// outside the engine plan.
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
96pub enum RefitPolicy {
97    #[default]
98    Disabled,
99    OnDemand,
100    WeightsStreaming,
101}
102
103/// Pure-Rust mirror of `nvinfer1::IBuilderConfig`. Holds the requested
104/// knobs in a side table; the FFI shim under `tensorrt-link` replays
105/// them against the C++ object inside `BuilderActor::build`.
106#[derive(Debug, Clone)]
107pub struct IBuilderConfig {
108    pub precision: Precision,
109    pub device_type: DeviceType,
110    /// Enable structured 2:4 sparsity (Ampere+).
111    pub structured_sparsity: bool,
112    /// Tactic-source allow-list (default: all on).
113    pub tactic_sources: TacticSources,
114    /// Persist the per-build timing cache. `None` = no cache.
115    pub timing_cache: Option<Vec<u8>>,
116    /// Engine refit policy.
117    pub refit: RefitPolicy,
118    /// Workspace memory pool budget (bytes).
119    pub workspace_bytes: usize,
120    /// DLA SRAM pool budget (bytes), only honoured when `device_type ==
121    /// Dla(_)`.
122    pub dla_sram_bytes: usize,
123    /// Extra builder flags merged in on top of `precision.flags()` —
124    /// allows callers to toggle e.g. `DEBUG_KERNELS` without losing
125    /// the high-level precision policy.
126    pub extra_flags: BuilderFlags,
127}
128
129impl Default for IBuilderConfig {
130    fn default() -> Self {
131        Self {
132            precision: Precision::default(),
133            device_type: DeviceType::default(),
134            structured_sparsity: false,
135            tactic_sources: TacticSources::CUBLAS
136                | TacticSources::CUBLAS_LT
137                | TacticSources::CUDNN
138                | TacticSources::EDGE_MASK_CONVOLUTIONS
139                | TacticSources::JIT_CONVOLUTIONS,
140            timing_cache: None,
141            refit: RefitPolicy::default(),
142            workspace_bytes: 1 << 30, // 1 GiB
143            dla_sram_bytes: 0,
144            extra_flags: BuilderFlags::empty(),
145        }
146    }
147}
148
149impl IBuilderConfig {
150    pub fn new() -> Self {
151        Self::default()
152    }
153
154    pub fn with_precision(mut self, p: Precision) -> Self {
155        self.precision = p;
156        self
157    }
158
159    pub fn with_device(mut self, dt: DeviceType) -> Self {
160        self.device_type = dt;
161        self
162    }
163
164    pub fn with_sparsity(mut self, on: bool) -> Self {
165        self.structured_sparsity = on;
166        self
167    }
168
169    pub fn with_tactic_sources(mut self, ts: TacticSources) -> Self {
170        self.tactic_sources = ts;
171        self
172    }
173
174    pub fn with_timing_cache(mut self, cache: Vec<u8>) -> Self {
175        self.timing_cache = Some(cache);
176        self
177    }
178
179    pub fn with_refit(mut self, refit: RefitPolicy) -> Self {
180        self.refit = refit;
181        self
182    }
183
184    pub fn with_workspace_bytes(mut self, bytes: usize) -> Self {
185        self.workspace_bytes = bytes;
186        self
187    }
188
189    pub fn with_extra_flags(mut self, flags: BuilderFlags) -> Self {
190        self.extra_flags = flags;
191        self
192    }
193
194    /// Compute the final `BuilderFlags` bitmask the FFI shim would
195    /// pass to `IBuilderConfig::setFlag()`. Combines the precision
196    /// policy with refit + sparsity + caller-supplied extras.
197    pub fn effective_flags(&self) -> BuilderFlags {
198        let mut f = self.precision.flags() | self.extra_flags;
199        if self.structured_sparsity {
200            f |= BuilderFlags::SPARSE_WEIGHTS;
201        }
202        match self.refit {
203            RefitPolicy::Disabled => {}
204            RefitPolicy::OnDemand => f |= BuilderFlags::REFIT,
205            RefitPolicy::WeightsStreaming => {
206                f |= BuilderFlags::REFIT | BuilderFlags::STRIP_PLAN;
207            }
208        }
209        f
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn builder_config_round_trip() {
219        let cfg = IBuilderConfig::new()
220            .with_precision(Precision::Best)
221            .with_device(DeviceType::Dla(1))
222            .with_sparsity(true)
223            .with_refit(RefitPolicy::WeightsStreaming)
224            .with_workspace_bytes(2 << 30)
225            .with_extra_flags(BuilderFlags::DEBUG_KERNELS)
226            .with_tactic_sources(TacticSources::CUBLAS | TacticSources::CUDNN)
227            .with_timing_cache(vec![1, 2, 3, 4]);
228
229        assert_eq!(cfg.precision, Precision::Best);
230        assert!(cfg.structured_sparsity);
231        assert!(matches!(cfg.refit, RefitPolicy::WeightsStreaming));
232        assert!(matches!(cfg.device_type, DeviceType::Dla(1)));
233        assert_eq!(cfg.workspace_bytes, 2 << 30);
234        assert_eq!(cfg.timing_cache.as_deref(), Some(&[1u8, 2, 3, 4][..]));
235        assert!(cfg.tactic_sources.contains(TacticSources::CUBLAS));
236        assert!(!cfg.tactic_sources.contains(TacticSources::CUBLAS_LT));
237
238        let flags = cfg.effective_flags();
239        // Best ⇒ FP16/BF16/INT8/FP8/TF32, plus REFIT|STRIP_PLAN from
240        // WeightsStreaming, plus SPARSE_WEIGHTS, plus DEBUG_KERNELS.
241        assert!(flags.contains(BuilderFlags::FP16));
242        assert!(flags.contains(BuilderFlags::BF16));
243        assert!(flags.contains(BuilderFlags::INT8));
244        assert!(flags.contains(BuilderFlags::FP8));
245        assert!(flags.contains(BuilderFlags::TF32));
246        assert!(flags.contains(BuilderFlags::REFIT));
247        assert!(flags.contains(BuilderFlags::STRIP_PLAN));
248        assert!(flags.contains(BuilderFlags::SPARSE_WEIGHTS));
249        assert!(flags.contains(BuilderFlags::DEBUG_KERNELS));
250    }
251
252    #[test]
253    fn precision_flag_mapping_is_stable() {
254        assert!(Precision::Fp16.flags().contains(BuilderFlags::FP16));
255        assert!(Precision::Bf16.flags().contains(BuilderFlags::BF16));
256        assert!(Precision::Int8.flags().contains(BuilderFlags::INT8));
257        assert!(Precision::Fp8.flags().contains(BuilderFlags::FP8));
258        let best = Precision::Best.flags();
259        for f in [
260            BuilderFlags::FP16,
261            BuilderFlags::BF16,
262            BuilderFlags::INT8,
263            BuilderFlags::FP8,
264            BuilderFlags::TF32,
265        ] {
266            assert!(best.contains(f), "Best is missing {:?}", f);
267        }
268    }
269
270    #[test]
271    fn refit_disabled_does_not_set_refit_flag() {
272        let cfg = IBuilderConfig::new();
273        let f = cfg.effective_flags();
274        assert!(!f.contains(BuilderFlags::REFIT));
275        assert!(!f.contains(BuilderFlags::STRIP_PLAN));
276    }
277}