use pgrx::prelude::*;
use serde::{Deserialize, Serialize};
pub mod flash;
pub mod multi_head;
pub mod operators;
pub mod scaled_dot;
pub use flash::FlashAttention;
pub use multi_head::MultiHeadAttention;
pub use scaled_dot::ScaledDotAttention;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, PostgresEnum)]
pub enum AttentionType {
ScaledDot,
MultiHead,
FlashV2,
Linear,
Gat,
Sparse,
Moe,
Cross,
Sliding,
Poincare,
}
impl Default for AttentionType {
fn default() -> Self {
AttentionType::ScaledDot
}
}
impl AttentionType {
pub fn name(&self) -> &'static str {
match self {
AttentionType::ScaledDot => "scaled_dot",
AttentionType::MultiHead => "multi_head",
AttentionType::FlashV2 => "flash_v2",
AttentionType::Linear => "linear",
AttentionType::Gat => "gat",
AttentionType::Sparse => "sparse",
AttentionType::Moe => "moe",
AttentionType::Cross => "cross",
AttentionType::Sliding => "sliding",
AttentionType::Poincare => "poincare",
}
}
pub fn complexity(&self) -> &'static str {
match self {
AttentionType::ScaledDot => "O(n²)",
AttentionType::MultiHead => "O(n²)",
AttentionType::FlashV2 => "O(n²) memory-efficient",
AttentionType::Linear => "O(n)",
AttentionType::Gat => "O(E) where E=edges",
AttentionType::Sparse => "O(n√n)",
AttentionType::Moe => "O(n*k) where k=experts",
AttentionType::Cross => "O(n*m)",
AttentionType::Sliding => "O(n*w) where w=window",
AttentionType::Poincare => "O(n²)",
}
}
pub fn best_for(&self) -> &'static str {
match self {
AttentionType::ScaledDot => "Small sequences (<512)",
AttentionType::MultiHead => "General purpose, parallel processing",
AttentionType::FlashV2 => "GPU acceleration, large sequences",
AttentionType::Linear => "Very long sequences (>4K)",
AttentionType::Gat => "Graph-structured data",
AttentionType::Sparse => "Ultra-long sequences (>16K)",
AttentionType::Moe => "Conditional computation, routing",
AttentionType::Cross => "Query-document matching",
AttentionType::Sliding => "Local context, streaming",
AttentionType::Poincare => "Hierarchical data structures",
}
}
}
impl std::str::FromStr for AttentionType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"scaled_dot" | "scaleddot" => Ok(AttentionType::ScaledDot),
"multi_head" | "multihead" => Ok(AttentionType::MultiHead),
"flash_v2" | "flashv2" | "flash" => Ok(AttentionType::FlashV2),
"linear" => Ok(AttentionType::Linear),
"gat" => Ok(AttentionType::Gat),
"sparse" => Ok(AttentionType::Sparse),
"moe" => Ok(AttentionType::Moe),
"cross" => Ok(AttentionType::Cross),
"sliding" => Ok(AttentionType::Sliding),
"poincare" | "poincaré" => Ok(AttentionType::Poincare),
_ => Err(format!("Unknown attention type: {}", s)),
}
}
}
pub trait Attention: Send + Sync {
fn attention_scores(&self, query: &[f32], keys: &[&[f32]]) -> Vec<f32>;
fn apply_attention(&self, scores: &[f32], values: &[&[f32]]) -> Vec<f32> {
assert_eq!(
scores.len(),
values.len(),
"Scores and values length mismatch"
);
if values.is_empty() {
return Vec::new();
}
let dim = values[0].len();
let mut result = vec![0.0; dim];
for (score, value) in scores.iter().zip(values.iter()) {
for (r, v) in result.iter_mut().zip(value.iter()) {
*r += score * v;
}
}
result
}
fn forward(&self, query: &[f32], keys: &[&[f32]], values: &[&[f32]]) -> Vec<f32> {
let scores = self.attention_scores(query, keys);
self.apply_attention(&scores, values)
}
}
#[inline]
pub fn softmax(logits: &[f32]) -> Vec<f32> {
if logits.is_empty() {
return Vec::new();
}
let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_values: Vec<f32> = logits.iter().map(|x| (x - max_logit).exp()).collect();
let sum: f32 = exp_values.iter().sum();
if sum > 0.0 {
exp_values.iter().map(|x| x / sum).collect()
} else {
vec![1.0 / logits.len() as f32; logits.len()]
}
}
#[inline]
pub fn softmax_inplace(logits: &mut [f32]) {
if logits.is_empty() {
return;
}
let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
for x in logits.iter_mut() {
*x = (*x - max_logit).exp();
}
let sum: f32 = logits.iter().sum();
if sum > 0.0 {
for x in logits.iter_mut() {
*x /= sum;
}
} else {
let uniform = 1.0 / logits.len() as f32;
for x in logits.iter_mut() {
*x = uniform;
}
}
}
#[cfg(feature = "pg_test")]
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_softmax() {
let logits = vec![1.0, 2.0, 3.0];
let result = softmax(&logits);
let sum: f32 = result.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
assert!(result[2] > result[1]);
assert!(result[1] > result[0]);
}
#[test]
fn test_softmax_inplace() {
let mut logits = vec![1.0, 2.0, 3.0];
softmax_inplace(&mut logits);
let sum: f32 = logits.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
assert!(logits[2] > logits[1]);
assert!(logits[1] > logits[0]);
}
#[test]
fn test_softmax_numerical_stability() {
let logits = vec![1000.0, 1001.0, 1002.0];
let result = softmax(&logits);
let sum: f32 = result.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
assert!(result.iter().all(|x| x.is_finite()));
}
#[test]
fn test_attention_type_parsing() {
assert_eq!(
"scaled_dot".parse::<AttentionType>().unwrap(),
AttentionType::ScaledDot
);
assert_eq!(
"flash_v2".parse::<AttentionType>().unwrap(),
AttentionType::FlashV2
);
assert_eq!(
"multi_head".parse::<AttentionType>().unwrap(),
AttentionType::MultiHead
);
assert!("unknown".parse::<AttentionType>().is_err());
}
}