pub mod alibi;
pub mod alibi_backward;
pub mod flash_decoding;
pub mod flash_sdpa;
pub mod flash_sdpa_backward;
pub mod flash_sdpa_varlen;
#[cfg(feature = "sm89")]
pub mod flash_sdpa_sm89;
pub mod kv_cache;
pub mod hyper_connection;
pub mod rope;
pub mod rope_backward;
pub mod rope_scaling;
pub mod sdpa;
pub mod sdpa_backward;
pub use alibi::{AlibiArgs, AlibiDescriptor, AlibiPlan};
pub use alibi_backward::{AlibiBackwardArgs, AlibiBackwardDescriptor, AlibiBackwardPlan};
pub use flash_decoding::{
FlashDecodingArgs, FlashDecodingDescriptor, FlashDecodingPlan, FLASH_DECODING_MAX_D,
};
pub use flash_sdpa::{FlashSdpaArgs, FlashSdpaDescriptor, FlashSdpaPlan, FLASH_SDPA_MAX_D};
pub use flash_sdpa_backward::{
FlashSdpaBackwardArgs, FlashSdpaBackwardDescriptor, FlashSdpaBackwardPlan,
};
pub use flash_sdpa_varlen::{
FlashSdpaVarlenArgs, FlashSdpaVarlenBackwardArgs, FlashSdpaVarlenBackwardPlan,
FlashSdpaVarlenDescriptor, FlashSdpaVarlenPlan,
};
#[cfg(feature = "sm89")]
pub use flash_sdpa_sm89::{FlashSdpaSm89Args, FlashSdpaSm89Descriptor, FlashSdpaSm89Plan};
pub use hyper_connection::{
HyperConnectionArgs, HyperConnectionDescriptor, HyperConnectionPlan,
};
pub use kv_cache::{KvCacheAppendArgs, KvCacheAppendDescriptor, KvCacheAppendPlan};
pub use rope::{RopeArgs, RopeDescriptor, RopePlan};
pub use rope_backward::{RopeBackwardArgs, RopeBackwardDescriptor, RopeBackwardPlan};
pub use rope_scaling::{RopeScaledTableBuilder, RopeScaling};
pub use sdpa::{SdpaArgs, SdpaDescriptor, SdpaPlan};
pub use sdpa_backward::{SdpaBackwardArgs, SdpaBackwardDescriptor, SdpaBackwardPlan};
use baracuda_cutlass::{Error, Result};
pub(crate) fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}
#[doc(hidden)]
#[allow(dead_code)]
pub(crate) fn map_status_pub(code: i32) -> Result<()> {
map_status(code)
}
pub const ROPE_DEFAULT_BASE: f32 = 10000.0;
pub mod batch_paged_decode;
pub mod batch_paged_decode_fp8;
pub mod batch_paged_prefill;
pub mod batch_ragged_prefill;
pub mod cascade_attn;
pub mod paged_kv_append;
pub mod sdpa_block_sparse;
#[cfg(feature = "mamba")]
pub mod ssd_chunk_scan;
#[cfg(feature = "mamba")]
pub mod selective_scan;
pub mod ring_attention;
pub use batch_paged_decode::{
BatchPagedDecodeArgs, BatchPagedDecodeDescriptor, BatchPagedDecodePlan,
PagedKvCacheDescriptor,
};
pub use batch_paged_prefill::{
BatchPagedPrefillArgs, BatchPagedPrefillDescriptor, BatchPagedPrefillPlan,
};
pub use batch_paged_decode_fp8::{
BatchPagedDecodeFp8Args, BatchPagedDecodeFp8Descriptor, BatchPagedDecodeFp8Plan, Fp8KvDtype,
};
pub use batch_ragged_prefill::{
BatchRaggedPrefillArgs, BatchRaggedPrefillDescriptor, BatchRaggedPrefillPlan,
};
pub use cascade_attn::{
CascadeAttentionArgs, CascadeAttentionDescriptor, CascadeAttentionPlan,
CascadeMergeStatesArgs, CascadeMergeStatesDescriptor, CascadeMergeStatesPlan,
};
pub use paged_kv_append::{
PagedKvAppendArgs, PagedKvAppendDescriptor, PagedKvAppendPlan,
};
pub use sdpa_block_sparse::{
SdpaBlockSparseArgs, SdpaBlockSparseDescriptor, SdpaBlockSparsePlan,
SDPA_BLOCK_SPARSE_MAX_BLOCK, SDPA_BLOCK_SPARSE_MAX_D,
};
#[cfg(feature = "mamba")]
pub use ssd_chunk_scan::{
SsdChunkScanArgs, SsdChunkScanBackwardArgs, SsdChunkScanBackwardDescriptor,
SsdChunkScanBackwardPlan, SsdChunkScanDescriptor, SsdChunkScanPlan,
};
#[cfg(feature = "mamba")]
pub use selective_scan::{
SelectiveScanArgs, SelectiveScanBackwardArgs, SelectiveScanBackwardDescriptor,
SelectiveScanBackwardPlan, SelectiveScanDescriptor, SelectiveScanPlan,
};
pub use ring_attention::{
RingAttentionArgs, RingAttentionDescriptor, RingAttentionPlan, RING_ATTENTION_HEAD_DIM,
};