use bitflags::bitflags;
bitflags! {
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct BuilderFlags: u32 {
const FP16 = 1 << 0;
const INT8 = 1 << 1;
const DEBUG_KERNELS = 1 << 2;
const GPU_FALLBACK = 1 << 3;
const REFIT = 1 << 4;
const DISABLE_TIMING_CACHE = 1 << 5;
const TF32 = 1 << 6;
const SPARSE_WEIGHTS = 1 << 7;
const SAFETY_SCOPE = 1 << 8;
const OBEY_PRECISION_CONSTRAINTS = 1 << 9;
const PREFER_PRECISION_CONSTRAINTS = 1 << 10;
const DIRECT_IO = 1 << 11;
const REJECT_EMPTY_ALGORITHMS = 1 << 12;
const BF16 = 1 << 13;
const FP8 = 1 << 14;
const STRIP_PLAN = 1 << 15;
const VERSION_COMPATIBLE = 1 << 16;
const EXCLUDE_LEAN_RUNTIME = 1 << 17;
}
}
bitflags! {
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct TacticSources: u32 {
const CUBLAS = 1 << 0;
const CUBLAS_LT = 1 << 1;
const CUDNN = 1 << 2;
const EDGE_MASK_CONVOLUTIONS = 1 << 3;
const JIT_CONVOLUTIONS = 1 << 4;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Precision {
#[default]
Fp32,
Fp16,
Bf16,
Int8,
Fp8,
Best,
}
impl Precision {
pub fn flags(self) -> BuilderFlags {
match self {
Precision::Fp32 => BuilderFlags::TF32,
Precision::Fp16 => BuilderFlags::FP16 | BuilderFlags::TF32,
Precision::Bf16 => BuilderFlags::BF16 | BuilderFlags::TF32,
Precision::Int8 => BuilderFlags::INT8 | BuilderFlags::TF32,
Precision::Fp8 => BuilderFlags::FP8 | BuilderFlags::FP16 | BuilderFlags::TF32,
Precision::Best => {
BuilderFlags::FP16
| BuilderFlags::BF16
| BuilderFlags::INT8
| BuilderFlags::FP8
| BuilderFlags::TF32
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DeviceType {
#[default]
Gpu,
Dla(i32),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RefitPolicy {
#[default]
Disabled,
OnDemand,
WeightsStreaming,
}
#[derive(Debug, Clone)]
pub struct IBuilderConfig {
pub precision: Precision,
pub device_type: DeviceType,
pub structured_sparsity: bool,
pub tactic_sources: TacticSources,
pub timing_cache: Option<Vec<u8>>,
pub refit: RefitPolicy,
pub workspace_bytes: usize,
pub dla_sram_bytes: usize,
pub extra_flags: BuilderFlags,
}
impl Default for IBuilderConfig {
fn default() -> Self {
Self {
precision: Precision::default(),
device_type: DeviceType::default(),
structured_sparsity: false,
tactic_sources: TacticSources::CUBLAS
| TacticSources::CUBLAS_LT
| TacticSources::CUDNN
| TacticSources::EDGE_MASK_CONVOLUTIONS
| TacticSources::JIT_CONVOLUTIONS,
timing_cache: None,
refit: RefitPolicy::default(),
workspace_bytes: 1 << 30, dla_sram_bytes: 0,
extra_flags: BuilderFlags::empty(),
}
}
}
impl IBuilderConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_precision(mut self, p: Precision) -> Self {
self.precision = p;
self
}
pub fn with_device(mut self, dt: DeviceType) -> Self {
self.device_type = dt;
self
}
pub fn with_sparsity(mut self, on: bool) -> Self {
self.structured_sparsity = on;
self
}
pub fn with_tactic_sources(mut self, ts: TacticSources) -> Self {
self.tactic_sources = ts;
self
}
pub fn with_timing_cache(mut self, cache: Vec<u8>) -> Self {
self.timing_cache = Some(cache);
self
}
pub fn with_refit(mut self, refit: RefitPolicy) -> Self {
self.refit = refit;
self
}
pub fn with_workspace_bytes(mut self, bytes: usize) -> Self {
self.workspace_bytes = bytes;
self
}
pub fn with_extra_flags(mut self, flags: BuilderFlags) -> Self {
self.extra_flags = flags;
self
}
pub fn effective_flags(&self) -> BuilderFlags {
let mut f = self.precision.flags() | self.extra_flags;
if self.structured_sparsity {
f |= BuilderFlags::SPARSE_WEIGHTS;
}
match self.refit {
RefitPolicy::Disabled => {}
RefitPolicy::OnDemand => f |= BuilderFlags::REFIT,
RefitPolicy::WeightsStreaming => {
f |= BuilderFlags::REFIT | BuilderFlags::STRIP_PLAN;
}
}
f
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_config_round_trip() {
let cfg = IBuilderConfig::new()
.with_precision(Precision::Best)
.with_device(DeviceType::Dla(1))
.with_sparsity(true)
.with_refit(RefitPolicy::WeightsStreaming)
.with_workspace_bytes(2 << 30)
.with_extra_flags(BuilderFlags::DEBUG_KERNELS)
.with_tactic_sources(TacticSources::CUBLAS | TacticSources::CUDNN)
.with_timing_cache(vec![1, 2, 3, 4]);
assert_eq!(cfg.precision, Precision::Best);
assert!(cfg.structured_sparsity);
assert!(matches!(cfg.refit, RefitPolicy::WeightsStreaming));
assert!(matches!(cfg.device_type, DeviceType::Dla(1)));
assert_eq!(cfg.workspace_bytes, 2 << 30);
assert_eq!(cfg.timing_cache.as_deref(), Some(&[1u8, 2, 3, 4][..]));
assert!(cfg.tactic_sources.contains(TacticSources::CUBLAS));
assert!(!cfg.tactic_sources.contains(TacticSources::CUBLAS_LT));
let flags = cfg.effective_flags();
assert!(flags.contains(BuilderFlags::FP16));
assert!(flags.contains(BuilderFlags::BF16));
assert!(flags.contains(BuilderFlags::INT8));
assert!(flags.contains(BuilderFlags::FP8));
assert!(flags.contains(BuilderFlags::TF32));
assert!(flags.contains(BuilderFlags::REFIT));
assert!(flags.contains(BuilderFlags::STRIP_PLAN));
assert!(flags.contains(BuilderFlags::SPARSE_WEIGHTS));
assert!(flags.contains(BuilderFlags::DEBUG_KERNELS));
}
#[test]
fn precision_flag_mapping_is_stable() {
assert!(Precision::Fp16.flags().contains(BuilderFlags::FP16));
assert!(Precision::Bf16.flags().contains(BuilderFlags::BF16));
assert!(Precision::Int8.flags().contains(BuilderFlags::INT8));
assert!(Precision::Fp8.flags().contains(BuilderFlags::FP8));
let best = Precision::Best.flags();
for f in [
BuilderFlags::FP16,
BuilderFlags::BF16,
BuilderFlags::INT8,
BuilderFlags::FP8,
BuilderFlags::TF32,
] {
assert!(best.contains(f), "Best is missing {:?}", f);
}
}
#[test]
fn refit_disabled_does_not_set_refit_flag() {
let cfg = IBuilderConfig::new();
let f = cfg.effective_flags();
assert!(!f.contains(BuilderFlags::REFIT));
assert!(!f.contains(BuilderFlags::STRIP_PLAN));
}
}