async_tensorrt/ffi/sync/
builder.rs1use cpp::cpp;
2
3use async_cuda::device::DeviceId;
4use async_cuda::ffi::device::Device;
5
6use crate::ffi::builder_config::BuilderConfig;
7use crate::ffi::memory::HostBuffer;
8use crate::ffi::network::{NetworkDefinition, NetworkDefinitionCreationFlags};
9use crate::ffi::optimization_profile::OptimizationProfile;
10use crate::ffi::result;
11
12type Result<T> = std::result::Result<T, crate::error::Error>;
13
14pub struct Builder {
18 addr: *mut std::ffi::c_void,
19 device: DeviceId,
20}
21
22unsafe impl Send for Builder {}
28
29unsafe impl Sync for Builder {}
35
36impl Builder {
37 pub fn new() -> Result<Self> {
38 let device = Device::get_or_panic();
39 let addr = cpp!(unsafe [] -> *mut std::ffi::c_void as "void*" {
40 return createInferBuilder(GLOBAL_LOGGER);
41 });
42 result!(addr, Builder { addr, device })
43 }
44
45 pub fn config(&mut self) -> BuilderConfig {
46 let internal = self.as_mut_ptr();
47 let internal = cpp!(unsafe [
48 internal as "void*"
49 ] -> *mut std::ffi::c_void as "void*" {
50 return ((IBuilder*) internal)->createBuilderConfig();
51 });
52 BuilderConfig::wrap(internal)
53 }
54
55 pub fn optimization_profile(&mut self) -> Result<OptimizationProfile> {
56 let internal = self.as_mut_ptr();
57 let optimization_profile_internal = cpp!(unsafe [
58 internal as "void*"
59 ] -> *mut std::ffi::c_void as "void*" {
60 return ((IBuilder*) internal)->createOptimizationProfile();
61 });
62 result!(
63 optimization_profile_internal,
64 OptimizationProfile::wrap(optimization_profile_internal, self)
65 )
66 }
67
68 pub fn add_default_optimization_profile(&mut self) -> Result<()> {
69 self.optimization_profile()?;
70 Ok(())
71 }
72
73 pub fn with_default_optimization_profile(mut self) -> Result<Self> {
74 self.optimization_profile()?;
75 Ok(self)
76 }
77
78 pub fn build_serialized_network(
79 &mut self,
80 network_definition: &mut NetworkDefinition,
81 config: BuilderConfig,
82 ) -> Result<HostBuffer> {
83 let internal = self.as_mut_ptr();
84 let internal_network_definition = network_definition.as_ptr();
85 let internal_builder_config = config.as_ptr();
86 let plan_internal = cpp!(unsafe [
87 internal as "void*",
88 internal_network_definition as "void*",
89 internal_builder_config as "void*"
90 ] -> *mut std::ffi::c_void as "void*" {
91 return ((IBuilder*) internal)->buildSerializedNetwork(
92 *((INetworkDefinition*) internal_network_definition),
93 *((IBuilderConfig*) internal_builder_config)
94 );
95 });
96 result!(plan_internal, HostBuffer::wrap(plan_internal))
97 }
98
99 pub fn network_definition(
100 &mut self,
101 flags: NetworkDefinitionCreationFlags,
102 ) -> NetworkDefinition {
103 let internal = self.as_mut_ptr();
104 let set_explicit_batch_size = match flags {
105 NetworkDefinitionCreationFlags::None => false,
106 NetworkDefinitionCreationFlags::ExplicitBatchSize => true,
107 };
108 let internal = cpp!(unsafe [
109 internal as "void*",
110 set_explicit_batch_size as "bool"
111 ] -> *mut std::ffi::c_void as "void*" {
112 std::uint32_t flags = 0;
113 if (set_explicit_batch_size) {
114 flags |= (1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));
115 }
116 return ((IBuilder*) internal)->createNetworkV2(flags);
117 });
118 NetworkDefinition::wrap(internal)
119 }
120
121 pub fn platform_has_fast_int8(&self) -> bool {
122 let internal = self.as_ptr();
123 cpp!(unsafe [
124 internal as "const void*"
125 ] -> bool as "bool" {
126 return ((const IBuilder*) internal)->platformHasFastInt8();
127 })
128 }
129
130 pub fn platform_has_fast_fp16(&self) -> bool {
131 let internal = self.as_ptr();
132 cpp!(unsafe [
133 internal as "const void*"
134 ] -> bool as "bool" {
135 return ((const IBuilder*) internal)->platformHasFastFp16();
136 })
137 }
138
139 #[inline(always)]
140 pub fn as_ptr(&self) -> *const std::ffi::c_void {
141 self.addr
142 }
143
144 #[inline(always)]
145 pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
146 self.addr
147 }
148}
149
150impl Drop for Builder {
151 fn drop(&mut self) {
152 Device::set_or_panic(self.device);
153 let internal = self.as_mut_ptr();
154 cpp!(unsafe [
155 internal as "void*"
156 ] {
157 destroy((IBuilder*) internal);
158 });
159 }
160}