use crate::config::Config;
use crate::indexer::graphrag::types::{CodeNode, CodeRelationship};
use crate::llm::{LlmClient, Message};
use anyhow::Result;
use serde::Deserialize;
use serde_json::json;
use std::collections::HashMap;
use std::time::Duration;
pub struct AIEnhancements {
config: Config,
llm_client: Option<LlmClient>,
quiet: bool,
}
#[derive(Debug, Clone)]
pub struct FileForAI {
pub file_id: String,
pub file_path: String,
pub language: String,
pub symbols: Vec<String>,
pub content_sample: String,
pub function_count: usize,
pub class_count: usize,
}
#[derive(Debug, Deserialize)]
struct BatchDescriptionResponse {
descriptions: Vec<FileDescription>,
}
#[derive(Debug, Deserialize)]
struct FileDescription {
file_id: String,
description: String,
}
impl AIEnhancements {
pub fn new(config: Config, quiet: bool) -> Result<Self> {
let llm_client = if config.graphrag.use_llm {
Some(LlmClient::from_config(&config).map_err(|e| {
anyhow::anyhow!(
"LLM required for GraphRAG but unavailable: {}. \
Disable graphrag.use_llm or fix LLM configuration.",
e
)
})?)
} else {
None
};
Ok(Self {
config,
llm_client,
quiet,
})
}
pub fn llm_enabled(&self) -> bool {
self.config.graphrag.use_llm && self.llm_client.is_some()
}
pub async fn discover_relationships_with_ai_enhancement(
&self,
new_files: &[CodeNode],
all_nodes: &[CodeNode],
) -> Result<Vec<CodeRelationship>> {
let mut relationships = crate::indexer::graphrag::relationships::RelationshipDiscovery::discover_relationships_efficiently(new_files, all_nodes).await?;
let ai_relationships = self
.discover_complex_relationships_with_ai(new_files, all_nodes)
.await?;
relationships.extend(ai_relationships);
relationships.sort_by(|a, b| {
(a.source.clone(), a.target.clone(), a.relation_type.clone()).cmp(&(
b.source.clone(),
b.target.clone(),
b.relation_type.clone(),
))
});
relationships.dedup_by(|a, b| {
a.source == b.source && a.target == b.target && a.relation_type == b.relation_type
});
Ok(relationships)
}
async fn discover_complex_relationships_with_ai(
&self,
new_files: &[CodeNode],
all_nodes: &[CodeNode],
) -> Result<Vec<CodeRelationship>> {
let mut ai_relationships = Vec::new();
let complex_files: Vec<&CodeNode> = new_files
.iter()
.filter(|node| self.should_use_ai_for_relationships(node))
.collect();
if complex_files.is_empty() {
if !self.quiet {
eprintln!("Debug: No files qualified for AI relationship analysis in this batch");
}
return Ok(ai_relationships);
}
if !self.quiet {
eprintln!(
"Info: AI analyzing {} files for architectural relationships",
complex_files.len()
);
}
let ai_batch_size = self.config.graphrag.llm.ai_batch_size;
for batch in complex_files.chunks(ai_batch_size) {
let batch_relationships = self
.analyze_architectural_relationships_batch(batch, all_nodes)
.await?;
ai_relationships.extend(batch_relationships);
}
Ok(ai_relationships)
}
fn should_use_ai_for_relationships(&self, node: &CodeNode) -> bool {
let has_meaningful_exports = node.exports.len() >= 2;
let is_substantial_file = node.size_lines >= 50;
let has_multiple_symbols = node.symbols.len() >= 3;
has_meaningful_exports || is_substantial_file || has_multiple_symbols
}
async fn analyze_architectural_relationships_batch(
&self,
source_nodes: &[&CodeNode],
all_nodes: &[CodeNode],
) -> Result<Vec<CodeRelationship>> {
let system_prompt = String::from(
"You are an expert software architect. Analyze these code files and identify ARCHITECTURAL relationships.\n\
Focus on design patterns, dependency injection, factory patterns, observer patterns, etc.\n\
Look for relationships that go beyond simple imports - identify architectural significance.\n\n\
For each relationship, provide:\n\
- source_path: relative path of the source file\n\
- target_path: relative path of the target file\n\
- relation_type: one of 'implements_pattern', 'dependency_injection', 'factory_creates', 'observer_pattern', 'strategy_pattern', 'adapter_pattern', 'decorator_pattern', 'architectural_dependency'\n\
- description: brief explanation of the architectural relationship\n\
- confidence: 0.0-1.0 confidence score\n\n\
Respond with JSON: {\"relationships\": [{\"source_path\": \"...\", \"target_path\": \"...\", \"relation_type\": \"...\", \"description\": \"...\", \"confidence\": 0.0}]}\n\n"
);
let mut batch_prompt = String::from("");
batch_prompt.push_str("SOURCE FILES TO ANALYZE:\n");
for node in source_nodes {
batch_prompt.push_str(&format!(
"File: {}\nLanguage: {}\nKey symbols: {}\nExports: {}\n\n",
node.path,
node.language,
node.symbols
.iter()
.take(8)
.cloned()
.collect::<Vec<_>>()
.join(", "),
node.exports
.iter()
.take(5)
.cloned()
.collect::<Vec<_>>()
.join(", ")
));
}
batch_prompt.push_str("POTENTIAL RELATIONSHIP TARGETS:\n");
let relevant_targets: Vec<&CodeNode> = all_nodes
.iter()
.filter(|n| source_nodes.iter().all(|s| s.id != n.id)) .filter(|n| !n.exports.is_empty() || n.size_lines > 100) .take(10) .collect();
for node in &relevant_targets {
batch_prompt.push_str(&format!(
"File: {}\nLanguage: {}\nExports: {}\n\n",
node.path,
node.language,
node.exports
.iter()
.take(3)
.cloned()
.collect::<Vec<_>>()
.join(", ")
));
}
batch_prompt.push_str("JSON Response:");
let relationship_schema = json!({
"type": "object",
"properties": {
"relationships": {"type": "array", "items": {"type": "object", "properties": {
"source_path": {"type": "string"},
"target_path": {"type": "string"},
"relation_type": {"type": "string"},
"description": {"type": "string"},
"confidence": {"type": "number"}
}, "required": ["source_path", "target_path", "relation_type", "description", "confidence"]}}
},
"required": ["relationships"]
});
let response = self
.call_llm_json(
&self.config.graphrag.llm.relationship_model,
system_prompt,
batch_prompt,
Some(relationship_schema),
)
.await
.map_err(|e| {
anyhow::anyhow!(
"GraphRAG AI architectural analysis failed after retries: {}. \
Stopping indexing to prevent storing data without LLM analysis.",
e
)
})?;
let ai_relationships = self
.parse_ai_architectural_relationships(&response)
.unwrap_or_default();
let valid_relationships: Vec<CodeRelationship> = ai_relationships
.into_iter()
.filter(|rel| rel.confidence > self.config.graphrag.llm.confidence_threshold)
.filter(|rel| all_nodes.iter().any(|n| n.path == rel.target))
.map(|mut rel| {
rel.weight = self.config.graphrag.llm.architectural_weight;
rel
})
.collect();
Ok(valid_relationships)
}
fn parse_ai_architectural_relationships(
&self,
response: &serde_json::Value,
) -> Result<Vec<CodeRelationship>> {
#[derive(Deserialize)]
struct AiRelationship {
source_path: String,
target_path: String,
relation_type: String,
description: String,
confidence: f32,
}
let rels_value = response
.get("relationships")
.cloned()
.unwrap_or_else(|| response.clone());
if let Ok(ai_rels) = serde_json::from_value::<Vec<AiRelationship>>(rels_value) {
let relationships = ai_rels
.into_iter()
.map(|ai_rel| CodeRelationship {
source: ai_rel.source_path,
target: ai_rel.target_path,
relation_type: ai_rel
.relation_type
.parse()
.unwrap_or(crate::indexer::graphrag::types::RelationType::Imports),
description: ai_rel.description,
confidence: ai_rel.confidence,
weight: 0.9, })
.collect();
return Ok(relationships);
}
Ok(Vec::new())
}
pub fn should_use_ai_for_description(
&self,
symbols: &[String],
_lines: u32,
language: &str,
) -> bool {
let is_supported_language = crate::indexer::languages::get_language(language).is_some();
!symbols.is_empty() && is_supported_language
}
pub fn build_content_sample_for_ai(&self, file_blocks: &[&crate::store::CodeBlock]) -> String {
let mut sample = String::new();
let mut total_tokens = 0;
let max_tokens = self.config.graphrag.llm.max_sample_tokens;
let mut sorted_blocks: Vec<&crate::store::CodeBlock> = file_blocks.to_vec();
sorted_blocks.sort_by(|a, b| b.symbols.len().cmp(&a.symbols.len()));
for block in sorted_blocks {
let block_tokens = crate::embedding::count_tokens(&block.content);
if total_tokens + block_tokens >= max_tokens {
break;
}
let block_content = if block.content.len() > 300 {
let start_chars: String = block.content.chars().take(150).collect();
let end_chars: String = block
.content
.chars()
.rev()
.take(150)
.collect::<String>()
.chars()
.rev()
.collect();
format!("{}\n...\n{}", start_chars, end_chars)
} else {
block.content.clone()
};
sample.push_str(&format!(
"// Block: {} symbols\n{}\n\n",
block.symbols.len(),
block_content
));
total_tokens += block_tokens + 50; }
sample
}
const MAX_BATCH_RETRIES: u32 = 3;
pub async fn extract_ai_descriptions_batch(
&self,
files: &[FileForAI],
) -> Result<HashMap<String, String>> {
if files.is_empty() {
return Ok(HashMap::new());
}
let json_schema = self.create_batch_response_schema();
let mut last_error = None;
for attempt in 0..=Self::MAX_BATCH_RETRIES {
if attempt > 0 {
let delay = Duration::from_secs(5 * (1 << (attempt - 1))); let err_msg = last_error
.as_ref()
.map(|e: &anyhow::Error| e.to_string())
.unwrap_or_default();
if !self.quiet {
eprintln!(
"⚠️ AI batch attempt {}/{} failed ({}), retrying in {:?}...",
attempt,
Self::MAX_BATCH_RETRIES + 1,
err_msg,
delay
);
}
tokio::time::sleep(delay).await;
}
let user_message = self.build_batch_user_message(files);
match self
.call_llm_json(
&self.config.graphrag.llm.description_model,
self.config.graphrag.llm.description_system_prompt.clone(),
user_message,
Some(json_schema.clone()),
)
.await
{
Ok(response) => match self.parse_batch_response(&response, files) {
Ok(results) => return Ok(results),
Err(e) => {
last_error = Some(e);
}
},
Err(e) => {
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| {
anyhow::anyhow!(
"AI batch description failed for {} files after {} retries",
files.len(),
Self::MAX_BATCH_RETRIES
)
}))
}
fn build_batch_user_message(&self, files: &[FileForAI]) -> String {
let mut message = format!(
"Analyze the following {} files and provide architectural descriptions:\n\n",
files.len()
);
for (index, file) in files.iter().enumerate() {
message.push_str(&format!("=== FILE {} ===\n", index + 1));
message.push_str(&format!("ID: {}\n", file.file_id));
message.push_str(&format!("Language: {}\n", file.language));
message.push_str(&format!(
"Stats: {} functions, {} classes/structs\n",
file.function_count, file.class_count
));
message.push_str(&format!(
"Key symbols: {}\n",
file.symbols
.iter()
.take(5)
.cloned()
.collect::<Vec<_>>()
.join(", ")
));
message.push_str(&format!("Code sample:\n{}\n\n", file.content_sample));
}
message.push_str(
"Respond with JSON: {\"descriptions\": [{\"file_id\": \"<ID>\", \"description\": \"<2-3 sentence description>\"}]}\n\
Include one entry per file using the exact ID provided.",
);
message
}
fn create_batch_response_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"descriptions": {
"type": "array",
"items": {
"type": "object",
"properties": {
"file_id": {
"type": "string"
},
"description": {
"type": "string"
}
},
"required": ["file_id", "description"]
}
}
},
"required": ["descriptions"]
})
}
fn parse_batch_response(
&self,
response: &serde_json::Value,
files: &[FileForAI],
) -> Result<HashMap<String, String>> {
let parsed: BatchDescriptionResponse = serde_json::from_value(response.clone())
.map_err(|e| anyhow::anyhow!("Failed to parse batch response: {}", e))?;
let mut results = HashMap::new();
for desc in parsed.descriptions {
if files.iter().any(|f| f.file_id == desc.file_id) {
let cleaned_desc = if desc.description.len() > 300 {
format!("{}...", &desc.description[0..297])
} else {
desc.description
};
results.insert(desc.file_id, cleaned_desc);
} else if !self.quiet {
eprintln!(
"⚠️ Received description for unknown file: {}",
desc.file_id
);
}
}
let missing_files: Vec<&str> = files
.iter()
.filter(|f| !results.contains_key(&f.file_id))
.map(|f| f.file_id.as_str())
.collect();
if !missing_files.is_empty() && !self.quiet {
eprintln!(
"⚠️ Missing descriptions for {} files: {:?}",
missing_files.len(),
missing_files
);
}
Ok(results)
}
async fn call_llm(&self, model_name: &str, system: String, prompt: String) -> Result<String> {
let client = self.create_llm_client(model_name)?;
let messages = vec![Message::system(&system), Message::user(&prompt)];
client.chat_completion(messages).await
}
async fn call_llm_json(
&self,
model_name: &str,
system: String,
prompt: String,
schema: Option<serde_json::Value>,
) -> Result<serde_json::Value> {
let client = self.create_llm_client(model_name)?;
let messages = vec![Message::system(&system), Message::user(&prompt)];
client.chat_completion_json(messages, schema).await
}
fn create_llm_client(&self, model_name: &str) -> Result<LlmClient> {
if self.llm_client.is_some() {
LlmClient::with_model(&self.config, model_name)
} else {
Err(anyhow::anyhow!("LLM client not initialized"))
}
}
pub async fn extract_ai_description(
&self,
content_sample: &str,
file_path: &str,
language: &str,
symbols: &[String],
) -> Result<String> {
let function_count = symbols
.iter()
.filter(|s| s.contains("function_") || s.contains("method_"))
.count();
let class_count = symbols
.iter()
.filter(|s| s.contains("class_") || s.contains("struct_"))
.count();
let user_message = format!(
"File: {}\nLanguage: {}\nStats: {} functions, {} classes/structs\nKey symbols: {}\n\nCode sample:\n{}",
std::path::Path::new(file_path).file_name().and_then(|s| s.to_str()).unwrap_or("unknown"),
language,
function_count,
class_count,
symbols.iter().take(5).cloned().collect::<Vec<_>>().join(", "),
content_sample
);
match self
.call_llm(
&self.config.graphrag.llm.description_model,
self.config.graphrag.llm.description_system_prompt.clone(),
user_message,
)
.await
{
Ok(description) => {
let cleaned = description.trim();
if cleaned.len() > 300 {
Ok(format!("{}...", &cleaned[0..297]))
} else {
Ok(cleaned.to_string())
}
}
Err(e) => {
if !self.quiet {
eprintln!("Warning: AI description failed for {}: {}", file_path, e);
}
Err(e)
}
}
}
}