pub use baracuda_kernels_types::{
ActivationKind, ArchSku, BackendKind, BiasElement, BiasElementKind, Element, ElementKind,
EpilogueKind, F32Strict, IntElement, LayoutSku, MathPrecision, MatrixMut, MatrixRef,
PlanPreference, PrecisionGuarantee, S8, ScalarType, U8, VectorRef, Workspace,
};
pub use baracuda_kernels_types::Element as CutlassElement;
#[derive(Copy, Clone, Debug)]
pub struct GemmDescriptor {
pub m: i32,
pub n: i32,
pub k: i32,
pub layout: LayoutSku,
pub epilogue: EpilogueKind,
}
#[derive(Debug)]
pub struct GemmArgs<'a, T: Element> {
pub a: MatrixRef<'a, T>,
pub b: MatrixRef<'a, T>,
pub c: Option<MatrixRef<'a, T>>,
pub d: MatrixMut<'a, T>,
pub bias: Option<VectorRef<'a, T>>,
pub alpha: T::Scalar,
pub beta: T::Scalar,
}
#[derive(Copy, Clone, Debug)]
pub struct BatchedGemmDescriptor {
pub m: i32,
pub n: i32,
pub k: i32,
pub batch_count: i32,
pub layout: LayoutSku,
pub epilogue: EpilogueKind,
}
#[derive(Debug)]
pub struct BatchedGemmArgs<'a, T: Element> {
pub a: MatrixRef<'a, T>,
pub stride_a: i64,
pub b: MatrixRef<'a, T>,
pub stride_b: i64,
pub c: Option<MatrixRef<'a, T>>,
pub stride_c: i64,
pub d: MatrixMut<'a, T>,
pub stride_d: i64,
pub alpha: T::Scalar,
pub beta: T::Scalar,
}
#[derive(Debug)]
pub struct GroupedProblem<'a, T: Element> {
pub m: i32,
pub n: i32,
pub k: i32,
pub a: MatrixRef<'a, T>,
pub b: MatrixRef<'a, T>,
pub c: Option<MatrixRef<'a, T>>,
pub d: MatrixMut<'a, T>,
pub alpha: T::Scalar,
pub beta: T::Scalar,
}
#[derive(Copy, Clone, Debug)]
pub struct IntGemmDescriptor {
pub m: i32,
pub n: i32,
pub k: i32,
pub layout: LayoutSku,
pub epilogue: EpilogueKind,
}
#[derive(Debug)]
pub struct IntGemmArgs<'a, T: IntElement, BT: BiasElement = f32> {
pub a: MatrixRef<'a, T>,
pub b: MatrixRef<'a, T>,
pub c: Option<MatrixRef<'a, T>>,
pub d: MatrixMut<'a, T>,
pub bias: Option<VectorRef<'a, BT>>,
pub alpha: f32,
pub beta: f32,
}
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
pub enum GroupedScheduleMode {
#[default]
DeviceOnly,
}
#[derive(Copy, Clone, Debug, Default)]
pub struct GroupedPlanPreference {
pub base: PlanPreference,
pub schedule: GroupedScheduleMode,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct GemmSku {
pub arch: ArchSku,
pub layout: LayoutSku,
pub epilogue: EpilogueKind,
pub element: ElementKind,
pub bias_element: Option<BiasElementKind>,
}
impl GemmSku {
pub fn precision_guarantee(self) -> PrecisionGuarantee {
let math_precision = match self.element {
ElementKind::F16 => MathPrecision::F16,
ElementKind::Bf16 => MathPrecision::Bf16,
ElementKind::F32 => MathPrecision::Tf32,
ElementKind::F32Strict => MathPrecision::F32,
ElementKind::F64 => MathPrecision::F64,
ElementKind::S8 | ElementKind::U8 => MathPrecision::Int8,
ElementKind::I32 => MathPrecision::Int8,
ElementKind::I64 | ElementKind::Bool => MathPrecision::Int8,
ElementKind::Fp8E4M3 => MathPrecision::Fp8E4M3,
ElementKind::Fp8E5M2 => MathPrecision::Fp8E5M2,
ElementKind::S4 | ElementKind::U4 => MathPrecision::Int4,
ElementKind::Bin => MathPrecision::Binary,
ElementKind::Complex32 => MathPrecision::F32,
ElementKind::Complex64 => MathPrecision::F64,
};
let bit_stable_on_same_hardware = matches!(
self.element,
ElementKind::F32Strict
| ElementKind::S8
| ElementKind::U8
| ElementKind::S4
| ElementKind::U4
| ElementKind::Bin,
);
let accumulator = match self.element {
ElementKind::F64 => ElementKind::F64,
ElementKind::S8
| ElementKind::U8
| ElementKind::S4
| ElementKind::U4
| ElementKind::Bin => ElementKind::I32,
_ => ElementKind::F32,
};
PrecisionGuarantee {
math_precision,
accumulator,
bit_stable_on_same_hardware,
deterministic: true,
}
}
}