Skip to main content

MessagePassing

Trait MessagePassing 

Source
pub trait MessagePassing {
    // Required methods
    fn message(
        &self,
        x_src: &Tensor,
        x_tgt: &Tensor,
        edge_index: &AdjacencyMatrix,
    ) -> Tensor;
    fn aggregate(
        &self,
        messages: &Tensor,
        edge_index: &AdjacencyMatrix,
        num_nodes: usize,
    ) -> Tensor;
    fn update(&self, x: &Tensor, aggregated: &Tensor) -> Tensor;

    // Provided method
    fn propagate(&self, x: &Tensor, edge_index: &AdjacencyMatrix) -> Tensor { ... }
}
Expand description

Message Passing Neural Network base trait.

Defines the generic message passing framework that underlies all GNN layers.

Required Methods§

Source

fn message( &self, x_src: &Tensor, x_tgt: &Tensor, edge_index: &AdjacencyMatrix, ) -> Tensor

Compute messages from source to target nodes.

Source

fn aggregate( &self, messages: &Tensor, edge_index: &AdjacencyMatrix, num_nodes: usize, ) -> Tensor

Aggregate messages for each node.

Source

fn update(&self, x: &Tensor, aggregated: &Tensor) -> Tensor

Update node representations based on aggregated messages.

Provided Methods§

Source

fn propagate(&self, x: &Tensor, edge_index: &AdjacencyMatrix) -> Tensor

Full message passing forward.

Implementors§