use crate::tensor::DenseTensor;
use crate::tensor::traits::TensorBase;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GraphEdgeType {
SelfAttention,
DataFlow,
Residual,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SkipType {
PreNorm,
PostNorm,
}
#[derive(Debug, Clone)]
pub struct SelfAttentionEdge {
pub weight: f64,
pub head: usize,
pub layer: usize,
pub message: Option<DenseTensor>,
pub key_proj: Option<DenseTensor>,
pub value_proj: Option<DenseTensor>,
}
impl SelfAttentionEdge {
pub fn new(weight: f64, head: usize, layer: usize) -> Self {
Self {
weight,
head,
layer,
message: None,
key_proj: None,
value_proj: None,
}
}
pub fn with_message(weight: f64, head: usize, layer: usize, message: DenseTensor) -> Self {
Self {
weight,
head,
layer,
message: Some(message),
key_proj: None,
value_proj: None,
}
}
pub fn with_qkv(
weight: f64,
head: usize,
layer: usize,
q_proj: DenseTensor,
k_proj: DenseTensor,
v_proj: DenseTensor,
) -> Self {
Self {
weight,
head,
layer,
message: Some(q_proj),
key_proj: Some(k_proj),
value_proj: Some(v_proj),
}
}
pub fn set_message(&mut self, message: DenseTensor) {
self.message = Some(message);
}
pub fn message(&self) -> Option<&DenseTensor> {
self.message.as_ref()
}
pub fn set_key_proj(&mut self, key: DenseTensor) {
self.key_proj = Some(key);
}
pub fn key_proj(&self) -> Option<&DenseTensor> {
self.key_proj.as_ref()
}
pub fn set_value_proj(&mut self, value: DenseTensor) {
self.value_proj = Some(value);
}
pub fn value_proj(&self) -> Option<&DenseTensor> {
self.value_proj.as_ref()
}
pub fn get_qkv(&self) -> (Option<&DenseTensor>, Option<&DenseTensor>, Option<&DenseTensor>) {
(self.message.as_ref(), self.key_proj.as_ref(), self.value_proj.as_ref())
}
pub fn has_qkv(&self) -> bool {
self.message.is_some() && self.key_proj.is_some() && self.value_proj.is_some()
}
pub fn compute_attention_score(&self, d_k: f64) -> Option<f64> {
if let (Some(q), Some(k)) = (&self.message, &self.key_proj) {
if q.shape() == k.shape() && q.ndim() == 2 {
let q_data = q.data();
let k_data = k.data();
let dot_product: f64 = q_data.iter()
.zip(k_data.iter())
.map(|(&q_val, &k_val)| q_val * k_val)
.sum();
Some(dot_product / d_k.sqrt())
} else {
None
}
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct DataFlowEdge {
pub operation: DataFlowOp,
pub layer: usize,
pub message: Option<DenseTensor>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DataFlowOp {
InputToAttention,
AttentionToOutput,
InputToFFN,
FFNToOutput,
LayerToLayer,
}
impl DataFlowEdge {
pub fn new(operation: DataFlowOp, layer: usize) -> Self {
Self {
operation,
layer,
message: None,
}
}
pub fn with_message(operation: DataFlowOp, layer: usize, message: DenseTensor) -> Self {
Self {
operation,
layer,
message: Some(message),
}
}
pub fn set_message(&mut self, message: DenseTensor) {
self.message = Some(message);
}
pub fn message(&self) -> Option<&DenseTensor> {
self.message.as_ref()
}
}
#[derive(Debug, Clone)]
pub struct ResidualEdge {
pub layer: usize,
pub skip_type: SkipType,
pub residual: Option<DenseTensor>,
}
impl ResidualEdge {
pub fn new(layer: usize, skip_type: SkipType) -> Self {
Self {
layer,
skip_type,
residual: None,
}
}
pub fn with_residual(layer: usize, skip_type: SkipType, residual: DenseTensor) -> Self {
Self {
layer,
skip_type,
residual: Some(residual),
}
}
pub fn set_residual(&mut self, residual: DenseTensor) {
self.residual = Some(residual);
}
pub fn residual(&self) -> Option<&DenseTensor> {
self.residual.as_ref()
}
}
#[derive(Debug, Clone)]
pub struct GraphEdge {
pub edge_type: GraphEdgeType,
pub source: usize,
pub target: usize,
pub self_attention: Option<SelfAttentionEdge>,
pub data_flow: Option<DataFlowEdge>,
pub residual: Option<ResidualEdge>,
}
impl GraphEdge {
pub fn self_attention(source: usize, target: usize, weight: f64, head: usize, layer: usize) -> Self {
Self {
edge_type: GraphEdgeType::SelfAttention,
source,
target,
self_attention: Some(SelfAttentionEdge::new(weight, head, layer)),
data_flow: None,
residual: None,
}
}
pub fn data_flow(source: usize, target: usize, operation: DataFlowOp, layer: usize) -> Self {
Self {
edge_type: GraphEdgeType::DataFlow,
source,
target,
self_attention: None,
data_flow: Some(DataFlowEdge::new(operation, layer)),
residual: None,
}
}
pub fn residual(source: usize, target: usize, layer: usize, skip_type: SkipType) -> Self {
Self {
edge_type: GraphEdgeType::Residual,
source,
target,
self_attention: None,
data_flow: None,
residual: Some(ResidualEdge::new(layer, skip_type)),
}
}
pub fn get_self_attention(&self) -> Option<&SelfAttentionEdge> {
self.self_attention.as_ref()
}
pub fn get_data_flow(&self) -> Option<&DataFlowEdge> {
self.data_flow.as_ref()
}
pub fn get_residual(&self) -> Option<&ResidualEdge> {
self.residual.as_ref()
}
pub fn layer(&self) -> usize {
if let Some(sa) = &self.self_attention {
sa.layer
} else if let Some(df) = &self.data_flow {
df.layer
} else if let Some(res) = &self.residual {
res.layer
} else {
0
}
}
pub fn self_attention_with_message(
source: usize,
target: usize,
weight: f64,
head: usize,
layer: usize,
message: DenseTensor,
) -> Self {
Self {
edge_type: GraphEdgeType::SelfAttention,
source,
target,
self_attention: Some(SelfAttentionEdge::with_message(weight, head, layer, message)),
data_flow: None,
residual: None,
}
}
pub fn data_flow_with_message(
source: usize,
target: usize,
operation: DataFlowOp,
layer: usize,
message: DenseTensor,
) -> Self {
Self {
edge_type: GraphEdgeType::DataFlow,
source,
target,
self_attention: None,
data_flow: Some(DataFlowEdge::with_message(operation, layer, message)),
residual: None,
}
}
pub fn residual_with_tensor(
source: usize,
target: usize,
layer: usize,
skip_type: SkipType,
residual: DenseTensor,
) -> Self {
Self {
edge_type: GraphEdgeType::Residual,
source,
target,
self_attention: None,
data_flow: None,
residual: Some(ResidualEdge::with_residual(layer, skip_type, residual)),
}
}
pub fn message(&self) -> Option<&DenseTensor> {
match self.edge_type {
GraphEdgeType::SelfAttention => {
self.self_attention.as_ref().and_then(|sa| sa.message.as_ref())
}
GraphEdgeType::DataFlow => {
self.data_flow.as_ref().and_then(|df| df.message.as_ref())
}
GraphEdgeType::Residual => {
self.residual.as_ref().and_then(|r| r.residual.as_ref())
}
}
}
pub fn set_message(&mut self, message: DenseTensor) -> bool {
match self.edge_type {
GraphEdgeType::SelfAttention => {
if let Some(ref mut sa) = self.self_attention {
sa.set_message(message);
true
} else {
false
}
}
GraphEdgeType::DataFlow => {
if let Some(ref mut df) = self.data_flow {
df.set_message(message);
true
} else {
false
}
}
GraphEdgeType::Residual => {
if let Some(ref mut r) = self.residual {
r.set_residual(message);
true
} else {
false
}
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn self_attention_with_qkv(
source: usize,
target: usize,
weight: f64,
head: usize,
layer: usize,
q_proj: DenseTensor,
k_proj: DenseTensor,
v_proj: DenseTensor,
) -> Self {
Self {
edge_type: GraphEdgeType::SelfAttention,
source,
target,
self_attention: Some(SelfAttentionEdge::with_qkv(
weight, head, layer, q_proj, k_proj, v_proj,
)),
data_flow: None,
residual: None,
}
}
pub fn get_qkv(&self) -> (Option<&DenseTensor>, Option<&DenseTensor>, Option<&DenseTensor>) {
if let Some(sa) = &self.self_attention {
sa.get_qkv()
} else {
(None, None, None)
}
}
pub fn has_qkv(&self) -> bool {
self.self_attention.as_ref().is_some_and(|sa| sa.has_qkv())
}
pub fn key_proj(&self) -> Option<&DenseTensor> {
self.self_attention.as_ref().and_then(|sa| sa.key_proj())
}
pub fn value_proj(&self) -> Option<&DenseTensor> {
self.self_attention.as_ref().and_then(|sa| sa.value_proj())
}
pub fn compute_attention_score(&self, d_k: f64) -> Option<f64> {
self.self_attention.as_ref().and_then(|sa| sa.compute_attention_score(d_k))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_self_attention_edge() {
let edge = GraphEdge::self_attention(0, 1, 0.8, 2, 5);
assert_eq!(edge.edge_type, GraphEdgeType::SelfAttention);
assert_eq!(edge.source, 0);
assert_eq!(edge.target, 1);
let sa = edge.get_self_attention().unwrap();
assert_eq!(sa.weight, 0.8);
assert_eq!(sa.head, 2);
assert_eq!(sa.layer, 5);
}
#[test]
fn test_data_flow_edge() {
let edge = GraphEdge::data_flow(10, 20, DataFlowOp::InputToAttention, 3);
assert_eq!(edge.edge_type, GraphEdgeType::DataFlow);
assert_eq!(edge.source, 10);
assert_eq!(edge.target, 20);
let df = edge.get_data_flow().unwrap();
assert_eq!(df.operation, DataFlowOp::InputToAttention);
assert_eq!(df.layer, 3);
}
#[test]
fn test_residual_edge() {
let edge = GraphEdge::residual(5, 15, 7, SkipType::PreNorm);
assert_eq!(edge.edge_type, GraphEdgeType::Residual);
assert_eq!(edge.source, 5);
assert_eq!(edge.target, 15);
let res = edge.get_residual().unwrap();
assert_eq!(res.layer, 7);
assert!(matches!(res.skip_type, SkipType::PreNorm));
}
#[test]
fn test_edge_layer() {
let sa_edge = GraphEdge::self_attention(0, 1, 0.5, 1, 10);
assert_eq!(sa_edge.layer(), 10);
let df_edge = GraphEdge::data_flow(0, 1, DataFlowOp::LayerToLayer, 5);
assert_eq!(df_edge.layer(), 5);
let res_edge = GraphEdge::residual(0, 1, 3, SkipType::PostNorm);
assert_eq!(res_edge.layer(), 3);
}
#[test]
fn test_tensor_message_passing() {
use crate::tensor::DenseTensor;
use crate::tensor::traits::TensorBase;
let message = DenseTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let mut sa_edge = GraphEdge::self_attention_with_message(
0, 1, 0.8, 2, 5, message.clone()
);
assert!(sa_edge.message().is_some());
assert_eq!(sa_edge.message().unwrap().shape(), &[2, 2]);
let df_edge = GraphEdge::data_flow_with_message(
10, 20, DataFlowOp::InputToAttention, 3, message.clone()
);
assert!(df_edge.message().is_some());
let res_edge = GraphEdge::residual_with_tensor(
5, 15, 7, SkipType::PreNorm, message.clone()
);
assert!(res_edge.message().is_some());
let new_message = DenseTensor::from_vec(vec![5.0, 6.0], vec![2]);
sa_edge.set_message(new_message.clone());
assert!(sa_edge.message().is_some());
}
}