gen MessagePassingLayer<NodeDim: u64, EdgeDim: u64, HiddenDim: u64> {
// Learnable weight matrices for message and update functions
has message_mlp_weights: Array<f64>
has update_mlp_weights: Array<f64>
// Aggregation strategy: "sum", "mean", or "max"
has aggregation: string
// Whether to incorporate edge features in message computation
has use_edge_features: bool
// Constraint: aggregation must be a valid type
rule valid_aggregation {
this.aggregation == "sum" ||
this.aggregation == "mean" ||
this.aggregation == "max"
}
// Law: The layer is S_n equivariant (permutation equivariance)
// For any permutation P in S_n: forward(P * G) = P * forward(G)
law permutation_equivariance {
forall perm: PermutationGroup<N>. forall graph: Graph<Array<f64>, Array<f64>>.
let permuted_input = graph.permute_nodes(perm)
let permuted_output = this.forward(graph).permute_nodes(perm)
this.forward(permuted_input) == permuted_output
}
// Forward pass: implements EquivariantLayer<PermutationGroup>
fun forward(graph: Graph<Array<f64>, Array<f64>>) -> Graph<Array<f64>, Array<f64>> {
let n = graph.node_count
let new_features = (0..n).map(|i| {
let neighbors = graph.neighbors(i)
let messages = neighbors.map(|j| {
let src = graph.get_node(j)
let dst = graph.get_node(i)
let edge = if this.use_edge_features {
graph.get_edge_data(j, i).unwrap_or([])
} else {
[]
}
this.message(src, dst, edge)
})
let aggregated = this.aggregate(messages)
this.update(graph.get_node(i), aggregated)
}).collect()
return Graph {
nodes: new_features,
edges: graph.edges,
adjacency: graph.adjacency,
node_count: n,
edge_count: graph.edge_count
}
}
// Return trainable parameters: implements EquivariantLayer<PermutationGroup>
fun parameters() -> Vec<Array<f64>> {
return [this.message_mlp_weights, this.update_mlp_weights]
}
// Set trainable parameters: implements EquivariantLayer<PermutationGroup>
fun set_parameters(params: Vec<Array<f64>>) -> () {
this.message_mlp_weights = params[0]
this.update_mlp_weights = params[1]
}
// Compute message from source node to destination node
// M(h_src, h_dst, e_ij) = MLP_msg(concat(h_src, h_dst, e_ij))
fun message(src_features: Array<f64>, dst_features: Array<f64>, edge_features: Array<f64>) -> Array<f64> {
let input = if this.use_edge_features {
src_features.concat(dst_features).concat(edge_features)
} else {
src_features.concat(dst_features)
}
return this.mlp_forward(input, this.message_mlp_weights)
}
// Aggregate incoming messages using sum, mean, or max
// This is where permutation invariance is enforced
fun aggregate(messages: Vec<Array<f64>>) -> Array<f64> {
if messages.length == 0 {
return this.zeros()
}
match this.aggregation {
"sum" => this.aggregate_sum(messages),
"mean" => this.aggregate_mean(messages),
"max" => this.aggregate_max(messages),
otherwise => this.zeros()
}
}
// Sum aggregation: element-wise sum of all message vectors
fun aggregate_sum(messages: Vec<Array<f64>>) -> Array<f64> {
return messages.fold(this.zeros(), |acc, m| this.vec_add(acc, m))
}
// Mean aggregation: element-wise mean of all message vectors
fun aggregate_mean(messages: Vec<Array<f64>>) -> Array<f64> {
let sum = this.aggregate_sum(messages)
let count = messages.length as f64
return (0..HiddenDim).map(|i| sum[i] / count).collect()
}
// Max aggregation: element-wise max of all message vectors
fun aggregate_max(messages: Vec<Array<f64>>) -> Array<f64> {
return messages.fold(messages[0].clone(), |acc, m| this.vec_max(acc, m))
}
// Helper: Element-wise vector addition
fun vec_add(a: Array<f64>, b: Array<f64>) -> Array<f64> {
return (0..HiddenDim).map(|i| a[i] + b[i]).collect()
}
// Helper: Element-wise vector max
fun vec_max(a: Array<f64>, b: Array<f64>) -> Array<f64> {
return (0..HiddenDim).map(|i| if a[i] > b[i] { a[i] } else { b[i] }).collect()
}
// Helper: Create zero vector of HiddenDim size
fun zeros() -> Array<f64> {
return (0..HiddenDim).map(|i| 0.0).collect()
}
// Update node features using aggregated messages
// h'_i = MLP_update(concat(h_i, aggregated_i))
fun update(node_features: Array<f64>, aggregated: Array<f64>) -> Array<f64> {
let input = node_features.concat(aggregated)
return this.mlp_forward(input, this.update_mlp_weights)
}
// Helper: Simple MLP forward pass (single layer with ReLU)
fun mlp_forward(input: Array<f64>, weights: Array<f64>) -> Array<f64> {
// Weights are flattened: [W11, W12, ..., b1, b2, ...]
let input_dim = input.length
let output = (0..HiddenDim).map(|i| {
let sum = (0..input_dim).fold(0.0, |acc, j| {
acc + input[j] * weights[i * input_dim + j]
})
let bias = weights[HiddenDim * input_dim + i]
let activated = sum + bias
// ReLU activation
if activated > 0.0 { activated } else { 0.0 }
}).collect()
return output
}
}
docs {
MessagePassingLayer<NodeDim, EdgeDim, HiddenDim> implements a permutation-equivariant
neural network layer for graph-structured data. This is the foundational building
block of Graph Neural Networks (GNNs) and the Message Passing Neural Network (MPNN)
framework introduced by Gilmer et al. (2017).
THE MESSAGE-AGGREGATE-UPDATE PARADIGM:
Message passing operates in three distinct phases:
1. MESSAGE: For each edge (j, i), compute a message m_ji from source j to target i
using the features of both nodes and optionally the edge features.
2. AGGREGATE: For each node i, combine all incoming messages using a permutation-
invariant function (sum, mean, or max).
3. UPDATE: Transform the aggregated messages with the original node features to
produce updated node representations.
TYPE PARAMETERS:
- NodeDim: u64 - The dimension of input node feature vectors
- EdgeDim: u64 - The dimension of edge feature vectors (0 if not using edges)
- HiddenDim: u64 - The dimension of output node feature vectors (hidden layer size)
FIELDS:
- message_mlp_weights: Learnable parameters for the message function MLP
- update_mlp_weights: Learnable parameters for the update function MLP
- aggregation: string specifying aggregation type ("sum", "mean", or "max")
- use_edge_features: Whether to incorporate edge features in message computation
WHY AGGREGATION PROVIDES PERMUTATION EQUIVARIANCE:
The key insight is that sum, mean, and max are all permutation-invariant functions:
they produce the same output regardless of the order of their inputs. When we
aggregate messages m_j1, m_j2, ..., m_jk from neighbors of node i, the result
is the same no matter how we order those neighbors. This invariance at the
aggregation step, combined with the per-node update, yields equivariance at
the layer level: permuting nodes permutes the output features accordingly.
EXPRESSIVENESS OF AGGREGATION FUNCTIONS:
Different aggregation functions have different expressive power for distinguishing
graphs (multisets of neighboring features):
- SUM: Most expressive - can distinguish multisets by counting elements
- MEAN: Less expressive - cannot distinguish {1,1} from {1}
- MAX: Least expressive - only captures the maximum element
The expressiveness hierarchy is: SUM > MEAN > MAX. This is formalized by the
Weisfeiler-Lehman (WL) graph isomorphism test correspondence:
- SUM aggregation matches 1-WL test expressiveness (GIN architecture)
- MEAN/MAX are strictly weaker than 1-WL
CONNECTION TO WEISFEILER-LEHMAN TEST:
The 1-WL test iteratively refines node colors by hashing the multiset of neighbor
colors. Message passing with SUM aggregation is exactly equivalent to 1-WL, making
GNNs at most as powerful as 1-WL for distinguishing non-isomorphic graphs. This
theoretical limit explains why standard GNNs cannot distinguish certain graph pairs
(e.g., regular graphs with the same degree sequence).
NOTABLE GNN ARCHITECTURES AND THEIR AGGREGATIONS:
- GCN (Kipf & Welling 2017): MEAN aggregation, normalized by degree
- GraphSAGE (Hamilton 2017): MEAN, MAX, or LSTM aggregation
- GIN (Xu et al. 2019): SUM aggregation for maximum expressiveness
- GAT (Velickovic 2018): Attention-weighted SUM aggregation
- MPNN (Gilmer 2017): General framework with customizable aggregation
IMPLEMENTS EQUIVARIANTLAYER<PERMUTATIONGROUP>:
This gen satisfies the EquivariantLayer trait for the symmetric group S_n:
- forward(graph): Applies one round of message passing
- parameters(): Returns the learnable weight arrays
- set_parameters(params): Updates the weight arrays
THE PERMUTATION EQUIVARIANCE LAW:
For any permutation P in S_n (represented as a PermutationGroup<N>):
forward(P * G) = P * forward(G)
where (P * G) denotes permuting the node indices of graph G. This ensures that
relabeling nodes produces correspondingly relabeled outputs, which is essential
because node ordering in a graph is arbitrary.
MATHEMATICAL FORMULATION:
Given a graph G = (V, E) with node features X in R^{n x d_in}:
h_i^{(l+1)} = UPDATE(h_i^{(l)}, AGGREGATE_{j in N(i)} MESSAGE(h_i^{(l)}, h_j^{(l)}, e_ji))
where:
- h_i^{(l)} is the feature vector of node i at layer l
- N(i) is the neighborhood of node i
- MESSAGE, AGGREGATE, UPDATE are learnable functions
STACKING MESSAGE PASSING LAYERS:
Multiple MessagePassingLayer instances can be composed to increase the receptive
field. After k layers, each node aggregates information from its k-hop neighborhood.
The composition of equivariant layers is equivariant, so deep GNNs remain S_n-
equivariant regardless of depth.
OVER-SMOOTHING AND OVER-SQUASHING:
Deep message passing networks face challenges:
- Over-smoothing: Node features become indistinguishable after many layers
- Over-squashing: Information bottleneck when aggregating from large neighborhoods
Solutions include skip connections, normalization, and attention mechanisms.
RELATED GENES:
- Graph<N, E>: The input domain type with S_n symmetry
- PermutationGroup<N>: The symmetry group this layer respects
- AttentionLayer<NodeDim, HiddenDim>: Attention-based message passing variant
- InvariantLayer<PermutationGroup>: For graph-level (invariant) predictions
REFERENCES:
- Gilmer et al., "Neural Message Passing for Quantum Chemistry" (ICML 2017)
- Xu et al., "How Powerful are Graph Neural Networks?" (ICLR 2019)
- Bronstein et al., "Geometric Deep Learning" (2021) - GDL Blueprint
- Morris et al., "Weisfeiler and Leman Go Neural" (AAAI 2019)
}