impl ErrorEncoder {
#[must_use]
pub fn new() -> Self {
Self {
dim: 256,
error_code_embeddings: Self::init_error_code_embeddings(),
vocab: HashMap::new(),
}
}
#[must_use]
pub fn with_dim(dim: usize) -> Self {
Self {
dim,
error_code_embeddings: Self::init_error_code_embeddings(),
vocab: HashMap::new(),
}
}
fn init_error_code_embeddings() -> HashMap<String, Vec<f32>> {
let codes = [
"E0308", "E0382", "E0597", "E0599", "E0433", "E0432", "E0277", "E0425", "E0282",
"E0412", "E0502", "E0499", "E0596", "E0507", "E0621", "E0106", "E0373", "E0495",
"E0623",
];
let mut embeddings = HashMap::new();
for (i, code) in codes.iter().enumerate() {
let mut vec = vec![0.0f32; 64];
let base_idx = i % 32;
vec[base_idx] = 1.0;
vec[(base_idx + 16) % 64] = 0.5;
vec[(base_idx + 32) % 64] = 0.25;
embeddings.insert((*code).to_string(), vec);
}
embeddings
}
#[must_use]
pub fn encode(&self, diagnostic: &CompilerDiagnostic, source: &str) -> ErrorEmbedding {
let mut vector = vec![0.0f32; self.dim];
let code_embedding = self
.error_code_embeddings
.get(&diagnostic.code.code)
.cloned()
.unwrap_or_else(|| self.hash_code(&diagnostic.code.code));
for (i, &v) in code_embedding.iter().enumerate().take(64.min(self.dim)) {
vector[i] = v;
}
let context_features = self.extract_context_features(source, &diagnostic.span);
for (i, &v) in context_features.iter().enumerate().take(64) {
if i + 64 < self.dim {
vector[i + 64] = v;
}
}
let type_features = self.extract_type_features(diagnostic);
for (i, &v) in type_features.iter().enumerate().take(64) {
if i + 128 < self.dim {
vector[i + 128] = v;
}
}
let message_features = self.extract_message_features(&diagnostic.message);
for (i, &v) in message_features.iter().enumerate().take(64) {
if i + 192 < self.dim {
vector[i + 192] = v;
}
}
let tv = Vector::from_slice(&vector);
if let Ok(normalized) = tv.normalize() {
vector.copy_from_slice(normalized.as_slice());
}
let context_hash = self.hash_context(source, &diagnostic.span);
ErrorEmbedding::new(vector, diagnostic.code.clone(), context_hash)
}
#[allow(clippy::unused_self)]
fn hash_code(&self, code: &str) -> Vec<f32> {
let mut vec = vec![0.0f32; 64];
let hash = Self::simple_hash(code);
for (i, v) in vec.iter_mut().enumerate() {
*v = ((hash >> (i % 64)) & 1) as f32 * 0.5;
}
vec
}
#[allow(clippy::unused_self)]
fn extract_context_features(&self, source: &str, span: &SourceSpan) -> Vec<f32> {
let mut features = vec![0.0f32; 64];
let lines: Vec<&str> = source.lines().collect();
let start_line = span.line_start.saturating_sub(1);
let end_line = span.line_end.min(lines.len());
let mut char_counts = [0u32; 16];
for line in lines.iter().take(end_line).skip(start_line) {
for c in line.chars() {
let bucket = (c as usize) % 16;
char_counts[bucket] += 1;
}
}
let total: f32 = char_counts.iter().sum::<u32>() as f32 + 1.0;
for (i, &count) in char_counts.iter().enumerate() {
features[i] = count as f32 / total;
}
let keywords = [
"let", "mut", "fn", "struct", "impl", "trait", "use", "mod", "pub", "self", "Self",
"return", "if", "else", "match", "for",
];
let context: String = lines
.iter()
.take(end_line)
.skip(start_line)
.copied()
.collect::<Vec<_>>()
.join(" ");
for (i, keyword) in keywords.iter().enumerate() {
features[16 + i] = if context.contains(keyword) { 1.0 } else { 0.0 };
}
let patterns = [
("->", 32),
("=>", 33),
("::", 34),
("&mut", 35),
("&", 36),
("'", 37),
("<", 38),
(">", 39),
("()", 40),
("[]", 41),
("{}", 42),
(";", 43),
("=", 44),
(".", 45),
("?", 46),
("!", 47),
];
for (pattern, idx) in &patterns {
features[*idx] = if context.contains(pattern) { 1.0 } else { 0.0 };
}
features[48] = end_line.saturating_sub(start_line) as f32 / 10.0; features[49] = span.column_start as f32 / 80.0; features[50] = if context.contains("fn ") { 1.0 } else { 0.0 }; features[51] = if context.contains("impl ") { 1.0 } else { 0.0 }; features[52] = if context.contains("struct ") {
1.0
} else {
0.0
};
features
}
fn extract_type_features(&self, diagnostic: &CompilerDiagnostic) -> Vec<f32> {
let mut features = vec![0.0f32; 64];
if let Some(expected) = &diagnostic.expected {
features[0] = 1.0; features[1] = if expected.is_reference { 1.0 } else { 0.0 };
features[2] = if expected.is_mutable { 1.0 } else { 0.0 };
features[3] = expected.generics.len() as f32 / 4.0;
let type_features = self.type_to_features(&expected.base);
for (i, &v) in type_features.iter().enumerate() {
if i + 4 < 32 {
features[i + 4] = v;
}
}
}
if let Some(found) = &diagnostic.found {
features[32] = 1.0; features[33] = if found.is_reference { 1.0 } else { 0.0 };
features[34] = if found.is_mutable { 1.0 } else { 0.0 };
features[35] = found.generics.len() as f32 / 4.0;
let type_features = self.type_to_features(&found.base);
for (i, &v) in type_features.iter().enumerate() {
if i + 36 < 64 {
features[i + 36] = v;
}
}
}
features
}
#[allow(clippy::unused_self)]
fn type_to_features(&self, type_name: &str) -> Vec<f32> {
let mut features = vec![0.0f32; 16];
let type_patterns = [
("String", 0),
("str", 1),
("Vec", 2),
("Option", 3),
("Result", 4),
("Box", 5),
("i32", 6),
("i64", 7),
("u32", 8),
("u64", 9),
("f32", 10),
("f64", 11),
("bool", 12),
("char", 13),
("usize", 14),
("isize", 15),
];
for (pattern, idx) in &type_patterns {
if type_name.contains(pattern) {
features[*idx] = 1.0;
}
}
features
}
#[allow(clippy::unused_self)]
fn extract_message_features(&self, message: &str) -> Vec<f32> {
let mut features = vec![0.0f32; 64];
let message_lower = message.to_lowercase();
let phrases = [
("mismatched types", 0),
("expected", 1),
("found", 2),
("borrow", 3),
("move", 4),
("lifetime", 5),
("cannot", 6),
("trait", 7),
("implement", 8),
("method", 9),
("function", 10),
("argument", 11),
("return", 12),
("value", 13),
("type", 14),
("reference", 15),
("mutable", 16),
("immutable", 17),
("borrowed", 18),
("owned", 19),
("copy", 20),
("clone", 21),
("bound", 22),
("satisfy", 23),
("require", 24),
("missing", 25),
("unknown", 26),
("unresolved", 27),
("import", 28),
("module", 29),
("crate", 30),
("use", 31),
];
for (phrase, idx) in &phrases {
features[*idx] = if message_lower.contains(phrase) {
1.0
} else {
0.0
};
}
features[32] = (message.len() as f32 / 200.0).min(1.0);
features[33] = (message.split_whitespace().count() as f32 / 30.0).min(1.0);
features
}
#[allow(clippy::unused_self)]
fn hash_context(&self, source: &str, span: &SourceSpan) -> u64 {
let lines: Vec<&str> = source.lines().collect();
let start = span.line_start.saturating_sub(1);
let end = span.line_end.min(lines.len());
let context: String = lines
.iter()
.take(end)
.skip(start)
.copied()
.collect::<Vec<_>>()
.join("\n");
Self::simple_hash(&context)
}
fn simple_hash(s: &str) -> u64 {
let mut hash: u64 = 5381;
for byte in s.bytes() {
hash = hash.wrapping_mul(33).wrapping_add(u64::from(byte));
}
hash
}
}
impl Default for ErrorEncoder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct GNNErrorEncoder {
hidden_dim: usize,
output_dim: usize,
gcn1: GCNConv,
sage: SAGEConv,
gcn2: GCNConv,
node_type_dim: usize,
base_encoder: ErrorEncoder,
}