Skip to main content

converge_knowledge/learning/
gnn.rs

1//! GNN layer implementation inspired by ruvector-gnn.
2
3use rand::Rng;
4
5/// A simplified GNN layer for embedding transformation.
6///
7/// This implements a message-passing mechanism where node embeddings
8/// are updated based on their neighbors, similar to the ruvector GNN layer.
9pub struct GnnLayer {
10    /// Input dimensions.
11    input_dim: usize,
12
13    /// Hidden dimensions (reserved for future multi-layer expansion).
14    #[allow(dead_code)]
15    hidden_dim: usize,
16
17    /// Number of attention heads.
18    num_heads: usize,
19
20    /// Linear transformation weights.
21    weights: Vec<f32>,
22
23    /// Attention weights (reserved for future explicit attention-head reads).
24    #[allow(dead_code)]
25    attention_weights: Vec<f32>,
26}
27
28impl GnnLayer {
29    /// Create a new GNN layer.
30    pub fn new(input_dim: usize, hidden_dim: usize, num_heads: usize) -> Self {
31        let mut rng = rand::thread_rng();
32
33        // Xavier initialization
34        let scale = (2.0 / (input_dim + hidden_dim) as f32).sqrt();
35
36        let weights: Vec<f32> = (0..input_dim * hidden_dim)
37            .map(|_| rng.gen_range(-scale..scale))
38            .collect();
39
40        let attention_weights: Vec<f32> = (0..num_heads * input_dim)
41            .map(|_| rng.gen_range(-scale..scale))
42            .collect();
43
44        Self {
45            input_dim,
46            hidden_dim,
47            num_heads,
48            weights,
49            attention_weights,
50        }
51    }
52
53    /// Forward pass through the GNN layer.
54    pub fn forward(
55        &self,
56        node_embedding: &[f32],
57        neighbor_embeddings: &[Vec<f32>],
58        edge_weights: &[f32],
59    ) -> Vec<f32> {
60        if neighbor_embeddings.is_empty() {
61            // No neighbors: just apply linear transformation
62            return self.linear_transform(node_embedding);
63        }
64
65        // Compute attention scores for each neighbor
66        let attention_scores = self.compute_attention(node_embedding, neighbor_embeddings);
67
68        // Combine attention with edge weights
69        let combined_weights: Vec<f32> = attention_scores
70            .iter()
71            .zip(edge_weights.iter())
72            .map(|(a, e)| a * e)
73            .collect();
74
75        // Normalize weights
76        let weight_sum: f32 = combined_weights.iter().sum();
77        let normalized_weights: Vec<f32> = if weight_sum > 0.0 {
78            combined_weights.iter().map(|w| w / weight_sum).collect()
79        } else {
80            vec![1.0 / neighbor_embeddings.len() as f32; neighbor_embeddings.len()]
81        };
82
83        // Aggregate neighbor messages
84        let mut aggregated = vec![0.0f32; self.input_dim];
85        for (neighbor, &weight) in neighbor_embeddings.iter().zip(normalized_weights.iter()) {
86            for (i, &val) in neighbor.iter().enumerate() {
87                if i < self.input_dim {
88                    aggregated[i] += val * weight;
89                }
90            }
91        }
92
93        // Combine with node embedding (skip connection)
94        let combined: Vec<f32> = node_embedding
95            .iter()
96            .zip(aggregated.iter())
97            .map(|(n, a)| 0.5 * n + 0.5 * a)
98            .collect();
99
100        // Apply linear transformation
101        let transformed = self.linear_transform(&combined);
102
103        // Apply ReLU activation
104        transformed.into_iter().map(|x| x.max(0.0)).collect()
105    }
106
107    /// Compute attention scores for neighbors.
108    fn compute_attention(&self, query: &[f32], keys: &[Vec<f32>]) -> Vec<f32> {
109        let head_dim = self.input_dim / self.num_heads;
110
111        let scores: Vec<f32> = keys
112            .iter()
113            .map(|key| {
114                let mut score = 0.0f32;
115                for h in 0..self.num_heads {
116                    let start = h * head_dim;
117                    let end = (start + head_dim).min(query.len()).min(key.len());
118
119                    let dot: f32 = query[start..end]
120                        .iter()
121                        .zip(key[start..end].iter())
122                        .map(|(q, k)| q * k)
123                        .sum();
124
125                    score += dot / (head_dim as f32).sqrt();
126                }
127                score / self.num_heads as f32
128            })
129            .collect();
130
131        // Softmax
132        let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
133        let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
134        let sum_exp: f32 = exp_scores.iter().sum();
135
136        exp_scores.iter().map(|e| e / sum_exp).collect()
137    }
138
139    /// Apply linear transformation.
140    fn linear_transform(&self, input: &[f32]) -> Vec<f32> {
141        let mut output = vec![0.0f32; self.input_dim];
142
143        for i in 0..self.input_dim {
144            for j in 0..input.len().min(self.input_dim) {
145                let weight_idx = i * self.input_dim + j;
146                if weight_idx < self.weights.len() {
147                    output[i] += input[j] * self.weights[weight_idx];
148                }
149            }
150        }
151
152        output
153    }
154
155    /// Update weights based on feedback.
156    pub fn update(&mut self, query: &[f32], target_score: f32, learning_rate: f32) {
157        // Simplified gradient update
158        for (i, weight) in self.weights.iter_mut().enumerate() {
159            let query_idx = i % query.len().max(1);
160            let gradient = query.get(query_idx).unwrap_or(&0.0) * (target_score - 1.0);
161            *weight += learning_rate * gradient * 0.01;
162        }
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_gnn_layer_forward() {
172        let layer = GnnLayer::new(64, 128, 4);
173
174        let node = vec![0.5; 64];
175        let neighbors = vec![vec![0.3; 64], vec![0.7; 64]];
176        let edge_weights = vec![0.8, 0.6];
177
178        let output = layer.forward(&node, &neighbors, &edge_weights);
179
180        assert_eq!(output.len(), 64);
181    }
182
183    #[test]
184    fn test_gnn_layer_no_neighbors() {
185        let layer = GnnLayer::new(32, 64, 2);
186
187        let node = vec![0.5; 32];
188        let output = layer.forward(&node, &[], &[]);
189
190        assert_eq!(output.len(), 32);
191    }
192}