pub mod attention;
pub mod checkpointing;
pub mod config;
pub mod decoder;
pub mod encoder;
pub mod error;
pub mod ffn;
pub mod flash_attention;
pub mod gqa;
pub mod kv_cache;
pub mod layers;
pub mod lora;
pub mod moe;
pub mod normalization;
pub mod normalization_variants;
pub mod patterns;
pub mod position;
pub mod presets;
pub mod quantization;
pub mod rule_attention;
pub mod rule_guided_decoder;
pub mod sliding_window;
pub mod sparse_attention;
pub mod speculative_decoding;
pub mod stacks;
pub mod trustformers_integration;
pub mod utils;
pub mod vision;
pub use attention::{MultiHeadAttention, SelfAttention};
pub use checkpointing::{CheckpointConfig, CheckpointStrategy};
pub use config::{AttentionConfig, FeedForwardConfig, TransformerLayerConfig};
pub use decoder::{Decoder, DecoderConfig};
pub use encoder::{Encoder, EncoderConfig};
pub use error::{Result, TrustformerError};
pub use ffn::{FeedForward, GatedFeedForward};
pub use flash_attention::{
FlashAttention, FlashAttentionConfig, FlashAttentionPreset, FlashAttentionStats,
FlashAttentionV2Config,
};
pub use gqa::{GQAConfig, GQAPreset, GQAStats, GroupedQueryAttention};
pub use kv_cache::{
CacheStats, CachedAttention, CachedAttentionError, InferenceStats, KVCache, KVCacheConfig,
KvCache, KvCacheError, PositionError, RelativePositionBias, RotaryPositionEmbedding,
};
pub use layers::{DecoderLayer, DecoderLayerConfig, EncoderLayer, EncoderLayerConfig};
pub use lora::{LoRAAttention, LoRAConfig, LoRALinear, LoRAPreset, LoRAStats};
pub use moe::{
combined_aux_loss, importance_loss, load_loss, BatchGatingStats, Expert, GatingDecision,
LinearExpert, MoELayer, MoeConfig, MoeError, MoeLayer, MoePreset, MoeStats, RouterType,
TopKGate,
};
pub use normalization::{LayerNorm, LayerNormConfig, RMSNorm};
pub use normalization_variants::{
BatchNorm, GroupNorm, InstanceNorm, NormStats, NormalizationError, RmsNorm, WeightNorm,
};
pub use patterns::{
AttentionMask, BlockSparseMask, CausalMask, GlobalLocalMask, LocalMask, RuleBasedMask,
RulePattern, StridedMask,
};
pub use position::{
AlibiPositionEncoding, LearnedPositionEncoding, PositionEncodingConfig, PositionEncodingType,
RelativePositionEncoding, RotaryPositionEncoding, SinusoidalPositionEncoding,
};
pub use presets::ModelPreset;
pub use quantization::{calibrate_linear, QuantizationError, QuantizedLinear};
pub use rule_attention::{
RuleAttentionConfig, RuleAttentionType, RuleBasedAttention, StructuredAttention,
};
pub use rule_guided_decoder::{
ConstraintVerdict, HardMask, LogitMasker, RuleConstraint, RuleGuidedBeamSearch,
RuleGuidedError, RuleGuidedResult, SoftPenaltyMask, TokenId, TokenSymbolMapper,
};
pub use sliding_window::{
SlidingWindowAttention, SlidingWindowConfig, SlidingWindowPreset, SlidingWindowStats,
};
pub use sparse_attention::{
build_mask, LocalAttention, SparseAttention, SparseAttentionConfig, SparseAttentionError,
SparseAttentionGraph, SparseAttentionGraphConfig, SparsePatternType,
};
pub use speculative_decoding::{
DraftModel, DraftProposal, FixedDistDraftModel, FixedDistTargetModel, LogProb, MockDraftModel,
MockTargetModel, SpecRng, SpeculativeDecoder, SpeculativeDecoderConfig,
SpeculativeDecodingError, SpeculativeDecodingResult, SpeculativeMetrics, TargetModel,
TargetScores,
};
pub use stacks::{DecoderStack, DecoderStackConfig, EncoderStack, EncoderStackConfig};
pub use trustformers_integration::{
CheckpointData, IntegrationConfig, ModelConfig, TensorLogicModel, TrustformersConverter,
TrustformersWeightLoader,
};
pub use utils::{decoder_stack_stats, encoder_stack_stats, ModelStats};
pub use vision::{
PatchEmbedding, PatchEmbeddingConfig, ViTPreset, VisionTransformer, VisionTransformerConfig,
};
#[deprecated(since = "0.1.0", note = "Use AttentionConfig instead")]
pub type AttnSpec = AttentionConfig;
#[deprecated(
since = "0.1.0",
note = "Use SelfAttention::build_attention_graph instead"
)]
pub fn self_attention_as_rules(_spec: &AttentionConfig) {
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::EinsumGraph;
#[test]
fn test_end_to_end_self_attention() {
let config = AttentionConfig::new(512, 8).expect("unwrap");
let attn = SelfAttention::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("Q");
graph.add_tensor("K");
graph.add_tensor("V");
let outputs = attn.build_attention_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(graph.validate().is_ok());
}
#[test]
fn test_end_to_end_multi_head_attention() {
let config = AttentionConfig::new(512, 8).expect("unwrap");
let mha = MultiHeadAttention::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("Q");
graph.add_tensor("K");
graph.add_tensor("V");
let outputs = mha.build_mha_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(graph.validate().is_ok());
}
#[test]
fn test_end_to_end_ffn() {
let config = FeedForwardConfig::new(512, 2048);
let ffn = FeedForward::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("x");
graph.add_tensor("W1");
graph.add_tensor("b1");
graph.add_tensor("W2");
graph.add_tensor("b2");
let outputs = ffn.build_ffn_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(graph.validate().is_ok());
}
#[test]
fn test_end_to_end_gated_ffn() {
let config = FeedForwardConfig::new(512, 2048);
let glu = GatedFeedForward::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("x");
graph.add_tensor("W_gate");
graph.add_tensor("W_value");
graph.add_tensor("W_out");
let outputs = glu.build_glu_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(graph.validate().is_ok());
}
#[test]
fn test_transformer_layer_config() {
let config = TransformerLayerConfig::new(512, 8, 2048).expect("unwrap");
assert_eq!(config.attention.d_model, 512);
assert_eq!(config.attention.n_heads, 8);
assert_eq!(config.feed_forward.d_ff, 2048);
assert!(config.validate().is_ok());
}
#[test]
fn test_config_builder_pattern() {
let config = AttentionConfig::new(512, 8)
.expect("unwrap")
.with_causal(true)
.with_dropout(0.1);
assert!(config.causal);
assert!((config.dropout - 0.1).abs() < 1e-10);
assert!(config.validate().is_ok());
}
#[test]
fn test_ffn_config_builder() {
let config = FeedForwardConfig::new(512, 2048)
.with_activation("relu")
.with_dropout(0.1);
assert_eq!(config.activation, "relu");
assert!((config.dropout - 0.1).abs() < 1e-10);
assert!(config.validate().is_ok());
}
#[test]
fn test_invalid_configurations() {
let result = AttentionConfig::new(512, 7);
assert!(result.is_err());
let result = AttentionConfig::new(512, 8);
assert!(result.is_ok());
}
}