use crate::config::Config;
use crate::indexer::graphrag::types::{CodeNode, CodeRelationship};
use crate::llm::{LlmClient, Message};
use anyhow::Result;
use serde::Deserialize;
use serde_json::json;
use std::collections::HashMap;
pub struct AIEnhancements {
config: Config,
llm_client: Option<LlmClient>,
quiet: bool,
}
#[derive(Debug, Clone)]
pub struct FileForAI {
pub file_id: String,
pub file_path: String,
pub language: String,
pub symbols: Vec<String>,
pub content_sample: String,
pub function_count: usize,
pub class_count: usize,
}
#[derive(Debug, Deserialize)]
struct BatchDescriptionResponse {
descriptions: Vec<FileDescription>,
}
#[derive(Debug, Deserialize)]
struct FileDescription {
file_id: String,
description: String,
}
impl AIEnhancements {
pub fn new(config: Config, quiet: bool) -> Self {
let llm_client = if config.graphrag.use_llm {
LlmClient::from_config(&config).ok()
} else {
None
};
Self {
config,
llm_client,
quiet,
}
}
pub fn llm_enabled(&self) -> bool {
self.config.graphrag.use_llm && self.llm_client.is_some()
}
pub async fn discover_relationships_with_ai_enhancement(
&self,
new_files: &[CodeNode],
all_nodes: &[CodeNode],
) -> Result<Vec<CodeRelationship>> {
let mut relationships = crate::indexer::graphrag::relationships::RelationshipDiscovery::discover_relationships_efficiently(new_files, all_nodes).await?;
let ai_relationships = self
.discover_complex_relationships_with_ai(new_files, all_nodes)
.await?;
relationships.extend(ai_relationships);
relationships.sort_by(|a, b| {
(a.source.clone(), a.target.clone(), a.relation_type.clone()).cmp(&(
b.source.clone(),
b.target.clone(),
b.relation_type.clone(),
))
});
relationships.dedup_by(|a, b| {
a.source == b.source && a.target == b.target && a.relation_type == b.relation_type
});
Ok(relationships)
}
async fn discover_complex_relationships_with_ai(
&self,
new_files: &[CodeNode],
all_nodes: &[CodeNode],
) -> Result<Vec<CodeRelationship>> {
let mut ai_relationships = Vec::new();
let complex_files: Vec<&CodeNode> = new_files
.iter()
.filter(|node| self.should_use_ai_for_relationships(node))
.collect();
if complex_files.is_empty() {
if !self.quiet {
eprintln!("Debug: No files qualified for AI relationship analysis in this batch");
}
return Ok(ai_relationships);
}
if !self.quiet {
eprintln!(
"Info: AI analyzing {} files for architectural relationships",
complex_files.len()
);
}
let ai_batch_size = self.config.graphrag.llm.ai_batch_size;
for batch in complex_files.chunks(ai_batch_size) {
if let Ok(batch_relationships) = self
.analyze_architectural_relationships_batch(batch, all_nodes)
.await
{
ai_relationships.extend(batch_relationships);
}
}
Ok(ai_relationships)
}
fn should_use_ai_for_relationships(&self, node: &CodeNode) -> bool {
let has_meaningful_exports = node.exports.len() >= 2;
let is_substantial_file = node.size_lines >= 50;
let has_multiple_symbols = node.symbols.len() >= 3;
has_meaningful_exports || is_substantial_file || has_multiple_symbols
}
async fn analyze_architectural_relationships_batch(
&self,
source_nodes: &[&CodeNode],
all_nodes: &[CodeNode],
) -> Result<Vec<CodeRelationship>> {
let system_prompt = String::from(
"You are an expert software architect. Analyze these code files and identify ARCHITECTURAL relationships.\n\
Focus on design patterns, dependency injection, factory patterns, observer patterns, etc.\n\
Look for relationships that go beyond simple imports - identify architectural significance.\n\n\
Respond with a JSON array of relationships. For each relationship, include:\n\
- source_path: relative path of source file\n\
- target_path: relative path of target file\n\
- relation_type: one of 'implements_pattern', 'dependency_injection', 'factory_creates', 'observer_pattern', 'strategy_pattern', 'adapter_pattern', 'decorator_pattern', 'architectural_dependency'\n\
- description: brief explanation of the architectural relationship\n\
- confidence: 0.0-1.0 confidence score\n\n"
);
let mut batch_prompt = String::from("");
batch_prompt.push_str("SOURCE FILES TO ANALYZE:\n");
for node in source_nodes {
batch_prompt.push_str(&format!(
"File: {}\nLanguage: {}\nKey symbols: {}\nExports: {}\n\n",
node.path,
node.language,
node.symbols
.iter()
.take(8)
.cloned()
.collect::<Vec<_>>()
.join(", "),
node.exports
.iter()
.take(5)
.cloned()
.collect::<Vec<_>>()
.join(", ")
));
}
batch_prompt.push_str("POTENTIAL RELATIONSHIP TARGETS:\n");
let relevant_targets: Vec<&CodeNode> = all_nodes
.iter()
.filter(|n| source_nodes.iter().all(|s| s.id != n.id)) .filter(|n| !n.exports.is_empty() || n.size_lines > 100) .take(10) .collect();
for node in &relevant_targets {
batch_prompt.push_str(&format!(
"File: {}\nLanguage: {}\nExports: {}\n\n",
node.path,
node.language,
node.exports
.iter()
.take(3)
.cloned()
.collect::<Vec<_>>()
.join(", ")
));
}
batch_prompt.push_str("JSON Response:");
match self
.call_llm(
&self.config.graphrag.llm.relationship_model,
system_prompt,
batch_prompt,
None,
)
.await
{
Ok(response) => {
if let Ok(ai_relationships) = self.parse_ai_architectural_relationships(&response) {
let valid_relationships: Vec<CodeRelationship> = ai_relationships
.into_iter()
.filter(|rel| {
rel.confidence > self.config.graphrag.llm.confidence_threshold
}) .filter(|rel| all_nodes.iter().any(|n| n.path == rel.target)) .map(|mut rel| {
rel.weight = self.config.graphrag.llm.architectural_weight; rel
})
.collect();
Ok(valid_relationships)
} else {
Ok(Vec::new())
}
}
Err(e) => {
eprintln!("Warning: AI architectural analysis failed: {}", e);
Ok(Vec::new())
}
}
}
fn parse_ai_architectural_relationships(
&self,
response: &str,
) -> Result<Vec<CodeRelationship>> {
#[derive(Deserialize)]
struct AiRelationship {
source_path: String,
target_path: String,
relation_type: String,
description: String,
confidence: f32,
}
if let Ok(ai_rels) = serde_json::from_str::<Vec<AiRelationship>>(response) {
let relationships = ai_rels
.into_iter()
.map(|ai_rel| CodeRelationship {
source: ai_rel.source_path,
target: ai_rel.target_path,
relation_type: ai_rel
.relation_type
.parse()
.unwrap_or(crate::indexer::graphrag::types::RelationType::Imports),
description: ai_rel.description,
confidence: ai_rel.confidence,
weight: 0.9, })
.collect();
return Ok(relationships);
}
Ok(Vec::new())
}
pub fn should_use_ai_for_description(
&self,
symbols: &[String],
lines: u32,
language: &str,
) -> bool {
let function_count = symbols.len(); let has_substantial_content = lines > 20; let is_important_language = crate::indexer::languages::get_language(language).is_some();
let should_use = has_substantial_content && is_important_language && function_count > 0;
if !self.quiet {
eprintln!(
"🤖 AI Decision: file={} lines, symbols={}, language={}, use_ai={}",
lines,
symbols.len(),
language,
should_use
);
if should_use && !symbols.is_empty() {
eprintln!("🔍 Symbols found: {:?}", &symbols[..symbols.len().min(5)]);
}
}
should_use
}
pub fn build_content_sample_for_ai(&self, file_blocks: &[&crate::store::CodeBlock]) -> String {
let mut sample = String::new();
let mut total_tokens = 0;
let max_tokens = self.config.graphrag.llm.max_sample_tokens;
let mut sorted_blocks: Vec<&crate::store::CodeBlock> = file_blocks.to_vec();
sorted_blocks.sort_by(|a, b| b.symbols.len().cmp(&a.symbols.len()));
for block in sorted_blocks {
let block_tokens = crate::embedding::count_tokens(&block.content);
if total_tokens + block_tokens >= max_tokens {
break;
}
let block_content = if block.content.len() > 300 {
let start_chars: String = block.content.chars().take(150).collect();
let end_chars: String = block
.content
.chars()
.rev()
.take(150)
.collect::<String>()
.chars()
.rev()
.collect();
format!("{}\n...\n{}", start_chars, end_chars)
} else {
block.content.clone()
};
sample.push_str(&format!(
"// Block: {} symbols\n{}\n\n",
block.symbols.len(),
block_content
));
total_tokens += block_tokens + 50; }
sample
}
pub async fn extract_ai_descriptions_batch(
&self,
files: &[FileForAI],
) -> Result<HashMap<String, String>> {
if files.is_empty() {
return Ok(HashMap::new());
}
let user_message = self.build_batch_user_message(files);
let json_schema = self.create_batch_response_schema();
match self
.call_llm(
&self.config.graphrag.llm.description_model,
self.config.graphrag.llm.description_system_prompt.clone(),
user_message,
Some(json_schema),
)
.await
{
Ok(response) => self.parse_batch_response(&response, files),
Err(e) => {
if !self.quiet {
eprintln!(
"⚠️ Batch AI description failed for {} files: {}",
files.len(),
e
);
}
if self.config.graphrag.llm.ai_batch_size > 1 {
if !self.quiet {
eprintln!("🔄 Falling back to individual AI calls...");
}
self.fallback_to_individual_calls(files).await
} else {
Err(e)
}
}
}
}
fn build_batch_user_message(&self, files: &[FileForAI]) -> String {
let mut message = format!(
"Analyze the following {} files and provide architectural descriptions:\n\n",
files.len()
);
for (index, file) in files.iter().enumerate() {
message.push_str(&format!("=== FILE {} ===\n", index + 1));
message.push_str(&format!("ID: {}\n", file.file_id));
message.push_str(&format!("Language: {}\n", file.language));
message.push_str(&format!(
"Stats: {} functions, {} classes/structs\n",
file.function_count, file.class_count
));
message.push_str(&format!(
"Key symbols: {}\n",
file.symbols
.iter()
.take(5)
.cloned()
.collect::<Vec<_>>()
.join(", ")
));
message.push_str(&format!("Code sample:\n{}\n\n", file.content_sample));
}
message.push_str("Provide a JSON response with descriptions for each file.");
message
}
fn create_batch_response_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"descriptions": {
"type": "array",
"items": {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file ID exactly as provided in the request"
},
"description": {
"type": "string",
"description": "Architectural description of the file (max 300 chars)"
}
},
"required": ["file_id", "description"],
"additionalProperties": false
}
}
},
"required": ["descriptions"],
"additionalProperties": false
})
}
fn parse_batch_response(
&self,
response: &str,
files: &[FileForAI],
) -> Result<HashMap<String, String>> {
let parsed: BatchDescriptionResponse = serde_json::from_str(response)
.map_err(|e| anyhow::anyhow!("Failed to parse batch response: {}", e))?;
let mut results = HashMap::new();
for desc in parsed.descriptions {
if files.iter().any(|f| f.file_id == desc.file_id) {
let cleaned_desc = if desc.description.len() > 300 {
format!("{}...", &desc.description[0..297])
} else {
desc.description
};
results.insert(desc.file_id, cleaned_desc);
} else if !self.quiet {
eprintln!(
"⚠️ Received description for unknown file: {}",
desc.file_id
);
}
}
let missing_files: Vec<&str> = files
.iter()
.filter(|f| !results.contains_key(&f.file_id))
.map(|f| f.file_id.as_str())
.collect();
if !missing_files.is_empty() && !self.quiet {
eprintln!(
"⚠️ Missing descriptions for {} files: {:?}",
missing_files.len(),
missing_files
);
}
Ok(results)
}
async fn fallback_to_individual_calls(
&self,
files: &[FileForAI],
) -> Result<HashMap<String, String>> {
let mut results = HashMap::new();
for file in files {
match self
.extract_ai_description(
&file.content_sample,
&file.file_path,
&file.language,
&file.symbols,
)
.await
{
Ok(description) => {
results.insert(file.file_id.clone(), description);
}
Err(e) => {
if !self.quiet {
eprintln!(
"⚠️ Individual AI description failed for {}: {}",
file.file_id, e
);
}
}
}
}
Ok(results)
}
async fn call_llm(
&self,
model_name: &str,
system: String,
prompt: String,
_json_schema: Option<serde_json::Value>,
) -> Result<String> {
let client = if let Some(ref client) = self.llm_client {
if client.model() != model_name {
LlmClient::with_model(&self.config, model_name)?
} else {
LlmClient::with_model(&self.config, model_name)?
}
} else {
return Err(anyhow::anyhow!("LLM client not initialized"));
};
let messages = vec![Message::system(&system), Message::user(&prompt)];
let response = client.chat_completion(messages).await?;
Ok(response)
}
pub async fn extract_ai_description(
&self,
content_sample: &str,
file_path: &str,
language: &str,
symbols: &[String],
) -> Result<String> {
let function_count = symbols
.iter()
.filter(|s| s.contains("function_") || s.contains("method_"))
.count();
let class_count = symbols
.iter()
.filter(|s| s.contains("class_") || s.contains("struct_"))
.count();
let user_message = format!(
"File: {}\nLanguage: {}\nStats: {} functions, {} classes/structs\nKey symbols: {}\n\nCode sample:\n{}",
std::path::Path::new(file_path).file_name().and_then(|s| s.to_str()).unwrap_or("unknown"),
language,
function_count,
class_count,
symbols.iter().take(5).cloned().collect::<Vec<_>>().join(", "),
content_sample
);
match self
.call_llm(
&self.config.graphrag.llm.description_model,
self.config.graphrag.llm.description_system_prompt.clone(),
user_message,
None,
)
.await
{
Ok(description) => {
let cleaned = description.trim();
if cleaned.len() > 300 {
Ok(format!("{}...", &cleaned[0..297]))
} else {
Ok(cleaned.to_string())
}
}
Err(e) => {
if !self.quiet {
eprintln!("Warning: AI description failed for {}: {}", file_path, e);
}
Err(e)
}
}
}
}