use async_trait::async_trait;
use rucora_core::agent::{Agent, AgentContext, AgentDecision, AgentError, AgentInput, AgentOutput};
use rucora_core::provider::LlmProvider;
use rucora_core::provider::types::LlmParams;
use std::sync::Arc;
use crate::agent::execution::DefaultExecution;
pub struct SimpleAgent<P> {
#[allow(dead_code)]
provider: Arc<P>,
#[allow(dead_code)]
model: String,
#[allow(dead_code)]
system_prompt: Option<String>,
llm_params: LlmParams,
execution: DefaultExecution,
}
#[async_trait]
impl<P> Agent for SimpleAgent<P>
where
P: LlmProvider + Send + Sync + 'static,
{
async fn think(&self, context: &AgentContext) -> AgentDecision {
AgentDecision::Chat {
request: Box::new({
let mut request = context.default_chat_request_with(&self.llm_params);
request.model = Some(self.model.clone());
request.tools = None; request
}),
}
}
fn name(&self) -> &str {
"simple_agent"
}
fn description(&self) -> Option<&str> {
Some("简单问答 Agent,一次调用直接返回结果")
}
async fn run(&self, input: AgentInput) -> Result<AgentOutput, rucora_core::agent::AgentError> {
self.execution.run(self, input).await
}
fn run_stream(
&self,
input: AgentInput,
) -> futures_util::stream::BoxStream<
'static,
Result<rucora_core::channel::types::ChannelEvent, rucora_core::agent::AgentError>,
> {
self.execution.run_stream_simple(input)
}
}
impl<P> SimpleAgent<P>
where
P: LlmProvider + Send + Sync + 'static,
{
pub async fn run_stream_text(
&self,
input: impl Into<AgentInput>,
) -> Result<String, rucora_core::agent::AgentError> {
self.execution.run_stream_text(input.into()).await
}
}
impl<P> SimpleAgent<P> {
pub fn builder() -> SimpleAgentBuilder<P> {
SimpleAgentBuilder::new()
}
pub fn provider(&self) -> &P {
&self.provider
}
pub fn model(&self) -> &str {
&self.model
}
}
pub struct SimpleAgentBuilder<P> {
provider: Option<P>,
system_prompt: Option<String>,
model: Option<String>,
llm_params: LlmParams,
middleware_chain: crate::middleware::MiddlewareChain,
}
impl<P> SimpleAgentBuilder<P> {
pub fn new() -> Self {
Self {
provider: None,
system_prompt: None,
model: None,
llm_params: LlmParams::default(),
middleware_chain: crate::middleware::MiddlewareChain::new(),
}
}
}
impl<P> SimpleAgentBuilder<P>
where
P: LlmProvider + Send + Sync + 'static,
{
pub fn provider(mut self, provider: P) -> Self {
self.provider = Some(provider);
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn temperature(mut self, value: f32) -> Self {
self.llm_params.temperature = Some(value);
self
}
pub fn top_p(mut self, value: f32) -> Self {
self.llm_params.top_p = Some(value);
self
}
pub fn top_k(mut self, value: u32) -> Self {
self.llm_params.top_k = Some(value);
self
}
pub fn max_tokens(mut self, value: u32) -> Self {
self.llm_params.max_tokens = Some(value);
self
}
pub fn frequency_penalty(mut self, value: f32) -> Self {
self.llm_params.frequency_penalty = Some(value);
self
}
pub fn presence_penalty(mut self, value: f32) -> Self {
self.llm_params.presence_penalty = Some(value);
self
}
pub fn stop(mut self, value: Vec<String>) -> Self {
self.llm_params.stop = Some(value);
self
}
pub fn extra_params(mut self, value: serde_json::Value) -> Self {
self.llm_params.extra = Some(value);
self
}
pub fn llm_params(mut self, params: LlmParams) -> Self {
self.llm_params = params;
self
}
pub fn with_middleware_chain(
mut self,
middleware_chain: crate::middleware::MiddlewareChain,
) -> Self {
self.middleware_chain = middleware_chain;
self
}
pub fn with_middleware<M: crate::middleware::Middleware + 'static>(
mut self,
middleware: M,
) -> Self {
self.middleware_chain = self.middleware_chain.with(middleware);
self
}
pub fn try_build(self) -> Result<SimpleAgent<P>, AgentError> {
let provider = self.provider.ok_or_else(|| {
AgentError::Message("构建 SimpleAgent 失败:缺少 provider".to_string())
})?;
let model = self
.model
.ok_or_else(|| AgentError::Message("构建 SimpleAgent 失败:缺少 model".to_string()))?;
let provider_arc = Arc::new(provider);
let execution = DefaultExecution::new(
provider_arc.clone(),
model.clone(),
crate::agent::ToolRegistry::new(),
)
.with_system_prompt_opt(self.system_prompt.clone())
.with_max_steps(1) .with_middleware_chain(self.middleware_chain)
.with_llm_params(self.llm_params.clone());
Ok(SimpleAgent {
provider: provider_arc,
model,
system_prompt: self.system_prompt,
llm_params: self.llm_params,
execution,
})
}
pub fn build(self) -> SimpleAgent<P> {
self.try_build()
.unwrap_or_else(|err| panic!("SimpleAgentBuilder::build 失败:{err}"))
}
}
impl<P> Default for SimpleAgentBuilder<P> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::stream;
use futures_util::stream::BoxStream;
use rucora_core::error::ProviderError;
use rucora_core::provider::types::{ChatRequest, ChatResponse, ChatStreamChunk};
use rucora_core::provider::{ChatMessage, Role};
struct MockProvider;
#[async_trait]
impl LlmProvider for MockProvider {
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ProviderError> {
Ok(ChatResponse {
message: ChatMessage {
role: Role::Assistant,
content: "Mock response".to_string(),
name: None,
},
tool_calls: vec![],
usage: None,
finish_reason: None,
})
}
fn stream_chat(
&self,
_request: ChatRequest,
) -> Result<BoxStream<'static, Result<ChatStreamChunk, ProviderError>>, ProviderError>
{
Ok(Box::pin(stream::empty()))
}
}
#[test]
fn test_simple_agent_builder() {
let _agent = SimpleAgentBuilder::<MockProvider>::new()
.provider(MockProvider)
.model("gpt-4o-mini")
.system_prompt("test")
.temperature(0.5)
.build();
}
}