use crate::indexer::graphrag::types::CodeNode;
use anyhow::Result;
use std::path::{Path, PathBuf};
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let mut dot_product = 0.0;
let mut a_norm = 0.0;
let mut b_norm = 0.0;
for i in 0..a.len() {
dot_product += a[i] * b[i];
a_norm += a[i] * a[i];
b_norm += b[i] * b[i];
}
a_norm = a_norm.sqrt();
b_norm = b_norm.sqrt();
if a_norm == 0.0 || b_norm == 0.0 {
return 0.0;
}
dot_product / (a_norm * b_norm)
}
pub fn detect_project_root() -> Result<PathBuf> {
let current_dir = std::env::current_dir()?;
let mut dir = current_dir.as_path();
let indicators = [
"Cargo.toml",
"package.json",
".git",
"pyproject.toml",
"go.mod",
"pom.xml",
"build.gradle",
"composer.json",
];
loop {
for indicator in &indicators {
if dir.join(indicator).exists() {
return Ok(dir.to_path_buf());
}
}
match dir.parent() {
Some(parent) => dir = parent,
None => break,
}
}
Ok(current_dir)
}
pub fn to_relative_path(absolute_path: &str, project_root: &Path) -> Result<String> {
let abs_path = PathBuf::from(absolute_path);
let canonical_abs = abs_path.canonicalize().unwrap_or_else(|_| {
if abs_path.is_relative() {
project_root.join(&abs_path)
} else {
abs_path
}
});
let canonical_root = project_root
.canonicalize()
.unwrap_or_else(|_| project_root.to_path_buf());
let relative = canonical_abs.strip_prefix(&canonical_root).map_err(|_| {
anyhow::anyhow!(
"Path {} (canonical: {}) is not within project root {} (canonical: {})",
absolute_path,
canonical_abs.display(),
project_root.display(),
canonical_root.display()
)
})?;
Ok(relative.to_string_lossy().to_string())
}
pub fn render_graphrag_nodes_json(nodes: &[CodeNode]) -> Result<(), anyhow::Error> {
let json = serde_json::to_string_pretty(nodes)?;
println!("{}", json);
Ok(())
}
pub fn graphrag_nodes_to_text(nodes: &[CodeNode]) -> String {
if nodes.is_empty() {
return "No matching nodes found.".to_string();
}
let mut output = String::new();
output.push_str(&format!("GRAPHRAG NODES ({} found)\n\n", nodes.len()));
let mut nodes_by_file: std::collections::HashMap<String, Vec<&CodeNode>> =
std::collections::HashMap::new();
for node in nodes {
nodes_by_file
.entry(node.path.clone())
.or_default()
.push(node);
}
for (file_path, file_nodes) in nodes_by_file.iter() {
output.push_str(&format!("FILE: {}\n", file_path));
for node in file_nodes {
output.push_str(&format!(" {} {}\n", node.kind, node.name));
output.push_str(&format!(" ID: {}\n", node.id));
output.push_str(&format!(" Description: {}\n", node.description));
if !node.symbols.is_empty() {
output.push_str(" Symbols:\n");
let mut display_symbols = node.symbols.clone();
display_symbols.sort();
display_symbols.dedup();
for symbol in display_symbols {
if !symbol.contains("_") {
output.push_str(&format!(" - {}\n", symbol));
}
}
}
output.push('\n');
}
output.push('\n');
}
output
}
pub fn graphrag_nodes_to_markdown(nodes: &[CodeNode]) -> String {
let mut markdown = String::new();
if nodes.is_empty() {
markdown.push_str("No matching nodes found.");
return markdown;
}
markdown.push_str(&format!("# Found {} GraphRAG nodes\n\n", nodes.len()));
let mut nodes_by_file: std::collections::HashMap<String, Vec<&CodeNode>> =
std::collections::HashMap::new();
for node in nodes {
nodes_by_file
.entry(node.path.clone())
.or_default()
.push(node);
}
for (file_path, file_nodes) in nodes_by_file.iter() {
markdown.push_str(&format!("## File: {}\n\n", file_path));
for node in file_nodes {
markdown.push_str(&format!("### {} `{}`\n", node.kind, node.name));
markdown.push_str(&format!("**ID:** {} \n", node.id));
markdown.push_str(&format!("**Description:** {} \n", node.description));
if !node.symbols.is_empty() {
markdown.push_str("**Symbols:** \n");
let mut display_symbols = node.symbols.clone();
display_symbols.sort();
display_symbols.dedup();
for symbol in display_symbols {
if !symbol.contains("_") {
markdown.push_str(&format!("- `{}` \n", symbol));
}
}
}
markdown.push('\n');
}
markdown.push_str("---\n\n");
}
markdown
}
pub fn symbols_match(import: &str, export: &str) -> bool {
if import == export {
return true;
}
let clean_import = import
.trim_start_matches("import_")
.trim_start_matches("use_")
.trim_start_matches("from_");
let clean_export = export
.trim_start_matches("export_")
.trim_start_matches("pub_")
.trim_start_matches("public_");
clean_import == clean_export
}
pub fn is_parent_child_relationship(path1: &str, path2: &str) -> bool {
use crate::utils::path::PathNormalizer;
let normalized_path1 = PathNormalizer::normalize_separators(path1);
let normalized_path2 = PathNormalizer::normalize_separators(path2);
let path1_parts: Vec<&str> = normalized_path1.split('/').collect();
let path2_parts: Vec<&str> = normalized_path2.split('/').collect();
if path1_parts.len().abs_diff(path2_parts.len()) == 1 {
let (shorter, longer) = if path1_parts.len() < path2_parts.len() {
(path1_parts, path2_parts)
} else {
(path2_parts, path1_parts)
};
shorter.iter().zip(longer.iter()).all(|(a, b)| a == b)
} else {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_parent_child_relationship_cross_platform() {
assert!(is_parent_child_relationship("src", "src/main.rs"));
assert!(is_parent_child_relationship("src/main.rs", "src"));
assert!(is_parent_child_relationship("src", "src\\main.rs"));
assert!(is_parent_child_relationship("src\\main.rs", "src"));
assert!(is_parent_child_relationship(
"src/utils",
"src\\utils\\helper.rs"
));
assert!(is_parent_child_relationship(
"src\\utils\\helper.rs",
"src/utils"
));
assert!(!is_parent_child_relationship(
"src/main.rs",
"lib/helper.rs"
));
assert!(!is_parent_child_relationship(
"src\\main.rs",
"lib\\helper.rs"
));
assert!(!is_parent_child_relationship("src/main.rs", "src/lib.rs"));
assert!(!is_parent_child_relationship("src\\main.rs", "src\\lib.rs"));
}
}