pub mod block_sparse;
pub mod flash_attn;
pub mod fused_rope_attn;
pub mod gqa;
pub mod kv_cache;
pub mod mha;
pub mod ring_attention;
pub mod rope;
pub mod sliding_window;
pub mod speculative_decode;
pub use block_sparse::{BlockSparseAttentionPlan, BlockSparseConfig, BlockSparsePattern};
pub use flash_attn::backward::flash_attention_backward;
pub use flash_attn::decode::single_query_decode_attention;
pub use flash_attn::forward::{FlashAttentionConfig, flash_attention_forward};
pub use flash_attn::paged::{PagedAttentionConfig, paged_attention_decode};
pub use fused_rope_attn::{FusedRopeAttnConfig, FusedRopeAttnPlan};
pub use gqa::{GqaConfig, gqa_forward};
pub use kv_cache::{KvCache, KvCacheConfig};
pub use mha::multi_head_attention;
pub use ring_attention::{
RingAttentionConfig, RingAttentionDtype, RingAttentionPlan, RingAttentionStats, RingCommPlan,
RingStep,
};
pub use rope::apply_rope;
pub use sliding_window::{SlidingWindowConfig, sliding_window_attention};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Tcgen05AttentionConfig {
pub sm_version: u32,
pub m_tile: usize,
pub n_tile: usize,
pub supports_fp4: bool,
}
impl Tcgen05AttentionConfig {
#[must_use]
pub fn new(sm_version: u32) -> Self {
Self {
sm_version,
m_tile: 128,
n_tile: 256,
supports_fp4: sm_version >= 100,
}
}
#[must_use]
pub fn is_valid(&self) -> bool {
self.sm_version >= 100
}
}
pub use speculative_decode::{
KvCheckpoint, KvManager, SpecDecConfig, SpecDecOutput, SpeculativeDecodeConfig,
SpeculativeDecodePlan, SpeculativeDecoder, SpeculativeKvManager, TokenVerificationResult,
VerificationResult, accept_token,
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tcgen05_valid_for_sm100() {
let cfg = Tcgen05AttentionConfig::new(100);
assert!(cfg.is_valid(), "sm_100 must be valid for tcgen05");
}
#[test]
fn tcgen05_valid_for_sm120() {
let cfg = Tcgen05AttentionConfig::new(120);
assert!(cfg.is_valid(), "sm_120 must be valid for tcgen05");
}
#[test]
fn tcgen05_invalid_for_sm90() {
let cfg = Tcgen05AttentionConfig::new(90);
assert!(!cfg.is_valid(), "sm_90 should not be valid for tcgen05");
}
#[test]
fn tcgen05_invalid_for_sm80() {
let cfg = Tcgen05AttentionConfig::new(80);
assert!(!cfg.is_valid(), "sm_80 should not be valid for tcgen05");
}
#[test]
fn tcgen05_tile_shape_m128_n256() {
let cfg = Tcgen05AttentionConfig::new(100);
assert_eq!(cfg.m_tile, 128, "tcgen05 M tile must be 128");
assert_eq!(cfg.n_tile, 256, "tcgen05 N tile must be 256");
}
#[test]
fn tcgen05_supports_fp4_on_blackwell() {
let cfg = Tcgen05AttentionConfig::new(100);
assert!(
cfg.supports_fp4,
"Blackwell (sm_100) should support FP4 input"
);
}
#[test]
fn tcgen05_does_not_support_fp4_on_hopper() {
let cfg = Tcgen05AttentionConfig::new(90);
assert!(
!cfg.supports_fp4,
"Hopper (sm_90) should not support FP4 input"
);
}
#[test]
fn tcgen05_sm_version_stored_correctly() {
let cfg = Tcgen05AttentionConfig::new(100);
assert_eq!(cfg.sm_version, 100);
}
#[test]
fn tcgen05_tile_shape_consistent_across_blackwell() {
let sm100 = Tcgen05AttentionConfig::new(100);
let sm120 = Tcgen05AttentionConfig::new(120);
assert_eq!(
sm100.m_tile, sm120.m_tile,
"M tile must be the same for all Blackwell"
);
assert_eq!(
sm100.n_tile, sm120.n_tile,
"N tile must be the same for all Blackwell"
);
}
}