use super::diagnostic::{CompilerDiagnostic, SourceSpan};
use super::ErrorCode;
use crate::autograd::Tensor;
use crate::nn::gnn::{AdjacencyMatrix, GCNConv, SAGEAggregation, SAGEConv};
use std::collections::HashMap;
use trueno::Vector;
#[derive(Debug, Clone)]
pub struct ErrorEmbedding {
pub vector: Vec<f32>,
pub error_code: ErrorCode,
pub context_hash: u64,
}
impl ErrorEmbedding {
#[must_use]
pub fn new(vector: Vec<f32>, error_code: ErrorCode, context_hash: u64) -> Self {
Self {
vector,
error_code,
context_hash,
}
}
#[must_use]
pub fn dim(&self) -> usize {
self.vector.len()
}
#[must_use]
pub fn cosine_similarity(&self, other: &ErrorEmbedding) -> f32 {
if self.vector.len() != other.vector.len() || self.vector.is_empty() {
return 0.0;
}
crate::nn::functional::cosine_similarity_slice(&self.vector, &other.vector)
}
#[must_use]
pub fn l2_distance(&self, other: &ErrorEmbedding) -> f32 {
if self.vector.len() != other.vector.len() || self.vector.is_empty() {
return f32::MAX;
}
let va = Vector::from_slice(&self.vector);
let vb = Vector::from_slice(&other.vector);
va.sub(&vb)
.and_then(|diff| diff.norm_l2())
.unwrap_or(f32::MAX)
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct ErrorEncoder {
dim: usize,
error_code_embeddings: HashMap<String, Vec<f32>>,
vocab: HashMap<String, usize>,
}
include!("gnn_error_encoder.rs");
include!("gnn_encoder_impl.rs");
include!("program_feedback_graph.rs");