1use crate::memory::MemoryProvider;
2use crate::middleware::{Middleware, MiddlewarePipeline, ToolMiddlewareContext};
3use crate::provider::ModelProvider;
4use crate::reasoning::{ReasoningContext, ReasoningEngine};
5use crate::tool::ToolRegistry;
6use crate::types::{
7 AgentPolicy, ExecutionContext, MiddlewareScope, ModelMessage, ModelRequest, Role,
8};
9use anyhow::{anyhow, Result};
10use std::sync::Arc;
11use tokio::sync::Mutex;
12
13#[derive(Clone)]
14pub struct AgentConfig {
15 pub name: String,
16 pub description: Option<String>,
17 pub system_prompt: Option<String>,
18 pub provider: Option<Arc<dyn ModelProvider>>,
19 pub memory: Option<Arc<dyn MemoryProvider>>,
20 pub engine: Option<Arc<dyn ReasoningEngine>>,
21 pub policy: AgentPolicy,
22}
23
24pub trait Agent: Send + Sync {
25 fn name(&self) -> &str;
26 fn description(&self) -> Option<&str> {
27 None
28 }
29 fn system_prompt(&self) -> Option<&str> {
30 None
31 }
32 fn register_tools(self: Arc<Self>, registry: &mut ToolRegistry);
33}
34
35pub fn create_agent<T: Agent + 'static>(data: T) -> AgentInstance {
36 let data = Arc::new(data);
37 let instance = AgentInstance::new(AgentConfig {
38 name: data.name().to_string(),
39 description: data.description().map(|s| s.to_string()),
40 system_prompt: data.system_prompt().map(|s| s.to_string()),
41 provider: None,
42 memory: None,
43 engine: None,
44 policy: AgentPolicy::default(),
45 });
46
47 {
48 let mut tools = instance
49 .tools
50 .try_lock()
51 .expect("Failed to lock tools during init");
52 data.clone().register_tools(&mut tools);
53 }
54
55 instance
56}
57
58pub struct AgentInstance {
59 pub config: AgentConfig,
60 pub tools: Arc<Mutex<ToolRegistry>>,
61 pub middleware: MiddlewarePipeline,
62}
63
64impl AgentInstance {
65 pub fn new(config: AgentConfig) -> Self {
66 Self {
67 config,
68 tools: Arc::new(Mutex::new(ToolRegistry::new())),
69 middleware: MiddlewarePipeline::new(),
70 }
71 }
72
73 pub fn provider(mut self, provider: Arc<dyn ModelProvider>) -> Self {
74 self.config.provider = Some(provider);
75 self
76 }
77
78 pub fn memory(mut self, memory: Arc<dyn MemoryProvider>) -> Self {
79 self.config.memory = Some(memory);
80 self
81 }
82
83 pub fn engine(mut self, engine: Arc<dyn ReasoningEngine>) -> Self {
84 self.config.engine = Some(engine);
85 self
86 }
87
88 pub fn policy(mut self, policy: AgentPolicy) -> Self {
89 self.config.policy = policy;
90 self
91 }
92
93 pub fn system_prompt(&mut self, prompt: &str) -> &mut Self {
94 self.config.system_prompt = Some(prompt.to_string());
95 self
96 }
97
98 pub fn use_middleware(&mut self, middleware: Box<dyn Middleware>) -> &mut Self {
99 self.middleware.use_middleware(middleware);
100 self
101 }
102
103 pub async fn run(&self, input: &str) -> Result<String> {
104 let mut ctx = ExecutionContext::new(input.to_string());
105 self.middleware
106 .run(MiddlewareScope::RunBefore, &mut ctx)
107 .await?;
108
109 if let Some(prompt) = &self.config.system_prompt {
111 ctx.messages.push(ModelMessage {
112 role: Role::System,
113 content: prompt.clone(),
114 tool_call_id: None,
115 tool_calls: None,
116 });
117 }
118
119 ctx.messages.push(ModelMessage {
120 role: Role::User,
121 content: input.to_string(),
122 tool_call_id: None,
123 tool_calls: None,
124 });
125
126 let provider = self
127 .config
128 .provider
129 .as_ref()
130 .ok_or_else(|| anyhow!("No provider configured"))?;
131
132 let output = if let Some(engine) = &self.config.engine {
134 let mut r_ctx = ReasoningContext {
135 ctx: &mut ctx,
136 model: provider.as_ref(),
137 tools: &*self.tools.lock().await,
138 middleware: &self.middleware,
139 policy: self.config.policy.clone(),
140 };
141 engine.execute(&mut r_ctx).await?
142 } else {
143 loop {
145 self.middleware
146 .run(MiddlewareScope::StepBefore, &mut ctx)
147 .await?;
148 let tools = self.tools.lock().await;
149 let request = ModelRequest {
150 messages: ctx.messages.clone(),
151 tools: Some(tools.list()),
152 };
153 drop(tools);
154
155 let response = provider.complete(request).await?;
156
157 if let Some(usage) = &response.usage {
159 ctx.usage.prompt_tokens += usage.prompt_tokens;
160 ctx.usage.completion_tokens += usage.completion_tokens;
161 ctx.usage.total_tokens += usage.total_tokens;
162 }
163
164 let message = response.message;
165 ctx.messages.push(message.clone());
166
167 if let Some(tool_calls) = message.tool_calls {
168 for tool_call in tool_calls {
169 let tools = self.tools.lock().await;
170
171 let tool_mw_ctx = ToolMiddlewareContext {
172 name: tool_call.name.clone(),
173 args: tool_call.arguments.clone(),
174 result: None,
175 };
176
177 self.middleware
178 .run_tool(MiddlewareScope::ToolBefore, &mut ctx, tool_mw_ctx.clone())
179 .await?;
180
181 let result = tools
182 .call_tool(&tool_call.name, tool_call.arguments)
183 .await?;
184 drop(tools);
185
186 let mut tool_mw_ctx_after = tool_mw_ctx;
187 tool_mw_ctx_after.result = Some(result.clone());
188 self.middleware
189 .run_tool(MiddlewareScope::ToolAfter, &mut ctx, tool_mw_ctx_after)
190 .await?;
191
192 ctx.messages.push(ModelMessage {
193 role: Role::Tool,
194 content: result.to_string(),
195 tool_call_id: Some(tool_call.id),
196 tool_calls: None,
197 });
198 }
199 } else {
200 self.middleware
201 .run(MiddlewareScope::StepAfter, &mut ctx)
202 .await?;
203 break message.content;
204 }
205 self.middleware
206 .run(MiddlewareScope::StepAfter, &mut ctx)
207 .await?;
208 }
209 };
210
211 self.middleware
212 .run(MiddlewareScope::RunAfter, &mut ctx)
213 .await?;
214 Ok(output)
215 }
216}