pub mod attention;
pub mod config;
pub mod error;
pub mod graph;
pub mod hyperbolic;
pub mod moe;
pub mod sdk;
pub mod sparse;
pub mod training;
pub mod traits;
pub mod utils;
pub mod curvature;
pub mod topology;
pub mod transport;
pub mod info_bottleneck;
pub mod info_geometry;
pub mod pde_attention;
pub mod unified_report;
#[cfg(feature = "sheaf")]
pub mod sheaf;
pub use attention::{MLACache, MLAConfig, MLALayer, MemoryComparison};
pub use attention::{MultiHeadAttention, ScaledDotProductAttention};
pub use config::{AttentionConfig, GraphAttentionConfig, SparseAttentionConfig};
pub use error::{AttentionError, AttentionResult};
pub use hyperbolic::{
exp_map, log_map, mobius_add, poincare_distance, project_to_ball, HyperbolicAttention,
HyperbolicAttentionConfig, MixedCurvatureAttention, MixedCurvatureConfig,
};
pub use traits::{
Attention, EdgeInfo, GeometricAttention, Gradients, GraphAttention, SparseAttention,
SparseMask, TrainableAttention,
};
pub use sparse::{
AttentionMask, FlashAttention, LinearAttention, LocalGlobalAttention, SparseMaskBuilder,
};
pub use moe::{
Expert, ExpertType, HyperbolicExpert, LearnedRouter, LinearExpert, MoEAttention, MoEConfig,
Router, StandardExpert, TopKRouting,
};
pub use graph::{
DualSpaceAttention, DualSpaceConfig, EdgeFeaturedAttention, EdgeFeaturedConfig, GraphRoPE,
RoPEConfig,
};
pub use training::{
Adam, AdamW, CurriculumScheduler, CurriculumStage, DecayType, HardNegativeMiner, InfoNCELoss,
LocalContrastiveLoss, Loss, MiningStrategy, NegativeMiner, Optimizer, Reduction,
SpectralRegularization, TemperatureAnnealing, SGD,
};
pub use sdk::{presets, AttentionBuilder, AttentionPipeline};
pub use transport::{
CentroidCache, CentroidOTAttention, CentroidOTConfig, ProjectionCache,
SlicedWassersteinAttention, SlicedWassersteinConfig, WindowCache,
};
pub use curvature::{
ComponentQuantizer, FusedCurvatureConfig, MixedCurvatureCache, MixedCurvatureFusedAttention,
QuantizationConfig, QuantizedVector, TangentSpaceConfig, TangentSpaceMapper,
};
pub use topology::{
AttentionMode, AttentionPolicy, CoherenceMetric, PolicyConfig, TopologyGatedAttention,
TopologyGatedConfig, WindowCoherence,
};
pub use info_geometry::{FisherConfig, FisherMetric, NaturalGradient, NaturalGradientConfig};
pub use info_bottleneck::{DiagonalGaussian, IBConfig, InformationBottleneck, KLDivergence};
pub use pde_attention::{DiffusionAttention, DiffusionConfig, GraphLaplacian, LaplacianType};
#[cfg(feature = "sheaf")]
pub use sheaf::{
process_with_early_exit, ComputeLane, EarlyExit, EarlyExitConfig, EarlyExitResult,
EarlyExitStatistics, ExitReason, LaneStatistics, ResidualSparseMask, RestrictionMap,
RestrictionMapConfig, RoutingDecision, SheafAttention, SheafAttentionConfig,
SparseResidualAttention, SparseResidualConfig, SparsityStatistics, TokenRouter,
TokenRouterConfig,
};
pub use unified_report::{
AttentionRecommendation, GeometryReport, MetricType, MetricValue, ReportBuilder, ReportConfig,
};
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_version() {
assert!(!VERSION.is_empty());
}
#[test]
fn test_basic_attention_workflow() {
let config = AttentionConfig::builder()
.dim(64)
.num_heads(4)
.build()
.unwrap();
assert_eq!(config.dim, 64);
assert_eq!(config.num_heads, 4);
assert_eq!(config.head_dim(), 16);
}
}