impl GNNErrorEncoder {
#[must_use]
pub fn new(hidden_dim: usize, output_dim: usize) -> Self {
let node_feature_dim = 72;
Self {
hidden_dim,
output_dim,
gcn1: GCNConv::new(node_feature_dim, hidden_dim),
sage: SAGEConv::new(hidden_dim, hidden_dim).with_aggregation(SAGEAggregation::Mean),
gcn2: GCNConv::new(hidden_dim, output_dim),
node_type_dim: 8,
base_encoder: ErrorEncoder::with_dim(64),
}
}
#[must_use]
pub fn default_config() -> Self {
Self::new(64, 256)
}
#[must_use]
pub fn output_dim(&self) -> usize {
self.output_dim
}
#[must_use]
pub fn build_graph(
&self,
diagnostic: &CompilerDiagnostic,
source: &str,
) -> ProgramFeedbackGraph {
let mut graph = ProgramFeedbackGraph::new();
let diag_features = self.extract_diagnostic_features(diagnostic);
let diag_idx = graph.add_node(NodeType::Diagnostic, diag_features);
let expected_idx = diagnostic.expected.as_ref().map(|expected| {
let features = self.extract_type_node_features(&expected.base, true);
graph.add_node(NodeType::ExpectedType, features)
});
let found_idx = diagnostic.found.as_ref().map(|found| {
let features = self.extract_type_node_features(&found.base, false);
graph.add_node(NodeType::FoundType, features)
});
let ast_indices = self.extract_ast_nodes(&mut graph, source, &diagnostic.span);
if let Some(exp_idx) = expected_idx {
graph.add_edge(diag_idx, exp_idx, EdgeType::Expects);
}
if let Some(fnd_idx) = found_idx {
graph.add_edge(diag_idx, fnd_idx, EdgeType::Found);
}
for &ast_idx in &ast_indices {
graph.add_edge(diag_idx, ast_idx, EdgeType::DiagnosticRefers);
}
for window in ast_indices.windows(2) {
graph.add_edge(window[0], window[1], EdgeType::AstChild);
}
if !diagnostic.suggestions.is_empty() {
let suggestion = &diagnostic.suggestions[0];
let sugg_features = self.extract_suggestion_features(&suggestion.message);
let sugg_idx = graph.add_node(NodeType::Suggestion, sugg_features);
graph.add_edge(diag_idx, sugg_idx, EdgeType::DiagnosticRefers);
}
graph
}
#[must_use]
pub fn encode_graph(&self, graph: &ProgramFeedbackGraph) -> ErrorEmbedding {
if graph.num_nodes() == 0 {
return ErrorEmbedding::new(
vec![0.0; self.output_dim],
ErrorCode::new(
"E0000",
super::ErrorCategory::Unknown,
super::Difficulty::Easy,
),
0,
);
}
let node_tensor = self.graph_to_tensor(graph);
let adj = self.graph_to_adjacency(graph);
let h1 = self.gcn1.forward(&node_tensor, &adj);
let h1_relu = Self::relu(&h1);
let h2 = self.sage.forward(&h1_relu, &adj);
let h2_relu = Self::relu(&h2);
let h3 = self.gcn2.forward(&h2_relu, &adj);
let embedding = self.mean_pool(&h3, graph.num_nodes());
let normalized = self.normalize_embedding(&embedding);
let error_code = self.extract_error_code_from_graph(graph);
let context_hash = self.compute_graph_hash(graph);
ErrorEmbedding::new(normalized, error_code, context_hash)
}
#[must_use]
pub fn encode(&self, diagnostic: &CompilerDiagnostic, source: &str) -> ErrorEmbedding {
let graph = self.build_graph(diagnostic, source);
self.encode_graph(&graph)
}
fn extract_diagnostic_features(&self, diagnostic: &CompilerDiagnostic) -> Vec<f32> {
let mut features = vec![0.0f32; 64 + self.node_type_dim];
let code_hash = Self::simple_hash(&diagnostic.code.code);
for (i, feature) in features.iter_mut().take(32).enumerate() {
*feature = ((code_hash >> (i % 64)) & 1) as f32;
}
let msg_lower = diagnostic.message.to_lowercase();
let keywords = [
"type",
"borrow",
"move",
"lifetime",
"trait",
"impl",
"expected",
"found",
"cannot",
"missing",
"unknown",
"value",
"reference",
"mutable",
"method",
"function",
"argument",
"return",
"copy",
"clone",
"bound",
"satisfy",
"require",
"import",
"module",
"crate",
"use",
"struct",
"enum",
"unsafe",
"async",
"await",
];
for (i, kw) in keywords.iter().enumerate().take(32) {
features[32 + i] = if msg_lower.contains(kw) { 1.0 } else { 0.0 };
}
features[64] = 1.0; features[65] = match diagnostic.code.category {
super::ErrorCategory::TypeMismatch => 1.0,
_ => 0.0,
};
features[66] = match diagnostic.code.category {
super::ErrorCategory::Ownership => 1.0,
_ => 0.0,
};
features[67] = match diagnostic.code.category {
super::ErrorCategory::Lifetime => 1.0,
_ => 0.0,
};
features[68] = match diagnostic.code.category {
super::ErrorCategory::TraitBound => 1.0,
_ => 0.0,
};
features[69] = match diagnostic.code.category {
super::ErrorCategory::Import => 1.0,
_ => 0.0,
};
features[70] = match diagnostic.code.difficulty {
super::Difficulty::Easy => 0.25,
super::Difficulty::Medium => 0.5,
super::Difficulty::Hard => 0.75,
super::Difficulty::Expert => 1.0,
};
features
}
fn extract_type_node_features(&self, type_name: &str, is_expected: bool) -> Vec<f32> {
let mut features = vec![0.0f32; 64 + self.node_type_dim];
let type_patterns = [
("String", 0),
("str", 1),
("Vec", 2),
("Option", 3),
("Result", 4),
("Box", 5),
("i32", 6),
("i64", 7),
("u32", 8),
("u64", 9),
("f32", 10),
("f64", 11),
("bool", 12),
("char", 13),
("usize", 14),
("isize", 15),
("&", 16),
("mut", 17),
("'", 18),
("<", 19),
("impl", 20),
("dyn", 21),
("Rc", 22),
("Arc", 23),
("Cell", 24),
("RefCell", 25),
("Pin", 26),
("Future", 27),
("Iterator", 28),
("IntoIterator", 29),
("Clone", 30),
("Copy", 31),
];
for (pattern, idx) in &type_patterns {
if type_name.contains(pattern) {
features[*idx] = 1.0;
}
}
features[32] = type_name.len() as f32 / 50.0;
features[33] = type_name.matches('<').count() as f32 / 3.0;
features[34] = type_name.matches('&').count() as f32 / 2.0;
features[35] = type_name.matches('\'').count() as f32 / 2.0;
features[64] = 0.0; features[65] = if is_expected { 1.0 } else { 0.0 };
features[66] = if is_expected { 0.0 } else { 1.0 };
features[67] = 1.0;
features
}
fn extract_ast_nodes(
&self,
graph: &mut ProgramFeedbackGraph,
source: &str,
span: &SourceSpan,
) -> Vec<usize> {
let mut indices = Vec::new();
let lines: Vec<&str> = source.lines().collect();
let start_line = span.line_start.saturating_sub(1);
let end_line = span.line_end.min(lines.len());
for line in lines.iter().take(end_line).skip(start_line) {
for token in Self::tokenize_rust(line) {
let features = self.extract_token_features(&token);
let idx = graph.add_node(NodeType::Ast, features);
indices.push(idx);
if indices.len() >= 20 {
return indices;
}
}
}
indices
}
fn tokenize_rust(line: &str) -> Vec<String> {
let keywords = [
"fn", "let", "mut", "struct", "impl", "trait", "use", "mod", "pub", "self", "Self",
"return", "if", "else", "match", "for", "while", "loop", "break", "continue", "async",
"await", "move", "ref", "where", "type", "const", "static", "enum", "union",
];
let mut tokens = Vec::new();
let mut current = String::new();
for c in line.chars() {
if c.is_alphanumeric() || c == '_' {
current.push(c);
} else {
if !current.is_empty() {
tokens.push(current.clone());
current.clear();
}
if matches!(
c,
':' | ';'
| ','
| '.'
| '&'
| '*'
| '<'
| '>'
| '('
| ')'
| '{'
| '}'
| '['
| ']'
| '='
| '-'
| '+'
| '!'
| '?'
) {
tokens.push(c.to_string());
}
}
}
if !current.is_empty() {
tokens.push(current);
}
let keyword_set: std::collections::HashSet<&str> = keywords.iter().copied().collect();
tokens.sort_by(|a, b| {
let a_is_kw = keyword_set.contains(a.as_str());
let b_is_kw = keyword_set.contains(b.as_str());
b_is_kw.cmp(&a_is_kw)
});
tokens.into_iter().take(10).collect()
}
fn extract_token_features(&self, token: &str) -> Vec<f32> {
let mut features = vec![0.0f32; 64 + self.node_type_dim];
let hash = Self::simple_hash(token);
for (i, feature) in features.iter_mut().take(32).enumerate() {
*feature = ((hash >> (i % 64)) & 1) as f32;
}
let keywords = [
"fn", "let", "mut", "struct", "impl", "trait", "use", "mod", "pub", "self", "Self",
"return", "if", "else", "match", "for",
];
for (i, kw) in keywords.iter().enumerate().take(16) {
features[32 + i] = if token == *kw { 1.0 } else { 0.0 };
}
features[48] = token.len() as f32 / 20.0;
features[49] = if token.chars().all(|c| c.is_uppercase() || c == '_') {
1.0
} else {
0.0
};
features[50] = if token.starts_with(char::is_uppercase) {
1.0
} else {
0.0
};
features[51] = if token.chars().all(char::is_numeric) {
1.0
} else {
0.0
};
features[64] = 0.0; features[71] = 1.0;
features
}
}