trtx 0.7.0+rtx1.5

Safe Rust bindings to NVIDIA TensorRT-RTX (EXPERIMENTAL - NOT FOR PRODUCTION)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
//! Builder configuration for TensorRT engine builds.
//!
//! Wraps [`trtx_sys::nvinfer1::IBuilderConfig`]; C++: [`nvinfer1::IBuilderConfig`](https://docs.nvidia.com/deeplearning/tensorrt-rtx/latest/_static/cpp-api/classnvinfer1_1_1_i_builder_config.html).

use std::marker::PhantomData;
use std::pin::Pin;

use crate::error::PropertySetAttempt;
use crate::interfaces::MonitorProgress;
use crate::interfaces::ProgressMonitor;
use crate::optimization_profile::OptimizationProfile;
use crate::Builder;
use crate::Error;
use crate::Result;
use cxx::UniquePtr;
use trtx_sys::nvinfer1::{self, IBuilderConfig};
use trtx_sys::{
    BuilderFlag, DeviceType, EngineCapability, HardwareCompatibilityLevel, MemoryPoolType,
    PreviewFeature, ProfilingVerbosity, RuntimePlatform, TilingOptimizationLevel,
};

#[cfg(not(feature = "enterprise"))]
use trtx_sys::ComputeCapability;

/// [`trtx_sys::nvinfer1::IBuilderConfig`] — C++ [`nvinfer1::IBuilderConfig`](https://docs.nvidia.com/deeplearning/tensorrt-rtx/latest/_static/cpp-api/classnvinfer1_1_1_i_builder_config.html).
pub struct BuilderConfig<'builder> {
    pub(crate) inner: UniquePtr<IBuilderConfig>,
    progress_monitor: Option<Pin<Box<ProgressMonitor>>>,
    _builder: PhantomData<&'builder Builder<'builder>>,
}

impl std::fmt::Debug for BuilderConfig<'_> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("BuilderConfig")
            .field("inner", &format!("{:x}", self.inner.as_ptr() as usize))
            .finish_non_exhaustive()
    }
}

impl<'builder> BuilderConfig<'builder> {
    pub(crate) fn new(builder_config: *mut nvinfer1::IBuilderConfig) -> Result<Self> {
        #[cfg(not(feature = "mock"))]
        if builder_config.is_null() {
            return Err(Error::BuilderConfigCreationFailed);
        }
        Ok(Self {
            inner: unsafe { UniquePtr::from_raw(builder_config) },
            progress_monitor: None,
            _builder: Default::default(),
        })
    }

    /// See [IBuilderConfig::setProgressMonitor]
    /// The Rust bindings only allow setting the progress monitor once per builder config object
    pub fn set_progress_monitor(
        &mut self,
        progress_monitor: Box<dyn MonitorProgress>,
    ) -> Result<()> {
        let progress_monitor = ProgressMonitor::new(progress_monitor)?;
        if self.progress_monitor.is_some() {
            // would need to make sure that we don't destroy a monitor still in use
            // could offer this as an unsafe method for users who only set this when there is no
            // build process active. Or we only accept a ref to progress monitor and force user
            // via lifetimes to keep this alive for builder config lifetime
            panic!("Setting a progress monitor more than once not supported at the moment");
        }
        self.progress_monitor = Some(progress_monitor);
        #[cfg(not(feature = "mock"))]
        unsafe {
            self.inner.pin_mut().setProgressMonitor(
                self.progress_monitor
                    .as_mut()
                    .expect("progress_monitor can't be empty. we just set it")
                    .as_trt_progress_monitor(),
            )
        };
        Ok(())
    }

    /// See [IBuilderConfig::setMemoryPoolLimit]
    pub fn set_memory_pool_limit(&mut self, pool: MemoryPoolType, size: usize) {
        #[cfg(not(feature = "mock"))]
        self.inner.pin_mut().setMemoryPoolLimit(pool.into(), size);
    }

    /// See [IBuilderConfig::setProfilingVerbosity]
    pub fn set_profiling_verbosity(&mut self, verbosity: ProfilingVerbosity) {
        #[cfg(not(feature = "mock"))]
        self.inner.pin_mut().setProfilingVerbosity(verbosity.into());
    }

    /// See [IBuilderConfig::getProfilingVerbosity]
    pub fn profiling_verbosity(&self) -> ProfilingVerbosity {
        if cfg!(not(feature = "mock")) {
            self.inner.getProfilingVerbosity().into()
        } else {
            ProfilingVerbosity::kNONE
        }
    }

    #[deprecated = "use profiling_verbosity instead"]
    pub fn get_profiling_verbosity(&self) -> ProfilingVerbosity {
        self.profiling_verbosity()
    }

    /// See [IBuilderConfig::setAvgTimingIterations]
    pub fn set_avg_timing_iterations(&mut self, avg_timing: i32) {
        #[cfg(not(feature = "mock"))]
        self.inner.pin_mut().setAvgTimingIterations(avg_timing);
    }

    /// See [IBuilderConfig::getAvgTimingIterations]
    pub fn avg_timing_iterations(&self) -> i32 {
        if cfg!(not(feature = "mock")) {
            self.inner.getAvgTimingIterations()
        } else {
            0
        }
    }

    #[deprecated = "use avg_timing_iterations instead"]
    pub fn get_avg_timing_iterations(&self) -> i32 {
        self.avg_timing_iterations()
    }

    /// See [IBuilderConfig::setEngineCapability]
    pub fn set_engine_capability(&mut self, capability: EngineCapability) {
        #[cfg(not(feature = "mock"))]
        self.inner.pin_mut().setEngineCapability(capability.into());
    }

    /// See [IBuilderConfig::getEngineCapability]
    pub fn engine_capability(&self) -> EngineCapability {
        if cfg!(not(feature = "mock")) {
            self.inner.getEngineCapability().into()
        } else {
            EngineCapability::kSTANDARD
        }
    }

    #[deprecated = "use engine_capability instead"]
    pub fn get_engine_capability(&self) -> EngineCapability {
        self.engine_capability()
    }

    /// See [IBuilderConfig::setFlags]
    pub fn set_flags(&mut self, flags: u32) {
        #[cfg(not(feature = "mock"))]
        self.inner.pin_mut().setFlags(flags);
    }

    /// See [IBuilderConfig::getFlags]
    pub fn flags(&self) -> u32 {
        if cfg!(not(feature = "mock")) {
            self.inner.getFlags()
        } else {
            0
        }
    }

    #[deprecated = "use flags instead"]
    pub fn get_flags(&self) -> u32 {
        self.flags()
    }

    /// See [IBuilderConfig::setFlag]
    pub fn set_flag(&mut self, flag: BuilderFlag) {
        #[cfg(not(feature = "mock"))]
        self.inner.pin_mut().setFlag(flag.into());
    }

    /// See [IBuilderConfig::clearFlag]
    pub fn clear_flag(&mut self, flag: BuilderFlag) {
        #[cfg(not(feature = "mock"))]
        self.inner.pin_mut().clearFlag(flag.into());
    }

    /// See [IBuilderConfig::getFlag]
    pub fn flag(&self, flag: BuilderFlag) -> bool {
        if cfg!(not(feature = "mock")) {
            self.inner.getFlag(flag.into())
        } else {
            false
        }
    }

    #[deprecated = "use flag instead"]
    pub fn get_flag(&self, flag: BuilderFlag) -> bool {
        self.flag(flag)
    }

    /// See [IBuilderConfig::setDLACore]
    pub fn set_dla_core(&mut self, dla_core: i32) {
        #[cfg(not(feature = "mock"))]
        self.inner.pin_mut().setDLACore(dla_core);
    }

    /// See [IBuilderConfig::getDLACore]
    pub fn dla_core(&self) -> i32 {
        if cfg!(not(feature = "mock")) {
            self.inner.getDLACore()
        } else {
            0
        }
    }

    #[deprecated = "use dla_core instead"]
    pub fn get_dla_core(&self) -> i32 {
        self.dla_core()
    }

    /// See [IBuilderConfig::setDefaultDeviceType]
    pub fn set_default_device_type(&mut self, device_type: DeviceType) {
        #[cfg(not(feature = "mock"))]
        self.inner
            .pin_mut()
            .setDefaultDeviceType(device_type.into());
    }

    /// See [IBuilderConfig::getDefaultDeviceType]
    pub fn default_device_type(&self) -> DeviceType {
        if cfg!(not(feature = "mock")) {
            self.inner.getDefaultDeviceType().into()
        } else {
            DeviceType::kGPU
        }
    }

    #[deprecated = "use default_device_type instead"]
    pub fn get_default_device_type(&self) -> DeviceType {
        self.default_device_type()
    }

    /// See [IBuilderConfig::reset]
    pub fn reset(&mut self) {
        #[cfg(not(feature = "mock"))]
        self.inner.pin_mut().reset();
    }

    /// See [IBuilderConfig::getNbOptimizationProfiles]
    pub fn nb_optimization_profiles(&self) -> i32 {
        if cfg!(not(feature = "mock")) {
            self.inner.getNbOptimizationProfiles()
        } else {
            0
        }
    }

    #[deprecated = "use nb_optimization_profiles instead"]
    pub fn get_nb_optimization_profiles(&self) -> i32 {
        self.nb_optimization_profiles()
    }

    /// See [IBuilderConfig::addOptimizationProfile].
    /// Returns the profile index (0-based) on success.
    pub fn add_optimization_profile(
        &mut self,
        profile: &mut OptimizationProfile<'_>,
    ) -> Result<i32> {
        #[cfg(not(feature = "mock"))]
        {
            let idx = unsafe {
                self.inner
                    .pin_mut()
                    .addOptimizationProfile(profile.inner.as_mut().get_unchecked_mut())
            };
            if idx >= 0 {
                Ok(idx)
            } else {
                Err(Error::Runtime("addOptimizationProfile failed".to_string()))
            }
        }
        #[cfg(feature = "mock")]
        Ok(0)
    }

    /// See [IBuilderConfig::setTacticSources]
    pub fn set_tactic_sources(&mut self, sources: u32) -> crate::Result<()> {
        if cfg!(not(feature = "mock")) {
            if self.inner.pin_mut().setTacticSources(sources) {
                Ok(())
            } else {
                Err(crate::Error::FailedToSetProperty(
                    PropertySetAttempt::BuilderConfigTacticSources,
                ))
            }
        } else {
            Ok(())
        }
    }

    /// See [IBuilderConfig::getTacticSources]
    pub fn tactic_sources(&self) -> u32 {
        if cfg!(not(feature = "mock")) {
            self.inner.getTacticSources()
        } else {
            0
        }
    }

    #[deprecated = "use tactic_sources instead"]
    pub fn get_tactic_sources(&self) -> u32 {
        self.tactic_sources()
    }

    /// See [IBuilderConfig::getMemoryPoolLimit]
    pub fn memory_pool_limit(&self, pool: MemoryPoolType) -> usize {
        if cfg!(not(feature = "mock")) {
            self.inner.getMemoryPoolLimit(pool.into())
        } else {
            0
        }
    }

    #[deprecated = "use memory_pool_limit instead"]
    pub fn get_memory_pool_limit(&self, pool: MemoryPoolType) -> usize {
        self.memory_pool_limit(pool)
    }

    /// See [IBuilderConfig::setPreviewFeature]
    pub fn set_preview_feature(&mut self, feature: PreviewFeature, enable: bool) {
        #[cfg(not(feature = "mock"))]
        self.inner
            .pin_mut()
            .setPreviewFeature(feature.into(), enable);
    }

    /// See [IBuilderConfig::getPreviewFeature]
    pub fn preview_feature(&self, feature: PreviewFeature) -> bool {
        if cfg!(not(feature = "mock")) {
            self.inner.getPreviewFeature(feature.into())
        } else {
            false
        }
    }

    #[deprecated = "use preview_feature instead"]
    pub fn get_preview_feature(&self, feature: PreviewFeature) -> bool {
        self.preview_feature(feature)
    }

    /// See [IBuilderConfig::setBuilderOptimizationLevel]
    pub fn set_builder_optimization_level(&mut self, level: i32) {
        #[cfg(not(feature = "mock"))]
        self.inner.pin_mut().setBuilderOptimizationLevel(level);
    }

    /// See [IBuilderConfig::getBuilderOptimizationLevel]
    pub fn builder_optimization_level(&mut self) -> i32 {
        if cfg!(not(feature = "mock")) {
            self.inner.pin_mut().getBuilderOptimizationLevel()
        } else {
            0
        }
    }

    #[deprecated = "use builder_optimization_level instead"]
    pub fn get_builder_optimization_level(&mut self) -> i32 {
        self.builder_optimization_level()
    }

    /// See [IBuilderConfig::setHardwareCompatibilityLevel]
    pub fn set_hardware_compatibility_level(&mut self, level: HardwareCompatibilityLevel) {
        #[cfg(not(feature = "mock"))]
        self.inner
            .pin_mut()
            .setHardwareCompatibilityLevel(level.into());
    }

    /// See [IBuilderConfig::getHardwareCompatibilityLevel]
    pub fn hardware_compatibility_level(&self) -> HardwareCompatibilityLevel {
        self.inner.getHardwareCompatibilityLevel().into()
    }

    #[deprecated = "use hardware_compatibility_level instead"]
    pub fn get_hardware_compatibility_level(&self) -> HardwareCompatibilityLevel {
        self.hardware_compatibility_level()
    }

    /// See [IBuilderConfig::setMaxAuxStreams]
    pub fn set_max_aux_streams(&mut self, nb_streams: i32) {
        #[cfg(not(feature = "mock"))]
        self.inner.pin_mut().setMaxAuxStreams(nb_streams);
    }

    /// See [IBuilderConfig::getMaxAuxStreams]
    pub fn max_aux_streams(&self) -> i32 {
        if cfg!(not(feature = "mock")) {
            self.inner.getMaxAuxStreams()
        } else {
            0
        }
    }

    #[deprecated = "use max_aux_streams instead"]
    pub fn get_max_aux_streams(&self) -> i32 {
        self.max_aux_streams()
    }

    /// See [IBuilderConfig::setRuntimePlatform]
    pub fn set_runtime_platform(&mut self, platform: RuntimePlatform) {
        #[cfg(not(feature = "mock"))]
        self.inner.pin_mut().setRuntimePlatform(platform.into());
    }

    /// See [IBuilderConfig::getRuntimePlatform]
    pub fn runtime_platform(&self) -> RuntimePlatform {
        if cfg!(not(feature = "mock")) {
            self.inner.getRuntimePlatform().into()
        } else {
            RuntimePlatform::kSAME_AS_BUILD
        }
    }

    #[deprecated = "use runtime_platform instead"]
    pub fn get_runtime_platform(&self) -> RuntimePlatform {
        self.runtime_platform()
    }

    /// See [IBuilderConfig::setMaxNbTactics]
    pub fn set_max_nb_tactics(&mut self, max_nb_tactics: i32) {
        #[cfg(not(feature = "mock"))]
        self.inner.pin_mut().setMaxNbTactics(max_nb_tactics);
    }

    /// See [IBuilderConfig::getMaxNbTactics]
    pub fn max_nb_tactics(&self) -> i32 {
        if cfg!(not(feature = "mock")) {
            self.inner.getMaxNbTactics()
        } else {
            0
        }
    }

    #[deprecated = "use max_nb_tactics instead"]
    pub fn get_max_nb_tactics(&self) -> i32 {
        self.max_nb_tactics()
    }

    /// See [IBuilderConfig::setTilingOptimizationLevel]
    pub fn set_tiling_optimization_level(
        &mut self,
        level: TilingOptimizationLevel,
    ) -> crate::Result<()> {
        if cfg!(not(feature = "mock")) {
            if self
                .inner
                .pin_mut()
                .setTilingOptimizationLevel(level.into())
            {
                Ok(())
            } else {
                Err(crate::Error::FailedToSetProperty(
                    PropertySetAttempt::BuilderConfigTilingOptimizationLevel,
                ))
            }
        } else {
            Ok(())
        }
    }

    /// See [IBuilderConfig::getTilingOptimizationLevel]
    pub fn tiling_optimization_level(&self) -> TilingOptimizationLevel {
        if cfg!(not(feature = "mock")) {
            self.inner.getTilingOptimizationLevel().into()
        } else {
            TilingOptimizationLevel::kNONE
        }
    }

    #[deprecated = "use tiling_optimization_level instead"]
    pub fn get_tiling_optimization_level(&self) -> TilingOptimizationLevel {
        self.tiling_optimization_level()
    }

    /// See [IBuilderConfig::setL2LimitForTiling]
    pub fn set_l2_limit_for_tiling(&mut self, size: i64) -> crate::Result<()> {
        if cfg!(not(feature = "mock")) {
            if self.inner.pin_mut().setL2LimitForTiling(size) {
                Ok(())
            } else {
                Err(crate::Error::FailedToSetProperty(
                    PropertySetAttempt::BuilderConfigL2LimitForTiling,
                ))
            }
        } else {
            Ok(())
        }
    }

    /// See [IBuilderConfig::getL2LimitForTiling]
    pub fn l2_limit_for_tiling(&self) -> i64 {
        if cfg!(not(feature = "mock")) {
            self.inner.getL2LimitForTiling()
        } else {
            0
        }
    }

    #[deprecated = "use l2_limit_for_tiling instead"]
    pub fn get_l2_limit_for_tiling(&self) -> i64 {
        self.l2_limit_for_tiling()
    }

    /// See [IBuilderConfig::setNbComputeCapabilities]
    #[cfg(not(feature = "enterprise"))]
    pub fn set_nb_compute_capabilities(
        &mut self,
        max_nb_compute_capabilities: i32,
    ) -> crate::Result<()> {
        if cfg!(not(feature = "mock")) {
            if self
                .inner
                .pin_mut()
                .setNbComputeCapabilities(max_nb_compute_capabilities)
            {
                Ok(())
            } else {
                Err(crate::Error::FailedToSetProperty(
                    PropertySetAttempt::BuilderConfigNbComputeCapabilities,
                ))
            }
        } else {
            Ok(())
        }
    }

    /// See [IBuilderConfig::getNbComputeCapabilities]
    #[cfg(not(feature = "enterprise"))]
    pub fn nb_compute_capabilities(&self) -> i32 {
        if cfg!(not(feature = "mock")) {
            self.inner.getNbComputeCapabilities()
        } else {
            0
        }
    }

    #[cfg(not(feature = "enterprise"))]
    #[deprecated = "use nb_compute_capabilities instead"]
    pub fn get_nb_compute_capabilities(&self) -> i32 {
        self.nb_compute_capabilities()
    }

    #[cfg(not(feature = "enterprise"))]
    /// See [IBuilderConfig::setComputeCapability]
    pub fn set_compute_capability(
        &mut self,
        compute_capability: ComputeCapability,
        index: i32,
    ) -> crate::Result<()> {
        if cfg!(not(feature = "mock")) {
            if self
                .inner
                .pin_mut()
                .setComputeCapability(compute_capability.into(), index)
            {
                Ok(())
            } else {
                Err(crate::Error::FailedToSetProperty(
                    PropertySetAttempt::BuilderConfigComputeCapability,
                ))
            }
        } else {
            Ok(())
        }
    }

    #[cfg(not(feature = "enterprise"))]
    /// See [IBuilderConfig::getComputeCapability]
    pub fn compute_capability(&self, index: i32) -> ComputeCapability {
        if cfg!(not(feature = "mock")) {
            self.inner.getComputeCapability(index).into()
        } else {
            ComputeCapability::kNONE
        }
    }

    #[cfg(not(feature = "enterprise"))]
    #[deprecated = "use compute_capability instead"]
    pub fn get_compute_capability(&self, index: i32) -> ComputeCapability {
        self.compute_capability(index)
    }
}

#[cfg(test)]
#[cfg(not(feature = "mock"))]
mod tests {
    use crate::builder::MemoryPoolType;
    use crate::interfaces::MonitorProgress;
    use crate::{Builder, DataType, Logger, NetworkDefinition};
    use std::ops::ControlFlow;
    use std::sync::atomic::{AtomicU32, Ordering};

    const NUM_LAYERS: usize = 40;

    /// Progress monitor that writes to stdout and cancels the build after a few steps.
    struct StdoutProgressMonitor {
        step_count: AtomicU32,
        cancel_after: u32,
    }

    impl StdoutProgressMonitor {
        fn new(cancel_after: u32) -> Self {
            Self {
                step_count: AtomicU32::new(0),
                cancel_after,
            }
        }
    }

    impl MonitorProgress for StdoutProgressMonitor {
        fn phase_start(&self, phase_name: &str, parent_phase: Option<&str>, num_steps: i32) {
            println!(
                "[progress] phase_start phase={:?} parent={:?} num_steps={}",
                phase_name, parent_phase, num_steps
            );
        }

        fn step_complete(&self, phase_name: &str, step: i32) -> ControlFlow<()> {
            let n = self.step_count.fetch_add(1, Ordering::SeqCst);
            println!(
                "[progress] step_complete phase={:?} step={}",
                phase_name, step
            );
            if n + 1 >= self.cancel_after {
                println!("[progress] cancel requested after {} steps", n + 1);
                ControlFlow::Break(())
            } else {
                ControlFlow::Continue(())
            }
        }

        fn phase_finish(&self, phase_name: &str) {
            println!("[progress] phase_finish phase={:?}", phase_name);
        }
    }

    /// Build a network with many repeated identity layers, each named.
    fn build_heavy_network(logger: &Logger) -> crate::Result<(Builder<'_>, NetworkDefinition<'_>)> {
        let mut builder = Builder::new(logger)?;
        let mut network = builder.create_network(0)?;

        let mut tensor = network.add_input("input", DataType::kFLOAT, &[1, 4])?;
        for i in 0..NUM_LAYERS {
            let mut layer = network.add_identity(&tensor)?;
            layer.set_name(&mut network, &format!("layer_{}", i))?;
            tensor = layer.output(&network, 0)?;
        }
        tensor.set_name(&mut network, "output")?;
        network.mark_output(&tensor);

        Ok((builder, network))
    }

    #[test]
    fn set_progress_monitor_cancel_build() {
        let logger = Logger::stderr().expect("logger");
        let (mut builder, mut network) = build_heavy_network(&logger).expect("build network");

        let mut config = builder.create_config().expect("config");
        config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 24);

        let monitor = Box::new(StdoutProgressMonitor::new(3));
        config.set_progress_monitor(monitor).unwrap();

        let result = builder.build_serialized_network(&mut network, &mut config);

        assert!(
            result.is_err(),
            "build should fail (cancelled by progress monitor)"
        );
    }

    #[test]
    fn set_progress_monitor_progress_to_stdout() {
        let logger = Logger::stderr().expect("logger");
        let (mut builder, mut network) = build_heavy_network(&logger).expect("build network");

        let mut config = builder.create_config().expect("config");
        config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 24);

        let monitor = Box::new(StdoutProgressMonitor::new(10000));
        config.set_progress_monitor(monitor).unwrap();

        let result = builder.build_serialized_network(&mut network, &mut config);

        assert!(result.is_ok(), "build should succeed when not cancelling");
    }
}