1use crate::tensor::DenseTensor;
38use crate::tensor::traits::TensorBase;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42pub enum GraphEdgeType {
43 SelfAttention,
45 DataFlow,
47 Residual,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum SkipType {
54 PreNorm,
56 PostNorm,
58}
59
60#[derive(Debug, Clone)]
62pub struct SelfAttentionEdge {
63 pub weight: f64,
65 pub head: usize,
67 pub layer: usize,
69 pub message: Option<DenseTensor>,
72 pub key_proj: Option<DenseTensor>,
74 pub value_proj: Option<DenseTensor>,
76}
77
78impl SelfAttentionEdge {
79 pub fn new(weight: f64, head: usize, layer: usize) -> Self {
81 Self {
82 weight,
83 head,
84 layer,
85 message: None,
86 key_proj: None,
87 value_proj: None,
88 }
89 }
90
91 pub fn with_message(weight: f64, head: usize, layer: usize, message: DenseTensor) -> Self {
93 Self {
94 weight,
95 head,
96 layer,
97 message: Some(message),
98 key_proj: None,
99 value_proj: None,
100 }
101 }
102
103 pub fn with_qkv(
105 weight: f64,
106 head: usize,
107 layer: usize,
108 q_proj: DenseTensor,
109 k_proj: DenseTensor,
110 v_proj: DenseTensor,
111 ) -> Self {
112 Self {
113 weight,
114 head,
115 layer,
116 message: Some(q_proj),
117 key_proj: Some(k_proj),
118 value_proj: Some(v_proj),
119 }
120 }
121
122 pub fn set_message(&mut self, message: DenseTensor) {
124 self.message = Some(message);
125 }
126
127 pub fn message(&self) -> Option<&DenseTensor> {
129 self.message.as_ref()
130 }
131
132 pub fn set_key_proj(&mut self, key: DenseTensor) {
134 self.key_proj = Some(key);
135 }
136
137 pub fn key_proj(&self) -> Option<&DenseTensor> {
139 self.key_proj.as_ref()
140 }
141
142 pub fn set_value_proj(&mut self, value: DenseTensor) {
144 self.value_proj = Some(value);
145 }
146
147 pub fn value_proj(&self) -> Option<&DenseTensor> {
149 self.value_proj.as_ref()
150 }
151
152 pub fn get_qkv(&self) -> (Option<&DenseTensor>, Option<&DenseTensor>, Option<&DenseTensor>) {
154 (self.message.as_ref(), self.key_proj.as_ref(), self.value_proj.as_ref())
155 }
156
157 pub fn has_qkv(&self) -> bool {
159 self.message.is_some() && self.key_proj.is_some() && self.value_proj.is_some()
160 }
161
162 pub fn compute_attention_score(&self, d_k: f64) -> Option<f64> {
165 if let (Some(q), Some(k)) = (&self.message, &self.key_proj) {
166 if q.shape() == k.shape() && q.ndim() == 2 {
167 let q_data = q.data();
169 let k_data = k.data();
170
171 let dot_product: f64 = q_data.iter()
172 .zip(k_data.iter())
173 .map(|(&q_val, &k_val)| q_val * k_val)
174 .sum();
175
176 Some(dot_product / d_k.sqrt())
177 } else {
178 None
179 }
180 } else {
181 None
182 }
183 }
184}
185
186#[derive(Debug, Clone)]
188pub struct DataFlowEdge {
189 pub operation: DataFlowOp,
191 pub layer: usize,
193 pub message: Option<DenseTensor>,
195}
196
197#[derive(Debug, Clone, Copy, PartialEq, Eq)]
199pub enum DataFlowOp {
200 InputToAttention,
202 AttentionToOutput,
204 InputToFFN,
206 FFNToOutput,
208 LayerToLayer,
210}
211
212impl DataFlowEdge {
213 pub fn new(operation: DataFlowOp, layer: usize) -> Self {
215 Self {
216 operation,
217 layer,
218 message: None,
219 }
220 }
221
222 pub fn with_message(operation: DataFlowOp, layer: usize, message: DenseTensor) -> Self {
224 Self {
225 operation,
226 layer,
227 message: Some(message),
228 }
229 }
230
231 pub fn set_message(&mut self, message: DenseTensor) {
233 self.message = Some(message);
234 }
235
236 pub fn message(&self) -> Option<&DenseTensor> {
238 self.message.as_ref()
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct ResidualEdge {
245 pub layer: usize,
247 pub skip_type: SkipType,
249 pub residual: Option<DenseTensor>,
251}
252
253impl ResidualEdge {
254 pub fn new(layer: usize, skip_type: SkipType) -> Self {
256 Self {
257 layer,
258 skip_type,
259 residual: None,
260 }
261 }
262
263 pub fn with_residual(layer: usize, skip_type: SkipType, residual: DenseTensor) -> Self {
265 Self {
266 layer,
267 skip_type,
268 residual: Some(residual),
269 }
270 }
271
272 pub fn set_residual(&mut self, residual: DenseTensor) {
274 self.residual = Some(residual);
275 }
276
277 pub fn residual(&self) -> Option<&DenseTensor> {
279 self.residual.as_ref()
280 }
281}
282
283#[derive(Debug, Clone)]
285pub struct GraphEdge {
286 pub edge_type: GraphEdgeType,
288 pub source: usize,
290 pub target: usize,
292 pub self_attention: Option<SelfAttentionEdge>,
294 pub data_flow: Option<DataFlowEdge>,
296 pub residual: Option<ResidualEdge>,
298}
299
300impl GraphEdge {
301 pub fn self_attention(source: usize, target: usize, weight: f64, head: usize, layer: usize) -> Self {
303 Self {
304 edge_type: GraphEdgeType::SelfAttention,
305 source,
306 target,
307 self_attention: Some(SelfAttentionEdge::new(weight, head, layer)),
308 data_flow: None,
309 residual: None,
310 }
311 }
312
313 pub fn data_flow(source: usize, target: usize, operation: DataFlowOp, layer: usize) -> Self {
315 Self {
316 edge_type: GraphEdgeType::DataFlow,
317 source,
318 target,
319 self_attention: None,
320 data_flow: Some(DataFlowEdge::new(operation, layer)),
321 residual: None,
322 }
323 }
324
325 pub fn residual(source: usize, target: usize, layer: usize, skip_type: SkipType) -> Self {
327 Self {
328 edge_type: GraphEdgeType::Residual,
329 source,
330 target,
331 self_attention: None,
332 data_flow: None,
333 residual: Some(ResidualEdge::new(layer, skip_type)),
334 }
335 }
336
337 pub fn get_self_attention(&self) -> Option<&SelfAttentionEdge> {
339 self.self_attention.as_ref()
340 }
341
342 pub fn get_data_flow(&self) -> Option<&DataFlowEdge> {
344 self.data_flow.as_ref()
345 }
346
347 pub fn get_residual(&self) -> Option<&ResidualEdge> {
349 self.residual.as_ref()
350 }
351
352 pub fn layer(&self) -> usize {
354 if let Some(sa) = &self.self_attention {
355 sa.layer
356 } else if let Some(df) = &self.data_flow {
357 df.layer
358 } else if let Some(res) = &self.residual {
359 res.layer
360 } else {
361 0
362 }
363 }
364
365 pub fn self_attention_with_message(
367 source: usize,
368 target: usize,
369 weight: f64,
370 head: usize,
371 layer: usize,
372 message: DenseTensor,
373 ) -> Self {
374 Self {
375 edge_type: GraphEdgeType::SelfAttention,
376 source,
377 target,
378 self_attention: Some(SelfAttentionEdge::with_message(weight, head, layer, message)),
379 data_flow: None,
380 residual: None,
381 }
382 }
383
384 pub fn data_flow_with_message(
386 source: usize,
387 target: usize,
388 operation: DataFlowOp,
389 layer: usize,
390 message: DenseTensor,
391 ) -> Self {
392 Self {
393 edge_type: GraphEdgeType::DataFlow,
394 source,
395 target,
396 self_attention: None,
397 data_flow: Some(DataFlowEdge::with_message(operation, layer, message)),
398 residual: None,
399 }
400 }
401
402 pub fn residual_with_tensor(
404 source: usize,
405 target: usize,
406 layer: usize,
407 skip_type: SkipType,
408 residual: DenseTensor,
409 ) -> Self {
410 Self {
411 edge_type: GraphEdgeType::Residual,
412 source,
413 target,
414 self_attention: None,
415 data_flow: None,
416 residual: Some(ResidualEdge::with_residual(layer, skip_type, residual)),
417 }
418 }
419
420 pub fn message(&self) -> Option<&DenseTensor> {
422 match self.edge_type {
423 GraphEdgeType::SelfAttention => {
424 self.self_attention.as_ref().and_then(|sa| sa.message.as_ref())
425 }
426 GraphEdgeType::DataFlow => {
427 self.data_flow.as_ref().and_then(|df| df.message.as_ref())
428 }
429 GraphEdgeType::Residual => {
430 self.residual.as_ref().and_then(|r| r.residual.as_ref())
431 }
432 }
433 }
434
435 pub fn set_message(&mut self, message: DenseTensor) -> bool {
437 match self.edge_type {
438 GraphEdgeType::SelfAttention => {
439 if let Some(ref mut sa) = self.self_attention {
440 sa.set_message(message);
441 true
442 } else {
443 false
444 }
445 }
446 GraphEdgeType::DataFlow => {
447 if let Some(ref mut df) = self.data_flow {
448 df.set_message(message);
449 true
450 } else {
451 false
452 }
453 }
454 GraphEdgeType::Residual => {
455 if let Some(ref mut r) = self.residual {
456 r.set_residual(message);
457 true
458 } else {
459 false
460 }
461 }
462 }
463 }
464
465 #[allow(clippy::too_many_arguments)]
467 pub fn self_attention_with_qkv(
468 source: usize,
469 target: usize,
470 weight: f64,
471 head: usize,
472 layer: usize,
473 q_proj: DenseTensor,
474 k_proj: DenseTensor,
475 v_proj: DenseTensor,
476 ) -> Self {
477 Self {
478 edge_type: GraphEdgeType::SelfAttention,
479 source,
480 target,
481 self_attention: Some(SelfAttentionEdge::with_qkv(
482 weight, head, layer, q_proj, k_proj, v_proj,
483 )),
484 data_flow: None,
485 residual: None,
486 }
487 }
488
489 pub fn get_qkv(&self) -> (Option<&DenseTensor>, Option<&DenseTensor>, Option<&DenseTensor>) {
491 if let Some(sa) = &self.self_attention {
492 sa.get_qkv()
493 } else {
494 (None, None, None)
495 }
496 }
497
498 pub fn has_qkv(&self) -> bool {
500 self.self_attention.as_ref().is_some_and(|sa| sa.has_qkv())
501 }
502
503 pub fn key_proj(&self) -> Option<&DenseTensor> {
505 self.self_attention.as_ref().and_then(|sa| sa.key_proj())
506 }
507
508 pub fn value_proj(&self) -> Option<&DenseTensor> {
510 self.self_attention.as_ref().and_then(|sa| sa.value_proj())
511 }
512
513 pub fn compute_attention_score(&self, d_k: f64) -> Option<f64> {
515 self.self_attention.as_ref().and_then(|sa| sa.compute_attention_score(d_k))
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522
523 #[test]
524 fn test_self_attention_edge() {
525 let edge = GraphEdge::self_attention(0, 1, 0.8, 2, 5);
526
527 assert_eq!(edge.edge_type, GraphEdgeType::SelfAttention);
528 assert_eq!(edge.source, 0);
529 assert_eq!(edge.target, 1);
530
531 let sa = edge.get_self_attention().unwrap();
532 assert_eq!(sa.weight, 0.8);
533 assert_eq!(sa.head, 2);
534 assert_eq!(sa.layer, 5);
535 }
536
537 #[test]
538 fn test_data_flow_edge() {
539 let edge = GraphEdge::data_flow(10, 20, DataFlowOp::InputToAttention, 3);
540
541 assert_eq!(edge.edge_type, GraphEdgeType::DataFlow);
542 assert_eq!(edge.source, 10);
543 assert_eq!(edge.target, 20);
544
545 let df = edge.get_data_flow().unwrap();
546 assert_eq!(df.operation, DataFlowOp::InputToAttention);
547 assert_eq!(df.layer, 3);
548 }
549
550 #[test]
551 fn test_residual_edge() {
552 let edge = GraphEdge::residual(5, 15, 7, SkipType::PreNorm);
553
554 assert_eq!(edge.edge_type, GraphEdgeType::Residual);
555 assert_eq!(edge.source, 5);
556 assert_eq!(edge.target, 15);
557
558 let res = edge.get_residual().unwrap();
559 assert_eq!(res.layer, 7);
560 assert!(matches!(res.skip_type, SkipType::PreNorm));
561 }
562
563 #[test]
564 fn test_edge_layer() {
565 let sa_edge = GraphEdge::self_attention(0, 1, 0.5, 1, 10);
566 assert_eq!(sa_edge.layer(), 10);
567
568 let df_edge = GraphEdge::data_flow(0, 1, DataFlowOp::LayerToLayer, 5);
569 assert_eq!(df_edge.layer(), 5);
570
571 let res_edge = GraphEdge::residual(0, 1, 3, SkipType::PostNorm);
572 assert_eq!(res_edge.layer(), 3);
573 }
574
575 #[test]
576 fn test_tensor_message_passing() {
577 use crate::tensor::DenseTensor;
578 use crate::tensor::traits::TensorBase;
579
580 let message = DenseTensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
582
583 let mut sa_edge = GraphEdge::self_attention_with_message(
585 0, 1, 0.8, 2, 5, message.clone()
586 );
587 assert!(sa_edge.message().is_some());
588 assert_eq!(sa_edge.message().unwrap().shape(), &[2, 2]);
589
590 let df_edge = GraphEdge::data_flow_with_message(
592 10, 20, DataFlowOp::InputToAttention, 3, message.clone()
593 );
594 assert!(df_edge.message().is_some());
595
596 let res_edge = GraphEdge::residual_with_tensor(
598 5, 15, 7, SkipType::PreNorm, message.clone()
599 );
600 assert!(res_edge.message().is_some());
601
602 let new_message = DenseTensor::from_vec(vec![5.0, 6.0], vec![2]);
604 sa_edge.set_message(new_message.clone());
605 assert!(sa_edge.message().is_some());
606 }
607}