Skip to main content

agentlib_core/
agent.rs

1use crate::memory::MemoryProvider;
2use crate::middleware::{Middleware, MiddlewarePipeline, ToolMiddlewareContext};
3use crate::provider::ModelProvider;
4use crate::reasoning::{ReasoningContext, ReasoningEngine};
5use crate::tool::ToolRegistry;
6use crate::types::{
7    AgentPolicy, ExecutionContext, MiddlewareScope, ModelMessage, ModelRequest, Role,
8};
9use anyhow::{anyhow, Result};
10use std::sync::Arc;
11use tokio::sync::Mutex;
12
13#[derive(Clone)]
14pub struct AgentConfig {
15    pub name: String,
16    pub description: Option<String>,
17    pub system_prompt: Option<String>,
18    pub provider: Option<Arc<dyn ModelProvider>>,
19    pub memory: Option<Arc<dyn MemoryProvider>>,
20    pub engine: Option<Arc<dyn ReasoningEngine>>,
21    pub policy: AgentPolicy,
22}
23
24pub trait Agent: Send + Sync {
25    fn name(&self) -> &str;
26    fn description(&self) -> Option<&str> {
27        None
28    }
29    fn system_prompt(&self) -> Option<&str> {
30        None
31    }
32    fn register_tools(self: Arc<Self>, registry: &mut ToolRegistry);
33}
34
35pub fn create_agent<T: Agent + 'static>(data: T) -> AgentInstance {
36    let data = Arc::new(data);
37    let instance = AgentInstance::new(AgentConfig {
38        name: data.name().to_string(),
39        description: data.description().map(|s| s.to_string()),
40        system_prompt: data.system_prompt().map(|s| s.to_string()),
41        provider: None,
42        memory: None,
43        engine: None,
44        policy: AgentPolicy::default(),
45    });
46
47    {
48        let mut tools = instance
49            .tools
50            .try_lock()
51            .expect("Failed to lock tools during init");
52        data.clone().register_tools(&mut tools);
53    }
54
55    instance
56}
57
58pub struct AgentInstance {
59    pub config: AgentConfig,
60    pub tools: Arc<Mutex<ToolRegistry>>,
61    pub middleware: MiddlewarePipeline,
62}
63
64impl AgentInstance {
65    pub fn new(config: AgentConfig) -> Self {
66        Self {
67            config,
68            tools: Arc::new(Mutex::new(ToolRegistry::new())),
69            middleware: MiddlewarePipeline::new(),
70        }
71    }
72
73    pub fn provider(mut self, provider: Arc<dyn ModelProvider>) -> Self {
74        self.config.provider = Some(provider);
75        self
76    }
77
78    pub fn memory(mut self, memory: Arc<dyn MemoryProvider>) -> Self {
79        self.config.memory = Some(memory);
80        self
81    }
82
83    pub fn engine(mut self, engine: Arc<dyn ReasoningEngine>) -> Self {
84        self.config.engine = Some(engine);
85        self
86    }
87
88    pub fn policy(mut self, policy: AgentPolicy) -> Self {
89        self.config.policy = policy;
90        self
91    }
92
93    pub fn system_prompt(&mut self, prompt: &str) -> &mut Self {
94        self.config.system_prompt = Some(prompt.to_string());
95        self
96    }
97
98    pub fn use_middleware(&mut self, middleware: Box<dyn Middleware>) -> &mut Self {
99        self.middleware.use_middleware(middleware);
100        self
101    }
102
103    pub async fn run(&self, input: &str) -> Result<String> {
104        let mut ctx = ExecutionContext::new(input.to_string());
105        self.middleware
106            .run(MiddlewareScope::RunBefore, &mut ctx)
107            .await?;
108
109        // Setup initial messages
110        if let Some(prompt) = &self.config.system_prompt {
111            ctx.messages.push(ModelMessage {
112                role: Role::System,
113                content: prompt.clone(),
114                tool_call_id: None,
115                tool_calls: None,
116            });
117        }
118
119        ctx.messages.push(ModelMessage {
120            role: Role::User,
121            content: input.to_string(),
122            tool_call_id: None,
123            tool_calls: None,
124        });
125
126        let provider = self
127            .config
128            .provider
129            .as_ref()
130            .ok_or_else(|| anyhow!("No provider configured"))?;
131
132        // Minimal Loop (ReAct-like)
133        let output = if let Some(engine) = &self.config.engine {
134            let mut r_ctx = ReasoningContext {
135                ctx: &mut ctx,
136                model: provider.as_ref(),
137                tools: &*self.tools.lock().await,
138                middleware: &self.middleware,
139                policy: self.config.policy.clone(),
140            };
141            engine.execute(&mut r_ctx).await?
142        } else {
143            // Default simple loop
144            loop {
145                self.middleware
146                    .run(MiddlewareScope::StepBefore, &mut ctx)
147                    .await?;
148                let tools = self.tools.lock().await;
149                let request = ModelRequest {
150                    messages: ctx.messages.clone(),
151                    tools: Some(tools.list()),
152                };
153                drop(tools);
154
155                let response = provider.complete(request).await?;
156
157                // Add usage
158                if let Some(usage) = &response.usage {
159                    ctx.usage.prompt_tokens += usage.prompt_tokens;
160                    ctx.usage.completion_tokens += usage.completion_tokens;
161                    ctx.usage.total_tokens += usage.total_tokens;
162                }
163
164                let message = response.message;
165                ctx.messages.push(message.clone());
166
167                if let Some(tool_calls) = message.tool_calls {
168                    for tool_call in tool_calls {
169                        let tools = self.tools.lock().await;
170
171                        let tool_mw_ctx = ToolMiddlewareContext {
172                            name: tool_call.name.clone(),
173                            args: tool_call.arguments.clone(),
174                            result: None,
175                        };
176
177                        self.middleware
178                            .run_tool(MiddlewareScope::ToolBefore, &mut ctx, tool_mw_ctx.clone())
179                            .await?;
180
181                        let result = tools
182                            .call_tool(&tool_call.name, tool_call.arguments)
183                            .await?;
184                        drop(tools);
185
186                        let mut tool_mw_ctx_after = tool_mw_ctx;
187                        tool_mw_ctx_after.result = Some(result.clone());
188                        self.middleware
189                            .run_tool(MiddlewareScope::ToolAfter, &mut ctx, tool_mw_ctx_after)
190                            .await?;
191
192                        ctx.messages.push(ModelMessage {
193                            role: Role::Tool,
194                            content: result.to_string(),
195                            tool_call_id: Some(tool_call.id),
196                            tool_calls: None,
197                        });
198                    }
199                } else {
200                    self.middleware
201                        .run(MiddlewareScope::StepAfter, &mut ctx)
202                        .await?;
203                    break message.content;
204                }
205                self.middleware
206                    .run(MiddlewareScope::StepAfter, &mut ctx)
207                    .await?;
208            }
209        };
210
211        self.middleware
212            .run(MiddlewareScope::RunAfter, &mut ctx)
213            .await?;
214        Ok(output)
215    }
216}