1use bitflags::bitflags;
10
11bitflags! {
12 #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
16 pub struct BuilderFlags: u32 {
17 const FP16 = 1 << 0;
18 const INT8 = 1 << 1;
19 const DEBUG_KERNELS = 1 << 2;
20 const GPU_FALLBACK = 1 << 3;
21 const REFIT = 1 << 4;
22 const DISABLE_TIMING_CACHE = 1 << 5;
23 const TF32 = 1 << 6;
24 const SPARSE_WEIGHTS = 1 << 7;
25 const SAFETY_SCOPE = 1 << 8;
26 const OBEY_PRECISION_CONSTRAINTS = 1 << 9;
27 const PREFER_PRECISION_CONSTRAINTS = 1 << 10;
28 const DIRECT_IO = 1 << 11;
29 const REJECT_EMPTY_ALGORITHMS = 1 << 12;
30 const BF16 = 1 << 13;
31 const FP8 = 1 << 14;
32 const STRIP_PLAN = 1 << 15;
33 const VERSION_COMPATIBLE = 1 << 16;
34 const EXCLUDE_LEAN_RUNTIME = 1 << 17;
35 }
36}
37
38bitflags! {
39 #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
41 pub struct TacticSources: u32 {
42 const CUBLAS = 1 << 0;
43 const CUBLAS_LT = 1 << 1;
44 const CUDNN = 1 << 2;
45 const EDGE_MASK_CONVOLUTIONS = 1 << 3;
46 const JIT_CONVOLUTIONS = 1 << 4;
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
53pub enum Precision {
54 #[default]
55 Fp32,
56 Fp16,
57 Bf16,
58 Int8,
59 Fp8,
60 Best,
62}
63
64impl Precision {
65 pub fn flags(self) -> BuilderFlags {
66 match self {
67 Precision::Fp32 => BuilderFlags::TF32,
68 Precision::Fp16 => BuilderFlags::FP16 | BuilderFlags::TF32,
69 Precision::Bf16 => BuilderFlags::BF16 | BuilderFlags::TF32,
70 Precision::Int8 => BuilderFlags::INT8 | BuilderFlags::TF32,
71 Precision::Fp8 => BuilderFlags::FP8 | BuilderFlags::FP16 | BuilderFlags::TF32,
72 Precision::Best => {
73 BuilderFlags::FP16
74 | BuilderFlags::BF16
75 | BuilderFlags::INT8
76 | BuilderFlags::FP8
77 | BuilderFlags::TF32
78 }
79 }
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
86pub enum DeviceType {
87 #[default]
88 Gpu,
89 Dla(i32),
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
96pub enum RefitPolicy {
97 #[default]
98 Disabled,
99 OnDemand,
100 WeightsStreaming,
101}
102
103#[derive(Debug, Clone)]
107pub struct IBuilderConfig {
108 pub precision: Precision,
109 pub device_type: DeviceType,
110 pub structured_sparsity: bool,
112 pub tactic_sources: TacticSources,
114 pub timing_cache: Option<Vec<u8>>,
116 pub refit: RefitPolicy,
118 pub workspace_bytes: usize,
120 pub dla_sram_bytes: usize,
123 pub extra_flags: BuilderFlags,
127}
128
129impl Default for IBuilderConfig {
130 fn default() -> Self {
131 Self {
132 precision: Precision::default(),
133 device_type: DeviceType::default(),
134 structured_sparsity: false,
135 tactic_sources: TacticSources::CUBLAS
136 | TacticSources::CUBLAS_LT
137 | TacticSources::CUDNN
138 | TacticSources::EDGE_MASK_CONVOLUTIONS
139 | TacticSources::JIT_CONVOLUTIONS,
140 timing_cache: None,
141 refit: RefitPolicy::default(),
142 workspace_bytes: 1 << 30, dla_sram_bytes: 0,
144 extra_flags: BuilderFlags::empty(),
145 }
146 }
147}
148
149impl IBuilderConfig {
150 pub fn new() -> Self {
151 Self::default()
152 }
153
154 pub fn with_precision(mut self, p: Precision) -> Self {
155 self.precision = p;
156 self
157 }
158
159 pub fn with_device(mut self, dt: DeviceType) -> Self {
160 self.device_type = dt;
161 self
162 }
163
164 pub fn with_sparsity(mut self, on: bool) -> Self {
165 self.structured_sparsity = on;
166 self
167 }
168
169 pub fn with_tactic_sources(mut self, ts: TacticSources) -> Self {
170 self.tactic_sources = ts;
171 self
172 }
173
174 pub fn with_timing_cache(mut self, cache: Vec<u8>) -> Self {
175 self.timing_cache = Some(cache);
176 self
177 }
178
179 pub fn with_refit(mut self, refit: RefitPolicy) -> Self {
180 self.refit = refit;
181 self
182 }
183
184 pub fn with_workspace_bytes(mut self, bytes: usize) -> Self {
185 self.workspace_bytes = bytes;
186 self
187 }
188
189 pub fn with_extra_flags(mut self, flags: BuilderFlags) -> Self {
190 self.extra_flags = flags;
191 self
192 }
193
194 pub fn effective_flags(&self) -> BuilderFlags {
198 let mut f = self.precision.flags() | self.extra_flags;
199 if self.structured_sparsity {
200 f |= BuilderFlags::SPARSE_WEIGHTS;
201 }
202 match self.refit {
203 RefitPolicy::Disabled => {}
204 RefitPolicy::OnDemand => f |= BuilderFlags::REFIT,
205 RefitPolicy::WeightsStreaming => {
206 f |= BuilderFlags::REFIT | BuilderFlags::STRIP_PLAN;
207 }
208 }
209 f
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn builder_config_round_trip() {
219 let cfg = IBuilderConfig::new()
220 .with_precision(Precision::Best)
221 .with_device(DeviceType::Dla(1))
222 .with_sparsity(true)
223 .with_refit(RefitPolicy::WeightsStreaming)
224 .with_workspace_bytes(2 << 30)
225 .with_extra_flags(BuilderFlags::DEBUG_KERNELS)
226 .with_tactic_sources(TacticSources::CUBLAS | TacticSources::CUDNN)
227 .with_timing_cache(vec![1, 2, 3, 4]);
228
229 assert_eq!(cfg.precision, Precision::Best);
230 assert!(cfg.structured_sparsity);
231 assert!(matches!(cfg.refit, RefitPolicy::WeightsStreaming));
232 assert!(matches!(cfg.device_type, DeviceType::Dla(1)));
233 assert_eq!(cfg.workspace_bytes, 2 << 30);
234 assert_eq!(cfg.timing_cache.as_deref(), Some(&[1u8, 2, 3, 4][..]));
235 assert!(cfg.tactic_sources.contains(TacticSources::CUBLAS));
236 assert!(!cfg.tactic_sources.contains(TacticSources::CUBLAS_LT));
237
238 let flags = cfg.effective_flags();
239 assert!(flags.contains(BuilderFlags::FP16));
242 assert!(flags.contains(BuilderFlags::BF16));
243 assert!(flags.contains(BuilderFlags::INT8));
244 assert!(flags.contains(BuilderFlags::FP8));
245 assert!(flags.contains(BuilderFlags::TF32));
246 assert!(flags.contains(BuilderFlags::REFIT));
247 assert!(flags.contains(BuilderFlags::STRIP_PLAN));
248 assert!(flags.contains(BuilderFlags::SPARSE_WEIGHTS));
249 assert!(flags.contains(BuilderFlags::DEBUG_KERNELS));
250 }
251
252 #[test]
253 fn precision_flag_mapping_is_stable() {
254 assert!(Precision::Fp16.flags().contains(BuilderFlags::FP16));
255 assert!(Precision::Bf16.flags().contains(BuilderFlags::BF16));
256 assert!(Precision::Int8.flags().contains(BuilderFlags::INT8));
257 assert!(Precision::Fp8.flags().contains(BuilderFlags::FP8));
258 let best = Precision::Best.flags();
259 for f in [
260 BuilderFlags::FP16,
261 BuilderFlags::BF16,
262 BuilderFlags::INT8,
263 BuilderFlags::FP8,
264 BuilderFlags::TF32,
265 ] {
266 assert!(best.contains(f), "Best is missing {:?}", f);
267 }
268 }
269
270 #[test]
271 fn refit_disabled_does_not_set_refit_flag() {
272 let cfg = IBuilderConfig::new();
273 let f = cfg.effective_flags();
274 assert!(!f.contains(BuilderFlags::REFIT));
275 assert!(!f.contains(BuilderFlags::STRIP_PLAN));
276 }
277}