use crate::indexer::graphrag::{CodeGraph, CodeNode, CodeRelationship};
use crate::store::CodeBlock;
use anyhow::Result;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct TaskFocusedSubgraph {
pub nodes: Vec<CodeNode>,
pub relationships: Vec<CodeRelationship>,
pub relevant_files: HashSet<String>,
pub key_concepts: HashMap<String, f32>, }
impl Default for TaskFocusedSubgraph {
fn default() -> Self {
Self::new()
}
}
impl TaskFocusedSubgraph {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
relationships: Vec::new(),
relevant_files: HashSet::new(),
key_concepts: HashMap::new(),
}
}
pub fn estimated_token_count(&self) -> usize {
const TOKENS_PER_NODE: usize = 100; const TOKENS_PER_RELATIONSHIP: usize = 50;
let node_tokens = self.nodes.len() * TOKENS_PER_NODE;
let relationship_tokens = self.relationships.len() * TOKENS_PER_RELATIONSHIP;
node_tokens + relationship_tokens
}
pub fn to_markdown(&self) -> String {
let mut markdown = String::new();
markdown.push_str(&format!(
"# Code Knowledge Graph: {} nodes, {} relationships\n\n",
self.nodes.len(),
self.relationships.len()
));
if !self.key_concepts.is_empty() {
markdown.push_str("## Key Concepts\n\n");
let mut concepts: Vec<_> = self.key_concepts.iter().collect();
concepts.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal));
for (concept, relevance) in concepts.iter().take(10) {
markdown.push_str(&format!(
"- **{}** (relevance: {:.2})\n",
concept, relevance
));
}
markdown.push('\n');
}
if !self.relevant_files.is_empty() {
markdown.push_str("## Relevant Files\n\n");
let mut files: Vec<_> = self.relevant_files.iter().collect();
files.sort();
for file in files.iter().take(15) {
markdown.push_str(&format!("- `{}`\n", file));
}
if self.relevant_files.len() > 15 {
markdown.push_str(&format!(
"- *(and {} more files)*\n",
self.relevant_files.len() - 15
));
}
markdown.push('\n');
}
if !self.nodes.is_empty() {
markdown.push_str("## Key Components\n\n");
let mut node_by_kind: HashMap<String, Vec<&CodeNode>> = HashMap::new();
for node in &self.nodes {
node_by_kind
.entry(node.kind.clone())
.or_default()
.push(node);
}
let mut total_nodes_shown = 0;
const MAX_NODES_TO_SHOW: usize = 20;
for (kind, nodes) in node_by_kind.iter() {
markdown.push_str(&format!("### {}s\n\n", kind.to_uppercase()));
for node in nodes.iter().take(5) {
markdown.push_str(&format!("- **{}**: {}\n", node.name, node.description));
total_nodes_shown += 1;
if total_nodes_shown >= MAX_NODES_TO_SHOW {
break;
}
}
if nodes.len() > 5 {
markdown.push_str(&format!("- *(and {} more {}s)*\n", nodes.len() - 5, kind));
}
markdown.push('\n');
if total_nodes_shown >= MAX_NODES_TO_SHOW {
break;
}
}
}
if !self.relationships.is_empty() {
markdown.push_str("## Relationships\n\n");
let mut rels_by_type: HashMap<String, Vec<&CodeRelationship>> = HashMap::new();
for rel in &self.relationships {
rels_by_type
.entry(rel.relation_type.to_string())
.or_default()
.push(rel);
}
let mut rel_types: Vec<_> = rels_by_type.iter().collect();
rel_types.sort_by(|a, b| b.1.len().cmp(&a.1.len()));
for (rel_type, rels) in rel_types.iter().take(5) {
markdown.push_str(&format!("### {} relationships\n\n", rel_type));
for rel in rels.iter().take(3) {
let source_name = rel.source.split('/').next_back().unwrap_or(&rel.source);
let target_name = rel.target.split('/').next_back().unwrap_or(&rel.target);
markdown.push_str(&format!("- `{}` → `{}`\n", source_name, target_name));
}
if rels.len() > 3 {
markdown.push_str(&format!(
"- *(and {} more {} relationships)*\n",
rels.len() - 3,
rel_type
));
}
markdown.push('\n');
}
}
markdown
}
pub fn add_node(&mut self, node: CodeNode) {
self.relevant_files.insert(node.path.clone());
if !self.nodes.iter().any(|n| n.id == node.id) {
self.nodes.push(node);
}
}
pub fn add_relationship(&mut self, relationship: CodeRelationship) {
if !self.relationships.iter().any(|r| {
r.source == relationship.source
&& r.target == relationship.target
&& r.relation_type == relationship.relation_type
}) {
self.relationships.push(relationship);
}
}
pub fn add_key_concept(&mut self, concept: String, relevance: f32) {
self.key_concepts.insert(concept, relevance);
}
}
pub struct GraphOptimizer {
pub max_token_budget: usize,
}
impl GraphOptimizer {
pub fn new(max_token_budget: usize) -> Self {
Self { max_token_budget }
}
pub async fn extract_task_subgraph(
&self,
_task_description: &str,
query_embedding: &[f32],
full_graph: &CodeGraph,
) -> Result<TaskFocusedSubgraph> {
let mut subgraph = TaskFocusedSubgraph::new();
let relevant_nodes = self.find_relevant_nodes(query_embedding, full_graph, 20)?;
for (node, relevance) in &relevant_nodes {
subgraph.add_node(node.clone());
self.extract_key_concepts(&mut subgraph, node, *relevance);
if subgraph.estimated_token_count() > self.max_token_budget {
break;
}
}
let node_ids: HashSet<String> = relevant_nodes
.iter()
.map(|(node, _)| node.id.clone())
.collect();
for relationship in &full_graph.relationships {
if node_ids.contains(&relationship.source) && node_ids.contains(&relationship.target) {
subgraph.add_relationship(relationship.clone());
}
}
let mut additional_nodes = HashSet::new();
for relationship in &full_graph.relationships {
if node_ids.contains(&relationship.source) && !node_ids.contains(&relationship.target) {
additional_nodes.insert(relationship.target.clone());
} else if node_ids.contains(&relationship.target)
&& !node_ids.contains(&relationship.source)
{
additional_nodes.insert(relationship.source.clone());
}
if subgraph.estimated_token_count() > self.max_token_budget {
break;
}
}
let mut added = 0;
for node_id in additional_nodes {
if let Some(node) = full_graph.nodes.get(&node_id) {
if added < 20 {
subgraph.add_node(node.clone());
added += 1;
for relationship in &full_graph.relationships {
if (relationship.source == node_id || relationship.target == node_id)
&& subgraph.nodes.iter().any(|n| n.id == relationship.source)
&& subgraph.nodes.iter().any(|n| n.id == relationship.target)
{
subgraph.add_relationship(relationship.clone());
}
}
} else {
break;
}
}
}
Ok(subgraph)
}
fn find_relevant_nodes(
&self,
query_embedding: &[f32],
graph: &CodeGraph,
limit: usize,
) -> Result<Vec<(CodeNode, f32)>> {
let mut similarities = Vec::new();
for node in graph.nodes.values() {
let similarity = cosine_similarity(query_embedding, &node.embedding);
similarities.push((node.clone(), similarity));
}
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(similarities.into_iter().take(limit).collect())
}
fn extract_key_concepts(
&self,
subgraph: &mut TaskFocusedSubgraph,
node: &CodeNode,
relevance: f32,
) {
subgraph.add_key_concept(node.name.clone(), relevance);
subgraph.add_key_concept(node.kind.clone(), relevance * 0.8);
let words: Vec<&str> = node
.description
.split_whitespace()
.filter(|w| w.len() > 4) .collect();
for word in words {
if is_likely_technical_term(word) {
subgraph.add_key_concept(
word.trim_matches(|c: char| !c.is_alphanumeric())
.to_string(),
relevance * 0.5,
);
}
}
}
pub async fn generate_task_focused_view(
&self,
task_description: &str,
query_embedding: &[f32],
full_graph: &CodeGraph,
code_blocks: &[CodeBlock],
) -> Result<String> {
let subgraph = self
.extract_task_subgraph(task_description, query_embedding, full_graph)
.await?;
let graph_markdown = subgraph.to_markdown();
let relevant_snippets =
self.find_relevant_code_snippets(query_embedding, &subgraph, code_blocks, 5)?;
let mut result = String::new();
result.push_str("# Task-Focused Code Overview\n\n");
result.push_str(&format!("**Task:** {}\n\n", task_description));
result.push_str("## Knowledge Graph Summary\n\n");
result.push_str(&graph_markdown);
if !relevant_snippets.is_empty() {
result.push_str("## Relevant Code Snippets\n\n");
for (idx, (block, similarity)) in relevant_snippets.iter().enumerate() {
result.push_str(&format!(
"### Snippet {} (Relevance: {:.2})\n\n",
idx + 1,
similarity
));
result.push_str(&format!("File: `{}`\n\n", block.path));
if !block.symbols.is_empty() {
let display_symbols: Vec<_> =
block.symbols.iter().filter(|s| !s.contains('_')).collect();
if !display_symbols.is_empty() {
result.push_str("**Symbols:** ");
for (i, symbol) in display_symbols.iter().enumerate() {
if i > 0 {
result.push_str(", ");
}
result.push_str(&format!("`{}`", symbol));
}
result.push_str("\n\n");
}
}
result.push_str("```");
if !block.language.is_empty() && block.language != "text" {
result.push_str(&block.language);
}
result.push('\n');
let lines: Vec<&str> = block.content.lines().collect();
if lines.len() > 20 {
for line in lines.iter().take(10) {
result.push_str(&format!("{}{}", line, "\n"));
}
result.push_str(&format!(
"// ... {} lines omitted ...{}\n",
lines.len() - 20,
if !lines.is_empty() {
" for brevity"
} else {
""
}
));
for line in lines.iter().skip(lines.len() - 10) {
result.push_str(&format!("{}{}", line, "\n"));
}
} else {
result.push_str(&block.content);
if !block.content.ends_with('\n') {
result.push('\n');
}
}
result.push_str("```\n\n");
}
}
Ok(result)
}
fn find_relevant_code_snippets(
&self,
query_embedding: &[f32],
subgraph: &TaskFocusedSubgraph,
code_blocks: &[CodeBlock],
limit: usize,
) -> Result<Vec<(CodeBlock, f32)>> {
let mut relevant_blocks = Vec::new();
let filtered_blocks: Vec<_> = code_blocks
.iter()
.filter(|block| subgraph.relevant_files.contains(&block.path))
.collect();
for block in filtered_blocks {
let contains_key_concept = !block.symbols.is_empty()
&& block
.symbols
.iter()
.any(|symbol| subgraph.key_concepts.contains_key(symbol));
let mut similarity =
cosine_similarity(query_embedding, &generate_block_embedding(block)?);
if contains_key_concept {
similarity *= 1.5; }
if similarity > 0.5 {
relevant_blocks.push((block.clone(), similarity));
}
}
relevant_blocks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(relevant_blocks.into_iter().take(limit).collect())
}
}
fn generate_block_embedding(block: &CodeBlock) -> Result<Vec<f32>> {
let mut result = vec![0.0; 128];
let hash_bytes = block.hash.as_bytes();
for (i, byte) in hash_bytes.iter().enumerate() {
let idx = i % result.len();
result[idx] = (*byte as f32) / 255.0;
}
let norm: f32 = result.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > 0.0 {
for val in result.iter_mut() {
*val /= norm;
}
}
Ok(result)
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let mut dot_product = 0.0;
let mut a_norm = 0.0;
let mut b_norm = 0.0;
for i in 0..a.len() {
dot_product += a[i] * b[i];
a_norm += a[i] * a[i];
b_norm += b[i] * b[i];
}
a_norm = a_norm.sqrt();
b_norm = b_norm.sqrt();
if a_norm == 0.0 || b_norm == 0.0 {
return 0.0;
}
dot_product / (a_norm * b_norm)
}
fn is_likely_technical_term(word: &str) -> bool {
let word = word
.trim_matches(|c: char| !c.is_alphanumeric())
.to_lowercase();
let common_words = [
"about", "after", "again", "below", "could", "every", "first", "found", "great", "house",
"large", "learn", "never", "other", "place", "plant", "point", "right", "small", "sound",
"spell", "still", "study", "their", "there", "these", "thing", "think", "three", "water",
"where", "which", "world", "would", "write",
];
if common_words.contains(&word.as_str()) {
return false;
}
let has_mixed_case =
word.chars().any(|c| c.is_uppercase()) && word.chars().any(|c| c.is_lowercase());
let has_underscore = word.contains('_');
has_mixed_case || has_underscore || word.len() > 6
}