#![allow(clippy::module_name_repetitions, clippy::too_many_arguments)]
pub mod actor;
pub mod dispatch;
pub mod fa2;
pub mod fa3;
#[cfg(feature = "paged")]
pub mod paged;
pub mod prefill;
pub mod varlen;
#[cfg(feature = "cuda-runtime-tests")]
mod cuda_real {
pub struct NvrtcRef {
pub(crate) _opaque: (),
}
}
pub use actor::{FlashAttnActor, FlashAttnInner, FlashAttnMsg, FlashAttnProps};
pub use dispatch::{
lookup, Bf16, DType, DispatchError, DispatchKey, DispatchTable, FaBwdDispatch, FaFwdDispatch,
FaPagedFwdDispatch, GemmSupported, SmArch, DISPATCH_TABLE, F16,
};
#[cfg(feature = "fp8")]
pub use dispatch::{F8E4m3, F8E5m2};
pub use fa2::{Fa2BwdRequest, Fa2FwdRequest, MaskKind, PositionBias};
#[cfg(feature = "fp8")]
pub use fa3::Fa3FwdFp8Request;
pub use fa3::{Fa3FwdRequest, PersistentMode};
#[cfg(feature = "paged")]
pub use paged::{PagedAttentionRequest, PagedKvCache};
pub use prefill::{ChunkLayout, ChunkedPrefillRequest};
pub use varlen::{CumulativeSeqlens, VarlenFwdRequest};
#[derive(Debug, Clone, thiserror::Error)]
pub enum FlashAttnError {
#[error("dispatch error: {0}")]
Dispatch(#[from] DispatchError),
#[error("FA3 requires sm_90a or newer, got {0:?}")]
Fa3RequiresHopper(SmArch),
#[error("fp8 dtypes must use Fa3FwdFp8Request and vice versa")]
Fp8MustUseFp8Request,
#[error("attention batch must contain at least one sequence")]
EmptyBatch,
#[error("seqlen must be > 0")]
ZeroSeqlen,
#[error("cumulative seqlens overflow batch_size * max_seqlen")]
SeqlenOverflow,
#[error("paged KV cache must be non-empty")]
EmptyPagedCache,
#[error("paged block_size {0} is not in (8, 16, 32, 64, 128)")]
InvalidPagedBlockSize(u32),
#[error("paged cache head_dim {cache} != request head_dim {req}")]
PagedHeadDimMismatch { cache: u32, req: u32 },
#[error("chunk_index {index} >= total_chunks {total}")]
ChunkIndexOutOfRange { index: u32, total: u32 },
#[error("flashattn actor is in mock mode (no GPU)")]
MockMode,
}