use async_trait::async_trait;
use rucora_core::agent::{Agent, AgentContext, AgentDecision, AgentError, AgentInput, AgentOutput};
use rucora_core::provider::LlmProvider;
use rucora_core::provider::types::{ChatMessage, ChatRequest, LlmParams, Role};
use rucora_core::tool::Tool;
use serde_json::{Value, json};
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::agent::ToolRegistry;
use crate::agent::execution::DefaultExecution;
use crate::conversation::ConversationManager;
pub struct ReflectAgent<P> {
#[allow(dead_code)]
provider: Arc<P>,
#[allow(dead_code)]
model: String,
#[allow(dead_code)]
system_prompt: Option<String>,
#[allow(dead_code)]
tools: ToolRegistry,
#[allow(dead_code)]
max_iterations: usize,
#[allow(dead_code)]
quality_threshold: f32,
#[allow(dead_code)]
conversation_manager: Option<Arc<Mutex<ConversationManager>>>,
llm_params: LlmParams,
execution: DefaultExecution,
}
#[async_trait]
impl<P> Agent for ReflectAgent<P>
where
P: LlmProvider + Send + Sync + 'static,
{
async fn think(&self, context: &AgentContext) -> AgentDecision {
let iteration = context.step / 2;
if iteration == 0 {
AgentDecision::Chat {
request: Box::new(self._build_generate_request(context)),
}
} else if iteration >= self.max_iterations {
AgentDecision::Return(self._build_final_result(context))
} else {
if context.step % 2 == 1 {
AgentDecision::Chat {
request: Box::new(self._build_reflect_request(context, iteration)),
}
} else {
AgentDecision::Chat {
request: Box::new(self._build_improve_request(context, iteration)),
}
}
}
}
fn name(&self) -> &str {
"reflect_agent"
}
fn description(&self) -> Option<&str> {
Some("反思迭代 Agent,生成 - 反思 - 改进循环")
}
async fn run(&self, input: AgentInput) -> Result<AgentOutput, rucora_core::agent::AgentError> {
self.execution.run(self, input).await
}
fn run_stream(
&self,
input: AgentInput,
) -> futures_util::stream::BoxStream<
'static,
Result<rucora_core::channel::types::ChannelEvent, rucora_core::agent::AgentError>,
> {
self.execution.run_stream_simple(input)
}
}
impl<P> ReflectAgent<P>
where
P: LlmProvider + Send + Sync + 'static,
{
pub async fn run_stream_text(
&self,
input: impl Into<AgentInput>,
) -> Result<String, rucora_core::agent::AgentError> {
self.execution.run_stream_text(input.into()).await
}
}
impl<P> ReflectAgent<P>
where
P: LlmProvider,
{
pub fn builder() -> ReflectAgentBuilder<P> {
ReflectAgentBuilder::new()
}
fn _build_generate_request(&self, context: &AgentContext) -> ChatRequest {
let prompt = format!(
"请生成初始版本:{}\n\
\n\
要求:\n\
1. 完整实现功能\n\
2. 保证正确性\n\
3. 尽可能详细\n\
\n\
稍后我会进行自我反思和改进。",
context.input.text()
);
self._build_request(context, prompt)
}
fn _build_reflect_request(&self, context: &AgentContext, iteration: usize) -> ChatRequest {
let prompt = format!(
"请反思第 {iteration} 版本的质量:\n\
\n\
反思维度:\n\
1. **正确性**:是否有错误或遗漏?\n\
2. **完整性**:是否覆盖所有需求?\n\
3. **清晰度**:是否易于理解?\n\
4. **优化空间**:哪些地方可以改进?\n\
\n\
请详细列出问题和改进建议。"
);
self._build_request(context, prompt)
}
fn _build_improve_request(&self, context: &AgentContext, iteration: usize) -> ChatRequest {
let prompt = format!(
"请根据反思改进第 {} 版本:\n\
\n\
改进要求:\n\
1. 修复所有发现的问题\n\
2. 采纳所有合理的改进建议\n\
3. 保持原有优点\n\
4. 生成更高质量的版本\n\
\n\
目标质量阈值:{}",
iteration, self.quality_threshold
);
self._build_request(context, prompt)
}
fn _build_final_result(&self, context: &AgentContext) -> Value {
if let Some(last_msg) = context.messages.last() {
json!({
"content": last_msg.content,
"iterations": self.max_iterations,
"completed": true
})
} else {
json!({
"content": "未能生成结果",
"iterations": self.max_iterations,
"completed": false
})
}
}
fn _build_request(&self, context: &AgentContext, prompt: String) -> ChatRequest {
let mut messages = context.messages.clone();
if let Some(ref sys_prompt) = self.system_prompt
&& (messages.is_empty() || messages.first().map(|m| &m.role) != Some(&Role::System))
{
messages.insert(0, ChatMessage::system(sys_prompt.clone()));
}
messages.push(ChatMessage::user(prompt));
let mut request = ChatRequest {
messages,
model: Some(self.model.clone()),
tools: if !self.tools.definitions().is_empty() {
Some(self.tools.definitions())
} else {
None
},
..Default::default()
};
self.llm_params.apply_to(&mut request);
request
}
pub fn tools(&self) -> Vec<&str> {
self.tools
.tool_names()
.into_iter()
.map(|s| s.as_str())
.collect()
}
}
pub struct ReflectAgentBuilder<P> {
provider: Option<P>,
system_prompt: Option<String>,
model: Option<String>,
tools: ToolRegistry,
max_iterations: usize,
quality_threshold: f32,
with_conversation: bool,
middleware_chain: crate::middleware::MiddlewareChain,
llm_params: LlmParams,
}
impl<P> ReflectAgentBuilder<P> {
pub fn new() -> Self {
Self {
provider: None,
system_prompt: None,
model: None,
tools: ToolRegistry::new(),
max_iterations: 3,
quality_threshold: 0.9,
with_conversation: false,
middleware_chain: crate::middleware::MiddlewareChain::new(),
llm_params: LlmParams::default(),
}
}
}
impl<P> ReflectAgentBuilder<P>
where
P: LlmProvider + Send + Sync + 'static,
{
pub fn provider(mut self, provider: P) -> Self {
self.provider = Some(provider);
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
self.tools = self.tools.register(tool);
self
}
pub fn tools<I, T>(mut self, tools: I) -> Self
where
I: IntoIterator<Item = T>,
T: Tool + 'static,
{
for tool in tools {
self.tools = self.tools.register(tool);
}
self
}
pub fn max_iterations(mut self, max: usize) -> Self {
self.max_iterations = max;
self
}
pub fn quality_threshold(mut self, threshold: f32) -> Self {
self.quality_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn temperature(mut self, value: f32) -> Self {
self.llm_params.temperature = Some(value);
self
}
pub fn top_p(mut self, value: f32) -> Self {
self.llm_params.top_p = Some(value);
self
}
pub fn top_k(mut self, value: u32) -> Self {
self.llm_params.top_k = Some(value);
self
}
pub fn max_tokens(mut self, value: u32) -> Self {
self.llm_params.max_tokens = Some(value);
self
}
pub fn frequency_penalty(mut self, value: f32) -> Self {
self.llm_params.frequency_penalty = Some(value);
self
}
pub fn presence_penalty(mut self, value: f32) -> Self {
self.llm_params.presence_penalty = Some(value);
self
}
pub fn stop(mut self, value: Vec<String>) -> Self {
self.llm_params.stop = Some(value);
self
}
pub fn extra_params(mut self, value: serde_json::Value) -> Self {
self.llm_params.extra = Some(value);
self
}
pub fn llm_params(mut self, params: LlmParams) -> Self {
self.llm_params = params;
self
}
pub fn with_conversation(mut self, enabled: bool) -> Self {
self.with_conversation = enabled;
self
}
pub fn with_middleware_chain(
mut self,
middleware_chain: crate::middleware::MiddlewareChain,
) -> Self {
self.middleware_chain = middleware_chain;
self
}
pub fn with_middleware<M: crate::middleware::Middleware + 'static>(
mut self,
middleware: M,
) -> Self {
self.middleware_chain = self.middleware_chain.with(middleware);
self
}
pub fn try_build(self) -> Result<ReflectAgent<P>, AgentError> {
let provider = self.provider.ok_or_else(|| {
AgentError::Message("构建 ReflectAgent 失败:缺少 provider".to_string())
})?;
let model = self
.model
.ok_or_else(|| AgentError::Message("构建 ReflectAgent 失败:缺少 model".to_string()))?;
let conversation_manager = if self.with_conversation {
let mut conv = ConversationManager::new();
if let Some(ref prompt) = self.system_prompt {
conv = conv.with_system_prompt(prompt.clone());
}
Some(Arc::new(Mutex::new(conv)))
} else {
None
};
let provider_arc = Arc::new(provider);
let execution =
DefaultExecution::new(provider_arc.clone(), model.clone(), self.tools.clone())
.with_system_prompt_opt(self.system_prompt.clone())
.with_max_steps(self.max_iterations * 2) .with_conversation_manager(conversation_manager.clone())
.with_middleware_chain(self.middleware_chain)
.with_llm_params(self.llm_params.clone());
Ok(ReflectAgent {
provider: provider_arc,
model,
system_prompt: self.system_prompt,
tools: self.tools,
max_iterations: self.max_iterations,
quality_threshold: self.quality_threshold,
conversation_manager,
llm_params: self.llm_params,
execution,
})
}
pub fn build(self) -> ReflectAgent<P> {
self.try_build()
.unwrap_or_else(|err| panic!("ReflectAgentBuilder::build 失败:{err}"))
}
}
impl<P> Default for ReflectAgentBuilder<P> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::stream;
use futures_util::stream::BoxStream;
use rucora_core::error::ProviderError;
use rucora_core::provider::types::{ChatResponse, ChatStreamChunk};
struct MockProvider;
#[async_trait]
impl LlmProvider for MockProvider {
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ProviderError> {
Ok(ChatResponse {
message: ChatMessage {
role: Role::Assistant,
content: "Mock response".to_string(),
name: None,
},
tool_calls: vec![],
usage: None,
finish_reason: None,
})
}
fn stream_chat(
&self,
_request: ChatRequest,
) -> Result<BoxStream<'static, Result<ChatStreamChunk, ProviderError>>, ProviderError>
{
Ok(Box::pin(stream::empty()))
}
}
#[test]
fn test_reflect_agent_builder() {
let _agent = ReflectAgentBuilder::<MockProvider>::new()
.provider(MockProvider)
.model("gpt-4o-mini")
.max_iterations(3)
.quality_threshold(0.9)
.build();
}
}