pub mod dequantize_per_channel;
pub mod dequantize_per_channel_backward;
pub mod dequantize_per_tensor;
pub mod dequantize_per_tensor_backward;
pub mod fake_quantize;
pub mod fake_quantize_backward;
pub mod per_channel;
pub mod per_channel_backward;
pub mod per_tensor;
pub mod per_tensor_backward;
pub mod dequantize_per_group;
pub mod dequantize_per_group_backward;
pub mod dequantize_per_token;
pub mod dequantize_per_token_backward;
pub mod per_group;
pub mod per_group_backward;
pub mod per_token;
pub mod per_token_backward;
pub mod dynamic_range;
pub mod quantized_linear;
pub mod smoothquant;
pub mod gguf;
pub mod nf4;
pub use dequantize_per_channel::{
DequantizePerChannelArgs, DequantizePerChannelDescriptor, DequantizePerChannelPlan,
};
pub use dequantize_per_channel_backward::{
DequantizePerChannelBackwardArgs, DequantizePerChannelBackwardDescriptor,
DequantizePerChannelBackwardPlan,
};
pub use dequantize_per_tensor::{
DequantizePerTensorArgs, DequantizePerTensorDescriptor, DequantizePerTensorPlan,
};
pub use dequantize_per_tensor_backward::{
DequantizePerTensorBackwardArgs, DequantizePerTensorBackwardDescriptor,
DequantizePerTensorBackwardPlan,
};
pub use fake_quantize::{FakeQuantizeArgs, FakeQuantizeDescriptor, FakeQuantizePlan};
pub use fake_quantize_backward::{
FakeQuantizeBackwardArgs, FakeQuantizeBackwardDescriptor, FakeQuantizeBackwardPlan,
};
pub use per_channel::{QuantizePerChannelArgs, QuantizePerChannelDescriptor, QuantizePerChannelPlan};
pub use per_channel_backward::{
QuantizePerChannelBackwardArgs, QuantizePerChannelBackwardDescriptor,
QuantizePerChannelBackwardPlan,
};
pub use per_tensor::{QuantizePerTensorArgs, QuantizePerTensorDescriptor, QuantizePerTensorPlan};
pub use per_tensor_backward::{
QuantizePerTensorBackwardArgs, QuantizePerTensorBackwardDescriptor,
QuantizePerTensorBackwardPlan,
};
pub use per_token::{QuantizePerTokenArgs, QuantizePerTokenDescriptor, QuantizePerTokenPlan};
pub use per_token_backward::{
QuantizePerTokenBackwardArgs, QuantizePerTokenBackwardDescriptor, QuantizePerTokenBackwardPlan,
};
pub use dequantize_per_token::{
DequantizePerTokenArgs, DequantizePerTokenDescriptor, DequantizePerTokenPlan,
};
pub use dequantize_per_token_backward::{
DequantizePerTokenBackwardArgs, DequantizePerTokenBackwardDescriptor,
DequantizePerTokenBackwardPlan,
};
pub use per_group::{QuantizePerGroupArgs, QuantizePerGroupDescriptor, QuantizePerGroupPlan};
pub use per_group_backward::{
QuantizePerGroupBackwardArgs, QuantizePerGroupBackwardDescriptor, QuantizePerGroupBackwardPlan,
};
pub use dequantize_per_group::{
DequantizePerGroupArgs, DequantizePerGroupDescriptor, DequantizePerGroupPlan,
};
pub use dequantize_per_group_backward::{
DequantizePerGroupBackwardArgs, DequantizePerGroupBackwardDescriptor,
DequantizePerGroupBackwardPlan,
};
pub use dynamic_range::{
DynamicRangeMode, DynamicRangeQuantizeArgs, DynamicRangeQuantizeDescriptor,
DynamicRangeQuantizePlan, DynamicRangeScope,
};
pub use quantized_linear::{
QuantizedLinearArgs, QuantizedLinearDescriptor, QuantizedLinearPlan,
};
pub use smoothquant::{
SmoothQuantLinearArgs, SmoothQuantLinearDescriptor, SmoothQuantLinearPlan,
};
pub use gguf::{
BlockQ2K, BlockQ3K, BlockQ4_0, BlockQ4_1, BlockQ4K, BlockQ5_0, BlockQ5_1, BlockQ5K, BlockQ6K,
BlockQ8_0, BlockQ8K, GgufDequantizeArgs, GgufDequantizeDescriptor, GgufDequantizePlan,
GgufMmvqArgs, GgufMmvqDescriptor, GgufMmvqPlan,
};
pub use gguf::{
GgufMmvqBatchedActivation, GgufMmvqBatchedArgs, GgufMmvqBatchedDescriptor,
GgufMmvqBatchedFormat, GgufMmvqBatchedPlan,
};
pub use gguf::{GgufMmvqMultiMArgs, GgufMmvqMultiMDescriptor, GgufMmvqMultiMPlan};
pub use nf4::{
Nf4Activation, Nf4DequantizeArgs, Nf4DequantizePlan, Nf4Descriptor, Nf4MmvqArgs,
Nf4MmvqMultiMArgs, Nf4MmvqMultiMDescriptor, Nf4MmvqMultiMPlan, Nf4MmvqPlan, NF4_CODEBOOK,
};
use baracuda_cutlass::{Error, Result};
pub(crate) fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}
pub(crate) fn validate_input_element(
tin_kind: baracuda_kernels_types::ElementKind,
plan_name: &'static str,
) -> Result<()> {
use baracuda_kernels_types::ElementKind;
if !matches!(
tin_kind,
ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
) {
return Err(Error::Unsupported(plan_name));
}
Ok(())
}
pub(crate) fn validate_output_element(
tout_kind: baracuda_kernels_types::ElementKind,
plan_name: &'static str,
) -> Result<()> {
use baracuda_kernels_types::ElementKind;
if !matches!(tout_kind, ElementKind::S8 | ElementKind::U8) {
return Err(Error::Unsupported(plan_name));
}
Ok(())
}
#[inline]
pub fn default_q_range(out_kind: baracuda_kernels_types::ElementKind) -> Option<(i32, i32)> {
use baracuda_kernels_types::ElementKind;
match out_kind {
ElementKind::S8 => Some((-128, 127)),
ElementKind::U8 => Some((0, 255)),
_ => None,
}
}