use std::collections::HashMap;
use aprender::citl::{
CodeReplacement, CompilerDiagnostic, CompilerSuggestion, DiagnosticSeverity, Difficulty,
ErrorCategory as AprenderErrorCategory, ErrorCode as AprenderErrorCode, GNNErrorEncoder,
ProgramFeedbackGraph, SourceSpan, SuggestionApplicability, TypeInfo,
};
use aprender::index::hnsw::HNSWIndex;
use aprender::primitives::Vector;
use crate::ast_embeddings::{AstEmbedder, AstEmbeddingConfig};
use crate::classifier::ErrorCategory;
use crate::error_patterns::ErrorPattern;
use crate::tarantula::TranspilerDecision;
#[derive(Debug, Clone)]
pub struct GnnEncoderConfig {
pub hidden_dim: usize,
pub output_dim: usize,
pub similarity_threshold: f32,
pub max_similar: usize,
pub use_hnsw: bool,
pub use_ast_embeddings: bool,
pub ast_embedding_dim: usize,
pub hnsw_m: usize,
pub hnsw_ef_construction: usize,
}
impl Default for GnnEncoderConfig {
fn default() -> Self {
Self {
hidden_dim: 64,
output_dim: 256,
similarity_threshold: 0.7,
max_similar: 5,
use_hnsw: true,
use_ast_embeddings: true, ast_embedding_dim: 128,
hnsw_m: 16, hnsw_ef_construction: 200, }
}
}
#[derive(Debug, Clone)]
pub struct StructuralPattern {
pub id: String,
pub error_code: String,
pub embedding: Vec<f32>,
pub error_pattern: Option<ErrorPattern>,
pub match_count: u32,
pub success_rate: f64,
}
#[derive(Debug, Clone)]
pub struct SimilarPattern {
pub pattern_id: String,
pub similarity: f32,
pub pattern: StructuralPattern,
}
#[derive(Debug, Clone, Default)]
pub struct GnnEncoderStats {
pub patterns_indexed: usize,
pub queries_performed: usize,
pub successful_matches: usize,
pub avg_similarity: f64,
pub hnsw_queries: usize,
pub linear_queries: usize,
}
pub struct DepylerGnnEncoder {
config: GnnEncoderConfig,
encoder: GNNErrorEncoder,
ast_embedder: Option<AstEmbedder>,
patterns: HashMap<String, StructuralPattern>,
hnsw_index: Option<HNSWIndex>,
hnsw_id_map: Vec<String>,
stats: GnnEncoderStats,
}
impl DepylerGnnEncoder {
#[must_use]
pub fn new(config: GnnEncoderConfig) -> Self {
let encoder = GNNErrorEncoder::new(config.hidden_dim, config.output_dim);
let ast_embedder = if config.use_ast_embeddings {
Some(AstEmbedder::new(AstEmbeddingConfig {
embedding_dim: config.ast_embedding_dim,
..AstEmbeddingConfig::default()
}))
} else {
None
};
let hnsw_index = if config.use_hnsw {
Some(HNSWIndex::new(
config.hnsw_m,
config.hnsw_ef_construction,
0.0, ))
} else {
None
};
Self {
config,
encoder,
ast_embedder,
patterns: HashMap::new(),
hnsw_index,
hnsw_id_map: Vec::new(),
stats: GnnEncoderStats::default(),
}
}
#[must_use]
pub fn with_defaults() -> Self {
Self::new(GnnEncoderConfig::default())
}
pub fn index_pattern(&mut self, pattern: &ErrorPattern, source_context: &str) {
let diagnostic = self.pattern_to_diagnostic(pattern);
let embedding = self.encoder.encode(&diagnostic, source_context);
let structural = StructuralPattern {
id: pattern.id.clone(),
error_code: pattern.error_code.clone(),
embedding: embedding.vector.clone(),
error_pattern: Some(pattern.clone()),
match_count: 0,
success_rate: pattern.confidence,
};
if let Some(ref mut hnsw) = self.hnsw_index {
let vector_f64: Vec<f64> = embedding.vector.iter().map(|&x| x as f64).collect();
hnsw.add(pattern.id.clone(), Vector::from_slice(&vector_f64));
self.hnsw_id_map.push(pattern.id.clone());
}
self.patterns.insert(pattern.id.clone(), structural);
self.stats.patterns_indexed += 1;
}
pub fn batch_index_patterns(&mut self, patterns: &[(&ErrorPattern, &str)]) -> usize {
if patterns.is_empty() {
return 0;
}
let indexed = self.index_patterns_without_hnsw(patterns);
if self.config.use_hnsw && indexed > 0 {
self.rebuild_hnsw_index();
}
indexed
}
fn index_patterns_without_hnsw(&mut self, patterns: &[(&ErrorPattern, &str)]) -> usize {
let mut count = 0;
for (pattern, source_context) in patterns {
let structural = self.create_structural_pattern(pattern, source_context);
self.patterns.insert(pattern.id.clone(), structural);
self.stats.patterns_indexed += 1;
count += 1;
}
count
}
fn create_structural_pattern(
&self,
pattern: &ErrorPattern,
source_context: &str,
) -> StructuralPattern {
let diagnostic = self.pattern_to_diagnostic(pattern);
let embedding = self.encoder.encode(&diagnostic, source_context);
StructuralPattern {
id: pattern.id.clone(),
error_code: pattern.error_code.clone(),
embedding: embedding.vector.clone(),
error_pattern: Some(pattern.clone()),
match_count: 0,
success_rate: pattern.confidence,
}
}
pub fn find_similar(
&mut self,
error_code: &str,
error_message: &str,
source_context: &str,
) -> Vec<SimilarPattern> {
self.stats.queries_performed += 1;
let diagnostic = self.build_diagnostic(error_code, error_message);
let query_embedding = self.encoder.encode(&diagnostic, source_context);
let mut results = if self.hnsw_index.is_some() && !self.hnsw_id_map.is_empty() {
self.stats.hnsw_queries += 1;
self.find_similar_hnsw(&query_embedding.vector)
} else {
self.stats.linear_queries += 1;
self.find_similar_linear(&query_embedding.vector)
};
results.retain(|r| r.similarity >= self.config.similarity_threshold);
results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
results.truncate(self.config.max_similar);
if !results.is_empty() {
self.stats.successful_matches += 1;
let total_sim: f64 = results.iter().map(|r| f64::from(r.similarity)).sum();
let count = self.stats.successful_matches as f64;
self.stats.avg_similarity = (self.stats.avg_similarity * (count - 1.0)
+ total_sim / results.len() as f64)
/ count;
}
results
}
fn find_similar_hnsw(&self, query_embedding: &[f32]) -> Vec<SimilarPattern> {
let mut results = Vec::new();
if let Some(ref hnsw) = self.hnsw_index {
let query_f64: Vec<f64> = query_embedding.iter().map(|&x| x as f64).collect();
let query_vector = Vector::from_slice(&query_f64);
let k = self.config.max_similar * 2;
let neighbors = hnsw.search(&query_vector, k);
for (pattern_id, distance) in neighbors {
let similarity = (1.0 - distance as f32).clamp(0.0, 1.0);
if let Some(pattern) = self.patterns.get(&pattern_id) {
results.push(SimilarPattern {
pattern_id: pattern_id.clone(),
similarity,
pattern: pattern.clone(),
});
}
}
}
results
}
fn find_similar_linear(&self, query_embedding: &[f32]) -> Vec<SimilarPattern> {
let mut results = Vec::new();
for (id, pattern) in &self.patterns {
let similarity = self.cosine_similarity(query_embedding, &pattern.embedding);
results.push(SimilarPattern {
pattern_id: id.clone(),
similarity,
pattern: pattern.clone(),
});
}
results
}
pub fn record_match_success(&mut self, pattern_id: &str) {
if let Some(pattern) = self.patterns.get_mut(pattern_id) {
pattern.match_count += 1;
pattern.success_rate = pattern.success_rate * 0.9 + 0.1;
}
}
pub fn record_match_failure(&mut self, pattern_id: &str) {
if let Some(pattern) = self.patterns.get_mut(pattern_id) {
pattern.match_count += 1;
pattern.success_rate *= 0.9;
}
}
#[must_use]
pub fn encode_error(
&self,
error_code: &str,
error_message: &str,
source_context: &str,
) -> Vec<f32> {
let diagnostic = self.build_diagnostic(error_code, error_message);
let embedding = self.encoder.encode(&diagnostic, source_context);
embedding.vector
}
#[must_use]
pub fn encode_combined(
&self,
error_code: &str,
error_message: &str,
python_source: &str,
rust_source: &str,
) -> Vec<f32> {
let gnn_embedding = self.encode_error(error_code, error_message, rust_source);
if let Some(ref ast_embedder) = self.ast_embedder {
let python_ast = ast_embedder.embed_python(python_source);
let rust_ast = ast_embedder.embed_rust(rust_source);
let mut combined = gnn_embedding;
combined.extend(&python_ast.vector);
combined.extend(&rust_ast.vector);
combined
} else {
gnn_embedding
}
}
#[must_use]
pub fn combined_dim(&self) -> usize {
if self.config.use_ast_embeddings {
self.config.output_dim + self.config.ast_embedding_dim * 2
} else {
self.config.output_dim
}
}
#[must_use]
pub fn build_graph(
&self,
error_code: &str,
error_message: &str,
source_context: &str,
) -> ProgramFeedbackGraph {
let diagnostic = self.build_diagnostic(error_code, error_message);
self.encoder.build_graph(&diagnostic, source_context)
}
#[must_use]
pub fn stats(&self) -> &GnnEncoderStats {
&self.stats
}
#[must_use]
pub fn is_hnsw_active(&self) -> bool {
self.hnsw_index.is_some() && !self.hnsw_id_map.is_empty()
}
#[must_use]
pub fn hnsw_size(&self) -> usize {
self.hnsw_id_map.len()
}
pub fn rebuild_hnsw_index(&mut self) {
if !self.config.use_hnsw {
return;
}
self.hnsw_index = Some(HNSWIndex::new(
self.config.hnsw_m,
self.config.hnsw_ef_construction,
0.0,
));
self.hnsw_id_map.clear();
let mut pattern_ids: Vec<_> = self.patterns.keys().cloned().collect();
pattern_ids.sort();
for pattern_id in pattern_ids {
if let Some(pattern) = self.patterns.get(&pattern_id) {
if let Some(ref mut hnsw) = self.hnsw_index {
let vector_f64: Vec<f64> =
pattern.embedding.iter().map(|&x| x as f64).collect();
hnsw.add(pattern_id.clone(), Vector::from_slice(&vector_f64));
self.hnsw_id_map.push(pattern_id);
}
}
}
}
#[must_use]
pub fn config(&self) -> &GnnEncoderConfig {
&self.config
}
#[must_use]
pub fn pattern_count(&self) -> usize {
self.patterns.len()
}
pub fn patterns(&self) -> impl Iterator<Item = &StructuralPattern> {
self.patterns.values()
}
fn pattern_to_diagnostic(&self, pattern: &ErrorPattern) -> CompilerDiagnostic {
let error_code = self.depyler_to_aprender_code(&pattern.error_code);
let span = SourceSpan::single_line("source.rs", 1, 1, 80);
let mut diagnostic = CompilerDiagnostic::new(
error_code,
DiagnosticSeverity::Error,
&pattern.error_pattern,
span.clone(),
);
if pattern.error_pattern.contains("expected") && pattern.error_pattern.contains("found") {
diagnostic = diagnostic
.with_expected(TypeInfo::new("ExpectedType"))
.with_found(TypeInfo::new("FoundType"));
}
if !pattern.fix_diff.is_empty() {
let suggestion = CompilerSuggestion::new(
"Apply fix",
SuggestionApplicability::MachineApplicable,
CodeReplacement::new(span, &pattern.fix_diff),
);
diagnostic = diagnostic.with_suggestion(suggestion);
}
diagnostic
}
fn build_diagnostic(&self, error_code: &str, error_message: &str) -> CompilerDiagnostic {
let aprender_code = self.depyler_to_aprender_code(error_code);
let span = SourceSpan::single_line("source.rs", 1, 1, 80);
CompilerDiagnostic::new(
aprender_code,
DiagnosticSeverity::Error,
error_message,
span,
)
}
fn depyler_to_aprender_code(&self, code: &str) -> AprenderErrorCode {
let category = match code {
"E0308" => AprenderErrorCategory::TypeMismatch,
"E0382" | "E0502" | "E0503" => AprenderErrorCategory::Ownership,
"E0106" | "E0495" => AprenderErrorCategory::Lifetime,
"E0277" => AprenderErrorCategory::TraitBound,
"E0433" | "E0412" => AprenderErrorCategory::Import,
_ => AprenderErrorCategory::Unknown,
};
let difficulty = match code {
"E0308" | "E0433" | "E0412" => Difficulty::Easy,
"E0382" | "E0277" => Difficulty::Medium,
"E0502" | "E0503" | "E0106" => Difficulty::Hard,
"E0495" => Difficulty::Expert,
_ => Difficulty::Medium,
};
AprenderErrorCode::new(code, category, difficulty)
}
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
return 0.0;
}
dot / (norm_a * norm_b)
}
}
#[must_use]
pub fn map_error_category(category: ErrorCategory) -> AprenderErrorCategory {
match category {
ErrorCategory::TypeMismatch => AprenderErrorCategory::TypeMismatch,
ErrorCategory::BorrowChecker => AprenderErrorCategory::Ownership,
ErrorCategory::LifetimeError => AprenderErrorCategory::Lifetime,
ErrorCategory::TraitBound => AprenderErrorCategory::TraitBound,
ErrorCategory::MissingImport => AprenderErrorCategory::Import,
_ => AprenderErrorCategory::Unknown,
}
}
#[must_use]
pub fn infer_decision_from_match(pattern: &StructuralPattern) -> Option<TranspilerDecision> {
if let Some(ref error_pattern) = pattern.error_pattern {
return error_pattern.decision_type;
}
match pattern.error_code.as_str() {
"E0308" | "E0277" => Some(TranspilerDecision::TypeInference),
"E0382" | "E0502" | "E0503" => Some(TranspilerDecision::OwnershipInference),
"E0106" | "E0495" => Some(TranspilerDecision::LifetimeInference),
"E0433" | "E0412" => Some(TranspilerDecision::ImportGeneration),
"E0599" | "E0609" => Some(TranspilerDecision::MethodTranslation),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gnn_encoder_config_default() {
let config = GnnEncoderConfig::default();
assert_eq!(config.hidden_dim, 64);
assert_eq!(config.output_dim, 256);
assert!((config.similarity_threshold - 0.7).abs() < f32::EPSILON);
assert_eq!(config.max_similar, 5);
assert!(config.use_hnsw);
}
#[test]
fn test_gnn_encoder_creation() {
let encoder = DepylerGnnEncoder::with_defaults();
assert_eq!(encoder.pattern_count(), 0);
assert_eq!(encoder.stats().patterns_indexed, 0);
}
#[test]
fn test_encode_error() {
let encoder = DepylerGnnEncoder::with_defaults();
let embedding = encoder.encode_error(
"E0308",
"mismatched types: expected i32, found String",
"let x: i32 = \"hello\";",
);
assert_eq!(embedding.len(), 256);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 0.1 || norm < 0.1,
"Embedding should be normalized or near-zero, got {}",
norm
);
}
#[test]
fn test_index_pattern() {
let mut encoder = DepylerGnnEncoder::with_defaults();
let pattern = ErrorPattern::new("E0308", "mismatched types", "+let x: i32 = 42;");
encoder.index_pattern(&pattern, "let x: i32 = \"hello\";");
assert_eq!(encoder.pattern_count(), 1);
assert_eq!(encoder.stats().patterns_indexed, 1);
}
#[test]
fn test_find_similar_empty() {
let mut encoder = DepylerGnnEncoder::with_defaults();
let results = encoder.find_similar("E0308", "type mismatch", "let x = 5;");
assert!(results.is_empty());
assert_eq!(encoder.stats().queries_performed, 1);
assert_eq!(encoder.stats().successful_matches, 0);
}
#[test]
fn test_find_similar_with_patterns() {
let mut encoder = DepylerGnnEncoder::new(GnnEncoderConfig {
similarity_threshold: 0.0, ..Default::default()
});
let pattern = ErrorPattern::new("E0308", "mismatched types", "+fix");
encoder.index_pattern(&pattern, "let x: i32 = \"hello\";");
let results = encoder.find_similar("E0308", "mismatched types", "let y: i32 = \"world\";");
assert!(!results.is_empty());
assert!(results[0].similarity > 0.0);
}
#[test]
fn test_record_match_success() {
let mut encoder = DepylerGnnEncoder::with_defaults();
let pattern = ErrorPattern::new("E0308", "type error", "+fix");
encoder.index_pattern(&pattern, "source");
let pattern_id = encoder.patterns.keys().next().unwrap().clone();
let initial_rate = encoder.patterns.get(&pattern_id).unwrap().success_rate;
encoder.record_match_success(&pattern_id);
let new_rate = encoder.patterns.get(&pattern_id).unwrap().success_rate;
assert!(new_rate >= initial_rate * 0.8);
}
#[test]
fn test_record_match_failure() {
let mut encoder = DepylerGnnEncoder::with_defaults();
let pattern = ErrorPattern::new("E0308", "type error", "+fix");
encoder.index_pattern(&pattern, "source");
let pattern_id = encoder.patterns.keys().next().unwrap().clone();
let initial_rate = encoder.patterns.get(&pattern_id).unwrap().success_rate;
encoder.record_match_failure(&pattern_id);
let new_rate = encoder.patterns.get(&pattern_id).unwrap().success_rate;
assert!(new_rate <= initial_rate);
}
#[test]
fn test_build_graph() {
let encoder = DepylerGnnEncoder::with_defaults();
let graph = encoder.build_graph("E0308", "mismatched types", "let x: i32 = \"hello\";");
assert!(graph.num_nodes() > 0);
}
#[test]
fn test_cosine_similarity() {
let encoder = DepylerGnnEncoder::with_defaults();
let v1 = vec![1.0, 0.0, 0.0];
assert!((encoder.cosine_similarity(&v1, &v1) - 1.0).abs() < 0.001);
let v2 = vec![0.0, 1.0, 0.0];
assert!(encoder.cosine_similarity(&v1, &v2).abs() < 0.001);
let empty: Vec<f32> = vec![];
assert!(encoder.cosine_similarity(&empty, &empty).abs() < 0.001);
}
#[test]
fn test_depyler_to_aprender_code() {
let encoder = DepylerGnnEncoder::with_defaults();
let code = encoder.depyler_to_aprender_code("E0308");
assert_eq!(code.category, AprenderErrorCategory::TypeMismatch);
assert_eq!(code.difficulty, Difficulty::Easy);
let code = encoder.depyler_to_aprender_code("E0382");
assert_eq!(code.category, AprenderErrorCategory::Ownership);
let code = encoder.depyler_to_aprender_code("E9999");
assert_eq!(code.category, AprenderErrorCategory::Unknown);
}
#[test]
fn test_map_error_category() {
assert_eq!(
map_error_category(ErrorCategory::TypeMismatch),
AprenderErrorCategory::TypeMismatch
);
assert_eq!(
map_error_category(ErrorCategory::BorrowChecker),
AprenderErrorCategory::Ownership
);
assert_eq!(
map_error_category(ErrorCategory::LifetimeError),
AprenderErrorCategory::Lifetime
);
}
#[test]
fn test_infer_decision_from_match() {
let pattern = StructuralPattern {
id: "test".to_string(),
error_code: "E0308".to_string(),
embedding: vec![0.0; 256],
error_pattern: None,
match_count: 0,
success_rate: 1.0,
};
assert_eq!(
infer_decision_from_match(&pattern),
Some(TranspilerDecision::TypeInference)
);
}
#[test]
fn test_similar_errors_have_similar_embeddings() {
let encoder = DepylerGnnEncoder::with_defaults();
let e1 = encoder.encode_error(
"E0308",
"mismatched types: expected i32, found String",
"let x: i32 = \"hello\";",
);
let e2 = encoder.encode_error(
"E0308",
"mismatched types: expected i64, found &str",
"let y: i64 = \"world\";",
);
let sim = encoder.cosine_similarity(&e1, &e2);
assert!(
sim > 0.0,
"Similar errors should have positive similarity, got {}",
sim
);
}
#[test]
fn test_different_errors_produce_valid_embeddings() {
let encoder = DepylerGnnEncoder::with_defaults();
let e1 = encoder.encode_error("E0308", "mismatched types", "let x: i32 = \"hello\";");
let e2 = encoder.encode_error(
"E0382",
"borrow of moved value",
"let x = vec![1]; let y = x; x.push(1);",
);
assert_eq!(e1.len(), 256);
assert_eq!(e2.len(), 256);
assert!(e1.iter().all(|x| !x.is_nan() && x.is_finite()));
assert!(e2.iter().all(|x| !x.is_nan() && x.is_finite()));
}
#[test]
fn test_combined_embedding_dimension() {
let encoder = DepylerGnnEncoder::with_defaults();
assert_eq!(encoder.combined_dim(), 512);
let config = GnnEncoderConfig {
use_ast_embeddings: false,
..Default::default()
};
let encoder_no_ast = DepylerGnnEncoder::new(config);
assert_eq!(encoder_no_ast.combined_dim(), 256);
}
#[test]
fn test_encode_combined_with_ast() {
let encoder = DepylerGnnEncoder::with_defaults();
let python_source = r#"
def add(a, b):
return a + b
"#;
let rust_source = r#"
fn add(a: i32, b: i32) -> i32 {
a + b
}
"#;
let combined =
encoder.encode_combined("E0308", "mismatched types", python_source, rust_source);
assert_eq!(combined.len(), encoder.combined_dim());
assert_eq!(combined.len(), 512);
assert!(combined.iter().all(|x| !x.is_nan() && x.is_finite()));
}
#[test]
fn test_encode_combined_without_ast() {
let config = GnnEncoderConfig {
use_ast_embeddings: false,
..Default::default()
};
let encoder = DepylerGnnEncoder::new(config);
let combined = encoder.encode_combined(
"E0308",
"mismatched types",
"def foo(): pass",
"fn foo() {}",
);
assert_eq!(combined.len(), 256);
}
#[test]
fn test_combined_embedding_deterministic() {
let encoder = DepylerGnnEncoder::with_defaults();
let python = "def greet(name): return 'Hello ' + name";
let rust = "fn greet(name: &str) -> String { format!(\"Hello {}\", name) }";
let e1 = encoder.encode_combined("E0308", "type mismatch", python, rust);
let e2 = encoder.encode_combined("E0308", "type mismatch", python, rust);
assert_eq!(e1, e2);
}
#[test]
fn test_ast_embedder_initialized() {
let encoder = DepylerGnnEncoder::with_defaults();
assert!(encoder.ast_embedder.is_some());
let config = GnnEncoderConfig {
use_ast_embeddings: false,
..Default::default()
};
let encoder = DepylerGnnEncoder::new(config);
assert!(encoder.ast_embedder.is_none());
}
#[test]
fn test_phase3_hnsw_config_defaults() {
let config = GnnEncoderConfig::default();
assert!(config.use_hnsw);
assert_eq!(config.hnsw_m, 16);
assert_eq!(config.hnsw_ef_construction, 200);
}
#[test]
fn test_phase3_hnsw_initialization() {
let encoder = DepylerGnnEncoder::with_defaults();
assert!(encoder.hnsw_index.is_some());
assert!(!encoder.is_hnsw_active()); assert_eq!(encoder.hnsw_size(), 0);
}
#[test]
fn test_phase3_hnsw_disabled() {
let config = GnnEncoderConfig {
use_hnsw: false,
..Default::default()
};
let encoder = DepylerGnnEncoder::new(config);
assert!(encoder.hnsw_index.is_none());
assert!(!encoder.is_hnsw_active());
}
#[test]
fn test_phase3_hnsw_indexing() {
let mut encoder = DepylerGnnEncoder::with_defaults();
let pattern = ErrorPattern::new("E0308", "mismatched types", "+fix");
encoder.index_pattern(&pattern, "let x: i32 = \"hello\";");
assert!(encoder.is_hnsw_active());
assert_eq!(encoder.hnsw_size(), 1);
}
#[test]
fn test_phase3_hnsw_multiple_patterns() {
let mut encoder = DepylerGnnEncoder::with_defaults();
let patterns = [
ErrorPattern::new("E0308", "type mismatch 1", "+fix1"),
ErrorPattern::new("E0382", "borrow error", "+fix2"),
ErrorPattern::new("E0433", "import error", "+fix3"),
];
for pattern in &patterns {
encoder.index_pattern(pattern, "source context");
}
assert_eq!(encoder.hnsw_size(), 3);
assert_eq!(encoder.pattern_count(), 3);
}
#[test]
fn test_phase3_hnsw_search_uses_index() {
let mut encoder = DepylerGnnEncoder::new(GnnEncoderConfig {
similarity_threshold: 0.0, ..Default::default()
});
let pattern = ErrorPattern::new("E0308", "mismatched types", "+fix");
encoder.index_pattern(&pattern, "let x: i32 = \"hello\";");
let results = encoder.find_similar("E0308", "mismatched types", "let y: i32 = \"world\";");
assert_eq!(encoder.stats().hnsw_queries, 1);
assert_eq!(encoder.stats().linear_queries, 0);
assert!(!results.is_empty());
}
#[test]
fn test_phase3_linear_fallback_when_hnsw_disabled() {
let mut encoder = DepylerGnnEncoder::new(GnnEncoderConfig {
use_hnsw: false,
similarity_threshold: 0.0,
..Default::default()
});
let pattern = ErrorPattern::new("E0308", "mismatched types", "+fix");
encoder.index_pattern(&pattern, "source");
let _results = encoder.find_similar("E0308", "mismatched types", "source");
assert_eq!(encoder.stats().hnsw_queries, 0);
assert_eq!(encoder.stats().linear_queries, 1);
}
#[test]
fn test_phase3_rebuild_hnsw_index() {
let mut encoder = DepylerGnnEncoder::with_defaults();
for i in 0..5 {
let pattern = ErrorPattern::new("E0308", format!("error {}", i), "+fix");
encoder.index_pattern(&pattern, "source");
}
assert_eq!(encoder.hnsw_size(), 5);
encoder.rebuild_hnsw_index();
assert_eq!(encoder.hnsw_size(), 5);
assert!(encoder.is_hnsw_active());
}
#[test]
fn test_phase3_hnsw_stats_tracking() {
let mut encoder = DepylerGnnEncoder::new(GnnEncoderConfig {
similarity_threshold: 0.0,
..Default::default()
});
let pattern = ErrorPattern::new("E0308", "type error", "+fix");
encoder.index_pattern(&pattern, "source");
for _ in 0..3 {
let _ = encoder.find_similar("E0308", "error", "source");
}
assert_eq!(encoder.stats().queries_performed, 3);
assert_eq!(encoder.stats().hnsw_queries, 3);
assert_eq!(encoder.stats().linear_queries, 0);
}
#[test]
fn test_batch_index_patterns_empty() {
let mut encoder = DepylerGnnEncoder::with_defaults();
let patterns: Vec<(&ErrorPattern, &str)> = vec![];
let count = encoder.batch_index_patterns(&patterns);
assert_eq!(count, 0);
assert_eq!(encoder.pattern_count(), 0);
assert!(!encoder.is_hnsw_active());
}
#[test]
fn test_batch_index_patterns_single() {
let mut encoder = DepylerGnnEncoder::with_defaults();
let pattern = ErrorPattern::new("E0308", "type mismatch", "+fix");
let patterns = vec![(&pattern, "let x: i32 = \"hello\";")];
let count = encoder.batch_index_patterns(&patterns);
assert_eq!(count, 1);
assert_eq!(encoder.pattern_count(), 1);
assert!(encoder.is_hnsw_active());
assert_eq!(encoder.hnsw_size(), 1);
}
#[test]
fn test_batch_index_patterns_multiple() {
let mut encoder = DepylerGnnEncoder::with_defaults();
let p1 = ErrorPattern::new("E0308", "type mismatch", "+fix1");
let p2 = ErrorPattern::new("E0382", "borrow error", "+fix2");
let p3 = ErrorPattern::new("E0433", "import error", "+fix3");
let patterns = vec![
(&p1, "let x: i32 = \"hello\";"),
(&p2, "let x = vec![1]; let y = x; x.push(1);"),
(&p3, "use unknown::module;"),
];
let count = encoder.batch_index_patterns(&patterns);
assert_eq!(count, 3);
assert_eq!(encoder.pattern_count(), 3);
assert_eq!(encoder.hnsw_size(), 3);
assert_eq!(encoder.stats().patterns_indexed, 3);
}
#[test]
fn test_batch_index_patterns_searchable() {
let mut encoder = DepylerGnnEncoder::new(GnnEncoderConfig {
similarity_threshold: 0.0, ..Default::default()
});
let p1 = ErrorPattern::new("E0308", "mismatched types", "+fix");
let p2 = ErrorPattern::new("E0382", "borrow of moved value", "+fix");
let patterns = vec![(&p1, "let x: i32 = s;"), (&p2, "let y = x; x.push(1);")];
encoder.batch_index_patterns(&patterns);
let results = encoder.find_similar("E0308", "mismatched types", "let z: i32 = t;");
assert!(!results.is_empty());
assert_eq!(encoder.stats().hnsw_queries, 1);
}
#[test]
fn test_batch_index_patterns_without_hnsw() {
let mut encoder = DepylerGnnEncoder::new(GnnEncoderConfig {
use_hnsw: false,
similarity_threshold: 0.0,
..Default::default()
});
let p1 = ErrorPattern::new("E0308", "type error", "+fix");
let p2 = ErrorPattern::new("E0277", "trait not implemented", "+fix");
let patterns = vec![(&p1, "source1"), (&p2, "source2")];
let count = encoder.batch_index_patterns(&patterns);
assert_eq!(count, 2);
assert_eq!(encoder.pattern_count(), 2);
assert!(!encoder.is_hnsw_active());
let results = encoder.find_similar("E0308", "type error", "source");
assert!(!results.is_empty());
assert_eq!(encoder.stats().linear_queries, 1);
}
#[test]
fn test_batch_index_patterns_preserves_pattern_data() {
let mut encoder = DepylerGnnEncoder::with_defaults();
let p1 = ErrorPattern::new("E0308", "expected i32, found String", "- String\n+ i32");
let patterns = vec![(&p1, "let x: i32 = \"hello\";")];
encoder.batch_index_patterns(&patterns);
let stored = encoder.patterns().next().unwrap();
assert_eq!(stored.error_code, "E0308");
assert!(stored.error_pattern.is_some());
let original = stored.error_pattern.as_ref().unwrap();
assert_eq!(original.error_pattern, "expected i32, found String");
assert_eq!(original.fix_diff, "- String\n+ i32");
}
#[test]
fn test_batch_vs_individual_indexing_equivalence() {
let mut batch_encoder = DepylerGnnEncoder::new(GnnEncoderConfig {
use_hnsw: false, ..Default::default()
});
let mut individual_encoder = DepylerGnnEncoder::new(GnnEncoderConfig {
use_hnsw: false,
..Default::default()
});
let p1 = ErrorPattern::new("E0308", "type error", "+fix1");
let p2 = ErrorPattern::new("E0382", "borrow error", "+fix2");
let patterns = vec![(&p1, "source1"), (&p2, "source2")];
batch_encoder.batch_index_patterns(&patterns);
individual_encoder.index_pattern(&p1, "source1");
individual_encoder.index_pattern(&p2, "source2");
assert_eq!(
batch_encoder.pattern_count(),
individual_encoder.pattern_count()
);
assert_eq!(
batch_encoder.stats().patterns_indexed,
individual_encoder.stats().patterns_indexed
);
}
}