use anyhow::{Result, anyhow};
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::generator::agent_executor::{AgentExecuteParams, extract, prompt, prompt_with_tools};
use crate::generator::preprocess::memory::{MemoryScope, ScopedKeys};
use crate::generator::research::memory::MemoryRetriever;
use crate::{
generator::context::GeneratorContext,
types::{
code::CodeInsight, code_releationship::RelationshipAnalysis,
project_structure::ProjectStructure,
},
utils::project_structure_formatter::ProjectStructureFormatter,
};
#[derive(Debug, Clone, PartialEq)]
pub enum DataSource {
MemoryData {
scope: &'static str,
key: &'static str,
},
ResearchResult(String),
}
impl DataSource {
pub const PROJECT_STRUCTURE: DataSource = DataSource::MemoryData {
scope: MemoryScope::PREPROCESS,
key: ScopedKeys::PROJECT_STRUCTURE,
};
pub const CODE_INSIGHTS: DataSource = DataSource::MemoryData {
scope: MemoryScope::PREPROCESS,
key: ScopedKeys::CODE_INSIGHTS,
};
pub const DEPENDENCY_ANALYSIS: DataSource = DataSource::MemoryData {
scope: MemoryScope::PREPROCESS,
key: ScopedKeys::RELATIONSHIPS,
};
pub const README_CONTENT: DataSource = DataSource::MemoryData {
scope: MemoryScope::PREPROCESS,
key: ScopedKeys::ORIGINAL_DOCUMENT,
};
}
#[derive(Debug, Clone)]
pub struct AgentDataConfig {
pub required_sources: Vec<DataSource>,
pub optional_sources: Vec<DataSource>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum LLMCallMode {
Extract,
#[allow(dead_code)]
Prompt,
PromptWithTools,
}
#[derive(Debug, Clone)]
pub struct FormatterConfig {
pub code_insights_limit: usize,
pub include_source_code: bool,
pub dependency_limit: usize,
pub readme_truncate_length: Option<usize>,
}
impl Default for FormatterConfig {
fn default() -> Self {
Self {
code_insights_limit: 50,
include_source_code: false,
dependency_limit: 50,
readme_truncate_length: Some(16384),
}
}
}
#[derive(Debug, Clone)]
pub struct PromptTemplate {
pub system_prompt: String,
pub opening_instruction: String,
pub closing_instruction: String,
pub llm_call_mode: LLMCallMode,
pub formatter_config: FormatterConfig,
}
pub struct DataFormatter {
config: FormatterConfig,
}
impl DataFormatter {
pub fn new(config: FormatterConfig) -> Self {
Self { config }
}
pub fn format_project_structure(&self, structure: &ProjectStructure) -> String {
let project_tree_str = ProjectStructureFormatter::format_as_tree(structure);
format!(
"### 项目结构信息\n项目名称: {}\n根目录: {}\n\n项目目录结构:\n``` txt{}```\n",
structure.project_name,
structure.root_path.to_string_lossy(),
project_tree_str
)
}
pub fn format_code_insights(&self, insights: &[CodeInsight]) -> String {
let config = &self.config;
let mut content = String::from("### 源码洞察摘要\n");
for (i, insight) in insights
.iter()
.take(self.config.code_insights_limit)
.enumerate()
{
content.push_str(&format!(
"{}. 文件`{}`,用途类型为`{}`\n",
i + 1,
insight.code_dossier.file_path.to_string_lossy(),
insight.code_dossier.code_purpose
));
if !insight.detailed_description.is_empty() {
content.push_str(&format!(" 详细描述: {}\n", &insight.detailed_description));
}
if config.include_source_code {
content.push_str(&format!(
" 源码详情: ```code\n{}\n\n",
&insight.code_dossier.source_summary
));
}
}
content.push_str("\n");
content
}
pub fn format_readme_content(&self, readme: &str) -> String {
let content = if let Some(limit) = self.config.readme_truncate_length {
if readme.len() > limit {
format!("{}...(已截断)", &readme[..limit])
} else {
readme.to_string()
}
} else {
readme.to_string()
};
format!(
"### 先前README内容(为人工录入的信息,不一定准确,仅供参考)\n{}\n\n",
content
)
}
pub fn format_dependency_analysis(&self, deps: &RelationshipAnalysis) -> String {
let mut content = String::from("### 依赖关系分析\n");
for rel in deps
.core_dependencies
.iter()
.take(self.config.dependency_limit)
{
content.push_str(&format!(
"{} -> {} ({})\n",
rel.from,
rel.to,
rel.dependency_type.as_str()
));
}
content.push_str("\n");
content
}
pub fn format_research_results(&self, results: &HashMap<String, serde_json::Value>) -> String {
let mut content = String::from("### 已有调研结果\n");
for (key, value) in results {
content.push_str(&format!(
"#### {}:\n{}\n\n",
key,
serde_json::to_string_pretty(value).unwrap_or_default()
));
}
content
}
}
pub struct GeneratorPromptBuilder {
template: PromptTemplate,
formatter: DataFormatter,
}
impl GeneratorPromptBuilder {
pub fn new(template: PromptTemplate) -> Self {
let formatter = DataFormatter::new(template.formatter_config.clone());
Self {
template,
formatter,
}
}
pub async fn build_prompts(
&self,
context: &GeneratorContext,
data_sources: &[DataSource],
) -> Result<(String, String)> {
let system_prompt = self.template.system_prompt.clone();
let user_prompt = self
.build_standard_user_prompt(context, data_sources)
.await?;
Ok((system_prompt, user_prompt))
}
async fn build_standard_user_prompt(
&self,
context: &GeneratorContext,
data_sources: &[DataSource],
) -> Result<String> {
let mut prompt = String::new();
prompt.push_str(&self.template.opening_instruction);
prompt.push_str("\n\n");
prompt.push_str("## 调研材料参考\n");
let mut research_results = HashMap::new();
for source in data_sources {
match source {
DataSource::MemoryData { scope, key } => match *key {
ScopedKeys::PROJECT_STRUCTURE => {
if let Some(structure) = context
.get_from_memory::<ProjectStructure>(scope, key)
.await
{
prompt.push_str(&self.formatter.format_project_structure(&structure));
}
}
ScopedKeys::CODE_INSIGHTS => {
if let Some(insights) = context
.get_from_memory::<Vec<CodeInsight>>(scope, key)
.await
{
prompt.push_str(&self.formatter.format_code_insights(&insights));
}
}
ScopedKeys::ORIGINAL_DOCUMENT => {
if let Some(readme) = context.get_from_memory::<String>(scope, key).await {
prompt.push_str(&self.formatter.format_readme_content(&readme));
}
}
ScopedKeys::RELATIONSHIPS => {
if let Some(deps) = context
.get_from_memory::<RelationshipAnalysis>(scope, key)
.await
{
prompt.push_str(&self.formatter.format_dependency_analysis(&deps));
}
}
_ => {}
},
DataSource::ResearchResult(agent_type) => {
if let Some(result) = context.get_research(agent_type).await {
research_results.insert(agent_type.clone(), result);
}
}
}
}
if !research_results.is_empty() {
prompt.push_str(&self.formatter.format_research_results(&research_results));
}
prompt.push_str(&self.template.closing_instruction);
Ok(prompt)
}
}
#[async_trait]
pub trait StepForwardAgent: Send + Sync {
type Output: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static;
fn agent_type(&self) -> String;
fn memory_scope_key(&self) -> String;
fn data_config(&self) -> AgentDataConfig;
fn prompt_template(&self) -> PromptTemplate;
fn post_process(&self, _result: &Self::Output, _context: &GeneratorContext) -> Result<()> {
Ok(())
}
async fn execute(&self, context: &GeneratorContext) -> Result<Self::Output> {
let config = self.data_config();
for source in &config.required_sources {
match source {
DataSource::MemoryData { scope, key } => {
if !context.has_memory_data(scope, key).await {
return Err(anyhow!("必需的数据源 {}:{} 不可用", scope, key));
}
}
DataSource::ResearchResult(agent_type) => {
if context.get_research(agent_type).await.is_none() {
return Err(anyhow!("必需的研究结果 {} 不可用", agent_type));
}
}
}
}
let all_sources = [config.required_sources, config.optional_sources].concat();
let template = self.prompt_template();
let prompt_builder = GeneratorPromptBuilder::new(template.clone());
let (system_prompt, user_prompt) =
prompt_builder.build_prompts(context, &all_sources).await?;
let params = AgentExecuteParams {
prompt_sys: system_prompt,
prompt_user: user_prompt,
cache_scope: format!("{}/{}", self.memory_scope_key(), self.agent_type()),
log_tag: self.agent_type().to_string(),
};
let result_value = match template.llm_call_mode {
LLMCallMode::Extract => {
let result: Self::Output = extract(context, params).await?;
serde_json::to_value(&result)?
}
LLMCallMode::Prompt => {
let result_text: String = prompt(context, params).await?;
serde_json::to_value(&result_text)?
}
LLMCallMode::PromptWithTools => {
let result_text: String = prompt_with_tools(context, params).await?;
serde_json::to_value(&result_text)?
}
};
context
.store_to_memory(
&self.memory_scope_key(),
&self.agent_type(),
result_value.clone(),
)
.await?;
if let Ok(typed_result) = serde_json::from_value::<Self::Output>(result_value) {
self.post_process(&typed_result, context)?;
println!("✅ Sub-Agent [{}]执行完成", self.agent_type());
Ok(typed_result)
} else {
Err(anyhow::format_err!(""))
}
}
}