use crate::error::AttentionResult;
#[derive(Clone, Debug)]
pub struct SparseMask {
pub rows: Vec<usize>,
pub cols: Vec<usize>,
pub values: Option<Vec<f32>>,
}
#[derive(Clone, Debug)]
pub struct EdgeInfo {
pub src: usize,
pub dst: usize,
pub features: Option<Vec<f32>>,
}
pub trait Attention: Send + Sync {
fn compute(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
) -> AttentionResult<Vec<f32>>;
fn compute_with_mask(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
mask: Option<&[bool]>,
) -> AttentionResult<Vec<f32>>;
fn dim(&self) -> usize;
fn num_heads(&self) -> usize {
1
}
}
pub trait GraphAttention: Attention {
fn compute_with_edges(
&self,
node_features: &[Vec<f32>],
edges: &[EdgeInfo],
) -> AttentionResult<Vec<Vec<f32>>>;
fn compute_edge_attention(
&self,
src_feature: &[f32],
dst_feature: &[f32],
edge_feature: Option<&[f32]>,
) -> AttentionResult<f32>;
}
pub trait GeometricAttention: Attention {
fn compute_geometric(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
curvature: f32,
) -> AttentionResult<Vec<f32>>;
fn project_to_geometric(&self, vector: &[f32], curvature: f32) -> AttentionResult<Vec<f32>>;
fn project_from_geometric(&self, vector: &[f32], curvature: f32) -> AttentionResult<Vec<f32>>;
}
pub trait SparseAttention: Attention {
fn compute_sparse(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
mask: &SparseMask,
) -> AttentionResult<Vec<f32>>;
fn generate_mask(&self, seq_len: usize) -> AttentionResult<SparseMask>;
}
#[derive(Clone, Debug)]
pub struct Gradients {
pub query_grad: Vec<f32>,
pub keys_grad: Vec<Vec<f32>>,
pub values_grad: Vec<Vec<f32>>,
pub attention_weights_grad: Option<Vec<f32>>,
}
pub trait TrainableAttention: Attention {
fn forward(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
) -> AttentionResult<(Vec<f32>, Vec<f32>)>;
fn backward(
&self,
grad_output: &[f32],
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
attention_weights: &[f32],
) -> AttentionResult<Gradients>;
fn update_parameters(
&mut self,
gradients: &Gradients,
learning_rate: f32,
) -> AttentionResult<()>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_mask_creation() {
let mask = SparseMask {
rows: vec![0, 1, 2],
cols: vec![0, 1, 2],
values: None,
};
assert_eq!(mask.rows.len(), 3);
assert_eq!(mask.cols.len(), 3);
assert!(mask.values.is_none());
}
#[test]
fn test_edge_info_creation() {
let edge = EdgeInfo {
src: 0,
dst: 1,
features: Some(vec![0.5, 0.3]),
};
assert_eq!(edge.src, 0);
assert_eq!(edge.dst, 1);
assert_eq!(edge.features.as_ref().unwrap().len(), 2);
}
#[test]
fn test_gradients_creation() {
let grads = Gradients {
query_grad: vec![0.1, 0.2],
keys_grad: vec![vec![0.3, 0.4]],
values_grad: vec![vec![0.5, 0.6]],
attention_weights_grad: None,
};
assert_eq!(grads.query_grad.len(), 2);
assert_eq!(grads.keys_grad.len(), 1);
assert!(grads.attention_weights_grad.is_none());
}
}