use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossAttentionConfig {
pub hidden_size: usize,
pub num_heads: usize,
pub head_dim: Option<usize>,
pub attention_dropout: f32,
pub bias: bool,
pub scale: Option<f32>,
pub max_seq_len: usize,
pub attention_type: CrossAttentionType,
pub sparse_config: Option<SparseAttentionConfig>,
pub hierarchical_config: Option<HierarchicalAttentionConfig>,
pub adaptive_config: Option<AdaptiveAttentionConfig>,
pub gated_config: Option<GatedAttentionConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CrossAttentionType {
Standard,
MultiHead,
Sparse,
Hierarchical,
Adaptive,
Gated,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SparseAttentionConfig {
pub pattern: SparsePattern,
pub sparsity_ratio: f32,
pub block_size: Option<usize>,
pub random_connections: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SparsePattern {
Random,
Block,
Strided,
Local,
TopK,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HierarchicalAttentionConfig {
pub num_levels: usize,
pub pooling_factor: usize,
pub learnable_pooling: bool,
pub aggregation_method: AggregationMethod,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AggregationMethod {
WeightedSum,
Concatenation,
MaxPooling,
AvgPooling,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaptiveAttentionConfig {
pub num_patterns: usize,
pub pattern_dim: usize,
pub temperature: f32,
pub hard_selection: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatedAttentionConfig {
pub gate_activation: GateActivation,
pub separate_gates: bool,
pub gate_hidden_dim: Option<usize>,
pub gate_bias: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GateActivation {
Sigmoid,
Tanh,
ReLU,
GELU,
Swish,
}
impl Default for CrossAttentionConfig {
fn default() -> Self {
Self {
hidden_size: 512,
num_heads: 8,
head_dim: None,
attention_dropout: 0.1,
bias: true,
scale: None,
max_seq_len: 1024,
attention_type: CrossAttentionType::Standard,
sparse_config: None,
hierarchical_config: None,
adaptive_config: None,
gated_config: None,
}
}
}
impl Default for SparseAttentionConfig {
fn default() -> Self {
Self {
pattern: SparsePattern::Random,
sparsity_ratio: 0.1,
block_size: Some(64),
random_connections: Some(32),
}
}
}
impl Default for HierarchicalAttentionConfig {
fn default() -> Self {
Self {
num_levels: 3,
pooling_factor: 2,
learnable_pooling: true,
aggregation_method: AggregationMethod::WeightedSum,
}
}
}
impl Default for AdaptiveAttentionConfig {
fn default() -> Self {
Self {
num_patterns: 4,
pattern_dim: 64,
temperature: 1.0,
hard_selection: false,
}
}
}
impl Default for GatedAttentionConfig {
fn default() -> Self {
Self {
gate_activation: GateActivation::Sigmoid,
separate_gates: false,
gate_hidden_dim: None,
gate_bias: true,
}
}
}
impl CrossAttentionConfig {
pub fn get_head_dim(&self) -> usize {
self.head_dim.unwrap_or(self.hidden_size / self.num_heads)
}
pub fn get_scale(&self) -> f32 {
self.scale.unwrap_or(1.0 / (self.get_head_dim() as f32).sqrt())
}
pub fn validate(&self) -> Result<(), String> {
if !self.hidden_size.is_multiple_of(self.num_heads) {
return Err("hidden_size must be divisible by num_heads".to_string());
}
if self.attention_dropout < 0.0 || self.attention_dropout > 1.0 {
return Err("attention_dropout must be between 0.0 and 1.0".to_string());
}
if let Some(sparse_config) = &self.sparse_config {
if sparse_config.sparsity_ratio < 0.0 || sparse_config.sparsity_ratio > 1.0 {
return Err("sparsity_ratio must be between 0.0 and 1.0".to_string());
}
}
if let Some(hierarchical_config) = &self.hierarchical_config {
if hierarchical_config.num_levels == 0 {
return Err("num_levels must be greater than 0".to_string());
}
if hierarchical_config.pooling_factor == 0 {
return Err("pooling_factor must be greater than 0".to_string());
}
}
if let Some(adaptive_config) = &self.adaptive_config {
if adaptive_config.num_patterns == 0 {
return Err("num_patterns must be greater than 0".to_string());
}
if adaptive_config.temperature <= 0.0 {
return Err("temperature must be greater than 0".to_string());
}
}
Ok(())
}
pub fn standard(hidden_size: usize, num_heads: usize) -> Self {
Self {
hidden_size,
num_heads,
attention_type: CrossAttentionType::Standard,
..Default::default()
}
}
pub fn sparse(hidden_size: usize, num_heads: usize, sparsity_ratio: f32) -> Self {
Self {
hidden_size,
num_heads,
attention_type: CrossAttentionType::Sparse,
sparse_config: Some(SparseAttentionConfig {
sparsity_ratio,
..Default::default()
}),
..Default::default()
}
}
pub fn hierarchical(hidden_size: usize, num_heads: usize, num_levels: usize) -> Self {
Self {
hidden_size,
num_heads,
attention_type: CrossAttentionType::Hierarchical,
hierarchical_config: Some(HierarchicalAttentionConfig {
num_levels,
..Default::default()
}),
..Default::default()
}
}
pub fn adaptive(hidden_size: usize, num_heads: usize, num_patterns: usize) -> Self {
Self {
hidden_size,
num_heads,
attention_type: CrossAttentionType::Adaptive,
adaptive_config: Some(AdaptiveAttentionConfig {
num_patterns,
..Default::default()
}),
..Default::default()
}
}
pub fn gated(hidden_size: usize, num_heads: usize) -> Self {
Self {
hidden_size,
num_heads,
attention_type: CrossAttentionType::Gated,
gated_config: Some(GatedAttentionConfig::default()),
..Default::default()
}
}
}