dol 0.8.1

DOL (Design Ontology Language) - A declarative specification language for ontology-first development
gen MessagePassingLayer<NodeDim: u64, EdgeDim: u64, HiddenDim: u64> {
  // Learnable weight matrices for message and update functions
  has message_mlp_weights: Array<f64>
  has update_mlp_weights: Array<f64>

  // Aggregation strategy: "sum", "mean", or "max"
  has aggregation: string

  // Whether to incorporate edge features in message computation
  has use_edge_features: bool

  // Constraint: aggregation must be a valid type
  rule valid_aggregation {
    this.aggregation == "sum" ||
    this.aggregation == "mean" ||
    this.aggregation == "max"
  }

  // Law: The layer is S_n equivariant (permutation equivariance)
  // For any permutation P in S_n: forward(P * G) = P * forward(G)
  law permutation_equivariance {
    forall perm: PermutationGroup<N>. forall graph: Graph<Array<f64>, Array<f64>>.
      let permuted_input = graph.permute_nodes(perm)
      let permuted_output = this.forward(graph).permute_nodes(perm)
      this.forward(permuted_input) == permuted_output
  }

  // Forward pass: implements EquivariantLayer<PermutationGroup>
  fun forward(graph: Graph<Array<f64>, Array<f64>>) -> Graph<Array<f64>, Array<f64>> {
    let n = graph.node_count
    let new_features = (0..n).map(|i| {
      let neighbors = graph.neighbors(i)
      let messages = neighbors.map(|j| {
        let src = graph.get_node(j)
        let dst = graph.get_node(i)
        let edge = if this.use_edge_features {
          graph.get_edge_data(j, i).unwrap_or([])
        } else {
          []
        }
        this.message(src, dst, edge)
      })
      let aggregated = this.aggregate(messages)
      this.update(graph.get_node(i), aggregated)
    }).collect()
    return Graph {
      nodes: new_features,
      edges: graph.edges,
      adjacency: graph.adjacency,
      node_count: n,
      edge_count: graph.edge_count
    }
  }

  // Return trainable parameters: implements EquivariantLayer<PermutationGroup>
  fun parameters() -> Vec<Array<f64>> {
    return [this.message_mlp_weights, this.update_mlp_weights]
  }

  // Set trainable parameters: implements EquivariantLayer<PermutationGroup>
  fun set_parameters(params: Vec<Array<f64>>) -> () {
    this.message_mlp_weights = params[0]
    this.update_mlp_weights = params[1]
  }

  // Compute message from source node to destination node
  // M(h_src, h_dst, e_ij) = MLP_msg(concat(h_src, h_dst, e_ij))
  fun message(src_features: Array<f64>, dst_features: Array<f64>, edge_features: Array<f64>) -> Array<f64> {
    let input = if this.use_edge_features {
      src_features.concat(dst_features).concat(edge_features)
    } else {
      src_features.concat(dst_features)
    }
    return this.mlp_forward(input, this.message_mlp_weights)
  }

  // Aggregate incoming messages using sum, mean, or max
  // This is where permutation invariance is enforced
  fun aggregate(messages: Vec<Array<f64>>) -> Array<f64> {
    if messages.length == 0 {
      return this.zeros()
    }
    match this.aggregation {
      "sum" => this.aggregate_sum(messages),
      "mean" => this.aggregate_mean(messages),
      "max" => this.aggregate_max(messages),
      otherwise => this.zeros()
    }
  }

  // Sum aggregation: element-wise sum of all message vectors
  fun aggregate_sum(messages: Vec<Array<f64>>) -> Array<f64> {
    return messages.fold(this.zeros(), |acc, m| this.vec_add(acc, m))
  }

  // Mean aggregation: element-wise mean of all message vectors
  fun aggregate_mean(messages: Vec<Array<f64>>) -> Array<f64> {
    let sum = this.aggregate_sum(messages)
    let count = messages.length as f64
    return (0..HiddenDim).map(|i| sum[i] / count).collect()
  }

  // Max aggregation: element-wise max of all message vectors
  fun aggregate_max(messages: Vec<Array<f64>>) -> Array<f64> {
    return messages.fold(messages[0].clone(), |acc, m| this.vec_max(acc, m))
  }

  // Helper: Element-wise vector addition
  fun vec_add(a: Array<f64>, b: Array<f64>) -> Array<f64> {
    return (0..HiddenDim).map(|i| a[i] + b[i]).collect()
  }

  // Helper: Element-wise vector max
  fun vec_max(a: Array<f64>, b: Array<f64>) -> Array<f64> {
    return (0..HiddenDim).map(|i| if a[i] > b[i] { a[i] } else { b[i] }).collect()
  }

  // Helper: Create zero vector of HiddenDim size
  fun zeros() -> Array<f64> {
    return (0..HiddenDim).map(|i| 0.0).collect()
  }

  // Update node features using aggregated messages
  // h'_i = MLP_update(concat(h_i, aggregated_i))
  fun update(node_features: Array<f64>, aggregated: Array<f64>) -> Array<f64> {
    let input = node_features.concat(aggregated)
    return this.mlp_forward(input, this.update_mlp_weights)
  }

  // Helper: Simple MLP forward pass (single layer with ReLU)
  fun mlp_forward(input: Array<f64>, weights: Array<f64>) -> Array<f64> {
    // Weights are flattened: [W11, W12, ..., b1, b2, ...]
    let input_dim = input.length
    let output = (0..HiddenDim).map(|i| {
      let sum = (0..input_dim).fold(0.0, |acc, j| {
        acc + input[j] * weights[i * input_dim + j]
      })
      let bias = weights[HiddenDim * input_dim + i]
      let activated = sum + bias
      // ReLU activation
      if activated > 0.0 { activated } else { 0.0 }
    }).collect()
    return output
  }
}

docs {
  MessagePassingLayer<NodeDim, EdgeDim, HiddenDim> implements a permutation-equivariant
  neural network layer for graph-structured data. This is the foundational building
  block of Graph Neural Networks (GNNs) and the Message Passing Neural Network (MPNN)
  framework introduced by Gilmer et al. (2017).

  THE MESSAGE-AGGREGATE-UPDATE PARADIGM:
  Message passing operates in three distinct phases:
  1. MESSAGE: For each edge (j, i), compute a message m_ji from source j to target i
     using the features of both nodes and optionally the edge features.
  2. AGGREGATE: For each node i, combine all incoming messages using a permutation-
     invariant function (sum, mean, or max).
  3. UPDATE: Transform the aggregated messages with the original node features to
     produce updated node representations.

  TYPE PARAMETERS:
  - NodeDim: u64 - The dimension of input node feature vectors
  - EdgeDim: u64 - The dimension of edge feature vectors (0 if not using edges)
  - HiddenDim: u64 - The dimension of output node feature vectors (hidden layer size)

  FIELDS:
  - message_mlp_weights: Learnable parameters for the message function MLP
  - update_mlp_weights: Learnable parameters for the update function MLP
  - aggregation: string specifying aggregation type ("sum", "mean", or "max")
  - use_edge_features: Whether to incorporate edge features in message computation

  WHY AGGREGATION PROVIDES PERMUTATION EQUIVARIANCE:
  The key insight is that sum, mean, and max are all permutation-invariant functions:
  they produce the same output regardless of the order of their inputs. When we
  aggregate messages m_j1, m_j2, ..., m_jk from neighbors of node i, the result
  is the same no matter how we order those neighbors. This invariance at the
  aggregation step, combined with the per-node update, yields equivariance at
  the layer level: permuting nodes permutes the output features accordingly.

  EXPRESSIVENESS OF AGGREGATION FUNCTIONS:
  Different aggregation functions have different expressive power for distinguishing
  graphs (multisets of neighboring features):
  - SUM: Most expressive - can distinguish multisets by counting elements
  - MEAN: Less expressive - cannot distinguish {1,1} from {1}
  - MAX: Least expressive - only captures the maximum element

  The expressiveness hierarchy is: SUM > MEAN > MAX. This is formalized by the
  Weisfeiler-Lehman (WL) graph isomorphism test correspondence:
  - SUM aggregation matches 1-WL test expressiveness (GIN architecture)
  - MEAN/MAX are strictly weaker than 1-WL

  CONNECTION TO WEISFEILER-LEHMAN TEST:
  The 1-WL test iteratively refines node colors by hashing the multiset of neighbor
  colors. Message passing with SUM aggregation is exactly equivalent to 1-WL, making
  GNNs at most as powerful as 1-WL for distinguishing non-isomorphic graphs. This
  theoretical limit explains why standard GNNs cannot distinguish certain graph pairs
  (e.g., regular graphs with the same degree sequence).

  NOTABLE GNN ARCHITECTURES AND THEIR AGGREGATIONS:
  - GCN (Kipf & Welling 2017): MEAN aggregation, normalized by degree
  - GraphSAGE (Hamilton 2017): MEAN, MAX, or LSTM aggregation
  - GIN (Xu et al. 2019): SUM aggregation for maximum expressiveness
  - GAT (Velickovic 2018): Attention-weighted SUM aggregation
  - MPNN (Gilmer 2017): General framework with customizable aggregation

  IMPLEMENTS EQUIVARIANTLAYER<PERMUTATIONGROUP>:
  This gen satisfies the EquivariantLayer trait for the symmetric group S_n:
  - forward(graph): Applies one round of message passing
  - parameters(): Returns the learnable weight arrays
  - set_parameters(params): Updates the weight arrays

  THE PERMUTATION EQUIVARIANCE LAW:
  For any permutation P in S_n (represented as a PermutationGroup<N>):
    forward(P * G) = P * forward(G)
  where (P * G) denotes permuting the node indices of graph G. This ensures that
  relabeling nodes produces correspondingly relabeled outputs, which is essential
  because node ordering in a graph is arbitrary.

  MATHEMATICAL FORMULATION:
  Given a graph G = (V, E) with node features X in R^{n x d_in}:
    h_i^{(l+1)} = UPDATE(h_i^{(l)}, AGGREGATE_{j in N(i)} MESSAGE(h_i^{(l)}, h_j^{(l)}, e_ji))

  where:
  - h_i^{(l)} is the feature vector of node i at layer l
  - N(i) is the neighborhood of node i
  - MESSAGE, AGGREGATE, UPDATE are learnable functions

  STACKING MESSAGE PASSING LAYERS:
  Multiple MessagePassingLayer instances can be composed to increase the receptive
  field. After k layers, each node aggregates information from its k-hop neighborhood.
  The composition of equivariant layers is equivariant, so deep GNNs remain S_n-
  equivariant regardless of depth.

  OVER-SMOOTHING AND OVER-SQUASHING:
  Deep message passing networks face challenges:
  - Over-smoothing: Node features become indistinguishable after many layers
  - Over-squashing: Information bottleneck when aggregating from large neighborhoods
  Solutions include skip connections, normalization, and attention mechanisms.

  RELATED GENES:
  - Graph<N, E>: The input domain type with S_n symmetry
  - PermutationGroup<N>: The symmetry group this layer respects
  - AttentionLayer<NodeDim, HiddenDim>: Attention-based message passing variant
  - InvariantLayer<PermutationGroup>: For graph-level (invariant) predictions

  REFERENCES:
  - Gilmer et al., "Neural Message Passing for Quantum Chemistry" (ICML 2017)
  - Xu et al., "How Powerful are Graph Neural Networks?" (ICLR 2019)
  - Bronstein et al., "Geometric Deep Learning" (2021) - GDL Blueprint
  - Morris et al., "Weisfeiler and Leman Go Neural" (AAAI 2019)
}