use async_trait::async_trait;
use rucora_core::agent::{Agent, AgentContext, AgentDecision, AgentError, AgentInput, AgentOutput};
use rucora_core::provider::LlmProvider;
use rucora_core::provider::types::{ChatMessage, ChatRequest, LlmParams, Role};
use rucora_core::tool::Tool;
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::agent::ToolRegistry;
use crate::agent::execution::DefaultExecution;
use crate::conversation::ConversationManager;
pub struct ReActAgent<P> {
#[allow(dead_code)]
provider: Arc<P>,
#[allow(dead_code)]
model: String,
#[allow(dead_code)]
system_prompt: Option<String>,
#[allow(dead_code)]
tools: ToolRegistry,
#[allow(dead_code)]
max_steps: usize,
#[allow(dead_code)]
conversation_manager: Option<Arc<Mutex<ConversationManager>>>,
llm_params: LlmParams,
execution: DefaultExecution,
}
#[async_trait]
impl<P> Agent for ReActAgent<P>
where
P: LlmProvider + Send + Sync + 'static,
{
async fn think(&self, context: &AgentContext) -> AgentDecision {
if context.step == 0 {
AgentDecision::Chat {
request: Box::new(self._build_react_prompt(context, "think")),
}
} else if !context.tool_results.is_empty() {
AgentDecision::Chat {
request: Box::new(self._build_react_prompt(context, "observe")),
}
} else {
AgentDecision::Chat {
request: Box::new(self._build_react_prompt(context, "act")),
}
}
}
fn name(&self) -> &str {
"react_agent"
}
fn description(&self) -> Option<&str> {
Some("ReAct Agent,显式的推理 + 行动循环")
}
async fn run(&self, input: AgentInput) -> Result<AgentOutput, rucora_core::agent::AgentError> {
self.execution.run(self, input).await
}
fn run_stream(
&self,
input: AgentInput,
) -> futures_util::stream::BoxStream<
'static,
Result<rucora_core::channel::types::ChannelEvent, rucora_core::agent::AgentError>,
> {
self.execution.run_stream_simple(input)
}
}
impl<P> ReActAgent<P>
where
P: LlmProvider + Send + Sync + 'static,
{
pub async fn run_stream_text(
&self,
input: impl Into<AgentInput>,
) -> Result<String, rucora_core::agent::AgentError> {
self.execution.run_stream_text(input.into()).await
}
}
impl<P> ReActAgent<P>
where
P: LlmProvider,
{
pub fn builder() -> ReActAgentBuilder<P> {
ReActAgentBuilder::new()
}
fn _build_react_prompt(&self, context: &AgentContext, phase: &str) -> ChatRequest {
let prompt = match phase {
"think" => format!(
"请分析问题:{}\n\
\n\
思考步骤:\n\
1. 理解用户需求\n\
2. 确定需要什么信息\n\
3. 规划使用哪些工具\n\
\n\
可用工具:{:?}\n\
\n\
请详细分析并规划步骤。",
context.input.text(),
self.tools.tool_names()
),
"act" => format!(
"基于以上思考,请选择合适的工具行动。\n\
\n\
可用工具:{:?}\n\
\n\
如果需要调用工具,请使用工具调用格式。",
self.tools.tool_names()
),
"observe" => format!(
"观察工具执行结果,分析是否完成任务。\n\
\n\
如果完成,给出最终答案;否则继续思考下一步。\n\
\n\
当前步骤:{}/{}",
context.step, self.max_steps
),
_ => unreachable!(),
};
let mut messages = context.messages.clone();
if let Some(ref sys_prompt) = self.system_prompt
&& (messages.is_empty() || messages.first().map(|m| &m.role) != Some(&Role::System))
{
messages.insert(0, ChatMessage::system(sys_prompt.clone()));
}
messages.push(ChatMessage::user(prompt));
let mut request = ChatRequest {
messages,
model: Some(self.model.clone()),
tools: Some(self.tools.definitions()),
..Default::default()
};
self.llm_params.apply_to(&mut request);
request
}
pub fn tools(&self) -> Vec<&str> {
self.tools
.tool_names()
.into_iter()
.map(|s| s.as_str())
.collect()
}
}
pub struct ReActAgentBuilder<P> {
provider: Option<P>,
system_prompt: Option<String>,
model: Option<String>,
tools: ToolRegistry,
max_steps: usize,
with_conversation: bool,
middleware_chain: crate::middleware::MiddlewareChain,
llm_params: LlmParams,
}
impl<P> ReActAgentBuilder<P> {
pub fn new() -> Self {
Self {
provider: None,
system_prompt: None,
model: None,
tools: ToolRegistry::new(),
max_steps: 15, with_conversation: false,
middleware_chain: crate::middleware::MiddlewareChain::new(),
llm_params: LlmParams::default(),
}
}
}
impl<P> ReActAgentBuilder<P>
where
P: LlmProvider + Send + Sync + 'static,
{
pub fn provider(mut self, provider: P) -> Self {
self.provider = Some(provider);
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
self.tools = self.tools.register(tool);
self
}
pub fn tools<I, T>(mut self, tools: I) -> Self
where
I: IntoIterator<Item = T>,
T: Tool + 'static,
{
for tool in tools {
self.tools = self.tools.register(tool);
}
self
}
pub fn max_steps(mut self, max: usize) -> Self {
self.max_steps = max;
self
}
pub fn temperature(mut self, value: f32) -> Self {
self.llm_params.temperature = Some(value);
self
}
pub fn top_p(mut self, value: f32) -> Self {
self.llm_params.top_p = Some(value);
self
}
pub fn top_k(mut self, value: u32) -> Self {
self.llm_params.top_k = Some(value);
self
}
pub fn max_tokens(mut self, value: u32) -> Self {
self.llm_params.max_tokens = Some(value);
self
}
pub fn frequency_penalty(mut self, value: f32) -> Self {
self.llm_params.frequency_penalty = Some(value);
self
}
pub fn presence_penalty(mut self, value: f32) -> Self {
self.llm_params.presence_penalty = Some(value);
self
}
pub fn stop(mut self, value: Vec<String>) -> Self {
self.llm_params.stop = Some(value);
self
}
pub fn extra_params(mut self, value: serde_json::Value) -> Self {
self.llm_params.extra = Some(value);
self
}
pub fn llm_params(mut self, params: LlmParams) -> Self {
self.llm_params = params;
self
}
pub fn with_conversation(mut self, enabled: bool) -> Self {
self.with_conversation = enabled;
self
}
pub fn with_middleware_chain(
mut self,
middleware_chain: crate::middleware::MiddlewareChain,
) -> Self {
self.middleware_chain = middleware_chain;
self
}
pub fn with_middleware<M: crate::middleware::Middleware + 'static>(
mut self,
middleware: M,
) -> Self {
self.middleware_chain = self.middleware_chain.with(middleware);
self
}
pub fn try_build(self) -> Result<ReActAgent<P>, AgentError> {
let provider = self.provider.ok_or_else(|| {
AgentError::Message("构建 ReActAgent 失败:缺少 provider".to_string())
})?;
let model = self
.model
.ok_or_else(|| AgentError::Message("构建 ReActAgent 失败:缺少 model".to_string()))?;
let conversation_manager = if self.with_conversation {
let mut conv = ConversationManager::new();
if let Some(ref prompt) = self.system_prompt {
conv = conv.with_system_prompt(prompt.clone());
}
Some(Arc::new(Mutex::new(conv)))
} else {
None
};
let provider_arc = Arc::new(provider);
let execution =
DefaultExecution::new(provider_arc.clone(), model.clone(), self.tools.clone())
.with_system_prompt_opt(self.system_prompt.clone())
.with_max_steps(self.max_steps)
.with_conversation_manager(conversation_manager.clone())
.with_middleware_chain(self.middleware_chain)
.with_llm_params(self.llm_params.clone());
Ok(ReActAgent {
provider: provider_arc,
model,
system_prompt: self.system_prompt,
tools: self.tools,
max_steps: self.max_steps,
conversation_manager,
llm_params: self.llm_params,
execution,
})
}
pub fn build(self) -> ReActAgent<P> {
self.try_build()
.unwrap_or_else(|err| panic!("ReActAgentBuilder::build 失败:{err}"))
}
}
impl<P> Default for ReActAgentBuilder<P> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::stream;
use futures_util::stream::BoxStream;
use rucora_core::error::ProviderError;
use rucora_core::provider::types::{ChatResponse, ChatStreamChunk};
struct MockProvider;
#[async_trait]
impl LlmProvider for MockProvider {
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ProviderError> {
Ok(ChatResponse {
message: ChatMessage {
role: Role::Assistant,
content: "Mock response".to_string(),
name: None,
},
tool_calls: vec![],
usage: None,
finish_reason: None,
})
}
fn stream_chat(
&self,
_request: ChatRequest,
) -> Result<BoxStream<'static, Result<ChatStreamChunk, ProviderError>>, ProviderError>
{
Ok(Box::pin(stream::empty()))
}
}
#[test]
fn test_react_agent_builder() {
let _agent = ReActAgentBuilder::<MockProvider>::new()
.provider(MockProvider)
.model("gpt-4o-mini")
.max_steps(15)
.build();
}
}