use super::cpg::{CodePropertyGraph, CpgEdgeKind, CpgNode, CpgNodeKind};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct GnnConfig {
pub num_layers: usize,
pub hidden_dim: usize,
pub dropout: f64,
pub use_edge_features: bool,
pub use_attention: bool,
pub embedding_dim: usize,
}
impl Default for GnnConfig {
fn default() -> Self {
Self {
num_layers: 3,
hidden_dim: 256,
dropout: 0.1,
use_edge_features: true,
use_attention: true,
embedding_dim: 128,
}
}
}
#[derive(Debug, Clone)]
pub struct NodeFeatures {
pub node_idx: usize,
pub token_features: Vec<f32>,
pub structural_features: Vec<f32>,
pub type_features: Vec<f32>,
}
impl NodeFeatures {
pub fn from_cpg_node(node: &CpgNode, depth: usize, child_count: usize) -> Self {
let mut structural = Vec::with_capacity(8);
structural.push((depth as f32) / 20.0);
structural.push((child_count as f32) / 10.0);
let span_len = (node.location.1 - node.location.0) as f32;
structural.push(span_len / 1000.0);
let kind_encoding = match node.kind {
CpgNodeKind::Function => 0,
CpgNodeKind::Variable => 1,
CpgNodeKind::Call => 2,
CpgNodeKind::Branch => 3,
CpgNodeKind::Loop => 4,
CpgNodeKind::Assignment => 5,
CpgNodeKind::Return => 6,
_ => 7,
};
structural.push(kind_encoding as f32 / 8.0);
Self {
node_idx: node.id,
token_features: Vec::new(), structural_features: structural,
type_features: Vec::new(), }
}
pub fn feature_dim(&self) -> usize {
self.token_features.len() + self.structural_features.len() + self.type_features.len()
}
pub fn to_vector(&self) -> Vec<f32> {
let mut v = Vec::with_capacity(self.feature_dim());
v.extend(&self.token_features);
v.extend(&self.structural_features);
v.extend(&self.type_features);
v
}
}
#[derive(Debug, Clone)]
pub struct EdgeFeatures {
pub source: usize,
pub target: usize,
pub edge_type: Vec<f32>,
}
impl EdgeFeatures {
pub fn from_edge_kind(source: usize, target: usize, kind: &CpgEdgeKind) -> Self {
let mut edge_type = vec![0.0; 6];
match kind {
CpgEdgeKind::AstChild | CpgEdgeKind::AstSibling => edge_type[0] = 1.0,
CpgEdgeKind::CfgNext
| CpgEdgeKind::CfgTrue
| CpgEdgeKind::CfgFalse
| CpgEdgeKind::CfgBack
| CpgEdgeKind::CfgException => edge_type[1] = 1.0,
CpgEdgeKind::DfgRead
| CpgEdgeKind::DfgWrite
| CpgEdgeKind::DfgFlow
| CpgEdgeKind::DfgDepends => edge_type[2] = 1.0,
CpgEdgeKind::Calls | CpgEdgeKind::Argument | CpgEdgeKind::Returns => edge_type[3] = 1.0,
CpgEdgeKind::HasType | CpgEdgeKind::Inherits => edge_type[4] = 1.0,
}
Self {
source,
target,
edge_type,
}
}
}
#[derive(Debug, Clone)]
pub struct SemanticIssue {
pub node_idx: usize,
pub issue_type: IssueType,
pub confidence: f64,
pub suggestion: Option<String>,
pub related_nodes: Vec<usize>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IssueType {
VariableMisuse,
TypeError,
MissingErrorHandling,
NullDereference,
UnusedBinding,
ApiMisuse,
ResourceLeak,
Anomaly,
}
pub struct GnnSemanticScorer {
config: GnnConfig,
node_embeddings: HashMap<usize, Vec<f32>>,
}
impl GnnSemanticScorer {
pub fn new(config: GnnConfig) -> Self {
Self {
config,
node_embeddings: HashMap::new(),
}
}
pub fn default_scorer() -> Self {
Self::new(GnnConfig::default())
}
pub fn extract_features(&self, cpg: &CodePropertyGraph) -> GnnFeatures {
let mut node_features = Vec::new();
let mut edge_features = Vec::new();
let depths = cpg.compute_depths();
let child_counts = cpg.compute_child_counts();
for node in cpg.all_nodes() {
let depth = depths.get(&node.id).copied().unwrap_or(0);
let children = child_counts.get(&node.id).copied().unwrap_or(0);
node_features.push(NodeFeatures::from_cpg_node(node, depth, children));
}
for (source, target, edge) in cpg.all_edges() {
edge_features.push(EdgeFeatures::from_edge_kind(source, target, &edge.kind));
}
GnnFeatures {
node_features,
edge_features,
num_nodes: cpg.node_count(),
num_edges: cpg.edge_count(),
}
}
pub fn score_node(&self, _cpg: &CodePropertyGraph, node_idx: usize) -> f64 {
if let Some(embedding) = self.node_embeddings.get(&node_idx) {
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
return (magnitude / 2.0).min(1.0) as f64;
}
0.0
}
pub fn detect_issues(&self, cpg: &CodePropertyGraph) -> Vec<SemanticIssue> {
let mut issues = Vec::new();
let node_indices: HashMap<usize, petgraph::graph::NodeIndex> =
cpg.nodes().map(|(idx, node)| (node.id, idx)).collect();
for node in cpg.all_nodes() {
if node.kind == CpgNodeKind::Variable {
let node_idx = match node_indices.get(&node.id) {
Some(idx) => *idx,
None => continue,
};
let incoming_data_flow = cpg
.edges_to(node_idx)
.filter(|(_, e)| matches!(e.kind, CpgEdgeKind::DfgFlow | CpgEdgeKind::DfgWrite))
.count();
let outgoing_data_flow = cpg
.edges_from(node_idx)
.filter(|(_, e)| matches!(e.kind, CpgEdgeKind::DfgFlow | CpgEdgeKind::DfgRead))
.count();
if incoming_data_flow > 0 && outgoing_data_flow == 0 {
issues.push(SemanticIssue {
node_idx: node.id,
issue_type: IssueType::UnusedBinding,
confidence: 0.6,
suggestion: Some("Variable may be unused".to_string()),
related_nodes: vec![],
});
}
}
if node.kind == CpgNodeKind::Call {
if let Some(name) = &node.name {
if matches!(name.as_str(), "unwrap" | "expect") {
issues.push(SemanticIssue {
node_idx: node.id,
issue_type: IssueType::MissingErrorHandling,
confidence: 0.75,
suggestion: Some(format!(
"Replace `{}` with explicit error handling",
name
)),
related_nodes: vec![],
});
}
}
}
}
issues
}
pub fn variable_misuse_candidates(
&self,
cpg: &CodePropertyGraph,
node_idx: usize,
) -> Vec<(String, f64)> {
let mut node_ref = None;
for n in cpg.all_nodes() {
if n.id == node_idx {
node_ref = Some(n);
break;
}
}
let node = match node_ref {
Some(n) => n,
None => return vec![],
};
if node.kind != CpgNodeKind::Variable {
return vec![];
}
let mut candidates = Vec::new();
let node_name = node.name.clone().unwrap_or_default();
for other in cpg.all_nodes() {
if other.id == node_idx {
continue;
}
if other.kind != CpgNodeKind::Variable {
continue;
}
if let Some(name) = &other.name {
if name != &node_name {
let score = self.compute_similarity(&node_name, name);
if score > 0.3 {
candidates.push((name.clone(), score));
}
}
}
}
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
candidates.truncate(5);
candidates
}
fn compute_similarity(&self, a: &str, b: &str) -> f64 {
if a.is_empty() || b.is_empty() {
return 0.0;
}
let chars_a: Vec<char> = a.chars().collect();
let chars_b: Vec<char> = b.chars().collect();
let bigrams_a: std::collections::HashSet<_> =
chars_a.windows(2).map(|w| (w[0], w[1])).collect();
let bigrams_b: std::collections::HashSet<_> =
chars_b.windows(2).map(|w| (w[0], w[1])).collect();
if bigrams_a.is_empty() || bigrams_b.is_empty() {
return if a == b { 1.0 } else { 0.0 };
}
let intersection = bigrams_a.intersection(&bigrams_b).count();
let union = bigrams_a.union(&bigrams_b).count();
if union == 0 {
0.0
} else {
intersection as f64 / union as f64
}
}
pub fn config(&self) -> &GnnConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct GnnFeatures {
pub node_features: Vec<NodeFeatures>,
pub edge_features: Vec<EdgeFeatures>,
pub num_nodes: usize,
pub num_edges: usize,
}
impl GnnFeatures {
pub fn to_adjacency_list(&self) -> Vec<Vec<usize>> {
let mut adj = vec![Vec::new(); self.num_nodes];
for edge in &self.edge_features {
if edge.source < self.num_nodes && edge.target < self.num_nodes {
adj[edge.source].push(edge.target);
}
}
adj
}
pub fn to_node_matrix(&self) -> Vec<Vec<f32>> {
self.node_features.iter().map(|n| n.to_vector()).collect()
}
}
#[cfg(test)]
mod tests {
use super::super::cpg::CpgEdge;
use super::*;
fn test_node(id: usize, kind: CpgNodeKind, name: Option<&str>) -> CpgNode {
CpgNode {
id,
kind,
name: name.map(str::to_string),
location: (id * 10, id * 10 + 1),
position: (id, 0),
ast_kind: format!("{:?}", kind),
properties: HashMap::new(),
}
}
fn test_edge(kind: CpgEdgeKind) -> CpgEdge {
CpgEdge { kind, label: None }
}
#[test]
fn test_gnn_config_default() {
let config = GnnConfig::default();
assert_eq!(config.num_layers, 3);
assert_eq!(config.hidden_dim, 256);
assert!(config.use_attention);
}
#[test]
fn test_edge_features_encoding() {
let edge = EdgeFeatures::from_edge_kind(0, 1, &CpgEdgeKind::DfgFlow);
assert_eq!(edge.source, 0);
assert_eq!(edge.target, 1);
assert_eq!(edge.edge_type[2], 1.0); assert_eq!(edge.edge_type.iter().filter(|&&x| x == 1.0).count(), 1);
}
#[test]
fn test_similarity_computation() {
let scorer = GnnSemanticScorer::default_scorer();
let sim = scorer.compute_similarity("count", "counter");
assert!(sim > 0.5);
let sim = scorer.compute_similarity("foo", "bar");
assert!(sim < 0.3);
let sim = scorer.compute_similarity("test", "test");
assert!((sim - 1.0).abs() < 0.01);
}
#[test]
fn test_detect_issues_uses_graph_indices_not_node_ids() {
let mut cpg = CodePropertyGraph::new();
let assignment_idx = cpg.add_node(test_node(20, CpgNodeKind::Assignment, Some("=")));
let variable_idx = cpg.add_node(test_node(10, CpgNodeKind::Variable, Some("unused")));
cpg.add_edge(
assignment_idx,
variable_idx,
test_edge(CpgEdgeKind::DfgWrite),
);
let scorer = GnnSemanticScorer::default_scorer();
let issues = scorer.detect_issues(&cpg);
assert!(issues
.iter()
.any(|issue| { issue.node_idx == 10 && issue.issue_type == IssueType::UnusedBinding }));
}
#[test]
fn test_detects_unwrap_calls_as_missing_error_handling() {
let mut cpg = CodePropertyGraph::new();
cpg.add_node(test_node(42, CpgNodeKind::Call, Some("unwrap")));
let scorer = GnnSemanticScorer::default_scorer();
let issues = scorer.detect_issues(&cpg);
assert!(issues.iter().any(|issue| {
issue.node_idx == 42 && issue.issue_type == IssueType::MissingErrorHandling
}));
}
}