use std::collections::HashMap;
use std::panic::AssertUnwindSafe;
use std::path::{Path, PathBuf};
use crate::ast::imports::get_imports;
use crate::callgraph::builder_v2::build_project_call_graph_v2;
use crate::callgraph::cross_file_types::ProjectCallGraphV2;
use crate::callgraph::BuildConfig;
use crate::cfg::extractor::get_cfg_context;
use crate::dfg::extractor::get_dfg_context;
use crate::semantic::types::CodeChunk;
use crate::types::{BlockType, RefType};
use crate::Language;
#[derive(Debug, Clone)]
pub struct EmbeddingUnit {
pub chunk: CodeChunk,
pub signature: String,
pub docstring: String,
pub calls: Vec<String>,
pub called_by: Vec<String>,
pub cfg_summary: String,
pub dfg_summary: String,
pub dependencies: String,
}
struct FileAnalysisCache {
file_sources: HashMap<PathBuf, String>,
file_imports: HashMap<PathBuf, String>,
}
impl FileAnalysisCache {
fn build(chunks: &[CodeChunk]) -> Self {
let mut file_sources: HashMap<PathBuf, String> = HashMap::new();
let mut file_imports: HashMap<PathBuf, String> = HashMap::new();
let unique_paths: Vec<PathBuf> = {
let mut seen = std::collections::HashSet::new();
chunks
.iter()
.filter(|c| seen.insert(c.file_path.clone()))
.map(|c| c.file_path.clone())
.collect()
};
for path in &unique_paths {
if let Ok(source) = std::fs::read_to_string(path) {
file_sources.insert(path.clone(), source);
}
let imports_str = std::panic::catch_unwind(AssertUnwindSafe(|| {
let lang = chunks
.iter()
.find(|c| &c.file_path == path)
.map(|c| c.language);
if let Some(lang) = lang {
match get_imports(path, lang) {
Ok(imports) => {
let modules: Vec<String> = imports
.iter()
.map(|imp| imp.module.clone())
.collect();
if modules.is_empty() {
String::new()
} else {
let mut unique_modules: Vec<String> = modules;
unique_modules.sort();
unique_modules.dedup();
unique_modules.truncate(10);
unique_modules.join(", ")
}
}
Err(_) => String::new(),
}
} else {
String::new()
}
}))
.unwrap_or_default();
if !imports_str.is_empty() {
file_imports.insert(path.clone(), imports_str);
}
}
FileAnalysisCache {
file_sources,
file_imports,
}
}
}
pub fn build_embedding_text(unit: &EmbeddingUnit) -> String {
let mut parts = Vec::new();
let name = unit
.chunk
.function_name
.as_deref()
.unwrap_or_else(|| {
unit.chunk
.file_path
.file_name()
.and_then(|f| f.to_str())
.unwrap_or("unknown")
});
parts.push(format!("Function: {}", name));
if !unit.signature.is_empty() {
let sig = if unit.signature.len() > 200 {
&unit.signature[..200]
} else {
&unit.signature
};
parts.push(format!("Signature: {}", sig));
}
if !unit.docstring.is_empty() {
parts.push(format!("Description: {}", unit.docstring));
}
if !unit.calls.is_empty() {
let top_calls: Vec<&str> = unit.calls.iter().take(5).map(|s| s.as_str()).collect();
parts.push(format!("Calls: {}", top_calls.join(", ")));
}
if !unit.called_by.is_empty() {
let top_callers: Vec<&str> = unit.called_by.iter().take(5).map(|s| s.as_str()).collect();
parts.push(format!("Called by: {}", top_callers.join(", ")));
}
if !unit.cfg_summary.is_empty() {
parts.push(format!("Control flow: {}", unit.cfg_summary));
}
if !unit.dfg_summary.is_empty() {
parts.push(format!("Data flow: {}", unit.dfg_summary));
}
if !unit.dependencies.is_empty() {
parts.push(format!("Dependencies: {}", unit.dependencies));
}
let text = parts.join("\n");
if text.len() > 2000 {
text[..2000].to_string()
} else {
text
}
}
fn build_cfg_summary_from_source(
source: &str,
function_name: &str,
language: Language,
) -> String {
match get_cfg_context(source, function_name, language) {
Ok(cfg) => {
let complexity = cfg.cyclomatic_complexity;
let branches = cfg
.blocks
.iter()
.filter(|b| b.block_type == BlockType::Branch)
.count();
let loops = cfg
.blocks
.iter()
.filter(|b| {
b.block_type == BlockType::LoopHeader || b.block_type == BlockType::LoopBody
})
.count();
format!(
"complexity={}, branches={}, loops={}",
complexity, branches, loops
)
}
Err(_) => String::new(),
}
}
fn build_dfg_summary_from_source(
source: &str,
function_name: &str,
language: Language,
) -> String {
match get_dfg_context(source, function_name, language) {
Ok(dfg) => {
let vars = dfg.variables.len();
let defs = dfg
.refs
.iter()
.filter(|r| matches!(r.ref_type, RefType::Definition))
.count();
let uses = dfg
.refs
.iter()
.filter(|r| matches!(r.ref_type, RefType::Use))
.count();
format!("vars={}, defs={}, uses={}", vars, defs, uses)
}
Err(_) => String::new(),
}
}
pub fn enrich_chunks(chunks: &[CodeChunk], root: &Path) -> Vec<EmbeddingUnit> {
if chunks.is_empty() {
return Vec::new();
}
let file_cache = FileAnalysisCache::build(chunks);
let mut call_graphs: HashMap<Language, ProjectCallGraphV2> = HashMap::new();
{
let mut languages_seen = std::collections::HashSet::new();
for chunk in chunks {
if languages_seen.insert(chunk.language) {
let lang = chunk.language;
let graph = std::panic::catch_unwind(AssertUnwindSafe(|| {
let config = BuildConfig {
language: lang.as_str().to_string(),
..Default::default()
};
match build_project_call_graph_v2(root, config) {
Ok(ir) => {
let mut graph = ProjectCallGraphV2::new();
for edge in ir.edges {
graph.add_edge(edge);
}
graph
}
Err(_) => ProjectCallGraphV2::new(),
}
}))
.unwrap_or_else(|_| ProjectCallGraphV2::new());
call_graphs.insert(lang, graph);
}
}
}
let result: Vec<EmbeddingUnit> = chunks
.iter()
.map(|chunk| {
std::panic::catch_unwind(AssertUnwindSafe(|| {
enrich_single_chunk(chunk, root, &file_cache, &call_graphs)
}))
.unwrap_or_else(|_| EmbeddingUnit {
chunk: chunk.clone(),
signature: String::new(),
docstring: String::new(),
calls: Vec::new(),
called_by: Vec::new(),
cfg_summary: String::new(),
dfg_summary: String::new(),
dependencies: String::new(),
})
})
.collect();
assert_eq!(
result.len(),
chunks.len(),
"enrich_chunks must return exactly one EmbeddingUnit per input CodeChunk"
);
result
}
fn enrich_single_chunk(
chunk: &CodeChunk,
root: &Path,
file_cache: &FileAnalysisCache,
call_graphs: &HashMap<Language, ProjectCallGraphV2>,
) -> EmbeddingUnit {
let signature = chunk
.content
.lines()
.next()
.unwrap_or("")
.to_string();
let docstring = String::new();
let (calls, called_by) = if let Some(func_name) = &chunk.function_name {
if let Some(graph) = call_graphs.get(&chunk.language) {
let rel_path = chunk
.file_path
.strip_prefix(root)
.unwrap_or(&chunk.file_path);
let callees: Vec<String> = graph
.callees_of(rel_path, func_name)
.map(|e| e.dst_func.clone())
.take(5)
.collect();
let callers: Vec<String> = graph
.callers_of(rel_path, func_name)
.map(|e| e.src_func.clone())
.take(5)
.collect();
(callees, callers)
} else {
(Vec::new(), Vec::new())
}
} else {
(Vec::new(), Vec::new())
};
let cfg_summary = if let Some(func_name) = &chunk.function_name {
if let Some(source) = file_cache.file_sources.get(&chunk.file_path) {
std::panic::catch_unwind(AssertUnwindSafe(|| {
build_cfg_summary_from_source(source, func_name, chunk.language)
}))
.unwrap_or_default()
} else {
std::panic::catch_unwind(AssertUnwindSafe(|| {
build_cfg_summary_from_source(&chunk.content, func_name, chunk.language)
}))
.unwrap_or_default()
}
} else {
String::new()
};
let dfg_summary = if let Some(func_name) = &chunk.function_name {
if let Some(source) = file_cache.file_sources.get(&chunk.file_path) {
std::panic::catch_unwind(AssertUnwindSafe(|| {
build_dfg_summary_from_source(source, func_name, chunk.language)
}))
.unwrap_or_default()
} else {
std::panic::catch_unwind(AssertUnwindSafe(|| {
build_dfg_summary_from_source(&chunk.content, func_name, chunk.language)
}))
.unwrap_or_default()
}
} else {
String::new()
};
let dependencies = file_cache
.file_imports
.get(&chunk.file_path)
.cloned()
.unwrap_or_default();
EmbeddingUnit {
chunk: chunk.clone(),
signature,
docstring,
calls,
called_by,
cfg_summary,
dfg_summary,
dependencies,
}
}
pub fn content_hash_from_source(source: &str) -> String {
format!("{:x}", md5::compute(source.as_bytes()))
}