use async_trait::async_trait;
use rucora_core::agent::{Agent, AgentContext, AgentDecision, AgentError, AgentInput, AgentOutput};
use rucora_core::provider::LlmProvider;
use rucora_core::provider::types::{ChatMessage, LlmParams};
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::agent::execution::DefaultExecution;
use crate::conversation::ConversationManager;
pub struct ChatAgent<P> {
provider: Arc<P>,
model: String,
system_prompt: Option<String>,
llm_params: LlmParams,
conversation_manager: Option<Arc<Mutex<ConversationManager>>>,
execution: DefaultExecution,
}
#[async_trait]
impl<P> Agent for ChatAgent<P>
where
P: LlmProvider + Send + Sync + 'static,
{
async fn think(&self, context: &AgentContext) -> AgentDecision {
AgentDecision::Chat {
request: Box::new({
let mut request = context.default_chat_request_with(&self.llm_params);
request.model = Some(self.model.clone());
request.tools = None; request
}),
}
}
fn name(&self) -> &str {
"chat_agent"
}
fn description(&self) -> Option<&str> {
Some("纯对话 Agent,支持多轮对话历史")
}
async fn run(&self, input: AgentInput) -> Result<AgentOutput, rucora_core::agent::AgentError> {
self.execution.run(self, input).await
}
fn run_stream(
&self,
input: AgentInput,
) -> futures_util::stream::BoxStream<
'static,
Result<rucora_core::channel::types::ChannelEvent, rucora_core::agent::AgentError>,
> {
self.execution.run_stream_simple(input)
}
}
impl<P> ChatAgent<P>
where
P: LlmProvider + Send + Sync + 'static,
{
pub async fn run_stream_text(
&self,
input: impl Into<AgentInput>,
) -> Result<String, rucora_core::agent::AgentError> {
self.execution.run_stream_text(input.into()).await
}
}
impl<P> ChatAgent<P> {
pub fn builder() -> ChatAgentBuilder<P> {
ChatAgentBuilder::new()
}
pub fn provider(&self) -> &P {
&self.provider
}
pub fn model(&self) -> &str {
&self.model
}
pub async fn get_conversation_history(&self) -> Option<Vec<ChatMessage>> {
match &self.conversation_manager {
Some(conv_arc) => {
let conv = conv_arc.lock().await;
Some(conv.get_messages().to_vec())
}
None => None,
}
}
pub async fn clear_conversation(&self) {
if let Some(ref conv_arc) = self.conversation_manager {
let mut conv = conv_arc.lock().await;
conv.clear();
if let Some(ref prompt) = self.system_prompt {
conv.ensure_system_prompt(prompt);
}
}
}
}
pub struct ChatAgentBuilder<P> {
provider: Option<P>,
system_prompt: Option<String>,
model: Option<String>,
llm_params: LlmParams,
with_conversation: bool,
max_history_messages: usize,
middleware_chain: crate::middleware::MiddlewareChain,
}
impl<P> ChatAgentBuilder<P> {
pub fn new() -> Self {
Self {
provider: None,
system_prompt: None,
model: None,
llm_params: LlmParams::default(),
with_conversation: false,
max_history_messages: 0, middleware_chain: crate::middleware::MiddlewareChain::new(),
}
}
}
impl<P> ChatAgentBuilder<P>
where
P: LlmProvider + Send + Sync + 'static,
{
pub fn provider(mut self, provider: P) -> Self {
self.provider = Some(provider);
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn temperature(mut self, value: f32) -> Self {
self.llm_params.temperature = Some(value);
self
}
pub fn top_p(mut self, value: f32) -> Self {
self.llm_params.top_p = Some(value);
self
}
pub fn top_k(mut self, value: u32) -> Self {
self.llm_params.top_k = Some(value);
self
}
pub fn max_tokens(mut self, value: u32) -> Self {
self.llm_params.max_tokens = Some(value);
self
}
pub fn frequency_penalty(mut self, value: f32) -> Self {
self.llm_params.frequency_penalty = Some(value);
self
}
pub fn presence_penalty(mut self, value: f32) -> Self {
self.llm_params.presence_penalty = Some(value);
self
}
pub fn stop(mut self, value: Vec<String>) -> Self {
self.llm_params.stop = Some(value);
self
}
pub fn extra_params(mut self, value: serde_json::Value) -> Self {
self.llm_params.extra = Some(value);
self
}
pub fn llm_params(mut self, params: LlmParams) -> Self {
self.llm_params = params;
self
}
pub fn with_conversation(mut self, enabled: bool) -> Self {
self.with_conversation = enabled;
self
}
pub fn max_history_messages(mut self, max_messages: usize) -> Self {
self.max_history_messages = max_messages;
self
}
pub fn with_middleware_chain(
mut self,
middleware_chain: crate::middleware::MiddlewareChain,
) -> Self {
self.middleware_chain = middleware_chain;
self
}
pub fn with_middleware<M: crate::middleware::Middleware + 'static>(
mut self,
middleware: M,
) -> Self {
self.middleware_chain = self.middleware_chain.with(middleware);
self
}
pub fn try_build(self) -> Result<ChatAgent<P>, AgentError> {
let provider = self
.provider
.ok_or_else(|| AgentError::Message("构建 ChatAgent 失败:缺少 provider".to_string()))?;
let model = self
.model
.ok_or_else(|| AgentError::Message("构建 ChatAgent 失败:缺少 model".to_string()))?;
let conversation_manager = if self.with_conversation {
let mut conv = ConversationManager::new();
if let Some(ref prompt) = self.system_prompt {
conv = conv.with_system_prompt(prompt.clone());
}
if self.max_history_messages > 0 {
conv = conv.with_max_messages(self.max_history_messages);
}
Some(Arc::new(Mutex::new(conv)))
} else {
None
};
let provider_arc = Arc::new(provider);
let execution = DefaultExecution::new(
provider_arc.clone(),
model.clone(),
crate::agent::ToolRegistry::new(),
)
.with_system_prompt_opt(self.system_prompt.clone())
.with_conversation_manager(conversation_manager.clone())
.with_middleware_chain(self.middleware_chain)
.with_llm_params(self.llm_params.clone());
Ok(ChatAgent {
provider: provider_arc,
model,
system_prompt: self.system_prompt,
llm_params: self.llm_params,
conversation_manager,
execution,
})
}
pub fn build(self) -> ChatAgent<P> {
self.try_build()
.unwrap_or_else(|err| panic!("ChatAgentBuilder::build 失败:{err}"))
}
}
impl<P> Default for ChatAgentBuilder<P> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::stream;
use futures_util::stream::BoxStream;
use rucora_core::error::ProviderError;
use rucora_core::provider::Role;
use rucora_core::provider::types::{ChatRequest, ChatResponse, ChatStreamChunk};
struct MockProvider;
#[async_trait]
impl LlmProvider for MockProvider {
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ProviderError> {
Ok(ChatResponse {
message: ChatMessage {
role: Role::Assistant,
content: "Mock response".to_string(),
name: None,
},
tool_calls: vec![],
usage: None,
finish_reason: None,
})
}
fn stream_chat(
&self,
_request: ChatRequest,
) -> Result<BoxStream<'static, Result<ChatStreamChunk, ProviderError>>, ProviderError>
{
Ok(Box::pin(stream::empty()))
}
}
#[test]
fn test_chat_agent_builder() {
let _agent = ChatAgentBuilder::<MockProvider>::new()
.provider(MockProvider)
.model("gpt-4o-mini")
.system_prompt("test")
.with_conversation(true)
.max_history_messages(20)
.build();
}
}