async_tensorrt/ffi/sync/
builder.rs

1use cpp::cpp;
2
3use async_cuda::device::DeviceId;
4use async_cuda::ffi::device::Device;
5
6use crate::ffi::builder_config::BuilderConfig;
7use crate::ffi::memory::HostBuffer;
8use crate::ffi::network::{NetworkDefinition, NetworkDefinitionCreationFlags};
9use crate::ffi::optimization_profile::OptimizationProfile;
10use crate::ffi::result;
11
12type Result<T> = std::result::Result<T, crate::error::Error>;
13
14/// Synchronous implementation of [`crate::Builder`].
15///
16/// Refer to [`crate::Builder`] for documentation.
17pub struct Builder {
18    addr: *mut std::ffi::c_void,
19    device: DeviceId,
20}
21
22/// Implements [`Send`] for [`Builder`].
23///
24/// # Safety
25///
26/// The TensorRT API is thread-safe with regards to all operations on [`Builder`].
27unsafe impl Send for Builder {}
28
29/// Implements [`Sync`] for [`Builder`].
30///
31/// # Safety
32///
33/// The TensorRT API is thread-safe with regards to all operations on [`Builder`].
34unsafe impl Sync for Builder {}
35
36impl Builder {
37    pub fn new() -> Result<Self> {
38        let device = Device::get_or_panic();
39        let addr = cpp!(unsafe [] -> *mut std::ffi::c_void as "void*" {
40            return createInferBuilder(GLOBAL_LOGGER);
41        });
42        result!(addr, Builder { addr, device })
43    }
44
45    pub fn config(&mut self) -> BuilderConfig {
46        let internal = self.as_mut_ptr();
47        let internal = cpp!(unsafe [
48            internal as "void*"
49        ] -> *mut std::ffi::c_void as "void*" {
50            return ((IBuilder*) internal)->createBuilderConfig();
51        });
52        BuilderConfig::wrap(internal)
53    }
54
55    pub fn optimization_profile(&mut self) -> Result<OptimizationProfile> {
56        let internal = self.as_mut_ptr();
57        let optimization_profile_internal = cpp!(unsafe [
58            internal as "void*"
59        ] -> *mut std::ffi::c_void as "void*" {
60            return ((IBuilder*) internal)->createOptimizationProfile();
61        });
62        result!(
63            optimization_profile_internal,
64            OptimizationProfile::wrap(optimization_profile_internal, self)
65        )
66    }
67
68    pub fn add_default_optimization_profile(&mut self) -> Result<()> {
69        self.optimization_profile()?;
70        Ok(())
71    }
72
73    pub fn with_default_optimization_profile(mut self) -> Result<Self> {
74        self.optimization_profile()?;
75        Ok(self)
76    }
77
78    pub fn build_serialized_network(
79        &mut self,
80        network_definition: &mut NetworkDefinition,
81        config: BuilderConfig,
82    ) -> Result<HostBuffer> {
83        let internal = self.as_mut_ptr();
84        let internal_network_definition = network_definition.as_ptr();
85        let internal_builder_config = config.as_ptr();
86        let plan_internal = cpp!(unsafe [
87            internal as "void*",
88            internal_network_definition as "void*",
89            internal_builder_config as "void*"
90        ] -> *mut std::ffi::c_void as "void*" {
91            return ((IBuilder*) internal)->buildSerializedNetwork(
92                *((INetworkDefinition*) internal_network_definition),
93                *((IBuilderConfig*) internal_builder_config)
94            );
95        });
96        result!(plan_internal, HostBuffer::wrap(plan_internal))
97    }
98
99    pub fn network_definition(
100        &mut self,
101        flags: NetworkDefinitionCreationFlags,
102    ) -> NetworkDefinition {
103        let internal = self.as_mut_ptr();
104        let set_explicit_batch_size = match flags {
105            NetworkDefinitionCreationFlags::None => false,
106            NetworkDefinitionCreationFlags::ExplicitBatchSize => true,
107        };
108        let internal = cpp!(unsafe [
109            internal as "void*",
110            set_explicit_batch_size as "bool"
111        ] -> *mut std::ffi::c_void as "void*" {
112            std::uint32_t flags = 0;
113            if (set_explicit_batch_size) {
114                flags |= (1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));
115            }
116            return ((IBuilder*) internal)->createNetworkV2(flags);
117        });
118        NetworkDefinition::wrap(internal)
119    }
120
121    pub fn platform_has_fast_int8(&self) -> bool {
122        let internal = self.as_ptr();
123        cpp!(unsafe [
124            internal as "const void*"
125        ] -> bool as "bool" {
126            return ((const IBuilder*) internal)->platformHasFastInt8();
127        })
128    }
129
130    pub fn platform_has_fast_fp16(&self) -> bool {
131        let internal = self.as_ptr();
132        cpp!(unsafe [
133            internal as "const void*"
134        ] -> bool as "bool" {
135            return ((const IBuilder*) internal)->platformHasFastFp16();
136        })
137    }
138
139    #[inline(always)]
140    pub fn as_ptr(&self) -> *const std::ffi::c_void {
141        self.addr
142    }
143
144    #[inline(always)]
145    pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
146        self.addr
147    }
148}
149
150impl Drop for Builder {
151    fn drop(&mut self) {
152        Device::set_or_panic(self.device);
153        let internal = self.as_mut_ptr();
154        cpp!(unsafe [
155            internal as "void*"
156        ] {
157            destroy((IBuilder*) internal);
158        });
159    }
160}