use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CodeEdgeType {
ControlFlow,
DataFlow,
AstChild,
TypeAnnotation,
Ownership,
Call,
Return,
}
#[derive(Debug, Clone)]
pub struct CodeGraphNode {
pub id: usize,
pub features: Vec<f64>,
pub node_type: String,
}
impl CodeGraphNode {
#[must_use]
pub fn new(id: usize, features: Vec<f64>, node_type: impl Into<String>) -> Self {
Self {
id,
features,
node_type: node_type.into(),
}
}
#[must_use]
pub fn dim(&self) -> usize {
self.features.len()
}
}
#[derive(Debug, Clone)]
pub struct CodeGraphEdge {
pub source: usize,
pub target: usize,
pub edge_type: CodeEdgeType,
pub features: Option<Vec<f64>>,
}
impl CodeGraphEdge {
#[must_use]
pub fn new(source: usize, target: usize, edge_type: CodeEdgeType) -> Self {
Self {
source,
target,
edge_type,
features: None,
}
}
#[must_use]
pub fn with_features(mut self, features: Vec<f64>) -> Self {
self.features = Some(features);
self
}
}
#[derive(Debug, Clone)]
pub struct CodeGraph {
nodes: Vec<CodeGraphNode>,
edges: Vec<CodeGraphEdge>,
adj_list: Vec<Vec<(usize, usize)>>,
}
impl CodeGraph {
#[must_use]
pub fn new() -> Self {
Self {
nodes: Vec::new(),
edges: Vec::new(),
adj_list: Vec::new(),
}
}
pub fn add_node(&mut self, node: CodeGraphNode) -> usize {
let id = self.nodes.len();
self.nodes.push(node);
self.adj_list.push(Vec::new());
id
}
pub fn add_edge(&mut self, edge: CodeGraphEdge) {
let edge_idx = self.edges.len();
let source = edge.source;
let target = edge.target;
self.edges.push(edge);
if source < self.adj_list.len() {
self.adj_list[source].push((target, edge_idx));
}
if target < self.adj_list.len() {
self.adj_list[target].push((source, edge_idx));
}
}
#[must_use]
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
#[must_use]
pub fn num_edges(&self) -> usize {
self.edges.len()
}
#[must_use]
pub fn node(&self, idx: usize) -> Option<&CodeGraphNode> {
self.nodes.get(idx)
}
#[must_use]
pub fn neighbors(&self, node_idx: usize) -> &[(usize, usize)] {
self.adj_list.get(node_idx).map_or(&[], Vec::as_slice)
}
#[must_use]
pub fn nodes(&self) -> &[CodeGraphNode] {
&self.nodes
}
#[must_use]
pub fn edges(&self) -> &[CodeGraphEdge] {
&self.edges
}
}
impl Default for CodeGraph {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CodeMPNNLayer {
in_dim: usize,
out_dim: usize,
hidden_dim: usize,
edge_type_weights: HashMap<CodeEdgeType, Vec<f64>>,
message_weights: Vec<f64>,
update_weights: Vec<f64>,
seed: u64,
}
impl CodeMPNNLayer {
#[must_use]
pub fn new(in_dim: usize, out_dim: usize) -> Self {
let hidden_dim = usize::midpoint(in_dim, out_dim);
let seed = 42;
let mut edge_type_weights = HashMap::new();
for edge_type in [
CodeEdgeType::ControlFlow,
CodeEdgeType::DataFlow,
CodeEdgeType::AstChild,
CodeEdgeType::TypeAnnotation,
CodeEdgeType::Ownership,
CodeEdgeType::Call,
CodeEdgeType::Return,
] {
edge_type_weights.insert(edge_type, Self::init_weights(hidden_dim, seed));
}
Self {
in_dim,
out_dim,
hidden_dim,
edge_type_weights,
message_weights: Self::init_weights(in_dim * hidden_dim, seed),
update_weights: Self::init_weights(hidden_dim * out_dim, seed),
seed,
}
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self.message_weights = Self::init_weights(self.in_dim * self.hidden_dim, seed);
self.update_weights = Self::init_weights(self.hidden_dim * self.out_dim, seed);
self
}
fn init_weights(size: usize, seed: u64) -> Vec<f64> {
let scale = (2.0 / size as f64).sqrt();
let mut weights = Vec::with_capacity(size);
let mut hash = seed;
for _ in 0..size {
hash = hash.wrapping_mul(0x5851_f42d_4c95_7f2d).wrapping_add(1);
let val = ((hash >> 32) as f64) / f64::from(u32::MAX) * 2.0 - 1.0;
weights.push(val * scale);
}
weights
}
fn compute_message(&self, source_features: &[f64], edge_type: CodeEdgeType) -> Vec<f64> {
let edge_weights = self
.edge_type_weights
.get(&edge_type)
.expect("Edge type not found");
let mut message = vec![0.0; self.hidden_dim];
for i in 0..self.hidden_dim.min(source_features.len()) {
message[i] = source_features[i] * edge_weights[i % edge_weights.len()];
}
message
}
fn aggregate_messages(&self, messages: &[Vec<f64>]) -> Vec<f64> {
if messages.is_empty() {
return vec![0.0; self.hidden_dim];
}
let mut aggregated = vec![0.0; self.hidden_dim];
for msg in messages {
for (i, &val) in msg.iter().enumerate() {
if i < self.hidden_dim {
aggregated[i] += val;
}
}
}
let n = messages.len() as f64;
for val in &mut aggregated {
*val /= n;
}
aggregated
}
fn update_features(&self, node_features: &[f64], aggregated: &[f64]) -> Vec<f64> {
let mut updated = vec![0.0; self.out_dim];
for i in 0..self.out_dim {
let node_contrib = if i < node_features.len() {
node_features[i]
} else {
0.0
};
let msg_contrib = if i < aggregated.len() {
aggregated[i]
} else {
0.0
};
updated[i] = (node_contrib + msg_contrib).max(0.0);
}
updated
}
#[must_use]
pub fn forward(&self, graph: &CodeGraph) -> Vec<Vec<f64>> {
let n = graph.num_nodes();
let mut output = Vec::with_capacity(n);
for node_idx in 0..n {
let mut messages = Vec::new();
for &(neighbor_idx, edge_idx) in graph.neighbors(node_idx) {
if let Some(neighbor) = graph.node(neighbor_idx) {
let edge = &graph.edges()[edge_idx];
let msg = self.compute_message(&neighbor.features, edge.edge_type);
messages.push(msg);
}
}
let aggregated = self.aggregate_messages(&messages);
let node = graph.node(node_idx).expect("Node not found");
let updated = self.update_features(&node.features, &aggregated);
output.push(updated);
}
output
}
#[must_use]
pub fn in_dim(&self) -> usize {
self.in_dim
}
#[must_use]
pub fn out_dim(&self) -> usize {
self.out_dim
}
}
#[derive(Debug)]
pub struct CodeMPNN {
layers: Vec<CodeMPNNLayer>,
}
impl CodeMPNN {
#[must_use]
pub fn new(dims: &[usize]) -> Self {
assert!(dims.len() >= 2, "Need at least input and output dimensions");
let mut layers = Vec::new();
for i in 0..dims.len() - 1 {
layers.push(CodeMPNNLayer::new(dims[i], dims[i + 1]));
}
Self { layers }
}
#[must_use]
pub fn forward(&self, graph: &CodeGraph) -> Vec<Vec<f64>> {
if self.layers.is_empty() {
return graph.nodes().iter().map(|n| n.features.clone()).collect();
}
let mut current_features: Vec<Vec<f64>> =
graph.nodes().iter().map(|n| n.features.clone()).collect();
for layer in &self.layers {
let mut temp_graph = CodeGraph::new();
for (i, features) in current_features.iter().enumerate() {
let node_type = graph.node(i).map_or("unknown", |n| &n.node_type);
temp_graph.add_node(CodeGraphNode::new(i, features.clone(), node_type));
}
for edge in graph.edges() {
temp_graph.add_edge(edge.clone());
}
current_features = layer.forward(&temp_graph);
}
current_features
}
#[must_use]
pub fn num_layers(&self) -> usize {
self.layers.len()
}
#[must_use]
pub fn out_dim(&self) -> usize {
self.layers.last().map_or(0, CodeMPNNLayer::out_dim)
}
}
pub mod pooling {
#[must_use]
pub fn mean_pool(embeddings: &[Vec<f64>]) -> Vec<f64> {
if embeddings.is_empty() {
return Vec::new();
}
let dim = embeddings[0].len();
let mut result = vec![0.0; dim];
for emb in embeddings {
for (i, &val) in emb.iter().enumerate() {
result[i] += val;
}
}
let n = embeddings.len() as f64;
for val in &mut result {
*val /= n;
}
result
}
#[must_use]
pub fn max_pool(embeddings: &[Vec<f64>]) -> Vec<f64> {
if embeddings.is_empty() {
return Vec::new();
}
let dim = embeddings[0].len();
let mut result = vec![f64::NEG_INFINITY; dim];
for emb in embeddings {
for (i, &val) in emb.iter().enumerate() {
if val > result[i] {
result[i] = val;
}
}
}
result
}
#[must_use]
pub fn sum_pool(embeddings: &[Vec<f64>]) -> Vec<f64> {
if embeddings.is_empty() {
return Vec::new();
}
let dim = embeddings[0].len();
let mut result = vec![0.0; dim];
for emb in embeddings {
for (i, &val) in emb.iter().enumerate() {
result[i] += val;
}
}
result
}
}
#[cfg(test)]
#[path = "mpnn_tests.rs"]
mod tests;