converge_knowledge/learning/
gnn.rs1use rand::Rng;
4
5pub struct GnnLayer {
10 input_dim: usize,
12
13 #[allow(dead_code)]
15 hidden_dim: usize,
16
17 num_heads: usize,
19
20 weights: Vec<f32>,
22
23 #[allow(dead_code)]
25 attention_weights: Vec<f32>,
26}
27
28impl GnnLayer {
29 pub fn new(input_dim: usize, hidden_dim: usize, num_heads: usize) -> Self {
31 let mut rng = rand::thread_rng();
32
33 let scale = (2.0 / (input_dim + hidden_dim) as f32).sqrt();
35
36 let weights: Vec<f32> = (0..input_dim * hidden_dim)
37 .map(|_| rng.gen_range(-scale..scale))
38 .collect();
39
40 let attention_weights: Vec<f32> = (0..num_heads * input_dim)
41 .map(|_| rng.gen_range(-scale..scale))
42 .collect();
43
44 Self {
45 input_dim,
46 hidden_dim,
47 num_heads,
48 weights,
49 attention_weights,
50 }
51 }
52
53 pub fn forward(
55 &self,
56 node_embedding: &[f32],
57 neighbor_embeddings: &[Vec<f32>],
58 edge_weights: &[f32],
59 ) -> Vec<f32> {
60 if neighbor_embeddings.is_empty() {
61 return self.linear_transform(node_embedding);
63 }
64
65 let attention_scores = self.compute_attention(node_embedding, neighbor_embeddings);
67
68 let combined_weights: Vec<f32> = attention_scores
70 .iter()
71 .zip(edge_weights.iter())
72 .map(|(a, e)| a * e)
73 .collect();
74
75 let weight_sum: f32 = combined_weights.iter().sum();
77 let normalized_weights: Vec<f32> = if weight_sum > 0.0 {
78 combined_weights.iter().map(|w| w / weight_sum).collect()
79 } else {
80 vec![1.0 / neighbor_embeddings.len() as f32; neighbor_embeddings.len()]
81 };
82
83 let mut aggregated = vec![0.0f32; self.input_dim];
85 for (neighbor, &weight) in neighbor_embeddings.iter().zip(normalized_weights.iter()) {
86 for (i, &val) in neighbor.iter().enumerate() {
87 if i < self.input_dim {
88 aggregated[i] += val * weight;
89 }
90 }
91 }
92
93 let combined: Vec<f32> = node_embedding
95 .iter()
96 .zip(aggregated.iter())
97 .map(|(n, a)| 0.5 * n + 0.5 * a)
98 .collect();
99
100 let transformed = self.linear_transform(&combined);
102
103 transformed.into_iter().map(|x| x.max(0.0)).collect()
105 }
106
107 fn compute_attention(&self, query: &[f32], keys: &[Vec<f32>]) -> Vec<f32> {
109 let head_dim = self.input_dim / self.num_heads;
110
111 let scores: Vec<f32> = keys
112 .iter()
113 .map(|key| {
114 let mut score = 0.0f32;
115 for h in 0..self.num_heads {
116 let start = h * head_dim;
117 let end = (start + head_dim).min(query.len()).min(key.len());
118
119 let dot: f32 = query[start..end]
120 .iter()
121 .zip(key[start..end].iter())
122 .map(|(q, k)| q * k)
123 .sum();
124
125 score += dot / (head_dim as f32).sqrt();
126 }
127 score / self.num_heads as f32
128 })
129 .collect();
130
131 let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
133 let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
134 let sum_exp: f32 = exp_scores.iter().sum();
135
136 exp_scores.iter().map(|e| e / sum_exp).collect()
137 }
138
139 fn linear_transform(&self, input: &[f32]) -> Vec<f32> {
141 let mut output = vec![0.0f32; self.input_dim];
142
143 for i in 0..self.input_dim {
144 for j in 0..input.len().min(self.input_dim) {
145 let weight_idx = i * self.input_dim + j;
146 if weight_idx < self.weights.len() {
147 output[i] += input[j] * self.weights[weight_idx];
148 }
149 }
150 }
151
152 output
153 }
154
155 pub fn update(&mut self, query: &[f32], target_score: f32, learning_rate: f32) {
157 for (i, weight) in self.weights.iter_mut().enumerate() {
159 let query_idx = i % query.len().max(1);
160 let gradient = query.get(query_idx).unwrap_or(&0.0) * (target_score - 1.0);
161 *weight += learning_rate * gradient * 0.01;
162 }
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_gnn_layer_forward() {
172 let layer = GnnLayer::new(64, 128, 4);
173
174 let node = vec![0.5; 64];
175 let neighbors = vec![vec![0.3; 64], vec![0.7; 64]];
176 let edge_weights = vec![0.8, 0.6];
177
178 let output = layer.forward(&node, &neighbors, &edge_weights);
179
180 assert_eq!(output.len(), 64);
181 }
182
183 #[test]
184 fn test_gnn_layer_no_neighbors() {
185 let layer = GnnLayer::new(32, 64, 2);
186
187 let node = vec![0.5; 32];
188 let output = layer.forward(&node, &[], &[]);
189
190 assert_eq!(output.len(), 32);
191 }
192}