1use ghostflow_core::Tensor;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct Graph {
15 pub num_nodes: usize,
17 pub num_edges: usize,
19 pub edges: Vec<(usize, usize)>,
21 pub node_features: Tensor,
23 pub edge_features: Option<Tensor>,
25 adjacency: HashMap<usize, Vec<usize>>,
27}
28
29impl Graph {
30 pub fn new(edges: Vec<(usize, usize)>, node_features: Tensor) -> Self {
32 let num_nodes = node_features.dims()[0];
33 let num_edges = edges.len();
34
35 let mut adjacency = HashMap::new();
37 for &(src, dst) in &edges {
38 adjacency.entry(src).or_insert_with(Vec::new).push(dst);
39 }
40
41 Graph {
42 num_nodes,
43 num_edges,
44 edges,
45 node_features,
46 edge_features: None,
47 adjacency,
48 }
49 }
50
51 pub fn with_edge_features(mut self, edge_features: Tensor) -> Self {
53 self.edge_features = Some(edge_features);
54 self
55 }
56
57 pub fn neighbors(&self, node: usize) -> &[usize] {
59 self.adjacency.get(&node).map(|v| v.as_slice()).unwrap_or(&[])
60 }
61
62 pub fn degree(&self, node: usize) -> usize {
64 self.neighbors(node).len()
65 }
66
67 pub fn normalized_adjacency(&self) -> Tensor {
69 let mut adj_data = vec![0.0f32; self.num_nodes * self.num_nodes];
71
72 for i in 0..self.num_nodes {
74 adj_data[i * self.num_nodes + i] = 1.0; }
76 for &(src, dst) in &self.edges {
77 adj_data[src * self.num_nodes + dst] = 1.0;
78 }
79
80 let mut degrees = vec![0.0f32; self.num_nodes];
82 for i in 0..self.num_nodes {
83 for j in 0..self.num_nodes {
84 degrees[i] += adj_data[i * self.num_nodes + j];
85 }
86 }
87
88 for i in 0..self.num_nodes {
90 for j in 0..self.num_nodes {
91 let idx = i * self.num_nodes + j;
92 if adj_data[idx] > 0.0 {
93 adj_data[idx] /= (degrees[i] * degrees[j]).sqrt();
94 }
95 }
96 }
97
98 Tensor::from_slice(&adj_data, &[self.num_nodes, self.num_nodes]).unwrap()
99 }
100}
101
102pub struct GCNLayer {
104 weight: Tensor,
105 bias: Option<Tensor>,
106 use_bias: bool,
107}
108
109impl GCNLayer {
110 pub fn new(in_features: usize, out_features: usize, use_bias: bool) -> Self {
112 let weight = Tensor::randn(&[in_features, out_features]);
113 let bias = if use_bias {
114 Some(Tensor::zeros(&[out_features]))
115 } else {
116 None
117 };
118
119 GCNLayer {
120 weight,
121 bias,
122 use_bias,
123 }
124 }
125
126 pub fn forward(&self, graph: &Graph, activation: bool) -> Tensor {
128 let adj = graph.normalized_adjacency();
129 let features = &graph.node_features;
130
131 let hw = features.matmul(&self.weight).unwrap();
133
134 let mut output = adj.matmul(&hw).unwrap();
136
137 if let Some(ref bias) = self.bias {
139 output = output.add(bias).unwrap();
140 }
141
142 if activation {
144 output = output.relu();
145 }
146
147 output
148 }
149}
150
151pub struct GATLayer {
153 weight: Tensor,
154 attention_weight: Tensor,
155 bias: Option<Tensor>,
156 num_heads: usize,
157 dropout: f32,
158}
159
160impl GATLayer {
161 pub fn new(in_features: usize, out_features: usize, num_heads: usize, dropout: f32) -> Self {
163 let weight = Tensor::randn(&[in_features, out_features * num_heads]);
164 let attention_weight = Tensor::randn(&[2 * out_features, 1]);
165 let bias = Some(Tensor::zeros(&[out_features * num_heads]));
166
167 GATLayer {
168 weight,
169 attention_weight,
170 bias,
171 num_heads,
172 dropout,
173 }
174 }
175
176 fn attention_coefficients(&self, node_i: &Tensor, node_j: &Tensor) -> f32 {
178 let data_i = node_i.data_f32();
181 let data_j = node_j.data_f32();
182 let mut concat_data = Vec::with_capacity(data_i.len() + data_j.len());
183 concat_data.extend_from_slice(&data_i);
184 concat_data.extend_from_slice(&data_j);
185
186 let concat = Tensor::from_slice(&concat_data, &[data_i.len() + data_j.len()]).unwrap();
187 let score = concat.matmul(&self.attention_weight).unwrap();
188
189 let data = score.data_f32();
191 let alpha = 0.2;
192 if data[0] > 0.0 {
193 data[0]
194 } else {
195 alpha * data[0]
196 }
197 }
198
199 pub fn forward(&self, graph: &Graph) -> Tensor {
201 let features = &graph.node_features;
202
203 let transformed = features.matmul(&self.weight).unwrap();
205
206 if let Some(ref bias) = self.bias {
209 transformed.add(bias).unwrap()
210 } else {
211 transformed
212 }
213 }
214}
215
216pub struct GraphSAGELayer {
218 weight_self: Tensor,
219 weight_neighbor: Tensor,
220 aggregator: AggregatorType,
221}
222
223#[derive(Debug, Clone, Copy)]
224pub enum AggregatorType {
225 Mean,
226 Pool,
227 LSTM,
228}
229
230impl GraphSAGELayer {
231 pub fn new(in_features: usize, out_features: usize, aggregator: AggregatorType) -> Self {
233 let weight_self = Tensor::randn(&[in_features, out_features]);
234 let weight_neighbor = Tensor::randn(&[in_features, out_features]);
235
236 GraphSAGELayer {
237 weight_self,
238 weight_neighbor,
239 aggregator,
240 }
241 }
242
243 fn aggregate(&self, neighbor_features: &[Tensor]) -> Tensor {
245 match self.aggregator {
246 AggregatorType::Mean => {
247 if neighbor_features.is_empty() {
249 return Tensor::zeros(neighbor_features[0].dims());
250 }
251
252 let sum = neighbor_features.iter()
253 .fold(Tensor::zeros(neighbor_features[0].dims()), |acc, feat| {
254 acc.add(feat).unwrap()
255 });
256
257 sum.div_scalar(neighbor_features.len() as f32)
258 }
259 AggregatorType::Pool => {
260 neighbor_features[0].clone() }
263 AggregatorType::LSTM => {
264 neighbor_features[0].clone()
266 }
267 }
268 }
269
270 pub fn forward(&self, graph: &Graph) -> Tensor {
272 let features = &graph.node_features;
273 let num_nodes = graph.num_nodes;
274 let feature_dim = features.dims()[1];
275
276 let mut output_data = Vec::new();
277
278 for node in 0..num_nodes {
279 let node_feat_data: Vec<f32> = (0..feature_dim)
281 .map(|i| features.data_f32()[node * feature_dim + i])
282 .collect();
283 let node_feat = Tensor::from_slice(&node_feat_data, &[1, feature_dim]).unwrap();
284
285 let neighbors = graph.neighbors(node);
287 let neighbor_feats: Vec<Tensor> = neighbors.iter()
288 .map(|&n| {
289 let data: Vec<f32> = (0..feature_dim)
290 .map(|i| features.data_f32()[n * feature_dim + i])
291 .collect();
292 Tensor::from_slice(&data, &[1, feature_dim]).unwrap()
293 })
294 .collect();
295
296 let aggregated = if !neighbor_feats.is_empty() {
298 self.aggregate(&neighbor_feats)
299 } else {
300 Tensor::zeros(&[1, feature_dim])
301 };
302
303 let self_part = node_feat.matmul(&self.weight_self).unwrap();
305 let neighbor_part = aggregated.matmul(&self.weight_neighbor).unwrap();
306 let combined = self_part.add(&neighbor_part).unwrap();
307
308 output_data.extend(combined.data_f32());
309 }
310
311 let out_dim = self.weight_self.dims()[1];
312 Tensor::from_slice(&output_data, &[num_nodes, out_dim]).unwrap()
313 }
314}
315
316pub struct MPNNLayer {
318 message_fn: Tensor,
319 update_fn: Tensor,
320}
321
322impl MPNNLayer {
323 pub fn new(node_dim: usize, edge_dim: usize, hidden_dim: usize) -> Self {
325 let message_fn = Tensor::randn(&[node_dim + edge_dim, hidden_dim]);
326 let update_fn = Tensor::randn(&[node_dim + hidden_dim, node_dim]);
327
328 MPNNLayer {
329 message_fn,
330 update_fn,
331 }
332 }
333
334 pub fn forward(&self, graph: &Graph) -> Tensor {
336 graph.node_features.clone()
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_graph_creation() {
348 let edges = vec![(0, 1), (1, 2), (2, 0)];
349 let features = Tensor::randn(&[3, 4]);
350 let graph = Graph::new(edges, features);
351
352 assert_eq!(graph.num_nodes, 3);
353 assert_eq!(graph.num_edges, 3);
354 assert_eq!(graph.neighbors(0).len(), 1);
355 }
356
357 #[test]
358 fn test_gcn_layer() {
359 let edges = vec![(0, 1), (1, 2), (2, 0)];
360 let features = Tensor::randn(&[3, 4]);
361 let graph = Graph::new(edges, features);
362
363 let gcn = GCNLayer::new(4, 8, true);
364 let output = gcn.forward(&graph, true);
365
366 assert_eq!(output.dims(), &[3, 8]);
367 }
368
369 #[test]
370 fn test_graphsage_layer() {
371 let edges = vec![(0, 1), (1, 2), (2, 0)];
372 let features = Tensor::randn(&[3, 4]);
373 let graph = Graph::new(edges, features);
374
375 let sage = GraphSAGELayer::new(4, 8, AggregatorType::Mean);
376 let output = sage.forward(&graph);
377
378 assert_eq!(output.dims(), &[3, 8]);
379 }
380}