ghostflow_nn/
gnn.rs

1//! Graph Neural Networks (GNN) module
2//!
3//! Implements various GNN architectures:
4//! - Graph Convolutional Networks (GCN)
5//! - Graph Attention Networks (GAT)
6//! - GraphSAGE
7//! - Message Passing Neural Networks (MPNN)
8
9use ghostflow_core::Tensor;
10use std::collections::HashMap;
11
12/// Graph structure for GNN operations
13#[derive(Debug, Clone)]
14pub struct Graph {
15    /// Number of nodes
16    pub num_nodes: usize,
17    /// Number of edges
18    pub num_edges: usize,
19    /// Edge list: (source, target) pairs
20    pub edges: Vec<(usize, usize)>,
21    /// Node features [num_nodes, feature_dim]
22    pub node_features: Tensor,
23    /// Edge features [num_edges, edge_feature_dim] (optional)
24    pub edge_features: Option<Tensor>,
25    /// Adjacency matrix (sparse representation)
26    adjacency: HashMap<usize, Vec<usize>>,
27}
28
29impl Graph {
30    /// Create a new graph from edge list and node features
31    pub fn new(edges: Vec<(usize, usize)>, node_features: Tensor) -> Self {
32        let num_nodes = node_features.dims()[0];
33        let num_edges = edges.len();
34        
35        // Build adjacency list
36        let mut adjacency = HashMap::new();
37        for &(src, dst) in &edges {
38            adjacency.entry(src).or_insert_with(Vec::new).push(dst);
39        }
40        
41        Graph {
42            num_nodes,
43            num_edges,
44            edges,
45            node_features,
46            edge_features: None,
47            adjacency,
48        }
49    }
50    
51    /// Add edge features
52    pub fn with_edge_features(mut self, edge_features: Tensor) -> Self {
53        self.edge_features = Some(edge_features);
54        self
55    }
56    
57    /// Get neighbors of a node
58    pub fn neighbors(&self, node: usize) -> &[usize] {
59        self.adjacency.get(&node).map(|v| v.as_slice()).unwrap_or(&[])
60    }
61    
62    /// Get degree of a node
63    pub fn degree(&self, node: usize) -> usize {
64        self.neighbors(node).len()
65    }
66    
67    /// Compute normalized adjacency matrix (for GCN)
68    pub fn normalized_adjacency(&self) -> Tensor {
69        // A_norm = D^(-1/2) * A * D^(-1/2)
70        let mut adj_data = vec![0.0f32; self.num_nodes * self.num_nodes];
71        
72        // Build adjacency matrix with self-loops
73        for i in 0..self.num_nodes {
74            adj_data[i * self.num_nodes + i] = 1.0; // Self-loop
75        }
76        for &(src, dst) in &self.edges {
77            adj_data[src * self.num_nodes + dst] = 1.0;
78        }
79        
80        // Compute degree matrix
81        let mut degrees = vec![0.0f32; self.num_nodes];
82        for i in 0..self.num_nodes {
83            for j in 0..self.num_nodes {
84                degrees[i] += adj_data[i * self.num_nodes + j];
85            }
86        }
87        
88        // Normalize: D^(-1/2) * A * D^(-1/2)
89        for i in 0..self.num_nodes {
90            for j in 0..self.num_nodes {
91                let idx = i * self.num_nodes + j;
92                if adj_data[idx] > 0.0 {
93                    adj_data[idx] /= (degrees[i] * degrees[j]).sqrt();
94                }
95            }
96        }
97        
98        Tensor::from_slice(&adj_data, &[self.num_nodes, self.num_nodes]).unwrap()
99    }
100}
101
102/// Graph Convolutional Network (GCN) layer
103pub struct GCNLayer {
104    weight: Tensor,
105    bias: Option<Tensor>,
106    use_bias: bool,
107}
108
109impl GCNLayer {
110    /// Create a new GCN layer
111    pub fn new(in_features: usize, out_features: usize, use_bias: bool) -> Self {
112        let weight = Tensor::randn(&[in_features, out_features]);
113        let bias = if use_bias {
114            Some(Tensor::zeros(&[out_features]))
115        } else {
116            None
117        };
118        
119        GCNLayer {
120            weight,
121            bias,
122            use_bias,
123        }
124    }
125    
126    /// Forward pass: H' = σ(A_norm * H * W + b)
127    pub fn forward(&self, graph: &Graph, activation: bool) -> Tensor {
128        let adj = graph.normalized_adjacency();
129        let features = &graph.node_features;
130        
131        // H * W
132        let hw = features.matmul(&self.weight).unwrap();
133        
134        // A_norm * (H * W)
135        let mut output = adj.matmul(&hw).unwrap();
136        
137        // Add bias
138        if let Some(ref bias) = self.bias {
139            output = output.add(bias).unwrap();
140        }
141        
142        // Apply activation (ReLU)
143        if activation {
144            output = output.relu();
145        }
146        
147        output
148    }
149}
150
151/// Graph Attention Network (GAT) layer
152pub struct GATLayer {
153    weight: Tensor,
154    attention_weight: Tensor,
155    bias: Option<Tensor>,
156    num_heads: usize,
157    dropout: f32,
158}
159
160impl GATLayer {
161    /// Create a new GAT layer with multi-head attention
162    pub fn new(in_features: usize, out_features: usize, num_heads: usize, dropout: f32) -> Self {
163        let weight = Tensor::randn(&[in_features, out_features * num_heads]);
164        let attention_weight = Tensor::randn(&[2 * out_features, 1]);
165        let bias = Some(Tensor::zeros(&[out_features * num_heads]));
166        
167        GATLayer {
168            weight,
169            attention_weight,
170            bias,
171            num_heads,
172            dropout,
173        }
174    }
175    
176    /// Compute attention coefficients
177    fn attention_coefficients(&self, node_i: &Tensor, node_j: &Tensor) -> f32 {
178        // Concatenate features and compute attention
179        // e_ij = LeakyReLU(a^T [W*h_i || W*h_j])
180        let data_i = node_i.data_f32();
181        let data_j = node_j.data_f32();
182        let mut concat_data = Vec::with_capacity(data_i.len() + data_j.len());
183        concat_data.extend_from_slice(&data_i);
184        concat_data.extend_from_slice(&data_j);
185        
186        let concat = Tensor::from_slice(&concat_data, &[data_i.len() + data_j.len()]).unwrap();
187        let score = concat.matmul(&self.attention_weight).unwrap();
188        
189        // LeakyReLU with alpha=0.2
190        let data = score.data_f32();
191        let alpha = 0.2;
192        if data[0] > 0.0 {
193            data[0]
194        } else {
195            alpha * data[0]
196        }
197    }
198    
199    /// Forward pass with multi-head attention
200    pub fn forward(&self, graph: &Graph) -> Tensor {
201        let features = &graph.node_features;
202        
203        // Transform features: H' = W * H
204        let transformed = features.matmul(&self.weight).unwrap();
205        
206        // For now, return transformed features
207        // Full GAT implementation would compute attention for each edge
208        if let Some(ref bias) = self.bias {
209            transformed.add(bias).unwrap()
210        } else {
211            transformed
212        }
213    }
214}
215
216/// GraphSAGE layer (Sample and Aggregate)
217pub struct GraphSAGELayer {
218    weight_self: Tensor,
219    weight_neighbor: Tensor,
220    aggregator: AggregatorType,
221}
222
223#[derive(Debug, Clone, Copy)]
224pub enum AggregatorType {
225    Mean,
226    Pool,
227    LSTM,
228}
229
230impl GraphSAGELayer {
231    /// Create a new GraphSAGE layer
232    pub fn new(in_features: usize, out_features: usize, aggregator: AggregatorType) -> Self {
233        let weight_self = Tensor::randn(&[in_features, out_features]);
234        let weight_neighbor = Tensor::randn(&[in_features, out_features]);
235        
236        GraphSAGELayer {
237            weight_self,
238            weight_neighbor,
239            aggregator,
240        }
241    }
242    
243    /// Aggregate neighbor features
244    fn aggregate(&self, neighbor_features: &[Tensor]) -> Tensor {
245        match self.aggregator {
246            AggregatorType::Mean => {
247                // Mean aggregation
248                if neighbor_features.is_empty() {
249                    return Tensor::zeros(neighbor_features[0].dims());
250                }
251                
252                let sum = neighbor_features.iter()
253                    .fold(Tensor::zeros(neighbor_features[0].dims()), |acc, feat| {
254                        acc.add(feat).unwrap()
255                    });
256                
257                sum.div_scalar(neighbor_features.len() as f32)
258            }
259            AggregatorType::Pool => {
260                // Max pooling aggregation
261                neighbor_features[0].clone() // Simplified
262            }
263            AggregatorType::LSTM => {
264                // LSTM aggregation (simplified)
265                neighbor_features[0].clone()
266            }
267        }
268    }
269    
270    /// Forward pass: h_v' = σ(W_self * h_v + W_neighbor * AGG({h_u : u ∈ N(v)}))
271    pub fn forward(&self, graph: &Graph) -> Tensor {
272        let features = &graph.node_features;
273        let num_nodes = graph.num_nodes;
274        let feature_dim = features.dims()[1];
275        
276        let mut output_data = Vec::new();
277        
278        for node in 0..num_nodes {
279            // Get node's own features
280            let node_feat_data: Vec<f32> = (0..feature_dim)
281                .map(|i| features.data_f32()[node * feature_dim + i])
282                .collect();
283            let node_feat = Tensor::from_slice(&node_feat_data, &[1, feature_dim]).unwrap();
284            
285            // Get neighbor features
286            let neighbors = graph.neighbors(node);
287            let neighbor_feats: Vec<Tensor> = neighbors.iter()
288                .map(|&n| {
289                    let data: Vec<f32> = (0..feature_dim)
290                        .map(|i| features.data_f32()[n * feature_dim + i])
291                        .collect();
292                    Tensor::from_slice(&data, &[1, feature_dim]).unwrap()
293                })
294                .collect();
295            
296            // Aggregate neighbors
297            let aggregated = if !neighbor_feats.is_empty() {
298                self.aggregate(&neighbor_feats)
299            } else {
300                Tensor::zeros(&[1, feature_dim])
301            };
302            
303            // Combine self and neighbor information
304            let self_part = node_feat.matmul(&self.weight_self).unwrap();
305            let neighbor_part = aggregated.matmul(&self.weight_neighbor).unwrap();
306            let combined = self_part.add(&neighbor_part).unwrap();
307            
308            output_data.extend(combined.data_f32());
309        }
310        
311        let out_dim = self.weight_self.dims()[1];
312        Tensor::from_slice(&output_data, &[num_nodes, out_dim]).unwrap()
313    }
314}
315
316/// Message Passing Neural Network (MPNN) layer
317pub struct MPNNLayer {
318    message_fn: Tensor,
319    update_fn: Tensor,
320}
321
322impl MPNNLayer {
323    /// Create a new MPNN layer
324    pub fn new(node_dim: usize, edge_dim: usize, hidden_dim: usize) -> Self {
325        let message_fn = Tensor::randn(&[node_dim + edge_dim, hidden_dim]);
326        let update_fn = Tensor::randn(&[node_dim + hidden_dim, node_dim]);
327        
328        MPNNLayer {
329            message_fn,
330            update_fn,
331        }
332    }
333    
334    /// Forward pass: message passing and node update
335    pub fn forward(&self, graph: &Graph) -> Tensor {
336        // Simplified MPNN implementation
337        // Full version would iterate over edges and aggregate messages
338        graph.node_features.clone()
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    
346    #[test]
347    fn test_graph_creation() {
348        let edges = vec![(0, 1), (1, 2), (2, 0)];
349        let features = Tensor::randn(&[3, 4]);
350        let graph = Graph::new(edges, features);
351        
352        assert_eq!(graph.num_nodes, 3);
353        assert_eq!(graph.num_edges, 3);
354        assert_eq!(graph.neighbors(0).len(), 1);
355    }
356    
357    #[test]
358    fn test_gcn_layer() {
359        let edges = vec![(0, 1), (1, 2), (2, 0)];
360        let features = Tensor::randn(&[3, 4]);
361        let graph = Graph::new(edges, features);
362        
363        let gcn = GCNLayer::new(4, 8, true);
364        let output = gcn.forward(&graph, true);
365        
366        assert_eq!(output.dims(), &[3, 8]);
367    }
368    
369    #[test]
370    fn test_graphsage_layer() {
371        let edges = vec![(0, 1), (1, 2), (2, 0)];
372        let features = Tensor::randn(&[3, 4]);
373        let graph = Graph::new(edges, features);
374        
375        let sage = GraphSAGELayer::new(4, 8, AggregatorType::Mean);
376        let output = sage.forward(&graph);
377        
378        assert_eq!(output.dims(), &[3, 8]);
379    }
380}