Skip to main content

god_graph/transformer/graph_transformer/
edges.rs

1//! Edge types for graph-structured Transformer
2//!
3//! This module provides edge types that support tensor message passing
4//! between nodes in the computation graph.
5//!
6//! ## Edge Tensor Passing Semantics
7//!
8//! Edges in the GraphTransformer carry tensor messages for efficient computation:
9//!
10//! - **SelfAttention edges**: Carry Q/K/V projection tensors for attention computation
11//! - **DataFlow edges**: Carry activation tensors between layers
12//! - **Residual edges**: Carry identity passthrough tensors for residual connections
13//!
14//! ## Example
15//!
16//! ```rust
17//! use god_gragh::transformer::graph_transformer::edges::{GraphEdge, GraphEdgeType, SelfAttentionEdge};
18//! use god_gragh::tensor::DenseTensor;
19//! use god_gragh::tensor::traits::TensorBase;
20//!
21//! // Create Q/K/V projection tensors
22//! let q_proj = DenseTensor::zeros(vec![1, 64]); // Query projection
23//! let k_proj = DenseTensor::zeros(vec![1, 64]); // Key projection
24//! let v_proj = DenseTensor::zeros(vec![1, 64]); // Value projection
25//!
26//! // Create self-attention edge with QKV message
27//! let mut sa_edge = GraphEdge::self_attention_with_message(
28//!     0, 1, 0.5, 2, 0, q_proj
29//! );
30//!
31//! // Access the message tensor
32//! if let Some(msg) = sa_edge.message() {
33//!     println!("Message shape: {:?}", msg.shape());
34//! }
35//! ```
36
37use crate::tensor::DenseTensor;
38use crate::tensor::traits::TensorBase;
39
40/// Type of graph edge
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42pub enum GraphEdgeType {
43    /// Self-attention edge
44    SelfAttention,
45    /// Data flow edge (residual connections, FFN input/output)
46    DataFlow,
47    /// Residual connection edge
48    Residual,
49}
50
51/// Skip connection type
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum SkipType {
54    /// Pre-normalization (norm before attention/FFN)
55    PreNorm,
56    /// Post-normalization (norm after attention/FFN)
57    PostNorm,
58}
59
60/// Self-attention edge data
61#[derive(Debug, Clone)]
62pub struct SelfAttentionEdge {
63    /// Attention weight
64    pub weight: f64,
65    /// Attention head
66    pub head: usize,
67    /// Layer number
68    pub layer: usize,
69    /// Message tensor (Q/K/V projections)
70    /// For multi-head attention, this contains concatenated QKV projections
71    pub message: Option<DenseTensor>,
72    /// Optional separate K (key) projection
73    pub key_proj: Option<DenseTensor>,
74    /// Optional separate V (value) projection
75    pub value_proj: Option<DenseTensor>,
76}
77
78impl SelfAttentionEdge {
79    /// Create a new self-attention edge
80    pub fn new(weight: f64, head: usize, layer: usize) -> Self {
81        Self {
82            weight,
83            head,
84            layer,
85            message: None,
86            key_proj: None,
87            value_proj: None,
88        }
89    }
90
91    /// Create with message tensor (Q projection)
92    pub fn with_message(weight: f64, head: usize, layer: usize, message: DenseTensor) -> Self {
93        Self {
94            weight,
95            head,
96            layer,
97            message: Some(message),
98            key_proj: None,
99            value_proj: None,
100        }
101    }
102
103    /// Create with separate Q, K, V projections
104    pub fn with_qkv(
105        weight: f64,
106        head: usize,
107        layer: usize,
108        q_proj: DenseTensor,
109        k_proj: DenseTensor,
110        v_proj: DenseTensor,
111    ) -> Self {
112        Self {
113            weight,
114            head,
115            layer,
116            message: Some(q_proj),
117            key_proj: Some(k_proj),
118            value_proj: Some(v_proj),
119        }
120    }
121
122    /// Set the message tensor (Q projection)
123    pub fn set_message(&mut self, message: DenseTensor) {
124        self.message = Some(message);
125    }
126
127    /// Get the message tensor (Q projection)
128    pub fn message(&self) -> Option<&DenseTensor> {
129        self.message.as_ref()
130    }
131
132    /// Set the key projection
133    pub fn set_key_proj(&mut self, key: DenseTensor) {
134        self.key_proj = Some(key);
135    }
136
137    /// Get the key projection
138    pub fn key_proj(&self) -> Option<&DenseTensor> {
139        self.key_proj.as_ref()
140    }
141
142    /// Set the value projection
143    pub fn set_value_proj(&mut self, value: DenseTensor) {
144        self.value_proj = Some(value);
145    }
146
147    /// Get the value projection
148    pub fn value_proj(&self) -> Option<&DenseTensor> {
149        self.value_proj.as_ref()
150    }
151
152    /// Get all QKV projections if available
153    pub fn get_qkv(&self) -> (Option<&DenseTensor>, Option<&DenseTensor>, Option<&DenseTensor>) {
154        (self.message.as_ref(), self.key_proj.as_ref(), self.value_proj.as_ref())
155    }
156
157    /// Check if this edge has complete QKV projections
158    pub fn has_qkv(&self) -> bool {
159        self.message.is_some() && self.key_proj.is_some() && self.value_proj.is_some()
160    }
161
162    /// Compute attention score using Q and K projections
163    /// score = Q @ K^T / sqrt(d_k)
164    pub fn compute_attention_score(&self, d_k: f64) -> Option<f64> {
165        if let (Some(q), Some(k)) = (&self.message, &self.key_proj) {
166            if q.shape() == k.shape() && q.ndim() == 2 {
167                // Simple dot-product attention
168                let q_data = q.data();
169                let k_data = k.data();
170                
171                let dot_product: f64 = q_data.iter()
172                    .zip(k_data.iter())
173                    .map(|(&q_val, &k_val)| q_val * k_val)
174                    .sum();
175                
176                Some(dot_product / d_k.sqrt())
177            } else {
178                None
179            }
180        } else {
181            None
182        }
183    }
184}
185
186/// Data flow edge data
187#[derive(Debug, Clone)]
188pub struct DataFlowEdge {
189    /// Operation type
190    pub operation: DataFlowOp,
191    /// Layer number
192    pub layer: usize,
193    /// Message tensor being transferred
194    pub message: Option<DenseTensor>,
195}
196
197/// Data flow operation types
198#[derive(Debug, Clone, Copy, PartialEq, Eq)]
199pub enum DataFlowOp {
200    /// Input to attention
201    InputToAttention,
202    /// Attention to output
203    AttentionToOutput,
204    /// Input to FFN
205    InputToFFN,
206    /// FFN to output
207    FFNToOutput,
208    /// Layer output to next layer
209    LayerToLayer,
210}
211
212impl DataFlowEdge {
213    /// Create a new data flow edge
214    pub fn new(operation: DataFlowOp, layer: usize) -> Self {
215        Self {
216            operation,
217            layer,
218            message: None,
219        }
220    }
221
222    /// Create with message tensor
223    pub fn with_message(operation: DataFlowOp, layer: usize, message: DenseTensor) -> Self {
224        Self {
225            operation,
226            layer,
227            message: Some(message),
228        }
229    }
230
231    /// Set the message tensor
232    pub fn set_message(&mut self, message: DenseTensor) {
233        self.message = Some(message);
234    }
235
236    /// Get the message tensor
237    pub fn message(&self) -> Option<&DenseTensor> {
238        self.message.as_ref()
239    }
240}
241
242/// Residual connection edge data
243#[derive(Debug, Clone)]
244pub struct ResidualEdge {
245    /// Layer number
246    pub layer: usize,
247    /// Skip type
248    pub skip_type: SkipType,
249    /// Residual tensor (identity passthrough)
250    pub residual: Option<DenseTensor>,
251}
252
253impl ResidualEdge {
254    /// Create a new residual edge
255    pub fn new(layer: usize, skip_type: SkipType) -> Self {
256        Self {
257            layer,
258            skip_type,
259            residual: None,
260        }
261    }
262
263    /// Create with residual tensor
264    pub fn with_residual(layer: usize, skip_type: SkipType, residual: DenseTensor) -> Self {
265        Self {
266            layer,
267            skip_type,
268            residual: Some(residual),
269        }
270    }
271
272    /// Set the residual tensor
273    pub fn set_residual(&mut self, residual: DenseTensor) {
274        self.residual = Some(residual);
275    }
276
277    /// Get the residual tensor
278    pub fn residual(&self) -> Option<&DenseTensor> {
279        self.residual.as_ref()
280    }
281}
282
283/// Graph edge wrapper
284#[derive(Debug, Clone)]
285pub struct GraphEdge {
286    /// Edge type
287    pub edge_type: GraphEdgeType,
288    /// Source node ID
289    pub source: usize,
290    /// Target node ID
291    pub target: usize,
292    /// Optional self-attention data
293    pub self_attention: Option<SelfAttentionEdge>,
294    /// Optional data flow data
295    pub data_flow: Option<DataFlowEdge>,
296    /// Optional residual data
297    pub residual: Option<ResidualEdge>,
298}
299
300impl GraphEdge {
301    /// Create a self-attention edge
302    pub fn self_attention(source: usize, target: usize, weight: f64, head: usize, layer: usize) -> Self {
303        Self {
304            edge_type: GraphEdgeType::SelfAttention,
305            source,
306            target,
307            self_attention: Some(SelfAttentionEdge::new(weight, head, layer)),
308            data_flow: None,
309            residual: None,
310        }
311    }
312
313    /// Create a data flow edge
314    pub fn data_flow(source: usize, target: usize, operation: DataFlowOp, layer: usize) -> Self {
315        Self {
316            edge_type: GraphEdgeType::DataFlow,
317            source,
318            target,
319            self_attention: None,
320            data_flow: Some(DataFlowEdge::new(operation, layer)),
321            residual: None,
322        }
323    }
324
325    /// Create a residual edge
326    pub fn residual(source: usize, target: usize, layer: usize, skip_type: SkipType) -> Self {
327        Self {
328            edge_type: GraphEdgeType::Residual,
329            source,
330            target,
331            self_attention: None,
332            data_flow: None,
333            residual: Some(ResidualEdge::new(layer, skip_type)),
334        }
335    }
336
337    /// Get the self-attention data if applicable
338    pub fn get_self_attention(&self) -> Option<&SelfAttentionEdge> {
339        self.self_attention.as_ref()
340    }
341
342    /// Get the data flow info if applicable
343    pub fn get_data_flow(&self) -> Option<&DataFlowEdge> {
344        self.data_flow.as_ref()
345    }
346
347    /// Get the residual info if applicable
348    pub fn get_residual(&self) -> Option<&ResidualEdge> {
349        self.residual.as_ref()
350    }
351
352    /// Get the layer number
353    pub fn layer(&self) -> usize {
354        if let Some(sa) = &self.self_attention {
355            sa.layer
356        } else if let Some(df) = &self.data_flow {
357            df.layer
358        } else if let Some(res) = &self.residual {
359            res.layer
360        } else {
361            0
362        }
363    }
364
365    /// Create a self-attention edge with message tensor
366    pub fn self_attention_with_message(
367        source: usize,
368        target: usize,
369        weight: f64,
370        head: usize,
371        layer: usize,
372        message: DenseTensor,
373    ) -> Self {
374        Self {
375            edge_type: GraphEdgeType::SelfAttention,
376            source,
377            target,
378            self_attention: Some(SelfAttentionEdge::with_message(weight, head, layer, message)),
379            data_flow: None,
380            residual: None,
381        }
382    }
383
384    /// Create a data flow edge with message tensor
385    pub fn data_flow_with_message(
386        source: usize,
387        target: usize,
388        operation: DataFlowOp,
389        layer: usize,
390        message: DenseTensor,
391    ) -> Self {
392        Self {
393            edge_type: GraphEdgeType::DataFlow,
394            source,
395            target,
396            self_attention: None,
397            data_flow: Some(DataFlowEdge::with_message(operation, layer, message)),
398            residual: None,
399        }
400    }
401
402    /// Create a residual edge with tensor
403    pub fn residual_with_tensor(
404        source: usize,
405        target: usize,
406        layer: usize,
407        skip_type: SkipType,
408        residual: DenseTensor,
409    ) -> Self {
410        Self {
411            edge_type: GraphEdgeType::Residual,
412            source,
413            target,
414            self_attention: None,
415            data_flow: None,
416            residual: Some(ResidualEdge::with_residual(layer, skip_type, residual)),
417        }
418    }
419
420    /// Get the message tensor from this edge (if any)
421    pub fn message(&self) -> Option<&DenseTensor> {
422        match self.edge_type {
423            GraphEdgeType::SelfAttention => {
424                self.self_attention.as_ref().and_then(|sa| sa.message.as_ref())
425            }
426            GraphEdgeType::DataFlow => {
427                self.data_flow.as_ref().and_then(|df| df.message.as_ref())
428            }
429            GraphEdgeType::Residual => {
430                self.residual.as_ref().and_then(|r| r.residual.as_ref())
431            }
432        }
433    }
434
435    /// Set the message tensor on this edge
436    pub fn set_message(&mut self, message: DenseTensor) -> bool {
437        match self.edge_type {
438            GraphEdgeType::SelfAttention => {
439                if let Some(ref mut sa) = self.self_attention {
440                    sa.set_message(message);
441                    true
442                } else {
443                    false
444                }
445            }
446            GraphEdgeType::DataFlow => {
447                if let Some(ref mut df) = self.data_flow {
448                    df.set_message(message);
449                    true
450                } else {
451                    false
452                }
453            }
454            GraphEdgeType::Residual => {
455                if let Some(ref mut r) = self.residual {
456                    r.set_residual(message);
457                    true
458                } else {
459                    false
460                }
461            }
462        }
463    }
464
465    /// Create a self-attention edge with separate Q, K, V projections
466    #[allow(clippy::too_many_arguments)]
467    pub fn self_attention_with_qkv(
468        source: usize,
469        target: usize,
470        weight: f64,
471        head: usize,
472        layer: usize,
473        q_proj: DenseTensor,
474        k_proj: DenseTensor,
475        v_proj: DenseTensor,
476    ) -> Self {
477        Self {
478            edge_type: GraphEdgeType::SelfAttention,
479            source,
480            target,
481            self_attention: Some(SelfAttentionEdge::with_qkv(
482                weight, head, layer, q_proj, k_proj, v_proj,
483            )),
484            data_flow: None,
485            residual: None,
486        }
487    }
488
489    /// Get Q/K/V projections if available (SelfAttention edges only)
490    pub fn get_qkv(&self) -> (Option<&DenseTensor>, Option<&DenseTensor>, Option<&DenseTensor>) {
491        if let Some(sa) = &self.self_attention {
492            sa.get_qkv()
493        } else {
494            (None, None, None)
495        }
496    }
497
498    /// Check if this edge has complete QKV projections
499    pub fn has_qkv(&self) -> bool {
500        self.self_attention.as_ref().is_some_and(|sa| sa.has_qkv())
501    }
502
503    /// Get the key projection (SelfAttention edges only)
504    pub fn key_proj(&self) -> Option<&DenseTensor> {
505        self.self_attention.as_ref().and_then(|sa| sa.key_proj())
506    }
507
508    /// Get the value projection (SelfAttention edges only)
509    pub fn value_proj(&self) -> Option<&DenseTensor> {
510        self.self_attention.as_ref().and_then(|sa| sa.value_proj())
511    }
512
513    /// Compute attention score using Q and K projections (SelfAttention edges only)
514    pub fn compute_attention_score(&self, d_k: f64) -> Option<f64> {
515        self.self_attention.as_ref().and_then(|sa| sa.compute_attention_score(d_k))
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522
523    #[test]
524    fn test_self_attention_edge() {
525        let edge = GraphEdge::self_attention(0, 1, 0.8, 2, 5);
526
527        assert_eq!(edge.edge_type, GraphEdgeType::SelfAttention);
528        assert_eq!(edge.source, 0);
529        assert_eq!(edge.target, 1);
530
531        let sa = edge.get_self_attention().unwrap();
532        assert_eq!(sa.weight, 0.8);
533        assert_eq!(sa.head, 2);
534        assert_eq!(sa.layer, 5);
535    }
536
537    #[test]
538    fn test_data_flow_edge() {
539        let edge = GraphEdge::data_flow(10, 20, DataFlowOp::InputToAttention, 3);
540
541        assert_eq!(edge.edge_type, GraphEdgeType::DataFlow);
542        assert_eq!(edge.source, 10);
543        assert_eq!(edge.target, 20);
544
545        let df = edge.get_data_flow().unwrap();
546        assert_eq!(df.operation, DataFlowOp::InputToAttention);
547        assert_eq!(df.layer, 3);
548    }
549
550    #[test]
551    fn test_residual_edge() {
552        let edge = GraphEdge::residual(5, 15, 7, SkipType::PreNorm);
553
554        assert_eq!(edge.edge_type, GraphEdgeType::Residual);
555        assert_eq!(edge.source, 5);
556        assert_eq!(edge.target, 15);
557
558        let res = edge.get_residual().unwrap();
559        assert_eq!(res.layer, 7);
560        assert!(matches!(res.skip_type, SkipType::PreNorm));
561    }
562
563    #[test]
564    fn test_edge_layer() {
565        let sa_edge = GraphEdge::self_attention(0, 1, 0.5, 1, 10);
566        assert_eq!(sa_edge.layer(), 10);
567
568        let df_edge = GraphEdge::data_flow(0, 1, DataFlowOp::LayerToLayer, 5);
569        assert_eq!(df_edge.layer(), 5);
570
571        let res_edge = GraphEdge::residual(0, 1, 3, SkipType::PostNorm);
572        assert_eq!(res_edge.layer(), 3);
573    }
574
575    #[test]
576    fn test_tensor_message_passing() {
577        use crate::tensor::DenseTensor;
578        use crate::tensor::traits::TensorBase;
579        
580        // Create a message tensor
581        let message = DenseTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
582        
583        // Test self-attention edge with message
584        let mut sa_edge = GraphEdge::self_attention_with_message(
585            0, 1, 0.8, 2, 5, message.clone()
586        );
587        assert!(sa_edge.message().is_some());
588        assert_eq!(sa_edge.message().unwrap().shape(), &[2, 2]);
589        
590        // Test data flow edge with message
591        let df_edge = GraphEdge::data_flow_with_message(
592            10, 20, DataFlowOp::InputToAttention, 3, message.clone()
593        );
594        assert!(df_edge.message().is_some());
595        
596        // Test residual edge with tensor
597        let res_edge = GraphEdge::residual_with_tensor(
598            5, 15, 7, SkipType::PreNorm, message.clone()
599        );
600        assert!(res_edge.message().is_some());
601        
602        // Test set_message on existing edge
603        let new_message = DenseTensor::from_vec(vec![5.0, 6.0], vec![2]);
604        sa_edge.set_message(new_message.clone());
605        assert!(sa_edge.message().is_some());
606    }
607}